Skip to content

Commit

Permalink
chore: add kill trial endpoint [DET-3739] (#1071)
Browse files Browse the repository at this point in the history
  • Loading branch information
hamidzr authored Aug 17, 2020
1 parent 48c6230 commit c0b3fe5
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 5 deletions.
21 changes: 21 additions & 0 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/grpc"
"github.com/determined-ai/determined/master/pkg/actor"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/checkpointv1"
Expand Down Expand Up @@ -137,3 +138,23 @@ func (a *apiServer) GetTrialCheckpoints(

return resp, a.paginate(&resp.Pagination, &resp.Checkpoints, req.Offset, req.Limit)
}

func (a *apiServer) KillTrial(
ctx context.Context, req *apiv1.KillTrialRequest,
) (*apiv1.KillTrialResponse, error) {
ok, err := a.m.db.CheckTrialExists(int(req.Id))
switch {
case err != nil:
return nil, status.Errorf(codes.Internal, "failed to check if trial exists: %s", err)
case !ok:
return nil, status.Errorf(codes.NotFound, "trial %d not found", req.Id)
}

resp := apiv1.KillTrialResponse{}
addr := actor.Addr("trials", req.Id).String()
err = a.actorRequest(addr, req, &resp)
if status.Code(err) == codes.NotFound {
return &apiv1.KillTrialResponse{}, nil
}
return &resp, err
}
13 changes: 13 additions & 0 deletions master/internal/db/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,19 @@ EXISTS(
return exists, err
}

// CheckTrialExists checks if the trial exists.
func (db *PgDB) CheckTrialExists(id int) (bool, error) {
var exists bool
err := db.sql.QueryRow(`
SELECT
EXISTS(
select id
FROM trials
WHERE id = $1
)`, id).Scan(&exists)
return exists, err
}

// ExperimentCheckpointsRaw returns a JSON string describing checkpoints for a given experiment,
// either all of them or the best subset.
func (db *PgDB) ExperimentCheckpointsRaw(id int, numBest *int) ([]byte, error) {
Expand Down
17 changes: 13 additions & 4 deletions master/internal/trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/determined-ai/determined/master/pkg/ssh"
"github.com/determined-ai/determined/master/pkg/tasks"
"github.com/determined-ai/determined/master/pkg/union"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

const (
Expand Down Expand Up @@ -225,6 +226,13 @@ func newTrial(
}
}

func (t *trial) killTrial(ctx *actor.Context) {
t.killed = true
if t.task != nil {
ctx.Tell(t.rp, scheduler.TerminateTask{TaskID: t.task.ID, Forcible: true})
}
}

func (t *trial) Receive(ctx *actor.Context) error {
switch msg := ctx.Message().(type) {
case actor.PreStart:
Expand Down Expand Up @@ -364,10 +372,7 @@ func (t *trial) runningReceive(ctx *actor.Context) error {
t.processTaskTerminated(ctx, msg)

case killTrial:
t.killed = true
if t.task != nil {
ctx.Tell(t.rp, scheduler.TerminateTask{TaskID: t.task.ID, Forcible: true})
}
t.killTrial(ctx)

case terminateTimeout:
if t.task != nil && msg.runID == t.runID {
Expand All @@ -384,6 +389,10 @@ func (t *trial) runningReceive(ctx *actor.Context) error {

case actor.ChildStopped:

case *apiv1.KillTrialRequest:
t.killTrial(ctx)
ctx.Respond(&apiv1.KillTrialResponse{})

default:
return actor.ErrUnexpectedMessage(ctx)
}
Expand Down
7 changes: 6 additions & 1 deletion proto/src/determined/api/v1/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ service Determined {
// Stream Trial logs.
rpc TrialLogs(TrialLogsRequest) returns (stream TrialLogsResponse) {
option (google.api.http) = {get: "/api/v1/trials/{trial_id}/logs"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Experiments"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: ["Experiments", "Trials"]};
}
// Kill an trial.
rpc KillTrial(KillTrialRequest) returns (KillTrialResponse) {
option (google.api.http) = {post: "/api/v1/trials/{id}/kill" };
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: ["Experiments", "Trials"]};
}

// Get a list of checkpoints for a trial.
Expand Down
9 changes: 9 additions & 0 deletions proto/src/determined/api/v1/trial.proto
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,12 @@ message GetTrialCheckpointsResponse {
// Pagination information of the full dataset.
Pagination pagination = 2;
}

// Kill an trial.
message KillTrialRequest {
// The trial id
int32 id = 1;
}
// Response to KillTrialRequest.
message KillTrialResponse {}

0 comments on commit c0b3fe5

Please sign in to comment.