diff --git a/progress_test.go b/progress_test.go index 7e5db19..87446b4 100644 --- a/progress_test.go +++ b/progress_test.go @@ -5,7 +5,6 @@ "context" "errors" "io" - "math/rand" "strings" "testing" "time" @@ -18,114 +17,48 @@ timeout = 300 * time.Millisecond ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func TestBarCount(t *testing.T) { - p := mpb.New(mpb.WithOutput(io.Discard)) - - b := p.AddBar(0, mpb.BarRemoveOnComplete()) - - if count := p.BarCount(); count != 1 { - t.Errorf("BarCount want: %d, got: %d\n", 1, count) - } - - b.SetTotal(100, true) - - b.Wait() - - if count := p.BarCount(); count != 0 { - t.Errorf("BarCount want: %d, got: %d\n", 0, count) - } +func TestWithContext(t *testing.T) { + shutdown := make(chan interface{}) + ctx, cancel := context.WithCancel(context.Background()) + p := mpb.NewWithContext(ctx, + mpb.WithShutdownNotifier(shutdown), + ) + _ = p.AddBar(0) // never complete bar + _ = p.AddBar(0) // never complete bar + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() p.Wait() -} - -func TestBarAbort(t *testing.T) { - shutdown := make(chan struct{}) - p := mpb.New(mpb.WithShutdownNotifier(shutdown), mpb.WithOutput(io.Discard)) - n := 2 - bars := make([]*mpb.Bar, n) - for i := 0; i < n; i++ { - b := p.AddBar(100) - switch i { - case n - 1: - var abortCalledTimes int - for j := 0; !b.Aborted(); j++ { - if j >= 10 { - b.Abort(true) - abortCalledTimes++ - } else { - b.Increment() - } - } - if abortCalledTimes != 1 { - t.Errorf("Expected abortCalledTimes: %d, got: %d\n", 1, abortCalledTimes) - } - b.Wait() - count := p.BarCount() - if count != 1 { - t.Errorf("BarCount want: %d, got: %d\n", 1, count) - } - default: - go func() { - for !b.Completed() { - b.Increment() - time.Sleep(randomDuration(100 * time.Millisecond)) - } - }() - } - bars[i] = b - } - - go p.Wait() - - bars[0].Abort(false) select { - case <-shutdown: + case v := <-shutdown: + if l := len(v.([]*mpb.Bar)); l != 2 { + t.Errorf("Expected len of bars: %d, got: %d", 2, l) + } case <-time.After(timeout): t.Errorf("Progress didn't shutdown after %v", timeout) } } -func TestWithContext(t *testing.T) { - shutdown := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - p := mpb.NewWithContext(ctx, - mpb.WithShutdownNotifier(shutdown), - mpb.WithOutput(io.Discard), - ) - _ = p.AddBar(0) // never complete bar - _ = p.AddBar(0) // never complete bar - go func() { - time.Sleep(randomDuration(100 * time.Millisecond)) - cancel() - p.Wait() - }() - - select { - case <-shutdown: - case <-time.After(timeout): - t.Errorf("Progress didn't shutdown after %v", timeout) - } -} - -func TestProgressShutdownsWithErrFiller(t *testing.T) { +func TestShutdownsWithErrFiller(t *testing.T) { var debug bytes.Buffer - shutdown := make(chan struct{}) + shutdown := make(chan interface{}) p := mpb.New( mpb.WithShutdownNotifier(shutdown), mpb.WithOutput(io.Discard), mpb.WithDebugOutput(&debug), + mpb.ForceAutoRefresh(), ) + var errReturnCount int testError := errors.New("test error") bar := p.AddBar(100, mpb.BarFillerMiddleware(func(base mpb.BarFiller) mpb.BarFiller { return mpb.BarFillerFunc(func(w io.Writer, st decor.Statistics) error { - if st.Current >= 42 { + if st.Current >= 22 { + errReturnCount++ return testError } return base.Fill(w, st) @@ -136,20 +69,97 @@ go func() { for bar.IsRunning() { bar.Increment() + time.Sleep(10 * time.Millisecond) } }() + p.Wait() + + if errReturnCount != 1 { + t.Errorf("Expected errReturnCount: %d, got: %d\n", 1, errReturnCount) + } + select { - case <-shutdown: + case v := <-shutdown: + if l := len(v.([]*mpb.Bar)); l != 0 { + t.Errorf("Expected len of bars: %d, got: %d\n", 0, l) + } if err := strings.TrimSpace(debug.String()); err != testError.Error() { t.Errorf("Expected err: %q, got %q\n", testError.Error(), err) } case <-time.After(timeout): t.Errorf("Progress didn't shutdown after %v", timeout) } - p.Wait() } -func randomDuration(max time.Duration) time.Duration { - return time.Duration(rand.Intn(10)+1) * max / 10 +func TestShutdownAfterBarAbortWithDrop(t *testing.T) { + shutdown := make(chan interface{}) + p := mpb.New( + mpb.WithShutdownNotifier(shutdown), + mpb.WithOutput(io.Discard), + mpb.ForceAutoRefresh(), + ) + b := p.AddBar(100) + + var count int + for i := 0; !b.Aborted(); i++ { + if i >= 10 { + count++ + b.Abort(true) + } else { + b.Increment() + time.Sleep(10 * time.Millisecond) + } + } + + p.Wait() + + if count != 1 { + t.Errorf("Expected count: %d, got: %d", 1, count) + } + + select { + case v := <-shutdown: + if l := len(v.([]*mpb.Bar)); l != 0 { + t.Errorf("Expected len of bars: %d, got: %d", 0, l) + } + case <-time.After(timeout): + t.Errorf("Progress didn't shutdown after %v", timeout) + } } + +func TestShutdownAfterBarAbortWithNoDrop(t *testing.T) { + shutdown := make(chan interface{}) + p := mpb.New( + mpb.WithShutdownNotifier(shutdown), + mpb.WithOutput(io.Discard), + mpb.ForceAutoRefresh(), + ) + b := p.AddBar(100) + + var count int + for i := 0; !b.Aborted(); i++ { + if i >= 10 { + count++ + b.Abort(false) + } else { + b.Increment() + time.Sleep(10 * time.Millisecond) + } + } + + p.Wait() + + if count != 1 { + t.Errorf("Expected count: %d, got: %d", 1, count) + } + + select { + case v := <-shutdown: + if l := len(v.([]*mpb.Bar)); l != 1 { + t.Errorf("Expected len of bars: %d, got: %d", 1, l) + } + case <-time.After(timeout): + t.Errorf("Progress didn't shutdown after %v", timeout) + } +}