Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gemini] gemini support tensor parallelism. #4942

Merged
merged 47 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
dc0dc0b
[colossalai]fix typo
flybird11111 Sep 22, 2023
dd59ca2
[inference] Add smmoothquant for llama (#4904)
Xu-Kai Oct 16, 2023
52707c6
Update flash_attention_patch.py
Orion-Zheng Oct 13, 2023
61ec9f7
[kernel] support pure fp16 for cpu adam and update gemini optim tests…
ver217 Oct 16, 2023
561553b
[format] applied code formatting on changed files in pull request 490…
github-actions[bot] Oct 17, 2023
8d42002
[gemini] support gradient accumulation (#4869)
Fridge003 Oct 17, 2023
da55732
[hotfix] fix torch 2.0 compatibility (#4936)
ver217 Oct 18, 2023
775ea1b
[test] add no master test for low level zero plugin (#4934)
KKZ20 Oct 18, 2023
0074178
[format] applied code formatting on changed files in pull request 482…
github-actions[bot] Oct 18, 2023
907aa98
[nfc] fix some typo with colossalai/ docs/ etc. (#4920)
digger-yu Oct 18, 2023
31fddbc
[Refactor] Integrated some lightllm kernels into token-attention (#4…
tiandiao123 Oct 19, 2023
8633a87
[test] merge old components to test to model zoo (#4945)
ver217 Oct 20, 2023
9d543af
[inference] add reference and fix some bugs (#4937)
Xu-Kai Oct 20, 2023
fe79560
[Inference]ADD Bench Chatglm2 script (#4963)
CjhHa1 Oct 24, 2023
a610046
[Pipeline inference] Combine kvcache with pipeline inference (#4938)
FoolPlayer Oct 27, 2023
3b8137d
updated c++17 compiler flags (#4983)
kurisusnowdeng Oct 27, 2023
9fce43b
[Inference] Dynamic Batching Inference, online and offline (#4953)
CjhHa1 Oct 30, 2023
62eb99f
[Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding …
tiandiao123 Oct 30, 2023
fa1cbd3
fix ColossalEval (#4992)
chengeharrison Oct 31, 2023
3209431
[doc]Update doc for colossal-inference (#4989)
tiandiao123 Oct 31, 2023
f0482f4
[hotfix] Fix the bug where process groups were not being properly rel…
littsk Oct 31, 2023
cd8ad65
[hotfix] fix the bug of repeatedly storing param group (#4951)
Fridge003 Oct 31, 2023
5266946
[doc] add supported feature diagram for hybrid parallel plugin (#4996)
ppt0011 Oct 31, 2023
ab8468c
[Pipeline Inference] Merge pp with tp (#4993)
FoolPlayer Nov 1, 2023
f9c1920
[release] update version (#4995)
ver217 Nov 1, 2023
2043b9d
[gemini] gemini support tp
flybird11111 Oct 18, 2023
da1915d
fix
flybird11111 Oct 19, 2023
9fd9e69
update checkpointIO
flybird11111 Oct 20, 2023
a89f2fd
support fused layernorm
flybird11111 Oct 23, 2023
2406cb0
update fusedlayernorm
flybird11111 Oct 23, 2023
a0509a6
add sequence parallel to gemini
flybird11111 Oct 24, 2023
12cd780
fix
flybird11111 Oct 25, 2023
0110902
fix comments
flybird11111 Oct 25, 2023
86a5eca
fix
flybird11111 Oct 30, 2023
6f13876
fix t5
flybird11111 Oct 30, 2023
5f16e4f
clear cache
flybird11111 Oct 30, 2023
adead50
fix
flybird11111 Oct 31, 2023
ed825dc
activate ci
flybird11111 Oct 31, 2023
37494c3
activate ci
flybird11111 Oct 31, 2023
73da4ca
fix
flybird11111 Nov 1, 2023
cf2bc63
fix
flybird11111 Nov 1, 2023
6c85a9e
fix
flybird11111 Nov 1, 2023
8dd4b41
fix
flybird11111 Nov 1, 2023
3d8319e
revert
flybird11111 Nov 1, 2023
66ffed5
modify tp gather method
flybird11111 Nov 6, 2023
c40c459
fix test
flybird11111 Nov 8, 2023
bc575a2
Merge branch 'main' into gemini-tp
flybird11111 Nov 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 77 additions & 4 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, Iterator, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
Expand All @@ -19,8 +20,9 @@
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
Expand All @@ -32,7 +34,25 @@
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}

DP_AXIS = 0
TP_AXIS = 1

def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.

if optim is None:
return {}
param_info = {"id2shape": {}}
start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
param_info["id2shape"][param_id] = original_shape

start_index += len(group["params"])

return param_info
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -284,6 +304,16 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
"""

Expand Down Expand Up @@ -317,6 +347,14 @@ def __init__(
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
enable_tensor_parallelism: bool = False,
tp_size: int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
verbose: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -355,8 +393,32 @@ def __init__(
max_norm=max_norm,
norm_type=norm_type,
)
self.enable_tensor_parallelism = enable_tensor_parallelism
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose

self.tp_size = tp_size if self.enable_tensor_parallelism else 1
self.dp_size = dist.get_world_size() // self.tp_size
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
)

def support_no_sync(self) -> bool:
return False

Expand All @@ -380,6 +442,7 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
optimizer_params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
Expand All @@ -391,11 +454,21 @@ def configure(
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)

# wrap the model with Gemini
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if self.enable_tensor_parallelism:
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)

model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
optimizer,
model,
**self.zero_optim_config,
**self.optim_kwargs,
tp_group=self.tp_group,
optimizer_params_info=optimizer_params_info,
verbose=self.verbose,
)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand All @@ -407,4 +480,4 @@ def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()

def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
raise NotImplementedError
1 change: 1 addition & 0 deletions colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,4 @@ def get_group_along_axis(
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group]

58 changes: 50 additions & 8 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
Expand All @@ -62,13 +62,18 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):

if bias is not None:
output = output + bias

return output

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
bias = bias.view(bias.shape)

total_input = input
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
Expand Down Expand Up @@ -100,7 +105,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
Expand All @@ -109,13 +114,18 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)

return output

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)

total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
Expand Down Expand Up @@ -152,7 +162,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
Expand All @@ -170,12 +180,16 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group)

Expand Down Expand Up @@ -289,7 +303,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
Expand All @@ -306,12 +320,17 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group)

Expand Down Expand Up @@ -454,6 +473,29 @@ def forward(ctx, input_, dim, process_group):
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None


class HookParameter(torch.autograd.Function):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias)
output = input
return output

@staticmethod
def backward(ctx, grad_output):
weight, bias = ctx.saved_tensors
if weight is not None:
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
return grad_output, None, None


def hook_paramter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)



def _reduce(input_, process_group):
Expand Down
5 changes: 3 additions & 2 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def forward(self, input_: Tensor) -> Tensor:
)

# Mask the output embedding.
output_parallel[input_mask, :] = 0.0
embedding_output = output_parallel.clone()
Fridge003 marked this conversation as resolved.
Show resolved Hide resolved
embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_forward(output_parallel, self.process_group)
output = reduce_forward(embedding_output, self.process_group)
return output
Loading
Loading