diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 2e48a79dc1d7..559f9a56f61e 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -30,27 +30,48 @@ ### Quick Start -The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.): +The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization): ```python -from colossalai.shardformer import ShardConfig, Shard +from colossalai.shardformer import ShardConfig, ShardFormer from transformers import BertForMaskedLM +import colossalai # launch colossalai -colossalai.launch_from_torch() +colossalai.launch_from_torch(config={}) # create model config = BertConfig.from_pretrained('bert-base-uncased') model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) # create huggingface model as normal -shard_config = ShardConfig() +shard_config = ShardConfig(tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=True, + enable_fused_normalization=True, + enable_flash_attention=True, + enable_jit_fused=True, + enable_sequence_parallelism=True, + enable_sequence_overlap=True) + shard_former = ShardFormer(shard_config=shard_config) -sharded_model = shard_former.optimize(model).to('cuda') +sharded_model, shared_params = shard_former.optimize(model).to('cuda') # do everything like normal ... ``` +shardformer configuration + +`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel. +`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. +{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }} +`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows +`enable_fused_normalization`: using apex fused layernorm +`enable_flash_attention`: using flash attention +`enable_jit_fused`: using jit fused operators +`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension. +`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`. + ### Write your own policy @@ -82,44 +103,30 @@ We will follow this roadmap to develop Shardformer: - [x] API Implementation - [x] Unit Testing - [ ] Policy Implementation - - [ ] Hugging Face - - [ ] NLP - - [x] BERT - - [x] T5 - - [x] LlaMa - - [x] GPT2 - - [x] OPT - - [x] BLOOM - - [ ] GLM - - [ ] RoBERTa - - [ ] ALBERT - - [ ] ERNIE - - [ ] GPT Neo - - [ ] GPT-J - - [ ] CV - - [x] ViT - - [ ] BEiT - - [ ] SwinTransformer - - [ ] SwinTransformer V2 - - [ ] Audio - - [x] Whisper - - [ ] Multi-modal - - [x] SAM - - [x] BLIP-2 -- [ ] Flash Attention Support - - [ ] NLP - - [x] BERT - - [x] T5 - - [x] LlaMa - - [x] GPT2 - - [x] OPT - - [x] BLOOM - - [ ] GLM - - [ ] RoBERTa - - [ ] ALBERT - - [ ] ERNIE - - [ ] GPT Neo - - [ ] GPT-J + +| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | +| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | +| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | + ## 💡 API Design @@ -286,41 +293,36 @@ class ShardFormer: Example: + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') + shard_config = ShardConfig() shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - model = shard_former.optimize(model, policy=policy) - dataloader = shard_former.shard_dataset(dataset) + model, shared_params = shard_former.optimize(org_model) """ def __init__(self, shard_config: ShardConfig): """ Do two things: - 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp + 1. Create a distribute coordinator 2. serve as a store for shard config """ self.shard_config = shard_config - self.pg_manager = None + self.coordinator = DistCoordinator() - def init_distributed(self) -> colossalai.cluster.ProcessGroupManager: - """ - Initialize the distributed process group according to the - """ - pg_manager = ... - self.pg_manager = pg_manager - return pg_manager + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: + r""" + This method will optimize the model based on the given policy. - def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module: - """ - Shard model for TP and PP - """ - ... + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding - def shard_dataset(self, dataset: Dataset) -> Dataloader: + Returns: the sharded model and the shared parameters """ - Shard dataset for DP - """ - ... + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) + shared_params = sharder.shard() + return model, shared_params ``` ## ⌨️ Development Notes @@ -429,13 +431,24 @@ As shown in the figures above, when the sequence length is around 1000 or greate ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. + +the configurations are as follows: +```python +batch_size = 2 +epoch = 3 +lr = 2.4e-5 +accumulation_steps = 8 +warmup_fraction = 0.03 +``` + | accuracy | f1 | loss | GPU number | model sharded | | :------: | :-----: | :-----: | :--------: | :---------: | -| 0.84589 | 0.88613 | 0.43414 | 4 | True | -| 0.83594 | 0.88064 | 0.43298 | 1 | False | +| 0.82971 | 0.87713 | 0.23194 | 4 | True | +| 0.83797 | 0.88006 | 0.22683 | 2 | True | +| 0.84521 | 0.88700 | 0.21822 | 1 | False | Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py index de82305b2547..81be2017855c 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.py +++ b/colossalai/shardformer/examples/convergence_benchmark.py @@ -49,9 +49,12 @@ def train(args): # if multiple GPUs, shard the model if dist.get_world_size() > 1: - shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm) + tp_group = dist.new_group(backend='nccl') + shard_config = ShardConfig(tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=True, + enable_all_optimization=True) shard_former = ShardFormer(shard_config=shard_config) - model = shard_former.optimize(model) + model, _ = shard_former.optimize(model) optim = Adam(model.parameters(), lr=args.lr) num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps diff --git a/colossalai/shardformer/examples/convergence_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh index 1c281abcda6d..22f13a7cf827 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.sh +++ b/colossalai/shardformer/examples/convergence_benchmark.sh @@ -1,7 +1,7 @@ torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ --model "bert" \ --pretrain "bert-base-uncased" \ - --max_epochs 1 \ + --max_epochs 3 \ --batch_size 2 \ --lr 2.4e-5 \ --fused_layernorm False \ diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py index 9c7b76bcf0a6..2f186709d946 100644 --- a/colossalai/shardformer/examples/performance_benchmark.py +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -29,7 +29,8 @@ def data_gen_for_sequence_classification(batch_size, seq_length): intermediate_size=256, num_attention_heads=4, max_position_embeddings=128, - num_labels=16) + num_labels=16, + pad_token_id=2) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) @@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d if provider == "shard_model": shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True) shard_former = ShardFormer(shard_config=shard_config) - sharded_model = shard_former.optimize(model).cuda() + sharded_model, _ = shard_former.optimize(model) + sharded_model = sharded_model.cuda() fn = lambda: train(sharded_model, data) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms