diff --git a/internal/kernel/scheduler/scheduler.go b/internal/kernel/scheduler/scheduler.go index 67c14f67..1914dfd9 100644 --- a/internal/kernel/scheduler/scheduler.go +++ b/internal/kernel/scheduler/scheduler.go @@ -15,17 +15,17 @@ type S interface { } type Scheduler struct { - aio aio.AIO - time int64 - metrics *metrics.Metrics - coroutines []*suspendableCoroutine + aio aio.AIO + time int64 + metrics *metrics.Metrics + runnable []*runnableCoroutine + suspended []*Coroutine[*t_aio.Completion, *t_aio.Submission] } -type suspendableCoroutine struct { +type runnableCoroutine struct { *Coroutine[*t_aio.Completion, *t_aio.Submission] - next *t_aio.Completion - error error - suspended bool + next *t_aio.Completion + error error } func NewScheduler(aio aio.AIO, metrics *metrics.Metrics) *Scheduler { @@ -44,13 +44,12 @@ func (s *Scheduler) Add(coroutine *Coroutine[*t_aio.Completion, *t_aio.Submissio coroutine.Scheduler = s // wrap in suspendable coroutine - s.coroutines = append(s.coroutines, &suspendableCoroutine{ + s.runnable = append(s.runnable, &runnableCoroutine{ Coroutine: coroutine, }) } func (s *Scheduler) Tick(t int64, batchSize int) { - coroutines := []*suspendableCoroutine{} s.time = t // dequeue cqes @@ -59,42 +58,48 @@ func (s *Scheduler) Tick(t int64, batchSize int) { } // enqueue sqes - for _, coroutine := range s.coroutines { - if coroutine.suspended { - continue - } + for _, coroutine := range s.runnable { + coroutine := coroutine // bind to local variable for callback if submission := coroutine.Resume(coroutine.next, coroutine.error); !submission.Done { + // suspend + s.suspended = append(s.suspended, coroutine.Coroutine) + s.aio.Enqueue(&bus.SQE[t_aio.Submission, t_aio.Completion]{ Tags: submission.Value.Kind.String(), Submission: submission.Value, Callback: func(completion *t_aio.Completion, err error) { // unsuspend - coroutine.next = completion - coroutine.error = err - coroutine.suspended = false + s.runnable = append(s.runnable, &runnableCoroutine{ + Coroutine: coroutine.Coroutine, + next: completion, + error: err, + }) + + for i, c := range s.suspended { + if c == coroutine.Coroutine { + s.suspended = append(s.suspended[:i], s.suspended[i+1:]...) + break + } + } }, }) - - // suspend - coroutine.suspended = true - coroutines = append(coroutines, coroutine) } else { + slog.Debug("scheduler:rmv", "coroutine", coroutine.name) + s.metrics.CoroutinesInFlight.WithLabelValues(coroutine.name).Dec() + // call onDone functions for _, f := range coroutine.onDone { f() } - - slog.Debug("scheduler:rmv", "coroutine", coroutine.name) - s.metrics.CoroutinesInFlight.WithLabelValues(coroutine.name).Dec() } } // flush s.aio.Flush(t) - // discard done coroutines - s.coroutines = coroutines + // clear runnable + s.runnable = nil } func (s *Scheduler) Time() int64 { @@ -102,5 +107,5 @@ func (s *Scheduler) Time() int64 { } func (s *Scheduler) Done() bool { - return len(s.coroutines) == 0 + return len(s.runnable) == 0 && len(s.suspended) == 0 }