diff --git a/pond.go b/pond.go index e5e6192..90fa70d 100644 --- a/pond.go +++ b/pond.go @@ -353,6 +353,11 @@ func (p *WorkerPool) stop(waitForQueuedTasksToComplete bool) { // Mark pool as stopped atomic.StoreInt32(&p.stopped, 1) + // close tasks channel (only once, in case multiple concurrent calls to StopAndWait are made) + p.tasksCloseOnce.Do(func() { + close(p.tasks) + }) + if waitForQueuedTasksToComplete { // Wait for all queued tasks to complete p.tasksWaitGroup.Wait() @@ -366,11 +371,6 @@ func (p *WorkerPool) stop(waitForQueuedTasksToComplete bool) { // Wait for all workers & purger goroutine to exit p.workersWaitGroup.Wait() - - // close tasks channel (only once, in case multiple concurrent calls to StopAndWait are made) - p.tasksCloseOnce.Do(func() { - close(p.tasks) - }) } // purge represents the work done by the purger goroutine @@ -420,7 +420,7 @@ func (p *WorkerPool) maybeStartWorker(firstTask func()) bool { } // Launch worker goroutine - go worker(p.context, &p.workersWaitGroup, firstTask, p.tasks, p.executeTask) + go worker(p.context, &p.workersWaitGroup, firstTask, p.tasks, p.executeTask, &p.tasksWaitGroup) return true } diff --git a/pond_blackbox_test.go b/pond_blackbox_test.go index ca10b11..e020b0b 100644 --- a/pond_blackbox_test.go +++ b/pond_blackbox_test.go @@ -542,6 +542,47 @@ func TestSubmitWithContext(t *testing.T) { assertEqual(t, int32(0), atomic.LoadInt32(&doneCount)) } +func TestSubmitWithContextCancelWithIdleTasks(t *testing.T) { + + ctx, cancel := context.WithCancel(context.Background()) + + pool := pond.New(1, 5, pond.Context(ctx)) + + var doneCount, taskCount int32 + + // Submit a long-running, cancellable task + pool.Submit(func() { + atomic.AddInt32(&taskCount, 1) + select { + case <-ctx.Done(): + return + case <-time.After(10 * time.Minute): + atomic.AddInt32(&doneCount, 1) + return + } + }) + + // Submit a long-running, cancellable task + pool.Submit(func() { + atomic.AddInt32(&taskCount, 1) + select { + case <-ctx.Done(): + return + case <-time.After(10 * time.Minute): + atomic.AddInt32(&doneCount, 1) + return + } + }) + + // Cancel the context + cancel() + + pool.StopAndWait() + + assertEqual(t, int32(1), atomic.LoadInt32(&taskCount)) + assertEqual(t, int32(0), atomic.LoadInt32(&doneCount)) +} + func TestConcurrentStopAndWait(t *testing.T) { pool := pond.New(1, 5) diff --git a/worker.go b/worker.go index 1677c27..02a7288 100644 --- a/worker.go +++ b/worker.go @@ -6,7 +6,7 @@ import ( ) // worker represents a worker goroutine -func worker(context context.Context, waitGroup *sync.WaitGroup, firstTask func(), tasks <-chan func(), taskExecutor func(func(), bool)) { +func worker(context context.Context, waitGroup *sync.WaitGroup, firstTask func(), tasks <-chan func(), taskExecutor func(func(), bool), taskWaitGroup *sync.WaitGroup) { // If provided, execute the first task immediately, before listening to the tasks channel if firstTask != nil { @@ -20,7 +20,10 @@ func worker(context context.Context, waitGroup *sync.WaitGroup, firstTask func() for { select { case <-context.Done(): - // Pool context was cancelled, exit + // Pool context was cancelled, empty tasks channel and exit + for _ = range tasks { + taskWaitGroup.Done() + } return case task, ok := <-tasks: if task == nil || !ok {