Skip to content

Commit

Permalink
support all2all fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Jul 30, 2024
1 parent 5fd0592 commit 9043fba
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,42 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
tensor_out = torch.cat(tensor_list, dim=0)
tensor.data = tensor_out.view(input_shape).to(input_type)

def all_to_all_single_fp8(output, input, output_tensor_list, input_tensor_list, fp8_format="e5m2", group=None, async_op=False) -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
Returns:
None
"""

world_size = dist.get_world_size(group=group)
input_type = input.dtype
input_shape = input.shape
input_device = input.device
input = input.flatten()

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2

ret, scale = cast_to_fp8(input, fp8_format=fp8_format)

inp = ret.view(torch.uint8)
input_chunks = torch.split(inp, input_tensor_list)

output_chunks = [torch.empty((output_tensor_list[i]*np.prod(input_shape[1:]),), device=input_device, dtype=input_type) for i in range(world_size)]

dist.all_to_all(output_chunks, input_chunks, group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
out = cast_from_fp8(out, scale, input_type)

tensor_out = torch.cat(output_chunks, dim=0)
output.data = tensor_out.to(input_type)


def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Expand Down

0 comments on commit 9043fba

Please sign in to comment.