Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fp8] support all-gather flat tensor #5932

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any

import numpy as np
import torch
import torch.distributed as dist

Expand Down Expand Up @@ -202,3 +203,78 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
output.data = summed_out


def split_chunk_by_channel(
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
):
offset = chunk.numel() * rank
end = offset + chunk.numel()
break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end]
if len(break_points) == 0 or break_points[0] > offset:
break_points.insert(0, offset)
if break_points[-1] < end:
break_points.append(end)
sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])]
return chunk.split(sizes)


def all_gather_into_tensor_flat_fp8(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_shape: torch.Size,
group: dist.ProcessGroup,
fp8_format: str = "e4m3",
):
"""all gather into tensor in fp8 format

Args:
output_tensor (torch.Tensor): output tensor, which is flattened
input_tensor (torch.Tensor): input tensor, which is flattened
group (dist.ProcessGroup): process group
fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3".
"""
assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened"
world_size = dist.get_world_size(group)
assert (
output_tensor.numel() == input_tensor.numel() * world_size
), "output tensor size should be world_size times of input tensor size"

input_type = output_tensor.dtype

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max

if len(output_shape) == 2:
per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float)
num_channels, channel_size = output_shape
rank = dist.get_rank(group)
channel_start_idx = (input_tensor.numel() * rank) // channel_size
per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_max[idx] = per_channel_split.abs().max().float()
dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group)
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max
fp8_input = input_tensor.float()
fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(fp8_per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_split.mul_(scale[idx])
fp8_input = fp8_input.to(fp8_type)
else:
per_tensor_max = input_tensor.abs().max().float()
dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group)
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
fp8_input = (scale * input_tensor.float()).to(fp8_type)
scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
numel = np.prod(output_shape)
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type)
output_tensor[:numel].copy_(valid_buffer.view(-1))
40 changes: 40 additions & 0 deletions tests/test_fp8/test_fp8_allgather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize("shape", [(3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_4gpu(shape, dtype):
world_size = dist.get_world_size()
rank = dist.get_rank()
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
flat_padded_x = x.view(-1)
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
output = torch.empty_like(flat_padded_x)
chunk = flat_padded_x.chunk(world_size)[rank].clone()
all_gather_into_tensor_flat_fp8(output, chunk, x.shape, group=_get_default_group())
assert_close(output[: x.numel()], x.view(-1), rtol=0.1, atol=0.1)


def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_4gpu()


@rerun_if_address_is_in_use()
def test_all_gather():
spawn(run_dist, 4)


if __name__ == "__main__":
test_all_gather()
Loading