Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: model registry register_version CLI command [DET-3481] #881

Merged
merged 1 commit into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions cli/determined_cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from determined_common import api
from determined_common.api.authentication import authentication_required
from determined_common.experimental import Determined, Model, ModelOrderBy, ModelSortBy
from determined_common.experimental import Checkpoint, Determined, Model, ModelOrderBy, ModelSortBy

from . import render
from .declarative_argparse import Arg, Cmd
Expand All @@ -24,6 +24,30 @@ def render_model(model: Model) -> None:
render.tabulate_or_csv(headers, [values], False)


def render_model_version(checkpoint: Checkpoint) -> None:
headers = [
"Version #",
"Trial ID",
"Batch #",
"Checkpoint UUID",
"Validation Metrics",
"Metadata",
]

values = [
[
checkpoint.version,
checkpoint.trial_id,
checkpoint.batch_number,
checkpoint.uuid,
json.dumps(checkpoint.validation, indent=2),
json.dumps(checkpoint.metadata, indent=2),
]
]

render.tabulate_or_csv(headers, values, False)


def list_models(args: Namespace) -> None:
models = Determined(args.master, None).get_models(
sort_by=ModelSortBy[args.sort_by.upper()], order_by=ModelOrderBy[args.order_by.upper()]
Expand Down Expand Up @@ -88,35 +112,31 @@ def create(args: Namespace) -> None:

def describe(args: Namespace) -> None:
model = Determined(args.master, None).get_model(args.name)
ckpt = model.get_version()
checkpoint = model.get_version()

if args.json:
print(json.dumps(model.to_json(), indent=2))
else:
render_model(model)
print("\n")
render_model_version(checkpoint)

headers = [
"Version #",
"Trial ID",
"Batch #",
"Checkpoint UUID",
"Validation Metrics",
"Metadata",
]

print("\n")
values = [
[
ckpt.version,
ckpt.trial_id,
ckpt.batch_number,
ckpt.uuid,
json.dumps(ckpt.validation, indent=2),
json.dumps(ckpt.metadata, indent=2),
]
]
def register_version(args: Namespace) -> None:
if args.json:
resp = api.post(
args.master,
"/api/v1/models/{}/versions".format(args.name),
body={"checkpoint_uuid": args.uuid},
)

render.tabulate_or_csv(headers, values, False)
print(json.dumps(resp.json(), indent=2))
else:
model = Determined(args.master, None).get_model(args.name)
checkpoint = model.register_version(args.uuid)
render_model(model)
print("\n")
render_model_version(checkpoint)


args_description = [
Expand Down Expand Up @@ -148,6 +168,16 @@ def describe(args: Namespace) -> None:
],
is_default=True,
),
Cmd(
"register_version",
register_version,
"register a new version of a model",
[
Arg("name", type=str, help="name of the model"),
Arg("uuid", type=str, help="uuid to register as the next version of the model"),
Arg("--json", action="store_true", help="print as JSON"),
],
),
Cmd(
"describe",
describe,
Expand Down
40 changes: 36 additions & 4 deletions common/determined_common/experimental/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import enum
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional

from determined_common import api
from determined_common.experimental.checkpoint import Checkpoint
Expand Down Expand Up @@ -43,6 +43,13 @@ def __init__(
self.metadata = metadata or {}

def get_version(self, version: int = 0) -> Checkpoint:
"""
Retrieve the checkpoint corresponding to the specified version of the
model. If no version is specified the latest model version is returned.

Arguments:
version (int, optional): the model version number requested.
"""
if version == 0:
resp = api.get(
self._master,
Expand All @@ -66,6 +73,14 @@ def get_version(self, version: int = 0) -> Checkpoint:
return Checkpoint.from_json(data["version"]["checkpoint"], self._master)

def get_versions(self, order_by: ModelOrderBy = ModelOrderBy.DESC) -> List[Checkpoint]:
"""
Get a list of checkpoints corresponding to versions of this model. The
models are sorted by version number and are returned in descending
order by default.

Arguments:
order_by (enum): a member of the ModelOrderBy enum.
"""
resp = api.get(
self._master,
"/api/v1/models/{}/versions/".format(self.name),
Expand All @@ -85,14 +100,31 @@ def get_versions(self, order_by: ModelOrderBy = ModelOrderBy.DESC) -> List[Check
for version in data["versions"]
]

def register_version(self, checkpoint: Checkpoint) -> int:
def register_version(self, checkpoint_uuid: str) -> Checkpoint:
"""
Creats a new model version and returns the
:class:`~determined.experimental.Checkpoint` corresponding to the
version.

Arguments:
checkpoint_uuid: the uuid to associate with the new model version.
"""
resp = api.post(
self._master,
"/api/v1/models/{}/versions".format(self.name),
body={"checkpoint_uuid": checkpoint.uuid},
body={"checkpoint_uuid": checkpoint_uuid},
)

return cast(int, resp.json()["version"]["version"])
data = resp.json()

return Checkpoint.from_json(
{
**data["version"]["checkpoint"],
"version": data["version"]["version"],
"model_name": data["version"]["model"]["name"],
},
self._master,
)

def add_metadata(self, metadata: Dict[str, Any]) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions e2e_tests/tests/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ def test_model_registry() -> None:
assert mnist.metadata == {"testing": "override"}

checkpoint = d.get_experiment(exp_id).top_checkpoint()
model_version = mnist.register_version(checkpoint)
assert model_version == 1
model_version = mnist.register_version(checkpoint.uuid)
assert model_version.version == 1
assert mnist.get_version().uuid == checkpoint.uuid

d.create_model("transformer", "all you need is attention")
Expand Down