From b75442a90bbbe7be57b22f815864bba7a36fd29e Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 22 Jul 2020 10:13:12 -0600 Subject: [PATCH] feat: add custom reducers to estimators [DET-3098] (#923) * feat: support custom reducers for estimators Custom metrics for estimators are supported via the context.experimental.make_metric() method, which can either accept a single function reducer or a det.estimator.MetricReducer class to do hierarchical metric reduction. The resulting custom metric is returned as a tuple of (value_op, update_op) in accordance with the tf.metrics API, where the value_op will allgather results from all workers during distributed evaluation. --- docs/reference/api/estimator.txt | 19 ++- .../tests/experiment/test_tf_estimator.py | 21 +++ .../fixtures/estimator_dataset/const.yaml | 4 +- .../estimator_dataset/distributed.yaml | 18 ++ .../tests/fixtures/estimator_dataset/model.py | 46 +++++- harness/determined/estimator/__init__.py | 1 + .../estimator/_estimator_context.py | 95 ++++++++++- .../determined/estimator/_estimator_trial.py | 3 + harness/determined/estimator/_reducer.py | 154 ++++++++++++++++++ .../fixtures/estimator_linear_model.py | 111 +++++++++++++ .../tensorflow/test_estimator_trial.py | 39 ++++- 11 files changed, 498 insertions(+), 13 deletions(-) create mode 100644 e2e_tests/tests/fixtures/estimator_dataset/distributed.yaml create mode 100644 harness/determined/estimator/_reducer.py create mode 100644 harness/tests/experiment/fixtures/estimator_linear_model.py diff --git a/docs/reference/api/estimator.txt b/docs/reference/api/estimator.txt index 63945cec4f6..6b79addf88a 100644 --- a/docs/reference/api/estimator.txt +++ b/docs/reference/api/estimator.txt @@ -44,7 +44,23 @@ API or Native API. accessible via ``context.experimental`` for information related to experimental features. .. autoclass:: determined.estimator.EstimatorExperimentalContext - :members: cache_train_dataset, cache_validation_dataset + :members: cache_train_dataset, cache_validation_dataset, make_metric + :member-order: bysource + + +Reducing Metrics +~~~~~~~~~~~~~~~~ + +Determined supports proper reduction of arbitrary validation metrics during +distributed training by allowing users to define custom reducers for their +metrics. Custom reducers can be either a function or an implementation of the +:class:`determined.estimator.MetricReducer` interface. + +See :func:`determined.estimator.EstimatorExperimentalContext.make_metric()` for +more details. + +.. autoclass:: determined.estimator.MetricReducer + :members: accumulate, cross_slot_reduce :member-order: bysource @@ -83,6 +99,7 @@ Example usage of ``determined.estimator.RunHook`` which adds custom metadata che hooks=[MyHook(self.context, "my_metadata")], ) + Examples -------- diff --git a/e2e_tests/tests/experiment/test_tf_estimator.py b/e2e_tests/tests/experiment/test_tf_estimator.py index e2fbf780439..59a8f338862 100644 --- a/e2e_tests/tests/experiment/test_tf_estimator.py +++ b/e2e_tests/tests/experiment/test_tf_estimator.py @@ -150,6 +150,27 @@ def test_mnist_estimator_data_layer_lfs(tf2: bool) -> None: run_mnist_estimator_data_layer_test(tf2, "lfs") +@pytest.mark.parallel # type: ignore +@pytest.mark.parametrize("tf2", [True, False]) # type: ignore +def test_custom_reducer_distributed(secrets: Dict[str, str], tf2: bool) -> None: + config = conf.load_config(conf.fixtures_path("estimator_dataset/distributed.yaml")) + # Run with multiple steps to verify we are resetting reducers right. + config = conf.set_max_steps(config, 2) + config = conf.set_slots_per_trial(config, 8) + config = conf.set_tf2_image(config) if tf2 else conf.set_tf1_image(config) + + experiment_id = exp.run_basic_test_with_temp_config( + config, conf.fixtures_path("estimator_dataset"), 1 + ) + + trial = exp.experiment_trials(experiment_id)[0] + last_validation = trial["steps"][len(trial["steps"]) - 1]["validation"] + metrics = last_validation["metrics"]["validation_metrics"] + label_sum = 2 * sum(range(16)) + assert metrics["label_sum_fn"] == label_sum + assert metrics["label_sum_cls"] == label_sum + + @pytest.mark.e2e_gpu # type: ignore @pytest.mark.parametrize("tf2", [True, False]) # type: ignore @pytest.mark.parametrize("storage_type", ["s3"]) # type: ignore diff --git a/e2e_tests/tests/fixtures/estimator_dataset/const.yaml b/e2e_tests/tests/fixtures/estimator_dataset/const.yaml index eddb1089913..1b2d19034a6 100644 --- a/e2e_tests/tests/fixtures/estimator_dataset/const.yaml +++ b/e2e_tests/tests/fixtures/estimator_dataset/const.yaml @@ -4,6 +4,7 @@ hyperparameters: dataset_size: 100 print: true validation_size: 4 + lr: 1 searcher: name: single metric: loss @@ -11,4 +12,5 @@ searcher: max_steps: 1 max_restarts: 0 batches_per_step: 1 -entrypoint: model:EstimatorDebugTrial +entrypoint: model:EstimatorDatasetTrial +min_validation_period: 1 diff --git a/e2e_tests/tests/fixtures/estimator_dataset/distributed.yaml b/e2e_tests/tests/fixtures/estimator_dataset/distributed.yaml new file mode 100644 index 00000000000..0b81f723b9e --- /dev/null +++ b/e2e_tests/tests/fixtures/estimator_dataset/distributed.yaml @@ -0,0 +1,18 @@ +description: dataset-experiment-distributed +hyperparameters: + global_batch_size: 8 + dataset_size: 100 + print: true + validation_size: 16 + lr: 0.001 +searcher: + name: single + metric: loss + smaller_is_better: true + max_steps: 1 +max_restarts: 0 +batches_per_step: 1 +entrypoint: model:EstimatorDatasetTrial +min_validation_period: 1 +resources: + slots_per_trial: 8 diff --git a/e2e_tests/tests/fixtures/estimator_dataset/model.py b/e2e_tests/tests/fixtures/estimator_dataset/model.py index e9df407696d..cbdc48a8f51 100644 --- a/e2e_tests/tests/fixtures/estimator_dataset/model.py +++ b/e2e_tests/tests/fixtures/estimator_dataset/model.py @@ -31,14 +31,35 @@ from the analytical calculations. Replace this model with a more robust one. """ +from typing import Any, List + import numpy as np import tensorflow as tf -from determined.estimator import EstimatorTrial, EstimatorTrialContext +from determined import estimator + + +def sum_reducer(batch_metrics: List): + """A function that is able to operate as a custom reducer.""" + return np.hstack(batch_metrics).sum() + + +class SumReducer(estimator.MetricReducer): + """A class that is able to operate as a custom reducer.""" + + def __init__(self): + self.sum = 0 + def accumulate(self, metric: Any): + self.sum += metric.sum() + return self.sum -class EstimatorDebugTrial(EstimatorTrial): - def __init__(self, context: EstimatorTrialContext): + def cross_slot_reduce(self, per_slot_metrics: List): + return sum(per_slot_metrics) + + +class EstimatorDatasetTrial(estimator.EstimatorTrial): + def __init__(self, context: estimator.EstimatorTrialContext): self.context = context self.hparams = context.get_hparams() @@ -65,15 +86,28 @@ def model_fn(self, features, labels, mode): with tf.control_dependencies([print_input, print_output, print_loss]): loss = tf.identity(loss) - opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1) + opt = self.context.wrap_optimizer( + tf.compat.v1.train.GradientDescentOptimizer(learning_rate=self.hparams["lr"]) + ) train_op = opt.minimize(loss=loss, global_step=tf.compat.v1.train.get_global_step()) + eval_metrics_ops = None + if mode == tf.estimator.ModeKeys.EVAL: + # Use the custom metrics API. + fn_sum = self.context.experimental.make_metric(labels, sum_reducer, np.float32) + cls_sum = self.context.experimental.make_metric(labels, SumReducer(), np.float32) + + eval_metrics_ops = {"label_sum_fn": fn_sum, "label_sum_cls": cls_sum} + return tf.estimator.EstimatorSpec( - mode=mode, loss=loss, train_op=train_op, predictions={"output": output, "prod": prod} + mode=mode, + loss=loss, + train_op=train_op, + predictions={"output": output, "prod": prod}, + eval_metric_ops=eval_metrics_ops, ) def build_estimator(self): - _ = self.context.wrap_optimizer(None) return tf.estimator.Estimator( model_fn=self.model_fn, config=tf.estimator.RunConfig( diff --git a/harness/determined/estimator/__init__.py b/harness/determined/estimator/__init__.py index bbcecd9aaac..2a59f415bee 100644 --- a/harness/determined/estimator/__init__.py +++ b/harness/determined/estimator/__init__.py @@ -6,6 +6,7 @@ EstimatorTrialContext, ServingInputReceiverFn, ) +from determined.estimator._reducer import MetricReducer, _SimpleMetricReducer, _distributed_metric from determined.estimator._util import ( _cleanup_after_train_step, _cleanup_after_validation_step, diff --git a/harness/determined/estimator/_estimator_context.py b/harness/determined/estimator/_estimator_context.py index a1a887daf64..e5d89877846 100644 --- a/harness/determined/estimator/_estimator_context.py +++ b/harness/determined/estimator/_estimator_context.py @@ -1,10 +1,10 @@ import logging -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import tensorflow as tf import determined as det -from determined import _data_layer, horovod +from determined import _data_layer, estimator, horovod from determined.horovod import hvd from determined_common import check @@ -19,7 +19,6 @@ seamlessly distribute training across multiple workers when distributed training is configured. """ - # The optional interface for specifying serving input receiver functions to # export SavedModels expects the following function type. ServingInputReceiverFn = Callable[ @@ -134,6 +133,9 @@ class EstimatorExperimentalContext(_data_layer.DataLayerContext): def __init__(self, env: det.EnvContext, hvd_config: horovod.HorovodContext) -> None: super().__init__(env=env, hvd_config=hvd_config) self._allgather_fn = None # type: Optional[Callable[[Any], List]] + # allgather is not parallelizable, so we have to strictly order how they are placed in the + # graph via tf.control_dependencies(). + self._allgather_ops = [] # type: List[tf.Operation] def _set_allgather_fn(self, fn: Callable[[Any], List]) -> None: self._allgather_fn = fn @@ -142,3 +144,90 @@ def allgather_metrics(self, metrics: Any) -> List: if self._allgather_fn is None: raise AssertionError("allgather_metrics must not be called before training begins") return self._allgather_fn(metrics) + + def _build_allgather_op(self, build_op_fn: Callable[[], tf.Operation]) -> tf.Operation: + """Build an op that uses allgather in a way that is safely sequentialized.""" + + with tf.compat.v1.control_dependencies(self._allgather_ops): + new_op = build_op_fn() + self._allgather_ops.append(new_op) + return new_op + + def _reset_allgather_ops(self) -> None: + """Every Estimator evaluation happens on a clean graph, so forget the old operations.""" + self._allgather_ops = [] + + def make_metric( + self, + metric: Any, + reducer: Union[Callable[[List[Any]], Any], "estimator.MetricReducer"], + numpy_dtype: Any, + ) -> Tuple[tf.Operation, tf.Operation]: + """ + Return an estimator-compatible validation metric which will be calculated properly, even + during distributed evaluation. + + During distributed evaluation, many types of metrics calculated via ``tf.metrics`` or + ``tf.keras.metrics`` cannot be aggregated properly from the per-slot final metrics + calculated by each separate Estimator replica. One example is ``tf.metrics.auc``, where + the ROC AUC calculated over predictions and labels from a full dataset cannot be derived + from the individual ROC AUC metrics evaluated over several shards of a dataset. + + Determined solves this problem by offering customizable metrics which are + Estimator-compatible. For example, ROC AUC could be properly calculated during distributed + evaluation by calling ``sklearn.metrics.roc_auc_score`` in a custom ``reducer`` function + passed to ``make_metric``. + + The ``metric`` input can be a tensor, a list of tensors, or a dictionary of tensors. + + The ``reducer`` should be either a single function that can calculate the metric from a + list of the per-batch values of ``metric``, or it can be an instance of a + :class:`det.estimator.MetricReducer`. + + The ``numpy_dtype`` must be a numpy dtype. It is used internally to determined the output + type of the TensorFlow ``py_func`` to report the final metric result to the Estimator API. + The format of ``numpy_dtype`` should be anything that ``np.dtype()`` accepts. + + The primary motivation for passing a function as the reducer is simplicity. Metrics from + all batches will be buffered in memory and passed over the network where they will be + reduced all at once. This introduces some overhead, but it is likely unnoticeable for + scalar metrics or on validation datasets of small or medium size. This single function + strategy may also be desirable for quick prototyping or for calculating metrics that are + difficult or impossible to calculate incrementally. + + The primary motivation for passing a ``det.estimator.MetricsReducer`` as the reducer is + performance. ``det.estimator.MetricsReducer`` allows the user to incrementally calculate + the partial metric on each slot, taking advantage of distributed computation, minimizing + memory usage, and minimizing the network communication before the final + ``cross_slot_reduce`` operation. + + Evaluation performance may be improved by precomputing as much as possible in the graph so + that less computation on the ``metric`` value is required within the reducer. + + Example usage where ``reducer`` is a function: + + .. code-block:: python + + def my_mean_reducer(all_batch_metrics): + # Use hstack in case not all batches are equal length. + return np.mean(np.hstack(all_batch_metrics)) + + def my_estimator_model_function(features, labels, mode): + ... + if mode == tf.estimator.ModeKeys.EVAL: + + my_avg_prediction = context.experimental.make_metric( + metric=predictions, reducer=my_mean_reducer, numpy_dtype=np.float32 + ) + + return tf.estimator.EstimatorSpec( + mode, + loss=loss, + eval_metric_ops={"my_avg_prediction": my_avg_prediction}, + ) + """ + if isinstance(reducer, estimator.MetricReducer): + return estimator._distributed_metric(self, metric, reducer, numpy_dtype) + + simple_reducer = estimator._SimpleMetricReducer(reducer) + return estimator._distributed_metric(self, metric, simple_reducer, numpy_dtype) diff --git a/harness/determined/estimator/_estimator_trial.py b/harness/determined/estimator/_estimator_trial.py index 2f8dd59a922..c92fe2c81ed 100644 --- a/harness/determined/estimator/_estimator_trial.py +++ b/harness/determined/estimator/_estimator_trial.py @@ -639,6 +639,9 @@ def compute_validation_metrics(self) -> workload.Response: pathlib.Path(self.estimator._model_dir), self.is_chief ) + # Reset the per-evaluation set of allgather ops in the context. + self.context.experimental._reset_allgather_ops() + if not self.is_chief: return workload.Skipped() diff --git a/harness/determined/estimator/_reducer.py b/harness/determined/estimator/_reducer.py new file mode 100644 index 00000000000..b7fb869507a --- /dev/null +++ b/harness/determined/estimator/_reducer.py @@ -0,0 +1,154 @@ +import abc +from typing import Any, Callable, List, Tuple + +import numpy as np +import tensorflow as tf + +from determined import estimator + + +class MetricReducer: + """ + Efficiently aggregating validation metrics across a multi-slot distributed evaluation is done + in two steps: + + #. Accumulate metrics from each batch on each slot. In the case of calculating a mean, this + might mean keeping a running sum and a count of metrics received. + + #. Reduce metrics from each slot to calculate the final metric. In the case of calculating a + mean, this might mean adding up the per-slot sums and dividing the result by the per-slot + counts. + + Example implementation and usage: + + .. code:: python + + class MyAvgMetricReducer(estimator.MetricReducer): + def __init__(self): + self.sum = 0 + self.counts = 0 + + def accumulate(self, metric): + self.sum += sum(metric) + self.counts += 1 + return self.sum, self.counts + + def cross_slot_reduce(self, per_slot_metrics): + # per_slot_metrics is a list of (sum, counts) tuples + # returned by the final self.accumulate() on each slot + sums, counts = zip(*per_slot_metrics) + return sum(sums) / sum(counts) + + def my_estimator_model_function(features, labels, mode): + ... + if mode == tf.estimator.ModeKeys.EVAL: + + my_avg_prediction = context.experimental.make_metric( + metric=predictions, reducer=MyAvgMetricReducer(), numpy_dtype=np.float32 + ) + + return tf.estimator.EstimatorSpec( + mode, + loss=loss, + eval_metric_ops={"my_avg_prediction": my_avg_prediction}, + ) + + See also: :func:`determined.estimator.EstimatorExperimentalContext.make_metric`. + """ + + @abc.abstractmethod + def accumulate(self, metric: Any) -> Any: + """ + accumulate is called for each batch in the evaluation dataset. Batches will be distributed + across slots, so accumulate will be called many times on each slot. + + accumulate should return the accumulated state. After evaluation is complete, the final + return value of accumulate will become an element of the per_slot_metrics argument to + cross_slot_reduce. + + In the example of the calculating a distributed mean, accumulate might keep a running sum + and a count of metrics received: + + .. code:: python + + def accumulate(self, metric): + self.sum += metric + self.count += 1 + return self.sum, self.count + """ + pass + + @abc.abstractmethod + def cross_slot_reduce(self, per_slot_metrics: List[Any]) -> Any: + """ + cross_slot_reduce is called on the list of results from the final call to accumulate on + each slot. per_slot_metrics will be a list of length N, where N is the number of slots in + the trial (or 1 in non-distributed training). cross_slot_reduce must return the final + metric. + + In the example of calculating a distributed mean, cross_slot_reduce might recieve a list of + (sum, count) tuples and it would calculate the overall mean. + + .. code:: python + + def cross_slot_reduce(self, per_slot_metrics): + sums, counts = zip(*per_slot_metrics) + return np.array(sum(sums) / sum(counts)) + """ + pass + + +class _SimpleMetricReducer(MetricReducer): + """_SimpleMetricReducer takes a one-step reducer function and converts it to a MetricReducer.""" + + def __init__(self, reduce_fn: Callable[[List[Any]], Any]): + self.updates = [] # type: List[Any] + self.reduce_fn = reduce_fn + + def reset(self) -> None: + self.updates = [] + + def accumulate(self, metric: Any) -> List[Any]: + self.updates.append(metric) + return self.updates + + def cross_slot_reduce(self, per_slot_metrics: List[List[Any]]) -> Any: + flat_metrics = [item for sublist in per_slot_metrics for item in sublist] + return self.reduce_fn(flat_metrics) + + +def _distributed_metric( + context: estimator.EstimatorExperimentalContext, + metric: Any, + reducer: MetricReducer, + numpy_dtype: Any, +) -> Tuple[tf.Operation, tf.Operation]: + """ + _distributed_metric returns a tf.metrics-style tuple of (value_op, update_op). The value_op is + apparently read once after all evaluation is completed, which is where we do the allgather and + call the user's cross_slot_reduce to calculate the distributed metric. + """ + if isinstance(numpy_dtype, tf.dtypes.DType): + raise TypeError(f"numpy_dtype parameter must not be a TensorFlow dtype: {numpy_dtype}") + np_dtype = np.dtype(numpy_dtype) + tf_dtype = tf.compat.v1.as_dtype(numpy_dtype) + + last_accumulate = None # type: Any + + def py_update(metric: Any) -> None: + nonlocal last_accumulate + last_accumulate = reducer.accumulate(metric) + + update_op = tf.compat.v1.py_func(py_update, [metric], []) + + def py_value() -> Any: + allgathered = context.allgather_metrics(last_accumulate) + value = reducer.cross_slot_reduce(allgathered) + return np.array(value).astype(np_dtype) + + def build_value_op() -> tf.Operation: + return tf.compat.v1.py_func(py_value, [], tf_dtype) + + value_op = context._build_allgather_op(build_value_op) + + return value_op, update_op diff --git a/harness/tests/experiment/fixtures/estimator_linear_model.py b/harness/tests/experiment/fixtures/estimator_linear_model.py new file mode 100644 index 00000000000..7fa403ab3c3 --- /dev/null +++ b/harness/tests/experiment/fixtures/estimator_linear_model.py @@ -0,0 +1,111 @@ +from typing import Any, List + +import numpy as np +import tensorflow as tf + +from determined import estimator + +TRAINING_LENGTH = 100 +VALIDATION_LENGTH = 10 + + +def validation_label_sum(): + """The custom metrics return a sum of labels of the validation dataset.""" + return sum(range(VALIDATION_LENGTH)) + + +def range_data_loader(batch_size, length): + """Return a dataloader that yields tuples like ({"x": val}, val) for LinearEstimator.""" + data = tf.data.Dataset.range(length).map(lambda x: tf.cast(x, tf.float32)).batch(batch_size) + label = tf.data.Dataset.range(length).map(lambda x: tf.cast(x, tf.float32)).batch(batch_size) + return tf.data.Dataset.zip(({"x": data}, label)) + + +def sum_reducer(batch_metrics: List): + """A function that is able to operate as a custom reducer.""" + return np.hstack(batch_metrics).sum() + + +class SumReducer(estimator.MetricReducer): + """A class that is able to operate as a custom reducer.""" + + def __init__(self): + self.sum = 0 + + def accumulate(self, metric: Any): + self.sum += metric.sum() + return self.sum + + def cross_slot_reduce(self, per_slot_metrics: List): + return sum(per_slot_metrics) + + +class LinearEstimator(estimator.EstimatorTrial): + def __init__(self, context: estimator.EstimatorTrialContext) -> None: + self.context = context + self.hparams = context.get_hparams() + self.batch_size = self.context.get_per_slot_batch_size() + + self.dense = None + + def make_model_fn(self, feature_columns, optimizer): + """Return a one variable linear model. Used by LinearEstimator.""" + + def model_fn(features, labels, mode): + input_layer = tf.compat.v1.feature_column.input_layer(features, feature_columns) + dense = tf.compat.v1.layers.Dense( + units=1, use_bias=False, kernel_initializer=tf.zeros_initializer(), name="my_dense", + ) + output_layer = dense(input_layer) + predictions = tf.squeeze(output_layer, 1) + + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + loss = tf.losses.mean_squared_error(labels, predictions) + + if mode == tf.estimator.ModeKeys.EVAL: + # Use the custom metrics API. + fn_sum = self.context.experimental.make_metric(labels, sum_reducer, np.float32) + cls_sum = self.context.experimental.make_metric(labels, SumReducer(), np.float32) + + return tf.estimator.EstimatorSpec( + mode, + loss=loss, + eval_metric_ops={"label_sum_fn": fn_sum, "label_sum_cls": cls_sum}, + ) + + if mode == tf.estimator.ModeKeys.TRAIN: + train_op = optimizer.minimize( + loss, global_step=tf.compat.v1.train.get_global_step() + ) + return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) + + return model_fn + + def build_estimator(self) -> tf.compat.v1.estimator.Estimator: + feature_columns = [tf.feature_column.numeric_column("x", shape=(), dtype=tf.int64)] + optimizer = tf.compat.v1.train.GradientDescentOptimizer( + learning_rate=self.hparams["learning_rate"], + ) + optimizer = self.context.wrap_optimizer(optimizer) + + estimator = tf.compat.v1.estimator.Estimator( + model_fn=self.make_model_fn(feature_columns, optimizer) + ) + + return estimator + + def build_train_spec(self) -> tf.estimator.TrainSpec: + def fn(): + ds = range_data_loader(self.context.get_per_slot_batch_size(), TRAINING_LENGTH) + return self.context.wrap_dataset(ds) + + return tf.estimator.TrainSpec(fn) + + def build_validation_spec(self) -> tf.estimator.EvalSpec: + def fn(): + ds = range_data_loader(self.context.get_per_slot_batch_size(), VALIDATION_LENGTH) + return self.context.wrap_dataset(ds) + + return tf.estimator.EvalSpec(fn) diff --git a/harness/tests/experiment/tensorflow/test_estimator_trial.py b/harness/tests/experiment/tensorflow/test_estimator_trial.py index 510e4f9761a..f2a2830f9f0 100644 --- a/harness/tests/experiment/tensorflow/test_estimator_trial.py +++ b/harness/tests/experiment/tensorflow/test_estimator_trial.py @@ -11,7 +11,7 @@ from determined import workload from determined.exec import harness from tests.experiment import utils # noqa: I100 -from tests.experiment.fixtures import estimator_xor_model +from tests.experiment.fixtures import estimator_linear_model, estimator_xor_model @pytest.fixture( @@ -71,7 +71,6 @@ def _xor_trial_controller( class TestXORTrial: def setup_method(self) -> None: - os.environ["DET_RENDEZVOUS_INFO"] = '{"rank": 0, "addrs": ["localhost"]}' self.hparams = { "hidden_size": 2, "learning_rate": 0.1, @@ -308,6 +307,42 @@ def make_workloads() -> workload.Stream: controller.run() +class TestLinearTrial: + def setup_method(self) -> None: + self.hparams = { + "learning_rate": 0.0001, + "global_batch_size": 4, + } + + def teardown_method(self) -> None: + # Cleanup leftover environment variable state. + for key in harness.ENVIRONMENT_VARIABLE_KEYS: + if key in os.environ: + del os.environ[key] + + def test_custom_reducer(self) -> None: + def make_workloads() -> workload.Stream: + trainer = utils.TrainAndValidate() + + # Test >1 validation to ensure that resetting the allgather_op list is working. + yield from trainer.send(steps=2, validation_freq=1, batches_per_step=1) + training_metrics, validation_metrics = trainer.result() + + for metrics in validation_metrics: + assert metrics["label_sum_fn"] == estimator_linear_model.validation_label_sum() + assert metrics["label_sum_cls"] == estimator_linear_model.validation_label_sum() + + yield workload.terminate_workload(), [], workload.ignore_workload_response + + controller = utils.make_trial_controller_from_trial_implementation( + trial_class=estimator_linear_model.LinearEstimator, + hparams=self.hparams, + workloads=make_workloads(), + trial_seed=0, + ) + controller.run() + + def test_local_mode() -> None: utils.run_local_test_mode(utils.fixtures_path("estimator_xor_model_native.py"))