Skip to content

Commit

Permalink
feat: support gradient clipping in PyTorchTrial via callbacks (#615)
Browse files Browse the repository at this point in the history
Breaking Change: we no longer accept gradient clipping as a special
hyperparameter for PyTorchTrial.
  • Loading branch information
aaron276h authored Jun 3, 2020
1 parent 80e39d0 commit 25e725e
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 51 deletions.
12 changes: 12 additions & 0 deletions docs/reference/api/pytorch.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ Then, implement the ``build_callbacks`` function in ``PyTorchTrial``:
def build_callbacks(self):
return {"reduce_lr": ReduceLROnPlateauEveryValidationStep(self.context)}


``Gradient Clipping``
^^^^^^^^^^^^^^^^^^^^^

To perform gradient clipping Determined provides two pre-made callback classes:

.. autoclass:: determined.pytorch.ClipGradsL2Norm
:members:

.. autoclass:: determined.pytorch.ClipGradsL2Value
:members:

Examples
--------

Expand Down
2 changes: 1 addition & 1 deletion harness/determined/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
data_length,
to_device,
)
from determined.pytorch._callback import PyTorchCallback
from determined.pytorch._callback import PyTorchCallback, ClipGradsL2Norm, ClipGradsL2Value
from determined.pytorch._lr_scheduler import LRScheduler, _LRHelper
from determined.pytorch._reducer import Reducer, _reduce_metrics
from determined.pytorch._pytorch_context import PyTorchTrialContext
Expand Down
38 changes: 37 additions & 1 deletion harness/determined/pytorch/_callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict
from typing import Any, Dict, Iterator

import torch


class PyTorchCallback:
Expand Down Expand Up @@ -39,6 +41,14 @@ def on_train_step_end(self, step_id: int, metrics: Dict[str, Any]) -> None:
"""
pass

def on_before_optimizer_step(self, parameters: Iterator) -> None:
"""
Run before every before `optimizer.step()`. For multi-GPU training, executes
after gradient updates have been communicated. Typically used to perform gradient
clipping.
"""
pass

def on_validation_step_start(self) -> None:
"""
Run before every validation step begins.
Expand Down Expand Up @@ -75,3 +85,29 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Load the state of this using the deserialized ``state_dict``.
"""
pass


class ClipGradsL2Norm(PyTorchCallback):
"""
Callback that performs gradient clipping using
`L2 Norm <https://pytorch.org/docs/stable/nn.html#clip-grad-norm>`_.
"""

def __init__(self, clip_value: float) -> None:
self._clip_value = clip_value

def on_before_optimizer_step(self, parameters: Iterator) -> None:
torch.nn.utils.clip_grad_norm_(parameters, self._clip_value) # type: ignore


class ClipGradsL2Value(PyTorchCallback):
"""
Callback that performs gradient clipping using
`L2 Value <https://pytorch.org/docs/stable/nn.html#clip-grad-value>`_.
"""

def __init__(self, clip_value: float) -> None:
self._clip_value = clip_value

def on_before_optimizer_step(self, parameters: Iterator) -> None:
torch.nn.utils.clip_grad_value_(parameters, self._clip_value) # type: ignore
28 changes: 9 additions & 19 deletions harness/determined/pytorch/_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,24 +288,6 @@ def _average_gradients(parameters: Any, divisor: int) -> None:
for p in filter(lambda param: param.grad is not None, parameters):
p.grad.data.div_(divisor_value)

def _clip_grads(self, parameters: Any) -> None:
# TODO: Support clip by norm other than L2.
clip_grad_l2_norm = self.env.hparams.get("clip_grad_l2_norm", None)
clip_by_val = self.env.hparams.get("clip_grad_val", None)
check.false(
clip_grad_l2_norm is not None and clip_by_val is not None,
"Please specify either `clip_grad_l2_norm` or `clip_by_val` "
"in your hparams, not both.",
)
if clip_grad_l2_norm is not None:
logging.debug(f"Clipping gradients by L2 norm of: {clip_grad_l2_norm}.")
torch.nn.utils.clip_grad_norm_(parameters, clip_grad_l2_norm) # type: ignore
elif clip_by_val is not None:
logging.debug(f"Clipping gradients by value of: {clip_by_val}.")
torch.nn.utils.clip_grad_value_(parameters, clip_by_val) # type: ignore
else:
logging.debug("No gradient clipping enabled.")

def _average_training_metrics(
self, per_batch_metrics: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -416,7 +398,15 @@ def _train_for_step(self, step_id: int, batches_per_step: int) -> workload.Respo
parameters=parameters, divisor=self.hvd_config.aggregation_frequency
)

self._clip_grads(parameters)
# TODO: Remove this check in v0.12.8.
check.false(
self.env.hparams.get("clip_grad_l2_norm", None)
or self.env.hparams.get("clip_grad_val", None),
"Please specify gradient clipping via callbacks.",
)

for callback in self.callbacks.values():
callback.on_before_optimizer_step(parameters)

if self.hvd_config.use:
with self.context.optimizer.skip_synchronize():
Expand Down
73 changes: 47 additions & 26 deletions harness/tests/experiment/fixtures/pytorch_xor_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import TensorDataset

import determined as det
from determined.pytorch import DataLoader, LRScheduler, PyTorchTrial, TorchData, reset_parameters
from determined import pytorch
from determined_common import check


Expand Down Expand Up @@ -51,22 +51,22 @@ def __init__(self, context):
nn.Linear(context.get_hparam("hidden_size"), 1),
nn.Sigmoid(),
)
reset_parameters(self.main_net)
pytorch.reset_parameters(self.main_net)

def forward(self, model_input: Any):
return self.main_net(model_input)


def xor_data_loader(batch_size: int) -> DataLoader:
def xor_data_loader(batch_size: int) -> pytorch.DataLoader:
training_data = np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
training_data = torch.Tensor(training_data)
training_labels = np.array([0, 1, 1, 0], dtype=np.float32)
training_labels = torch.Tensor(training_labels)
training = TensorDataset(training_data, training_labels)
return DataLoader(training, batch_size=batch_size)
return pytorch.DataLoader(training, batch_size=batch_size)


class BaseXORTrial(PyTorchTrial):
class BaseXORTrial(pytorch.PyTorchTrial):
"""
Models a lightweight neural network model with one hidden layer to
learn a binary XOR function. See Deep Learning Book, chapter 6.1 for
Expand All @@ -85,23 +85,23 @@ def optimizer(self, model: nn.Module) -> torch.optim.Optimizer:
return torch.optim.SGD(model.parameters(), self.context.get_hparam("learning_rate"))

