Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
NicholasBlaskey committed Nov 29, 2023
1 parent 19659b0 commit 886ec10
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 22 deletions.
8 changes: 4 additions & 4 deletions master/internal/api_checkpoint_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ func TestCheckpointsOnArchivedSteps(t *testing.T) {
trial, task := createTestTrial(t, api, curUser)
for _, shouldArchive := range []bool{false, true} {
if shouldArchive {
_, err := db.Bun().NewUpdate().Table("runs").
Set("restart_id = 1").
Where("id = ?", trial.ID).
Exec(ctx)
_, err := db.Bun().NewUpdate().Table("runs"). // TODO(nick-runs) call runs package.
Set("restart_id = 1").
Where("id = ?", trial.ID).
Exec(ctx)
require.NoError(t, err)
trialRunID++
}
Expand Down
18 changes: 9 additions & 9 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1544,10 +1544,10 @@ func (a *apiServer) ContinueExperiment(
"experiment has been completed, cannot continue this experiment")
}
} else if isSingle && len(trialsResp.Trials) > 0 {
if _, err := tx.NewUpdate().Table("runs").
Set("state = ?", model.PausedState).
Where("id = ?", trialsResp.Trials[0].Id).
Exec(ctx); err != nil {
if _, err := tx.NewUpdate().Table("runs"). // TODO(nick-runs) call runs package.
Set("state = ?", model.PausedState).
Where("id = ?", trialsResp.Trials[0].Id).
Exec(ctx); err != nil {
return fmt.Errorf("changing trial state to PAUSED: %w", err)
}
}
Expand Down Expand Up @@ -1592,11 +1592,11 @@ func (a *apiServer) ContinueExperiment(
trialIDs = append(trialIDs, t.Id)
}
if len(trialIDs) > 0 {
if _, err := tx.NewUpdate().Table("runs").
Set("restarts = 0").
Set("end_time = null").
Where("id IN (?)", bun.In(trialIDs)).
Exec(ctx); err != nil {
if _, err := tx.NewUpdate().Table("runs"). // TODO(nick-runs) call runs package.
Set("restarts = 0").
Set("end_time = null").
Where("id IN (?)", bun.In(trialIDs)).
Exec(ctx); err != nil {
return fmt.Errorf("zeroing out trial restarts: %w", err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion master/internal/trials/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (a *TrialsAPIServer) StartTrial(
}
}

run := model.Run{ID: trialID}
run := model.Run{ID: trialID} // TODO(nick-runs) call runs package.
_, err := tx.NewUpdate().Model(&run).WherePK().
Set("restart_id = restart_id + 1").
Set("state = ?", model.RunningState).
Expand Down
47 changes: 47 additions & 0 deletions master/pkg/model/experiment_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package model

import (
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/determined-ai/determined/master/pkg/ptrs"
)

func TestToRun(t *testing.T) {
trial := &Trial{
ID: 3,
RequestID: ptrs.Ptr(RequestID(uuid.New())),
ExperimentID: 4,
State: CompletedState,
StartTime: time.Now(),
EndTime: ptrs.Ptr(time.Now()),
HParams: map[string]any{"test": "test"},
WarmStartCheckpointID: ptrs.Ptr(2),
Seed: 12,
TotalBatches: 15,
ExternalTrialID: ptrs.Ptr("ext"),
RunID: 19,
LastActivity: ptrs.Ptr(time.Now()),
}

expected := &Run{
ID: trial.ID,
RequestID: trial.RequestID,
ExperimentID: trial.ExperimentID,
State: trial.State,
StartTime: trial.StartTime,
EndTime: trial.EndTime,
HParams: trial.HParams,
WarmStartCheckpointID: trial.WarmStartCheckpointID,
Seed: trial.Seed,
TotalBatches: trial.TotalBatches,
ExternalRunID: trial.ExternalTrialID,
RestartID: 19,
LastActivity: trial.LastActivity,
}

require.Equal(t, expected, trial.ToRun())
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,3 @@ DROP VIEW trials;
ALTER TABLE runs RENAME COLUMN restart_id TO run_id;
ALTER TABLE runs RENAME COLUMN external_run_id TO external_trial_id;
ALTER TABLE runs RENAME TO trials;

DROP TABLE dummy;
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@ ALTER TABLE trials RENAME TO runs;
ALTER TABLE runs RENAME COLUMN run_id TO restart_id;
ALTER TABLE runs RENAME COLUMN external_trial_id TO external_run_id;

CREATE TABLE dummy (
xyz INT
);
INSERT INTO dummy (xyz) VALUES (1);

CREATE VIEW trials AS
SELECT
id AS id,
Expand Down Expand Up @@ -48,4 +43,4 @@ SELECT

-- warm_start_checkpoint_id will eventually be in the runs to checkpoint MTM.
warm_start_checkpoint_id AS warm_start_checkpoint_id
FROM runs, dummy; -- FROM dummy is a hack to make this view not insertable.
FROM runs;

0 comments on commit 886ec10

Please sign in to comment.