Skip to content

Commit

Permalink
docs: add documentation for EsatimatorTrial callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron276h committed May 27, 2020
1 parent 454957b commit 0978df0
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions docs/reference/api/estimator.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,40 @@ API or Native API.
:members: cache_train_dataset, cache_validation_dataset
:member-order: bysource


Callbacks
~~~~~~~~~

To execute arbitrary Python functionality during the lifecycle of a
``EstimatorTrial``, ``determined.estimator.RunHook`` extends
`tf.estimator.SessionRunHook <https://www.tensorflow.org/api_docs/python/tf/estimator/SessionRunHook/>`_.

.. autoclass:: determined.estimator.RunHook
:members: on_checkpoint_load, on_checkpoint_end


Example usage of ``determined.estimator.RunHook`` which adds custom metadata checkpoints:

.. code:: python

class MyHook(determined.estimator.RunHook):
def __init__(self, context, metadata) -> None:
self._context = context
self._metadata = metadata

def on_checkpoint_end(self, checkpoint_path) -> None:
with open(str(checkpoint_dir.joinpath("metadata.txt")), "w") as fp:
fp.write(self._metadata)

class MyEstimatorTrial(determined.estimator.EstimatorTrial) -> None:
...

def build_train_spec(self) -> tf.estimator.TrainSpec:
return tf.estimator.TrainSpec(
make_input_fn(),
hooks=[MyHook(self.context, "my_metadata")],
)

Examples
--------

Expand Down

0 comments on commit 0978df0

Please sign in to comment.