Skip to content

Commit

Permalink
Update lock timeout calculation on heartbeat (#284)
Browse files Browse the repository at this point in the history
Use the reqTime to calculate the new lock timeout on hearbeat.
This is not perfect as cannot know the exact tick that the server
updates the database, but the assertion in the release lock
validation will be correct because we err on the side of under
calculating the timeout time.
  • Loading branch information
dfarr authored Apr 11, 2024
1 parent b67b3a9 commit 2054353
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 29 deletions.
5 changes: 3 additions & 2 deletions test/dst/dst.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,12 @@ func (d *DST) Run(r *rand.Rand, api api.API, aio aio.AIO, system *system.System,
Metadata: metadata,
Submission: req,
Callback: func(res *t_api.Response, err error) {
modelErr := model.Step(t, req, res, err)
resTime := t
modelErr := model.Step(reqTime, resTime, req, res, err)
if modelErr != nil {
errs = append(errs, modelErr)
}
slog.Info("DST", "t", fmt.Sprintf("%d|%d", reqTime, t), "tid", metadata.TransactionId, "req", req, "res", res, "err", err, "ok", modelErr == nil)
slog.Info("DST", "t", fmt.Sprintf("%d|%d", reqTime, resTime), "tid", metadata.TransactionId, "req", req, "res", res, "err", err, "ok", modelErr == nil)
},
})

Expand Down
46 changes: 23 additions & 23 deletions test/dst/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type Schedules map[string]*ScheduleModel
type Subscriptions map[string]*SubscriptionModel
type Locks map[string]*LockModel
type Tasks map[string]*TaskModel
type ResponseValidator func(int64, *t_api.Request, *t_api.Response) error
type ResponseValidator func(int64, int64, *t_api.Request, *t_api.Response) error

