Skip to content

Commit

Permalink
feat: add custom reducers to estimators [DET-3098] (#923)
Browse files Browse the repository at this point in the history
* feat: support custom reducers for estimators

Custom metrics for estimators are supported via the
context.experimental.make_metric() method, which can either accept a
single function reducer or a det.estimator.MetricReducer class to do
hierarchical metric reduction.  The resulting custom metric is returned
as a tuple of (value_op, update_op) in accordance with the tf.metrics
API, where the value_op will allgather results from all workers during
distributed evaluation.
  • Loading branch information
rb-determined-ai authored Jul 22, 2020
1 parent 27be82a commit b75442a
Show file tree
Hide file tree
Showing 11 changed files with 498 additions and 13 deletions.
19 changes: 18 additions & 1 deletion docs/reference/api/estimator.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,23 @@ API or Native API.
accessible via ``context.experimental`` for information related to experimental features.

.. autoclass:: determined.estimator.EstimatorExperimentalContext
:members: cache_train_dataset, cache_validation_dataset
:members: cache_train_dataset, cache_validation_dataset, make_metric
:member-order: bysource


Reducing Metrics
~~~~~~~~~~~~~~~~

Determined supports proper reduction of arbitrary validation metrics during
distributed training by allowing users to define custom reducers for their
metrics. Custom reducers can be either a function or an implementation of the
:class:`determined.estimator.MetricReducer` interface.

See :func:`determined.estimator.EstimatorExperimentalContext.make_metric()` for
more details.

.. autoclass:: determined.estimator.MetricReducer
:members: accumulate, cross_slot_reduce
:member-order: bysource


Expand Down Expand Up @@ -83,6 +99,7 @@ Example usage of ``determined.estimator.RunHook`` which adds custom metadata che
hooks=[MyHook(self.context, "my_metadata")],
)


Examples
--------

Expand Down
21 changes: 21 additions & 0 deletions e2e_tests/tests/experiment/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,27 @@ def test_mnist_estimator_data_layer_lfs(tf2: bool) -> None:
run_mnist_estimator_data_layer_test(tf2, "lfs")


@pytest.mark.parallel # type: ignore
@pytest.mark.parametrize("tf2", [True, False]) # type: ignore
def test_custom_reducer_distributed(secrets: Dict[str, str], tf2: bool) -> None:
config = conf.load_config(conf.fixtures_path("estimator_dataset/distributed.yaml"))
# Run with multiple steps to verify we are resetting reducers right.
config = conf.set_max_steps(config, 2)
config = conf.set_slots_per_trial(config, 8)
config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config)

experiment_id = exp.run_basic_test_with_temp_config(
config, conf.fixtures_path("estimator_dataset"), 1
)

trial = exp.experiment_trials(experiment_id)[0]
last_validation = trial["steps"][len(trial["steps"]) - 1]["validation"]
metrics = last_validation["metrics"]["validation_metrics"]
label_sum = 2 * sum(range(16))
assert metrics["label_sum_fn"] == label_sum
assert metrics["label_sum_cls"] == label_sum


