Skip to content

Commit

Permalink
chore: bunify db/postgres_tasks.go
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinaecalderon committed Jan 29, 2024
1 parent e7c9565 commit fac7217
Show file tree
Hide file tree
Showing 27 changed files with 320 additions and 468 deletions.
2 changes: 1 addition & 1 deletion master/internal/api_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func createVersionTwoCheckpoint(
ResourcePool: "somethingelse",
StartTime: ptrs.Ptr(time.Now().UTC().Truncate(time.Millisecond)),
}
require.NoError(t, api.m.db.AddAllocation(aIn))
require.NoError(t, db.AddAllocation(ctx, aIn))

checkpoint := &model.CheckpointV2{
ID: 0,
Expand Down
2 changes: 1 addition & 1 deletion master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2737,7 +2737,7 @@ func (a *apiServer) createTrialTx(
nil,
0)

if err := a.m.db.AddTask(&model.Task{
if err := db.AddTask(ctx, &model.Task{
TaskID: taskID,
TaskType: model.TaskTypeTrial,
StartTime: time.Now(),
Expand Down
6 changes: 3 additions & 3 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func TestGetTaskContextDirectoryExperiment(t *testing.T) {
func TestGetTaskContextDirectoryTask(t *testing.T) {
api, _, ctx := setupAPITest(t, nil)
task := &model.Task{TaskType: model.TaskTypeNotebook, TaskID: model.NewTaskID()}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))

expectedContextDirectory := []byte("expectedContextDirectory")
_, err := db.Bun().NewInsert().Model(&model.TaskContextDirectory{
Expand Down Expand Up @@ -539,7 +539,7 @@ func TestGetExperiments(t *testing.T) {
require.NoError(t, api.m.db.AddExperiment(exp0, activeConfig0))
for i := 0; i < 3; i++ {
task := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))
require.NoError(t, db.AddTrial(ctx, &model.Trial{
State: model.PausedState,
ExperimentID: exp0.ID,
Expand Down Expand Up @@ -791,7 +791,7 @@ func TestSearchExperiments(t *testing.T) {
// Trial without validations doesn't cause issues.
noValidationsExp := createTestExpWithProjectID(t, api, curUser, projectIDInt)
task := &model.Task{TaskType: model.TaskTypeTrial, TaskID: model.NewTaskID()}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))
require.NoError(t, db.AddTrial(ctx, &model.Trial{
State: model.PausedState,
ExperimentID: noValidationsExp.ID,
Expand Down
18 changes: 9 additions & 9 deletions master/internal/api_tasks_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ func mockNotebookWithWorkspaceID(
TaskID: model.NewTaskID(),
TaskType: model.TaskTypeNotebook,
}
require.NoError(t, api.m.db.AddTask(nb))
require.NoError(t, db.AddTask(ctx, nb))

allocationID := model.AllocationID(string(nb.TaskID) + ".1")
require.NoError(t, api.m.db.AddAllocation(&model.Allocation{
require.NoError(t, db.AddAllocation(ctx, &model.Allocation{
TaskID: nb.TaskID,
AllocationID: allocationID,
}))
Expand Down Expand Up @@ -324,7 +324,7 @@ func TestAddAllocationAcceleratorData(t *testing.T) {
TaskType: model.TaskTypeTrial,
StartTime: time.Now().UTC().Truncate(time.Millisecond),
}
require.NoError(t, api.m.db.AddTask(task), "failed to add task")
require.NoError(t, db.AddTask(ctx, task), "failed to add task")

aID := tID + "-1"
a := &model.Allocation{
Expand All @@ -333,7 +333,7 @@ func TestAddAllocationAcceleratorData(t *testing.T) {
Slots: 1,
ResourcePool: "default",
}
require.NoError(t, api.m.db.AddAllocation(a), "failed to add allocation")
require.NoError(t, db.AddAllocation(ctx, a), "failed to add allocation")
accData := &model.AcceleratorData{
ContainerID: uuid.NewString(),
AllocationID: model.AllocationID(aID),
Expand Down Expand Up @@ -362,7 +362,7 @@ func TestGetAllocationAcceleratorDataWithNoData(t *testing.T) {
TaskType: model.TaskTypeTrial,
StartTime: time.Now().UTC().Truncate(time.Millisecond),
}
require.NoError(t, api.m.db.AddTask(task), "failed to add task")
require.NoError(t, db.AddTask(ctx, task), "failed to add task")

aID := tID + "-1"
a := &model.Allocation{
Expand All @@ -371,7 +371,7 @@ func TestGetAllocationAcceleratorDataWithNoData(t *testing.T) {
Slots: 1,
ResourcePool: "default",
}
require.NoError(t, api.m.db.AddAllocation(a), "failed to add allocation")
require.NoError(t, db.AddAllocation(ctx, a), "failed to add allocation")

resp, err := api.GetTaskAcceleratorData(ctx,
&apiv1.GetTaskAcceleratorDataRequest{TaskId: tID.String()})
Expand All @@ -390,7 +390,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) {
TaskType: model.TaskTypeTrial,
StartTime: time.Now().UTC().Truncate(time.Millisecond),
}
require.NoError(t, api.m.db.AddTask(task), "failed to add task")
require.NoError(t, db.AddTask(ctx, task), "failed to add task")

aID1 := tID + "-1"
a1 := &model.Allocation{
Expand All @@ -399,7 +399,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) {
Slots: 1,
ResourcePool: "default",
}
require.NoError(t, api.m.db.AddAllocation(a1), "failed to add allocation")
require.NoError(t, db.AddAllocation(ctx, a1), "failed to add allocation")
accData := &model.AcceleratorData{
ContainerID: uuid.NewString(),
AllocationID: model.AllocationID(aID1),
Expand All @@ -418,7 +418,7 @@ func TestGetAllocationAcceleratorData(t *testing.T) {
Slots: 1,
ResourcePool: "default",
}
require.NoError(t, api.m.db.AddAllocation(a2), "failed to add allocation")
require.NoError(t, db.AddAllocation(ctx, a2), "failed to add allocation")

resp, err := api.GetTaskAcceleratorData(ctx,
&apiv1.GetTaskAcceleratorDataRequest{TaskId: tID.String()})
Expand Down
23 changes: 12 additions & 11 deletions master/internal/api_trials_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var inferenceMetricGroup = "inference"
func createTestTrial(
t *testing.T, api *apiServer, curUser model.User,
) (*model.Trial, *model.Task) {
ctx := context.Background()
exp := createTestExpWithProjectID(t, api, curUser, 1)

task := &model.Task{
Expand All @@ -46,17 +47,17 @@ func createTestTrial(
StartTime: time.Now(),
TaskID: trialTaskID(exp.ID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))

trial := &model.Trial{
StartTime: time.Now(),
State: model.PausedState,
ExperimentID: exp.ID,
}
require.NoError(t, db.AddTrial(context.TODO(), trial, task.TaskID))
require.NoError(t, db.AddTrial(ctx, trial, task.TaskID))

// Return trial exactly the way the API will generally get it.
outTrial, err := db.TrialByID(context.TODO(), trial.ID)
outTrial, err := db.TrialByID(ctx, trial.ID)
require.NoError(t, err)
return outTrial, task
}
Expand Down Expand Up @@ -751,15 +752,15 @@ func TestTrialProtoTaskIDs(t *testing.T) {
StartTime: task0.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task1))
require.NoError(t, db.AddTask(ctx, task1))

task2 := &model.Task{
TaskType: model.TaskTypeTrial,
LogVersion: model.TaskLogVersion1,
StartTime: task1.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task2))
require.NoError(t, db.AddTask(ctx, task2))

_, err = db.Bun().NewInsert().Model(&[]model.TrialTaskID{
{TrialID: trial.ID, TaskID: task1.TaskID},
Expand Down Expand Up @@ -834,7 +835,7 @@ func TestExperimentIDFromTrialTaskID(t *testing.T) {
StartTime: time.Now(),
TaskID: model.TaskID(uuid.New().String()),
}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(context.Background(), task))
_, err = experimentIDFromTrialTaskID(notTrialTask.TaskID)
require.ErrorIs(t, err, errIsNotTrialTaskID)

Expand All @@ -852,7 +853,7 @@ func TestTrialLogsBackported(t *testing.T) {
StartTime: time.Now(),
TaskID: model.TaskID(fmt.Sprintf("backported.%d", exp.ID)),
}
require.NoError(t, api.m.db.AddTask(task))
require.NoError(t, db.AddTask(ctx, task))

trial := &model.Trial{
StartTime: time.Now(),
Expand Down Expand Up @@ -889,15 +890,15 @@ func TestTrialLogs(t *testing.T) {
StartTime: task0.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task1))
require.NoError(t, db.AddTask(ctx, task1))

task2 := &model.Task{
TaskType: model.TaskTypeTrial,
LogVersion: model.TaskLogVersion1,
StartTime: task1.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task2))
require.NoError(t, db.AddTask(ctx, task2))

_, err := db.Bun().NewInsert().Model(&[]model.TrialTaskID{
{TrialID: trial.ID, TaskID: task1.TaskID},
Expand Down Expand Up @@ -985,15 +986,15 @@ func TestTrialLogFields(t *testing.T) {
StartTime: task0.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task1))
require.NoError(t, db.AddTask(ctx, task1))

task2 := &model.Task{
TaskType: model.TaskTypeTrial,
LogVersion: model.TaskLogVersion1,
StartTime: task1.StartTime.Add(time.Second),
TaskID: trialTaskID(trial.ExperimentID, model.NewRequestID(rand.Reader)),
}
require.NoError(t, api.m.db.AddTask(task2))
require.NoError(t, db.AddTask(ctx, task2))

_, err := db.Bun().NewInsert().Model(&[]model.TrialTaskID{
{TrialID: trial.ID, TaskID: task1.TaskID},
Expand Down
10 changes: 6 additions & 4 deletions master/internal/checkpoint_gc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package internal

import (
"context"
"fmt"
"strings"
"time"
Expand All @@ -26,7 +27,7 @@ const fullDeleteGlob = "**/*"

func runCheckpointGCTask(
rm rm.ResourceManager,
db *db.PgDB,
pgDB *db.PgDB,
taskID model.TaskID,
jobID model.JobID,
jobSubmissionTime time.Time,
Expand Down Expand Up @@ -82,7 +83,8 @@ func runCheckpointGCTask(
})
syslog := logrus.WithField("component", "checkpointgc").WithFields(logCtx.Fields())

if err := db.AddTask(&model.Task{
ctx := context.Background()
if err := db.AddTask(ctx, &model.Task{
TaskID: taskID,
TaskType: model.TaskTypeCheckpointGC,
StartTime: time.Now().UTC(),
Expand All @@ -97,7 +99,7 @@ func runCheckpointGCTask(

resultChan := make(chan error, 1)
onExit := func(ae *task.AllocationExited) {
if err := db.CompleteTask(taskID, time.Now().UTC()); err != nil {
if err := db.CompleteTask(ctx, taskID, time.Now().UTC()); err != nil {
syslog.WithError(err).Error("marking GC task complete")
}
if err := tasklist.GroupPriorityChangeRegistry.Delete(gcJobID); err != nil {
Expand All @@ -119,7 +121,7 @@ func runCheckpointGCTask(
SingleAgent: true,
},
ResourcePool: rp,
}, db, rm, gcSpec, onExit)
}, pgDB, rm, gcSpec, onExit)
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions master/internal/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,11 @@ func (c *Command) OnExit(ae *task.AllocationExited) {

c.exitStatus = ae

if err := c.db.CompleteTask(c.taskID, time.Now().UTC()); err != nil {
ctx := context.Background()
if err := db.CompleteTask(ctx, c.taskID, time.Now().UTC()); err != nil {
c.syslog.WithError(err).Error("marking task complete")
}
if err := user.DeleteSessionByToken(context.TODO(), c.GenericCommandSpec.Base.UserSessionToken); err != nil {
if err := user.DeleteSessionByToken(ctx, c.GenericCommandSpec.Base.UserSessionToken); err != nil {
c.syslog.WithError(err).Errorf(
"failure to delete user session for task: %v", c.taskID)
}
Expand All @@ -251,7 +252,7 @@ func (c *Command) garbageCollect() {
}

if c.exitStatus == nil {
if err := c.db.CompleteTask(c.taskID, time.Now().UTC()); err != nil {
if err := db.CompleteTask(context.Background(), c.taskID, time.Now().UTC()); err != nil {
c.syslog.WithError(err).Error("marking task complete")
}
}
Expand Down
8 changes: 4 additions & 4 deletions master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,9 @@ func (m *Master) restoreNonTerminalExperiments() error {
return nil
}

func (m *Master) closeOpenAllocations() error {
func (m *Master) closeOpenAllocations(ctx context.Context) error {
allocationIds := task.DefaultService.GetAllAllocationIDs()
if err := m.db.CloseOpenAllocations(allocationIds); err != nil {
if err := db.CloseOpenAllocations(ctx, allocationIds); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -1081,11 +1081,11 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error {
return err
}

if err = m.closeOpenAllocations(); err != nil {
if err = m.closeOpenAllocations(ctx); err != nil {
return err
}

if err = m.db.EndAllTaskStats(); err != nil {
if err = db.EndAllTaskStats(ctx); err != nil {
return err
}

Expand Down
13 changes: 0 additions & 13 deletions master/internal/db/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,8 @@ type DB interface {
id int,
experimentBest, trialBest, trialLatest int,
) ([]uuid.UUID, error)
AddTask(t *model.Task) error
UpdateTrial(id int, newState model.State) error
UpdateTrialRunnerState(id int, state string) error
UpdateTrialRunnerMetadata(id int, md *trialv1.TrialRunnerMetadata) error
AddAllocation(a *model.Allocation) error
CompleteAllocation(a *model.Allocation) error
CompleteAllocationTelemetry(aID model.AllocationID) ([]byte, error)
TrialRunIDAndRestarts(trialID int) (int, int, error)
UpdateTrialRunID(id, runID int) error
UpdateTrialRestarts(id, restarts int) error
Expand Down Expand Up @@ -89,11 +84,6 @@ type DB interface {
trials []*apiv1.TrialsSnapshotResponse_Trial, endTime time.Time, err error)
TopTrialsByTrainingLength(experimentID int, maxTrials int, metric string,
smallerIsBetter bool) (trials []int32, err error)
StartAllocationSession(allocationID model.AllocationID, owner *model.User) (string, error)
DeleteAllocationSession(allocationID model.AllocationID) error
UpdateAllocationState(allocation model.Allocation) error
UpdateAllocationStartTime(allocation model.Allocation) error
UpdateAllocationProxyAddress(allocation model.Allocation) error
ExperimentSnapshot(experimentID int) ([]byte, int, error)
SaveSnapshot(
experimentID int, version int, experimentSnapshot []byte,
Expand All @@ -114,9 +104,6 @@ type DB interface {
RecordInstanceStats(a *model.InstanceStats) error
EndInstanceStats(a *model.InstanceStats) error
EndAllInstanceStats() error
EndAllTaskStats() error
RecordTaskEndStats(stats *model.TaskStats) error
RecordTaskStats(stats *model.TaskStats) error
UpdateJobPosition(jobID model.JobID, position decimal.Decimal) error
}

Expand Down
Loading

0 comments on commit fac7217

Please sign in to comment.