Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: remove steps from pytorch callbacks [DET-3526] #831

Merged
merged 12 commits into from
Jul 8, 2020
2 changes: 1 addition & 1 deletion docs/reference/api/pytorch.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class with ``PyTorchTrial``, implement the following callback:
context.get_optimizer(), "min", verbose=True
) # customize arguments as desired here

def on_validation_step_end(self, metrics):
def on_validation_end(self, metrics):
self.reduce_lr.step(metrics["validation_error"])

def state_dict(self):
Expand Down
36 changes: 20 additions & 16 deletions harness/determined/pytorch/_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,32 @@ class PyTorchCallback:

.. warning::
If distributed training is enabled, every GPU will execute a copy of this callback
(except for :meth:`on_validation_step_end` and :meth:`on_checkpoint_end`). To
(except for :meth:`on_validation_end` and :meth:`on_checkpoint_end`). To
configure a callback implementation to execute on a subset of GPUs, please condition
your implementation on ``trial.context.distributed.get_rank()``.
"""

def on_train_step_start(self, step_id: int) -> None:
def on_batch_start(self, batch_idx: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really different behavior than what we had before, and it doesn't seem like it would be terribly useful to me, since we already call train_batch on every batch.

I'm not sure what use cases we were trying to solve before though? Are those use cases still valid?

Git blame says @yoavz wrote these hooks, maybe @shiyuann or @aaron276h know the answer though?

And if those use cases are still valid, how would we address them after removing steps from the UX?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done so that users can do adjustments to the optimizer and model before training. Don't remember the exact use cases. I do think it makes sense to leave this callback in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose these changes because it's exactly what is in the ERD. I assumed they were discussed/decided already. I understand on_batch_start/on_batch_end seem pretty not useful though. I think most of the context on the original decision is somewhere in #ml-ag slack.

"""
Run before every training step begins.
Run before every batch is trained.
"""
pass

def on_train_step_end(self, step_id: int, metrics: Dict[str, Any]) -> None:
def on_batch_end(self, batch_idx: int, metrics: Dict[str, Any]) -> None:
"""
Run after every training step ends.
Run after every batch is trained.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: you need to add a warning about metrics here. Additionally, we need to decide if we want to average metrics here every batch if optimizations.average_training_metrics is enabled and if on_batch_end is used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like on_batch_end we shouldn't but on_train_epoch_end we should?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that, the tricky part here is we may not always have the metrics for the entire epoch (if the training was resumed mid-epoch). I would propose that we just don't provide averaged metrics in the training callbacks for now, and if/when the need for them arises, we can decide on the proper mechanism to do so.

"""
pass

.. warning::
If distributed training is enabled, every GPU will execute a copy of
this callback at the end of every training step. If
``optimizations.average_training_metrics`` is enabled, then the
``metrics`` will be averaged across all GPUs before the callback is
executed. If ``optimizations.average_training_metrics`` is
disabled, then the ``metrics`` will be local to the GPU.
def on_epoch_start(self, epoch_idx: int) -> None:
"""
Run before every epoch begins.
"""
pass

def on_epoch_end(self, epoch_idx: int, metrics: Dict[str, Any]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: need to be more descriptive about what these metrics are.

"""
Run after every epoch ends.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: We should be more descriptive here about the timing. Often epoch_end is considered to be the end of training and evaluating on a full dataset, which is not the case for us. Might be even worth renaming this callback to on_train_epoch_end().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hadn't thought that would be the expected behavior. thanks.

"""
pass

Expand All @@ -50,15 +54,15 @@ def on_before_optimizer_step(self, parameters: Iterator) -> None:
# TODO(DET-3267): deprecate this when releasing pytorch flexible primitives.
pass

def on_validation_step_start(self) -> None:
def on_validation_start(self) -> None:
"""
Run before every validation step begins.
Run before every validation begins.
"""
pass

def on_validation_step_end(self, metrics: Dict[str, Any]) -> None:
def on_validation_end(self, metrics: Dict[str, Any]) -> None:
"""
Run after every validation step ends.
Run after every validation ends.

.. warning::
This callback only executes on the chief GPU when doing distributed training.
Expand Down
32 changes: 23 additions & 9 deletions harness/determined/pytorch/_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ def run(self) -> None:
def get_epoch_idx(self, batch_id: int) -> int:
return batch_id // len(self.training_loader)

def is_epoch_start(self, batch_id: int) -> bool:
return batch_id % len(self.training_loader) == 0

def is_epoch_end(self, batch_id: int) -> bool:
return batch_id % len(self.training_loader) == len(self.training_loader) - 1

