Skip to content

Commit

Permalink
fix: correctly use mixed precision with multi-GPU in PyTorchTrial [DE…
Browse files Browse the repository at this point in the history
…T-3285] (#699)

* fix: correctly support mixed precision for multi-GPU PyTorchTrial
* test: re-enable AMP parallel test

When performing mixed precision multi-GPU training, we need to finalize
gradient update communication prior to unscaling.
  • Loading branch information
aaron276h authored Jun 12, 2020
1 parent 5aa0eb3 commit 8f7b68e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
3 changes: 0 additions & 3 deletions e2e_tests/tests/experiment/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ def test_pytorch_const_parallel(aggregation_frequency: int, use_amp: bool) -> No
if use_amp and aggregation_frequency > 1:
pytest.skip("Mixed precision is not support with aggregation frequency > 1.")

if use_amp:
pytest.skip("AMP support NaNs right now, disabling until this is fixed.")

config = conf.load_config(conf.official_examples_path("trial/mnist_pytorch/const.yaml"))
config = conf.set_slots_per_trial(config, 8)
config = conf.set_native_parallel(config, False)
Expand Down
11 changes: 9 additions & 2 deletions harness/determined/pytorch/_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,20 @@ def _train_for_step(self, step_id: int, batches_per_step: int) -> workload.Respo
if self.use_amp():
with apex.amp.scale_loss(loss, self.context.optimizer) as scaled_loss:
scaled_loss.backward()
if self.hvd_config.use and communicate_and_update:
# When using horovod, we need to finish communicating gradient
# updates before they are unscaled which happens when we exit
# of this context manager.
self.context.optimizer.synchronize()
else:
loss.backward()

if communicate_and_update:
if self.hvd_config.use:
# Communication needs to be synchronized so that is completed
# before we apply gradient clipping and `step()`.
if communicate_and_update and self.hvd_config.use:
self.context.optimizer.synchronize()

if communicate_and_update:
parameters = (
self.context.model.parameters()
if not self.use_amp()
Expand Down

0 comments on commit 8f7b68e

Please sign in to comment.