@pytest.mark.e2e_gpu # type: ignore
@pytest.mark.parametrize("tf2", [True, False]) # type: ignore
@pytest.mark.parametrize("storage_type", ["s3"]) # type: ignore
Expand Down
4 changes: 3 additions & 1 deletion e2e_tests/tests/fixtures/estimator_dataset/const.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ hyperparameters:
dataset_size: 100
print: true
validation_size: 4
lr: 1
searcher:
name: single
metric: loss
smaller_is_better: true
max_steps: 1
max_restarts: 0
batches_per_step: 1
entrypoint: model:EstimatorDebugTrial
entrypoint: model:EstimatorDatasetTrial
min_validation_period: 1
18 changes: 18 additions & 0 deletions e2e_tests/tests/fixtures/estimator_dataset/distributed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
description: dataset-experiment-distributed
hyperparameters:
global_batch_size: 8
dataset_size: 100
print: true
validation_size: 16
lr: 0.001
searcher:
name: single
metric: loss
smaller_is_better: true
max_steps: 1
max_restarts: 0
batches_per_step: 1
entrypoint: model:EstimatorDatasetTrial
min_validation_period: 1
resources:
slots_per_trial: 8
46 changes: 40 additions & 6 deletions e2e_tests/tests/fixtures/estimator_dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,35 @@
from the analytical calculations. Replace this model with a more robust one.
"""
from typing import Any, List

import numpy as np
import tensorflow as tf

from determined.estimator import EstimatorTrial, EstimatorTrialContext
from determined import estimator


def sum_reducer(batch_metrics: List):
"""A function that is able to operate as a custom reducer."""
return np.hstack(batch_metrics).sum()


class SumReducer(estimator.MetricReducer):
"""A class that is able to operate as a custom reducer."""

def __init__(self):
self.sum = 0

def accumulate(self, metric: Any):
self.sum += metric.sum()
return self.sum

class EstimatorDebugTrial(EstimatorTrial):
def __init__(self, context: EstimatorTrialContext):
def cross_slot_reduce(self, per_slot_metrics: List):
return sum(per_slot_metrics)


class EstimatorDatasetTrial(estimator.EstimatorTrial):
def __init__(self, context: estimator.EstimatorTrialContext):
self.context = context
self.hparams = context.get_hparams()

Expand All @@ -65,15 +86,28 @@ def model_fn(self, features, labels, mode):
with tf.control_dependencies([print_input, print_output, print_loss]):
loss = tf.identity(loss)

opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1)
opt = self.context.wrap_optimizer(
tf.compat.v1.train.GradientDescentOptimizer(learning_rate=self.hparams["lr"])
)
train_op = opt.minimize(loss=loss, global_step=tf.compat.v1.train.get_global_step())

eval_metrics_ops = None
if mode == tf.estimator.ModeKeys.EVAL:
# Use the custom metrics API.
fn_sum = self.context.experimental.make_metric(labels, sum_reducer, np.float32)
cls_sum = self.context.experimental.make_metric(labels, SumReducer(), np.float32)

eval_metrics_ops = {"label_sum_fn": fn_sum, "label_sum_cls": cls_sum}

return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, train_op=train_op, predictions={"output": output, "prod": prod}
mode=mode,
loss=loss,
train_op=train_op,
predictions={"output": output, "prod": prod},
eval_metric_ops=eval_metrics_ops,
)

def build_estimator(self):
_ = self.context.wrap_optimizer(None)
return tf.estimator.Estimator(
model_fn=self.model_fn,
config=tf.estimator.RunConfig(
Expand Down
1 change: 1 addition & 0 deletions harness/determined/estimator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
EstimatorTrialContext,
ServingInputReceiverFn,
)
from determined.estimator._reducer import MetricReducer, _SimpleMetricReducer, _distributed_metric
from determined.estimator._util import (
_cleanup_after_train_step,
_cleanup_after_validation_step,
Expand Down
95 changes: 92 additions & 3 deletions harness/determined/estimator/_estimator_context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import tensorflow as tf

import determined as det
from determined import _data_layer, horovod
from determined import _data_layer, estimator, horovod
from determined.horovod import hvd
from determined_common import check

Expand All @@ -19,7 +19,6 @@
seamlessly distribute training across multiple workers when distributed training is configured.
"""


