Skip to content

Commit

Permalink
docs: checkpoint metadata [DET-3211] (#671)
Browse files Browse the repository at this point in the history
Documentation for the arbitrary user checkpoint metadata feature.
  • Loading branch information
sidneyw authored Jun 11, 2020
1 parent 3b12c87 commit 39d7f18
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 17 deletions.
50 changes: 35 additions & 15 deletions common/determined_common/experimental/checkpoint/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,27 @@ def __init__(
):
"""
Arguments:
uuid (string): UUID of the checkpoint.
experiment_config: The experiment configuration related to the checkpoint.
experiment_id: Trial ID for the trial related to the checkpoint.
trial_id: Trial ID for the trial related to the checkpoint.
hparams: Hyperparameter values fro the trial related to the checkpoint.
batch_number: Batch number of the checkpoint.
start_time: Timestamp of when the checkpoint began being saved to
uuid (string): UUID of this checkpoint.
experiment_config (dict): The configuration of the experiment that
created this checkpoint.
experiment_id (int): The ID of the experiment that created this checkpoint.
trial_id (int): The ID of the trial that created this checkpoint.
hparams (dict): Hyperparameter values for the trial that created
this checkpoint.
batch_number (int): Batch number of the checkpoint.
start_time (string): Timestamp of when the checkpoint began being saved to
persistent storage.
end_time: Timestamp of when the checkpoint completed being saved to
end_time (string): Timestamp of when the checkpoint completed being saved to
persistent storage.
resources: Dictionary of file paths to file sizes in bytes of all
resources (dict): Dictionary of file paths to file sizes in bytes of all
files related to the checkpoint.
validation: Dictionary of validation metric names to their values.
framework: The framework of the trial ie. tensorflow, torch.
format: The format of the checkpoint ie h5, saved_model, pickle.
determined_version: the version of Determined the checkpoint was taken with.
metadata: User defined metadata associated with the checkpoint.
master: The address of the determined master instance.
validation (dict): Dictionary of validation metric names to their values.
framework (string, optional): The framework of the trial i.e., tensorflow, torch.
format (string, optional): The format of the checkpoint i.e., h5, saved_model, pickle.
determined_version (str, optional): The version of Determined the
checkpoint was taken with.
metadata (dict, optional): User defined metadata associated with the checkpoint.
master (string, optional): The address of the Determined master instance.
"""

self.uuid = uuid
Expand Down Expand Up @@ -185,6 +188,15 @@ def load(
return Checkpoint.load_from_path(ckpt_path, tags=tags)

def add_metadata(self, metadata: Dict[str, Any]) -> None:
"""
Adds metadata to the checkpoint. JSON serializable dictionaries are
permitted as an argument. If a top level key in the metadata argument
already exists in the checkpoint metadata the entire tree is replaced
in favor of the passed metadata value.
Arguments:
metadata (dict): Dictionary of metadata to add to the checkpoint.
"""
if self._master:
r = api.post(
self._master,
Expand All @@ -194,6 +206,14 @@ def add_metadata(self, metadata: Dict[str, Any]) -> None:
self.metadata = r.json()

def remove_metadata(self, keys: List[str]) -> None:
"""
Remove checkpoint metadata top level keys corresponding to the keys
passed as arguments. If a provided key does not exist in the checkpoint
metadata this method is a no-op.
Arguments:
keys (List[string]): Top level keys of the checkpoint metadata to remove.
"""
if self._master:
r = api.delete(
self._master, "checkpoints/{}/metadata".format(self.uuid), params={"keys": keys}
Expand Down
49 changes: 47 additions & 2 deletions docs/how-to/use-trained-models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This guide discusses:

#. Querying trained model checkpoints from trials and experiments.
#. Loading models into memory in a Python process.
#. Storing metadata associated with checkpoints.
#. Using the Determined CLI to download checkpoints to disk.

Querying Checkpoints
Expand Down Expand Up @@ -140,8 +141,9 @@ memory in a Python process, as shown in the following snippet.
checkpoint = Determined.get_experiment(id).top_checkpoint()
model = checkpoint.load()

TensorFlow checkpoints are saved in the ``saved_model`` format and are loaded
as trackable objects (see documentation for `tf.compat.v1.saved_model.load_v2
TensorFlow checkpoints are saved in either the ``saved_model`` or ``h5``
formats and are loaded as trackable objects (see documentation for
`tf.compat.v1.saved_model.load_v2
<https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/saved_model/load_v2>`__
for details).

Expand All @@ -151,6 +153,49 @@ PyTorch checkpoints are saved using `pickle
<https://pytorch.org/docs/stable/notes/serialization.html>`__ for
details).

Storing Checkpoint Metadata
---------------------------
You may store metadata related to a checkpoint via the Python API. This feature
is useful for storing post training metrics, labels, information related to
deployment, etc.

.. code:: python

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint()
checkpoint.add_metadata({"environment": "production"})

# Metadata will be stored in Determined and accessible on the checkpoint object.
print(checkpoint.metadata)

You may store an arbitrarily nested dictionary using the :meth:`~determined.experimental.Checkpoint.add_metadata`
method. If the top level key already exists the entire tree beneath it will
be overwritten.

.. code:: python

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint()
checkpoint.add_metadata({"metrics": {"loss": 0.12}})
checkpoint.add_metadata({"metrics": {"acc": 0.92}})

print(checkpoint.metadata) # will output: {"metrics": {"acc": 0.92}}

You may remove metadata via the
:meth:`~determined.experimental.Checkpoint.remove_metadata` method. The method
accepts a list of top level keys. The entire tree beneath the keys passed will
be deleted.

.. code:: python

from determined.experimental import Determined

checkpoint = Determined().get_experiment(id).top_checkpoint()
checkpoint.remove_metadata(["metrics"])


Download Checkpoints via the CLI
--------------------------------
Determined offers the following CLI commands for downloading checkpoints locally:
Expand Down

0 comments on commit 39d7f18

Please sign in to comment.