diff --git a/progress_test.go b/progress_test.go index b81576e..12c5d60 100644 --- a/progress_test.go +++ b/progress_test.go @@ -1,9 +1,12 @@ package mpb_test import ( + "bytes" "context" + "errors" "io" "math/rand" + "strings" "testing" "time" @@ -181,6 +184,44 @@ } } +func TestProgressShutdownsWithErrFiller(t *testing.T) { + var debug bytes.Buffer + shutdown := make(chan struct{}) + p := mpb.New( + mpb.WithShutdownNotifier(shutdown), + mpb.WithOutput(io.Discard), + mpb.WithDebugOutput(&debug), + ) + + testError := errors.New("test error") + bar := p.AddBar(100, + mpb.BarFillerMiddleware(func(base mpb.BarFiller) mpb.BarFiller { + return mpb.BarFillerFunc(func(w io.Writer, st decor.Statistics) error { + if st.Current >= 42 { + return testError + } + return base.Fill(w, st) + }) + }), + ) + + for bar.IsRunning() { + time.Sleep(randomDuration(100 * time.Millisecond)) + bar.Increment() + } + + go p.Wait() + + select { + case <-shutdown: + if err := strings.TrimSpace(debug.String()); err != testError.Error() { + t.Errorf("Expected err: %q, got %q\n", testError.Error(), err) + } + case <-time.After(timeout): + t.Errorf("Progress didn't shutdown after %v", timeout) + } +} + func randomDuration(max time.Duration) time.Duration { return time.Duration(rand.Intn(10)+1) * max / 10 }