Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add preview search to new API #813

Merged
merged 4 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,19 @@ package internal

import (
"context"
"encoding/json"
"fmt"
"strings"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"

"github.com/determined-ai/determined/master/pkg/check"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/searcher"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/experimentv1"
)

func (a *apiServer) GetExperiments(
Expand Down Expand Up @@ -43,3 +53,55 @@ func (a *apiServer) GetExperiments(
a.sort(resp.Experiments, req.OrderBy, req.SortBy, apiv1.GetExperimentsRequest_SORT_BY_ID)
return resp, a.paginate(&resp.Pagination, &resp.Experiments, req.Offset, req.Limit)
}

func (a *apiServer) PreviewHPSearch(
_ context.Context, req *apiv1.PreviewHPSearchRequest) (*apiv1.PreviewHPSearchResponse, error) {
bytes, err := protojson.Marshal(req.Config)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "error parsing experiment config: %s", err)
}
config := model.DefaultExperimentConfig()
if err = json.Unmarshal(bytes, &config); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "error parsing experiment config: %s", err)
}
if err = check.Validate(config.Searcher); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid experiment config: %s", err)
}

sm := searcher.NewSearchMethod(config.Searcher, config.BatchesPerStep)
s := searcher.NewSearcher(req.Seed, sm, config.Hyperparameters)
sim, err := searcher.Simulate(s, nil, searcher.RandomValidation, true, config.Searcher.Metric)
if err != nil {
return nil, err
}
protoSim := &experimentv1.ExperimentSimulation{Seed: req.Seed}
indexes := make(map[string]int)
toProto := func(k searcher.Kind) experimentv1.WorkloadKind {
switch k {
case searcher.RunStep:
return experimentv1.WorkloadKind_WORKLOAD_KIND_RUN_STEP
case searcher.ComputeValidationMetrics:
return experimentv1.WorkloadKind_WORKLOAD_KIND_COMPUTE_VALIDATION_METRICS
case searcher.CheckpointModel:
return experimentv1.WorkloadKind_WORKLOAD_KIND_CHECKPOINT_MODEL
default:
return experimentv1.WorkloadKind_WORKLOAD_KIND_UNSPECIFIED
}
}
for _, result := range sim.Results {
var workloads []experimentv1.WorkloadKind
for _, msg := range result {
w := toProto(msg.Workload.Kind)
workloads = append(workloads, w)
}
hash := fmt.Sprint(workloads)
if i, ok := indexes[hash]; ok {
protoSim.Trials[i].Occurrences++
} else {
protoSim.Trials = append(protoSim.Trials,
&experimentv1.TrialSimulation{Workloads: workloads, Occurrences: 1})
indexes[hash] = len(protoSim.Trials) - 1
}
}
return &apiv1.PreviewHPSearchResponse{Simulation: protoSim}, nil
}
2 changes: 2 additions & 0 deletions proto/src/determined/api/v1/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ service Determined {

// Get a list of experiments.
rpc GetExperiments(GetExperimentsRequest) returns (GetExperimentsResponse) { option (google.api.http) = {get: "/api/v1/experiments"}; }
// Preview hyperparameter search.
rpc PreviewHPSearch(PreviewHPSearchRequest) returns (PreviewHPSearchResponse) { option (google.api.http) = {post: "/api/v1/preview-hp-search" body: "*"}; }

// Stream Trial logs.
rpc TrialLogs(TrialLogsRequest) returns (stream TrialLogsResponse) { option (google.api.http) = {get: "/api/v1/trials/{trial_id}/logs"}; }
Expand Down
15 changes: 15 additions & 0 deletions proto/src/determined/api/v1/experiment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ package determined.api.v1;
option go_package = "github.com/determined-ai/determined/proto/pkg/apiv1";

import "google/protobuf/wrappers.proto";
import "google/protobuf/struct.proto";

import "determined/api/v1/pagination.proto";
import "determined/experiment/v1/experiment.proto";
import "determined/experiment/v1/searcher.proto";

// Get a list of experiments.
message GetExperimentsRequest {
Expand Down Expand Up @@ -57,3 +59,16 @@ message GetExperimentsResponse {
// Pagination information of the full dataset.
Pagination pagination = 2;
}

// Preview hyperparameter search.
message PreviewHPSearchRequest {
// The experiment config to simulate.
google.protobuf.Struct config = 1;
// The searcher simulation seed.
uint32 seed = 2;
}
// Response to PreviewSearchRequest.
message PreviewHPSearchResponse {
// The resulting simulation.
determined.experiment.v1.ExperimentSimulation simulation = 1;
}
36 changes: 36 additions & 0 deletions proto/src/determined/experiment/v1/searcher.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
syntax = "proto3";

package determined.experiment.v1;
option go_package = "github.com/determined-ai/determined/proto/pkg/experimentv1";

import "google/protobuf/struct.proto";

// WorkloadKind defines the kind of workload that should be executed by trial runners.
enum WorkloadKind {
// Denotes an unknown workload kind.
WORKLOAD_KIND_UNSPECIFIED = 0;
// Signals to a trial runner that it should run a training step.
WORKLOAD_KIND_RUN_STEP = 1;
// Signals to a trial runner it should compute validation metrics.
WORKLOAD_KIND_COMPUTE_VALIDATION_METRICS = 2;
// Signals to the trial runner that the current model state should be checkpointed.
WORKLOAD_KIND_CHECKPOINT_MODEL = 3;
}

// TrialSimulation is a specific sequence of workloads that were run before the trial was completed.
message TrialSimulation {
// The list of workloads that were run before the trial was completed.
repeated WorkloadKind workloads = 1;
// The number of times that this trial configuration has occurred during the simulation.
int32 occurrences = 2;
}

// ExperimentSimulation holds the configuration and results of simulated run of a searcher.
message ExperimentSimulation {
// The simulated experiment config.
google.protobuf.Struct config = 1;
// The searcher simulation seed.
uint32 seed = 2;
// The list of trials in the simulation.
repeated TrialSimulation trials = 3;
}