From 42c737d8ef3a928107a94f4239754e8381452640 Mon Sep 17 00:00:00 2001 From: Carolina Calderon Date: Mon, 29 Jan 2024 13:57:41 -0500 Subject: [PATCH] undoing task_logs split --- master/internal/db/postgres_task_logs.go | 146 ------------------ .../db/postgres_task_logs_intg_test.go | 84 ---------- master/internal/db/postgres_tasks.go | 138 +++++++++++++++++ .../internal/db/postgres_tasks_intg_test.go | 101 ++++++------ 4 files changed, 195 insertions(+), 274 deletions(-) delete mode 100644 master/internal/db/postgres_task_logs.go delete mode 100644 master/internal/db/postgres_task_logs_intg_test.go diff --git a/master/internal/db/postgres_task_logs.go b/master/internal/db/postgres_task_logs.go deleted file mode 100644 index 144c4d70658c..000000000000 --- a/master/internal/db/postgres_task_logs.go +++ /dev/null @@ -1,146 +0,0 @@ -package db - -import ( - "context" - "fmt" - "strings" - "time" - - "github.com/determined-ai/determined/master/internal/api" - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/proto/pkg/apiv1" -) - -// taskLogsFieldMap is used to map fields in filters to expressions. This was used historically -// in trial logs to either read timestamps or regex them out of logs. -var taskLogsFieldMap = map[string]string{} - -type taskLogsFollowState struct { - // The last ID returned by the query. Historically the trial logs API when streaming - // repeatedly made a request like SELECT ... FROM trial_logs ... ORDER BY k OFFSET N LIMIT M. - // Since offset is less than optimal (no filtering is done during the initial - // index scan), we at least pass Postgres the ID and let it begin after a certain ID rather - // than offset N into the query. - id int64 -} - -// TaskLogs takes a task ID and log offset, limit and filters and returns matching logs. -func (db *PgDB) TaskLogs( - taskID model.TaskID, limit int, fs []api.Filter, order apiv1.OrderBy, followState interface{}, -) ([]*model.TaskLog, interface{}, error) { - if followState != nil { - fs = append(fs, api.Filter{ - Field: "id", - Operation: api.FilterOperationGreaterThan, - Values: []int64{followState.(*taskLogsFollowState).id}, - }) - } - - params := []interface{}{taskID, limit} - fragment, params := filtersToSQL(fs, params, taskLogsFieldMap) - query := fmt.Sprintf(` -SELECT - l.id, - l.task_id, - l.allocation_id, - l.agent_id, - l.container_id, - l.rank_id, - l.timestamp, - l.level, - l.stdtype, - l.source, - l.log -FROM task_logs l -WHERE l.task_id = $1 -%s -ORDER BY l.id %s LIMIT $2 -`, fragment, OrderByToSQL(order)) - - var b []*model.TaskLog - if err := db.queryRows(query, &b, params...); err != nil { - return nil, nil, err - } - - if len(b) > 0 { - lastLog := b[len(b)-1] - followState = &taskLogsFollowState{id: int64(*lastLog.ID)} - } - - return b, followState, nil -} - -// AddTaskLogs adds a list of *model.TaskLog objects to the database with automatic IDs. -func (db *PgDB) AddTaskLogs(logs []*model.TaskLog) error { - if len(logs) == 0 { - return nil - } - - var text strings.Builder - text.WriteString(` -INSERT INTO task_logs - (task_id, allocation_id, log, agent_id, container_id, rank_id, timestamp, level, stdtype, source) -VALUES -`) - - args := make([]interface{}, 0, len(logs)*10) - - for i, log := range logs { - if i > 0 { - text.WriteString(",") - } - // TODO(brad): We can do better. - fmt.Fprintf(&text, " ($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", - i*10+1, i*10+2, i*10+3, i*10+4, i*10+5, i*10+6, i*10+7, i*10+8, i*10+9, i*10+10) - - args = append(args, log.TaskID, log.AllocationID, []byte(log.Log), log.AgentID, log.ContainerID, - log.RankID, log.Timestamp, log.Level, log.StdType, log.Source) - } - - if _, err := db.sql.Exec(text.String(), args...); err != nil { - return fmt.Errorf("error inserting %d task logs: %w", len(logs), err) - } - - return nil -} - -// DeleteTaskLogs deletes the logs for the given tasks. -func (db *PgDB) DeleteTaskLogs(ids []model.TaskID) error { - if _, err := Bun().NewDelete().Table("task_logs"). - Where("task_id IN (SELECT unnest(?::text [])::text)", ids).Exec(context.Background()); err != nil { - return fmt.Errorf("error deleting task logs for task %v: %w", ids, err) - } - return nil -} - -// TaskLogsCount returns the number of logs in postgres for the given task. -func (db *PgDB) TaskLogsCount(taskID model.TaskID, fs []api.Filter) (int, error) { - params := []interface{}{taskID} - fragment, params := filtersToSQL(fs, params, taskLogsFieldMap) - query := fmt.Sprintf(` -SELECT count(*) -FROM task_logs -WHERE task_id = $1 -%s -`, fragment) - var count int - if err := db.sql.QueryRow(query, params...).Scan(&count); err != nil { - return 0, err - } - return count, nil -} - -// TaskLogsFields returns the unique fields that can be filtered on for the given task. -func (db *PgDB) TaskLogsFields(taskID model.TaskID) (*apiv1.TaskLogsFieldsResponse, error) { - var fields apiv1.TaskLogsFieldsResponse - err := db.QueryProto("get_task_logs_fields", &fields, taskID) - return &fields, err -} - -// MaxTerminationDelay is the max delay before a consumer can be sure all logs have been recevied. -// For Postgres, we don't need to wait very long at all; this was a hypothetical cap on fluent -// to DB latency prior to fluent's deprecation. // to DB latency prior to fluent's deprecation. -func (db *PgDB) MaxTerminationDelay() time.Duration { - // TODO: K8s logs can take a bit to get to us, so much so we should investigate. - return 5 * time.Second -} diff --git a/master/internal/db/postgres_task_logs_intg_test.go b/master/internal/db/postgres_task_logs_intg_test.go deleted file mode 100644 index 9406dbc1955b..000000000000 --- a/master/internal/db/postgres_task_logs_intg_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package db - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/determined-ai/determined/master/internal/api" - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/proto/pkg/apiv1" -) - -func TestTaskLogsFlow(t *testing.T) { - db := MustResolveTestPostgres(t) - t1In := RequireMockTask(t, db, nil) - t2In := RequireMockTask(t, db, nil) - - // Test AddTaskLogs & TaskLogCounts - taskLog1 := RequireMockTaskLog(t, db, t1In.TaskID, "1") - taskLog2 := RequireMockTaskLog(t, db, t1In.TaskID, "2") - taskLog3 := RequireMockTaskLog(t, db, t2In.TaskID, "3") - - // Try adding only taskLog1, and count only 1 log. - err := db.AddTaskLogs([]*model.TaskLog{taskLog1}) - require.NoError(t, err) - - count, err := db.TaskLogsCount(t1In.TaskID, []api.Filter{}) - require.NoError(t, err) - require.Equal(t, count, 1) - - // Try adding the rest of the Task logs, and count 2 for t1In.TaskID, and 1 for t2In.TaskID - err = db.AddTaskLogs([]*model.TaskLog{taskLog2, taskLog3}) - require.NoError(t, err) - - count, err = db.TaskLogsCount(t1In.TaskID, []api.Filter{}) - require.NoError(t, err) - require.Equal(t, count, 2) - - count, err = db.TaskLogsCount(t2In.TaskID, []api.Filter{}) - require.NoError(t, err) - require.Equal(t, count, 1) - - // Test TaskLogsFields. - resp, err := db.TaskLogsFields(t1In.TaskID) - require.NoError(t, err) - require.ElementsMatch(t, resp.AgentIds, []string{"testing-agent-1", "testing-agent-2"}) - require.ElementsMatch(t, resp.ContainerIds, []string{"1", "2"}) - - // Test TaskLogs. - // Get 1 task log matching t1In task ID. - logs, _, err := db.TaskLogs(t1In.TaskID, 1, []api.Filter{}, apiv1.OrderBy_ORDER_BY_UNSPECIFIED, nil) - require.NoError(t, err) - require.Equal(t, 1, len(logs)) - require.Equal(t, logs[0].TaskID, string(t1In.TaskID)) - require.Contains(t, []string{"1", "2"}, *logs[0].ContainerID) - - // Get up to 5 tasks matching t2In task ID -- receive only 2. - logs, _, err = db.TaskLogs(t1In.TaskID, 5, []api.Filter{}, apiv1.OrderBy_ORDER_BY_UNSPECIFIED, nil) - require.NoError(t, err) - require.Equal(t, 2, len(logs)) - - // Test DeleteTaskLogs. - err = db.DeleteTaskLogs([]model.TaskID{t2In.TaskID}) - require.NoError(t, err) - - count, err = db.TaskLogsCount(t2In.TaskID, []api.Filter{}) - require.NoError(t, err) - require.Equal(t, 0, count) -} - -func RequireMockTaskLog(t *testing.T, db *PgDB, tID model.TaskID, suffix string) *model.TaskLog { - mockA := RequireMockAllocation(t, db, tID) - agentID := fmt.Sprintf("testing-agent-%s", suffix) - containerID := suffix - log := &model.TaskLog{ - TaskID: string(tID), - AllocationID: (*string)(&mockA.AllocationID), - Log: fmt.Sprintf("this is a log for task %s-%s", tID, suffix), - AgentID: &agentID, - ContainerID: &containerID, - } - return log -} diff --git a/master/internal/db/postgres_tasks.go b/master/internal/db/postgres_tasks.go index 18167a60dec5..c3afe73d6dd1 100644 --- a/master/internal/db/postgres_tasks.go +++ b/master/internal/db/postgres_tasks.go @@ -11,7 +11,9 @@ import ( "github.com/pkg/errors" "github.com/uptrace/bun" + "github.com/determined-ai/determined/master/internal/api" "github.com/determined-ai/determined/master/pkg/model" + "github.com/determined-ai/determined/proto/pkg/apiv1" ) // initAllocationSessions purges sessions of all closed allocations. @@ -298,3 +300,139 @@ func EndAllTaskStats(ctx context.Context) error { return nil } + +// taskLogsFieldMap is used to map fields in filters to expressions. This was used historically +// in trial logs to either read timestamps or regex them out of logs. +var taskLogsFieldMap = map[string]string{} + +type taskLogsFollowState struct { + // The last ID returned by the query. Historically the trial logs API when streaming + // repeatedly made a request like SELECT ... FROM trial_logs ... ORDER BY k OFFSET N LIMIT M. + // Since offset is less than optimal (no filtering is done during the initial + // index scan), we at least pass Postgres the ID and let it begin after a certain ID rather + // than offset N into the query. + id int64 +} + +// TaskLogs takes a task ID and log offset, limit and filters and returns matching logs. +func (db *PgDB) TaskLogs( + taskID model.TaskID, limit int, fs []api.Filter, order apiv1.OrderBy, followState interface{}, +) ([]*model.TaskLog, interface{}, error) { + if followState != nil { + fs = append(fs, api.Filter{ + Field: "id", + Operation: api.FilterOperationGreaterThan, + Values: []int64{followState.(*taskLogsFollowState).id}, + }) + } + + params := []interface{}{taskID, limit} + fragment, params := filtersToSQL(fs, params, taskLogsFieldMap) + query := fmt.Sprintf(` +SELECT + l.id, + l.task_id, + l.allocation_id, + l.agent_id, + l.container_id, + l.rank_id, + l.timestamp, + l.level, + l.stdtype, + l.source, + l.log +FROM task_logs l +WHERE l.task_id = $1 +%s +ORDER BY l.id %s LIMIT $2 +`, fragment, OrderByToSQL(order)) + + var b []*model.TaskLog + if err := db.queryRows(query, &b, params...); err != nil { + return nil, nil, err + } + + if len(b) > 0 { + lastLog := b[len(b)-1] + followState = &taskLogsFollowState{id: int64(*lastLog.ID)} + } + + return b, followState, nil +} + +// AddTaskLogs bulk-inserts a list of *model.TaskLog objects to the database with automatic IDs. +func (db *PgDB) AddTaskLogs(logs []*model.TaskLog) error { + if len(logs) == 0 { + return nil + } + + var text strings.Builder + text.WriteString(` +INSERT INTO task_logs + (task_id, allocation_id, log, agent_id, container_id, rank_id, timestamp, level, stdtype, source) +VALUES +`) + + args := make([]interface{}, 0, len(logs)*10) + + for i, log := range logs { + if i > 0 { + text.WriteString(",") + } + // TODO(brad): We can do better. + fmt.Fprintf(&text, " ($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", + i*10+1, i*10+2, i*10+3, i*10+4, i*10+5, i*10+6, i*10+7, i*10+8, i*10+9, i*10+10) + + args = append(args, log.TaskID, log.AllocationID, []byte(log.Log), log.AgentID, log.ContainerID, + log.RankID, log.Timestamp, log.Level, log.StdType, log.Source) + } + + if _, err := db.sql.Exec(text.String(), args...); err != nil { + return errors.Wrapf(err, "error inserting %d task logs", len(logs)) + } + + return nil +} + +// DeleteTaskLogs deletes the logs for the given tasks. +func (db *PgDB) DeleteTaskLogs(ids []model.TaskID) error { + if _, err := db.sql.Exec(` +DELETE FROM task_logs +WHERE task_id IN (SELECT unnest($1::text [])::text); +`, ids); err != nil { + return errors.Wrapf(err, "error deleting task logs for task %v", ids) + } + return nil +} + +// TaskLogsCount returns the number of logs in postgres for the given task. +func (db *PgDB) TaskLogsCount(taskID model.TaskID, fs []api.Filter) (int, error) { + params := []interface{}{taskID} + fragment, params := filtersToSQL(fs, params, taskLogsFieldMap) + query := fmt.Sprintf(` +SELECT count(*) +FROM task_logs +WHERE task_id = $1 +%s +`, fragment) + var count int + if err := db.sql.QueryRow(query, params...).Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +// TaskLogsFields returns the unique fields that can be filtered on for the given task. +func (db *PgDB) TaskLogsFields(taskID model.TaskID) (*apiv1.TaskLogsFieldsResponse, error) { + var fields apiv1.TaskLogsFieldsResponse + err := db.QueryProto("get_task_logs_fields", &fields, taskID) + return &fields, err +} + +// MaxTerminationDelay is the max delay before a consumer can be sure all logs have been recevied. +// For Postgres, we don't need to wait very long at all; this was a hypothetical cap on fluent +// to DB latency prior to fluent's deprecation. // to DB latency prior to fluent's deprecation. +func (db *PgDB) MaxTerminationDelay() time.Duration { + // TODO: K8s logs can take a bit to get to us, so much so we should investigate. + return 5 * time.Second +} diff --git a/master/internal/db/postgres_tasks_intg_test.go b/master/internal/db/postgres_tasks_intg_test.go index 09eded41b3d1..f3cdfd1f893e 100644 --- a/master/internal/db/postgres_tasks_intg_test.go +++ b/master/internal/db/postgres_tasks_intg_test.go @@ -90,7 +90,7 @@ func TestJobTaskAndAllocationAPI(t *testing.T) { TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - err = db.AddTask(tIn) + err = AddTask(ctx, tIn) require.NoError(t, err, "failed to add task") // Retrieve it back and make sure the mapping is exhaustive. @@ -100,7 +100,7 @@ func TestJobTaskAndAllocationAPI(t *testing.T) { // Complete it. tIn.EndTime = ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)) - err = db.CompleteTask(tID, *tIn.EndTime) + err = CompleteTask(ctx, tID, *tIn.EndTime) require.NoError(t, err, "failed to mark task completed") // Re-retrieve it back and make sure the mapping is still exhaustive. @@ -124,7 +124,7 @@ func TestJobTaskAndAllocationAPI(t *testing.T) { StartTime: ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)), Ports: ports, } - err = db.AddAllocation(aIn) + err = AddAllocation(ctx, aIn) require.NoError(t, err, "failed to add allocation") // Update ports @@ -133,35 +133,37 @@ func TestJobTaskAndAllocationAPI(t *testing.T) { ports["inter_train_process_comm_port2"] = 0 ports["c10d_port"] = 0 aIn.Ports = ports - err = UpdateAllocationPorts(*aIn) + err = UpdateAllocationPorts(ctx, *aIn) require.NoError(t, err, "failed to update port offset") // Retrieve it back and make sure the mapping is exhaustive. - aOut, err := db.AllocationByID(aIn.AllocationID) + aOut, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err, "failed to retrieve allocation") require.True(t, reflect.DeepEqual(aIn, aOut), pprintedExpect(aIn, aOut)) // Complete it. aIn.EndTime = ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)) - err = db.CompleteAllocation(aIn) + err = CompleteAllocation(ctx, aIn) require.NoError(t, err, "failed to mark allocation completed") // Re-retrieve it back and make sure the mapping is still exhaustive. - aOut, err = db.AllocationByID(aIn.AllocationID) + aOut, err = AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err, "failed to re-retrieve allocation") require.True(t, reflect.DeepEqual(aIn, aOut), pprintedExpect(aIn, aOut)) } func TestRecordAndEndTaskStats(t *testing.T) { + ctx := context.Background() + tID := model.NewTaskID() - require.NoError(t, db.AddTask(&model.Task{ + require.NoError(t, AddTask(ctx, &model.Task{ TaskID: tID, TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), }), "failed to add task") allocationID := model.AllocationID(tID + "allocationID") - require.NoError(t, db.AddAllocation(&model.Allocation{ + require.NoError(t, AddAllocation(ctx, &model.Allocation{ TaskID: tID, AllocationID: allocationID, }), "failed to add allocation") @@ -177,10 +179,10 @@ func TestRecordAndEndTaskStats(t *testing.T) { if i == 0 { taskStats.ContainerID = nil } - require.NoError(t, RecordTaskStatsBun(taskStats)) + require.NoError(t, RecordTaskStatsBun(ctx, taskStats)) taskStats.EndTime = ptrs.Ptr(time.Now().Truncate(time.Millisecond)) - require.NoError(t, RecordTaskEndStatsBun(taskStats)) + require.NoError(t, RecordTaskEndStatsBun(ctx, taskStats)) expected = append(expected, taskStats) } @@ -193,7 +195,7 @@ func TestRecordAndEndTaskStats(t *testing.T) { require.ElementsMatch(t, expected, actual) - err = db.EndAllTaskStats() + err = EndAllTaskStats(ctx) require.NoError(t, err) } @@ -206,7 +208,7 @@ func TestNonExperimentTasksContextDirectory(t *testing.T) { // Nil context directory. tID := model.NewTaskID() - require.NoError(t, db.AddTask(&model.Task{ + require.NoError(t, AddTask(ctx, &model.Task{ TaskID: tID, TaskType: model.TaskTypeNotebook, StartTime: time.Now().UTC().Truncate(time.Millisecond), @@ -220,7 +222,7 @@ func TestNonExperimentTasksContextDirectory(t *testing.T) { // Non nil context directory. tID = model.NewTaskID() - require.NoError(t, db.AddTask(&model.Task{ + require.NoError(t, AddTask(ctx, &model.Task{ TaskID: tID, TaskType: model.TaskTypeNotebook, StartTime: time.Now().UTC().Truncate(time.Millisecond), @@ -235,6 +237,8 @@ func TestNonExperimentTasksContextDirectory(t *testing.T) { } func TestAllocationState(t *testing.T) { + ctx := context.Background() + // Add an allocation of every possible state. states := []model.AllocationState{ model.AllocationStatePending, @@ -252,7 +256,7 @@ func TestAllocationState(t *testing.T) { TaskType: model.TaskTypeTrial, StartTime: time.Now().UTC().Truncate(time.Millisecond), } - require.NoError(t, db.AddTask(task), "failed to add task") + require.NoError(t, AddTask(ctx, task), "failed to add task") s := state a := &model.Allocation{ @@ -261,7 +265,7 @@ func TestAllocationState(t *testing.T) { ResourcePool: "default", State: &s, } - require.NoError(t, db.AddAllocation(a), "failed to add allocation") + require.NoError(t, AddAllocation(ctx, a), "failed to add allocation") // Update allocation to every possible state. testNoUpdate := true @@ -271,7 +275,7 @@ func TestAllocationState(t *testing.T) { j-- // Go to first iteration of loop after this. } else { a.State = &states[j] - require.NoError(t, db.UpdateAllocationState(*a), + require.NoError(t, UpdateAllocationState(ctx, *a), "failed to update allocation state") } @@ -431,7 +435,7 @@ func TestTaskCompleted(t *testing.T) { require.False(t, completed) require.NoError(t, err) - err = db.CompleteTask(tIn.TaskID, time.Now().UTC().Truncate(time.Millisecond)) + err = CompleteTask(ctx, tIn.TaskID, time.Now().UTC().Truncate(time.Millisecond)) require.NoError(t, err) completed, err = TaskCompleted(ctx, tIn.TaskID) @@ -440,6 +444,8 @@ func TestTaskCompleted(t *testing.T) { } func TestAddAllocation(t *testing.T) { + ctx := context.Background() + tIn := RequireMockTask(t, db, nil) a := model.Allocation{ AllocationID: model.AllocationID(fmt.Sprintf("%s-1", tIn.TaskID)), @@ -448,12 +454,10 @@ func TestAddAllocation(t *testing.T) { State: ptrs.Ptr(model.AllocationStateTerminated), } - err := db.AddAllocation(&a) + err := AddAllocation(ctx, &a) require.NoError(t, err, "failed to add allocation") - var res model.Allocation - err = Bun().NewSelect().Table("allocations").Where("allocation_id = ?", string(a.AllocationID)). - Scan(context.Background(), &res) + res, err := AllocationByID(ctx, a.AllocationID) require.NoError(t, err) require.Equal(t, a.AllocationID, res.AllocationID) require.Equal(t, a.TaskID, res.TaskID) @@ -462,6 +466,7 @@ func TestAddAllocation(t *testing.T) { } func TestAddAllocationExitStatus(t *testing.T) { + ctx := context.Background() tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) @@ -473,10 +478,10 @@ func TestAddAllocationExitStatus(t *testing.T) { aIn.ExitErr = &exitErr aIn.StatusCode = &statusCode - err := AddAllocationExitStatus(context.Background(), aIn) + err := AddAllocationExitStatus(ctx, aIn) require.NoError(t, err) - res, err := db.AllocationByID(aIn.AllocationID) + res, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.ExitErr, res.ExitErr) require.Equal(t, aIn.ExitReason, res.ExitReason) @@ -484,15 +489,17 @@ func TestAddAllocationExitStatus(t *testing.T) { } func TestCompleteAllocation(t *testing.T) { + ctx := context.Background() + tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) aIn.EndTime = ptrs.Ptr(time.Now().UTC()) - err := db.CompleteAllocation(aIn) + err := CompleteAllocation(ctx, aIn) require.NoError(t, err) - res, err := db.AllocationByID(aIn.AllocationID) + res, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.EndTime, res.EndTime) } @@ -501,7 +508,7 @@ func TestCompleteAllocationTelemetry(t *testing.T) { tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) - bytes, err := db.CompleteAllocationTelemetry(aIn.AllocationID) + bytes, err := CompleteAllocationTelemetry(context.Background(), aIn.AllocationID) require.NoError(t, err) require.Contains(t, string(bytes), string(aIn.AllocationID)) require.Contains(t, string(bytes), string(*tIn.JobID)) @@ -512,17 +519,19 @@ func TestAllocationByID(t *testing.T) { tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) - a, err := db.AllocationByID(aIn.AllocationID) + a, err := AllocationByID(context.Background(), aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn, a) } func TestAllocationSessionFlow(t *testing.T) { + ctx := context.Background() + uIn := RequireMockUser(t, db) tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) - tok, err := db.StartAllocationSession(aIn.AllocationID, &uIn) + tok, err := StartAllocationSession(ctx, aIn.AllocationID, &uIn) require.NoError(t, err) require.NotNil(t, tok) @@ -532,14 +541,14 @@ func TestAllocationSessionFlow(t *testing.T) { running := model.AllocationStatePulling aIn.State = &running - err = db.UpdateAllocationState(*aIn) + err = UpdateAllocationState(ctx, *aIn) require.NoError(t, err) - a, err := db.AllocationByID(aIn.AllocationID) + a, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.State, a.State) - err = db.DeleteAllocationSession(aIn.AllocationID) + err = DeleteAllocationSession(ctx, aIn.AllocationID) require.NoError(t, err) as, err = allocationSessionByID(t, aIn.AllocationID) @@ -548,15 +557,17 @@ func TestAllocationSessionFlow(t *testing.T) { } func TestUpdateAllocation(t *testing.T) { + ctx := context.Background() + tIn := RequireMockTask(t, db, nil) aIn := RequireMockAllocation(t, db, tIn.TaskID) // Testing UpdateAllocation Ports aIn.Ports = map[string]int{"abc": 123, "def": 456} - err := UpdateAllocationPorts(*aIn) + err := UpdateAllocationPorts(ctx, *aIn) require.NoError(t, err) - a, err := db.AllocationByID(aIn.AllocationID) + a, err := AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.Ports, a.Ports) @@ -564,10 +575,10 @@ func TestUpdateAllocation(t *testing.T) { newStartTime := ptrs.Ptr(time.Now().UTC()) aIn.StartTime = newStartTime - err = db.UpdateAllocationStartTime(*aIn) + err = UpdateAllocationStartTime(ctx, *aIn) require.NoError(t, err) - a, err = db.AllocationByID(aIn.AllocationID) + a, err = AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.StartTime, a.StartTime) @@ -575,15 +586,17 @@ func TestUpdateAllocation(t *testing.T) { proxyAddr := "here" aIn.ProxyAddress = &proxyAddr - err = db.UpdateAllocationProxyAddress(*aIn) + err = UpdateAllocationProxyAddress(ctx, *aIn) require.NoError(t, err) - a, err = db.AllocationByID(aIn.AllocationID) + a, err = AllocationByID(ctx, aIn.AllocationID) require.NoError(t, err) require.Equal(t, aIn.ProxyAddress, a.ProxyAddress) } func TestCloseOpenAllocations(t *testing.T) { + ctx := context.Background() + // Create test allocations, with a NULL end time. t1In := RequireMockTask(t, db, nil) a1In := RequireMockAllocation(t, db, t1In.TaskID) @@ -597,22 +610,22 @@ func TestCloseOpenAllocations(t *testing.T) { a2In.State = &terminated // Close only a2In open allocations (filter out the rest). - err := db.CloseOpenAllocations([]model.AllocationID{a1In.AllocationID}) + err := CloseOpenAllocations(ctx, []model.AllocationID{a1In.AllocationID}) require.NoError(t, err) - a1, err := db.AllocationByID(a1In.AllocationID) + a1, err := AllocationByID(ctx, a1In.AllocationID) require.NoError(t, err) require.Nil(t, a1.EndTime) - a2, err := db.AllocationByID(a2In.AllocationID) + a2, err := AllocationByID(ctx, a2In.AllocationID) require.NoError(t, err) require.NotNil(t, a2.EndTime) // Close the rest of the open allocations. - err = db.CloseOpenAllocations([]model.AllocationID{}) + err = CloseOpenAllocations(ctx, []model.AllocationID{}) require.NoError(t, err) - a1, err = db.AllocationByID(a1In.AllocationID) + a1, err = AllocationByID(ctx, a1In.AllocationID) require.NoError(t, err) require.NotNil(t, a1.EndTime) } @@ -632,7 +645,7 @@ func TestTaskLogsFlow(t *testing.T) { count, err := db.TaskLogsCount(t1In.TaskID, []api.Filter{}) require.NoError(t, err) - require.Equal(t, count, 1) + require.Equal(t, 1, count) // Try adding the rest of the Task logs, and count 2 for t1In.TaskID, and 1 for t2In.TaskID err = db.AddTaskLogs([]*model.TaskLog{taskLog2, taskLog3})