diff --git a/bar.go b/bar.go index a1c3a38..3d31501 100644 --- a/bar.go +++ b/bar.go @@ -141,7 +141,11 @@ if !ok { rc = ioutil.NopCloser(r) } - return &proxyReader{rc, b, time.Now()} + prox := &proxyReader{rc, b, time.Now()} + if wt, ok := r.(io.WriterTo); ok { + return &proxyWriterTo{prox, wt} + } + return prox } // ID returs id of the bar. diff --git a/proxyreader.go b/proxyreader.go index 425e57a..1034fe6 100644 --- a/proxyreader.go +++ b/proxyreader.go @@ -5,24 +5,41 @@ "time" ) -// proxyReader is io.Reader wrapper, for proxy read bytes type proxyReader struct { io.ReadCloser bar *Bar iT time.Time } -func (pr *proxyReader) Read(p []byte) (n int, err error) { - n, err = pr.ReadCloser.Read(p) +func (prox *proxyReader) Read(p []byte) (n int, err error) { + n, err = prox.ReadCloser.Read(p) if n > 0 { - pr.bar.IncrBy(n, time.Since(pr.iT)) - pr.iT = time.Now() + prox.bar.IncrBy(n, time.Since(prox.iT)) + prox.iT = time.Now() } if err == io.EOF { go func() { - current := pr.bar.Current() - pr.bar.SetTotal(current, true) + prox.bar.SetTotal(prox.bar.Current(), true) }() } return } + +type proxyWriterTo struct { + *proxyReader + wt io.WriterTo +} + +func (prox *proxyWriterTo) WriteTo(w io.Writer) (n int64, err error) { + n, err = prox.wt.WriteTo(w) + if n > 0 { + prox.bar.IncrInt64(n, time.Since(prox.iT)) + prox.iT = time.Now() + } + if err == io.EOF { + go func() { + prox.bar.SetTotal(prox.bar.Current(), true) + }() + } + return +} diff --git a/proxyreader_test.go b/proxyreader_test.go index 280e54a..2042e9e 100644 --- a/proxyreader_test.go +++ b/proxyreader_test.go @@ -1,6 +1,7 @@ package mpb_test import ( + "bytes" "io" "io/ioutil" "strings" @@ -28,26 +29,62 @@ } func TestProxyReader(t *testing.T) { - p := mpb.New(mpb.WithOutput(ioutil.Discard)) - reader := &testReader{Reader: strings.NewReader(content)} + tReader := &testReader{strings.NewReader(content), false} - total := len(content) - bar := p.AddBar(100, mpb.TrimSpace()) + bar := p.AddBar(int64(len(content)), mpb.TrimSpace()) - written, err := io.Copy(ioutil.Discard, bar.ProxyReader(reader)) + var buf bytes.Buffer + _, err := io.Copy(&buf, bar.ProxyReader(tReader)) if err != nil { t.Errorf("Error copying from reader: %+v\n", err) } p.Wait() - if !reader.called { + if !tReader.called { t.Error("Read not called") } - if written != int64(total) { - t.Errorf("Expected written: %d, got: %d\n", total, written) + if got := buf.String(); got != content { + t.Errorf("Expected content: %s, got: %s\n", content, got) } } + +type testWriterTo struct { + io.Reader + wt io.WriterTo + called bool +} + +func (wt *testWriterTo) WriteTo(w io.Writer) (n int64, err error) { + wt.called = true + return wt.wt.WriteTo(w) +} + +func TestProxyWriterTo(t *testing.T) { + p := mpb.New(mpb.WithOutput(ioutil.Discard)) + + var reader io.Reader = strings.NewReader(content) + wt := reader.(io.WriterTo) + tReader := &testWriterTo{reader, wt, false} + + bar := p.AddBar(int64(len(content)), mpb.TrimSpace()) + + var buf bytes.Buffer + _, err := io.Copy(&buf, bar.ProxyReader(tReader)) + if err != nil { + t.Errorf("Error copying from reader: %+v\n", err) + } + + p.Wait() + + if !tReader.called { + t.Error("WriteTo not called") + } + + if got := buf.String(); got != content { + t.Errorf("Expected content: %s, got: %s\n", content, got) + } +}