diff --git a/bar.go b/bar.go index d2ab14f..cf9413e 100644 --- a/bar.go +++ b/bar.go @@ -97,6 +97,22 @@ select { case b.operateState <- func(s *bState) { result <- len(s.ewmaDecorators) != 0 }: return newProxyReader(r, b, <-result) + case <-b.done: + return nil + } +} + +// ProxyWriter wraps io.Writer with metrics required for progress tracking. +// If bar is already completed or aborted, returns nil. +// Panics if `w` is nil. +func (b *Bar) ProxyWriter(w io.Writer) io.WriteCloser { + if w == nil { + panic("expected non nil io.Writer") + } + result := make(chan bool) + select { + case b.operateState <- func(s *bState) { result <- len(s.ewmaDecorators) != 0 }: + return newProxyWriter(w, b, <-result) case <-b.done: return nil } diff --git a/proxywriter.go b/proxywriter.go new file mode 100644 index 0000000..96fe452 --- /dev/null +++ b/proxywriter.go @@ -0,0 +1,98 @@ +package mpb + +import ( + "io" + "time" +) + +type proxyWriter struct { + io.WriteCloser + bar *Bar +} + +func (x proxyWriter) Write(p []byte) (int, error) { + n, err := x.WriteCloser.Write(p) + x.bar.IncrBy(n) + return n, err +} + +type proxyReaderFrom struct { + proxyWriter +} + +func (x proxyReaderFrom) ReadFrom(r io.Reader) (int64, error) { + n, err := x.WriteCloser.(io.ReaderFrom).ReadFrom(r) + x.bar.IncrInt64(n) + return n, err +} + +type ewmaProxyWriter struct { + proxyWriter +} + +func (x ewmaProxyWriter) Write(p []byte) (int, error) { + start := time.Now() + n, err := x.proxyWriter.Write(p) + if n > 0 { + x.bar.DecoratorEwmaUpdate(time.Since(start)) + } + return n, err +} + +type ewmaProxyReaderFrom struct { + ewmaProxyWriter +} + +func (x ewmaProxyReaderFrom) ReadFrom(r io.Reader) (int64, error) { + start := time.Now() + n, err := x.WriteCloser.(io.ReaderFrom).ReadFrom(r) + if n > 0 { + x.bar.DecoratorEwmaUpdate(time.Since(start)) + } + return n, err +} + +func newProxyWriter(w io.Writer, b *Bar, hasEwma bool) io.WriteCloser { + pw := proxyWriter{toWriteCloser(w), b} + if hasEwma { + epw := ewmaProxyWriter{pw} + if _, ok := w.(io.ReaderFrom); ok { + return ewmaProxyReaderFrom{epw} + } + return epw + } + if _, ok := w.(io.ReaderFrom); ok { + return proxyReaderFrom{pw} + } + return pw +} + +func toWriteCloser(w io.Writer) io.WriteCloser { + if wc, ok := w.(io.WriteCloser); ok { + return wc + } + return toNopWriteCloser(w) +} + +func toNopWriteCloser(w io.Writer) io.WriteCloser { + if _, ok := w.(io.ReaderFrom); ok { + return nopWriteCloserReaderFrom{w} + } + return nopWriteCloser{w} +} + +type nopWriteCloser struct { + io.Writer +} + +func (nopWriteCloser) Close() error { return nil } + +type nopWriteCloserReaderFrom struct { + io.Writer +} + +func (nopWriteCloserReaderFrom) Close() error { return nil } + +func (c nopWriteCloserReaderFrom) ReadFrom(r io.Reader) (int64, error) { + return c.Writer.(io.ReaderFrom).ReadFrom(r) +} diff --git a/proxywriter_test.go b/proxywriter_test.go new file mode 100644 index 0000000..aafc6bc --- /dev/null +++ b/proxywriter_test.go @@ -0,0 +1,120 @@ +package mpb_test + +import ( + "bytes" + "io" + "strings" + "testing" + + "github.com/vbauerster/mpb/v8" +) + +type testWriter struct { + io.Writer + called bool +} + +func (w *testWriter) Write(p []byte) (n int, err error) { + w.called = true + return w.Writer.Write(p) +} + +func TestProxyWriter(t *testing.T) { + p := mpb.New(mpb.WithOutput(io.Discard)) + + var buf bytes.Buffer + tw := &testWriter{&buf, false} + + bar := p.AddBar(int64(len(content))) + + _, err := io.Copy(bar.ProxyWriter(tw), strings.NewReader(content)) + if err != nil { + t.Errorf("io.Copy: %s\n", err.Error()) + } + + p.Wait() + + if !tw.called { + t.Error("Read not called") + } + + if got := buf.String(); got != content { + t.Errorf("Expected content: %s, got: %s\n", content, got) + } +} + +type testWriteCloser struct { + io.Writer + called bool +} + +func (w *testWriteCloser) Close() error { + w.called = true + return nil +} + +func TestProxyWriteCloser(t *testing.T) { + p := mpb.New(mpb.WithOutput(io.Discard)) + + var buf bytes.Buffer + tw := &testWriteCloser{&buf, false} + + bar := p.AddBar(int64(len(content))) + + wc := bar.ProxyWriter(tw) + _, err := io.Copy(wc, strings.NewReader(content)) + if err != nil { + t.Errorf("io.Copy: %s\n", err.Error()) + } + _ = wc.Close() + + p.Wait() + + if !tw.called { + t.Error("Close not called") + } +} + +type testWriterReadFrom struct { + io.Writer + called bool +} + +func (w *testWriterReadFrom) ReadFrom(r io.Reader) (n int64, err error) { + w.called = true + return w.Writer.(io.ReaderFrom).ReadFrom(r) +} + +type dumbReader struct { + r *strings.Reader +} + +func (r dumbReader) Read(p []byte) (int, error) { + return r.r.Read(p) +} + +func TestProxyWriterReadFrom(t *testing.T) { + p := mpb.New(mpb.WithOutput(io.Discard)) + + var buf bytes.Buffer + tw := &testWriterReadFrom{&buf, false} + + bar := p.New(int64(len(content)), mpb.NopStyle()) + + // To trigger ReadFrom, WriteTo needs to be hidden, hence a dumb wrapper + dr := dumbReader{strings.NewReader(content)} + _, err := io.Copy(bar.ProxyWriter(tw), dr) + if err != nil { + t.Errorf("io.Copy: %s\n", err.Error()) + } + + p.Wait() + + if !tw.called { + t.Error("ReadFrom not called") + } + + if got := buf.String(); got != content { + t.Errorf("Expected content: %s, got: %s\n", content, got) + } +}