diff --git a/progress_test.go b/progress_test.go index 64e0fa6..bf745ac 100644 --- a/progress_test.go +++ b/progress_test.go @@ -5,7 +5,6 @@ "context" "io/ioutil" "math/rand" - "sync" "testing" "time" @@ -20,59 +19,65 @@ func TestBarCount(t *testing.T) { p := mpb.New(mpb.WithOutput(ioutil.Discard)) - var wg sync.WaitGroup - wg.Add(1) + check := make(chan struct{}) b := p.AddBar(100) go func() { - rng := rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < 100; i++ { - if i == 33 { - wg.Done() + if i == 10 { + close(check) } b.Increment() - time.Sleep((time.Duration(rng.Intn(10)+1) * (10 * time.Millisecond)) / 2) + time.Sleep((time.Duration(rand.Intn(10)+1) * (10 * time.Millisecond)) / 2) } }() - wg.Wait() + <-check count := p.BarCount() if count != 1 { t.Errorf("BarCount want: %q, got: %q\n", 1, count) } - b.Abort(true) + b.Abort(false) p.Wait() } func TestBarAbort(t *testing.T) { + n := 2 p := mpb.New(mpb.WithOutput(ioutil.Discard)) - - var wg sync.WaitGroup - wg.Add(1) - bars := make([]*mpb.Bar, 3) - for i := 0; i < 3; i++ { + bars := make([]*mpb.Bar, n) + for i := 0; i < n; i++ { b := p.AddBar(100) - rng := rand.New(rand.NewSource(time.Now().UnixNano())) - go func(n int) { - for i := 0; !b.Completed(); i++ { - if n == 0 && i >= 33 { + switch i { + case n - 1: + var abortCalledTimes int + for j := 0; !b.Completed(); j++ { + if j >= 33 { b.Abort(true) - wg.Done() + abortCalledTimes++ + } else { + b.Increment() + time.Sleep((time.Duration(rand.Intn(10)+1) * (10 * time.Millisecond)) / 2) } - b.Increment() - time.Sleep((time.Duration(rng.Intn(10)+1) * (10 * time.Millisecond)) / 2) } - }(i) + if abortCalledTimes != 1 { + t.Errorf("Expected abortCalledTimes: %d, got: %d\n", 1, abortCalledTimes) + } + count := p.BarCount() + if count != 1 { + t.Errorf("BarCount want: %d, got: %d\n", 1, count) + } + default: + go func() { + for !b.Completed() { + b.Increment() + time.Sleep((time.Duration(rand.Intn(10)+1) * (10 * time.Millisecond)) / 2) + } + }() + } bars[i] = b } - wg.Wait() - count := p.BarCount() - if count != 2 { - t.Errorf("BarCount want: %d, got: %d\n", 2, count) - } - bars[1].Abort(true) - bars[2].Abort(true) + bars[0].Abort(false) p.Wait() } @@ -140,7 +145,6 @@ ) go func() { <-ready - rng := rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < total; i++ { start := time.Now() if id := bar.ID(); id > 1 && i >= 42 { @@ -150,7 +154,7 @@ bar.Abort(false) } } - time.Sleep((time.Duration(rng.Intn(10)+1) * (50 * time.Millisecond)) / 2) + time.Sleep((time.Duration(rand.Intn(10)+1) * (50 * time.Millisecond)) / 2) bar.IncrInt64(rand.Int63n(5) + 1) bar.DecoratorEwmaUpdate(time.Since(start)) }