diff --git a/bar.go b/bar.go index b508428..2523b2f 100644 --- a/bar.go +++ b/bar.go @@ -26,6 +26,7 @@ stateReqCh chan chan state decoratorCh chan *decorator flushedCh chan struct{} + removeReqCh chan struct{} done chan struct{} lastState state @@ -66,6 +67,7 @@ stateReqCh: make(chan chan state), decoratorCh: make(chan *decorator), flushedCh: make(chan struct{}), + removeReqCh: make(chan struct{}), done: make(chan struct{}), } go b.server(ctx, wg, total) @@ -160,12 +162,6 @@ b.stateReqCh <- ch state := <-ch return state.current -} - -func (b *Bar) stop() { - if !b.isDone() { - close(b.done) - } } // InProgress returns true, while progress is running @@ -238,16 +234,22 @@ case state.trimRightSpace = <-b.trimRightCh: case <-b.flushedCh: if completed { - b.lastState = state - b.stop() + b.stop(&state) return } + case <-b.removeReqCh: + b.stop(&state) + return case <-ctx.Done(): - b.lastState = state - b.stop() + b.stop(&state) return } } +} + +func (b *Bar) stop(s *state) { + b.lastState = *s + close(b.done) } func (b *Bar) draw(s state, termWidth int) []byte { @@ -327,6 +329,12 @@ return buf } + +// func (b *Bar) closeDone() { +// if !b.isDone() { +// close(b.done) +// } +// } func (b *Bar) isDone() bool { select { diff --git a/progress.go b/progress.go index 8b0d190..c50c6e0 100644 --- a/progress.go +++ b/progress.go @@ -1,6 +1,7 @@ package mpb import ( + "context" "errors" "io" "os" @@ -35,6 +36,8 @@ // Progress represents the container that renders Progress bars type Progress struct { + // Context for canceling bars rendering + ctx context.Context // WaitGroup for internal rendering sync wg *sync.WaitGroup @@ -42,11 +45,11 @@ width int sort SortType - op chan *operation + operationCh chan *operation rrChangeReqCh chan time.Duration outChangeReqCh chan io.Writer - countReqCh chan chan int - allDone chan struct{} + barCountReqCh chan chan int + done chan struct{} } type operation struct { @@ -55,16 +58,22 @@ result chan bool } -// New returns a new progress bar with defaults -func New() *Progress { +// New creates new Progress instance, which will orchestrate bars rendering +// process. It acceepts context.Context, for cancellation. +// If you don't plan to cancel, it is safe to feed with nil +func New(ctx context.Context) *Progress { + if ctx == nil { + ctx = context.Background() + } p := &Progress{ width: 70, - op: make(chan *operation), + operationCh: make(chan *operation), rrChangeReqCh: make(chan time.Duration), outChangeReqCh: make(chan io.Writer), - countReqCh: make(chan chan int), - allDone: make(chan struct{}), + barCountReqCh: make(chan chan int), + done: make(chan struct{}), wg: new(sync.WaitGroup), + ctx: ctx, } go p.server(cwriter.New(os.Stdout), time.NewTicker(rr*time.Millisecond)) return p @@ -82,7 +91,7 @@ // SetOut sets underlying writer of progress. Default is os.Stdout // pancis, if called on stopped Progress instance, i.e after Stop() func (p *Progress) SetOut(w io.Writer) *Progress { - if p.isAllDone() { + if p.isDone() { panic(ErrCallAfterStop) } if w == nil { @@ -95,12 +104,23 @@ // RefreshRate overrides default (30ms) refreshRate value // pancis, if called on stopped Progress instance, i.e after Stop() func (p *Progress) RefreshRate(d time.Duration) *Progress { - if p.isAllDone() { + if p.isDone() { panic(ErrCallAfterStop) } p.rrChangeReqCh <- d return p } + +// func (p *Progress) WithContext(ctx context.Context) *Progress { +// if p.BarCount() > 0 { +// panic("cannot apply ctx after AddBar has been called") +// } +// if ctx == nil { +// panic("nil context") +// } +// p.ctx = ctx +// return p +// } // WithSort sorts the bars, while redering func (p *Progress) WithSort(sort SortType) *Progress { @@ -110,13 +130,13 @@ // AddBar creates a new progress bar and adds to the container // pancis, if called on stopped Progress instance, i.e after Stop() -func (p *Progress) AddBar(total int) *Bar { - if p.isAllDone() { +func (p *Progress) AddBar(total int64) *Bar { + if p.isDone() { panic(ErrCallAfterStop) } result := make(chan bool) - bar := newBar(total, p.width, p.wg) - p.op <- &operation{opBarAdd, bar, result} + bar := newBar(p.ctx, p.wg, total, p.width) + p.operationCh <- &operation{opBarAdd, bar, result} if <-result { p.wg.Add(1) } @@ -126,31 +146,31 @@ // RemoveBar removes bar at any time // pancis, if called on stopped Progress instance, i.e after Stop() func (p *Progress) RemoveBar(b *Bar) bool { - if p.isAllDone() { + if p.isDone() { panic(ErrCallAfterStop) } result := make(chan bool) - p.op <- &operation{opBarRemove, b, result} + p.operationCh <- &operation{opBarRemove, b, result} return <-result } -// BarsCount returns bars count in the container -// pancis, if called on stopped Progress instance, i.e after Stop() -func (p *Progress) BarsCount() int { - if p.isAllDone() { +// BarCount returns bars count in the container. +// Pancis if called on stopped Progress instance, i.e after Stop() +func (p *Progress) BarCount() int { + if p.isDone() { panic(ErrCallAfterStop) } respCh := make(chan int) - p.countReqCh <- respCh + p.barCountReqCh <- respCh return <-respCh } // Stop waits for bars to finish rendering and stops the rendering goroutine func (p *Progress) Stop() { - if !p.isAllDone() { - close(p.allDone) - p.wg.Wait() - close(p.op) + p.wg.Wait() + if !p.isDone() { + close(p.done) + close(p.operationCh) } } @@ -162,12 +182,9 @@ case w := <-p.outChangeReqCh: cw.Flush() cw = cwriter.New(w) - case op, ok := <-p.op: + case op, ok := <-p.operationCh: if !ok { t.Stop() - for _, b := range bars { - b.Stop() - } return } switch op.kind { @@ -180,13 +197,13 @@ if b == op.bar { bars = append(bars[:i], bars[i+1:]...) ok = true - b.Stop() + b.removeReqCh <- struct{}{} break } } op.result <- ok } - case respCh := <-p.countReqCh: + case respCh := <-p.barCountReqCh: respCh <- len(bars) case <-t.C: width, _ := cwriter.TerminalWidth() @@ -210,13 +227,17 @@ case d := <-p.rrChangeReqCh: t.Stop() t = time.NewTicker(d) + case <-p.ctx.Done(): + t.Stop() + close(p.done) + return } } } -func (p *Progress) isAllDone() bool { +func (p *Progress) isDone() bool { select { - case <-p.allDone: + case <-p.done: return true default: return false diff --git a/progress_test.go b/progress_test.go index fe4b443..3f60ed6 100644 --- a/progress_test.go +++ b/progress_test.go @@ -7,7 +7,7 @@ func TestAddBar(t *testing.T) { var buf bytes.Buffer - p := New().SetWidth(60).SetOut(&buf) + p := New(nil).SetWidth(60).SetOut(&buf) count := p.BarsCount() if count != 0 { t.Errorf("Count want: %q, got: %q\n", 0, count) @@ -24,7 +24,7 @@ } func TestRemoveBar(t *testing.T) { - p := New() + p := New(nil) b := p.AddBar(10) if !p.RemoveBar(b) {