Skip to content

Commit

Permalink
fix: bugs in refactored task queue and improved coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
garethgeorge committed Apr 11, 2024
1 parent eb07230 commit 834b74f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 20 deletions.
2 changes: 1 addition & 1 deletion internal/queue/timepriorityqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (t *TimePriorityQueue[T]) Enqueue(at time.Time, priority int, v T) {
func (t *TimePriorityQueue[T]) Dequeue(ctx context.Context) T {
t.mu.Lock()
for {
for t.tqueue.Len() > 0 {
for t.tqueue.heap.Len() > 0 {
thead := t.tqueue.Peek() // peek at the head of the time queue
if thead.at.Before(time.Now()) {
tqe := heap.Pop(&t.tqueue.heap).(timeQueueEntry[priorityEntry[T]])
Expand Down
31 changes: 31 additions & 0 deletions internal/queue/timepriorityqueue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package queue

import (
"context"
"math/rand"
"testing"
"time"
)
Expand Down Expand Up @@ -53,3 +54,33 @@ func TestTPQMixedReadinessStates(t *testing.T) {
}
}
}

func TestTPQStress(t *testing.T) {
tpq := NewTimePriorityQueue[int]()
start := time.Now()

totalEnqueued := 0
totalEnqueuedSum := 0

go func() {
ctx, _ := context.WithDeadline(context.Background(), start.Add(1*time.Second))
for ctx.Err() == nil {
v := rand.Intn(100)
tpq.Enqueue(time.Now().Add(time.Duration(rand.Intn(1000)-500)*time.Millisecond), rand.Intn(5), v)
totalEnqueuedSum += v
totalEnqueued++
}
}()

ctx, _ := context.WithDeadline(context.Background(), start.Add(3*time.Second))
totalDequeued := 0
sum := 0
for ctx.Err() == nil || totalDequeued < totalEnqueued {
sum += tpq.Dequeue(ctx)
totalDequeued++
}

if sum != totalEnqueuedSum {
t.Errorf("expected sum to be %d, got %d", totalEnqueuedSum, sum)
}
}
36 changes: 17 additions & 19 deletions internal/queue/timequeue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"container/heap"
"context"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -13,7 +14,7 @@ type TimeQueue[T any] struct {

dequeueMu sync.Mutex
mu sync.Mutex
notify chan struct{}
notify atomic.Pointer[chan struct{}]
}

func NewTimeQueue[T any]() *TimeQueue[T] {
Expand All @@ -25,10 +26,13 @@ func NewTimeQueue[T any]() *TimeQueue[T] {
func (t *TimeQueue[T]) Enqueue(at time.Time, v T) {
t.mu.Lock()
heap.Push(&t.heap, timeQueueEntry[T]{at, v})
if t.notify != nil {
t.notify <- struct{}{}
}
t.mu.Unlock()
if n := t.notify.Load(); n != nil {
select {
case *n <- struct{}{}:
default:
}
}
}

func (t *TimeQueue[T]) Len() int {
Expand Down Expand Up @@ -63,30 +67,24 @@ func (t *TimeQueue[T]) Dequeue(ctx context.Context) T {
t.dequeueMu.Lock()
defer t.dequeueMu.Unlock()

t.mu.Lock()
t.notify = make(chan struct{}, 1)
defer func() {
t.mu.Lock()
close(t.notify)
t.notify = nil
t.mu.Unlock()
}()
t.mu.Unlock()
notify := make(chan struct{}, 1)
t.notify.Store(&notify)
defer t.notify.Store(nil)

for {
t.mu.Lock()

var wait time.Duration
if t.heap.Len() == 0 {
wait = 3 * time.Minute
} else {
if t.heap.Len() > 0 {
val := t.heap.Peek()
wait = time.Until(val.at)
if wait <= 0 {
t.mu.Unlock()
defer t.mu.Unlock()
return heap.Pop(&t.heap).(timeQueueEntry[T]).v
}
}
if wait == 0 || wait > 3*time.Minute {
wait = 3 * time.Minute
}
t.mu.Unlock()

timer := time.NewTimer(wait)
Expand All @@ -101,7 +99,7 @@ func (t *TimeQueue[T]) Dequeue(ctx context.Context) T {
}
t.mu.Unlock()
return val.v
case <-t.notify: // new task was added, loop again to ensure we have the earliest task.
case <-notify: // new task was added, loop again to ensure we have the earliest task.
if !timer.Stop() {
<-timer.C
}
Expand Down

0 comments on commit 834b74f

Please sign in to comment.