def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
self, batch: pytorch.TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
data, labels = batch
output = model(data)
loss = torch.nn.functional.binary_cross_entropy(output, labels.view(-1, 1))

return {"loss": loss}

def build_training_data_loader(self) -> DataLoader:
def build_training_data_loader(self) -> pytorch.DataLoader:
return xor_data_loader(self.context.get_per_slot_batch_size())

def build_validation_data_loader(self) -> DataLoader:
def build_validation_data_loader(self) -> pytorch.DataLoader:
return xor_data_loader(self.context.get_per_slot_batch_size())


class XORTrial(BaseXORTrial):
def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]:
def evaluate_batch(self, batch: pytorch.TorchData, model: nn.Module) -> Dict[str, Any]:
data, labels = batch
output = model(data)
loss = error_rate(output, labels)
Expand Down Expand Up @@ -134,15 +134,15 @@ def build_model(self) -> nn.Module:
return XORNetMulti(self.context)

def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
self, batch: pytorch.TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
data, labels = batch
output = model(data)
loss = nn.functional.binary_cross_entropy(output["output"], labels.view(-1, 1))

return {"loss": loss}

def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]:
def evaluate_batch(self, batch: pytorch.TorchData, model: nn.Module) -> Dict[str, Any]:
data, labels = batch
output = model(data)
error = binary_error_rate(output["output"], labels)
Expand All @@ -152,7 +152,7 @@ def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]:

class XORTrialWithTrainingMetrics(XORTrialMulti):
def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
self, batch: pytorch.TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
data, labels = batch
output = model(data)
Expand All @@ -164,7 +164,7 @@ def train_batch(


class XORTrialWithMultiValidation(XORTrialMulti):
def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]:
def evaluate_batch(self, batch: pytorch.TorchData, model: nn.Module) -> Dict[str, Any]:
data, labels = batch
output = model(data)
accuracy = error_rate(output["output"], labels)
Expand All @@ -173,7 +173,7 @@ def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]:
return {"accuracy": accuracy, "binary_error": binary_error}


class XORTrialWithNonScalarValidation(PyTorchTrial):
class XORTrialWithNonScalarValidation(pytorch.PyTorchTrial):
def __init__(self, context: det.TrialContext) -> None:
self.context = context

Expand All @@ -183,14 +183,14 @@ def build_model(self) -> nn.Module:
def optimizer(self, model: nn.Module) -> torch.optim.Optimizer:
return torch.optim.SGD(model.parameters(), self.context.get_hparam("learning_rate"))

def build_training_data_loader(self) -> DataLoader:
def build_training_data_loader(self) -> pytorch.DataLoader:
return xor_data_loader(self.context.get_per_slot_batch_size())

def build_validation_data_loader(self) -> DataLoader:
def build_validation_data_loader(self) -> pytorch.DataLoader:
return xor_data_loader(self.context.get_per_slot_batch_size())

