diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 0f5dfa0feea4..fb3978276ade 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -95,7 +95,7 @@ def backward(ctx, grad_output): total_input = total_input.view(-1, total_input.shape[-1]) if ctx.async_grad_allreduce and fp8_communication: - _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication) + _reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2") elif ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) @@ -566,7 +566,7 @@ def forward(ctx, input_, process_group, dim, fp8_communication=False): input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) if fp8_communication: - reduce_scatter_fp8(output, input_list, group=process_group) + reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3") else: dist.reduce_scatter(output, input_list, group=process_group) @@ -577,7 +577,12 @@ def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group fp8_communication = ctx.fp8_communication - return _gather(grad_output, dim, process_group, fp8_communication=fp8_communication), None, None, None + return ( + _gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"), + None, + None, + None, + ) class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -618,7 +623,7 @@ def forward( ) else: - input_parallel = _gather(input_, dim, process_group, fp8_communication) + input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") output = torch.matmul(input_parallel, weight) @@ -641,7 +646,7 @@ def backward(ctx, grad_output): bias = bias.view(bias.shape) if not overlap: - input_parallel = _gather(input_, dim, process_group, fp8_communication) + input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2") total_input = input_parallel grad_input = grad_output.matmul(weight.T) @@ -728,8 +733,13 @@ def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - # to_cast.append(grad_output.cpu().detach().numpy()) - return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None + return ( + _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"), + None, + None, + None, + None, + ) class _ReduceForward(torch.autograd.Function): @@ -743,7 +753,7 @@ class _ReduceForward(torch.autograd.Function): @staticmethod def forward(ctx, input_, process_group, fp8_communication=False): - return _reduce(input_, process_group, fp8_communication) + return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): @@ -768,7 +778,7 @@ def forward(ctx, input_, process_group, fp8_communication=False): @staticmethod def backward(ctx, grad_output): fp8_communication = ctx.fp8_communication - return _reduce(grad_output, ctx.process_group, fp8_communication), None, None + return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None class _GatherForwardSplitBackward(torch.autograd.Function): @@ -786,7 +796,7 @@ def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication= ctx.dim = dim ctx.grad_scale = grad_scale - return _gather(input_, dim, process_group, fp8_communication=fp8_communication) + return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): @@ -851,13 +861,13 @@ def hook_parameter_in_backward(input, weight=None, bias=None): return HookParameter.apply(input, weight, bias) -def _reduce(input_, process_group, fp8_communication=False): +def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"): # skip if only one rank involved if dist.get_world_size(process_group) == 1: return input_ else: if fp8_communication: - all_reduce_fp8(input_, group=process_group) + all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format) else: dist.all_reduce(input_, group=process_group) return input_