Skip to content

Commit

Permalink
add sequence parallel to gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Oct 24, 2023
1 parent 503c25e commit 9bb81f7
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 46 deletions.
68 changes: 43 additions & 25 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,16 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
use_tensor_parallel (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'use_tensor_parallel' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
use_fused_layernorm (bool, optional): Whether to use fused layernorm operator, which is implemented in Shardformer. Used when 'use_tensor_parallel' is True. Default to False.
use_flash_attention (bool, optional): Whether to use flash attention, which is implemented in Shardformer. Used when 'use_tensor_parallel' is True. Default to False.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
"""

Expand Down Expand Up @@ -341,10 +347,14 @@ def __init__(
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
use_tensor_parallel: bool = False,
enable_tensor_parallelism: bool = False,
tp_size: int = 1,
use_fused_layernorm: bool = False,
use_flash_attention: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
verbose: bool = False
) -> None:
super().__init__()
Expand Down Expand Up @@ -383,10 +393,14 @@ def __init__(
max_norm=max_norm,
norm_type=norm_type,
)
self.use_tensor_parallel = use_tensor_parallel
self.tp_size = tp_size if self.use_tensor_parallel else 1
self.use_fused_layernorm = use_fused_layernorm if self.use_tensor_parallel else False
self.use_flash_attention = use_flash_attention if self.use_tensor_parallel else False
self.enable_tensor_parallelism = enable_tensor_parallelism
self.tp_size = tp_size if self.enable_tensor_parallelism else 1
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose

def support_no_sync(self) -> bool:
Expand Down Expand Up @@ -426,20 +440,24 @@ def configure(
# wrap the model with Gemini
self.dp_group = None
self.tp_group = None
if self.use_tensor_parallel:
try:
dp_size = dist.get_world_size() // self.tp_size
self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(0)
self.tp_group = self.pg_mesh.get_group_along_axis(1)
shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group,
enable_tensor_parallelism=True,
enable_fused_normalization=self.use_fused_layernorm,
enable_flash_attention=self.use_flash_attention)
shardformer = ShardFormer(shard_config)
model, _ = shardformer.optimize(model)
except NotImplementedError as e:
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
try:
dp_size = dist.get_world_size() // self.tp_size
assert dp_size > 1, f"the size of DP group should greater than 1. Please reduce the TP group size."
self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(0)
self.tp_group = self.pg_mesh.get_group_along_axis(1)
shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap)
shardformer = ShardFormer(shard_config)
model, _ = shardformer.optimize(model)
except NotImplementedError as e:
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")

model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)

Expand Down
19 changes: 14 additions & 5 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
Expand All @@ -180,12 +180,16 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group)

Expand Down Expand Up @@ -299,7 +303,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
Expand All @@ -316,12 +320,17 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,

@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)

if not overlap:
input_parallel = _gather(input_, dim, process_group)

Expand Down Expand Up @@ -467,7 +476,7 @@ def backward(ctx, grad_output):


class HookParameter(torch.autograd.Function):
"In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias)
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def forward(
):
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _ = query_layer.size()
batch_size, tgt_len, _, _ = query_layer.size()

_, kv_length, _, _ = key_layer.size()

Expand Down Expand Up @@ -755,6 +755,7 @@ def forward(
attention_numerical_mask = torch.masked_fill(
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
)
attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)

context_layer = me_attention(
query_layer,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
self.param_info = param_info
self.tp_group = tp_group
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
self.tp_rank = dist.get_rank(tp_group)
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else None
self.verbose = verbose
self.param_groups_backup = list()

Expand Down
18 changes: 11 additions & 7 deletions tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from tests.kit.model_zoo import model_zoo


def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_parallel) -> Optional[str]:
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
try:
if init_method == "lazy":
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, use_tensor_parallel=use_tensor_parallel, use_fused_layernorm=True, use_flash_attention=True)
enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
Expand All @@ -47,7 +48,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_p
optimizer.step()

except Exception as e:
# raise e
raise e
return repr(e)


Expand All @@ -57,8 +58,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_p

@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
@parameterize("use_tensor_parallel", [True, False])
def check_gemini_plugin(subset: str, init_method: str = "none", use_tensor_parallel: bool = True, early_stop: bool = True):
@parameterize("enable_tensor_parallelism", [True, False])
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
"""check gemini plugin over model zoo
Args:
Expand Down Expand Up @@ -120,9 +121,12 @@ def check_gemini_plugin(subset: str, init_method: str = "none", use_tensor_paral

# TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name:
use_tensor_parallel = False
enable_tensor_parallelism = False

err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_parallel)
# if name is not "transformers_bloom":
# continue

err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
Expand Down
16 changes: 9 additions & 7 deletions tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
@parameterize("use_tensor_parallel", [True, False])
@parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, use_tensor_parallel: bool, tp_size: int):
@parameterize("enable_all_optimization", [True, False])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int, enable_all_optimization: bool):
from transformers import BertForSequenceClassification

(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
Expand All @@ -49,7 +50,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path)

plugin = GeminiPlugin(**placement_config, use_tensor_parallel=use_tensor_parallel, tp_size=tp_size, use_fused_layernorm=True)
plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
Expand All @@ -68,12 +69,13 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32])
@parameterize("use_tensor_parallel", [True, False])
@parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, use_tensor_parallel: bool, tp_size: int):
@parameterize("enable_all_optimization", [True, False])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int, enable_all_optimization: bool):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), use_tensor_parallel=use_tensor_parallel, tp_size=tp_size, use_fused_layernorm=True)
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)

model = model_fn()
Expand Down Expand Up @@ -148,7 +150,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
exam_state_dict_with_origin()
exam_lazy_from_pretrained()
# exam_lazy_from_pretrained()


@pytest.mark.dist
Expand Down

0 comments on commit 9bb81f7

Please sign in to comment.