diff --git a/progress.go b/progress.go index efb8ab3..5902e73 100644 --- a/progress.go +++ b/progress.go @@ -25,12 +25,11 @@ type Progress struct { ctx context.Context uwg *sync.WaitGroup - cwg *sync.WaitGroup bwg *sync.WaitGroup operateState chan func(*pState) interceptIo chan func(io.Writer) done chan struct{} - once sync.Once + shutdown chan struct{} cancel func() } @@ -73,15 +72,14 @@ // method has been called. func NewWithContext(ctx context.Context, options ...ContainerOption) *Progress { s := &pState{ - rows: make([]io.Reader, 0, 64), - pool: make([]*Bar, 0, 64), - refreshRate: defaultRefreshRate, - popPriority: math.MinInt32, - manualRefresh: make(chan interface{}), - shutdownNotifier: make(chan struct{}), - queueBars: make(map[*Bar]*Bar), - output: os.Stdout, - debugOut: io.Discard, + rows: make([]io.Reader, 0, 64), + pool: make([]*Bar, 0, 64), + refreshRate: defaultRefreshRate, + popPriority: math.MinInt32, + manualRefresh: make(chan interface{}), + queueBars: make(map[*Bar]*Bar), + output: os.Stdout, + debugOut: io.Discard, } for _, opt := range options { @@ -94,7 +92,6 @@ p := &Progress{ ctx: ctx, uwg: s.uwg, - cwg: new(sync.WaitGroup), bwg: new(sync.WaitGroup), operateState: make(chan func(*pState)), interceptIo: make(chan func(io.Writer)), @@ -102,7 +99,13 @@ cancel: cancel, } - p.cwg.Add(1) + if s.shutdownNotifier != nil { + p.shutdown = s.shutdownNotifier + s.shutdownNotifier = nil + } else { + p.shutdown = make(chan struct{}) + } + go p.serve(s, cwriter.New(s.output)) return p } @@ -225,12 +228,8 @@ p.uwg.Wait() } - // wait for bars to quit, if any p.bwg.Wait() - // shutdown - p.once.Do(p.shutdown) - // wait for container to quit - p.cwg.Wait() + p.Shutdown() } // Shutdown cancels any running bar immediately and then shutdowns (*Progress) @@ -238,17 +237,42 @@ // are doing. Proper way to shutdown is to call (*Progress).Wait() instead. func (p *Progress) Shutdown() { p.cancel() - p.bwg.Wait() - p.once.Do(p.shutdown) - p.cwg.Wait() -} - -func (p *Progress) shutdown() { - close(p.done) + <-p.shutdown +} + +func (p *Progress) newTicker(s *pState) chan time.Time { + ch := make(chan time.Time) + go func() { + var autoRefresh <-chan time.Time + if !s.disableAutoRefresh && !s.outputDiscarded { + if s.renderDelay != nil { + <-s.renderDelay + } + ticker := time.NewTicker(s.refreshRate) + defer ticker.Stop() + autoRefresh = ticker.C + } + for { + select { + case t := <-autoRefresh: + ch <- t + case x := <-s.manualRefresh: + if t, ok := x.(time.Time); ok { + ch <- t + } else { + ch <- time.Now() + } + case <-p.ctx.Done(): + close(p.done) + return + } + } + }() + return ch } func (p *Progress) serve(s *pState, cw *cwriter.Writer) { - defer p.cwg.Done() + defer close(p.shutdown) render := func() error { if s.bHeap.Len() == 0 { @@ -257,7 +281,7 @@ return s.render(cw) } - refreshCh := s.newTicker(p.done) + refreshCh := p.newTicker(s) for { select { @@ -268,18 +292,12 @@ case <-refreshCh: err := render() if err != nil { - go func() { - p.bwg.Wait() - p.once.Do(p.shutdown) - }() - render = func() error { - s.heapUpdated = false - return nil - } + s.heapUpdated = false + render = func() error { return nil } _, _ = fmt.Fprintln(s.debugOut, err) p.cancel() // cancel all bars } - case <-s.shutdownNotifier: + case <-p.done: for s.heapUpdated { err := render() if err != nil { @@ -323,6 +341,7 @@ b := heap.Pop(&s.bHeap).(*Bar) frame := <-b.frameCh if frame.err != nil { + s.rows = s.rows[:0] return frame.err } var usedRows int @@ -400,37 +419,6 @@ return err } -func (s *pState) newTicker(done <-chan struct{}) chan time.Time { - ch := make(chan time.Time) - go func() { - var autoRefresh <-chan time.Time - if !s.disableAutoRefresh && !s.outputDiscarded { - if s.renderDelay != nil { - <-s.renderDelay - } - ticker := time.NewTicker(s.refreshRate) - defer ticker.Stop() - autoRefresh = ticker.C - } - for { - select { - case t := <-autoRefresh: - ch <- t - case x := <-s.manualRefresh: - if t, ok := x.(time.Time); ok { - ch <- t - } else { - ch <- time.Now() - } - case <-done: - close(s.shutdownNotifier) - return - } - } - }() - return ch -} - func (s *pState) updateSyncMatrix() { s.pMatrix = make(map[int][]chan int) s.aMatrix = make(map[int][]chan int) @@ -484,8 +472,8 @@ } func syncWidth(wg *sync.WaitGroup, matrix map[int][]chan int) { + wg.Add(len(matrix)) for _, column := range matrix { - wg.Add(1) go maxWidthDistributor(wg, column) } }