diff --git a/proxyreader.go b/proxyreader.go index 5a46ea7..0e87c2e 100644 --- a/proxyreader.go +++ b/proxyreader.go @@ -18,11 +18,10 @@ type proxyWriterTo struct { proxyReader - wt io.WriterTo } func (x proxyWriterTo) WriteTo(w io.Writer) (int64, error) { - n, err := x.wt.WriteTo(w) + n, err := x.ReadCloser.(io.WriterTo).WriteTo(w) x.bar.IncrInt64(n) return n, err } @@ -42,12 +41,11 @@ type ewmaProxyWriterTo struct { ewmaProxyReader - wt proxyWriterTo } func (x ewmaProxyWriterTo) WriteTo(w io.Writer) (int64, error) { start := time.Now() - n, err := x.wt.WriteTo(w) + n, err := x.ReadCloser.(io.WriterTo).WriteTo(w) if n > 0 { x.bar.DecoratorEwmaUpdate(time.Since(start)) } @@ -58,14 +56,13 @@ pr := proxyReader{toReadCloser(r), b} if hasEwma { epr := ewmaProxyReader{pr} - if wt, ok := r.(io.WriterTo); ok { - pwt := proxyWriterTo{pr, wt} - return ewmaProxyWriterTo{epr, pwt} + if _, ok := r.(io.WriterTo); ok { + return ewmaProxyWriterTo{epr} } return epr } - if wt, ok := r.(io.WriterTo); ok { - return proxyWriterTo{pr, wt} + if _, ok := r.(io.WriterTo); ok { + return proxyWriterTo{pr} } return pr } diff --git a/proxyreader_test.go b/proxyreader_test.go index 0f7f058..eaa972d 100644 --- a/proxyreader_test.go +++ b/proxyreader_test.go @@ -29,30 +29,30 @@ type testWriterTo struct { *testReader - wt io.WriterTo + called bool } -func (wt testWriterTo) WriteTo(w io.Writer) (n int64, err error) { +func (wt *testWriterTo) WriteTo(w io.Writer) (n int64, err error) { wt.called = true - return wt.wt.WriteTo(w) + return wt.Reader.(io.WriterTo).WriteTo(w) } func TestProxyReader(t *testing.T) { p := mpb.New(mpb.WithOutput(io.Discard)) - tReader := &testReader{strings.NewReader(content), false} + reader := &testReader{strings.NewReader(content), false} bar := p.AddBar(int64(len(content))) var buf bytes.Buffer - _, err := io.Copy(&buf, bar.ProxyReader(tReader)) + _, err := io.Copy(&buf, bar.ProxyReader(reader)) if err != nil { t.Errorf("Error copying from reader: %+v\n", err) } p.Wait() - if !tReader.called { + if !reader.called { t.Error("Read not called") } @@ -65,19 +65,19 @@ p := mpb.New(mpb.WithOutput(io.Discard)) var reader io.Reader = strings.NewReader(content) - tWriterTo := testWriterTo{&testReader{reader, false}, reader.(io.WriterTo)} + writerTo := &testWriterTo{&testReader{reader, false}, false} bar := p.AddBar(int64(len(content))) var buf bytes.Buffer - _, err := io.Copy(&buf, bar.ProxyReader(tWriterTo)) + _, err := io.Copy(&buf, bar.ProxyReader(writerTo)) if err != nil { t.Errorf("Error copying from reader: %+v\n", err) } p.Wait() - if !tWriterTo.called { + if !writerTo.called { t.Error("WriteTo not called") }