diff --git a/bar.go b/bar.go index 582af1a..a5f514c 100644 --- a/bar.go +++ b/bar.go @@ -144,12 +144,11 @@ } // ProxyReader allows progress tracking against provided io.Reader. -func (b *Bar) ProxyReader(r io.Reader) *Reader { - proxyReader := &Reader{ +func (b *Bar) ProxyReader(r io.Reader) io.Reader { + return &Reader{ Reader: r, bar: b, } - return proxyReader } // ID returs id of the bar. diff --git a/proxyreader_test.go b/proxyreader_test.go index c85fb2e..d1c0862 100644 --- a/proxyreader_test.go +++ b/proxyreader_test.go @@ -4,8 +4,6 @@ "bytes" "io" "io/ioutil" - "net/http" - "net/http/httptest" "strings" "testing" @@ -30,6 +28,10 @@ bar := p.AddBar(100, mpb.BarTrim()) preader := bar.ProxyReader(reader) + if _, ok := preader.(io.Closer); !ok { + t.Error("type assertion to io.Closer is not ok") + } + written, err := io.Copy(ioutil.Discard, preader) if err != nil { t.Errorf("Error copying from reader: %+v\n", err) @@ -40,49 +42,4 @@ if written != int64(total) { t.Errorf("Expected written: %d, got: %d\n", total, written) } - - // underlying reader is not Closer - err = preader.Close() - if err != nil { - t.Errorf("Expected nil error, got: %+v\n", err) - } } - -func TestProxyReaderCloser(t *testing.T) { - var buf bytes.Buffer - p := mpb.New(mpb.WithOutput(&buf)) - - ts := setupTestHttpServer(content) - defer ts.Close() - - url := ts.URL + "/test" - resp, err := http.Get(url) - if err != nil { - t.Errorf("Test server get failure: %s\n", url) - } - - total := resp.ContentLength - bar := p.AddBar(total, mpb.BarTrim()) - reader := bar.ProxyReader(resp.Body) - - // calling reader.Close() will call resp.Body.Close() implicitly - err = reader.Close() - if err != nil { - t.Logf("Error closing resp.Body over reader.Close: %+v\n", err) - t.FailNow() - } - - // reading from closed resp.Body - _, err = io.Copy(ioutil.Discard, reader) - if err == nil { - t.Error("Expected read on closed response body error!") - } -} - -func setupTestHttpServer(content string) *httptest.Server { - mux := http.NewServeMux() - mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, content) - }) - return httptest.NewServer(mux) -}