diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 3fea95ef8030..61e5e40078b8 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -304,10 +304,16 @@ class GeminiPlugin(DPPluginBase): max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. - use_tensor_parallel (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. - tp_size (int, optional): If 'use_tensor_parallel' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. - use_fused_layernorm (bool, optional): Whether to use fused layernorm operator, which is implemented in Shardformer. Used when 'use_tensor_parallel' is True. Default to False. - use_flash_attention (bool, optional): Whether to use flash attention, which is implemented in Shardformer. Used when 'use_tensor_parallel' is True. Default to False. + enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. + tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. """ @@ -341,10 +347,14 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, - use_tensor_parallel: bool = False, + enable_tensor_parallelism: bool = False, tp_size: int = 1, - use_fused_layernorm: bool = False, - use_flash_attention: bool = False, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_sequence_parallelism: bool = False, + enable_jit_fused: bool = False, + enable_sequence_overlap: bool = False, verbose: bool = False ) -> None: super().__init__() @@ -383,10 +393,14 @@ def __init__( max_norm=max_norm, norm_type=norm_type, ) - self.use_tensor_parallel = use_tensor_parallel - self.tp_size = tp_size if self.use_tensor_parallel else 1 - self.use_fused_layernorm = use_fused_layernorm if self.use_tensor_parallel else False - self.use_flash_attention = use_flash_attention if self.use_tensor_parallel else False + self.enable_tensor_parallelism = enable_tensor_parallelism + self.tp_size = tp_size if self.enable_tensor_parallelism else 1 + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_overlap = enable_sequence_overlap self.verbose = verbose def support_no_sync(self) -> bool: @@ -426,20 +440,24 @@ def configure( # wrap the model with Gemini self.dp_group = None self.tp_group = None - if self.use_tensor_parallel: - try: - dp_size = dist.get_world_size() // self.tp_size - self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) - self.dp_group = self.pg_mesh.get_group_along_axis(0) - self.tp_group = self.pg_mesh.get_group_along_axis(1) - shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, - enable_tensor_parallelism=True, - enable_fused_normalization=self.use_fused_layernorm, - enable_flash_attention=self.use_flash_attention) - shardformer = ShardFormer(shard_config) - model, _ = shardformer.optimize(model) - except NotImplementedError as e: - print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") + try: + dp_size = dist.get_world_size() // self.tp_size + assert dp_size > 1, f"the size of DP group should greater than 1. Please reduce the TP group size." + self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(0) + self.tp_group = self.pg_mesh.get_group_along_axis(1) + shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, + enable_tensor_parallelism=self.enable_tensor_parallelism, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=self.enable_sequence_parallelism, + enable_sequence_overlap=self.enable_sequence_overlap) + shardformer = ShardFormer(shard_config) + model, _ = shardformer.optimize(model) + except NotImplementedError as e: + print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 92014064433d..0d8c3d453ce1 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -162,7 +162,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter @@ -180,12 +180,16 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, @staticmethod def backward(ctx, grad_output): - input_, weight = ctx.saved_tensors + input_, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + if use_bias: + bias = bias.view(bias.shape) + if not overlap: input_parallel = _gather(input_, dim, process_group) @@ -299,7 +303,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter @@ -316,12 +320,17 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, @staticmethod def backward(ctx, grad_output): - input_, weight = ctx.saved_tensors + input_, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + weight = weight.view(weight.shape) + if use_bias: + bias = bias.view(bias.shape) + if not overlap: input_parallel = _gather(input_, dim, process_group) @@ -467,7 +476,7 @@ def backward(ctx, grad_output): class HookParameter(torch.autograd.Function): - "In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm" + """In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm""" @staticmethod def forward(ctx, input, weight, bias): ctx.save_for_backward(weight, bias) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 1bf87e80a461..cd8a023306dc 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -719,7 +719,7 @@ def forward( ): fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, tgt_len, _ = query_layer.size() + batch_size, tgt_len, _, _ = query_layer.size() _, kv_length, _, _ = key_layer.size() @@ -755,6 +755,7 @@ def forward( attention_numerical_mask = torch.masked_fill( attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min ) + attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype) context_layer = me_attention( query_layer, diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 7f97c0a82ed9..ae13866eb6aa 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -116,7 +116,7 @@ def __init__( self.param_info = param_info self.tp_group = tp_group self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 - self.tp_rank = dist.get_rank(tp_group) + self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else None self.verbose = verbose self.param_groups_backup = list() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index f5f9aee7a8b6..7954ce7b1b15 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -15,13 +15,14 @@ from tests.kit.model_zoo import model_zoo -def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_parallel) -> Optional[str]: +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]: try: if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, use_tensor_parallel=use_tensor_parallel, use_fused_layernorm=True, use_flash_attention=True) + enable_all_optimization = True if enable_tensor_parallelism else False + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -47,7 +48,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_p optimizer.step() except Exception as e: - # raise e + raise e return repr(e) @@ -57,8 +58,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_p @parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("init_method", ["none"]) -@parameterize("use_tensor_parallel", [True, False]) -def check_gemini_plugin(subset: str, init_method: str = "none", use_tensor_parallel: bool = True, early_stop: bool = True): +@parameterize("enable_tensor_parallelism", [True, False]) +def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True): """check gemini plugin over model zoo Args: @@ -120,9 +121,12 @@ def check_gemini_plugin(subset: str, init_method: str = "none", use_tensor_paral # TODO debug blip2 when using tp, something wrong with shift_logits's shape if "transformers_blip2" in name: - use_tensor_parallel = False + enable_tensor_parallelism = False - err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_parallel) + # if name is not "transformers_bloom": + # continue + + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) torch.cuda.empty_cache() if err is None: passed_models.append(name) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 3f5d663431d4..2220d9e5b96a 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -37,9 +37,10 @@ @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_safetensors", [False, True]) -@parameterize("use_tensor_parallel", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) @parameterize("tp_size", [2]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, use_tensor_parallel: bool, tp_size: int): +@parameterize("enable_all_optimization", [True, False]) +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int, enable_all_optimization: bool): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -49,7 +50,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b pretrained_path = os.path.join(tempdir, "pretrained") bert_model.config.save_pretrained(save_directory=pretrained_path) - plugin = GeminiPlugin(**placement_config, use_tensor_parallel=use_tensor_parallel, tp_size=tp_size, use_fused_layernorm=True) + plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -68,12 +69,13 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_gpt"]) @parameterize("size_per_shard", [32]) -@parameterize("use_tensor_parallel", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) @parameterize("tp_size", [2]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, use_tensor_parallel: bool, tp_size: int): +@parameterize("enable_all_optimization", [True, False]) +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int, enable_all_optimization: bool): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), use_tensor_parallel=use_tensor_parallel, tp_size=tp_size, use_fused_layernorm=True) + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) model = model_fn() @@ -148,7 +150,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_state_dict_with_origin() - exam_lazy_from_pretrained() + # exam_lazy_from_pretrained() @pytest.mark.dist