From 81a68c3b9aa8c9460229b80b54e2665a58b4a06d Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Mon, 6 Jul 2020 16:07:18 -0400 Subject: [PATCH 01/12] feat: remove steps from pytorch callbacks --- docs/reference/api/pytorch.txt | 2 +- harness/determined/pytorch/_callback.py | 39 ++++++++++++++----- harness/determined/pytorch/_pytorch_trial.py | 32 ++++++++++----- .../experiment/fixtures/pytorch_xor_model.py | 24 ++++++++---- .../experiment/pytorch/test_pytorch_trial.py | 24 ++++++++---- 5 files changed, 86 insertions(+), 35 deletions(-) diff --git a/docs/reference/api/pytorch.txt b/docs/reference/api/pytorch.txt index ae81fb0e13e..343420e0661 100644 --- a/docs/reference/api/pytorch.txt +++ b/docs/reference/api/pytorch.txt @@ -116,7 +116,7 @@ class with ``PyTorchTrial``, implement the following callback: context.get_optimizer(), "min", verbose=True ) # customize arguments as desired here - def on_validation_step_end(self, metrics): + def on_validation_end(self, metrics): self.reduce_lr.step(metrics["validation_error"]) def state_dict(self): diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index d41f5e3a00e..037e89d9fd0 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -16,20 +16,41 @@ class PyTorchCallback: .. warning:: If distributed training is enabled, every GPU will execute a copy of this callback - (except for :meth:`on_validation_step_end` and :meth:`on_checkpoint_end`). To + (except for :meth:`on_validation_end` and :meth:`on_checkpoint_end`). To configure a callback implementation to execute on a subset of GPUs, please condition your implementation on ``trial.context.distributed.get_rank()``. """ - def on_train_step_start(self, step_id: int) -> None: + def on_batch_start(self, batch_idx: int) -> None: """ - Run before every training step begins. + Run before every batch is trained. + + """ + pass + + def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: + """ + Run after every batch is trained. + + .. warning:: + If distributed training is enabled, every GPU will execute a copy of + this callback at the end of every training step. If + ``optimizations.average_training_metrics`` is enabled, then the + ``metrics`` will be averaged across all GPUs before the callback is + executed. If ``optimizations.average_training_metrics`` is + disabled, then the ``metrics`` will be local to the GPU. + """ + pass + + def on_epoch_start(self, epoch_idx: int) -> None: + """ + Run before every epoch begins. """ pass - def on_train_step_end(self, step_id: int, metrics: Dict[str, Any]) -> None: + def on_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: """ - Run after every training step ends. + Run after every epoch ends. .. warning:: If distributed training is enabled, every GPU will execute a copy of @@ -50,15 +71,15 @@ def on_before_optimizer_step(self, parameters: Iterator) -> None: # TODO(DET-3267): deprecate this when releasing pytorch flexible primitives. pass - def on_validation_step_start(self) -> None: + def on_validation_start(self) -> None: """ - Run before every validation step begins. + Run before every validation begins. """ pass - def on_validation_step_end(self, metrics: Dict[str, Any]) -> None: + def on_validation_end(self, metrics: Dict[str, Any]) -> None: """ - Run after every validation step ends. + Run after every validation ends. .. warning:: This callback only executes on the chief GPU when doing distributed training. diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index f08b650c71d..d5bc0635b57 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -229,6 +229,12 @@ def run(self) -> None: def get_epoch_idx(self, batch_id: int) -> int: return batch_id // len(self.training_loader) + def is_epoch_start(self, batch_id: int) -> bool: + return batch_id % len(self.training_loader) == 0 + + def is_epoch_end(self, batch_id: int) -> bool: + return batch_id % len(self.training_loader) == len(self.training_loader) - 1 + def _average_training_metrics( self, per_batch_metrics: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: @@ -295,9 +301,6 @@ def _train_for_step( for model in self.context.models: model.train() - for callback in self.callbacks.values(): - callback.on_train_step_start(step_id) - start = total_batches_processed end = start + num_batches @@ -305,6 +308,13 @@ def _train_for_step( num_inputs = 0 for batch_idx in range(start, end): + for callback in self.callbacks.values(): + callback.on_batch_start(batch_idx) + + if self.is_epoch_start(batch_idx): + for callback in self.callbacks.values(): + callback.on_epoch_start(self.get_epoch_idx(batch_idx)) + batch = next(self.training_iterator) num_inputs += data_length(batch) batch = self.context._to_device(batch) @@ -342,6 +352,13 @@ def _train_for_step( check.is_in("loss", tr_metrics, 'Please include "loss" in your training metrics.') per_batch_metrics.append(tr_metrics) + for callback in self.callbacks.values(): + callback.on_batch_end(batch_idx, tr_metrics) + + if self.is_epoch_end(batch_idx): + for callback in self.callbacks.values(): + callback.on_epoch_end(self.get_epoch_idx(batch_idx), tr_metrics) + # Aggregate and reduce training metrics from all the training processes. if self.hvd_config.use and self.hvd_config.average_training_metrics: per_batch_metrics = self._average_training_metrics(per_batch_metrics) @@ -349,9 +366,6 @@ def _train_for_step( num_inputs *= hvd.size() metrics = det.util.make_metrics(num_inputs, per_batch_metrics) - for callback in self.callbacks.values(): - callback.on_train_step_end(step_id, metrics) - if not self.is_chief: # The training metrics are reported only in the chief process. return workload.Skipped() @@ -375,7 +389,7 @@ def _compute_validation_metrics(self) -> workload.Response: model.eval() for callback in self.callbacks.values(): - callback.on_validation_step_start() + callback.on_validation_start() num_inputs = 0 metrics = {} # type: Optional[Dict[str, Any]] @@ -436,7 +450,7 @@ def _compute_validation_metrics(self) -> workload.Response: if self.hvd_config.use and any( map( - lambda c: util.is_overridden(c.on_validation_step_end, _callback.PyTorchCallback), + lambda c: util.is_overridden(c.on_validation_end, _callback.PyTorchCallback), self.callbacks.values(), ) ): @@ -447,7 +461,7 @@ def _compute_validation_metrics(self) -> workload.Response: metrics = hvd.broadcast_object(metrics, root_rank=0) for callback in self.callbacks.values(): - callback.on_validation_step_end(cast(Dict[str, Any], metrics)) + callback.on_validation_end(cast(Dict[str, Any], metrics)) if not self.is_chief: return workload.Skipped() diff --git a/harness/tests/experiment/fixtures/pytorch_xor_model.py b/harness/tests/experiment/fixtures/pytorch_xor_model.py index caa8546105a..ef5a6bf4527 100644 --- a/harness/tests/experiment/fixtures/pytorch_xor_model.py +++ b/harness/tests/experiment/fixtures/pytorch_xor_model.py @@ -303,22 +303,30 @@ def evaluation_reducer(self) -> Dict[str, det.pytorch.Reducer]: class Counter(det.pytorch.PyTorchCallback): def __init__(self) -> None: - self.train_steps_started = 0 - self.train_steps_ended = 0 + self.batches_started = 0 + self.batches_ended = 0 + self.epochs_started = 0 + self.epochs_ended = 0 self.validation_steps_started = 0 self.validation_steps_ended = 0 self.checkpoints_ended = 0 - def on_train_step_start(self, step_id: int) -> None: - self.train_steps_started += 1 + def on_batch_start(self, step_id: int) -> None: + self.batches_started += 1 - def on_train_step_end(self, step_id: int, metrics: Dict[str, Any]) -> None: - self.train_steps_ended += 1 + def on_batch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: + self.batches_ended += 1 - def on_validation_step_start(self) -> None: + def on_epoch_start(self, step_id: int) -> None: + self.epochs_started += 1 + + def on_epoch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: + self.epochs_ended += 1 + + def on_validation_start(self) -> None: self.validation_steps_started += 1 - def on_validation_step_end(self, metrics: Dict[str, Any]) -> None: + def on_validation_end(self, metrics: Dict[str, Any]) -> None: self.validation_steps_ended += 1 def on_checkpoint_end(self, checkpoint_dir: str): diff --git a/harness/tests/experiment/pytorch/test_pytorch_trial.py b/harness/tests/experiment/pytorch/test_pytorch_trial.py index e52c5da8b2b..d43fd8654c6 100644 --- a/harness/tests/experiment/pytorch/test_pytorch_trial.py +++ b/harness/tests/experiment/pytorch/test_pytorch_trial.py @@ -452,8 +452,10 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None: ) controller._train_for_step(1, 1, 0) assert controller.trial.counter.__dict__ == { - "train_steps_started": 1, - "train_steps_ended": 1, + "batches_started": 1, + "batches_ended": 1, + "epochs_started": 1, + "epochs_ended": 1, "validation_steps_started": 0, "validation_steps_ended": 0, "checkpoints_ended": 0, @@ -461,8 +463,10 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None: controller._compute_validation_metrics() assert controller.trial.counter.__dict__ == { - "train_steps_started": 1, - "train_steps_ended": 1, + "batches_started": 1, + "batches_ended": 1, + "epochs_started": 1, + "epochs_ended": 1, "validation_steps_started": 1, "validation_steps_ended": 1, "checkpoints_ended": 0, @@ -470,8 +474,10 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None: controller._save(checkpoint_dir) assert controller.trial.counter.__dict__ == { - "train_steps_started": 1, - "train_steps_ended": 1, + "batches_started": 1, + "batches_ended": 1, + "epochs_started": 1, + "epochs_ended": 1, "validation_steps_started": 1, "validation_steps_ended": 1, "checkpoints_ended": 1, @@ -487,8 +493,10 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None: ) controller._load() assert controller.trial.counter.__dict__ == { - "train_steps_started": 1, - "train_steps_ended": 1, + "batches_started": 1, + "batches_ended": 1, + "epochs_started": 1, + "epochs_ended": 1, "validation_steps_started": 1, "validation_steps_ended": 1, "checkpoints_ended": 0, From aac12aba4941164af8704288593b328ce5b73843 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Mon, 6 Jul 2020 16:49:08 -0400 Subject: [PATCH 02/12] remove misleading warning from docstring --- harness/determined/pytorch/_callback.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index 037e89d9fd0..99b95546c99 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -31,14 +31,6 @@ def on_batch_start(self, batch_idx: int) -> None: def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: """ Run after every batch is trained. - - .. warning:: - If distributed training is enabled, every GPU will execute a copy of - this callback at the end of every training step. If - ``optimizations.average_training_metrics`` is enabled, then the - ``metrics`` will be averaged across all GPUs before the callback is - executed. If ``optimizations.average_training_metrics`` is - disabled, then the ``metrics`` will be local to the GPU. """ pass @@ -51,14 +43,6 @@ def on_epoch_start(self, epoch_idx: int) -> None: def on_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: """ Run after every epoch ends. - - .. warning:: - If distributed training is enabled, every GPU will execute a copy of - this callback at the end of every training step. If - ``optimizations.average_training_metrics`` is enabled, then the - ``metrics`` will be averaged across all GPUs before the callback is - executed. If ``optimizations.average_training_metrics`` is - disabled, then the ``metrics`` will be local to the GPU. """ pass From 681f9c8f713e449944ef12579276593d33ca2a22 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Mon, 6 Jul 2020 16:49:46 -0400 Subject: [PATCH 03/12] remove excess whitespace --- harness/determined/pytorch/_callback.py | 1 - 1 file changed, 1 deletion(-) diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index 99b95546c99..95327d22020 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -24,7 +24,6 @@ class PyTorchCallback: def on_batch_start(self, batch_idx: int) -> None: """ Run before every batch is trained. - """ pass From 6498fdac51cf613152e6cdc4987d5ee10e730fac Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Tue, 7 Jul 2020 10:41:19 -0400 Subject: [PATCH 04/12] rename and switch order of callbacks --- harness/determined/pytorch/_callback.py | 8 ++++---- harness/determined/pytorch/_pytorch_trial.py | 8 ++++---- harness/tests/experiment/fixtures/pytorch_xor_model.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index 95327d22020..3fa1a79b672 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -23,13 +23,13 @@ class PyTorchCallback: def on_batch_start(self, batch_idx: int) -> None: """ - Run before every batch is trained. + Run before every batch begins training. """ pass def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: """ - Run after every batch is trained. + Run after every batch finishes training. """ pass @@ -39,9 +39,9 @@ def on_epoch_start(self, epoch_idx: int) -> None: """ pass - def on_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: + def on_train_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: """ - Run after every epoch ends. + Run after every epoch finishes training. """ pass diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index d5bc0635b57..d17d48b5f63 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -308,13 +308,13 @@ def _train_for_step( num_inputs = 0 for batch_idx in range(start, end): - for callback in self.callbacks.values(): - callback.on_batch_start(batch_idx) - if self.is_epoch_start(batch_idx): for callback in self.callbacks.values(): callback.on_epoch_start(self.get_epoch_idx(batch_idx)) + for callback in self.callbacks.values(): + callback.on_batch_start(batch_idx) + batch = next(self.training_iterator) num_inputs += data_length(batch) batch = self.context._to_device(batch) @@ -357,7 +357,7 @@ def _train_for_step( if self.is_epoch_end(batch_idx): for callback in self.callbacks.values(): - callback.on_epoch_end(self.get_epoch_idx(batch_idx), tr_metrics) + callback.on_train_epoch_end(self.get_epoch_idx(batch_idx), tr_metrics) # Aggregate and reduce training metrics from all the training processes. if self.hvd_config.use and self.hvd_config.average_training_metrics: diff --git a/harness/tests/experiment/fixtures/pytorch_xor_model.py b/harness/tests/experiment/fixtures/pytorch_xor_model.py index ef5a6bf4527..2e8d5c24fe0 100644 --- a/harness/tests/experiment/fixtures/pytorch_xor_model.py +++ b/harness/tests/experiment/fixtures/pytorch_xor_model.py @@ -320,7 +320,7 @@ def on_batch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: def on_epoch_start(self, step_id: int) -> None: self.epochs_started += 1 - def on_epoch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: + def on_train_epoch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: self.epochs_ended += 1 def on_validation_start(self) -> None: From b16ead17f19b5ed2c754a2011e32385713015dc6 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Tue, 7 Jul 2020 10:56:25 -0400 Subject: [PATCH 05/12] better docstrings for pytorch callbacks --- harness/determined/pytorch/_callback.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index 3fa1a79b672..a5d755a957f 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -29,7 +29,12 @@ def on_batch_start(self, batch_idx: int) -> None: def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: """ - Run after every batch finishes training. + Run after every batch finishes training. ``metrics`` are the metrics returned from calling + ``PyTorchTrail.train_batch`` on batch ``batch_idx``. + + .. warning:: + If distributed training is enabled, every GPU will execute a copy of this callback at + the end of every training step and the ``metrics``will be local to the GPU. """ pass @@ -41,7 +46,12 @@ def on_epoch_start(self, epoch_idx: int) -> None: def on_train_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: """ - Run after every epoch finishes training. + Run after every epoch finishes training. ``metrics`` are the metrics returned from calling + ``PyTorchTrail.train_batch`` on the last batch of epoch ``epoch_idx``. + + .. warning:: + If distributed training is enabled, every GPU will execute a copy of this callback at + the end of every training step and the ``metrics``will be local to the GPU. """ pass From 13b3b3018b6045e4d313ca923144b2429d4c9c06 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Tue, 7 Jul 2020 11:18:04 -0400 Subject: [PATCH 06/12] fix docs --- harness/determined/pytorch/_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index a5d755a957f..33993586a08 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -34,7 +34,7 @@ def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: .. warning:: If distributed training is enabled, every GPU will execute a copy of this callback at - the end of every training step and the ``metrics``will be local to the GPU. + the end of every training step and the ``metrics`` will be local to the GPU. """ pass @@ -51,7 +51,7 @@ def on_train_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: .. warning:: If distributed training is enabled, every GPU will execute a copy of this callback at - the end of every training step and the ``metrics``will be local to the GPU. + the end of every training step and the ``metrics`` will be local to the GPU. """ pass From 41051caaf02cca2a108227f194dc1e63f937feb5 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Tue, 7 Jul 2020 14:57:33 -0400 Subject: [PATCH 07/12] rename callbacks to be more semantically correct given future changes --- harness/determined/pytorch/_callback.py | 6 +++--- harness/determined/pytorch/_pytorch_trial.py | 6 +++--- harness/tests/experiment/fixtures/pytorch_xor_model.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index 33993586a08..9a8bc31a933 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -21,13 +21,13 @@ class PyTorchCallback: your implementation on ``trial.context.distributed.get_rank()``. """ - def on_batch_start(self, batch_idx: int) -> None: + def on_train_batch_start(self, batch_idx: int) -> None: """ Run before every batch begins training. """ pass - def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: + def on_train_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: """ Run after every batch finishes training. ``metrics`` are the metrics returned from calling ``PyTorchTrail.train_batch`` on batch ``batch_idx``. @@ -38,7 +38,7 @@ def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: """ pass - def on_epoch_start(self, epoch_idx: int) -> None: + def on_train_epoch_start(self, epoch_idx: int) -> None: """ Run before every epoch begins. """ diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index d17d48b5f63..e7b18bed083 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -310,10 +310,10 @@ def _train_for_step( for batch_idx in range(start, end): if self.is_epoch_start(batch_idx): for callback in self.callbacks.values(): - callback.on_epoch_start(self.get_epoch_idx(batch_idx)) + callback.on_train_epoch_start(self.get_epoch_idx(batch_idx)) for callback in self.callbacks.values(): - callback.on_batch_start(batch_idx) + callback.on_train_batch_start(batch_idx) batch = next(self.training_iterator) num_inputs += data_length(batch) @@ -353,7 +353,7 @@ def _train_for_step( per_batch_metrics.append(tr_metrics) for callback in self.callbacks.values(): - callback.on_batch_end(batch_idx, tr_metrics) + callback.on_train_batch_end(batch_idx, tr_metrics) if self.is_epoch_end(batch_idx): for callback in self.callbacks.values(): diff --git a/harness/tests/experiment/fixtures/pytorch_xor_model.py b/harness/tests/experiment/fixtures/pytorch_xor_model.py index 2e8d5c24fe0..b25d8393c73 100644 --- a/harness/tests/experiment/fixtures/pytorch_xor_model.py +++ b/harness/tests/experiment/fixtures/pytorch_xor_model.py @@ -311,13 +311,13 @@ def __init__(self) -> None: self.validation_steps_ended = 0 self.checkpoints_ended = 0 - def on_batch_start(self, step_id: int) -> None: + def on_train_batch_start(self, step_id: int) -> None: self.batches_started += 1 - def on_batch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: + def on_train_batch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: self.batches_ended += 1 - def on_epoch_start(self, step_id: int) -> None: + def on_train_epoch_start(self, step_id: int) -> None: self.epochs_started += 1 def on_train_epoch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: From 85d45c20c0c0821500faf5af5f32f6f9148d7aae Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Tue, 7 Jul 2020 21:57:20 -0400 Subject: [PATCH 08/12] remove some callbacks, bring back others with deprecaton waring --- harness/determined/pytorch/_callback.py | 53 ++++++------------- harness/determined/pytorch/_pytorch_trial.py | 29 +++------- .../experiment/fixtures/pytorch_xor_model.py | 16 ------ .../experiment/pytorch/test_pytorch_trial.py | 16 ------ 4 files changed, 25 insertions(+), 89 deletions(-) diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index 9a8bc31a933..513cf2c0d06 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -16,63 +16,44 @@ class PyTorchCallback: .. warning:: If distributed training is enabled, every GPU will execute a copy of this callback - (except for :meth:`on_validation_end` and :meth:`on_checkpoint_end`). To - configure a callback implementation to execute on a subset of GPUs, please condition - your implementation on ``trial.context.distributed.get_rank()``. + (except for :meth:`on_validation_end`, :meth:`on_validation_step_end` and + :meth:`on_checkpoint_end`). To configure a callback implementation to execute on a subset of + GPUs, please condition your implementation on ``trial.context.distributed.get_rank()``. """ - def on_train_batch_start(self, batch_idx: int) -> None: - """ - Run before every batch begins training. - """ - pass - - def on_train_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None: + def on_before_optimizer_step(self, parameters: Iterator) -> None: """ - Run after every batch finishes training. ``metrics`` are the metrics returned from calling - ``PyTorchTrail.train_batch`` on batch ``batch_idx``. - - .. warning:: - If distributed training is enabled, every GPU will execute a copy of this callback at - the end of every training step and the ``metrics`` will be local to the GPU. + Run before every before ``optimizer.step()``. For multi-GPU training, executes + after gradient updates have been communicated. Typically used to perform gradient + clipping. """ + # TODO(DET-3267): deprecate this when releasing pytorch flexible primitives. pass - def on_train_epoch_start(self, epoch_idx: int) -> None: + def on_validation_start(self) -> None: """ - Run before every epoch begins. + Run before every validation begins. """ pass - def on_train_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None: + def on_validation_end(self, metrics: Dict[str, Any]) -> None: """ - Run after every epoch finishes training. ``metrics`` are the metrics returned from calling - ``PyTorchTrail.train_batch`` on the last batch of epoch ``epoch_idx``. + Run after every validation ends. .. warning:: - If distributed training is enabled, every GPU will execute a copy of this callback at - the end of every training step and the ``metrics`` will be local to the GPU. - """ - pass - - def on_before_optimizer_step(self, parameters: Iterator) -> None: - """ - Run before every before ``optimizer.step()``. For multi-GPU training, executes - after gradient updates have been communicated. Typically used to perform gradient - clipping. + This callback only executes on the chief GPU when doing distributed training. """ - # TODO(DET-3267): deprecate this when releasing pytorch flexible primitives. pass - def on_validation_start(self) -> None: + def on_validation_step_start(self) -> None: """ - Run before every validation begins. + Run before every validation step begins. """ pass - def on_validation_end(self, metrics: Dict[str, Any]) -> None: + def on_validation_step_end(self, metrics: Dict[str, Any]) -> None: """ - Run after every validation ends. + Run after every validation step ends. .. warning:: This callback only executes on the chief GPU when doing distributed training. diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index e7b18bed083..a7553b28368 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -229,12 +229,6 @@ def run(self) -> None: def get_epoch_idx(self, batch_id: int) -> int: return batch_id // len(self.training_loader) - def is_epoch_start(self, batch_id: int) -> bool: - return batch_id % len(self.training_loader) == 0 - - def is_epoch_end(self, batch_id: int) -> bool: - return batch_id % len(self.training_loader) == len(self.training_loader) - 1 - def _average_training_metrics( self, per_batch_metrics: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: @@ -308,13 +302,6 @@ def _train_for_step( num_inputs = 0 for batch_idx in range(start, end): - if self.is_epoch_start(batch_idx): - for callback in self.callbacks.values(): - callback.on_train_epoch_start(self.get_epoch_idx(batch_idx)) - - for callback in self.callbacks.values(): - callback.on_train_batch_start(batch_idx) - batch = next(self.training_iterator) num_inputs += data_length(batch) batch = self.context._to_device(batch) @@ -352,13 +339,6 @@ def _train_for_step( check.is_in("loss", tr_metrics, 'Please include "loss" in your training metrics.') per_batch_metrics.append(tr_metrics) - for callback in self.callbacks.values(): - callback.on_train_batch_end(batch_idx, tr_metrics) - - if self.is_epoch_end(batch_idx): - for callback in self.callbacks.values(): - callback.on_train_epoch_end(self.get_epoch_idx(batch_idx), tr_metrics) - # Aggregate and reduce training metrics from all the training processes. if self.hvd_config.use and self.hvd_config.average_training_metrics: per_batch_metrics = self._average_training_metrics(per_batch_metrics) @@ -450,7 +430,8 @@ def _compute_validation_metrics(self) -> workload.Response: if self.hvd_config.use and any( map( - lambda c: util.is_overridden(c.on_validation_end, _callback.PyTorchCallback), + lambda c: util.is_overridden(c.on_validation_end, _callback.PyTorchCallback) + or util.is_overridden(c.on_validation_step_end, _callback.PyTorchCallback), self.callbacks.values(), ) ): @@ -460,6 +441,12 @@ def _compute_validation_metrics(self) -> workload.Response: ) metrics = hvd.broadcast_object(metrics, root_rank=0) + for callback in self.callbacks.values(): + logging.warning( + "on_validation_step_end is now deprecated, please use on_validation_end instead." + ) + callback.on_validation_step_end(cast(Dict[str, Any], metrics)) + for callback in self.callbacks.values(): callback.on_validation_end(cast(Dict[str, Any], metrics)) diff --git a/harness/tests/experiment/fixtures/pytorch_xor_model.py b/harness/tests/experiment/fixtures/pytorch_xor_model.py index b25d8393c73..b6931d098b7 100644 --- a/harness/tests/experiment/fixtures/pytorch_xor_model.py +++ b/harness/tests/experiment/fixtures/pytorch_xor_model.py @@ -303,26 +303,10 @@ def evaluation_reducer(self) -> Dict[str, det.pytorch.Reducer]: class Counter(det.pytorch.PyTorchCallback): def __init__(self) -> None: - self.batches_started = 0 - self.batches_ended = 0 - self.epochs_started = 0 - self.epochs_ended = 0 self.validation_steps_started = 0 self.validation_steps_ended = 0 self.checkpoints_ended = 0 - def on_train_batch_start(self, step_id: int) -> None: - self.batches_started += 1 - - def on_train_batch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: - self.batches_ended += 1 - - def on_train_epoch_start(self, step_id: int) -> None: - self.epochs_started += 1 - - def on_train_epoch_end(self, step_id: int, metrics: Dict[str, Any]) -> None: - self.epochs_ended += 1 - def on_validation_start(self) -> None: self.validation_steps_started += 1 diff --git a/harness/tests/experiment/pytorch/test_pytorch_trial.py b/harness/tests/experiment/pytorch/test_pytorch_trial.py index d43fd8654c6..71276297144 100644 --- a/harness/tests/experiment/pytorch/test_pytorch_trial.py +++ b/harness/tests/experiment/pytorch/test_pytorch_trial.py @@ -452,10 +452,6 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None: ) controller._train_for_step(1, 1, 0) assert controller.trial.counter.__dict__ == { - "batches_started": 1, - "batches_ended": 1, - "epochs_started": 1, - "epochs_ended": 1, "validation_steps_started": 0, "validation_steps_ended": 0, "checkpoints_ended": 0, @@ -463,10 +459,6 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None: controller._compute_validation_metrics() assert controller.trial.counter.__dict__ == { - "batches_started": 1, - "batches_ended": 1, - "epochs_started": 1, - "epochs_ended": 1, "validation_steps_started": 1, "validation_steps_ended": 1, "checkpoints_ended": 0, @@ -474,10 +466,6 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None: controller._save(checkpoint_dir) assert controller.trial.counter.__dict__ == { - "batches_started": 1, - "batches_ended": 1, - "epochs_started": 1, - "epochs_ended": 1, "validation_steps_started": 1, "validation_steps_ended": 1, "checkpoints_ended": 1, @@ -493,10 +481,6 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None: ) controller._load() assert controller.trial.counter.__dict__ == { - "batches_started": 1, - "batches_ended": 1, - "epochs_started": 1, - "epochs_ended": 1, "validation_steps_started": 1, "validation_steps_ended": 1, "checkpoints_ended": 0, From 49e813315c9831090a5222884d71d90a722d16b2 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Tue, 7 Jul 2020 22:07:12 -0400 Subject: [PATCH 09/12] call other callback too --- harness/determined/pytorch/_pytorch_trial.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index a7553b28368..3600542f975 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -368,6 +368,12 @@ def _compute_validation_metrics(self) -> workload.Response: for model in self.context.models: model.eval() + for callback in self.callbacks.values(): + logging.warning( + "on_validation_step_start is now deprecated, please use on_validation_start instead." + ) + callback.on_validation_step_start() + for callback in self.callbacks.values(): callback.on_validation_start() From 662c189298f657472de95222f3d3e2f19747f081 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Tue, 7 Jul 2020 22:09:46 -0400 Subject: [PATCH 10/12] remove punct --- harness/determined/pytorch/_pytorch_trial.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index 3600542f975..7bb8c028c1f 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -370,7 +370,7 @@ def _compute_validation_metrics(self) -> workload.Response: for callback in self.callbacks.values(): logging.warning( - "on_validation_step_start is now deprecated, please use on_validation_start instead." + "on_validation_step_start is now deprecated, please use on_validation_start instead" ) callback.on_validation_step_start() @@ -449,7 +449,7 @@ def _compute_validation_metrics(self) -> workload.Response: for callback in self.callbacks.values(): logging.warning( - "on_validation_step_end is now deprecated, please use on_validation_end instead." + "on_validation_step_end is now deprecated, please use on_validation_end instead" ) callback.on_validation_step_end(cast(Dict[str, Any], metrics)) From 85b8d7a2e1468c1a9973f81946ece909f1aa1986 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Tue, 7 Jul 2020 22:32:38 -0400 Subject: [PATCH 11/12] add is_epoch_start/is_epoch_end to PytorchTrialContext --- harness/determined/pytorch/_pytorch_context.py | 17 ++++++++++++++++- harness/determined/pytorch/_pytorch_trial.py | 1 + 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/harness/determined/pytorch/_pytorch_context.py b/harness/determined/pytorch/_pytorch_context.py index 70bd568cf5e..e2a558b6069 100644 --- a/harness/determined/pytorch/_pytorch_context.py +++ b/harness/determined/pytorch/_pytorch_context.py @@ -48,11 +48,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Track whether a warning logging category has already been issued to the user. self.warning_logged = {_WarningLogs.FAILED_MOVING_TO_DEVICE: False} - # The following three attributes are initialized during the lifetime of + # The following attributes are initialized during the lifetime of # a PyTorchTrialContext. self.models = [] # type: List[nn.Module] self.optimizers = [] # type: List[torch.optim.Optimizer] # type: ignore self.lr_schedulers = [] # type: List[pytorch.LRScheduler] + self._epoch_len = None # type: Optional[int] # Use a main model to contain all of the models because when using horovod # to broadcast the states of models we want to avoid name conflicts for these @@ -457,3 +458,17 @@ def _step_optimizer( else: optimizer.step() optimizer.zero_grad() + + def is_epoch_start(self) -> bool: + if self._current_batch_idx is None: + raise det.errors.InternalException("Training hasn't started.") + if self._current_epoch_len is None: + raise det.errors.InternalException("Training DataLoader uninitialized.") + return self._current_batch_idx % self._epoch_len == 0 + + def is_epoch_end(self) -> bool: + if self._current_batch_idx is None: + raise det.errors.InternalException("Training hasn't started.") + if self._current_epoch_len is None: + raise det.errors.InternalException("Training DataLoader uninitialized.") + return self._current_batch_idx % self._epoch_len == self._epoch_len - 1 diff --git a/harness/determined/pytorch/_pytorch_trial.py b/harness/determined/pytorch/_pytorch_trial.py index 7bb8c028c1f..54fe55d4410 100644 --- a/harness/determined/pytorch/_pytorch_trial.py +++ b/harness/determined/pytorch/_pytorch_trial.py @@ -189,6 +189,7 @@ def _set_data_loaders(self) -> None: self.training_loader = self.trial.build_training_data_loader().get_data_loader( repeat=True, skip=skip_batches, num_replicas=nreplicas, rank=rank ) + self.context._epoch_len = len(self.training_loader) validation_dataset = self.trial.build_validation_data_loader() if self._evaluate_batch_defined(): From efd3bffdc40419c70d922fb78e8f2f5af96c2528 Mon Sep 17 00:00:00 2001 From: Bradley Laney Date: Wed, 8 Jul 2020 10:15:32 -0400 Subject: [PATCH 12/12] better docstrings/comments --- harness/determined/pytorch/_callback.py | 2 ++ harness/determined/pytorch/_pytorch_context.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/harness/determined/pytorch/_callback.py b/harness/determined/pytorch/_callback.py index 513cf2c0d06..b64a816ce8e 100644 --- a/harness/determined/pytorch/_callback.py +++ b/harness/determined/pytorch/_callback.py @@ -49,6 +49,7 @@ def on_validation_step_start(self) -> None: """ Run before every validation step begins. """ + # TODO(DET-3555): remove this once it has been deprecated long enough. pass def on_validation_step_end(self, metrics: Dict[str, Any]) -> None: @@ -58,6 +59,7 @@ def on_validation_step_end(self, metrics: Dict[str, Any]) -> None: .. warning:: This callback only executes on the chief GPU when doing distributed training. """ + # TODO(DET-3555): remove this once it has been deprecated long enough. pass def on_checkpoint_end(self, checkpoint_dir: str) -> None: diff --git a/harness/determined/pytorch/_pytorch_context.py b/harness/determined/pytorch/_pytorch_context.py index e2a558b6069..90fef68bff7 100644 --- a/harness/determined/pytorch/_pytorch_context.py +++ b/harness/determined/pytorch/_pytorch_context.py @@ -460,15 +460,27 @@ def _step_optimizer( optimizer.zero_grad() def is_epoch_start(self) -> bool: + """ + Returns true if the current batch is the first batch of the epoch. + + .. warning:: + Not accurate for variable size epochs. + """ if self._current_batch_idx is None: raise det.errors.InternalException("Training hasn't started.") - if self._current_epoch_len is None: + if self._epoch_len is None: raise det.errors.InternalException("Training DataLoader uninitialized.") return self._current_batch_idx % self._epoch_len == 0 def is_epoch_end(self) -> bool: + """ + Returns true if the current batch is the last batch of the epoch. + + .. warning:: + Not accurate for variable size epochs. + """ if self._current_batch_idx is None: raise det.errors.InternalException("Training hasn't started.") - if self._current_epoch_len is None: + if self._epoch_len is None: raise det.errors.InternalException("Training DataLoader uninitialized.") return self._current_batch_idx % self._epoch_len == self._epoch_len - 1