# The optional interface for specifying serving input receiver functions to
# export SavedModels expects the following function type.
ServingInputReceiverFn = Callable[
Expand Down Expand Up @@ -134,6 +133,9 @@ class EstimatorExperimentalContext(_data_layer.DataLayerContext):
def __init__(self, env: det.EnvContext, hvd_config: horovod.HorovodContext) -> None:
super().__init__(env=env, hvd_config=hvd_config)
self._allgather_fn = None # type: Optional[Callable[[Any], List]]
# allgather is not parallelizable, so we have to strictly order how they are placed in the
# graph via tf.control_dependencies().
self._allgather_ops = [] # type: List[tf.Operation]

def _set_allgather_fn(self, fn: Callable[[Any], List]) -> None:
self._allgather_fn = fn
Expand All @@ -142,3 +144,90 @@ def allgather_metrics(self, metrics: Any) -> List:
if self._allgather_fn is None:
raise AssertionError("allgather_metrics must not be called before training begins")
return self._allgather_fn(metrics)

def _build_allgather_op(self, build_op_fn: Callable[[], tf.Operation]) -> tf.Operation:
"""Build an op that uses allgather in a way that is safely sequentialized."""

with tf.compat.v1.control_dependencies(self._allgather_ops):
new_op = build_op_fn()
self._allgather_ops.append(new_op)
return new_op

def _reset_allgather_ops(self) -> None:
"""Every Estimator evaluation happens on a clean graph, so forget the old operations."""
self._allgather_ops = []

def make_metric(
self,
metric: Any,
reducer: Union[Callable[[List[Any]], Any], "estimator.MetricReducer"],
numpy_dtype: Any,
) -> Tuple[tf.Operation, tf.Operation]:
"""
Return an estimator-compatible validation metric which will be calculated properly, even
during distributed evaluation.
During distributed evaluation, many types of metrics calculated via ``tf.metrics`` or
``tf.keras.metrics`` cannot be aggregated properly from the per-slot final metrics
calculated by each separate Estimator replica. One example is ``tf.metrics.auc``, where
the ROC AUC calculated over predictions and labels from a full dataset cannot be derived
from the individual ROC AUC metrics evaluated over several shards of a dataset.
Determined solves this problem by offering customizable metrics which are
Estimator-compatible. For example, ROC AUC could be properly calculated during distributed
evaluation by calling ``sklearn.metrics.roc_auc_score`` in a custom ``reducer`` function
passed to ``make_metric``.
The ``metric`` input can be a tensor, a list of tensors, or a dictionary of tensors.
The ``reducer`` should be either a single function that can calculate the metric from a
list of the per-batch values of ``metric``, or it can be an instance of a
:class:`det.estimator.MetricReducer<determined.estimator.MetricReducer>`.
The ``numpy_dtype`` must be a numpy dtype. It is used internally to determined the output
type of the TensorFlow ``py_func`` to report the final metric result to the Estimator API.
The format of ``numpy_dtype`` should be anything that ``np.dtype()`` accepts.
The primary motivation for passing a function as the reducer is simplicity. Metrics from
all batches will be buffered in memory and passed over the network where they will be
reduced all at once. This introduces some overhead, but it is likely unnoticeable for
scalar metrics or on validation datasets of small or medium size. This single function
strategy may also be desirable for quick prototyping or for calculating metrics that are
difficult or impossible to calculate incrementally.
The primary motivation for passing a ``det.estimator.MetricsReducer`` as the reducer is
performance. ``det.estimator.MetricsReducer`` allows the user to incrementally calculate
the partial metric on each slot, taking advantage of distributed computation, minimizing
memory usage, and minimizing the network communication before the final
``cross_slot_reduce`` operation.
Evaluation performance may be improved by precomputing as much as possible in the graph so
that less computation on the ``metric`` value is required within the reducer.
Example usage where ``reducer`` is a function:
.. code-block:: python
def my_mean_reducer(all_batch_metrics):
# Use hstack in case not all batches are equal length.
return np.mean(np.hstack(all_batch_metrics))
def my_estimator_model_function(features, labels, mode):
...
if mode == tf.estimator.ModeKeys.EVAL:
my_avg_prediction = context.experimental.make_metric(
metric=predictions, reducer=my_mean_reducer, numpy_dtype=np.float32
)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
eval_metric_ops={"my_avg_prediction": my_avg_prediction},
)
"""
if isinstance(reducer, estimator.MetricReducer):
return estimator._distributed_metric(self, metric, reducer, numpy_dtype)

simple_reducer = estimator._SimpleMetricReducer(reducer)
return estimator._distributed_metric(self, metric, simple_reducer, numpy_dtype)
3 changes: 3 additions & 0 deletions harness/determined/estimator/_estimator_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,9 @@ def compute_validation_metrics(self) -> workload.Response:
pathlib.Path(self.estimator._model_dir), self.is_chief
)

# Reset the per-evaluation set of allgather ops in the context.
self.context.experimental._reset_allgather_ops()

if not self.is_chief:
return workload.Skipped()

Expand Down
Loading

0 comments on commit b75442a

Please sign in to comment.