diff --git a/proxyreader_test.go b/proxyreader_test.go index eaa972d..e56e82a 100644 --- a/proxyreader_test.go +++ b/proxyreader_test.go @@ -27,16 +27,6 @@ return r.Reader.Read(p) } -type testWriterTo struct { - *testReader - called bool -} - -func (wt *testWriterTo) WriteTo(w io.Writer) (n int64, err error) { - wt.called = true - return wt.Reader.(io.WriterTo).WriteTo(w) -} - func TestProxyReader(t *testing.T) { p := mpb.New(mpb.WithOutput(io.Discard)) @@ -61,11 +51,48 @@ } } +type testReadCloser struct { + io.Reader + called bool +} + +func (r *testReadCloser) Close() error { + r.called = true + return nil +} + +func TestProxyReadCloser(t *testing.T) { + p := mpb.New(mpb.WithOutput(io.Discard)) + + reader := &testReadCloser{strings.NewReader(content), false} + + bar := p.AddBar(int64(len(content))) + + rc := bar.ProxyReader(reader) + _, _ = io.Copy(io.Discard, rc) + _ = rc.Close() + + p.Wait() + + if !reader.called { + t.Error("Close not called") + } +} + +type testWriterTo struct { + io.Reader + called bool +} + +func (wt *testWriterTo) WriteTo(w io.Writer) (n int64, err error) { + wt.called = true + return wt.Reader.(io.WriterTo).WriteTo(w) +} + func TestProxyWriterTo(t *testing.T) { p := mpb.New(mpb.WithOutput(io.Discard)) - var reader io.Reader = strings.NewReader(content) - writerTo := &testWriterTo{&testReader{reader, false}, false} + writerTo := &testWriterTo{strings.NewReader(content), false} bar := p.AddBar(int64(len(content)))