Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added context #516

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions chain.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cron

import (
"context"
"fmt"
"runtime"
"sync"
Expand All @@ -24,9 +25,12 @@ func NewChain(c ...JobWrapper) Chain {
// Then decorates the given job with all JobWrappers in the chain.
//
// This:
// NewChain(m1, m2, m3).Then(job)
//
// NewChain(m1, m2, m3).Then(job)
//
// is equivalent to:
// m1(m2(m3(job)))
//
// m1(m2(m3(job)))
func (c Chain) Then(j Job) Job {
for i := range c.wrappers {
j = c.wrappers[len(c.wrappers)-i-1](j)
Expand All @@ -37,7 +41,7 @@ func (c Chain) Then(j Job) Job {
// Recover panics in wrapped jobs and log them with the provided logger.
func Recover(logger Logger) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
defer func() {
if r := recover(); r != nil {
const size = 64 << 10
Expand All @@ -50,7 +54,7 @@ func Recover(logger Logger) JobWrapper {
logger.Error(err, "panic", "stack", "...\n"+string(buf))
}
}()
j.Run()
j.Run(ctx)
})
}
}
Expand All @@ -61,14 +65,14 @@ func Recover(logger Logger) JobWrapper {
func DelayIfStillRunning(logger Logger) JobWrapper {
return func(j Job) Job {
var mu sync.Mutex
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
start := time.Now()
mu.Lock()
defer mu.Unlock()
if dur := time.Since(start); dur > time.Minute {
logger.Info("delay", "duration", dur)
}
j.Run()
j.Run(ctx)
})
}
}
Expand All @@ -79,11 +83,11 @@ func SkipIfStillRunning(logger Logger) JobWrapper {
return func(j Job) Job {
var ch = make(chan struct{}, 1)
ch <- struct{}{}
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
select {
case v := <-ch:
defer func() { ch <- v }()
j.Run()
j.Run(ctx)
default:
logger.Info("skip")
}
Expand Down
47 changes: 24 additions & 23 deletions chain_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cron

import (
"context"
"io/ioutil"
"log"
"reflect"
Expand All @@ -11,7 +12,7 @@ import (

func appendingJob(slice *[]int, value int) Job {
var m sync.Mutex
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
m.Lock()
*slice = append(*slice, value)
m.Unlock()
Expand All @@ -20,9 +21,9 @@ func appendingJob(slice *[]int, value int) Job {

func appendingWrapper(slice *[]int, value int) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
appendingJob(slice, value).Run()
j.Run()
return FuncJob(func(ctx context.Context) {
appendingJob(slice, value).Run(context.Background())
j.Run(context.Background())
})
}
}
Expand All @@ -35,14 +36,14 @@ func TestChain(t *testing.T) {
append3 = appendingWrapper(&nums, 3)
append4 = appendingJob(&nums, 4)
)
NewChain(append1, append2, append3).Then(append4).Run()
NewChain(append1, append2, append3).Then(append4).Run(context.Background())
if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) {
t.Error("unexpected order of calls:", nums)
}
}

func TestChainRecover(t *testing.T) {
panickingJob := FuncJob(func() {
panickingJob := FuncJob(func(ctx context.Context) {
panic("panickingJob panics")
})

Expand All @@ -53,19 +54,19 @@ func TestChainRecover(t *testing.T) {
}
}()
NewChain().Then(panickingJob).
Run()
Run(context.Background())
})

t.Run("Recovering JobWrapper recovers", func(t *testing.T) {
NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))).
Then(panickingJob).
Run()
Run(context.Background())
})

t.Run("composed with the *IfStillRunning wrappers", func(t *testing.T) {
NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))).
Then(panickingJob).
Run()
Run(context.Background())
})
}

Expand All @@ -76,7 +77,7 @@ type countJob struct {
delay time.Duration
}