func (p Promises) Get(id string) *PromiseModel {
if _, ok := p[id]; !ok {
Expand Down Expand Up @@ -135,7 +135,7 @@ func (m *Model) addCursor(next *t_api.Request) {

// Validation

func (m *Model) Step(t int64, req *t_api.Request, res *t_api.Response, err error) error {
func (m *Model) Step(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response, err error) error {
if err != nil {
var resErr *t_api.ResonateError
if !errors.As(err, &resErr) {
Expand All @@ -162,15 +162,15 @@ func (m *Model) Step(t int64, req *t_api.Request, res *t_api.Response, err error
}

if f, ok := m.responses[req.Kind]; ok {
return f(t, req, res)
return f(reqTime, resTime, req, res)
}

return nil
}

// PROMISES

func (m *Model) ValidateReadPromise(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateReadPromise(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
pm := m.promises.Get(req.ReadPromise.Id)

switch res.ReadPromise.Status {
Expand All @@ -192,7 +192,7 @@ func (m *Model) ValidateReadPromise(t int64, req *t_api.Request, res *t_api.Resp
}
}

func (m *Model) ValidateSearchPromises(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateSearchPromises(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
if res.SearchPromises.Cursor != nil {
m.addCursor(&t_api.Request{
Kind: t_api.SearchPromises,
Expand Down Expand Up @@ -239,7 +239,7 @@ func (m *Model) ValidateSearchPromises(t int64, req *t_api.Request, res *t_api.R
}
}

func (m *Model) ValidateCreatePromise(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateCreatePromise(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
pm := m.promises.Get(req.CreatePromise.Id)

switch res.CreatePromise.Status {
Expand Down Expand Up @@ -296,7 +296,7 @@ func (m *Model) ValidateCreatePromise(t int64, req *t_api.Request, res *t_api.Re
}
}

func (m *Model) ValidateCompletePromise(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateCompletePromise(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
pm := m.promises.Get(req.CompletePromise.Id)

switch res.CompletePromise.Status {
Expand Down Expand Up @@ -354,7 +354,7 @@ func (m *Model) ValidateCompletePromise(t int64, req *t_api.Request, res *t_api.

// SCHEDULES

func (m *Model) ValidateReadSchedule(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateReadSchedule(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
sm := m.schedules.Get(req.ReadSchedule.Id)

switch res.ReadSchedule.Status {
Expand Down Expand Up @@ -382,7 +382,7 @@ func (m *Model) ValidateReadSchedule(t int64, req *t_api.Request, res *t_api.Res
}
}

func (m *Model) ValidateSearchSchedules(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateSearchSchedules(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
if res.SearchSchedules.Cursor != nil {
m.addCursor(&t_api.Request{
Kind: t_api.SearchSchedules,
Expand Down Expand Up @@ -430,7 +430,7 @@ func (m *Model) ValidateSearchSchedules(t int64, req *t_api.Request, res *t_api.
}
}

func (m *Model) ValidateCreateSchedule(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateCreateSchedule(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
sm := m.schedules.Get(req.CreateSchedule.Id)

switch res.CreateSchedule.Status {
Expand All @@ -454,7 +454,7 @@ func (m *Model) ValidateCreateSchedule(t int64, req *t_api.Request, res *t_api.R
}
}

func (m *Model) ValidateDeleteSchedule(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateDeleteSchedule(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
sm := m.schedules.Get(req.DeleteSchedule.Id)

switch res.DeleteSchedule.Status {
Expand All @@ -470,7 +470,7 @@ func (m *Model) ValidateDeleteSchedule(t int64, req *t_api.Request, res *t_api.R

// SUBSCRIPTIONS

func (m *Model) ValidateReadSubscriptions(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateReadSubscriptions(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
if res.ReadSubscriptions.Cursor != nil {
m.addCursor(&t_api.Request{
Kind: t_api.ReadSubscriptions,
Expand All @@ -497,7 +497,7 @@ func (m *Model) ValidateReadSubscriptions(t int64, req *t_api.Request, res *t_ap
}
}

func (m *Model) ValidateCreateSubscription(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateCreateSubscription(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
pm := m.promises.Get(req.CreateSubscription.PromiseId)
sm := pm.subscriptions.Get(req.CreateSubscription.Id)

Expand All @@ -513,7 +513,7 @@ func (m *Model) ValidateCreateSubscription(t int64, req *t_api.Request, res *t_a
}
}

func (m *Model) ValidateDeleteSubscription(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateDeleteSubscription(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
pm := m.promises.Get(req.DeleteSubscription.PromiseId)
sm := pm.subscriptions.Get(req.DeleteSubscription.Id)
switch res.DeleteSubscription.Status {
Expand All @@ -530,7 +530,7 @@ func (m *Model) ValidateDeleteSubscription(t int64, req *t_api.Request, res *t_a

// LOCKS

func (m *Model) ValidateAcquireLock(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateAcquireLock(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
lm := m.locks.Get(req.AcquireLock.ResourceId)

switch res.AcquireLock.Status {
Expand All @@ -552,7 +552,7 @@ func (m *Model) ValidateAcquireLock(t int64, req *t_api.Request, res *t_api.Resp
}
}

func (m *Model) ValidateHeartbeatLocks(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateHeartbeatLocks(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
switch res.HeartbeatLocks.Status {
case t_api.StatusOK:
if res.HeartbeatLocks.LocksAffected == 0 {
Expand Down Expand Up @@ -582,7 +582,7 @@ func (m *Model) ValidateHeartbeatLocks(t int64, req *t_api.Request, res *t_api.R
if l.lock.ProcessId == req.HeartbeatLocks.ProcessId {
// update local model for processId's locks
owned := m.locks.Get(l.lock.ResourceId)
owned.lock.ExpiresAt = owned.lock.ExpiresAt + owned.lock.ExpiryInMilliseconds
owned.lock.ExpiresAt = reqTime + owned.lock.ExpiryInMilliseconds
}
}

Expand All @@ -592,7 +592,7 @@ func (m *Model) ValidateHeartbeatLocks(t int64, req *t_api.Request, res *t_api.R
}
}

func (m *Model) ValidateReleaseLock(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateReleaseLock(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
lm := m.locks.Get(req.ReleaseLock.ResourceId)

switch res.ReleaseLock.Status {
Expand All @@ -610,7 +610,7 @@ func (m *Model) ValidateReleaseLock(t int64, req *t_api.Request, res *t_api.Resp
}

// if lock belongs to the same executionId it must have timedout.
if lm.lock.ExpiresAt > t {
if lm.lock.ExpiresAt > resTime {
return fmt.Errorf("executionId %s still has the lock for resourceId %s", req.ReleaseLock.ExecutionId, req.ReleaseLock.ResourceId)
}
lm.lock = nil
Expand All @@ -626,7 +626,7 @@ func (m *Model) ValidateReleaseLock(t int64, req *t_api.Request, res *t_api.Resp

// TASKS

func (m *Model) ValidateClaimTask(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateClaimTask(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
tm := m.tasks.Get(req.ClaimTask.TaskId)

switch res.ClaimTask.Status {
Expand All @@ -649,7 +649,7 @@ func (m *Model) ValidateClaimTask(t int64, req *t_api.Request, res *t_api.Respon
}
return nil
case t_api.StatusTaskAlreadyTimedOut:
if tm.task.PromiseTimeout > t {
if tm.task.PromiseTimeout > resTime {
return fmt.Errorf("task %s has not yet timed out", tm.task)
}
return nil
Expand All @@ -660,7 +660,7 @@ func (m *Model) ValidateClaimTask(t int64, req *t_api.Request, res *t_api.Respon
}
}

func (m *Model) ValidateCompleteTask(t int64, req *t_api.Request, res *t_api.Response) error {
func (m *Model) ValidateCompleteTask(reqTime int64, resTime int64, req *t_api.Request, res *t_api.Response) error {
tm := m.tasks.Get(req.CompleteTask.TaskId)

switch res.CompleteTask.Status {
Expand All @@ -680,7 +680,7 @@ func (m *Model) ValidateCompleteTask(t int64, req *t_api.Request, res *t_api.Res
}
return nil
case t_api.StatusTaskAlreadyTimedOut:
if tm.task.PromiseTimeout > t {
if tm.task.PromiseTimeout > resTime {
return fmt.Errorf("task %s has not yet timed out", tm.task)
}
return nil
Expand Down
8 changes: 4 additions & 4 deletions test/dst/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestModelStep(t *testing.T) {

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
err := tc.m.Step(0, tc.req, tc.res, tc.err)
err := tc.m.Step(0, 0, tc.req, tc.res, tc.err)
if tc.wantErr {
assert.Error(t, err)
} else {
Expand Down Expand Up @@ -111,16 +111,16 @@ func TestModelValidateReadSchedule(t *testing.T) {
},
},
}
err := model.ValidateReadSchedule(0, request, response)
err := model.ValidateReadSchedule(0, 0, request, response)
assert.NoError(t, err)

// Test case 2: Schedule not found
response.ReadSchedule.Status = t_api.StatusScheduleNotFound
err = model.ValidateReadSchedule(0, request, response)
err = model.ValidateReadSchedule(0, 0, request, response)
assert.NoError(t, err) // Expecting no error for schedule not found

// Test case 3: Unexpected response status
response.ReadSchedule.Status = t_api.StatusNoContent
err = model.ValidateReadSchedule(0, request, response)
err = model.ValidateReadSchedule(0, 0, request, response)
assert.Error(t, err) // Expecting error for unexpected response status
}

0 comments on commit 2054353

Please sign in to comment.