Skip to content

Commit

Permalink
feat: clean up swagger spec (determined-ai#823)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoni Ben-tzur authored Jul 4, 2020
1 parent afc6e3f commit a9d7007
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 55 deletions.
10 changes: 2 additions & 8 deletions common/determined_common/experimental/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def add_metadata(self, metadata: Dict[str, Any]) -> None:
api.patch(
self._master,
"/api/v1/models/{}".format(self.name),
body={
"model": {"metadata": self.metadata},
"update_mask": {"paths": ["model.metadata"]},
},
body={"model": {"metadata": self.metadata}},
)

def remove_metadata(self, keys: List[str]) -> None:
Expand All @@ -78,10 +75,7 @@ def remove_metadata(self, keys: List[str]) -> None:
api.patch(
self._master,
"/api/v1/models/{}".format(self.name),
body={
"model": {"metadata": self.metadata},
"update_mask": {"paths": ["model.metadata"]},
},
body={"model": {"metadata": self.metadata}},
)

def to_json(self) -> Dict[str, Any]:
Expand Down
8 changes: 4 additions & 4 deletions master/internal/api_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ func (a *apiServer) PatchModel(

paths := req.UpdateMask.GetPaths()
for _, path := range paths {
switch path {
case "model.description":
switch {
case path == "model.description":
m.Description = req.Model.Description
case "model.metadata":
case strings.HasPrefix(path, "model.metadata"):
m.Metadata = req.Model.Metadata
default:
case !strings.HasPrefix(path, "update_mask"):
return nil, status.Errorf(
codes.InvalidArgument,
"only description and metadata fields are mutable. cannot update %s", path)
Expand Down
1 change: 1 addition & 0 deletions proto/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ build: $(patsubst %.proto, %.pb.go, $(shell find src/determined -type f -name '*
rm -rf build/proto
mkdir -p build/swagger
protoc -I src src/determined/api/v1/api.proto --swagger_out=logtostderr=true:build/swagger
python3 scripts/swagger.py build/swagger

.PHONY: check
check:
Expand Down
75 changes: 75 additions & 0 deletions proto/scripts/swagger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import os
import sys


def capitalize(s: str) -> str:
if len(s) <= 1:
return s.title()
return s[0].upper() + s[1:]


def clean(fn: str) -> None:
with open(fn, "r") as fp:
spec = json.load(fp)

# Add tag descriptions.
spec["tags"] = [
{
"name": "Authentication",
"description": "Login and logout of the cluster",
},
{
"name": "Users",
"description": "Manage users",
},
{
"name": "Cluster",
"description": "Manage cluster components",
},
{
"name": "Experiments",
"description": "Manage experiments",
},
{
"name": "Templates",
"description": "Manage templates",
},
{
"name": "Models",
"description": "Manage models",
}
]

# Update path names to be consistent.
paths = {}
for key, value in spec["paths"].items():
paths[key.replace(".", "_")] = value
spec["paths"] = paths

del spec["definitions"]["protobufFieldMask"]
for key, value in spec["definitions"].items():
# Remove definitions that should be hidden from the user.
if key == "protobufAny":
value["title"] = "Object"
elif key == "protobufNullValue":
value["title"] = "NullValue"

# Clean up titles.
if "title" not in value:
value["title"] = "".join(capitalize(k) for k in key.split(sep="v1"))

with open(fn, "w") as fp:
json.dump(spec, fp)


def main() -> None:
files = []
for r, d, f in os.walk(sys.argv[1]):
for file in f:
if file.endswith(".json"):
clean(os.path.join(r, file))


if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions proto/src/determined/api/v1/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import "determined/api/v1/pagination.proto";

import "determined/agent/v1/agent.proto";

// Get a set of agents from the cluster
// Get a set of agents from the cluster.
message GetAgentsRequest {
// Sorts agents by the given field
// Sorts agents by the given field.
enum SortBy {
// Returns agents in an unsorted list.
SORT_BY_UNSPECIFIED = 0;
Expand Down
159 changes: 126 additions & 33 deletions proto/src/determined/api/v1/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,73 +17,166 @@ import "determined/api/v1/user.proto";

option (grpc.gateway.protoc_gen_swagger.options.openapiv2_swagger) = {
info: {
title: "Determined API"
version: "1.0"
};
title: "Determined API",
description: "Determined helps deep learning teams train models more quickly, easily share GPU resources, and effectively collaborate. Determined allows deep learning engineers to focus on building and training models at scale, without needing to worry about DevOps or writing custom code for common tasks like fault tolerance or experiment tracking.\n\nYou can think of Determined as a platform that bridges the gap between tools like TensorFlow and PyTorch --- which work great for a single researcher with a single GPU --- to the challenges that arise when doing deep learning at scale, as teams, clusters, and data sets all increase in size.",
version: "1.0",
contact: {
name: "Determined AI",
url: "https://determined.ai/",
email: "[email protected]",
}
license: {
name: "Apache 2.0",
url: "http://www.apache.org/licenses/LICENSE-2.0.html"
}
}
schemes: [HTTP, HTTPS],
external_docs: {
description: "Determined AI Documentation",
url: "https://docs.determined.ai/",
}
};

// Determined is the official v1 of the Determined API.
service Determined {
// Login the user.
rpc Login(LoginRequest) returns (LoginResponse) { option (google.api.http) = {post: "/api/v1/auth/login" body: "*"}; }
rpc Login(LoginRequest) returns (LoginResponse) {
option (google.api.http) = {post: "/api/v1/auth/login" body: "*"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Authentication"};
}
// Get the current user.
rpc CurrentUser(CurrentUserRequest) returns (CurrentUserResponse) { option (google.api.http) = {get: "/api/v1/auth/user"}; }
rpc CurrentUser(CurrentUserRequest) returns (CurrentUserResponse) {
option (google.api.http) = {get: "/api/v1/auth/user"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Authentication"};
}
// Logout the user.
rpc Logout(LogoutRequest) returns (LogoutResponse) { option (google.api.http) = {post: "/api/v1/auth/logout"}; }
rpc Logout(LogoutRequest) returns (LogoutResponse) {
option (google.api.http) = {post: "/api/v1/auth/logout"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Authentication"};
}

// Get a list of users
rpc GetUsers(GetUsersRequest) returns (GetUsersResponse) { option (google.api.http) = {get: "/api/v1/users"}; }
rpc GetUsers(GetUsersRequest) returns (GetUsersResponse) {
option (google.api.http) = {get: "/api/v1/users"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Users"};
}
// Get the requested user.
rpc GetUser(GetUserRequest) returns (GetUserResponse) { option (google.api.http) = {get: "/api/v1/users/{username}"}; }
rpc GetUser(GetUserRequest) returns (GetUserResponse) {
option (google.api.http) = {get: "/api/v1/users/{username}"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Users"};
}
// Create a new user.
rpc PostUser(PostUserRequest) returns (PostUserResponse) { option (google.api.http) = {post: "/api/v1/users" body: "*"}; }
rpc PostUser(PostUserRequest) returns (PostUserResponse) {
option (google.api.http) = {post: "/api/v1/users" body: "*"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Users"};
}
// Set the requested user's passwords.
rpc SetUserPassword(SetUserPasswordRequest) returns (SetUserPasswordResponse) { option (google.api.http) = {post: "/api/v1/users/{username}/password" body: "password"}; }
rpc SetUserPassword(SetUserPasswordRequest) returns (SetUserPasswordResponse) {
option (google.api.http) = {post: "/api/v1/users/{username}/password" body: "password"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Users"};
}

// Get master information.
rpc GetMaster(GetMasterRequest) returns (GetMasterResponse) { option (google.api.http) = {get: "/api/v1/master"}; }

rpc GetMaster(GetMasterRequest) returns (GetMasterResponse) {
option (google.api.http) = {get: "/api/v1/master"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}
// Get a set of agents from the cluster
rpc GetAgents(GetAgentsRequest) returns (GetAgentsResponse) { option (google.api.http) = {get: "/api/v1/agents"}; }
rpc GetAgents(GetAgentsRequest) returns (GetAgentsResponse) {
option (google.api.http) = {get: "/api/v1/agents"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}
// Get the requested agent.
rpc GetAgent(GetAgentRequest) returns (GetAgentResponse) { option (google.api.http) = {get: "/api/v1/agents/{agent_id}"}; }
rpc GetAgent(GetAgentRequest) returns (GetAgentResponse) {
option (google.api.http) = {get: "/api/v1/agents/{agent_id}"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}
// Get the set of slots for the agent with the given id.
rpc GetSlots(GetSlotsRequest) returns (GetSlotsResponse) { option (google.api.http) = {get: "/api/v1/agents/{agent_id}/slots"}; }
rpc GetSlots(GetSlotsRequest) returns (GetSlotsResponse) {
option (google.api.http) = {get: "/api/v1/agents/{agent_id}/slots"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}
// Get the requested slot for the agent with the given id.
rpc GetSlot(GetSlotRequest) returns (GetSlotResponse) { option (google.api.http) = {get: "/api/v1/agents/{agent_id}/slots/{slot_id}"}; }

rpc GetSlot(GetSlotRequest) returns (GetSlotResponse) {
option (google.api.http) = {get: "/api/v1/agents/{agent_id}/slots/{slot_id}"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}
// Enable the agent.
rpc EnableAgent(EnableAgentRequest) returns (EnableAgentResponse) { option (google.api.http) = {post: "/api/v1/agents/{agent_id}/enable"}; }
rpc EnableAgent(EnableAgentRequest) returns (EnableAgentResponse) {
option (google.api.http) = {post: "/api/v1/agents/{agent_id}/enable"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}
// Disable the agent.
rpc DisableAgent(DisableAgentRequest) returns (DisableAgentResponse) { option (google.api.http) = {post: "/api/v1/agents/{agent_id}/disable"}; }
rpc DisableAgent(DisableAgentRequest) returns (DisableAgentResponse) {
option (google.api.http) = {post: "/api/v1/agents/{agent_id}/disable"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}
// Enable the slot.
rpc EnableSlot(EnableSlotRequest) returns (EnableSlotResponse) { option (google.api.http) = {post: "/api/v1/agents/{agent_id}/slots/{slot_id}/enable"}; }
rpc EnableSlot(EnableSlotRequest) returns (EnableSlotResponse) {
option (google.api.http) = {post: "/api/v1/agents/{agent_id}/slots/{slot_id}/enable"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}
// Disable the slot.
rpc DisableSlot(DisableSlotRequest) returns (DisableSlotResponse) { option (google.api.http) = {post: "/api/v1/agents/{agent_id}/slots/{slot_id}/disable"}; }
rpc DisableSlot(DisableSlotRequest) returns (DisableSlotResponse) {
option (google.api.http) = {post: "/api/v1/agents/{agent_id}/slots/{slot_id}/disable"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Cluster"};
}

// Get a list of experiments.
rpc GetExperiments(GetExperimentsRequest) returns (GetExperimentsResponse) { option (google.api.http) = {get: "/api/v1/experiments"}; }
rpc GetExperiments(GetExperimentsRequest) returns (GetExperimentsResponse) {
option (google.api.http) = {get: "/api/v1/experiments"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Experiments"};
}
// Preview hyperparameter search.
rpc PreviewHPSearch(PreviewHPSearchRequest) returns (PreviewHPSearchResponse) { option (google.api.http) = {post: "/api/v1/preview-hp-search" body: "*"}; }

rpc PreviewHPSearch(PreviewHPSearchRequest) returns (PreviewHPSearchResponse) {
option (google.api.http) = {post: "/api/v1/preview-hp-search" body: "*"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Experiments"};
}
// Stream Trial logs.
rpc TrialLogs(TrialLogsRequest) returns (stream TrialLogsResponse) { option (google.api.http) = {get: "/api/v1/trials/{trial_id}/logs"}; }
rpc TrialLogs(TrialLogsRequest) returns (stream TrialLogsResponse) {
option (google.api.http) = {get: "/api/v1/trials/{trial_id}/logs"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Experiments"};
}

// Get a list of templates.
rpc GetTemplates(GetTemplatesRequest) returns (GetTemplatesResponse) { option (google.api.http) = {get: "/api/v1/templates"}; }
rpc GetTemplates(GetTemplatesRequest) returns (GetTemplatesResponse) {
option (google.api.http) = {get: "/api/v1/templates"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Templates"};
}
// Get the requested template.
rpc GetTemplate(GetTemplateRequest) returns (GetTemplateResponse) { option (google.api.http) = {get: "/api/v1/templates/{template_name}"}; }
rpc GetTemplate(GetTemplateRequest) returns (GetTemplateResponse) {
option (google.api.http) = {get: "/api/v1/templates/{template_name}"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Templates"};
}
// Update the requested template. If one does not exist, a new template is created
rpc PutTemplate(PutTemplateRequest) returns (PutTemplateResponse) { option (google.api.http) = {put: "/api/v1/templates/{template.name}" body: "template"}; }
rpc PutTemplate(PutTemplateRequest) returns (PutTemplateResponse) {
option (google.api.http) = {put: "/api/v1/templates/{template.name}" body: "template"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Templates"};
}
// Delete the template with the given id.
rpc DeleteTemplate(DeleteTemplateRequest) returns (DeleteTemplateResponse) { option (google.api.http) = {delete: "/api/v1/templates/{template_name}"}; }
rpc DeleteTemplate(DeleteTemplateRequest) returns (DeleteTemplateResponse) {
option (google.api.http) = {delete: "/api/v1/templates/{template_name}"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Templates"};
}

// Get the requested model.
rpc GetModel(GetModelRequest) returns (GetModelResponse) { option (google.api.http) = {get: "/api/v1/models/{model_name}"}; }
rpc GetModel(GetModelRequest) returns (GetModelResponse) {
option (google.api.http) = {get: "/api/v1/models/{model_name}"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Models"};
}
// Create a model in the registry.
rpc PostModel(PostModelRequest) returns (PostModelResponse) { option (google.api.http) = {post: "/api/v1/models/{model.name}" body: "model"}; }
rpc PostModel(PostModelRequest) returns (PostModelResponse) {
option (google.api.http) = {post: "/api/v1/models/{model.name}" body: "model"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Models"};
}
// Update model fields
rpc PatchModel(PatchModelRequest) returns (PatchModelResponse) { option (google.api.http) = {patch: "/api/v1/models/{model.name}" body: "*"}; }
rpc PatchModel(PatchModelRequest) returns (PatchModelResponse) {
option (google.api.http) = {patch: "/api/v1/models/{model.name}" body: "model"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Models"};
}
// Get the requested model.
rpc GetModels(GetModelsRequest) returns (GetModelsResponse) { option (google.api.http) = {get: "/api/v1/models"}; }
rpc GetModels(GetModelsRequest) returns (GetModelsResponse) {
option (google.api.http) = {get: "/api/v1/models"};
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_operation) = {tags: "Models"};
}
}
3 changes: 1 addition & 2 deletions proto/src/determined/api/v1/model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package determined.api.v1;
option go_package = "github.com/determined-ai/determined/proto/pkg/apiv1";

import "google/protobuf/field_mask.proto";
import "google/protobuf/struct.proto";

import "determined/api/v1/pagination.proto";
import "determined/model/v1/model.proto";
Expand Down Expand Up @@ -54,7 +53,7 @@ message GetModelsRequest {
string description = 6;
}

// Response to GetModelsRequest
// Response to GetModelsRequest.
message GetModelsResponse {
// The list of returned models.
repeated determined.model.v1.Model models = 1;
Expand Down
4 changes: 0 additions & 4 deletions proto/src/determined/api/v1/trial.proto
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ syntax = "proto3";
package determined.api.v1;
option go_package = "github.com/determined-ai/determined/proto/pkg/apiv1";

import "google/protobuf/wrappers.proto";

import "determined/api/v1/pagination.proto";

// Stream Trial logs.
message TrialLogsRequest {
// The id of the trial.
Expand Down
8 changes: 7 additions & 1 deletion proto/src/determined/model/v1/model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@ option go_package = "github.com/determined-ai/determined/proto/pkg/modelv1";

import "google/protobuf/struct.proto";
import "google/protobuf/timestamp.proto";
import "protoc-gen-swagger/options/annotations.proto";

// Model is a named collection of model versions.
message Model {
option (grpc.gateway.protoc_gen_swagger.options.openapiv2_schema) = {
json_schema: {
required: ["name", "metadata", "creation_time", "last_updated_time"]
}
};
// The name of the model.
string name = 1;
string name = 1 [(grpc.gateway.protoc_gen_swagger.options.openapiv2_field) = {min_length: 1}];
// The description of the model.
string description = 2;
// The user-defined metadata of the model.
Expand Down
2 changes: 1 addition & 1 deletion webui/react/src/services/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
import { serverAddress } from 'utils/routes';
import { isExperimentTask } from 'utils/task';

export const sApi = new DetSwagger.DeterminedApi(undefined, serverAddress());
export const sApi = new DetSwagger.AuthenticationApi(undefined, serverAddress());

/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const isAuthFailure = (e: any): boolean => {
Expand Down

0 comments on commit a9d7007

Please sign in to comment.