diff --git a/export_test.go b/export_test.go index 7f5cb84..937ce57 100644 --- a/export_test.go +++ b/export_test.go @@ -2,3 +2,5 @@ // make syncWidth func public in test var SyncWidth = syncWidth + +type PriorityQueue = priorityQueue diff --git a/progress_test.go b/progress_test.go index 08658be..f13b821 100644 --- a/progress_test.go +++ b/progress_test.go @@ -2,6 +2,7 @@ import ( "bytes" + "container/heap" "context" "errors" "io" @@ -164,3 +165,44 @@ t.Errorf("Progress didn't shutdown after %v", timeout) } } + +func TestUpdateBarPriority(t *testing.T) { + shutdown := make(chan interface{}) + ctx, cancel := context.WithCancel(context.Background()) + p := mpb.NewWithContext(ctx, + mpb.WithOutput(io.Discard), + mpb.WithShutdownNotifier(shutdown), + ) + a := p.AddBar(100, mpb.BarPriority(1)) + b := p.AddBar(100, mpb.BarPriority(2)) + c := p.AddBar(100, mpb.BarPriority(3)) + + identity := map[*mpb.Bar]string{ + a: "a", + b: "b", + c: "c", + } + + p.UpdateBarPriority(c, 2) + p.UpdateBarPriority(b, 3) + + cancel() + + bars := (<-shutdown).([]*mpb.Bar) + if l := len(bars); l != 3 { + t.Errorf("Expected len of bars: %d, got: %d", 3, l) + } + + p.Wait() + pq := mpb.PriorityQueue(bars) + + if bar := heap.Pop(&pq).(*mpb.Bar); bar != b { + t.Errorf("Expected bar b, got: %s", identity[bar]) + } + if bar := heap.Pop(&pq).(*mpb.Bar); bar != c { + t.Errorf("Expected bar c, got: %s", identity[bar]) + } + if bar := heap.Pop(&pq).(*mpb.Bar); bar != a { + t.Errorf("Expected bar a, got: %s", identity[bar]) + } +}