def _average_training_metrics(
self, per_batch_metrics: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -295,16 +301,20 @@ def _train_for_step(
for model in self.context.models:
model.train()

for callback in self.callbacks.values():
callback.on_train_step_start(step_id)

start = total_batches_processed
end = start + num_batches

per_batch_metrics = [] # type: List[Dict]
num_inputs = 0

for batch_idx in range(start, end):
for callback in self.callbacks.values():
callback.on_batch_start(batch_idx)

if self.is_epoch_start(batch_idx):
aaron276h marked this conversation as resolved.
Show resolved Hide resolved
for callback in self.callbacks.values():
callback.on_epoch_start(self.get_epoch_idx(batch_idx))

batch = next(self.training_iterator)
num_inputs += data_length(batch)
batch = self.context._to_device(batch)
Expand Down Expand Up @@ -342,16 +352,20 @@ def _train_for_step(
check.is_in("loss", tr_metrics, 'Please include "loss" in your training metrics.')
per_batch_metrics.append(tr_metrics)

for callback in self.callbacks.values():
callback.on_batch_end(batch_idx, tr_metrics)

if self.is_epoch_end(batch_idx):
for callback in self.callbacks.values():
callback.on_epoch_end(self.get_epoch_idx(batch_idx), tr_metrics)

# Aggregate and reduce training metrics from all the training processes.
if self.hvd_config.use and self.hvd_config.average_training_metrics:
per_batch_metrics = self._average_training_metrics(per_batch_metrics)
if self.hvd_config.use:
num_inputs *= hvd.size()
metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

for callback in self.callbacks.values():
callback.on_train_step_end(step_id, metrics)

if not self.is_chief:
# The training metrics are reported only in the chief process.
return workload.Skipped()
Expand All @@ -375,7 +389,7 @@ def _compute_validation_metrics(self) -> workload.Response:
model.eval()

for callback in self.callbacks.values():
callback.on_validation_step_start()
callback.on_validation_start()

num_inputs = 0
metrics = {} # type: Optional[Dict[str, Any]]
Expand Down Expand Up @@ -436,7 +450,7 @@ def _compute_validation_metrics(self) -> workload.Response:

if self.hvd_config.use and any(
map(
lambda c: util.is_overridden(c.on_validation_step_end, _callback.PyTorchCallback),
lambda c: util.is_overridden(c.on_validation_end, _callback.PyTorchCallback),
self.callbacks.values(),
)
):
Expand All @@ -447,7 +461,7 @@ def _compute_validation_metrics(self) -> workload.Response:
metrics = hvd.broadcast_object(metrics, root_rank=0)

for callback in self.callbacks.values():
callback.on_validation_step_end(cast(Dict[str, Any], metrics))
callback.on_validation_end(cast(Dict[str, Any], metrics))

if not self.is_chief:
return workload.Skipped()
Expand Down
24 changes: 16 additions & 8 deletions harness/tests/experiment/fixtures/pytorch_xor_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,22 +303,30 @@ def evaluation_reducer(self) -> Dict[str, det.pytorch.Reducer]:

class Counter(det.pytorch.PyTorchCallback):
def __init__(self) -> None:
self.train_steps_started = 0
self.train_steps_ended = 0
self.batches_started = 0
self.batches_ended = 0
self.epochs_started = 0
self.epochs_ended = 0
self.validation_steps_started = 0
self.validation_steps_ended = 0
self.checkpoints_ended = 0

def on_train_step_start(self, step_id: int) -> None:
self.train_steps_started += 1
def on_batch_start(self, step_id: int) -> None:
self.batches_started += 1

def on_train_step_end(self, step_id: int, metrics: Dict[str, Any]) -> None:
self.train_steps_ended += 1
def on_batch_end(self, step_id: int, metrics: Dict[str, Any]) -> None:
self.batches_ended += 1

def on_validation_step_start(self) -> None:
def on_epoch_start(self, step_id: int) -> None:
self.epochs_started += 1

def on_epoch_end(self, step_id: int, metrics: Dict[str, Any]) -> None:
self.epochs_ended += 1

def on_validation_start(self) -> None:
self.validation_steps_started += 1

def on_validation_step_end(self, metrics: Dict[str, Any]) -> None:
def on_validation_end(self, metrics: Dict[str, Any]) -> None:
self.validation_steps_ended += 1

def on_checkpoint_end(self, checkpoint_dir: str):
Expand Down
24 changes: 16 additions & 8 deletions harness/tests/experiment/pytorch/test_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,26 +452,32 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None:
)
controller._train_for_step(1, 1, 0)
assert controller.trial.counter.__dict__ == {
"train_steps_started": 1,
"train_steps_ended": 1,
"batches_started": 1,
"batches_ended": 1,
"epochs_started": 1,
"epochs_ended": 1,
"validation_steps_started": 0,
"validation_steps_ended": 0,
"checkpoints_ended": 0,
}

controller._compute_validation_metrics()
assert controller.trial.counter.__dict__ == {
"train_steps_started": 1,
"train_steps_ended": 1,
"batches_started": 1,
"batches_ended": 1,
"epochs_started": 1,
"epochs_ended": 1,
"validation_steps_started": 1,
"validation_steps_ended": 1,
"checkpoints_ended": 0,
}

controller._save(checkpoint_dir)
assert controller.trial.counter.__dict__ == {
"train_steps_started": 1,
"train_steps_ended": 1,
"batches_started": 1,
"batches_ended": 1,
"epochs_started": 1,
"epochs_ended": 1,
"validation_steps_started": 1,
"validation_steps_ended": 1,
"checkpoints_ended": 1,
Expand All @@ -487,8 +493,10 @@ def test_callbacks(self, tmp_path: pathlib.Path) -> None:
)
controller._load()
assert controller.trial.counter.__dict__ == {
"train_steps_started": 1,
"train_steps_ended": 1,
"batches_started": 1,
"batches_ended": 1,
"epochs_started": 1,
"epochs_ended": 1,
"validation_steps_started": 1,
"validation_steps_ended": 1,
"checkpoints_ended": 0,
Expand Down