From 144c3167f8e3d1fba42f9f526d80fb902aba519a Mon Sep 17 00:00:00 2001 From: Carolina Calderon Date: Thu, 25 Jan 2024 17:04:49 -0500 Subject: [PATCH] wip --- master/internal/api_checkpoint_intg_test.go | 2 +- master/internal/api_experiment.go | 2 +- master/internal/api_experiment_intg_test.go | 6 +- master/internal/api_tasks_intg_test.go | 18 +- master/internal/api_trials_intg_test.go | 23 +- master/internal/checkpoint_gc.go | 10 +- master/internal/command/command.go | 7 +- master/internal/core.go | 8 +- master/internal/db/database.go | 13 -- master/internal/db/postgres.go | 38 ---- .../internal/db/postgres_cluster_intg_test.go | 10 +- master/internal/db/postgres_task_logs.go | 13 +- master/internal/db/postgres_tasks.go | 214 +++++++----------- .../internal/db/postgres_tasks_intg_test.go | 116 +++++----- master/internal/db/postgres_test_utils.go | 74 +----- master/internal/db/postgres_trial.go | 27 +-- .../internal/db/postgres_trial_intg_test.go | 14 +- master/internal/db/setup.go | 2 +- master/internal/populate_metrics.go | 2 +- master/internal/rm/agentrm/agent.go | 7 +- .../internal/rm/agentrm/agent_state_test.go | 4 +- master/internal/task/allocation.go | 43 ++-- .../internal/task/allocation_service_test.go | 47 ++-- master/internal/telemetry/reports.go | 5 +- master/internal/telemetry/telemetry_test.go | 2 +- master/internal/trial.go | 16 +- .../integration/api/api_checkpoints_test.go | 7 +- .../test/integration/api/api_trials_test.go | 4 +- 28 files changed, 290 insertions(+), 444 deletions(-) diff --git a/master/internal/api_checkpoint_intg_test.go b/master/internal/api_checkpoint_intg_test.go index 3d7a6830dd9..2704aafba9b 100644 --- a/master/internal/api_checkpoint_intg_test.go +++ b/master/internal/api_checkpoint_intg_test.go @@ -42,7 +42,7 @@ func createVersionTwoCheckpoint( ResourcePool: "somethingelse", StartTime: ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)), } - require.NoError(t, api.m.db.AddAllocation(aIn)) + require.NoError(t, db.AddAllocation(ctx, aIn)) checkpoint := &model.CheckpointV2{ ID: 0, diff --git a/master/internal/api_experiment.go b/master/internal/api_experiment.go index a4132be2c91..afe04fd88e7 100644 --- a/master/internal/api_experiment.go +++ b/master/internal/api_experiment.go @@ -2737,7 +2737,7 @@ func (a *apiServer) createTrialTx( nil, 0) - if err := a.m.db.AddTask(&model.Task{ + if err := db.AddTask(ctx, &model.Task{ TaskID: taskID, TaskType: model.TaskTypeTrial, StartTime: time.Now(), diff --git a/master/internal/api_experiment_intg_test.go b/master/internal/api_experiment_intg_test.go index 98848d85c84..fab8ebde137 100644 --- a/master/internal/api_experiment_intg_test.go +++ b/master/internal/api_experiment_intg_test.go @@ -209,7 +209,7 @@ func TestGetTaskContextDirectoryExperiment(t *testing.T) { func TestGetTaskContextDirectoryTask(t *testing.T) { api, _, ctx := setupAPITest(t, nil) task := &model.Task{TaskType: model.TaskTypeNotebook, TaskID: model.NewTaskID()} - require.NoError(t, api.m.db.AddTask(task)) + require.NoError(t, db.AddTask(ctx, task)) expectedContextDirectory := []byte("expectedContextDirectory") _, err := db.Bun().NewInsert().Model(&model.TaskContextDirectory{ @@ -539,7 +539,7 @@ func TestGetExperiments(t *testing.T) { require.NoError(t, api.m.db.AddExperiment(exp0, activeConfig0)) for i := 0; i < 3; i++ { task := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()} - require.NoError(t, api.m.db.AddTask(task)) + require.NoError(t, db.AddTask(ctx, task)) require.NoError(t, db.AddTrial(ctx, &model.Trial{ State: model.PausedState, ExperimentID: exp0.ID, @@ -791,7 +791,7 @@ func TestSearchExperiments(t *testing.T) { // Trial without validations doesn't cause issues. noValidationsExp := createTestExpWithProjectID(t, api, curUser, projectIDInt) task := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()} - require.NoError(t, api.m.db.AddTask(task)) + require.NoError(t, db.AddTask(ctx, task)) require.NoError(t, db.AddTrial(ctx, &model.Trial{ State: model.PausedState, ExperimentID: noValidationsExp.ID, diff --git a/master/internal/api_tasks_intg_test.go b/master/internal/api_tasks_intg_test.go index 9eab886637b..9bcbd12ab63 100644 --- a/master/internal/api_tasks_intg_test.go +++ b/master/internal/api_tasks_intg_test.go @@ -113,10 +113,10 @@ func mockNotebookWithWorkspaceID( TaskID: model.NewTaskID(), TaskType: model.TaskTypeNotebook, } - require.NoError(t, api.m.db.AddTask(nb)) + require.NoError(t, db.AddTask(ctx, nb)) allocationID := model.AllocationID(string(nb.TaskID) + ".1") - require.NoError(t, api.m.db.AddAllocation(&model.Allocation{ + require.NoError(t, db.AddAllocation(ctx, &model.Allocation{ TaskID: nb.TaskID, AllocationID: allocationID, })) @@ -324,7 +324,7 @@ func TestAddAllocationAcceleratorData(t *testing.T) { TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - require.NoError(t, api.m.db.AddTask(task), "failed to add task") + require.NoError(t, db.AddTask(ctx, task), "failed to add task") aID := tID + "-1" a := &model.Allocation{ @@ -333,7 +333,7 @@ func TestAddAllocationAcceleratorData(t *testing.T) { Slots: 1, ResourcePool: "default", } - require.NoError(t, api.m.db.AddAllocation(a), "failed to add allocation") + require.NoError(t, db.AddAllocation(ctx, a), "failed to add allocation") accData := &model.AcceleratorData{ ContainerID: uuid.NewString(), AllocationID: model.AllocationID(aID), @@ -362,7 +362,7 @@ func TestGetAllocationAcceleratorDataWithNoData(t *testing.T) { TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - require.NoError(t, api.m.db.AddTask(task), "failed to add task") + require.NoError(t, db.AddTask(ctx, task), "failed to add task") aID := tID + "-1" a := &model.Allocation{ @@ -371,7 +371,7 @@ func TestGetAllocationAcceleratorDataWithNoData(t *testing.T) { Slots: 1, ResourcePool: "default", } - require.NoError(t, api.m.db.AddAllocation(a), "failed to add allocation") + require.NoError(t, db.AddAllocation(ctx, a), "failed to add allocation") resp, err := api.GetTaskAcceleratorData(ctx, &apiv1.GetTaskAcceleratorDataRequest{TaskId: tID.String()}) @@ -390,7 +390,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) { TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - require.NoError(t, api.m.db.AddTask(task), "failed to add task") + require.NoError(t, db.AddTask(ctx, task), "failed to add task") aID1 := tID + "-1" a1 := &model.Allocation{ @@ -399,7 +399,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) { Slots: 1, ResourcePool: "default", } - require.NoError(t, api.m.db.AddAllocation(a1), "failed to add allocation") + require.NoError(t, db.AddAllocation(ctx, a1), "failed to add allocation") accData := &model.AcceleratorData{ ContainerID: uuid.NewString(), AllocationID: model.AllocationID(aID1), @@ -418,7 +418,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) { Slots: 1, ResourcePool: "default", } - require.NoError(t, api.m.db.AddAllocation(a2), "failed to add allocation") + require.NoError(t, db.AddAllocation(ctx, a2), "failed to add allocation") resp, err := api.GetTaskAcceleratorData(ctx, &apiv1.GetTaskAcceleratorDataRequest{TaskId: tID.String()}) diff --git a/master/internal/api_trials_intg_test.go b/master/internal/api_trials_intg_test.go index de1c367b6d2..95fa0007042 100644 --- a/master/internal/api_trials_intg_test.go +++ b/master/internal/api_trials_intg_test.go @@ -38,6 +38,7 @@ var inferenceMetricGroup = "inference" func createTestTrial( t *testing.T, api *apiServer, curUser model.User, ) (*model.Trial, *model.Task) { + ctx := context.Background() exp := createTestExpWithProjectID(t, api, curUser, 1) task := &model.Task{ @@ -46,17 +47,17 @@ func createTestTrial( StartTime: time.Now(), TaskID: trialTaskID(exp.ID, model.NewRequestID(rand.Reader)), } - require.NoError(t, api.m.db.AddTask(task)) + require.NoError(t, db.AddTask(ctx, task)) trial := &model.Trial{ StartTime: time.Now(), State: model.PausedState, ExperimentID: exp.ID, } - require.NoError(t, db.AddTrial(context.TODO(), trial, task.TaskID)) + require.NoError(t, db.AddTrial(ctx, trial, task.TaskID)) // Return trial exactly the way the API will generally get it. - outTrial, err := db.TrialByID(context.TODO(), trial.ID) + outTrial, err := db.TrialByID(ctx, trial.ID) require.NoError(t, err) return outTrial, task } @@ -751,7 +752,7 @@ func TestTrialProtoTaskIDs(t *testing.T) { StartTime: task0.StartTime.Add(time.Second), TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)), } - require.NoError(t, api.m.db.AddTask(task1)) + require.NoError(t, db.AddTask(ctx, task1)) task2 := &model.Task{ TaskType: model.TaskTypeTrial, @@ -759,7 +760,7 @@ func TestTrialProtoTaskIDs(t *testing.T) { StartTime: task1.StartTime.Add(time.Second), TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)), } - require.NoError(t, api.m.db.AddTask(task2)) + require.NoError(t, db.AddTask(ctx, task2)) _, err = db.Bun().NewInsert().Model(&[]model.TrialTaskID{ {TrialID: trial.ID, TaskID: task1.TaskID}, @@ -834,7 +835,7 @@ func TestExperimentIDFromTrialTaskID(t *testing.T) { StartTime: time.Now(), TaskID: model.TaskID(uuid.New().String()), } - require.NoError(t, api.m.db.AddTask(task)) + require.NoError(t, db.AddTask(context.Background(), task)) _, err = experimentIDFromTrialTaskID(notTrialTask.TaskID) require.ErrorIs(t, err, errIsNotTrialTaskID) @@ -852,7 +853,7 @@ func TestTrialLogsBackported(t *testing.T) { StartTime: time.Now(), TaskID: model.TaskID(fmt.Sprintf("backported.%d", exp.ID)), } - require.NoError(t, api.m.db.AddTask(task)) + require.NoError(t, db.AddTask(ctx, task)) trial := &model.Trial{ StartTime: time.Now(), @@ -889,7 +890,7 @@ func TestTrialLogs(t *testing.T) { StartTime: task0.StartTime.Add(time.Second), TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)), } - require.NoError(t, api.m.db.AddTask(task1)) + require.NoError(t, db.AddTask(ctx, task1)) task2 := &model.Task{ TaskType: model.TaskTypeTrial, @@ -897,7 +898,7 @@ func TestTrialLogs(t *testing.T) { StartTime: task1.StartTime.Add(time.Second), TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)), } - require.NoError(t, api.m.db.AddTask(task2)) + require.NoError(t, db.AddTask(ctx, task2)) _, err := db.Bun().NewInsert().Model(&[]model.TrialTaskID{ {TrialID: trial.ID, TaskID: task1.TaskID}, @@ -985,7 +986,7 @@ func TestTrialLogFields(t *testing.T) { StartTime: task0.StartTime.Add(time.Second), TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)), } - require.NoError(t, api.m.db.AddTask(task1)) + require.NoError(t, db.AddTask(ctx, task1)) task2 := &model.Task{ TaskType: model.TaskTypeTrial, @@ -993,7 +994,7 @@ func TestTrialLogFields(t *testing.T) { StartTime: task1.StartTime.Add(time.Second), TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)), } - require.NoError(t, api.m.db.AddTask(task2)) + require.NoError(t, db.AddTask(ctx, task2)) _, err := db.Bun().NewInsert().Model(&[]model.TrialTaskID{ {TrialID: trial.ID, TaskID: task1.TaskID}, diff --git a/master/internal/checkpoint_gc.go b/master/internal/checkpoint_gc.go index 84f900771dc..94309405d72 100644 --- a/master/internal/checkpoint_gc.go +++ b/master/internal/checkpoint_gc.go @@ -1,6 +1,7 @@ package internal import ( + "context" "fmt" "strings" "time" @@ -26,7 +27,7 @@ const fullDeleteGlob = "**/*" func runCheckpointGCTask( rm rm.ResourceManager, - db *db.PgDB, + pgDB *db.PgDB, taskID model.TaskID, jobID model.JobID, jobSubmissionTime time.Time, @@ -82,7 +83,8 @@ func runCheckpointGCTask( }) syslog := logrus.WithField("component", "checkpointgc").WithFields(logCtx.Fields()) - if err := db.AddTask(&model.Task{ + ctx := context.Background() + if err := db.AddTask(ctx, &model.Task{ TaskID: taskID, TaskType: model.TaskTypeCheckpointGC, StartTime: time.Now().UTC(), @@ -97,7 +99,7 @@ func runCheckpointGCTask( resultChan := make(chan error, 1) onExit := func(ae *task.AllocationExited) { - if err := db.CompleteTask(taskID, time.Now().UTC()); err != nil { + if err := db.CompleteTask(ctx, taskID, time.Now().UTC()); err != nil { syslog.WithError(err).Error("marking GC task complete") } if err := tasklist.GroupPriorityChangeRegistry.Delete(gcJobID); err != nil { @@ -119,7 +121,7 @@ func runCheckpointGCTask( SingleAgent: true, }, ResourcePool: rp, - }, db, rm, gcSpec, onExit) + }, pgDB, rm, gcSpec, onExit) if err != nil { return err } diff --git a/master/internal/command/command.go b/master/internal/command/command.go index e148b15397b..740f14e8c72 100644 --- a/master/internal/command/command.go +++ b/master/internal/command/command.go @@ -230,10 +230,11 @@ func (c *Command) OnExit(ae *task.AllocationExited) { c.exitStatus = ae - if err := c.db.CompleteTask(c.taskID, time.Now().UTC()); err != nil { + ctx := context.Background() + if err := db.CompleteTask(ctx, c.taskID, time.Now().UTC()); err != nil { c.syslog.WithError(err).Error("marking task complete") } - if err := user.DeleteSessionByToken(context.TODO(), c.GenericCommandSpec.Base.UserSessionToken); err != nil { + if err := user.DeleteSessionByToken(ctx, c.GenericCommandSpec.Base.UserSessionToken); err != nil { c.syslog.WithError(err).Errorf( "failure to delete user session for task: %v", c.taskID) } @@ -251,7 +252,7 @@ func (c *Command) garbageCollect() { } if c.exitStatus == nil { - if err := c.db.CompleteTask(c.taskID, time.Now().UTC()); err != nil { + if err := db.CompleteTask(context.Background(), c.taskID, time.Now().UTC()); err != nil { c.syslog.WithError(err).Error("marking task complete") } } diff --git a/master/internal/core.go b/master/internal/core.go index 069bbdd6574..eb637b58076 100644 --- a/master/internal/core.go +++ b/master/internal/core.go @@ -791,9 +791,9 @@ func (m *Master) restoreNonTerminalExperiments() error { return nil } -func (m *Master) closeOpenAllocations() error { +func (m *Master) closeOpenAllocations(ctx context.Context) error { allocationIds := task.DefaultService.GetAllAllocationIDs() - if err := m.db.CloseOpenAllocations(allocationIds); err != nil { + if err := db.CloseOpenAllocations(ctx, allocationIds); err != nil { return err } return nil @@ -1081,11 +1081,11 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error { return err } - if err = m.closeOpenAllocations(); err != nil { + if err = m.closeOpenAllocations(ctx); err != nil { return err } - if err = m.db.EndAllTaskStats(); err != nil { + if err = db.EndAllTaskStats(ctx); err != nil { return err } diff --git a/master/internal/db/database.go b/master/internal/db/database.go index 3166a3dd6f3..f422e3086b3 100644 --- a/master/internal/db/database.go +++ b/master/internal/db/database.go @@ -43,13 +43,8 @@ type DB interface { id int, experimentBest, trialBest, trialLatest int, ) ([]uuid.UUID, error) - AddTask(t *model.Task) error - UpdateTrial(id int, newState model.State) error UpdateTrialRunnerState(id int, state string) error UpdateTrialRunnerMetadata(id int, md *trialv1.TrialRunnerMetadata) error - AddAllocation(a *model.Allocation) error - CompleteAllocation(a *model.Allocation) error - CompleteAllocationTelemetry(aID model.AllocationID) ([]byte, error) TrialRunIDAndRestarts(trialID int) (int, int, error) UpdateTrialRunID(id, runID int) error UpdateTrialRestarts(id, restarts int) error @@ -89,11 +84,6 @@ type DB interface { trials []*apiv1.TrialsSnapshotResponse_Trial, endTime time.Time, err error) TopTrialsByTrainingLength(experimentID int, maxTrials int, metric string, smallerIsBetter bool) (trials []int32, err error) - StartAllocationSession(allocationID model.AllocationID, owner *model.User) (string, error) - DeleteAllocationSession(allocationID model.AllocationID) error - UpdateAllocationState(allocation model.Allocation) error - UpdateAllocationStartTime(allocation model.Allocation) error - UpdateAllocationProxyAddress(allocation model.Allocation) error ExperimentSnapshot(experimentID int) ([]byte, int, error) SaveSnapshot( experimentID int, version int, experimentSnapshot []byte, @@ -114,9 +104,6 @@ type DB interface { RecordInstanceStats(a *model.InstanceStats) error EndInstanceStats(a *model.InstanceStats) error EndAllInstanceStats() error - EndAllTaskStats() error - RecordTaskEndStats(stats *model.TaskStats) error - RecordTaskStats(stats *model.TaskStats) error UpdateJobPosition(jobID model.JobID, position decimal.Decimal) error } diff --git a/master/internal/db/postgres.go b/master/internal/db/postgres.go index 65980b69f6a..bfcfdb1419d 100644 --- a/master/internal/db/postgres.go +++ b/master/internal/db/postgres.go @@ -211,25 +211,6 @@ func (db *PgDB) Close() error { return db.sql.Close() } -// namedGet is a convenience method for a named query for a single value. -func (db *PgDB) namedGet(dest interface{}, query string, arg interface{}) error { - nstmt, err := db.sql.PrepareNamed(query) - if err != nil { - return errors.Wrapf(err, "error preparing query %s", query) - } - - defer nstmt.Close() - - if sErr := nstmt.QueryRowx(arg).Scan(dest); sErr != nil { - err = errors.Wrapf(sErr, "error scanning query %s", query) - } - if cErr := nstmt.Close(); cErr != nil && err != nil { - err = errors.Wrap(cErr, "error closing named DB statement") - } - - return err -} - // namedExecOne is a convenience method for a NamedExec that should affect only one row. func (db *PgDB) namedExecOne(query string, arg interface{}) error { res, err := db.sql.NamedExec(query, arg) @@ -249,25 +230,6 @@ func (db *PgDB) namedExecOne(query string, arg interface{}) error { return nil } -// namedExecOne is a convenience method for a NamedExec that should affect only one row. -func namedExecOne(tx *sqlx.Tx, query string, arg interface{}) error { - res, err := tx.NamedExec(query, arg) - if err != nil { - return errors.Wrapf(err, "error in query %v \narg %v", query, arg) - } - num, err := res.RowsAffected() - if err != nil { - return errors.Wrapf( - err, - "error checking rows affected for query %v\n arg %v", - query, arg) - } - if num != 1 { - return errors.Errorf("error: %v rows affected on query %v \narg %v", num, query, arg) - } - return nil -} - func queryBinds(fields []string) []string { binds := make([]string, 0, len(fields)) for _, field := range fields { diff --git a/master/internal/db/postgres_cluster_intg_test.go b/master/internal/db/postgres_cluster_intg_test.go index fdc0cae5e86..4ca1e1fdcb7 100644 --- a/master/internal/db/postgres_cluster_intg_test.go +++ b/master/internal/db/postgres_cluster_intg_test.go @@ -4,6 +4,7 @@ package db import ( + "context" "testing" "time" @@ -46,7 +47,8 @@ func TestClusterAPI(t *testing.T) { StartTime: time.Now().UTC().Truncate(time.Millisecond), } - err = db.AddTask(tIn) + ctx := context.Background() + err = AddTask(ctx, tIn) require.NoError(t, err, "failed to add task") // Add an allocation @@ -59,7 +61,7 @@ func TestClusterAPI(t *testing.T) { StartTime: ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)), } - err = db.AddAllocation(aIn) + err = AddAllocation(ctx, aIn) require.NoError(t, err, "failed to add allocation") // Add a cluster heartbeat after allocation, so it is as if the master died with it open. @@ -74,10 +76,10 @@ func TestClusterAPI(t *testing.T) { "Retrieved cluster heartbeat doesn't match the correct time") // Don't complete the above allocation and call CloseOpenAllocations - require.NoError(t, db.CloseOpenAllocations(nil)) + require.NoError(t, CloseOpenAllocations(ctx, nil)) // Retrieve the open allocation and check if end time is set to cluster_heartbeat - aOut, err := db.AllocationByID(aIn.AllocationID) + aOut, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.NotNil(t, aOut, "aOut is Nil") require.NotNil(t, aOut.EndTime, "aOut.EndTime is Nil") diff --git a/master/internal/db/postgres_task_logs.go b/master/internal/db/postgres_task_logs.go index cd90fbbaeca..144c4d70658 100644 --- a/master/internal/db/postgres_task_logs.go +++ b/master/internal/db/postgres_task_logs.go @@ -1,12 +1,11 @@ package db import ( + "context" "fmt" "strings" "time" - "github.com/pkg/errors" - "github.com/determined-ai/determined/master/internal/api" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/proto/pkg/apiv1" @@ -99,7 +98,7 @@ VALUES } if _, err := db.sql.Exec(text.String(), args...); err != nil { - return errors.Wrapf(err, "error inserting %d task logs", len(logs)) + return fmt.Errorf("error inserting %d task logs: %w", len(logs), err) } return nil @@ -107,11 +106,9 @@ VALUES // DeleteTaskLogs deletes the logs for the given tasks. func (db *PgDB) DeleteTaskLogs(ids []model.TaskID) error { - if _, err := db.sql.Exec(` -DELETE FROM task_logs -WHERE task_id IN (SELECT unnest($1::text [])::text); -`, ids); err != nil { - return errors.Wrapf(err, "error deleting task logs for task %v", ids) + if _, err := Bun().NewDelete().Table("task_logs"). + Where("task_id IN (SELECT unnest(?::text [])::text)", ids).Exec(context.Background()); err != nil { + return fmt.Errorf("error deleting task logs for task %v: %w", ids, err) } return nil } diff --git a/master/internal/db/postgres_tasks.go b/master/internal/db/postgres_tasks.go index cfc7d7741af..18167a60dec 100644 --- a/master/internal/db/postgres_tasks.go +++ b/master/internal/db/postgres_tasks.go @@ -7,7 +7,6 @@ import ( "strings" "time" - "github.com/jmoiron/sqlx" "github.com/o1egl/paseto" "github.com/pkg/errors" "github.com/uptrace/bun" @@ -16,25 +15,15 @@ import ( ) // initAllocationSessions purges sessions of all closed allocations. -func (db *PgDB) initAllocationSessions() error { - _, err := db.sql.Exec(` -DELETE FROM allocation_sessions WHERE allocation_id in ( - SELECT allocation_id FROM allocations - WHERE start_time IS NOT NULL AND end_time IS NOT NULL -)`) +func initAllocationSessions(ctx context.Context) error { + subq := Bun().NewSelect().Table("allocations"). + Column("allocation_id").Where("start_time IS NOT NULL AND end_time IS NOT NULL") + _, err := Bun().NewDelete().Table("allocation_sessions").Where("allocation_id in (?)", subq).Exec(ctx) return err } -// AddTask UPSERT's the existence of a task. -// -// TODO(ilia): deprecate and use module function instead. -func (db *PgDB) AddTask(t *model.Task) error { - return AddTask(context.TODO(), t) -} - // AddTask UPSERT's the existence of a task. func AddTask(ctx context.Context, t *model.Task) error { - // Since AddTaskTx is a single query, RunInTx is an overkill. return AddTaskTx(ctx, Bun(), t) } @@ -97,29 +86,9 @@ func TaskCompleted(ctx context.Context, tID model.TaskID) (bool, error) { } // CompleteTask persists the completion of a task. -func (db *PgDB) CompleteTask(tID model.TaskID, endTime time.Time) error { - return completeTask(db.sql, tID, endTime) -} - -func completeTask(ex sqlx.Execer, tID model.TaskID, endTime time.Time) error { - if _, err := ex.Exec(` -UPDATE tasks -SET end_time = $2 -WHERE task_id = $1 - `, tID, endTime); err != nil { - return errors.Wrap(err, "completing task") - } - return nil -} - -func completeTrialsTasks(ex sqlx.Execer, trialID int, endTime time.Time) error { - if _, err := ex.Exec(` -UPDATE tasks -SET end_time = $2 -FROM trial_id_task_id -WHERE trial_id_task_id.task_id = tasks.task_id - AND trial_id_task_id.trial_id = $1 - AND end_time IS NULL`, trialID, endTime); err != nil { +func CompleteTask(ctx context.Context, tID model.TaskID, endTime time.Time) error { + if _, err := Bun().NewUpdate().Table("tasks").Set("end_time = ?", endTime). + Where("task_id = ?", tID).Exec(ctx); err != nil { return fmt.Errorf("completing task: %w", err) } return nil @@ -128,75 +97,68 @@ WHERE trial_id_task_id.task_id = tasks.task_id // AddAllocation upserts the existence of an allocation. Allocation IDs may conflict in the event // the master restarts and the trial run ID increment is not persisted, but it is the same // allocation so this is OK. -func (db *PgDB) AddAllocation(a *model.Allocation) error { - return db.namedExecOne(` -INSERT INTO allocations - (task_id, allocation_id, slots, resource_pool, start_time, state, ports) -VALUES - (:task_id, :allocation_id, :slots, :resource_pool, :start_time, :state, :ports) -ON CONFLICT - (allocation_id) -DO UPDATE SET - task_id=EXCLUDED.task_id, slots=EXCLUDED.slots, resource_pool=EXCLUDED.resource_pool, - start_time=EXCLUDED.start_time, state=EXCLUDED.state, ports=EXCLUDED.ports -`, a) +func AddAllocation(ctx context.Context, a *model.Allocation) error { + _, err := Bun().NewInsert().Table("allocations").Model(a).On("CONFLICT (allocation_id) DO UPDATE"). + Set("task_id=EXCLUDED.task_id, slots=EXCLUDED.slots"). + Set("resource_pool=EXCLUDED.resource_pool,start_time=EXCLUDED.start_time"). + Set("state=EXCLUDED.state, ports=EXCLUDED.ports").Exec(ctx) + return err } // AddAllocationExitStatus adds the allocation exit status to the allocations table. func AddAllocationExitStatus(ctx context.Context, a *model.Allocation) error { - _, err := Bun().NewUpdate(). - Model(a). + if _, err := Bun().NewUpdate().Model(a). Column("exit_reason", "exit_error", "status_code"). - Where("allocation_id = ?", a.AllocationID). - Exec(ctx) - if err != nil { + Where("allocation_id = ?", a.AllocationID).Exec(ctx); err != nil { return fmt.Errorf("adding allocation exit status to db: %w", err) } return nil } // CompleteAllocation persists the end of an allocation lifetime. -func (db *PgDB) CompleteAllocation(a *model.Allocation) error { +func CompleteAllocation(ctx context.Context, a *model.Allocation) error { if a.StartTime == nil { a.StartTime = a.EndTime } - _, err := db.sql.Exec(` -UPDATE allocations -SET start_time = $2, end_time = $3 -WHERE allocation_id = $1`, a.AllocationID, a.StartTime, a.EndTime) + _, err := Bun().NewUpdate().Table("allocations").Model(a). + Set("start_time = ?, end_time = ?", a.StartTime, a.EndTime). + Where("allocation_id = ?", a.AllocationID).Exec(ctx) return err } // CompleteAllocationTelemetry returns the analytics of an allocation for the telemetry. -func (db *PgDB) CompleteAllocationTelemetry(aID model.AllocationID) ([]byte, error) { - return db.rawQuery(` -SELECT json_build_object( - 'allocation_id', a.allocation_id, - 'job_id', t.job_id, - 'task_type', t.task_type, - 'duration_sec', COALESCE(EXTRACT(EPOCH FROM (a.end_time - a.start_time)), 0) -) -FROM allocations as a JOIN tasks as t -ON a.task_id = t.task_id -WHERE a.allocation_id = $1; -`, aID) +func CompleteAllocationTelemetry(ctx context.Context, aID model.AllocationID) ([]byte, error) { + var res []byte + err := Bun().NewRaw(` + SELECT json_build_object( + 'allocation_id', a.allocation_id, + 'job_id', t.job_id, + 'task_type', t.task_type, + 'duration_sec', COALESCE(EXTRACT(EPOCH FROM (a.end_time - a.start_time)), 0) + ) + FROM allocations as a JOIN tasks as t + ON a.task_id = t.task_id + WHERE a.allocation_id = ?`, aID).Scan(ctx, &res) + return res, err } // AllocationByID retrieves an allocation by its ID. -func (db *PgDB) AllocationByID(aID model.AllocationID) (*model.Allocation, error) { +func AllocationByID(ctx context.Context, aID model.AllocationID) (*model.Allocation, error) { var a model.Allocation - if err := Bun().NewSelect().Model(&a).Where("allocation_id = ?", aID). - Scan(context.TODO()); err != nil { + if err := Bun().NewSelect().Table("allocations"). + Where("allocation_id = ?", aID).Scan(ctx, &a); err != nil { return nil, err } return &a, nil } // StartAllocationSession creates a row in the allocation_sessions table. -func (db *PgDB) StartAllocationSession( - allocationID model.AllocationID, owner *model.User, +func StartAllocationSession( + ctx context.Context, + allocationID model.AllocationID, + owner *model.User, ) (string, error) { if owner == nil { return "", errors.New("owner cannot be nil for allocation session") @@ -207,76 +169,63 @@ func (db *PgDB) StartAllocationSession( OwnerID: &owner.ID, } - query := ` -INSERT INTO allocation_sessions (allocation_id, owner_id) VALUES - (:allocation_id, :owner_id) RETURNING id` - if err := db.namedGet(&taskSession.ID, query, *taskSession); err != nil { + if _, err := Bun().NewInsert().Table("allocation_sessions"). + Model(taskSession).Returning("id").Exec(ctx, &taskSession.ID); err != nil { return "", err } v2 := paseto.NewV2() token, err := v2.Sign(GetTokenKeys().PrivateKey, taskSession, nil) if err != nil { - return "", errors.Wrap(err, "failed to generate task authentication token") + return "", fmt.Errorf("failed to generate task authentication token: %w", err) } return token, nil } // DeleteAllocationSession deletes the task session with the given AllocationID. -func (db *PgDB) DeleteAllocationSession(allocationID model.AllocationID) error { - _, err := db.sql.Exec( - "DELETE FROM allocation_sessions WHERE allocation_id=$1", allocationID) +func DeleteAllocationSession(ctx context.Context, allocationID model.AllocationID) error { + _, err := Bun().NewDelete().Table("allocation_sessions").Where("allocation_id = ?", allocationID).Exec(ctx) return err } // UpdateAllocationState stores the latest task state and readiness. -func (db *PgDB) UpdateAllocationState(a model.Allocation) error { - _, err := db.sql.Exec(` - UPDATE allocations - SET state=$2, is_ready=$3 - WHERE allocation_id=$1 - `, a.AllocationID, a.State, a.IsReady) +func UpdateAllocationState(ctx context.Context, a model.Allocation) error { + _, err := Bun().NewUpdate().Table("allocations"). + Set("state = ?, is_ready = ?", a.State, a.IsReady). + Where("allocation_id = ?", a.AllocationID).Exec(ctx) + return err } // UpdateAllocationPorts stores the latest task state and readiness. -func UpdateAllocationPorts(a model.Allocation) error { +func UpdateAllocationPorts(ctx context.Context, a model.Allocation) error { _, err := Bun().NewUpdate().Table("allocations"). Set("ports = ?", a.Ports). Where("allocation_id = ?", a.AllocationID). - Exec(context.TODO()) + Exec(ctx) return err } // UpdateAllocationStartTime stores the latest start time. -func (db *PgDB) UpdateAllocationStartTime(a model.Allocation) error { - _, err := db.sql.Exec(` - UPDATE allocations - SET start_time = $2 - WHERE allocation_id = $1 - `, a.AllocationID, a.StartTime) +func UpdateAllocationStartTime(ctx context.Context, a model.Allocation) error { + _, err := Bun().NewUpdate().Table("allocations"). + Set("start_time = ?", a.StartTime).Where("allocation_id = ?", a.AllocationID).Exec(ctx) return err } // UpdateAllocationProxyAddress stores the proxy address. -func (db *PgDB) UpdateAllocationProxyAddress(a model.Allocation) error { - _, err := db.sql.Exec(` - UPDATE allocations - SET proxy_address = $2 - WHERE allocation_id = $1 - `, a.AllocationID, a.ProxyAddress) +func UpdateAllocationProxyAddress(ctx context.Context, a model.Allocation) error { + _, err := Bun().NewUpdate().Table("allocations").Set("proxy_address = ?", a.ProxyAddress). + Where("allocation_id = ?", a.AllocationID).Exec(ctx) return err } // CloseOpenAllocations finds all allocations that were open when the master crashed // and adds an end time. -func (db *PgDB) CloseOpenAllocations(exclude []model.AllocationID) error { - if _, err := db.sql.Exec(` - UPDATE allocations - SET start_time = cluster_heartbeat FROM cluster_id - WHERE start_time is NULL`); err != nil { - return errors.Wrap(err, - "setting start time to cluster heartbeat when it's assigned to zero value") +func CloseOpenAllocations(ctx context.Context, exclude []model.AllocationID) error { + if _, err := Bun().NewUpdate().Table("allocations").Set("start_time = cluster_heartbeat FROM cluster_id"). + Where("start_time is NULL").Exec(ctx); err != nil { + return fmt.Errorf("setting start time to cluster heartbeat when it's assigned to zero value: %w", err) } excludedFilter := "" @@ -289,36 +238,33 @@ func (db *PgDB) CloseOpenAllocations(exclude []model.AllocationID) error { excludedFilter = strings.Join(excludeStr, ",") } - if _, err := db.sql.Exec(` - UPDATE allocations - SET end_time = greatest(cluster_heartbeat, start_time), state = 'TERMINATED' - FROM cluster_id - WHERE end_time IS NULL AND - ($1 = '' OR allocation_id NOT IN ( - SELECT unnest(string_to_array($1, ','))))`, excludedFilter); err != nil { - return errors.Wrap(err, "closing old allocations") + if _, err := Bun().NewUpdate().Table("allocations, cluster_id"). + Set("end_time = greatest(cluster_heartbeat, start_time), state = 'TERMINATED'"). + Where("end_time IS NULL AND (? = '' OR allocation_id NOT IN (SELECT unnest(string_to_array(?, ','))))", + excludedFilter).Exec(ctx); err != nil { + return fmt.Errorf("closing old allocations: %w", err) } return nil } // RecordTaskStats record stats for tasks. -func (db *PgDB) RecordTaskStats(stats *model.TaskStats) error { - return RecordTaskStatsBun(stats) +func RecordTaskStats(ctx context.Context, stats *model.TaskStats) error { + return RecordTaskStatsBun(ctx, stats) } // RecordTaskStatsBun record stats for tasks with bun. -func RecordTaskStatsBun(stats *model.TaskStats) error { +func RecordTaskStatsBun(ctx context.Context, stats *model.TaskStats) error { _, err := Bun().NewInsert().Model(stats).Exec(context.TODO()) return err } // RecordTaskEndStats record end stats for tasks. -func (db *PgDB) RecordTaskEndStats(stats *model.TaskStats) error { - return RecordTaskEndStatsBun(stats) +func RecordTaskEndStats(ctx context.Context, stats *model.TaskStats) error { + return RecordTaskEndStatsBun(ctx, stats) } // RecordTaskEndStatsBun record end stats for tasks with bun. -func RecordTaskEndStatsBun(stats *model.TaskStats) error { +func RecordTaskEndStatsBun(ctx context.Context, stats *model.TaskStats) error { query := Bun().NewUpdate().Model(stats).Column("end_time"). Where("allocation_id = ?", stats.AllocationID). Where("event_type = ?", stats.EventType). @@ -333,7 +279,7 @@ func RecordTaskEndStatsBun(stats *model.TaskStats) error { query = query.Where("container_id = ?", stats.ContainerID) } - if _, err := query.Exec(context.TODO()); err != nil { + if _, err := query.Exec(ctx); err != nil { return fmt.Errorf("recording task end stats %+v: %w", stats, err) } @@ -341,14 +287,12 @@ func RecordTaskEndStatsBun(stats *model.TaskStats) error { } // EndAllTaskStats called at master starts, in case master previously crashed. -func (db *PgDB) EndAllTaskStats() error { - _, err := db.sql.Exec(` -UPDATE task_stats SET end_time = greatest(cluster_heartbeat, task_stats.start_time) -FROM cluster_id, allocations -WHERE allocations.allocation_id = task_stats.allocation_id -AND allocations.end_time IS NOT NULL -AND task_stats.end_time IS NULL`) - if err != nil { +func EndAllTaskStats(ctx context.Context) error { + if _, err := Bun().NewUpdate().Table("task_stats", "cluster_id, allocations"). + Set("end_time = greatest(cluster_heartbeat, task_stats.start_time)"). + Where("allocations.allocation_id = task_stats.allocation_id"). + Where("allocations_end_time IS NOT NULL AND task_stats.end_time IS NULL"). + Exec(ctx); err != nil { return fmt.Errorf("ending all task stats: %w", err) } diff --git a/master/internal/db/postgres_tasks_intg_test.go b/master/internal/db/postgres_tasks_intg_test.go index 31e6b498625..8d782f48021 100644 --- a/master/internal/db/postgres_tasks_intg_test.go +++ b/master/internal/db/postgres_tasks_intg_test.go @@ -86,7 +86,7 @@ func TestJobTaskAndAllocationAPI(t *testing.T) { TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - err = db.AddTask(tIn) + err = AddTask(ctx, tIn) require.NoError(t, err, "failed to add task") // Retrieve it back and make sure the mapping is exhaustive. @@ -96,7 +96,7 @@ func TestJobTaskAndAllocationAPI(t *testing.T) { // Complete it. tIn.EndTime = ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)) - err = db.CompleteTask(tID, *tIn.EndTime) + err = CompleteTask(ctx, tID, *tIn.EndTime) require.NoError(t, err, "failed to mark task completed") // Re-retrieve it back and make sure the mapping is still exhaustive. @@ -120,7 +120,7 @@ func TestJobTaskAndAllocationAPI(t *testing.T) { StartTime: ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)), Ports: ports, } - err = db.AddAllocation(aIn) + err = AddAllocation(ctx, aIn) require.NoError(t, err, "failed to add allocation") // Update ports @@ -129,39 +129,41 @@ func TestJobTaskAndAllocationAPI(t *testing.T) { ports["inter_train_process_comm_port2"] = 0 ports["c10d_port"] = 0 aIn.Ports = ports - err = UpdateAllocationPorts(*aIn) + err = UpdateAllocationPorts(ctx, *aIn) require.NoError(t, err, "failed to update port offset") // Retrieve it back and make sure the mapping is exhaustive. - aOut, err := db.AllocationByID(aIn.AllocationID) + aOut, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err, "failed to retrieve allocation") require.True(t, reflect.DeepEqual(aIn, aOut), pprintedExpect(aIn, aOut)) // Complete it. aIn.EndTime = ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)) - err = db.CompleteAllocation(aIn) + err = CompleteAllocation(ctx, aIn) require.NoError(t, err, "failed to mark allocation completed") // Re-retrieve it back and make sure the mapping is still exhaustive. - aOut, err = db.AllocationByID(aIn.AllocationID) + aOut, err = AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err, "failed to re-retrieve allocation") require.True(t, reflect.DeepEqual(aIn, aOut), pprintedExpect(aIn, aOut)) } func TestRecordAndEndTaskStats(t *testing.T) { + ctx := context.Background() require.NoError(t, etc.SetRootPath(RootFromDB)) db := MustResolveTestPostgres(t) MustMigrateTestPostgres(t, db, MigrationsFromDB) tID := model.NewTaskID() - require.NoError(t, db.AddTask(&model.Task{ - TaskID: tID, - TaskType: model.TaskTypeTrial, - StartTime: time.Now().UTC().Truncate(time.Millisecond), - }), "failed to add task") + require.NoError(t, AddTask(context.Background(), + &model.Task{ + TaskID: tID, + TaskType: model.TaskTypeTrial, + StartTime: time.Now().UTC().Truncate(time.Millisecond), + }), "failed to add task") allocationID := model.AllocationID(tID + "allocationID") - require.NoError(t, db.AddAllocation(&model.Allocation{ + require.NoError(t, AddAllocation(context.Background(), &model.Allocation{ TaskID: tID, AllocationID: allocationID, }), "failed to add allocation") @@ -177,10 +179,10 @@ func TestRecordAndEndTaskStats(t *testing.T) { if i == 0 { taskStats.ContainerID = nil } - require.NoError(t, RecordTaskStatsBun(taskStats)) + require.NoError(t, RecordTaskStatsBun(ctx, taskStats)) taskStats.EndTime = ptrs.Ptr(time.Now().Truncate(time.Millisecond)) - require.NoError(t, RecordTaskEndStatsBun(taskStats)) + require.NoError(t, RecordTaskEndStatsBun(ctx, taskStats)) expected = append(expected, taskStats) } @@ -193,7 +195,7 @@ func TestRecordAndEndTaskStats(t *testing.T) { require.ElementsMatch(t, expected, actual) - err = db.EndAllTaskStats() + err = EndAllTaskStats(ctx) require.NoError(t, err) } @@ -209,11 +211,12 @@ func TestNonExperimentTasksContextDirectory(t *testing.T) { // Nil context directory. tID := model.NewTaskID() - require.NoError(t, db.AddTask(&model.Task{ - TaskID: tID, - TaskType: model.TaskTypeNotebook, - StartTime: time.Now().UTC().Truncate(time.Millisecond), - }), "failed to add task") + require.NoError(t, AddTask(context.Background(), + &model.Task{ + TaskID: tID, + TaskType: model.TaskTypeNotebook, + StartTime: time.Now().UTC().Truncate(time.Millisecond), + }), "failed to add task") require.NoError(t, AddNonExperimentTasksContextDirectory(ctx, tID, nil)) @@ -223,11 +226,12 @@ func TestNonExperimentTasksContextDirectory(t *testing.T) { // Non nil context directory. tID = model.NewTaskID() - require.NoError(t, db.AddTask(&model.Task{ - TaskID: tID, - TaskType: model.TaskTypeNotebook, - StartTime: time.Now().UTC().Truncate(time.Millisecond), - }), "failed to add task") + require.NoError(t, AddTask(context.Background(), + &model.Task{ + TaskID: tID, + TaskType: model.TaskTypeNotebook, + StartTime: time.Now().UTC().Truncate(time.Millisecond), + }), "failed to add task") expectedDir := []byte{3, 2, 1} require.NoError(t, AddNonExperimentTasksContextDirectory(ctx, tID, expectedDir)) @@ -238,6 +242,7 @@ func TestNonExperimentTasksContextDirectory(t *testing.T) { } func TestAllocationState(t *testing.T) { + ctx := context.Background() require.NoError(t, etc.SetRootPath(RootFromDB)) db := MustResolveTestPostgres(t) MustMigrateTestPostgres(t, db, MigrationsFromDB) @@ -259,7 +264,7 @@ func TestAllocationState(t *testing.T) { TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - require.NoError(t, db.AddTask(task), "failed to add task") + require.NoError(t, AddTask(context.Background(), task), "failed to add task") s := state a := &model.Allocation{ @@ -268,7 +273,7 @@ func TestAllocationState(t *testing.T) { ResourcePool: "default", State: &s, } - require.NoError(t, db.AddAllocation(a), "failed to add allocation") + require.NoError(t, AddAllocation(context.Background(), a), "failed to add allocation") // Update allocation to every possible state. testNoUpdate := true @@ -278,7 +283,7 @@ func TestAllocationState(t *testing.T) { j-- // Go to first iteration of loop after this. } else { a.State = &states[j] - require.NoError(t, db.UpdateAllocationState(*a), + require.NoError(t, UpdateAllocationState(ctx, *a), "failed to update allocation state") } @@ -437,15 +442,15 @@ func TestAddNonExperimentTasksContextDirectory(t *testing.T) { } func TestTaskCompleted(t *testing.T) { - ctx := context.Background() db := MustResolveTestPostgres(t) + ctx := context.Background() tIn := RequireMockTask(t, db, nil) completed, err := TaskCompleted(ctx, tIn.TaskID) require.False(t, completed) require.NoError(t, err) - err = db.CompleteTask(tIn.TaskID, time.Now().UTC().Truncate(time.Millisecond)) + err = CompleteTask(ctx, tIn.TaskID, time.Now().UTC().Truncate(time.Millisecond)) require.NoError(t, err) completed, err = TaskCompleted(ctx, tIn.TaskID) @@ -455,6 +460,7 @@ func TestTaskCompleted(t *testing.T) { func TestAddAllocation(t *testing.T) { db := MustResolveTestPostgres(t) + tIn := RequireMockTask(t, db, nil) a := model.Allocation{ AllocationID: model.AllocationID(fmt.Sprintf("%s-1", tIn.TaskID)), @@ -463,7 +469,7 @@ func TestAddAllocation(t *testing.T) { State: ptrs.Ptr(model.AllocationStateTerminated), } - err := db.AddAllocation(&a) + err := AddAllocation(context.Background(), &a) require.NoError(t, err, "failed to add allocation") var res model.Allocation @@ -479,6 +485,7 @@ func TestAddAllocation(t *testing.T) { func TestAddAllocationExitStatus(t *testing.T) { db := MustResolveTestPostgres(t) + ctx := context.Background() tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) @@ -493,7 +500,7 @@ func TestAddAllocationExitStatus(t *testing.T) { err := AddAllocationExitStatus(context.Background(), aIn) require.NoError(t, err) - res, err := db.AllocationByID(aIn.AllocationID) + res, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.ExitErr, res.ExitErr) require.Equal(t, aIn.ExitReason, res.ExitReason) @@ -508,10 +515,10 @@ func TestCompleteAllocation(t *testing.T) { aIn.EndTime = ptrs.Ptr(time.Now().UTC()) - err := db.CompleteAllocation(aIn) + err := CompleteAllocation(context.Background(), aIn) require.NoError(t, err) - res, err := db.AllocationByID(aIn.AllocationID) + res, err := AllocationByID(context.Background(), aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.EndTime, res.EndTime) } @@ -522,7 +529,7 @@ func TestCompleteAllocationTelemetry(t *testing.T) { tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) - bytes, err := db.CompleteAllocationTelemetry(aIn.AllocationID) + bytes, err := CompleteAllocationTelemetry(context.Background(), aIn.AllocationID) require.NoError(t, err) require.Contains(t, string(bytes), string(aIn.AllocationID)) require.Contains(t, string(bytes), string(*tIn.JobID)) @@ -535,19 +542,20 @@ func TestAllocationByID(t *testing.T) { tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) - a, err := db.AllocationByID(aIn.AllocationID) + a, err := AllocationByID(context.Background(), aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn, a) } func TestAllocationSessionFlow(t *testing.T) { + ctx := context.Background() db := MustResolveTestPostgres(t) uIn := RequireMockUser(t, db) tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) - tok, err := db.StartAllocationSession(aIn.AllocationID, &uIn) + tok, err := StartAllocationSession(ctx, aIn.AllocationID, &uIn) require.NoError(t, err) require.NotNil(t, tok) @@ -557,14 +565,14 @@ func TestAllocationSessionFlow(t *testing.T) { running := model.AllocationStatePulling aIn.State = &running - err = db.UpdateAllocationState(*aIn) + err = UpdateAllocationState(ctx, *aIn) require.NoError(t, err) - a, err := db.AllocationByID(aIn.AllocationID) + a, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.State, a.State) - err = db.DeleteAllocationSession(aIn.AllocationID) + err = DeleteAllocationSession(ctx, aIn.AllocationID) require.NoError(t, err) as, err = allocationSessionByID(t, aIn.AllocationID) @@ -574,15 +582,16 @@ func TestAllocationSessionFlow(t *testing.T) { func TestUpdateAllocation(t *testing.T) { db := MustResolveTestPostgres(t) + ctx := context.Background() tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) // Testing UpdateAllocation Ports aIn.Ports = map[string]int{"abc": 123, "def": 456} - err := UpdateAllocationPorts(*aIn) + err := UpdateAllocationPorts(ctx, *aIn) require.NoError(t, err) - a, err := db.AllocationByID(aIn.AllocationID) + a, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.Ports, a.Ports) @@ -590,10 +599,10 @@ func TestUpdateAllocation(t *testing.T) { newStartTime := ptrs.Ptr(time.Now().UTC()) aIn.StartTime = newStartTime - err = db.UpdateAllocationStartTime(*aIn) + err = UpdateAllocationStartTime(ctx, *aIn) require.NoError(t, err) - a, err = db.AllocationByID(aIn.AllocationID) + a, err = AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.StartTime, a.StartTime) @@ -601,10 +610,10 @@ func TestUpdateAllocation(t *testing.T) { proxyAddr := "here" aIn.ProxyAddress = &proxyAddr - err = db.UpdateAllocationProxyAddress(*aIn) + err = UpdateAllocationProxyAddress(ctx, *aIn) require.NoError(t, err) - a, err = db.AllocationByID(aIn.AllocationID) + a, err = AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.ProxyAddress, a.ProxyAddress) } @@ -612,10 +621,11 @@ func TestUpdateAllocation(t *testing.T) { func TestCloseOpenAllocations(t *testing.T) { db := MustResolveTestPostgres(t) + ctx := context.Background() + // Create test allocations, with a NULL end time. t1In := RequireMockTask(t, db, nil) a1In := RequireMockAllocation(t, db, t1In.TaskID) - t2In := RequireMockTask(t, db, nil) a2In := RequireMockAllocation(t, db, t2In.TaskID) @@ -625,22 +635,22 @@ func TestCloseOpenAllocations(t *testing.T) { a2In.State = &terminated // Close only a2In open allocations (filter out the rest). - err := db.CloseOpenAllocations([]model.AllocationID{a1In.AllocationID}) + err := CloseOpenAllocations(ctx, []model.AllocationID{a1In.AllocationID}) require.NoError(t, err) - a1, err := db.AllocationByID(a1In.AllocationID) + a1, err := AllocationByID(ctx, a1In.AllocationID) require.NoError(t, err) require.Nil(t, a1.EndTime) - a2, err := db.AllocationByID(a2In.AllocationID) + a2, err := AllocationByID(ctx, a2In.AllocationID) require.NoError(t, err) require.NotNil(t, a2.EndTime) // Close the rest of the open allocations. - err = db.CloseOpenAllocations([]model.AllocationID{}) + err = CloseOpenAllocations(ctx, []model.AllocationID{}) require.NoError(t, err) - a1, err = db.AllocationByID(a1In.AllocationID) + a1, err = AllocationByID(ctx, a1In.AllocationID) require.NoError(t, err) require.NotNil(t, a1.EndTime) } diff --git a/master/internal/db/postgres_test_utils.go b/master/internal/db/postgres_test_utils.go index 57fce9ca107..2444e00bc4b 100644 --- a/master/internal/db/postgres_test_utils.go +++ b/master/internal/db/postgres_test_utils.go @@ -180,79 +180,11 @@ func RequireMockJob(t *testing.T, db *PgDB, userID *model.UserID) model.JobID { OwnerID: userID, QPos: decimal.New(0, 0), } - err := db.AddJob(jIn) + err := AddJobTx(context.TODO(), Bun(), jIn) require.NoError(t, err, "failed to add job") return jID } -// RequireMockCommandID creates a mock command and returns a command ID. -func RequireMockCommandID(t *testing.T, db *PgDB, userID model.UserID) model.TaskID { - task := RequireMockTask(t, db, &userID) - alloc := RequireMockAllocation(t, db, task.TaskID) - - mockCommand := struct { - bun.BaseModel `bun:"table:command_state"` - - TaskID model.TaskID - AllocationID model.AllocationID - GenericCommandSpec map[string]any - }{ - TaskID: task.TaskID, - AllocationID: alloc.AllocationID, - GenericCommandSpec: map[string]any{ - "TaskType": model.TaskTypeCommand, - "Metadata": map[string]any{ - "workspace_id": 1, - }, - "Base": map[string]any{ - "Owner": map[string]any{ - "id": userID, - }, - }, - }, - } - _, err := Bun().NewInsert().Model(&mockCommand).Exec(context.TODO()) - require.NoError(t, err) - - return task.TaskID -} - -// RequireMockTensorboardID creates a mock tensorboard and returns a tensorboard ID. -func RequireMockTensorboardID( - t *testing.T, db *PgDB, userID model.UserID, expIDs, trialIDs []int, -) model.TaskID { - task := RequireMockTask(t, db, &userID) - alloc := RequireMockAllocation(t, db, task.TaskID) - - mockTensorboard := struct { - bun.BaseModel `bun:"table:command_state"` - - TaskID model.TaskID - AllocationID model.AllocationID - GenericCommandSpec map[string]any - }{ - TaskID: task.TaskID, - AllocationID: alloc.AllocationID, - GenericCommandSpec: map[string]any{ - "TaskType": model.TaskTypeTensorboard, - "Metadata": map[string]any{ - "workspace_id": 1, - "experiment_ids": expIDs, - "trial_ids": trialIDs, - }, - "Base": map[string]any{ - "Owner": map[string]any{ - "id": userID, - }, - }, - }, - } - _, err := Bun().NewInsert().Model(&mockTensorboard).Exec(context.TODO()) - require.NoError(t, err) - - return task.TaskID -} - // RequireMockWorkspaceID returns a mock workspace ID. func RequireMockWorkspaceID(t *testing.T, db *PgDB) int { mockWorkspace := struct { @@ -315,7 +247,7 @@ func RequireMockTask(t *testing.T, db *PgDB, userID *model.UserID) *model.Task { TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - err := db.AddTask(tIn) + err := AddTask(context.Background(), tIn) require.NoError(t, err, "failed to add task") return tIn } @@ -470,7 +402,7 @@ func RequireMockAllocation(t *testing.T, db *PgDB, tID model.TaskID) *model.Allo StartTime: ptrs.Ptr(time.Now().UTC()), State: ptrs.Ptr(model.AllocationStateTerminated), } - err := db.AddAllocation(&a) + err := AddAllocation(context.Background(), &a) require.NoError(t, err, "failed to add allocation") return &a } diff --git a/master/internal/db/postgres_trial.go b/master/internal/db/postgres_trial.go index 7f57fcb38f5..9d7341a658b 100644 --- a/master/internal/db/postgres_trial.go +++ b/master/internal/db/postgres_trial.go @@ -123,10 +123,10 @@ func TrialByTaskID(ctx context.Context, taskID model.TaskID) (*model.Trial, erro // UpdateTrial updates an existing trial. Fields that are nil or zero are not // updated. end_time is set if the trial moves to a terminal state. -func (db *PgDB) UpdateTrial(id int, newState model.State) error { - trial, err := TrialByID(context.TODO(), id) +func UpdateTrial(ctx context.Context, id int, newState model.State) error { + trial, err := TrialByID(ctx, id) if err != nil { - return errors.Wrapf(err, "error finding trial %v to update", id) + return fmt.Errorf("error finding trial %v to update: %w", id, err) } if trial.State == newState { @@ -134,8 +134,8 @@ func (db *PgDB) UpdateTrial(id int, newState model.State) error { } if !model.TrialTransitions[trial.State][newState] { - return errors.Errorf("illegal transition %v -> %v for trial %v", - trial.State, newState, trial.ID) + return fmt.Errorf("illegal transition %v -> %v for trial %v: %w", + trial.State, newState, trial.ID, err) } toUpdate := []string{"state"} trial.State = newState @@ -145,19 +145,20 @@ func (db *PgDB) UpdateTrial(id int, newState model.State) error { toUpdate = append(toUpdate, "end_time") } - return db.withTransaction("update_trial", func(tx *sqlx.Tx) error { + return Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { // Only the trial actor updates this row, and it does so in a serialized // fashion already, so this transaction is more a matter of atomicity. - if err := namedExecOne(tx, fmt.Sprintf(` -UPDATE runs -%v -WHERE id = :id`, SetClause(toUpdate)), trial); err != nil { - return errors.Wrapf(err, "error updating (%v) in trial %v", - strings.Join(toUpdate, ", "), id) + if _, err := tx.NewUpdate().Table("trials"). + Column("?", strings.Join(toUpdate, ",")).Where("id = ?", trial.ID).Exec(ctx); err != nil { + return fmt.Errorf("error updating (%v) in trial %v: %w", strings.Join(toUpdate, ", "), id, err) } if model.TerminalStates[newState] && trial.EndTime != nil { - return completeTrialsTasks(tx, id, *trial.EndTime) + if _, err := tx.NewUpdate().Table("trial_id_task_id").Set("end_time = ?", *trial.EndTime). + Where("trial_id_task_id.task_id = tasks.task_id AND trial_id_task_id.trial_id = ?", id). + Exec(ctx); err != nil { + return fmt.Errorf("completing task: %w", err) + } } return nil diff --git a/master/internal/db/postgres_trial_intg_test.go b/master/internal/db/postgres_trial_intg_test.go index 3b8d3716545..4bb7c09d8eb 100644 --- a/master/internal/db/postgres_trial_intg_test.go +++ b/master/internal/db/postgres_trial_intg_test.go @@ -859,9 +859,9 @@ func TestProtoGetTrial(t *testing.T) { StartTime: ptrs.Ptr(startTime.Add(time.Duration(i) * time.Second)), EndTime: ptrs.Ptr(startTime.Add(time.Duration(i+1) * time.Second)), } - err = db.AddAllocation(a) + err = AddAllocation(ctx, a) require.NoError(t, err, "failed to add allocation") - err = db.CompleteAllocation(a) + err = CompleteAllocation(ctx, a) require.NoError(t, err, "failed to complete allocation") } @@ -909,7 +909,7 @@ func TestAddValidationMetricsDupeCheckpoints(t *testing.T) { TaskID: task.TaskID, StartTime: ptrs.Ptr(time.Now()), } - require.NoError(t, db.AddAllocation(a)) + require.NoError(t, AddAllocation(ctx, a)) // Report training metrics. require.NoError(t, db.AddTrainingMetrics(ctx, &trialv1.TrialMetrics{ @@ -934,7 +934,7 @@ func TestAddValidationMetricsDupeCheckpoints(t *testing.T) { TaskID: task.TaskID, StartTime: ptrs.Ptr(time.Now()), } - require.NoError(t, db.AddAllocation(a)) + require.NoError(t, AddAllocation(ctx, a)) require.NoError(t, db.UpdateTrialRunID(tr.ID, 1)) // Now trial runs validation. @@ -1008,7 +1008,7 @@ func TestBatchesProcessedNRollbacks(t *testing.T) { TaskID: task.TaskID, StartTime: ptrs.Ptr(time.Now()), } - err = db.AddAllocation(a) + err = AddAllocation(ctx, a) require.NoError(t, err, "failed to add allocation") metrics, err := structpb.NewStruct(map[string]any{"loss": 10}) @@ -1132,7 +1132,7 @@ func TestGenericMetricsIO(t *testing.T) { TaskID: task.TaskID, StartTime: ptrs.Ptr(time.Now()), } - err = db.AddAllocation(a) + err = AddAllocation(ctx, a) require.NoError(t, err, "failed to add allocation") metrics, err := structpb.NewStruct(map[string]any{ @@ -1255,7 +1255,7 @@ func TestConcurrentMetricUpdate(t *testing.T) { TaskID: task.TaskID, StartTime: ptrs.Ptr(time.Now()), } - err := db.AddAllocation(a) + err := AddAllocation(ctx, a) require.NoError(t, err, "failed to add allocation") dbTr, err := TrialByID(ctx, tr.ID) diff --git a/master/internal/db/setup.go b/master/internal/db/setup.go index 5fbe7dca9aa..a2bffc8e4c3 100644 --- a/master/internal/db/setup.go +++ b/master/internal/db/setup.go @@ -98,7 +98,7 @@ func Setup(opts *config.DBConfig) (*PgDB, error) { return nil, err } - if err = db.initAllocationSessions(); err != nil { + if err = initAllocationSessions(context.Background()); err != nil { return nil, err } return db, nil diff --git a/master/internal/populate_metrics.go b/master/internal/populate_metrics.go index 973b6c9e986..c03c3291397 100644 --- a/master/internal/populate_metrics.go +++ b/master/internal/populate_metrics.go @@ -250,7 +250,7 @@ func PopulateExpTrialsMetrics(pgdb *db.PgDB, masterConfig *config.Config, trivia TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - if err = pgdb.AddTask(tIn); err != nil { + if err = db.AddTask(ctx, tIn); err != nil { return err } diff --git a/master/internal/rm/agentrm/agent.go b/master/internal/rm/agentrm/agent.go index fdd98a83b36..218f9f321e1 100644 --- a/master/internal/rm/agentrm/agent.go +++ b/master/internal/rm/agentrm/agent.go @@ -1,6 +1,7 @@ package agentrm import ( + "context" "fmt" "net" "os" @@ -593,6 +594,8 @@ func (a *agent) HandleIncomingWebsocketMessage(msg *aproto.MasterMessage) { a.mu.Lock() defer a.mu.Unlock() + ctx := context.Background() + switch { case msg.AgentStarted != nil: a.syslog.Infof("agent connected ip: %v resource pool: %s slots: %d", @@ -647,9 +650,9 @@ func (a *agent) HandleIncomingWebsocketMessage(msg *aproto.MasterMessage) { if a.taskNeedsRecording(msg.ContainerStatsRecord) { var err error if msg.ContainerStatsRecord.EndStats { - err = db.RecordTaskEndStatsBun(msg.ContainerStatsRecord.Stats) + err = db.RecordTaskEndStatsBun(ctx, msg.ContainerStatsRecord.Stats) } else { - err = db.RecordTaskStatsBun(msg.ContainerStatsRecord.Stats) + err = db.RecordTaskStatsBun(ctx, msg.ContainerStatsRecord.Stats) } if err != nil { a.syslog.Errorf("error recording task stats %s", err) diff --git a/master/internal/rm/agentrm/agent_state_test.go b/master/internal/rm/agentrm/agent_state_test.go index a2f389c9bd3..51f7ae15d0e 100644 --- a/master/internal/rm/agentrm/agent_state_test.go +++ b/master/internal/rm/agentrm/agent_state_test.go @@ -75,7 +75,7 @@ func TestAgentStatePersistence(t *testing.T) { // Run through some container states. tID := model.TaskID(uuid.NewString()) - err = db.SingleDB().AddTask(&model.Task{ + err = db.AddTask(context.Background(), &model.Task{ TaskID: tID, JobID: nil, TaskType: model.TaskTypeCommand, @@ -85,7 +85,7 @@ func TestAgentStatePersistence(t *testing.T) { require.NoError(t, err) aID := model.AllocationID(uuid.NewString()) - err = db.SingleDB().AddAllocation(&model.Allocation{ + err = db.AddAllocation(context.Background(), &model.Allocation{ AllocationID: aID, TaskID: tID, Slots: 2, diff --git a/master/internal/task/allocation.go b/master/internal/task/allocation.go index dc957a1c7cd..05c7219c3ce 100644 --- a/master/internal/task/allocation.go +++ b/master/internal/task/allocation.go @@ -323,7 +323,7 @@ func (a *allocation) Signal(sig AllocationSignal, reason string) { // SetProxyAddress sets the proxy address of the allocation and sets up proxies for any services // it provides. -func (a *allocation) SetProxyAddress(_ context.Context, address string) error { +func (a *allocation) SetProxyAddress(ctx context.Context, address string) error { a.mu.Lock() defer a.mu.Unlock() @@ -332,7 +332,7 @@ func (a *allocation) SetProxyAddress(_ context.Context, address string) error { return nil } a.model.ProxyAddress = &address - if err := a.db.UpdateAllocationProxyAddress(a.model); err != nil { + if err := db.UpdateAllocationProxyAddress(ctx, a.model); err != nil { a.crash(err) return err } @@ -347,12 +347,12 @@ func (a *allocation) SendContainerLog(log *sproto.ContainerLog) { } // SetWaiting moves the allocation to the waiting state if it has not progressed past it yet. -func (a *allocation) SetWaiting(_ context.Context) error { +func (a *allocation) SetWaiting(ctx context.Context) error { a.mu.Lock() defer a.mu.Unlock() a.setMostProgressedModelState(model.AllocationStateWaiting) - if err := a.db.UpdateAllocationState(a.model); err != nil { + if err := db.UpdateAllocationState(ctx, a.model); err != nil { a.crash(err) return err } @@ -361,7 +361,7 @@ func (a *allocation) SetWaiting(_ context.Context) error { // SetReady sets the ready bit and moves the allocation to the running state if it has not // progressed past it already. -func (a *allocation) SetReady(_ context.Context) error { +func (a *allocation) SetReady(ctx context.Context) error { a.mu.Lock() defer a.mu.Unlock() @@ -371,7 +371,7 @@ func (a *allocation) SetReady(_ context.Context) error { a.sendTaskLog(&model.TaskLog{Log: fmt.Sprintf("Service of %s is available", a.req.Name)}) a.setMostProgressedModelState(model.AllocationStateRunning) a.model.IsReady = ptrs.Ptr(true) - if err := a.db.UpdateAllocationState(a.model); err != nil { + if err := db.UpdateAllocationState(ctx, a.model); err != nil { a.crash(err) return err } @@ -388,7 +388,7 @@ func (a *allocation) persistRendezvousComplete() error { if a.model.IsReady == nil || (a.model.IsReady != nil && !*a.model.IsReady) { a.syslog.Info("all containers are connected successfully (task container state changed)") a.model.IsReady = ptrs.Ptr(true) - if err := a.db.UpdateAllocationState(a.model); err != nil { + if err := db.UpdateAllocationState(context.Background(), a.model); err != nil { return err } } @@ -520,7 +520,7 @@ func (a *allocation) requestResources() (*sproto.ResourcesSubscription, error) { a.syslog.Debug("requestResources add allocation") a.setModelState(model.AllocationStatePending) - if err := a.db.AddAllocation(&a.model); err != nil { + if err := db.AddAllocation(context.Background(), &a.model); err != nil { return nil, errors.Wrap(err, "saving trial allocation") } @@ -573,7 +573,7 @@ func (a *allocation) finalize( } a.setMostProgressedModelState(model.AllocationStateTerminated) - if err := a.db.UpdateAllocationState(a.model); err != nil { + if err := db.UpdateAllocationState(context.Background(), a.model); err != nil { a.syslog.WithError(err).Error("failed to set allocation state to terminated") } a.purgeRestorableResources() @@ -617,12 +617,12 @@ func (a *allocation) resourcesAllocated(msg *sproto.ResourcesAllocated) error { } }) - if err := a.db.UpdateAllocationState(a.model); err != nil { + if err := db.UpdateAllocationState(context.Background(), a.model); err != nil { return errors.Wrap(err, "updating allocation state") } now := time.Now().UTC() - err = a.db.RecordTaskStats(&model.TaskStats{ + err = db.RecordTaskStats(context.Background(), &model.TaskStats{ AllocationID: msg.ID, EventType: "QUEUED", StartTime: &msg.JobSubmissionTime, @@ -672,7 +672,7 @@ func (a *allocation) resourcesAllocated(msg *sproto.ResourcesAllocated) error { } else { spec := a.specifier.ToTaskSpec() - token, err := a.db.StartAllocationSession(a.model.AllocationID, spec.Owner) + token, err := db.StartAllocationSession(context.Background(), a.model.AllocationID, spec.Owner) if err != nil { return errors.Wrap(err, "starting a new allocation session") } @@ -687,7 +687,7 @@ func (a *allocation) resourcesAllocated(msg *sproto.ResourcesAllocated) error { } }) - err = db.UpdateAllocationPorts(a.model) + err = db.UpdateAllocationPorts(context.Background(), a.model) if err != nil { return fmt.Errorf("updating allocation db") } @@ -841,7 +841,7 @@ func (a *allocation) resourcesStateChanged(msg *sproto.ResourcesStateChanged) { } } - if err := a.db.UpdateAllocationState(a.model); err != nil { + if err := db.UpdateAllocationState(context.Background(), a.model); err != nil { a.syslog.Error(err) } } @@ -851,7 +851,9 @@ func (a *allocation) restoreResourceFailure(msg *sproto.ResourcesRestoreError) { a.syslog.Debugf("allocation resource failure") a.setMostProgressedModelState(model.AllocationStateTerminating) - if err := a.db.UpdateAllocationState(a.model); err != nil { + ctx := context.Background() + + if err := db.UpdateAllocationState(ctx, a.model); err != nil { a.syslog.Error(err) } @@ -869,7 +871,7 @@ func (a *allocation) restoreResourceFailure(msg *sproto.ResourcesRestoreError) { a.model.EndTime = ptrs.Ptr(time.Now().UTC()) } - if err := a.db.CompleteAllocation(&a.model); err != nil { + if err := db.CompleteAllocation(ctx, &a.model); err != nil { a.syslog.WithError(err).Error("failed to mark allocation completed") } @@ -1191,7 +1193,7 @@ func (a *allocation) markResourcesStarted() { } else { a.sendTaskLog(&model.TaskLog{Log: fmt.Sprintf("%s was assigned to an agent", a.req.Name)}) } - if err := a.db.UpdateAllocationStartTime(a.model); err != nil { + if err := db.UpdateAllocationStartTime(context.Background(), a.model); err != nil { a.syslog. WithError(err). Errorf("allocation will not be properly accounted for") @@ -1200,18 +1202,19 @@ func (a *allocation) markResourcesStarted() { // markResourcesReleased persists completion information. func (a *allocation) markResourcesReleased() { - if err := a.db.DeleteAllocationSession(a.model.AllocationID); err != nil { + ctx := context.Background() + if err := db.DeleteAllocationSession(ctx, a.model.AllocationID); err != nil { a.syslog.WithError(err).Error("error deleting allocation session") } if a.model.StartTime == nil { return } a.model.EndTime = ptrs.Ptr(time.Now().UTC()) - if err := a.db.CompleteAllocation(&a.model); err != nil { + if err := db.CompleteAllocation(ctx, &a.model); err != nil { a.syslog.WithError(err).Error("failed to mark allocation completed") } - telemetry.ReportAllocationTerminal(a.db, a.model, a.resources.firstDevice()) + telemetry.ReportAllocationTerminal(a.model, a.resources.firstDevice()) } func (a *allocation) purgeRestorableResources() { diff --git a/master/internal/task/allocation_service_test.go b/master/internal/task/allocation_service_test.go index 8b42cd2eb9a..2448be891d5 100644 --- a/master/internal/task/allocation_service_test.go +++ b/master/internal/task/allocation_service_test.go @@ -44,7 +44,7 @@ func TestRestoreFailed(t *testing.T) { FailureType: sproto.RestoreError, ErrMsg: "things weren't there", }) - requireTerminated(t, db, id, exitFuture) + requireTerminated(t, id, exitFuture) } func TestInvalidResourcesRequest(t *testing.T) { @@ -55,7 +55,7 @@ func TestInvalidResourcesRequest(t *testing.T) { q.Put(&sproto.InvalidResourcesRequestError{ Cause: fmt.Errorf("eternal gke quota error"), }) - requireTerminated(t, db, id, exitFuture) + requireTerminated(t, id, exitFuture) } type checkWriter struct { @@ -100,7 +100,7 @@ func TestSetReady(t *testing.T) { err := DefaultService.SetReady(context.TODO(), id) require.NoError(t, err) - state, dbState := requireState(t, db, id, model.AllocationStateRunning) + state, dbState := requireState(t, id, model.AllocationStateRunning) require.True(t, state.Ready) require.NotNil(t, dbState.IsReady) require.True(t, *dbState.IsReady) @@ -113,7 +113,7 @@ func TestSetWaiting(t *testing.T) { err := DefaultService.SetWaiting(context.TODO(), id) require.NoError(t, err) - requireState(t, db, id, model.AllocationStateWaiting) + requireState(t, id, model.AllocationStateWaiting) } func TestSetProxyAddress(t *testing.T) { @@ -130,7 +130,7 @@ func TestSetProxyAddress(t *testing.T) { err := DefaultService.SetProxyAddress(context.TODO(), id, addr) require.NoError(t, err) - _, dbState := requireState(t, db, id, model.AllocationStatePending) + _, dbState := requireState(t, id, model.AllocationStatePending) require.NotNil(t, dbState.ProxyAddress) require.Equal(t, addr, *dbState.ProxyAddress) @@ -164,7 +164,7 @@ func TestServiceRendezvous(t *testing.T) { }, }, }) - requireState(t, db, id, model.AllocationStateRunning) + requireState(t, id, model.AllocationStateRunning) info, err := DefaultService.WatchRendezvous(context.TODO(), id, rID) require.NoError(t, err) @@ -217,7 +217,7 @@ func TestGracefullyTerminateAfterRestart(t *testing.T) { }, }, }) - requireState(t, pgDB, ar.AllocationID, model.AllocationStateRunning) + requireState(t, ar.AllocationID, model.AllocationStateRunning) t.Log("do rendezvous (sets ready bit)") info, err := DefaultService.WatchRendezvous(context.TODO(), ar.AllocationID, rID) @@ -302,7 +302,7 @@ func TestAllGather(t *testing.T) { }, }, }) - requireState(t, db, id, model.AllocationStateRunning) + requireState(t, id, model.AllocationStateRunning) wID := uuid.New() msg := "hello world" @@ -354,14 +354,14 @@ func TestPreemption(t *testing.T) { ResourcesID: rID, ResourcesState: sproto.Starting, }) - requireState(t, db, id, model.AllocationStateStarting) + requireState(t, id, model.AllocationStateStarting) q.Put(&sproto.ResourcesStateChanged{ ResourcesID: rID, ResourcesState: sproto.Running, ResourcesStarted: &sproto.ResourcesStarted{}, }) - requireState(t, db, id, model.AllocationStateRunning) + requireState(t, id, model.AllocationStateRunning) err := DefaultService.SetReady(context.Background(), id) require.NoError(t, err) @@ -380,7 +380,7 @@ func TestPreemption(t *testing.T) { ResourcesState: sproto.Terminated, ResourcesStopped: &sproto.ResourcesStopped{}, }) - requireTerminated(t, db, id, exitFuture) + requireTerminated(t, id, exitFuture) }) } } @@ -410,7 +410,7 @@ func TestSignalBeforeLaunch(t *testing.T) { err := DefaultService.Signal(id, tt.args.sig, "some severe reason") require.NoError(t, err) - exit := requireTerminated(t, db, id, exitFuture) + exit := requireTerminated(t, id, exitFuture) require.NoError(t, exit.Err) require.True(t, rm.AssertExpectations(t), "rm didn't receive release in time") }) @@ -444,7 +444,7 @@ func TestSignalBeforeReady(t *testing.T) { err := DefaultService.Signal(id, tt.args.sig, "some severe reason") require.NoError(t, err) - exit := requireTerminated(t, db, id, exitFuture) + exit := requireTerminated(t, id, exitFuture) require.NoError(t, exit.Err) require.True(t, rm.AssertExpectations(t), "rm didn't receive release in time") }) @@ -463,7 +463,7 @@ func TestSetResourcesDaemon(t *testing.T) { for _, rID := range ranked[1:] { err := DefaultService.SetResourcesAsDaemon(context.TODO(), id, rID) require.NoError(t, err) - requireState(t, db, id, model.AllocationStateAssigned) // should still be running + requireState(t, id, model.AllocationStateAssigned) // should still be running } t.Log("daemon exit should wait on chief") @@ -472,7 +472,7 @@ func TestSetResourcesDaemon(t *testing.T) { ResourcesState: sproto.Terminated, ResourcesStopped: &sproto.ResourcesStopped{}, }) - requireState(t, db, id, model.AllocationStateTerminating) + requireState(t, id, model.AllocationStateTerminating) require.False(t, waitForCondition(time.Second, func() bool { return exitFuture.Load() != nil }), "allocation exited prematurely") @@ -484,7 +484,7 @@ func TestSetResourcesDaemon(t *testing.T) { ResourcesStopped: &sproto.ResourcesStopped{}, }) - exit := requireTerminated(t, db, id, exitFuture) + exit := requireTerminated(t, id, exitFuture) require.NoError(t, exit.Err) require.True(t, resources[ranked[2]].AssertExpectations(t), "daemon wasn't killed") require.True(t, rm.AssertExpectations(t), "rm didn't receive release in time") @@ -517,7 +517,7 @@ func TestRestore(t *testing.T) { restoredAr := stubAllocateRequest(restoredTask) restoredAr.Restore = true - err := pgDB.AddAllocation(&model.Allocation{ + err := db.AddAllocation(context.Background(), &model.Allocation{ AllocationID: restoredAr.AllocationID, TaskID: restoredAr.TaskID, Slots: restoredAr.SlotsNeeded, @@ -661,7 +661,7 @@ func requireAssignedMany( ResourcePool: stubResourcePoolName, Resources: assigned, }) - requireState(t, db, id, model.AllocationStateAssigned) + requireState(t, id, model.AllocationStateAssigned) return resources } @@ -677,12 +677,11 @@ func requireKilled( } _ = DefaultService.Signal(id, KillAllocation, "cleanup for tests") - return requireTerminated(t, db, id, exitFuture) + return requireTerminated(t, id, exitFuture) } func requireTerminated( t *testing.T, - db *db.PgDB, id model.AllocationID, exitFuture *atomic.Pointer[AllocationExited], ) *AllocationExited { @@ -691,17 +690,16 @@ func requireTerminated( }), "allocation did not exit in time") exit := exitFuture.Load() require.True(t, exit.FinalState.State == model.AllocationStateTerminated) - requireDBState(t, db, id, model.AllocationStateTerminated) + requireDBState(t, id, model.AllocationStateTerminated) return exit } func requireState( t *testing.T, - db *db.PgDB, id model.AllocationID, state model.AllocationState, ) (AllocationState, *model.Allocation) { - return requireAllocationState(t, id, state), requireDBState(t, db, id, state) + return requireAllocationState(t, id, state), requireDBState(t, id, state) } func requireAllocationState( @@ -735,11 +733,10 @@ func requireAllocationState( func requireDBState( t *testing.T, - db *db.PgDB, id model.AllocationID, expected model.AllocationState, ) *model.Allocation { - dbState, err := db.AllocationByID(id) + dbState, err := db.AllocationByID(context.Background(), id) require.NoError(t, err) require.NotNil(t, dbState.State) require.Equal(t, expected, *dbState.State) diff --git a/master/internal/telemetry/reports.go b/master/internal/telemetry/reports.go index 10249eeb282..4d85649fce4 100644 --- a/master/internal/telemetry/reports.go +++ b/master/internal/telemetry/reports.go @@ -1,6 +1,7 @@ package telemetry import ( + "context" "crypto/rand" "encoding/json" "math/big" @@ -131,9 +132,9 @@ func ReportExperimentCreated(id int, config expconf.ExperimentConfig) { } // ReportAllocationTerminal reports that an allocation ends. -func ReportAllocationTerminal(db db.DB, a model.Allocation, d *device.Device, +func ReportAllocationTerminal(a model.Allocation, d *device.Device, ) { - res, err := db.CompleteAllocationTelemetry(a.AllocationID) + res, err := db.CompleteAllocationTelemetry(context.Background(), a.AllocationID) if err != nil { syslog.WithError(err).Warn("failed to fetch allocation telemetry") return diff --git a/master/internal/telemetry/telemetry_test.go b/master/internal/telemetry/telemetry_test.go index d0acb098072..699901736ae 100644 --- a/master/internal/telemetry/telemetry_test.go +++ b/master/internal/telemetry/telemetry_test.go @@ -41,7 +41,7 @@ func TestTelemetry(t *testing.T) { reportMasterTick(db, rm) ReportProvisionerTick([]*model.Instance{}, "test-instance") ReportExperimentCreated(1, schemas.WithDefaults(createExpConfig())) - ReportAllocationTerminal(db, model.Allocation{}, &device.Device{}) + ReportAllocationTerminal(model.Allocation{}, &device.Device{}) ReportExperimentStateChanged(db, &model.Experiment{}) ReportUserCreated(true, true) ReportUserCreated(false, false) diff --git a/master/internal/trial.go b/master/internal/trial.go index fb817222269..66f3447d769 100644 --- a/master/internal/trial.go +++ b/master/internal/trial.go @@ -301,13 +301,14 @@ func (t *trial) create() error { t.warmStartCheckpoint, int64(t.searcher.Create.TrialSeed), ) + ctx := context.Background() - err := t.addTask() + err := t.addTask(ctx) if err != nil { return err } - err = db.AddTrial(context.TODO(), m, t.taskID) + err = db.AddTrial(ctx, m, t.taskID) if err != nil { return errors.Wrap(err, "failed to save trial to database") } @@ -349,14 +350,15 @@ func (t *trial) continueSetup(continueFromTrialID *int) error { t.taskID = model.TaskID(fmt.Sprintf("%s-%d", t.taskID, len(trialIDTaskIDs))) - err = t.addTask() + ctx := context.Background() + err = t.addTask(ctx) if err != nil { return err } if _, err := db.Bun(). NewInsert(). Model(&model.TrialTaskID{TrialID: t.id, TaskID: t.taskID}). - Exec(context.TODO()); err != nil { + Exec(ctx); err != nil { return fmt.Errorf("adding trial ID task ID relationship: %w", err) } return nil @@ -473,8 +475,8 @@ func (t *trial) maybeAllocateTask() error { return nil } -func (t *trial) addTask() error { - return t.db.AddTask(&model.Task{ +func (t *trial) addTask(ctx context.Context) error { + return db.AddTask(ctx, &model.Task{ TaskID: t.taskID, TaskType: model.TaskTypeTrial, StartTime: t.jobSubmissionTime, // TODO: Why is this the job submission time..? @@ -696,7 +698,7 @@ func (t *trial) transition(s model.StateWithReason) error { if t.state != s.State { t.syslog.Infof("trial changed from state %s to %s", t.state, s.State) if t.idSet { - if err := t.db.UpdateTrial(t.id, s.State); err != nil { + if err := db.UpdateTrial(context.Background(), t.id, s.State); err != nil { return fmt.Errorf("updating trial with end state (%s, %s): %w", s.State, s.InformationalReason, err) } } diff --git a/master/test/integration/api/api_checkpoints_test.go b/master/test/integration/api/api_checkpoints_test.go index 6e293fda969..98508ca5ce0 100644 --- a/master/test/integration/api/api_checkpoints_test.go +++ b/master/test/integration/api/api_checkpoints_test.go @@ -366,6 +366,7 @@ func createPrereqs(t *testing.T, pgDB *db.PgDB) ( err := pgDB.AddExperiment(experiment, activeConfig) assert.NilError(t, err, "failed to insert experiment") + ctx := context.Background() task := db.RequireMockTask(t, pgDB, experiment.OwnerID) trial := &model.Trial{ ExperimentID: experiment.ID, @@ -373,7 +374,7 @@ func createPrereqs(t *testing.T, pgDB *db.PgDB) ( StartTime: time.Now(), } - err = db.AddTrial(context.TODO(), trial, task.TaskID) + err = db.AddTrial(ctx, trial, task.TaskID) assert.NilError(t, err, "failed to insert trial") t.Logf("Created trial=%v", trial) @@ -384,9 +385,9 @@ func createPrereqs(t *testing.T, pgDB *db.PgDB) ( StartTime: ptrs.Ptr(startTime), EndTime: ptrs.Ptr(startTime.Add(time.Duration(1) * time.Second)), } - err = pgDB.AddAllocation(a) + err = db.AddAllocation(ctx, a) assert.NilError(t, err, "failed to add allocation") - err = pgDB.CompleteAllocation(a) + err = db.CompleteAllocation(ctx, a) assert.NilError(t, err, "failed to complete allocation") return experiment, trial, a diff --git a/master/test/integration/api/api_trials_test.go b/master/test/integration/api/api_trials_test.go index 7e674b2bec3..c7f4ed76044 100644 --- a/master/test/integration/api/api_trials_test.go +++ b/master/test/integration/api/api_trials_test.go @@ -262,9 +262,9 @@ func trialProfilerMetricsTests( assert.Assert(t, origEqRecv, "received:\nt\t%s\noriginal:\n\t%s", bRecv, bOrig) } - err = pgDB.UpdateTrial(trial.ID, model.StoppingCompletedState) + err = db.UpdateTrial(ctx, trial.ID, model.StoppingCompletedState) assert.NilError(t, err, "failed to update trial state") - err = pgDB.UpdateTrial(trial.ID, model.CompletedState) + err = db.UpdateTrial(ctx, trial.ID, model.CompletedState) assert.NilError(t, err, "failed to update trial state") _, err = tlCl.Recv()