Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Support gradient accumulation for DDP #173

Merged
merged 3 commits into from
Apr 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions oslo/torch/nn/parallel/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,7 @@
from oslo.torch.nn.parallel.data_parallel._utils import is_ddp_ignored


def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
data.storage().resize_(0)


class _BackwardFunction(torch.autograd.Function):
class _DistributedBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, *inputs):
ctx.module = module
Expand Down Expand Up @@ -64,7 +55,7 @@ def DistributedDataParallel(


class _DistributedDataParallel(OsloParallelWrapper):
"""Distributed data parallel wrapper for Oslo.
"""Distributed data parallel wrapper for OSLO.
Example:
>>> from oslo.torch.nn.parallel import DistributedDataParallel as DDP
>>> model = torch.nn.Linear(20, 1)
Expand Down Expand Up @@ -117,14 +108,15 @@ def forward(self, *args, **kwargs):
{
k: v
for k, v in zip(
inputs.keys(), _BackwardFunction.apply(self, *inputs.values())
inputs.keys(),
_DistributedBackwardFunction.apply(self, *inputs.values()),
)
}
)

if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
return _BackwardFunction.apply(self, *inputs)
return _DistributedBackwardFunction.apply(self, *inputs)

def _pre_backward(self):
pass
Expand All @@ -138,13 +130,11 @@ def _post_backward(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad
p.grad = p._saved_grad

def grad_handle(self, p, grad):
if grad.device.type != "cpu":
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
Expand All @@ -161,7 +151,7 @@ def grad_handle(self, p, grad):
return empty_grad

else:
# You must model.to('cpu') after oslo.ready() to use cpu.
# You must assign the model to CPU after invoking ``oslo.ready()``.
dist.all_reduce(
grad, group=self.parallel_context.get_cpu_group(ParallelMode.DATA)
)
Expand Down