package mpb_test
import (
"bytes"
"context"
"errors"
"io"
"math/rand"
"strings"
"testing"
"time"
"github.com/vbauerster/mpb/v8"
"github.com/vbauerster/mpb/v8/decor"
)
const (
timeout = 300 * time.Millisecond
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func TestBarCount(t *testing.T) {
p := mpb.New(mpb.WithOutput(io.Discard))
b := p.AddBar(0, mpb.BarRemoveOnComplete())
if count := p.BarCount(); count != 1 {
t.Errorf("BarCount want: %d, got: %d\n", 1, count)
}
b.SetTotal(100, true)
b.Wait()
if count := p.BarCount(); count != 0 {
t.Errorf("BarCount want: %d, got: %d\n", 0, count)
}
p.Wait()
}
func TestBarAbort(t *testing.T) {
shutdown := make(chan struct{})
p := mpb.New(mpb.WithShutdownNotifier(shutdown), mpb.WithOutput(io.Discard))
n := 2
bars := make([]*mpb.Bar, n)
for i := 0; i < n; i++ {
b := p.AddBar(100)
switch i {
case n - 1:
var abortCalledTimes int
for j := 0; !b.Aborted(); j++ {
if j >= 10 {
b.Abort(true)
abortCalledTimes++
} else {
b.Increment()
}
}
if abortCalledTimes != 1 {
t.Errorf("Expected abortCalledTimes: %d, got: %d\n", 1, abortCalledTimes)
}
b.Wait()
count := p.BarCount()
if count != 1 {
t.Errorf("BarCount want: %d, got: %d\n", 1, count)
}
default:
go func() {
for !b.Completed() {
b.Increment()
time.Sleep(randomDuration(100 * time.Millisecond))
}
}()
}
bars[i] = b
}
go p.Wait()
bars[0].Abort(false)
select {
case <-shutdown:
case <-time.After(timeout):
t.Errorf("Progress didn't shutdown after %v", timeout)
}
}
func TestWithContext(t *testing.T) {
shutdown := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
p := mpb.NewWithContext(ctx,
mpb.WithShutdownNotifier(shutdown),
mpb.WithOutput(io.Discard),
)
_ = p.AddBar(0) // never complete bar
_ = p.AddBar(0) // never complete bar
go func() {
time.Sleep(randomDuration(100 * time.Millisecond))
cancel()
p.Wait()
}()
select {
case <-shutdown:
case <-time.After(timeout):
t.Errorf("Progress didn't shutdown after %v", timeout)
}
}
func TestProgressShutdownsWithErrFiller(t *testing.T) {
var debug bytes.Buffer
shutdown := make(chan struct{})
p := mpb.New(
mpb.WithShutdownNotifier(shutdown),
mpb.WithOutput(io.Discard),
mpb.WithDebugOutput(&debug),
)
testError := errors.New("test error")
bar := p.AddBar(100,
mpb.BarFillerMiddleware(func(base mpb.BarFiller) mpb.BarFiller {
return mpb.BarFillerFunc(func(w io.Writer, st decor.Statistics) error {
if st.Current >= 42 {
return testError
}
return base.Fill(w, st)
})
}),
)
go func() {
for bar.IsRunning() {
bar.Increment()
}
}()
select {
case <-shutdown:
if err := strings.TrimSpace(debug.String()); err != testError.Error() {
t.Errorf("Expected err: %q, got %q\n", testError.Error(), err)
}
case <-time.After(timeout):
t.Errorf("Progress didn't shutdown after %v", timeout)
}
p.Wait()
}
func randomDuration(max time.Duration) time.Duration {
return time.Duration(rand.Intn(10)+1) * max / 10
}