def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
self, batch: pytorch.TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
data, labels = batch
output = model(data)
Expand Down Expand Up @@ -244,16 +244,20 @@ def set_lr(self, lr: float) -> None:
class XORTrialStepEveryEpoch(XORTrialMulti):
def create_lr_scheduler(self, optimizer):
self.scheduler = ModifyableLRSchedule(optimizer)
return LRScheduler(self.scheduler, step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH)
return pytorch.LRScheduler(
self.scheduler, step_mode=pytorch.LRScheduler.StepMode.STEP_EVERY_EPOCH
)


class XORTrialRestoreLR(XORTrialMulti):
def create_lr_scheduler(self, optimizer):
self.scheduler = ModifyableLRSchedule(optimizer)
return LRScheduler(self.scheduler, step_mode=LRScheduler.StepMode.STEP_EVERY_BATCH)
return pytorch.LRScheduler(
self.scheduler, step_mode=pytorch.LRScheduler.StepMode.STEP_EVERY_BATCH
)

def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
self, batch: pytorch.TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
metrics = super().train_batch(batch, model, epoch_idx, batch_idx)
lr = self.scheduler.get_last_lr()[0]
Expand All @@ -265,10 +269,12 @@ def train_batch(
class XORTrialUserStepLRFail(XORTrialMulti):
def create_lr_scheduler(self, optimizer):
self.scheduler = ModifyableLRSchedule(optimizer)
return LRScheduler(self.scheduler, step_mode=LRScheduler.StepMode.STEP_EVERY_BATCH)
return pytorch.LRScheduler(
self.scheduler, step_mode=pytorch.LRScheduler.StepMode.STEP_EVERY_BATCH
)

def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
self, batch: pytorch.TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
metrics = super().train_batch(batch, model, epoch_idx, batch_idx)
self.scheduler.step()
Expand All @@ -278,10 +284,12 @@ def train_batch(
class XORTrialUserStepLR(XORTrialMulti):
def create_lr_scheduler(self, optimizer):
self.scheduler = ModifyableLRSchedule(optimizer)
return LRScheduler(self.scheduler, step_mode=LRScheduler.StepMode.MANUAL_STEP)
return pytorch.LRScheduler(
self.scheduler, step_mode=pytorch.LRScheduler.StepMode.MANUAL_STEP
)

def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
self, batch: pytorch.TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
metrics = super().train_batch(batch, model, epoch_idx, batch_idx)
self.scheduler.step()
Expand Down Expand Up @@ -334,17 +342,30 @@ def build_callbacks(self) -> Dict[str, det.pytorch.PyTorchCallback]:

class XORTrialAccessContext(XORTrialStepEveryEpoch):
def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
self, batch: pytorch.TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
assert self.context.get_model()
assert self.context.get_optimizer()
assert self.context.get_lr_scheduler()

return super().train_batch(batch, model, epoch_idx, batch_idx)

def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]:
def evaluate_batch(self, batch: pytorch.TorchData, model: nn.Module) -> Dict[str, Any]:
assert self.context.get_model()
assert self.context.get_optimizer()
assert self.context.get_lr_scheduler()

return super().evaluate_batch(batch, model)


class XORTrialGradClipping(XORTrial):
def build_callbacks(self) -> Dict[str, pytorch.PyTorchCallback]:
hparams = self.context.get_hparams()

if "gradient_clipping_l2_norm" in hparams:
return {"grad_clip": pytorch.ClipGradsL2Norm(hparams["gradient_clipping_l2_norm"])}

elif "gradient_clipping_value" in hparams:
return {"grad_clip": pytorch.ClipGradsL2Value(hparams["gradient_clipping_value"])}

return {}
8 changes: 4 additions & 4 deletions harness/tests/experiment/pytorch/test_pytorch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,9 @@ def make_workloads(tag: str) -> workload.Stream:
)
controller.run()

updated_hparams = {"clip_grad_l2_norm": 0.0001, **self.hparams}
updated_hparams = {"gradient_clipping_l2_norm": 0.0001, **self.hparams}
controller = utils.make_trial_controller_from_trial_implementation(
trial_class=pytorch_xor_model.XORTrialMulti,
trial_class=pytorch_xor_model.XORTrialGradClipping,
hparams=updated_hparams,
workloads=make_workloads("clipped_by_norm"),
trial_seed=self.trial_seed,
Expand All @@ -431,9 +431,9 @@ def make_workloads(tag: str) -> workload.Stream:
continue
assert original["loss"] != clipped["loss"]

updated_hparams = {"clip_grad_val": 0.0001, **self.hparams}
updated_hparams = {"gradient_clipping_value": 0.0001, **self.hparams}
controller = utils.make_trial_controller_from_trial_implementation(
trial_class=pytorch_xor_model.XORTrialMulti,
trial_class=pytorch_xor_model.XORTrialGradClipping,
hparams=updated_hparams,
workloads=make_workloads("clipped_by_val"),
trial_seed=self.trial_seed,
Expand Down

0 comments on commit 25e725e

Please sign in to comment.