Skip to content

Commit

Permalink
feat: break workload info from trial endpoint into a new endpoint
Browse files Browse the repository at this point in the history
When a trial has run many workloads, the response to the trial details
endpoint for it can become very large and unwieldy. Since we don't
always need the full set of workloads, we move those into a new endpoint
and just have some useful workload summary information in the original
one.
  • Loading branch information
dzhu committed May 4, 2022
1 parent 02114cf commit ce1fe9b
Show file tree
Hide file tree
Showing 14 changed files with 497 additions and 93 deletions.
4 changes: 2 additions & 2 deletions harness/determined/cli/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def is_number(value: Any) -> bool:


def scalar_training_metrics_names(
workloads: Sequence[bindings.GetTrialResponseWorkloadContainer],
workloads: Sequence[bindings.v1WorkloadContainer],
) -> Set[str]:
"""
Given an experiment history, return the names of training metrics
Expand All @@ -564,7 +564,7 @@ def scalar_training_metrics_names(


def scalar_validation_metrics_names(
workloads: Sequence[bindings.GetTrialResponseWorkloadContainer],
workloads: Sequence[bindings.v1WorkloadContainer],
) -> Set[str]:
for workload in workloads:
if workload.validation:
Expand Down
112 changes: 84 additions & 28 deletions harness/determined/common/api/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,32 +78,6 @@ def to_json(self) -> typing.Any:
"inProgress": self.inProgress if self.inProgress is not None else None,
}

class GetTrialResponseWorkloadContainer:
def __init__(
self,
checkpoint: "typing.Optional[v1CheckpointWorkload]" = None,
training: "typing.Optional[v1MetricsWorkload]" = None,
validation: "typing.Optional[v1MetricsWorkload]" = None,
):
self.training = training
self.validation = validation
self.checkpoint = checkpoint

@classmethod
def from_json(cls, obj: Json) -> "GetTrialResponseWorkloadContainer":
return cls(
training=v1MetricsWorkload.from_json(obj["training"]) if obj.get("training", None) is not None else None,
validation=v1MetricsWorkload.from_json(obj["validation"]) if obj.get("validation", None) is not None else None,
checkpoint=v1CheckpointWorkload.from_json(obj["checkpoint"]) if obj.get("checkpoint", None) is not None else None,
)

def to_json(self) -> typing.Any:
return {
"training": self.training.to_json() if self.training is not None else None,
"validation": self.validation.to_json() if self.validation is not None else None,
"checkpoint": self.checkpoint.to_json() if self.checkpoint is not None else None,
}

class TrialEarlyExitExitedReason(enum.Enum):
EXITED_REASON_UNSPECIFIED = "EXITED_REASON_UNSPECIFIED"
EXITED_REASON_INVALID_HP = "EXITED_REASON_INVALID_HP"
Expand Down Expand Up @@ -299,9 +273,11 @@ def __init__(
bestCheckpoint: "typing.Optional[v1CheckpointWorkload]" = None,
bestValidation: "typing.Optional[v1MetricsWorkload]" = None,
endTime: "typing.Optional[str]" = None,
latestTraining: "typing.Optional[v1MetricsWorkload]" = None,
latestValidation: "typing.Optional[v1MetricsWorkload]" = None,
runnerState: "typing.Optional[str]" = None,
taskId: "typing.Optional[str]" = None,
totalCheckpointSize: "typing.Optional[str]" = None,
wallClockTime: "typing.Optional[float]" = None,
warmStartCheckpointUuid: "typing.Optional[str]" = None,
):
Expand All @@ -315,10 +291,12 @@ def __init__(
self.bestValidation = bestValidation
self.latestValidation = latestValidation
self.bestCheckpoint = bestCheckpoint
self.latestTraining = latestTraining
self.runnerState = runnerState
self.wallClockTime = wallClockTime
self.warmStartCheckpointUuid = warmStartCheckpointUuid
self.taskId = taskId
self.totalCheckpointSize = totalCheckpointSize

@classmethod
def from_json(cls, obj: Json) -> "trialv1Trial":
Expand All @@ -333,10 +311,12 @@ def from_json(cls, obj: Json) -> "trialv1Trial":
bestValidation=v1MetricsWorkload.from_json(obj["bestValidation"]) if obj.get("bestValidation", None) is not None else None,
latestValidation=v1MetricsWorkload.from_json(obj["latestValidation"]) if obj.get("latestValidation", None) is not None else None,
bestCheckpoint=v1CheckpointWorkload.from_json(obj["bestCheckpoint"]) if obj.get("bestCheckpoint", None) is not None else None,
latestTraining=v1MetricsWorkload.from_json(obj["latestTraining"]) if obj.get("latestTraining", None) is not None else None,
runnerState=obj.get("runnerState", None),
wallClockTime=float(obj["wallClockTime"]) if obj.get("wallClockTime", None) is not None else None,
warmStartCheckpointUuid=obj.get("warmStartCheckpointUuid", None),
taskId=obj.get("taskId", None),
totalCheckpointSize=obj.get("totalCheckpointSize", None),
)

def to_json(self) -> typing.Any:
Expand All @@ -351,10 +331,12 @@ def to_json(self) -> typing.Any:
"bestValidation": self.bestValidation.to_json() if self.bestValidation is not None else None,
"latestValidation": self.latestValidation.to_json() if self.latestValidation is not None else None,
"bestCheckpoint": self.bestCheckpoint.to_json() if self.bestCheckpoint is not None else None,
"latestTraining": self.latestTraining.to_json() if self.latestTraining is not None else None,
"runnerState": self.runnerState if self.runnerState is not None else None,
"wallClockTime": dump_float(self.wallClockTime) if self.wallClockTime is not None else None,
"warmStartCheckpointUuid": self.warmStartCheckpointUuid if self.warmStartCheckpointUuid is not None else None,
"taskId": self.taskId if self.taskId is not None else None,
"totalCheckpointSize": self.totalCheckpointSize if self.totalCheckpointSize is not None else None,
}

class v1AckAllocationPreemptionSignalRequest:
Expand Down Expand Up @@ -2146,7 +2128,7 @@ class v1GetTrialResponse:
def __init__(
self,
trial: "trialv1Trial",
workloads: "typing.Sequence[GetTrialResponseWorkloadContainer]",
workloads: "typing.Sequence[v1WorkloadContainer]",
):
self.trial = trial
self.workloads = workloads
Expand All @@ -2155,7 +2137,7 @@ def __init__(
def from_json(cls, obj: Json) -> "v1GetTrialResponse":
return cls(
trial=trialv1Trial.from_json(obj["trial"]),
workloads=[GetTrialResponseWorkloadContainer.from_json(x) for x in obj["workloads"]],
workloads=[v1WorkloadContainer.from_json(x) for x in obj["workloads"]],
)

def to_json(self) -> typing.Any:
Expand All @@ -2164,6 +2146,28 @@ def to_json(self) -> typing.Any:
"workloads": [x.to_json() for x in self.workloads],
}

class v1GetTrialWorkloadsResponse:
def __init__(
self,
pagination: "v1Pagination",
workloads: "typing.Sequence[v1WorkloadContainer]",
):
self.workloads = workloads
self.pagination = pagination

@classmethod
def from_json(cls, obj: Json) -> "v1GetTrialWorkloadsResponse":
return cls(
workloads=[v1WorkloadContainer.from_json(x) for x in obj["workloads"]],
pagination=v1Pagination.from_json(obj["pagination"]),
)

def to_json(self) -> typing.Any:
return {
"workloads": [x.to_json() for x in self.workloads],
"pagination": self.pagination.to_json(),
}

class v1GetUserResponse:
def __init__(
self,
Expand Down Expand Up @@ -5147,6 +5151,32 @@ def to_json(self) -> typing.Any:
"searcherMetric": dump_float(self.searcherMetric),
}

class v1WorkloadContainer:
def __init__(
self,
checkpoint: "typing.Optional[v1CheckpointWorkload]" = None,
training: "typing.Optional[v1MetricsWorkload]" = None,
validation: "typing.Optional[v1MetricsWorkload]" = None,
):
self.training = training
self.validation = validation
self.checkpoint = checkpoint

@classmethod
def from_json(cls, obj: Json) -> "v1WorkloadContainer":
return cls(
training=v1MetricsWorkload.from_json(obj["training"]) if obj.get("training", None) is not None else None,
validation=v1MetricsWorkload.from_json(obj["validation"]) if obj.get("validation", None) is not None else None,
checkpoint=v1CheckpointWorkload.from_json(obj["checkpoint"]) if obj.get("checkpoint", None) is not None else None,
)

def to_json(self) -> typing.Any:
return {
"training": self.training.to_json() if self.training is not None else None,
"validation": self.validation.to_json() if self.validation is not None else None,
"checkpoint": self.checkpoint.to_json() if self.checkpoint is not None else None,
}

def post_AckAllocationPreemptionSignal(
session: "client.Session",
*,
Expand Down Expand Up @@ -6438,6 +6468,32 @@ def get_GetTrialCheckpoints(
return v1GetTrialCheckpointsResponse.from_json(_resp.json())
raise APIHttpError("get_GetTrialCheckpoints", _resp)

def get_GetTrialWorkloads(
session: "client.Session",
*,
trialId: int,
limit: "typing.Optional[int]" = None,
offset: "typing.Optional[int]" = None,
orderBy: "typing.Optional[v1OrderBy]" = None,
) -> "v1GetTrialWorkloadsResponse":
_params = {
"limit": limit,
"offset": offset,
"orderBy": orderBy.value if orderBy else None,
}
_resp = session._do_request(
method="GET",
path=f"/api/v1/trials/{trialId}/workloads",
params=_params,
json=None,
data=None,
headers=None,
timeout=None,
)
if _resp.status_code == 200:
return v1GetTrialWorkloadsResponse.from_json(_resp.json())
raise APIHttpError("get_GetTrialWorkloads", _resp)

def get_GetUser(
session: "client.Session",
*,
Expand Down
39 changes: 37 additions & 2 deletions master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,45 @@ func (a *apiServer) GetTrial(_ context.Context, req *apiv1.GetTrialRequest) (
return nil, errors.Wrapf(err, "failed to get trial %d", req.TrialId)
}

switch err := a.m.db.QueryProto(
workloads := apiv1.GetTrialWorkloadsResponse{}
switch err := a.m.db.QueryProtof(
"proto_get_trial_workloads",
[]interface{}{
db.OrderByToSQL(apiv1.OrderBy_ORDER_BY_ASC),
db.OrderByToSQL(apiv1.OrderBy_ORDER_BY_ASC),
},
&workloads,
req.TrialId,
nil,
nil,
); {
case err == db.ErrNotFound:
return nil, status.Errorf(codes.NotFound, "trial %d workloads not found:", req.TrialId)
case err != nil:
return nil, errors.Wrapf(err, "failed to get trial %d workloads", req.TrialId)
}

resp.Workloads = workloads.Workloads

return resp, nil
}

func (a *apiServer) GetTrialWorkloads(_ context.Context, req *apiv1.GetTrialWorkloadsRequest) (
*apiv1.GetTrialWorkloadsResponse, error,
) {
resp := &apiv1.GetTrialWorkloadsResponse{}
limit := &req.Limit
if *limit == 0 {
limit = nil
}

switch err := a.m.db.QueryProtof(
"proto_get_trial_workloads",
&resp.Workloads,
[]interface{}{db.OrderByToSQL(req.OrderBy), db.OrderByToSQL(req.OrderBy)},
resp,
req.TrialId,
req.Offset,
limit,
); {
case err == db.ErrNotFound:
return nil, status.Errorf(codes.NotFound, "trial %d workloads not found:", req.TrialId)
Expand Down
3 changes: 2 additions & 1 deletion master/internal/db/postgres_filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ func filterToParams(f api.Filter) []interface{} {
return params
}

func orderByToSQL(order apiv1.OrderBy) string {
// OrderByToSQL computes the SQL keyword corresponding to the given ordering type.
func OrderByToSQL(order apiv1.OrderBy) string {
switch order {
case apiv1.OrderBy_ORDER_BY_UNSPECIFIED:
return asc
Expand Down
2 changes: 1 addition & 1 deletion master/internal/db/postgres_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ FROM task_logs l
WHERE l.task_id = $1
%s
ORDER BY l.id %s LIMIT $2
`, fragment, orderByToSQL(order))
`, fragment, OrderByToSQL(order))

var b []*model.TaskLog
if err := db.queryRows(query, &b, params...); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion master/internal/db/postgres_trial_logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ FROM trial_logs l
WHERE l.trial_id = $1
%s
ORDER BY l.id %s LIMIT $2
`, fragment, orderByToSQL(order))
`, fragment, OrderByToSQL(order))

var b []*model.TrialLog
if err := db.queryRows(query, &b, params...); err != nil {
Expand Down
46 changes: 32 additions & 14 deletions master/static/srv/proto_get_trial_workloads.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
WITH validations_vt AS (
SELECT row_to_json(r1) AS validation, total_batches
SELECT row_to_json(r1) AS validation, total_batches, end_time
FROM (
SELECT 'STATE_' || v.state as state,
v.end_time,
Expand All @@ -11,7 +11,7 @@ WITH validations_vt AS (
) AS r1
),
trainings_vt AS (
SELECT row_to_json(r1) AS training, total_batches
SELECT row_to_json(r1) AS training, total_batches, end_time
FROM (
SELECT s.end_time,
'STATE_' || s.state as state,
Expand All @@ -23,7 +23,7 @@ trainings_vt AS (
) AS r1
),
checkpoints_vt AS (
SELECT row_to_json(r1) AS checkpoint, total_batches
SELECT row_to_json(r1) AS checkpoint, total_batches, end_time
FROM (
SELECT
'STATE_' || c.state AS state,
Expand All @@ -34,15 +34,33 @@ checkpoints_vt AS (
FROM checkpoints_view c
WHERE c.trial_id = $1
) AS r1
),
workloads AS (
SELECT v.validation::jsonb AS validation,
t.training::jsonb AS training,
c.checkpoint::jsonb AS checkpoint,
coalesce(
t.total_batches,
v.total_batches,
c.total_batches
) AS total_batches,
coalesce(
t.end_time,
v.end_time,
c.end_time
) AS end_time
FROM trainings_vt t
FULL JOIN checkpoints_vt c ON false
FULL JOIN validations_vt v ON false
),
page_info AS (
SELECT public.page_info((SELECT COUNT(*) AS count FROM workloads), $2 :: int, $3 :: int) AS page_info
)
SELECT v.validation::jsonb AS validation,
t.training::jsonb AS training,
c.checkpoint::jsonb AS checkpoint
FROM trainings_vt t
FULL JOIN checkpoints_vt c ON false
FULL JOIN validations_vt v ON false
ORDER BY coalesce(
t.total_batches,
v.total_batches,
c.total_batches
) ASC
SELECT (
SELECT jsonb_agg(w) FROM (SELECT validation, training, checkpoint FROM workloads
ORDER BY total_batches %s, end_time %s
OFFSET (SELECT p.page_info->>'start_index' FROM page_info p)::bigint
LIMIT (SELECT (p.page_info->>'end_index')::bigint - (p.page_info->>'start_index')::bigint FROM page_info p)
) w
) AS workloads,
(SELECT p.page_info FROM page_info p) as pagination
Loading

0 comments on commit ce1fe9b

Please sign in to comment.