Skip to content

Commit

Permalink
fix shardformer fp8 communication training degradation
Browse files Browse the repository at this point in the history
  • Loading branch information
GuangyaoZhang authored and flybird11111 committed Aug 2, 2024
1 parent 176c970 commit 8fb90e1
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def backward(ctx, grad_output):
total_input = total_input.view(-1, total_input.shape[-1])

if ctx.async_grad_allreduce and fp8_communication:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication)
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
Expand Down Expand Up @@ -566,7 +566,7 @@ def forward(ctx, input_, process_group, dim, fp8_communication=False):
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
if fp8_communication:
reduce_scatter_fp8(output, input_list, group=process_group)
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
else:
dist.reduce_scatter(output, input_list, group=process_group)

Expand All @@ -577,7 +577,12 @@ def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None
return (
_gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"),
None,
None,
None,
)


class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
Expand Down Expand Up @@ -618,7 +623,7 @@ def forward(
)

else:
input_parallel = _gather(input_, dim, process_group, fp8_communication)
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")

output = torch.matmul(input_parallel, weight)

Expand All @@ -641,7 +646,7 @@ def backward(ctx, grad_output):
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group, fp8_communication)
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")

total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
Expand Down Expand Up @@ -728,8 +733,13 @@ def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale

# to_cast.append(grad_output.cpu().detach().numpy())
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None
return (
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
None,
None,
None,
None,
)


class _ReduceForward(torch.autograd.Function):
Expand All @@ -743,7 +753,7 @@ class _ReduceForward(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, process_group, fp8_communication=False):
return _reduce(input_, process_group, fp8_communication)
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -768,7 +778,7 @@ def forward(ctx, input_, process_group, fp8_communication=False):
@staticmethod
def backward(ctx, grad_output):
fp8_communication = ctx.fp8_communication
return _reduce(grad_output, ctx.process_group, fp8_communication), None, None
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None


class _GatherForwardSplitBackward(torch.autograd.Function):
Expand All @@ -786,7 +796,7 @@ def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=
ctx.dim = dim
ctx.grad_scale = grad_scale

return _gather(input_, dim, process_group, fp8_communication=fp8_communication)
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")

@staticmethod
def backward(ctx, grad_output):
Expand Down Expand Up @@ -851,13 +861,13 @@ def hook_parameter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)


def _reduce(input_, process_group, fp8_communication=False):
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
if fp8_communication:
all_reduce_fp8(input_, group=process_group)
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
else:
dist.all_reduce(input_, group=process_group)
return input_
Expand Down

0 comments on commit 8fb90e1

Please sign in to comment.