diff --git a/bar.go b/bar.go index a61ddab..8cbd625 100644 --- a/bar.go +++ b/bar.go @@ -2,6 +2,7 @@ import ( "bytes" + "context" "fmt" "io" "io/ioutil" @@ -76,11 +77,11 @@ ) func newBar( + ctx context.Context, wg *sync.WaitGroup, filler Filler, id, width int, total int64, - cancel <-chan struct{}, options ...BarOption, ) *Bar { if total <= 0 { @@ -124,7 +125,7 @@ b.priority = b.runningBar.priority } - go b.serve(wg, s, cancel) + go b.serve(ctx, wg, s) return b } @@ -249,8 +250,9 @@ } } -func (b *Bar) serve(wg *sync.WaitGroup, s *bState, cancel <-chan struct{}) { +func (b *Bar) serve(ctx context.Context, wg *sync.WaitGroup, s *bState) { defer wg.Done() + cancel := ctx.Done() for { select { case op := <-b.operateState: diff --git a/examples/cancel/main.go b/examples/cancel/main.go index 001e201..9da8ed3 100644 --- a/examples/cancel/main.go +++ b/examples/cancel/main.go @@ -1,5 +1,3 @@ -//+build go1.7 - package main import ( diff --git a/options.go b/options.go index a477948..a0e5bf8 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,7 @@ package mpb import ( + "context" "io" "sync" "time" @@ -9,7 +10,7 @@ ) // ProgressOption is a function option which changes the default behavior of -// progress pool, if passed to mpb.New(...ProgressOption) +// progress pool, if passed to mpb.New(...ProgressOption). type ProgressOption func(*pState) // WithWaitGroup provides means to have a single joint point. @@ -22,7 +23,7 @@ } } -// WithWidth overrides default width 80 +// WithWidth overrides default width 80. func WithWidth(w int) ProgressOption { return func(s *pState) { if w >= 0 { @@ -31,7 +32,7 @@ } } -// WithRefreshRate overrides default 120ms refresh rate +// WithRefreshRate overrides default 120ms refresh rate. func WithRefreshRate(d time.Duration) ProgressOption { return func(s *pState) { if d < 10*time.Millisecond { @@ -49,11 +50,13 @@ } } -// WithCancel provide your cancel channel, -// which you plan to close at some point. -func WithCancel(ch <-chan struct{}) ProgressOption { +// WithContext provided context will be used for cancellation purposes. +func WithContext(ctx context.Context) ProgressOption { return func(s *pState) { - s.cancel = ch + if ctx == nil { + return + } + s.ctx = ctx } } @@ -64,7 +67,7 @@ } } -// WithOutput overrides default output os.Stdout +// WithOutput overrides default output os.Stdout. func WithOutput(w io.Writer) ProgressOption { return func(s *pState) { if w == nil { diff --git a/options_go1.7.go b/options_go1.7.go deleted file mode 100644 index ca9a5ba..0000000 --- a/options_go1.7.go +++ /dev/null @@ -1,15 +0,0 @@ -//+build go1.7 - -package mpb - -import "context" - -// WithContext provided context will be used for cancellation purposes -func WithContext(ctx context.Context) ProgressOption { - return func(s *pState) { - if ctx == nil { - panic("ctx must not be nil") - } - s.cancel = ctx.Done() - } -} diff --git a/progress.go b/progress.go index 79c0a9b..8b30185 100644 --- a/progress.go +++ b/progress.go @@ -2,6 +2,7 @@ import ( "container/heap" + "context" "fmt" "io" "io/ioutil" @@ -40,10 +41,10 @@ pMatrix map[int][]chan int aMatrix map[int][]chan int - // following are provided by user + // following are provided/overrided by user + ctx context.Context uwg *sync.WaitGroup manualRefreshCh <-chan time.Time - cancel <-chan struct{} shutdownNotifier chan struct{} waitBars map[*Bar]*Bar debugOut io.Writer @@ -55,6 +56,7 @@ pq := make(priorityQueue, 0) heap.Init(&pq) s := &pState{ + ctx: context.Background(), bHeap: &pq, width: pwidth, cw: cwriter.New(os.Stdout), @@ -101,7 +103,7 @@ result := make(chan *Bar) select { case p.operateState <- func(s *pState) { - b := newBar(p.wg, filler, s.idCounter, s.width, total, s.cancel, options...) + b := newBar(s.ctx, p.wg, filler, s.idCounter, s.width, total, options...) if b.runningBar != nil { s.waitBars[b.runningBar] = b } else { diff --git a/progress_go1.7_test.go b/progress_go1.7_test.go deleted file mode 100644 index 6b4adda..0000000 --- a/progress_go1.7_test.go +++ /dev/null @@ -1,50 +0,0 @@ -//+build go1.7 - -package mpb_test - -import ( - "context" - "io/ioutil" - "testing" - "time" - - "github.com/vbauerster/mpb" -) - -func TestWithContext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - shutdown := make(chan struct{}) - p := mpb.New( - mpb.WithOutput(ioutil.Discard), - mpb.WithContext(ctx), - mpb.WithShutdownNotifier(shutdown), - ) - - total := 1000 - numBars := 3 - bars := make([]*mpb.Bar, 0, numBars) - for i := 0; i < numBars; i++ { - bar := p.AddBar(int64(total)) - bars = append(bars, bar) - go func() { - for !bar.Completed() { - time.Sleep(randomDuration(40 * time.Millisecond)) - bar.Increment() - } - }() - } - - time.AfterFunc(100*time.Millisecond, cancel) - - p.Wait() - for _, bar := range bars { - if bar.Current() >= int64(total) { - t.Errorf("bar %d: total = %d, current = %d\n", bar.ID(), total, bar.Current()) - } - } - select { - case <-shutdown: - case <-time.After(100 * time.Millisecond): - t.Error("Progress didn't stop") - } -} diff --git a/progress_test.go b/progress_test.go index 2331a3f..de34785 100644 --- a/progress_test.go +++ b/progress_test.go @@ -2,6 +2,7 @@ import ( "bytes" + "context" "fmt" "io/ioutil" "math/rand" @@ -9,6 +10,7 @@ "testing" "time" + "github.com/vbauerster/mpb" . "github.com/vbauerster/mpb" "github.com/vbauerster/mpb/cwriter" ) @@ -80,35 +82,38 @@ p.Wait() } -func TestWithCancel(t *testing.T) { - cancel := make(chan struct{}) +func TestWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) shutdown := make(chan struct{}) - p := New( - WithOutput(ioutil.Discard), - WithCancel(cancel), - WithShutdownNotifier(shutdown), + p := mpb.New( + mpb.WithOutput(ioutil.Discard), + mpb.WithContext(ctx), + mpb.WithRefreshRate(50*time.Millisecond), + mpb.WithShutdownNotifier(shutdown), ) - for i := 0; i < 2; i++ { - bar := p.AddBar(int64(1000), BarID(i)) + total := 10000 + numBars := 3 + bars := make([]*mpb.Bar, 0, numBars) + for i := 0; i < numBars; i++ { + bar := p.AddBar(int64(total)) + bars = append(bars, bar) go func() { for !bar.Completed() { + bar.Increment() time.Sleep(randomDuration(100 * time.Millisecond)) - bar.Increment() } }() } - time.AfterFunc(100*time.Millisecond, func() { - close(cancel) - }) + time.Sleep(50 * time.Millisecond) + cancel() p.Wait() - select { case <-shutdown: - case <-time.After(200 * time.Millisecond): - t.FailNow() + case <-time.After(100 * time.Millisecond): + t.Error("Progress didn't stop") } }