diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 82d37bb4cf94..19da348e707d 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -132,7 +132,7 @@ def backward(ctx, grad_output): if use_bias: bias.view(bias.shape) - total_input = input + total_input = input.contiguous() grad_input = grad_output.matmul(weight) grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 53c151f02f63..34d900d8de94 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -11,7 +11,11 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.layer import AttnMaskType, ColoAttention -from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) def get_flash_core_attention_forward(): @@ -203,6 +207,13 @@ def chatglm_model_forward( dim=0, process_group=shard_config.tensor_parallel_process_group, ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=0, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: @@ -235,6 +246,13 @@ def chatglm_model_forward( dim=0, process_group=shard_config.tensor_parallel_process_group, ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=0, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=shard_config.sequence_parallel_size, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -329,7 +347,9 @@ def chatglm_for_conditional_generation_forward( return transformer_outputs -def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): +def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode, sp_size, sp_group): + logger = logging.get_logger(__name__) + def forward( self, input_ids, @@ -381,13 +401,27 @@ def forward( rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + if sp_mode in ["all_to_all"] and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with sp mode `{sp_mode}`. Setting `use_cache=False`..." + ) + use_cache = False # Run encoder. # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] - inputs_embeds = split_forward_gather_backward( - inputs_embeds, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if sp_mode in ["split_gather"]: + inputs_embeds = split_forward_gather_backward( + inputs_embeds, + dim=0, + process_group=sp_group, + ) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward( + inputs_embeds, + dim=0, + process_group=sp_group, + grad_scale=1 / sp_size, + ) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, @@ -397,11 +431,19 @@ def forward( output_hidden_states=output_hidden_states, ) - hidden_states = gather_forward_split_backward( - hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group, - ) + if sp_mode in ["split_gather"]: + hidden_states = gather_forward_split_backward( + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, + ) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=0, + process_group=sp_group, + grad_scale=sp_size, + ) if not return_dict: return tuple( @@ -423,3 +465,158 @@ def forward( ) return forward + + +def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, sp_mode, sp_size, sp_group): + from .chatglm2_6b.modeling_chatglm import apply_rotary_pos_emb, split_tensor_along_last_dim + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + + mixed_x_layer = self.query_key_value(hidden_states) + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + sq, bs, _, _ = value_layer.size() + + query_layer = query_layer.reshape(sq, bs, -1) + key_layer = key_layer.reshape(sq, bs, -1) + value_layer = value_layer.reshape(sq, bs, -1) + + query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0) + key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0) + value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0) + + query_layer = query_layer.view( + sq * sp_size, + bs, + self.num_attention_heads_per_partition // sp_size, + self.hidden_size_per_attention_head, + ).contiguous() + + key_layer = key_layer.view( + sq * sp_size, + bs, + self.num_attention_heads_per_partition // sp_size, + self.hidden_size_per_attention_head, + ).contiguous() + + value_layer = value_layer.view( + sq * sp_size, + bs, + self.num_attention_heads_per_partition // sp_size, + self.hidden_size_per_attention_head, + ).contiguous() + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + + ( + self.num_attention_heads_per_partition // sp_size, + self.hidden_size_per_attention_head, + ) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + if sp_mode == "all_to_all": + context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0) + + # ================= + # Output. [sq, b, h] + # ================= + output = self.dense(context_layer) + + return output, kv_cache + + return forward diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 07a7f6cbf8d3..77c12b9dbc83 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -135,6 +135,21 @@ def command_model_forward( ) use_cache = False + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + ) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -191,6 +206,21 @@ def command_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=shard_config.sequence_parallel_size, + ) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 11c26822f50a..90fd0661f600 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple, Union import torch +from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -30,6 +31,11 @@ from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer._operation import ( + all_to_all_comm, + gather_forward_split_backward, + split_forward_gather_backward, +) from colossalai.shardformer.shard import ShardConfig from ..layer import ColoAttention, cross_entropy_1d @@ -162,6 +168,21 @@ def qwen2_model_forward( sliding_window=self.config.sliding_window, ) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = split_forward_gather_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=1 / shard_config.sequence_parallel_size, + ) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -218,6 +239,20 @@ def qwen2_model_forward( if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) + if shard_config and shard_config.enable_sequence_parallelism: + if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + ) + elif shard_config.sequence_parallelism_mode == "all_to_all": + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.sequence_parallel_process_group, + grad_scale=shard_config.sequence_parallel_size, + ) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) @@ -469,7 +504,7 @@ def qwen2_for_sequence_classification_forward( return {"hidden_states": hidden_states} -def get_qwen2_flash_attention_forward(shard_config: ShardConfig): +def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self: Qwen2Attention, hidden_states: torch.Tensor, @@ -480,12 +515,28 @@ def forward( use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group) + key_states = all_to_all_comm(key_states, sp_group) + value_states = all_to_all_comm(value_states, sp_group) + bsz, q_len, _ = query_states.size() + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -538,10 +589,41 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." - attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + if shard_config.enable_flash_attention: + assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict." + attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value @@ -549,9 +631,8 @@ def forward( return forward -def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig): +def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) - assert shard_config.enable_flash_attention, "Flash Attention is not enabled." def forward( self, @@ -601,17 +682,26 @@ def forward( # embed positions hidden_states = inputs_embeds - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) + if shard_config.enable_flash_attention: + # in this case, attention_mask is a dict rather than a tensor + mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) + attention_mask = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) - if self.gradient_checkpointing and self.training: + if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -623,6 +713,11 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = None + if sp_mode in ["ring", "split_gather"]: + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + elif sp_mode == "all_to_all": + inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) @@ -657,6 +752,11 @@ def forward( hidden_states = self.norm(hidden_states) + if sp_mode == "ring" or sp_mode == "split_gather": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + elif sp_mode == "all_to_all": + hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 01aa77e57c00..3877bdac3ae2 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -9,6 +9,7 @@ from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from ..modeling.chatglm2 import ( + get_chatglm_sequence_parallel_attention_forward, get_chatglm_sequence_parallel_forward_fn, get_flash_core_attention_forward, get_jit_fused_glm_block_forward, @@ -58,14 +59,29 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: norm_cls = col_nn.LayerNorm sp_mode = self.shard_config.sequence_parallelism_mode or None - assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2" + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + if sp_mode == "ring": warnings.warn( f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" ) sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap - sp_partial_derived = sp_mode == "split_gather" + sp_partial_derived = sp_mode in ["split_gather"] + + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + "hidden_size_per_partition": self.model.config.kv_channels + * self.model.config.num_attention_heads + // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + policy["CoreAttention"] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) if self.shard_config.enable_tensor_parallelism: assert ( @@ -179,12 +195,26 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) # use sequence parallel - if sp_mode == "split_gather": + if self.shard_config.enable_sequence_parallelism: self.append_or_create_method_replacement( - description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + description={ + "forward": get_chatglm_sequence_parallel_attention_forward( + self.shard_config, sp_mode, sp_size, sp_group + ), + }, policy=policy, - target_key="ChatGLMModel", + target_key="SelfAttention", ) + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_chatglm_sequence_parallel_forward_fn( + self.shard_config, sp_mode, sp_size, sp_group + ) + }, + policy=policy, + target_key="ChatGLMModel", + ) # use jit fused operator if self.shard_config.enable_jit_fused: diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 902baf2e177c..a9b915d10485 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -66,13 +65,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: else: norm_cls = LayerNorm - if self.pipeline_stage_manager is not None: - self.shard_config.enable_sequence_parallelism = False - self.shard_config.enable_sequence_overlap = False - self.shard_config.sequence_parallelism_mode = None - warnings.warn( - f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False" - ) sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 3e427c4a1623..362c14060fd9 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Union @@ -82,9 +81,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: embedding_cls = PaddingEmbedding norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm - if self.shard_config.enable_sequence_parallelism: - self.shard_config.enable_sequence_parallelism = False - warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + sp_mode = self.shard_config.sequence_parallelism_mode or None + sp_size = self.shard_config.sequence_parallel_size or None + sp_group = self.shard_config.sequence_parallel_process_group or None + sp_partial_derived = sp_mode in ["split_gather", "ring"] + if sp_mode == "all_to_all": + decoder_attribute_replacement = { + "num_heads": self.model.config.num_attention_heads // sp_size, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size + + policy[attn_cls] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + ) if self.shard_config.enable_tensor_parallelism: assert ( @@ -109,30 +119,37 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, + kwargs=dict(seq_parallel_mode=sp_mode), ), ], ) @@ -154,10 +171,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="input_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), SubModuleReplacementDescription( suffix="post_attention_layernorm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), ], policy=policy, @@ -168,16 +187,16 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: description=SubModuleReplacementDescription( suffix="norm", target_module=norm_cls, + kwargs={"sp_partial_derived": sp_partial_derived}, ), policy=policy, target_key=Qwen2Model, ) - # use flash attention - if self.shard_config.enable_flash_attention: + if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: self.append_or_create_method_replacement( description={ - "forward": get_qwen2_flash_attention_forward(self.shard_config), + "forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, target_key=attn_cls, @@ -186,7 +205,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # replace qwen2 model forward method self.append_or_create_method_replacement( description={ - "forward": get_qwen2_model_forward_for_flash_attn(self.shard_config), + "forward": get_qwen2_model_forward_for_flash_attn( + self.shard_config, sp_mode, sp_size, sp_group + ), }, policy=policy, target_key=Qwen2Model, diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 6ce020b68ab5..92c077950ecc 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -136,6 +136,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 4, "pp_size": 1, diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 4d66692a4c11..3281b50e1d5d 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -58,6 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # Check the grad when using ZeRO-1 and ZeRO-2 if ( booster.plugin.zero_stage in [1, 2] + and booster.plugin.shard_config.pipeline_stage_manager is None and booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): @@ -154,6 +155,45 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 1, diff --git a/tests/test_shardformer/test_model/test_shard_qwen2.py b/tests/test_shardformer/test_model/test_shard_qwen2.py index 166b31df967e..160f9c53b68d 100644 --- a/tests/test_shardformer/test_model/test_shard_qwen2.py +++ b/tests/test_shardformer/test_model/test_shard_qwen2.py @@ -180,6 +180,68 @@ def run_qwen2_test(test_config): "zero_stage": 1, "initial_scale": 1, }, + { # Ulysess + Flash attention + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring", + "enable_flash_attention": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 1, + "sp_size": 2, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "all_to_all", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "num_microbatches": 1, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", + "enable_flash_attention": False, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, { "tp_size": 2, "pp_size": 2,