func (j *countJob) Run() {
func (j *countJob) Run(ctx context.Context) {
j.m.Lock()
j.started++
j.m.Unlock()
Expand All @@ -103,7 +104,7 @@ func TestChainDelayIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(t *testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
Expand All @@ -114,9 +115,9 @@ func TestChainDelayIfStillRunning(t *testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
Expand All @@ -129,9 +130,9 @@ func TestChainDelayIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()

// After 5ms, the first job is still in progress, and the second job was
Expand All @@ -157,7 +158,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(t *testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
Expand All @@ -168,9 +169,9 @@ func TestChainSkipIfStillRunning(t *testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
Expand All @@ -183,9 +184,9 @@ func TestChainSkipIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()

// After 5ms, the first job is still in progress, and the second job was
Expand All @@ -209,7 +210,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
for i := 0; i < 11; i++ {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}
time.Sleep(200 * time.Millisecond)
done := j.Done()
Expand All @@ -226,8 +227,8 @@ func TestChainSkipIfStillRunning(t *testing.T) {
wrappedJob1 := chain.Then(&j1)
wrappedJob2 := chain.Then(&j2)
for i := 0; i < 11; i++ {
go wrappedJob1.Run()
go wrappedJob2.Run()
go wrappedJob1.Run(context.Background())
go wrappedJob2.Run(context.Background())
}
time.Sleep(100 * time.Millisecond)
var (
Expand Down
47 changes: 26 additions & 21 deletions cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type ScheduleParser interface {

// Job is an interface for submitted cron jobs.
type Job interface {
Run()
Run(context.Context)
}

// Schedule describes a job's duty cycle.
Expand Down Expand Up @@ -97,17 +97,17 @@ func (s byTime) Less(i, j int) bool {
//
// Available Settings
//
// Time Zone
// Description: The time zone in which schedules are interpreted
// Default: time.Local
// Time Zone
// Description: The time zone in which schedules are interpreted
// Default: time.Local
//
// Parser
// Description: Parser converts cron spec strings into cron.Schedules.
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
// Parser
// Description: Parser converts cron spec strings into cron.Schedules.
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
//
// Chain
// Description: Wrap submitted jobs to customize behavior.
// Default: A chain that recovers panics and logs them to stderr.
// Chain
// Description: Wrap submitted jobs to customize behavior.
// Default: A chain that recovers panics and logs them to stderr.
//
// See "cron.With*" to modify the default behavior.
func New(opts ...Option) *Cron {
Expand All @@ -131,14 +131,14 @@ func New(opts ...Option) *Cron {
}

// FuncJob is a wrapper that turns a func() into a cron.Job
type FuncJob func()
type FuncJob func(ctx context.Context)

func (f FuncJob) Run() { f() }
func (f FuncJob) Run(ctx context.Context) { f(ctx) }

// AddFunc adds a func to the Cron to be run on the given schedule.
// The spec is parsed using the time zone of this Cron instance as the default.
// An opaque ID is returned that can be used to later remove it.
func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) {
func (c *Cron) AddFunc(spec string, cmd func(context.Context)) (EntryID, error) {
return c.AddJob(spec, FuncJob(cmd))
}

Expand Down Expand Up @@ -212,31 +212,31 @@ func (c *Cron) Remove(id EntryID) {
}

// Start the cron scheduler in its own goroutine, or no-op if already started.
func (c *Cron) Start() {
func (c *Cron) Start(ctx context.Context) {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
return
}
c.running = true
go c.run()
go c.run(ctx)
}

// Run the cron scheduler, or no-op if already running.
func (c *Cron) Run() {
func (c *Cron) Run(ctx context.Context) {
c.runningMu.Lock()
if c.running {
c.runningMu.Unlock()
return
}
c.running = true
c.runningMu.Unlock()
c.run()
c.run(ctx)
}

// run the scheduler.. this is private just due to the need to synchronize
// access to the 'running' state variable.
func (c *Cron) run() {
func (c *Cron) run(ctx context.Context) {
c.logger.Info("start")

// Figure out the next activation times for each entry.
Expand Down Expand Up @@ -270,7 +270,7 @@ func (c *Cron) run() {
if e.Next.After(now) || e.Next.IsZero() {
break
}
c.startJob(e.WrappedJob)
c.startJob(ctx, e.WrappedJob)
e.Prev = e.Next
e.Next = e.Schedule.Next(now)
c.logger.Info("run", "now", now, "entry", e.ID, "next", e.Next)
Expand All @@ -292,6 +292,11 @@ func (c *Cron) run() {
c.logger.Info("stop")
return

case <-ctx.Done():
timer.Stop()
c.logger.Info("context canceled")
return

case id := <-c.remove:
timer.Stop()
now = c.now()
Expand All @@ -305,11 +310,11 @@ func (c *Cron) run() {
}

// startJob runs the given job in a new goroutine.
func (c *Cron) startJob(j Job) {
func (c *Cron) startJob(ctx context.Context, j Job) {
c.jobWaiter.Add(1)
go func() {
defer c.jobWaiter.Done()
j.Run()
j.Run(ctx)
}()
}

Expand Down
Loading