@@ -162,20 +184,20 @@ Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):
We also train the reward model based on LLaMA-7B, which reaches the ACC of 72.06% after 1 epoch, performing almost the same as Anthropic's best RM.
### Arg List
-- --strategy: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
-- --model: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-- --pretrain: pretrain model, type=str, default=None
-- --model_path: the path of rm model(if continue to train), type=str, default=None
-- --save_path: path to save the model, type=str, default='output'
-- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
-- --max_epochs: max epochs for training, type=int, default=3
-- --dataset: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static']
-- --subset: subset of the dataset, type=str, default=None
-- --batch_size: batch size while training, type=int, default=4
-- --lora_rank: low-rank adaptation matrices rank, type=int, default=0
-- --loss_func: which kind of loss function, choices=['log_sig', 'log_exp']
-- --max_len: max sentence length for generation, type=int, default=512
-- --test: whether is only testing, if it's true, the dataset will be small
+
+- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
+- `--model`: model type, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
+- `--pretrain`: pretrain model, type=str, default=None
+- `--model_path`: the path of rm model(if continue to train), type=str, default=None
+- `--save_path`: path to save the model, type=str, default='output'
+- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
+- `--max_epochs`: max epochs for training, type=int, default=3
+- `--dataset`: dataset name, type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static']
+- `--subset`: subset of the dataset, type=str, default=None
+- `--batch_size`: batch size while training, type=int, default=4
+- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
+- `--loss_func`: which kind of loss function, choices=['log_sig', 'log_exp']
+- `--max_len`: max sentence length for generation, type=int, default=512
## Stage3 - Training model using prompts with RL
@@ -186,53 +208,89 @@ Stage3 uses reinforcement learning algorithm, which is the most complex part of
You can run the `examples/train_prompts.sh` to start PPO training.
+
You can also use the cmd following to start PPO training.
[[Stage3 tutorial video]](https://www.youtube.com/watch?v=Z8wwSHxPL9g)
-```
+```bash
torchrun --standalone --nproc_per_node=4 train_prompts.py \
- --pretrain "/path/to/LLaMa-7B/" \
- --model 'llama' \
- --strategy colossalai_zero2 \
- --prompt_dataset /path/to/your/prompt_dataset \
- --pretrain_dataset /path/to/your/pretrain_dataset \
- --rm_pretrain /your/pretrain/rm/definition \
- --rm_path /your/rm/model/path
+ --pretrain "/path/to/LLaMa-7B/" \
+ --model 'llama' \
+ --strategy colossalai_zero2 \
+ --prompt_dataset /path/to/your/prompt_dataset \
+ --pretrain_dataset /path/to/your/pretrain_dataset \
+ --rm_pretrain /your/pretrain/rm/definition \
+ --rm_path /your/rm/model/path
```
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset.
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
+**Note**: the required datasets follow the following format,
+
+- `pretrain dataset`
+
+ ```json
+ [
+ {
+ "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
+ "input": "",
+ "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id": 0
+ },
+ ...
+ ]
+ ```
+
+- `prompt dataset`
+
+ ```json
+ [
+ {
+ "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
+ "id": 0
+ },
+ {
+ "instruction": "Write a descriptive paragraph about a memorable vacation you went on",
+ "id": 1
+ },
+ ...
+ ]
+ ```
+
### Arg List
-- --strategy: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
-- --model: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
-- --pretrain: pretrain model, type=str, default=None
-- --rm_model: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None
-- --rm_pretrain: pretrain model for reward model, type=str, default=None
-- --rm_path: the path of rm model, type=str, default=None
-- --save_path: path to save the model, type=str, default='output'
-- --prompt_dataset: path of the prompt dataset, type=str, default=None
-- --pretrain_dataset: path of the ptx dataset, type=str, default=None
-- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
-- --num_episodes: num of episodes for training, type=int, default=10
-- --num_update_steps: number of steps to update policy per episode, type=int
-- --num_collect_steps: number of steps to collect experience per episode, type=int
-- --train_batch_size: batch size while training, type=int, default=8
-- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1
-- --experience_batch_size: batch size to make experience, type=int, default=8
-- --lora_rank: low-rank adaptation matrices rank, type=int, default=0
-- --kl_coef: kl_coef using for computing reward, type=float, default=0.1
-- --ptx_coef: ptx_coef using for computing policy loss, type=float, default=0.9
+
+- `--strategy`: the strategy using for training, choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], default='colossalai_zero2'
+- `--model`: model type of actor, choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom'
+- `--pretrain`: pretrain model, type=str, default=None
+- `--rm_model`: reward model type, type=str, choices=['gpt2', 'bloom', 'opt', 'llama'], default=None
+- `--rm_pretrain`: pretrain model for reward model, type=str, default=None
+- `--rm_path`: the path of rm model, type=str, default=None
+- `--save_path`: path to save the model, type=str, default='output'
+- `--prompt_dataset`: path of the prompt dataset, type=str, default=None
+- `--pretrain_dataset`: path of the ptx dataset, type=str, default=None
+- `--need_optim_ckpt`: whether to save optim ckpt, type=bool, default=False
+- `--num_episodes`: num of episodes for training, type=int, default=10
+- `--num_update_steps`: number of steps to update policy per episode, type=int
+- `--num_collect_steps`: number of steps to collect experience per episode, type=int
+- `--train_batch_size`: batch size while training, type=int, default=8
+- `--ptx_batch_size`: batch size to compute ptx loss, type=int, default=1
+- `--experience_batch_size`: batch size to make experience, type=int, default=8
+- `--lora_rank`: low-rank adaptation matrices rank, type=int, default=0
+- `--kl_coef`: kl_coef using for computing reward, type=float, default=0.1
+- `--ptx_coef`: ptx_coef using for computing policy loss, type=float, default=0.9
## Inference example - After Stage3
+
We support different inference options, including int8 and int4 quantization.
For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
-
## Attention
+
The examples are demos for the whole training process.You need to change the hyper-parameters to reach great performance.
#### data
+
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
@@ -242,14 +300,16 @@ The examples are demos for the whole training process.You need to change the hyp
## Support Model
### GPT
-- [x] GPT2-S (s)
-- [x] GPT2-M (m)
-- [x] GPT2-L (l)
-- [x] GPT2-XL (xl)
-- [x] GPT2-4B (4b)
-- [ ] GPT2-6B (6b)
+
+- [x] GPT2-S (s)
+- [x] GPT2-M (m)
+- [x] GPT2-L (l)
+- [x] GPT2-XL (xl)
+- [x] GPT2-4B (4b)
+- [ ] GPT2-6B (6b)
### BLOOM
+
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
@@ -257,6 +317,7 @@ The examples are demos for the whole training process.You need to change the hyp
- [ ] [BLOOM-175b](https://huggingface.co/bigscience/bloom)
### OPT
+
- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
- [x] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
@@ -266,10 +327,11 @@ The examples are demos for the whole training process.You need to change the hyp
- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
### [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
-- [x] LLaMA-7B
-- [x] LLaMA-13B
-- [ ] LLaMA-33B
-- [ ] LLaMA-65B
+
+- [x] LLaMA-7B
+- [x] LLaMA-13B
+- [ ] LLaMA-33B
+- [ ] LLaMA-65B
## Add your own models
@@ -282,12 +344,12 @@ if it is supported in huggingface [transformers](https://github.com/huggingface/
r you can build your own model by yourself.
### Actor model
-```
+
+```python
from ..base import Actor
from transformers.models.coati import CoatiModel
class CoatiActor(Actor):
-
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
@@ -302,7 +364,8 @@ class CoatiActor(Actor):
```
### Reward model
-```
+
+```python
from ..base import RewardModel
from transformers.models.coati import CoatiModel
@@ -325,12 +388,11 @@ class CoatiRM(RewardModel):
### Critic model
-```
+```python
from ..base import Critic
from transformers.models.coati import CoatiModel
class CoatiCritic(Critic):
-
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
diff --git a/applications/Chat/examples/community/README.md b/applications/Chat/examples/community/README.md
index cd7b9d99bf06..e14ac1767fc1 100644
--- a/applications/Chat/examples/community/README.md
+++ b/applications/Chat/examples/community/README.md
@@ -1,5 +1,9 @@
+:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
+
# Community Examples
+
---
+
We are thrilled to announce the latest updates to ColossalChat, an open-source solution for cloning ChatGPT with a complete RLHF (Reinforcement Learning with Human Feedback) pipeline.
As Colossal-AI undergoes major updates, we are actively maintaining ColossalChat to stay aligned with the project's progress. With the introduction of Community-driven example, we aim to create a collaborative platform for developers to contribute exotic features built on top of ColossalChat.
@@ -14,11 +18,12 @@ For more information about community pipelines, please have a look at this [issu
Community examples consist of both inference and training examples that have been added by the community. Please have a look at the following table to get an overview of all community examples. Click on the Code Example to get a copy-and-paste ready code example that you can try out. If a community doesn't work as expected, please open an issue and ping the author on it.
-| Example | Description | Code Example | Colab | Author |
-|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:|
-| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) |
-| Train prompts on Ray | A Ray based implementation of Train prompts example | [Training On Ray](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/ray) | - | [MisterLin1995](https://github.com/MisterLin1995) |
-|...|...|...|...|...|
+| Example | Description | Code Example | Colab | Author |
+| :------------------- | :----------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------- | :---- | ------------------------------------------------: |
+| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) |
+| Train prompts on Ray | A Ray based implementation of Train prompts example | [Training On Ray](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/ray) | - | [MisterLin1995](https://github.com/MisterLin1995) |
+| ... | ... | ... | ... | ... |
### How to get involved
+
To join our community-driven initiative, please visit the [ColossalChat GitHub repository](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples), review the provided information, and explore the codebase. To contribute, create a new issue outlining your proposed feature or enhancement, and our team will review and provide feedback. We look forward to collaborating with you on this exciting project!
diff --git a/applications/Chat/examples/community/peft/README.md b/applications/Chat/examples/community/peft/README.md
index 844bfd3d22c3..8b2edc48cd99 100644
--- a/applications/Chat/examples/community/peft/README.md
+++ b/applications/Chat/examples/community/peft/README.md
@@ -1,3 +1,5 @@
+:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
+
# Add Peft support for SFT and Prompts model training
The original implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed.
@@ -5,7 +7,9 @@ The original implementation just adopts the loralib and merges the layers into t
Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model.
# Preliminary installation
+
Since the current pypi peft package(0.2) has some bugs, please install the peft package using source.
+
```
git clone https://github.com/huggingface/peft
cd peft
@@ -13,6 +17,7 @@ pip install .
```
# Usage
+
For SFT training, just call train_peft_sft.py
Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py.
@@ -21,4 +26,5 @@ For stage-3 rlhf training, call train_peft_prompts.py.
Its arguments are almost identical to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported.
# Dataformat
+
Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt.
diff --git a/applications/Chat/examples/community/ray/README.md b/applications/Chat/examples/community/ray/README.md
index 64360bd73ddc..a679a58336a7 100644
--- a/applications/Chat/examples/community/ray/README.md
+++ b/applications/Chat/examples/community/ray/README.md
@@ -1,17 +1,31 @@
+:warning: **This content may be outdated since the major update of Colossal Chat. We will update this content soon.**
+
# ColossalAI on Ray
+
## Abstract
+
This is an experimental effort to run ColossalAI Chat training on Ray
+
## How to use?
+
### 1. Setup Ray clusters
+
Please follow the official [Ray cluster setup instructions](https://docs.ray.io/en/latest/cluster/getting-started.html) to setup an cluster with GPU support. Record the cluster's api server endpoint, it should be something similar to http://your.head.node.addrees:8265
+
### 2. Clone repo
+
Clone this project:
+
```shell
git clone https://github.com/hpcaitech/ColossalAI.git
```
+
### 3. Submit the ray job
+
```shell
python applications/Chat/examples/community/ray/ray_job_script.py http://your.head.node.addrees:8265
```
+
### 4. View your job on the Ray Dashboard
+
Open your ray cluster dashboard http://your.head.node.addrees:8265 to view your submitted training job.
diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt
index 40e6edc7ea73..5d0f9f927d17 100644
--- a/applications/Chat/examples/requirements.txt
+++ b/applications/Chat/examples/requirements.txt
@@ -1,2 +1,3 @@
pandas>=1.4.1
sentencepiece
+colossalai==0.3.1
\ No newline at end of file
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index 7585cf3ed0da..f068ea2bf5de 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -9,13 +9,15 @@
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
+from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
+from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler
@@ -58,6 +60,8 @@ def train(args):
model = LlamaActor(pretrained=args.pretrain,
lora_rank=args.lora_rank,
checkpoint=args.grad_checkpoint)
+ elif args.model == 'chatglm':
+ model = ChatGLMActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -81,6 +85,9 @@ def train(args):
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer)
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
+ elif args.model == 'chatglm':
+ tokenizer = ChatGLMTokenizer.from_pretrained(
+ "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True)
else:
raise ValueError(f'Unsupported model "{args.model}"')
@@ -99,7 +106,6 @@ def train(args):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
else:
optim = Adam(model.parameters(), lr=args.lr)
-
logger = get_dist_logger()
# configure dataset
@@ -185,7 +191,7 @@ def train(args):
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='colossalai_zero2')
- parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
+ parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default=None)
diff --git a/applications/Chat/inference/README.md b/applications/Chat/inference/README.md
index 4848817e0fd1..eea4ef5b86ca 100644
--- a/applications/Chat/inference/README.md
+++ b/applications/Chat/inference/README.md
@@ -20,21 +20,21 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar
### 8-bit
-| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
-| :---: | :---: | :---: | :---: | :---: |
-| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 |
-| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 |
-| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB |
-| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB |
+| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
+| :-------: | :---------: | :-----------------: | :----------: | :--------------------------------: |
+| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 |
+| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 |
+| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB |
+| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB |
### 4-bit
-| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
-| :---: | :---: | :---: | :---: | :---: |
-| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 |
-| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 |
-| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
-| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
+| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
+| :-------: | :---------: | :-----------------: | :----------: | :--------------------------------------------------------: |
+| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 |
+| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 |
+| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
+| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
## General setup
diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt
index e079f8a6038d..eb1a77875acb 100644
--- a/applications/Chat/requirements-test.txt
+++ b/applications/Chat/requirements-test.txt
@@ -1 +1,2 @@
pytest
+colossalai==0.3.1
\ No newline at end of file
diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt
index af7ff67861eb..e5f5ca0932a8 100644
--- a/applications/Chat/requirements.txt
+++ b/applications/Chat/requirements.txt
@@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
-colossalai>=0.2.4
+colossalai==0.3.1
torch<2.0.0, >=1.12.1
langchain
tokenizers
diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py
index 64ea1178cd0d..f9dee1bae935 100644
--- a/applications/Chat/tests/test_dataset.py
+++ b/applications/Chat/tests/test_dataset.py
@@ -11,32 +11,46 @@
from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
-
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
SFT_DATASET = [
{
- "instruction": "Provide a list of the top 10 most popular mobile games in Asia",
- "input": "",
- "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
- "id": 0
+ "instruction":
+ "Provide a list of the top 10 most popular mobile games in Asia",
+ "input":
+ "",
+ "output":
+ "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
+ "id":
+ 0
},
{
- "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
- "input": "",
- "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
- "id": 1
+ "instruction":
+ "Please provide an action plan for reducing carbon footprint on a corporate level",
+ "input":
+ "",
+ "output":
+ "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
+ "id":
+ 1
},
{
- "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
- "input": "",
- "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
- "id": 2
+ "instruction":
+ "Write a persuasive email to your boss explaining why you should have a pay raise",
+ "input":
+ "",
+ "output":
+ "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
+ "id":
+ 2
},
]
PROMPT_DATASET = [
{
- "instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
- "id": 0
+ "instruction":
+ "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
+ "id":
+ 0
},
{
"instruction": "Write a descriptive paragraph about a memorable vacation you went on",
@@ -66,14 +80,14 @@ def make_tokenizer(model: str):
elif model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.unk_token
+ elif model == "chatglm":
+ tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
else:
raise ValueError(f"Unsupported model '{model}'")
return tokenizer
-def check_content(input_ids_stripped: torch.Tensor,
- tokenizer: PreTrainedTokenizer,
- model: str):
+def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str):
if model == "opt":
# NOTE: Contrary to GPT2, OPT adds the EOS token to the beginning of every prompt.
assert input_ids_stripped[0] == tokenizer.eos_token_id
@@ -81,22 +95,25 @@ def check_content(input_ids_stripped: torch.Tensor,
elif model == "llama":
assert input_ids_stripped[0] == tokenizer.bos_token_id
input_ids_stripped = input_ids_stripped[1:]
-
+ elif model == "chatglm":
+ assert input_ids_stripped[0] == tokenizer.bos_token_id
+ assert input_ids_stripped[-1] == tokenizer.eos_token_id
+ input_ids_stripped = input_ids_stripped[1:-1]
assert torch.all(input_ids_stripped != tokenizer.pad_token_id)
assert torch.all(input_ids_stripped != tokenizer.bos_token_id)
assert torch.all(input_ids_stripped != tokenizer.eos_token_id)
assert input_ids_stripped != tokenizer.sep_token_id
assert input_ids_stripped != tokenizer.cls_token_id
- assert input_ids_stripped != tokenizer.mask_token_id
+ if model == "chatglm":
+ assert torch.all(input_ids_stripped != tokenizer.mask_token_id)
+ else:
+ assert input_ids_stripped != tokenizer.mask_token_id
-@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("max_length", [32, 1024])
@pytest.mark.parametrize("max_datasets_size", [2])
-def test_prompt_dataset(model: str,
- max_datasets_size: int,
- max_length: int):
+def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "prompt_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
@@ -119,19 +136,12 @@ def test_prompt_dataset(model: str,
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
-@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
-@pytest.mark.parametrize(["dataset_path", "subset"], [
- ("Anthropic/hh-rlhf", "harmless-base"),
- ("Dahoas/rm-static", None)
-])
+@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"),
+ ("Dahoas/rm-static", None)])
@pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024])
-def test_reward_dataset(model: str,
- dataset_path: str,
- subset: Optional[str],
- max_datasets_size: int,
- max_length: int):
+def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
data = load_dataset(dataset_path, data_dir=subset)
assert max_datasets_size <= len(data["train"]) \
and max_datasets_size <= len(data["test"])
@@ -188,15 +198,12 @@ def test_reward_dataset(model: str,
assert torch.all(r_mask)
-@pytest.mark.cpu
-@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
+
+@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024])
-def test_sft_dataset(model: str,
- dataset_path: Optional[str],
- max_dataset_size: int,
- max_length: int):
+def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int):
tokenizer = make_tokenizer(model)
if dataset_path == "yizhongw/self_instruct":
data = load_dataset(dataset_path, "super_natural_instructions")
@@ -213,6 +220,19 @@ def test_sft_dataset(model: str,
max_length=max_length)
assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET))
+ if isinstance(tokenizer, ChatGLMTokenizer):
+ for i in range(max_dataset_size):
+ assert isinstance(sft_dataset[i], dict)
+ assert list(sft_dataset[i].keys()) == ["input_ids", "labels"]
+ input_ids = sft_dataset[i]["input_ids"]
+ labels = sft_dataset[i]["labels"]
+ assert input_ids.shape == labels.shape == torch.Size([max_length])
+
+ ignore_mask = labels == IGNORE_INDEX
+ assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id
+ check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model)
+ return
+
for i in range(max_dataset_size):
assert isinstance(sft_dataset[i], dict)
assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"]
@@ -232,10 +252,7 @@ def test_sft_dataset(model: str,
if __name__ == "__main__":
- test_sft_dataset(model="bloom",
- dataset_path="yizhongw/self_instruct",
- max_dataset_size=2,
- max_length=256)
+ test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
test_reward_dataset(model="gpt2",
dataset_path="Anthropic/hh-rlhf",
@@ -246,3 +263,4 @@ def test_sft_dataset(model: str,
test_prompt_dataset(model="opt",
max_datasets_size=2,
max_length=128)
+
diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py
index bd6b3e8a5ad1..b98b3615cd28 100644
--- a/applications/Chat/tests/test_models.py
+++ b/applications/Chat/tests/test_models.py
@@ -9,22 +9,26 @@
from coati.models.generation import generate
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
+from coati.models.chatglm import ChatGLMActor
from coati.models.lora import LoraLinear, convert_to_lora_module
from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
+from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
-
-@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
-@pytest.mark.parametrize("actor_maker", [
- lambda: BLOOMActor(),
- lambda: GPTActor(),
+@pytest.mark.parametrize(
+ "actor_maker",
+ [
+ lambda: BLOOMActor(),
+ lambda: GPTActor(),
# HACK: skip llama due to long execution time
# lambda: LlamaActor(),
- lambda: OPTActor()
+ lambda: OPTActor(),
+ # lambda: ChatGLMActor(),
])
+
@pytest.mark.parametrize("generate_kwargs", [{
"max_length": 64,
"use_cache": True,
@@ -32,23 +36,15 @@
"temperature": 1.0,
"top_k": 50,
}])
-def test_generation(actor_maker: Callable[[], Actor],
- batch_size: int,
- seq_len: int,
- generate_kwargs: Dict[str, Any]
- ):
+def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
-@pytest.mark.cpu
def test_utils():
- fn_input = {
- "tensor": torch.ones((10, )),
- "mask": torch.randint(0, 2, (10, ))
- }
+ fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
fn_output = masked_mean(dim=0, **fn_input)
assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0))
@@ -56,14 +52,14 @@ def test_utils():
batch_size = 4
num_labels = 10
fn_input = {
- "r": torch.ones((batch_size, )),
+ "r": torch.ones((batch_size,)),
"kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels))
}
fn_output = compute_reward(**fn_input)
- assert fn_output.shape == (batch_size, )
+ assert fn_output.shape == (batch_size,)
batch_size = 4
seq_len = 32
@@ -80,17 +76,11 @@ def test_utils():
assert fn_output.shape == (batch_size, num_actions)
-@pytest.mark.cpu
@pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4])
-def test_lora(lora_rank: int,
- num_dim: int,
- num_layers: int):
- model = nn.ModuleList(
- [nn.Linear(num_dim, num_dim)
- for _ in range(num_layers)]
- )
+def test_lora(lora_rank: int, num_dim: int, num_layers: int):
+ model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
lora_model = convert_to_lora_module(model, lora_rank)
assert isinstance(lora_model, nn.ModuleList)
for i in range(num_layers):
@@ -103,8 +93,7 @@ def test_lora(lora_rank: int,
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
- assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
- lora_model[i].lora_B @ lora_model[i].lora_A)
+ assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
optimizer = torch.optim.Adam(lora_model.parameters())
x = torch.randn(8, num_dim)
for i in range(num_layers):
@@ -120,21 +109,22 @@ def test_lora(lora_rank: int,
lora_model[i].lora_B @ lora_model[i].lora_A)
-@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
-@pytest.mark.parametrize("models_maker", [
- lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
- lambda: (GPTActor(), GPTCritic(), GPTRM()),
+@pytest.mark.parametrize(
+ "models_maker",
+ [
+ lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
+ lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()),
+ lambda: (ChatGLMActor(), None, None),
])
@torch.no_grad()
def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
batch_size: int,
seq_len: int):
-
actor_input = {
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
"attention_mask": torch.randint(0, 2, (batch_size, seq_len))
@@ -150,29 +140,36 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
}
actor, critic, rm = models_maker()
+ if isinstance(actor, ChatGLMActor):
+ actor = actor.float()
+ tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True)
+ chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1)
+ actor_input ={
+ "input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1),
+ "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))
+ }
assert isinstance(actor, Actor)
base_actor_model = get_base_model(actor)
- assert isinstance(critic, Critic)
- base_critic_model = get_base_model(critic)
- assert isinstance(rm, RewardModel)
- base_rm_model = get_base_model(rm)
-
actor_output = actor(**actor_input)
- critic_output = critic(**critic_input)
- rm_output = rm(**rm_input)
-
assert actor_output.logits.shape[:2] == (batch_size, seq_len)
- assert critic_output.shape == (batch_size, )
- assert rm_output.shape == (batch_size, )
+
+ if critic:
+ assert isinstance(critic, Critic)
+ base_critic_model = get_base_model(critic)
+ critic_output = critic(**critic_input)
+ assert critic_output.shape == (batch_size, )
+
+ if rm:
+ assert isinstance(rm, RewardModel)
+ base_rm_model = get_base_model(rm)
+ rm_output = rm(**rm_input)
+ assert rm_output.shape == (batch_size, )
-@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100])
-def test_loss(batch_size: int,
- seq_len: int,
- num_labels: int):
+def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss()
loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels),
@@ -182,54 +179,43 @@ def test_loss(batch_size: int,
loss = PolicyLoss()
loss_input = {
- "log_probs": torch.randn(batch_size, ),
- "old_log_probs": torch.randn(batch_size, ),
- "advantages": torch.randn(batch_size, )
+ "log_probs": torch.randn(batch_size,),
+ "old_log_probs": torch.randn(batch_size,),
+ "advantages": torch.randn(batch_size,)
}
loss_output = loss(**loss_input)
loss = ValueLoss()
loss_input = {
- "values": torch.randn(batch_size, ),
- "old_values": torch.randn(batch_size, ),
- "reward": torch.randn(batch_size, )
+ "values": torch.randn(batch_size,),
+ "old_values": torch.randn(batch_size,),
+ "reward": torch.randn(batch_size,)
}
loss_output = loss(**loss_input)
loss = LogSigLoss()
loss_input = {
- "chosen_reward": torch.randn(batch_size, ),
- "reject_reward": torch.randn(batch_size, ),
+ "chosen_reward": torch.randn(batch_size,),
+ "reject_reward": torch.randn(batch_size,),
}
loss_output = loss(**loss_input)
loss = LogExpLoss()
loss_input = {
- "chosen_reward": torch.randn(batch_size, ),
- "reject_reward": torch.randn(batch_size, ),
+ "chosen_reward": torch.randn(batch_size,),
+ "reject_reward": torch.randn(batch_size,),
}
loss_output = loss(**loss_input)
if __name__ == "__main__":
- generate_kwargs = dict(max_length=40,
- use_cache=True,
- do_sample=True,
- temperature=1.0,
- top_k=50)
- test_generation(lambda: LlamaActor(),
- batch_size=4,
- seq_len=32,
- generate_kwargs=generate_kwargs)
+ generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
+ test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
test_utils()
test_lora(lora_rank=2, num_dim=8, num_layers=2)
- test_models(models_maker=lambda: (BLOOMActor(),
- BLOOMCritic(),
- BLOOMRM()),
- batch_size=8,
- seq_len=128)
+ test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
- test_loss(batch_size=8, seq_len=128, num_labels=100)
+ test_loss(batch_size=8, seq_len=128, num_labels=100)
\ No newline at end of file
diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py
new file mode 100644
index 000000000000..626a00c96d04
--- /dev/null
+++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py
@@ -0,0 +1,149 @@
+from typing import Dict, List
+
+import torch
+from torch import Tensor
+from torch.nn import Parameter
+from torch.optim import Optimizer
+
+from colossalai.interface import OptimizerWrapper
+
+from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
+
+
+class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
+
+ def __init__(self,
+ working_params: List[Parameter],
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32) -> None:
+ super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
+ max_scale)
+ self.params = working_params
+
+ def check_local_overflow(self) -> bool:
+ for p in self.params:
+ if p.grad is not None and not torch.isfinite(p.grad).all():
+ return True
+ return False
+
+
+class MixedPrecisionOptimizer(OptimizerWrapper):
+
+ def __init__(self,
+ optim: Optimizer,
+ precision: str = 'fp16',
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0):
+ super().__init__(optim)
+ if precision == 'fp16':
+ working_params = []
+ for group in self.optim.param_groups:
+ for p in group['params']:
+ working_params.append(p)
+ self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale)
+ elif precision == 'bf16':
+ self.mixed_precision = BF16MixedPrecisionMixin()
+ else:
+ raise ValueError(f'Unsupported precision: {precision}')
+ if max_norm > 0.0:
+ raise NotImplementedError('max_norm is not supported yet.')
+ self.max_norm = max_norm
+ self.working_to_master_map: Dict[Parameter, Tensor] = {}
+ self.master_to_working_map: Dict[Tensor, Parameter] = {}
+
+ # create master weights
+ for group in self.optim.param_groups:
+ master_params = []
+ for p in group['params']:
+ if p.requires_grad:
+ master_p = p
+ if p.dtype != torch.float:
+ master_p = p.detach().float()
+ self.working_to_master_map[p] = master_p
+ self.master_to_working_map[master_p] = p
+ master_params.append(master_p)
+ group['params'] = master_params
+
+ def backward(self, loss: Tensor, *args, **kwargs):
+ loss = self.mixed_precision.pre_backward(loss)
+ loss.backward(*args, **kwargs)
+
+ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
+ grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
+ tensor.backward(grad)
+
+ def zero_grad(self, *args, **kwargs):
+ for p in self.working_to_master_map.keys():
+ p.grad = None
+ self.mixed_precision.pre_zero_grad()
+ return super().zero_grad(*args, **kwargs)
+
+ def _unscale_and_clip_grads(self, total_norm: float) -> None:
+ div_scale = 1.0
+ if self.mixed_precision is not None:
+ div_scale = self.mixed_precision.get_grad_div_scale()
+
+ if self.max_norm > 0.:
+ # norm is in fact norm*scale
+ clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
+ if clip > 1:
+ div_scale = clip * div_scale
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ p.grad.data.mul_(1. / div_scale)
+
+ def _compute_grad_norm(self) -> float:
+ if self.max_norm <= 0.:
+ return 0.
+ grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
+ if len(grads) == 0:
+ return 0.
+ device = grads[0].device
+ # TODO(ver217): support tp
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
+ return total_norm.item()
+
+ def step(self, *args, **kwargs):
+ if self.mixed_precision.should_skip_step():
+ self.zero_grad()
+ return
+ # prepare grads
+ for group in self.optim.param_groups:
+ for p in group['params']:
+ working_param = self.master_to_working_map[p]
+ if p is working_param:
+ continue
+ if working_param.grad is not None:
+ p.grad = working_param.grad.data.float()
+ working_param.grad = None
+ total_norm = self._compute_grad_norm()
+ self._unscale_and_clip_grads(total_norm)
+ self.optim.step(*args, **kwargs)
+ # update working params
+ for group in self.optim.param_groups:
+ for p in group['params']:
+ working_param = self.master_to_working_map[p]
+ if p is working_param:
+ continue
+ working_param.data.copy_(p.data)
diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py
index d0c328e134ff..5b9f74b132f3 100644
--- a/colossalai/auto_parallel/offload/base_offload_module.py
+++ b/colossalai/auto_parallel/offload/base_offload_module.py
@@ -4,7 +4,7 @@
import torch
import torch.nn as nn
-from colossalai.nn.parallel.data_parallel import _cast_float
+from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager
diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
index 1a6dc7815176..0ed0742ee57e 100644
--- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
@@ -144,7 +144,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
- for dim, dim_size in enumerate(device_mesh.mesh_shape):
+ for dim, dim_size in enumerate(device_mesh.shape):
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
index 8e06cec4f463..730a90d74cf8 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py
@@ -1,5 +1,4 @@
class Registry:
- # TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
def __init__(self, name):
self.name = name
diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py
index ec3dc7fc143f..fb9dae7c9650 100644
--- a/colossalai/booster/booster.py
+++ b/colossalai/booster/booster.py
@@ -1,6 +1,6 @@
import warnings
from contextlib import contextmanager
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import torch
import torch.nn as nn
@@ -14,6 +14,7 @@
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
from .plugin import Plugin
+from .plugin.pp_plugin_base import PipelinePluginBase
__all__ = ['Booster']
@@ -23,32 +24,36 @@ class Booster:
Booster is a high-level API for training neural networks. It provides a unified interface for
training with different precision, accelerator, and plugin.
- Examples:
- ```python
- colossalai.launch(...)
- plugin = GeminiPlugin(...)
- booster = Booster(precision='fp16', plugin=plugin)
-
- model = GPT2()
- optimizer = HybridAdam(model.parameters())
- dataloader = Dataloader(Dataset)
- lr_scheduler = LinearWarmupScheduler()
- criterion = GPTLMLoss()
-
- model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
-
- for epoch in range(max_epochs):
- for input_ids, attention_mask in dataloader:
- outputs = model(input_ids, attention_mask)
- loss = criterion(outputs.logits, input_ids)
- booster.backward(loss, optimizer)
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
- ```
+
+ ```python
+ # Following is pseudocode
+
+ colossalai.launch(...)
+ plugin = GeminiPlugin(...)
+ booster = Booster(precision='fp16', plugin=plugin)
+
+ model = GPT2()
+ optimizer = HybridAdam(model.parameters())
+ dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ lr_scheduler = LinearWarmupScheduler()
+ criterion = GPTLMLoss()
+
+ model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
+
+ for epoch in range(max_epochs):
+ for input_ids, attention_mask in dataloader:
+ outputs = model(input_ids.cuda(), attention_mask.cuda())
+ loss = criterion(outputs.logits, input_ids)
+ booster.backward(loss, optimizer)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ ```
Args:
- device (str or torch.device): The device to run the training. Default: 'cuda'.
+ device (str or torch.device): The device to run the training. Default: None.
+ If plugin is not used or plugin doesn't control the device,
+ this argument will be set as training device ('cuda' will be used if argument is None).
mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
@@ -56,8 +61,8 @@ class Booster:
"""
def __init__(self,
- device: str = 'cuda',
- mixed_precision: Union[MixedPrecision, str] = None,
+ device: Optional[str] = None,
+ mixed_precision: Optional[Union[MixedPrecision, str]] = None,
plugin: Optional[Plugin] = None) -> None:
if plugin is not None:
assert isinstance(
@@ -67,13 +72,16 @@ def __init__(self,
# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
- warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
+ if device is not None:
+ warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else:
+ device = device or 'cuda'
self.accelerator = Accelerator(device)
# set precision
if self.plugin and self.plugin.control_precision():
- warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
+ if mixed_precision is not None:
+ warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
@@ -104,14 +112,19 @@ def boost(
lr_scheduler: Optional[LRScheduler] = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
- Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
+ Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
Args:
- model (nn.Module): The model to be boosted.
- optimizer (Optimizer): The optimizer to be boosted.
- criterion (Callable): The criterion to be boosted.
- dataloader (DataLoader): The dataloader to be boosted.
- lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
+ model (nn.Module): Convert model into a wrapped model for distributive training.
+ The model might be decorated or partitioned by plugin's strategy after execution of this method.
+ optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
+ The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
+ criterion (Callable, optional): The function that calculates loss. Defaults to None.
+ dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
+ lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.
+
+ Returns:
+ List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case
@@ -132,26 +145,49 @@ def boost(
return model, optimizer, criterion, dataloader, lr_scheduler
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
- """Backward pass.
+ """Execution of backward during training step.
Args:
- loss (torch.Tensor): The loss to be backpropagated.
+ loss (torch.Tensor): The loss for backpropagation.
optimizer (Optimizer): The optimizer to be updated.
"""
- # TODO: implement this method with plugin
+ # TODO(frank lee): implement this method with plugin
optimizer.backward(loss)
def execute_pipeline(self,
data_iter: Iterator,
model: nn.Module,
- criterion: Callable[[torch.Tensor], torch.Tensor],
- optimizer: Optimizer,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[Optimizer] = None,
return_loss: bool = True,
- return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
- # TODO: implement this method
- # run pipeline forward backward pass
- # return loss or outputs if needed
- pass
+ return_outputs: bool = False) -> Dict[str, Any]:
+ """
+ Execute forward & backward when utilizing pipeline parallel.
+ Return loss or Huggingface style model outputs if needed.
+
+ Warning: This function is tailored for the scenario of pipeline parallel.
+ As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
+ when doing pipeline parallel training with booster, which will cause unexpected errors.
+
+ Args:
+ data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
+ 1. wrap the dataloader to iterator through: iter(dataloader)
+ 2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
+ model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
+ criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ 'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
+ optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
+ return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
+ return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
+
+ Returns:
+ Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
+ ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
+ ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
+ """
+ assert isinstance(self.plugin,
+ PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
+ return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
@@ -168,7 +204,7 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model, optimizer)
- def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
+ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.
Args:
@@ -188,7 +224,7 @@ def save_model(self,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
- use_safetensors: bool = False):
+ use_safetensors: bool = False) -> None:
"""Save model to checkpoint.
Args:
@@ -196,7 +232,7 @@ def save_model(self,
checkpoint (str): Path to the checkpoint. It must be a local path.
It is a file path if ``shard=False``. Otherwise, it is a directory path.
shard (bool, optional): Whether to save checkpoint a sharded way.
- If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
+ If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
@@ -211,7 +247,7 @@ def save_model(self,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors)
- def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
+ def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
"""Load optimizer from checkpoint.
Args:
@@ -230,7 +266,7 @@ def save_optimizer(self,
shard: bool = False,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
- size_per_shard: int = 1024):
+ size_per_shard: int = 1024) -> None:
"""
Save optimizer to checkpoint.
@@ -247,7 +283,7 @@ def save_optimizer(self,
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
- def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.
Args:
@@ -256,7 +292,7 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
- def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Load lr scheduler from checkpoint.
Args:
diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py
index a3b87b5f11d3..f48bf38bd724 100644
--- a/colossalai/booster/plugin/__init__.py
+++ b/colossalai/booster/plugin/__init__.py
@@ -1,9 +1,10 @@
from .gemini_plugin import GeminiPlugin
+from .hybrid_parallel_plugin import HybridParallelPlugin
from .low_level_zero_plugin import LowLevelZeroPlugin
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
-__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']
+__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin']
import torch
from packaging import version
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 0f5ba6e9a6da..de03ba27bfda 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -1,13 +1,11 @@
import gc
import logging
import os
-import warnings
from pathlib import Path
-from typing import Callable, Iterator, List, Optional, Tuple, Union
+from typing import Callable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
-from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
@@ -16,16 +14,15 @@
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
- get_shard_filename,
load_shard_state_dict,
+ save_config_file,
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
-from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
-from colossalai.zero.gemini import ZeroOptimizer
+from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase
@@ -111,6 +108,7 @@ def save_sharded_model(self,
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
+ save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
@@ -132,11 +130,7 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
As there is communication when getting state dict, this must be called on all processes.
"""
- # If optimizer is wrapped, unwrap it.
- if isinstance(optimizer, OptimizerWrapper):
- optimizer = optimizer.unwrap()
-
- assert isinstance(optimizer, ZeroOptimizer)
+ assert isinstance(optimizer, GeminiOptimizer)
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@@ -183,11 +177,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Pa
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
- # If optimizer is wrapped, unwrap it.
- if isinstance(optimizer, OptimizerWrapper):
- optimizer = optimizer.unwrap()
-
- assert isinstance(optimizer, ZeroOptimizer)
+ assert isinstance(optimizer, GeminiOptimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
@@ -220,47 +210,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
super().save_lr_scheduler(lr_scheduler, checkpoint)
-class GeminiModel(ModelWrapper):
-
- def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
- super().__init__(module)
- self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
-
- def unwrap(self):
- # as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
- return self.module
-
-
-class GeminiOptimizer(OptimizerWrapper):
-
- def __init__(self,
- module: GeminiDDP,
- optimizer: Optimizer,
- zero_optim_config: dict,
- optim_kwargs: dict,
- verbose: bool = False) -> None:
- optimizer = zero_optim_wrapper(module,
- optimizer,
- optim_config=zero_optim_config,
- **optim_kwargs,
- verbose=verbose)
- super().__init__(optimizer)
-
- def backward(self, loss: Tensor, *args, **kwargs):
- self.optim.backward(loss)
-
- def clip_grad_by_norm(self,
- max_norm: Union[float, int],
- norm_type: Union[float, int] = 2,
- error_if_nonfinite: bool = False,
- *args,
- **kwargs) -> Tensor:
- warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
-
- def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
- raise NotImplementedError('Gemini does not support clip_grad_by_value')
-
-
class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
@@ -277,8 +226,20 @@ class GeminiPlugin(DPPluginBase):
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args:
- device (torch.device): device to place the model.
- placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
+ chunk_config_dict (dict, optional): chunk configuration dictionary.
+ chunk_init_device (torch.device, optional): device to initialize the chunk.
+ placement_policy (str, optional): "static" and "auto". Defaults to "static".
+ shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
+ If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
+ offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
+ If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
+ offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
+ For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
+ If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
+ When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
+ Defaults to 0.0.
+ warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
+ steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
@@ -310,8 +271,14 @@ class GeminiPlugin(DPPluginBase):
def __init__(
self,
- device: Optional[torch.device] = None,
- placement_policy: str = "cpu",
+ chunk_config_dict: Optional[dict] = None,
+ chunk_init_device: Optional[torch.device] = None,
+ placement_policy: str = "static",
+ shard_param_frac: float = 1.0, # only for static placement
+ offload_optim_frac: float = 0.0, # only for static placement
+ offload_param_frac: float = 0.0, # only for static placement
+ warmup_non_model_data_ratio: float = 0.8, # only for auto placement
+ steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
@@ -335,8 +302,14 @@ def __init__(
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict(
- device=(device or get_current_device()),
+ chunk_config_dict=chunk_config_dict,
+ chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
+ shard_param_frac=shard_param_frac,
+ offload_optim_frac=offload_optim_frac,
+ offload_param_frac=offload_param_frac,
+ warmup_non_model_data_ratio=warmup_non_model_data_ratio,
+ steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
@@ -393,12 +366,15 @@ def configure(
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
- model = GeminiModel(model, self.gemini_config, self.verbose)
+ model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
- optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
- self.verbose)
+ optimizer = GeminiOptimizer(optimizer,
+ model.unwrap(),
+ **self.zero_optim_config,
+ **self.optim_kwargs,
+ verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
new file mode 100644
index 000000000000..125a9ccca1b5
--- /dev/null
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -0,0 +1,520 @@
+import random
+from contextlib import nullcontext
+from functools import partial
+from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+from torch.nn import Module, SyncBatchNorm
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils._pytree import tree_map
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
+from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.zero.low_level import LowLevelZeroOptimizer
+
+from .pp_plugin_base import PipelinePluginBase
+
+DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
+
+
+def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
+ if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
+ return x.to(dtype)
+ return x
+
+
+class HybridParallelModule(ModelWrapper):
+
+ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
+ ddp_config: dict) -> None:
+
+ self.stage_manager = shard_config.pipeline_stage_manager
+ self.dp_group = dp_group
+
+ shardformer = ShardFormer(shard_config)
+ module, self.shared_params = shardformer.optimize(module)
+
+ # setting process groups for shared parameters
+ self.shared_param_process_groups = []
+ for shared_param in self.shared_params:
+ if len(shared_param) > 0:
+ self.shared_param_process_groups.append(
+ self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
+
+ # setting mixed_precision
+ self.mixed_precision = None
+ if precision == 'fp16':
+ self.mixed_precision = torch.float16
+ elif precision == 'bf16':
+ self.mixed_precision = torch.bfloat16
+ if self.mixed_precision is not None:
+ module = module.to(self.mixed_precision)
+ module = module.cuda()
+
+ # setting input type cast when using mixed precision
+ self.convert_fn = None
+ if self.mixed_precision is not None:
+ self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
+
+ # setting ddp configs
+ if use_ddp:
+ # convert model to sync bn
+ module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
+ # wrap the model with PyTorch DDP
+ module = DDP(module, process_group=dp_group, **ddp_config)
+
+ super().__init__(module)
+
+ def sync_shared_params(self):
+ for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
+ if self.stage_manager.stage in shared_param:
+ param = shared_param[self.stage_manager.stage]
+ dist.all_reduce(param.grad, group=group)
+ dist.barrier()
+
+ def no_sync(self) -> Iterator[None]:
+ # no sync grads across data parallel
+ return nullcontext()
+
+ def sync_grads(self):
+ # sync grad across data parallel
+ if self.dp_group.size() == 1:
+ return
+ for p in self.module.parameters():
+ if p.grad is not None:
+ dist.all_reduce(p.grad, group=self.dp_group)
+ p.grad.div_(self.dp_group.size())
+
+ def forward(self, *args, **kwargs):
+ if self.convert_fn is not None:
+ args = tree_map(self.convert_fn, args)
+ kwargs = tree_map(self.convert_fn, kwargs)
+ return super().forward(*args, **kwargs)
+
+ def unwrap(self):
+ module = super().unwrap()
+ if isinstance(module, DDP):
+ module = module.module
+ return module
+
+
+def get_param_info(optim: Optimizer):
+ # Get a backup of necessary information of parameters for future use, which includes:
+ # 1. A complete param_group, with params in the form of param_id
+ # 2. A mapping from param address (obtained using id(param)) to integer param_id
+ # 3. A mapping from integer param_id to param address.
+ # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
+ # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
+
+ if optim is None:
+ return {}
+ param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
+ start_index = 0
+ for group in optim.param_groups:
+
+ packed_group = {k: v for k, v in group.items() if k != 'params'}
+ packed_group['params'] = []
+
+ for param_id, param in enumerate(group['params'], start_index):
+ original_shape = param.shape if isinstance(param, torch.Tensor) else None
+ packed_group['params'].append(param_id)
+ param_info['param2id'][id(param)] = param_id
+ param_info['id2param'][param_id] = id(param)
+ param_info['param2shape'][id(param)] = original_shape
+
+ param_info['param_groups'].append(packed_group)
+ start_index += len(group['params'])
+
+ return param_info
+
+
+def init_pipeline_optimizer(optim: Optimizer, model: Module):
+ model_params = set(model.parameters())
+ new_param_groups = []
+ for group in optim.param_groups:
+ params = [p for p in group['params'] if p in model_params]
+ new_param_groups.append({**group, 'params': params})
+ optim.__setstate__({'param_groups': new_param_groups})
+
+
+class HybridParallelNaiveOptimizer(OptimizerWrapper):
+
+ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
+ self.param_info = param_info
+ if use_pipeline:
+ init_pipeline_optimizer(optim, model)
+ super().__init__(optim)
+
+
+class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
+
+ def __init__(self,
+ optim: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ precision: str = 'fp16',
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0):
+ self.param_info = param_info
+ if use_pipeline:
+ init_pipeline_optimizer(optim, model)
+ super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
+ hysteresis, max_scale, max_norm)
+
+
+class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ initial_scale: int = 2**16, # grad scaler config
+ min_scale: int = 1,
+ growth_factor: float = 2.,
+ backoff_factor: float = .5,
+ growth_interval: int = 2000,
+ hysteresis: int = 2,
+ max_scale: int = 2**24,
+ clip_grad_norm: float = 0.0, # grad clipping
+ verbose: bool = False,
+ reduce_bucket_size: int = 1024 * 1024, # communication
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ partition_grad: bool = False, # stage 2 flag
+ cpu_offload: bool = False, # cpu offload
+ dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
+ tp_process_group: Optional[ProcessGroup] = None, # if using tp
+ forced_dtype: Optional[torch.dtype] = None):
+ self.param_info = param_info
+ if use_pipeline:
+ init_pipeline_optimizer(optimizer, model)
+ super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
+ hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
+ overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
+ forced_dtype)
+
+
+class HybridParallelPlugin(PipelinePluginBase):
+ """
+ Plugin for Hybrid Parallel Training.
+ Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
+ The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
+
+ Example:
+ >>> from colossalai.booster import Booster
+ >>> from colossalai.booster.plugin import HybridParallelPlugin
+
+ >>> model, train_dataset, optimizer, criterion = ...
+ >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
+
+ >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ >>> booster = Booster(plugin=plugin)
+ >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
+
+ Args:
+ tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
+ pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
+ precision (str, optional): Specifies the precision of parameters during training.
+ Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
+ Defaults to 'fp16'.
+ zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
+ When set to 0, ZeRO will not be used. Defaults to 0.
+ enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
+ Currently all the optimization methods include fused normalization, flash attention and JIT.
+ Defaults to False.
+ enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
+ enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
+ enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
+ num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
+ microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
+ Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
+ If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
+ initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
+ min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
+ growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
+ backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
+ growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
+ hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
+ max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
+ max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
+ broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
+ ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
+ find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
+ check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
+ gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
+ static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
+ zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
+ cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
+ communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
+ overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
+ """
+
+ def __init__(self,
+ tp_size: int,
+ pp_size: int,
+ precision: str = 'fp16',
+ zero_stage: int = 0,
+ enable_all_optimization: bool = False,
+ enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_jit_fused: bool = False,
+ enable_sequence_parallelism: bool = False,
+ enable_sequence_overlap: bool = False,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ broadcast_buffers: bool = True,
+ ddp_bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ zero_bucket_size_in_m: int = 12,
+ cpu_offload: bool = False,
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True) -> None:
+
+ super().__init__()
+ assert dist.get_world_size() % (
+ tp_size * pp_size
+ ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
+
+ if enable_sequence_parallelism:
+ assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
+
+ self.tp_size = tp_size
+ self.pp_size = pp_size
+ self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ self.precision = precision
+ self.zero_stage = zero_stage
+ self.cpu_offload = cpu_offload
+ self.enable_all_optimization = enable_all_optimization
+ self.enable_fused_normalization = enable_fused_normalization
+ self.enable_flash_attention = enable_flash_attention
+ self.enable_jit_fused = enable_jit_fused
+ self.enable_sequence_parallelism = enable_sequence_parallelism
+ self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
+ self.stage_manager = None
+ self.schedule = None
+ assert zero_stage in (0, 1, 2)
+ if self.pp_size > 1:
+ assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
+ assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
+ self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
+ self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
+ num_microbatches=num_microbatches,
+ microbatch_size=microbatch_size)
+ self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
+ self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
+ self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
+ self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
+ pipeline_stage_manager=self.stage_manager,
+ enable_tensor_parallelism=self.tp_size > 1,
+ enable_all_optimization=self.enable_all_optimization,
+ enable_fused_normalization=self.enable_fused_normalization,
+ enable_flash_attention=self.enable_flash_attention,
+ enable_jit_fused=self.enable_jit_fused,
+ enable_sequence_parallelism=enable_sequence_parallelism,
+ enable_sequence_overlap=enable_sequence_overlap)
+ self.amp_config = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ )
+
+ self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=ddp_bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph)
+
+ self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(self.zero_stage == 2))
+
+ self.max_norm = max_norm
+
+ @property
+ def enable_pipeline_parallelism(self) -> bool:
+ return self.pp_size > 1
+
+ def supported_devices(self) -> List[str]:
+ return ['cuda']
+
+ def supported_precisions(self) -> List[str]:
+ return ['fp16', 'bf16', 'fp32']
+
+ def control_device(self) -> bool:
+ return True
+
+ def control_precision(self) -> bool:
+ return True
+
+ def support_no_sync(self) -> bool:
+ return False
+
+ def control_checkpoint_io(self) -> bool:
+ return True
+
+ def configure(
+ self,
+ model: Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
+ ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
+ param_info = get_param_info(optimizer)
+ if not isinstance(model, ModelWrapper):
+ use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
+ model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
+ self.ddp_config)
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ if self.zero_stage == 0:
+ if self.precision in ['fp16', 'bf16']:
+ optimizer = HybridParallelAMPOptimizer(optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ precision=self.precision,
+ max_norm=self.max_norm,
+ **self.amp_config)
+ self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
+ optimizer.master_to_working_map)
+ else:
+ optimizer = HybridParallelNaiveOptimizer(optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info)
+ else:
+ assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
+ assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
+ optimizer = HybridParallelZeroOptimizer(optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ dp_process_group=self.dp_group,
+ tp_process_group=self.tp_group,
+ verbose=True,
+ clip_grad_norm=self.max_norm,
+ **self.zero_config,
+ **self.amp_config)
+ self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
+ optimizer._param_store.master_to_working_param)
+
+ return model, optimizer, criterion, dataloader, lr_scheduler
+
+ def execute_pipeline(self,
+ data_iter: Iterator,
+ model: HybridParallelModule,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer,
+ HybridParallelZeroOptimizer]] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False) -> dict:
+ assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled'
+ # return loss or outputs if needed
+ ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
+ with ctx:
+ outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss,
+ return_outputs)
+ model.sync_shared_params()
+ if isinstance(optimizer, HybridParallelZeroOptimizer):
+ optimizer.sync_grad()
+ else:
+ model.sync_grads()
+ return outputs
+
+ def prepare_dataloader(self,
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ **kwargs):
+ r"""
+ Prepare a dataloader for distributed training. The dataloader will be wrapped by
+ `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
+
+
+ Args:
+ dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
+ seed (int, optional): Random worker seed for sampling, defaults to 1024.
+ add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
+ the batch size, then the last batch will be smaller, defaults to False.
+ pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
+ num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
+ kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
+ `DataLoader
`_.
+
+ Returns:
+ :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
+ """
+ _kwargs = kwargs.copy()
+ sampler = DistributedSampler(dataset,
+ num_replicas=self.pg_mesh.size(DP_AXIS),
+ rank=self.pg_mesh.coordinate(DP_AXIS),
+ shuffle=shuffle)
+
+ # Deterministic dataloader
+ def seed_worker(worker_id):
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs)
+
+ def get_checkpoint_io(self) -> CheckpointIO:
+ self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ return self.checkpoint_io
+
+ def no_sync(self, model: Module) -> Iterator[None]:
+ raise NotImplementedError
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index 616b218b2070..9adb4beec9b9 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -3,6 +3,7 @@
import warnings
from functools import partial
from pathlib import Path
+from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
@@ -17,12 +18,17 @@
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
+ load_param_groups_into_optimizer,
+ load_shard_state_dict,
+ load_states_into_optimizer,
save_param_groups,
save_state_dict,
+ sharded_optimizer_loading_epilogue,
+ unwrap_optimizer,
)
-from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
-from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
+from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@@ -39,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
+class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
+
+ def __init__(self, module: nn.Module, precision: str) -> None:
+ super().__init__(module)
+ self.dtype = None
+ if precision == 'fp16':
+ self.dtype = torch.float16
+ elif precision == 'bf16':
+ self.dtype = torch.bfloat16
+ if self.dtype is not None:
+ module = module.to(self.dtype)
+ module = module.to(get_current_device())
+ self.module = module
+ self.convert_fn = None
+ if self.dtype is not None:
+ self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
+
+ def forward(self, *args, **kwargs):
+ if self.convert_fn is not None:
+ args = tree_map(self.convert_fn, args)
+ kwargs = tree_map(self.convert_fn, kwargs)
+ return super().forward(*args, **kwargs)
+
+ def unwrap(self):
+ # TODO(ver217): this is a workaround for loading model
+ return self
+
+
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@@ -126,44 +160,70 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
- super().load_sharded_optimizer(optimizer, index_file_path, prefix)
- current_rank_state_dict = optimizer.optim.state_dict()['state']
- for param_idx, state in current_rank_state_dict.items():
- for k, v in state.items():
- if isinstance(v, torch.Tensor) and k != 'step':
- padding_size = (self.coordinator.world_size -
- v.numel() % self.coordinator.world_size) % self.coordinator.world_size
- with torch.no_grad():
- v = v.flatten()
- if padding_size > 0:
- v = torch.nn.functional.pad(v, [0, padding_size])
- v_list = v.split(v.numel() // self.coordinator.world_size)
- current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
-
-
-class LowLevelZeroModel(ModelWrapper):
-
- def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
- super().__init__(module)
- self.dtype = None
- if precision == 'fp16':
- self.dtype = torch.float16
- elif precision == 'bf16':
- self.dtype = torch.bfloat16
- module = zero_model_wrapper(module, zero_stage=stage)
- if self.dtype is not None:
- module = module.to(self.dtype)
- module = module.to(get_current_device())
- self.module = module
- self.convert_fn = None
- if self.dtype is not None:
- self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
-
- def forward(self, *args, **kwargs):
- if self.convert_fn is not None:
- args = tree_map(self.convert_fn, args)
- kwargs = tree_map(self.convert_fn, kwargs)
- return super().forward(*args, **kwargs)
+ # If optimizer is wrapped, unwrap it.
+ if isinstance(optimizer, OptimizerWrapper):
+ optimizer = unwrap_optimizer(optimizer)
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
+ Lacking param group file under current directory.')
+ id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
+
+ checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
+
+ for shard_file in checkpoint_files:
+ state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
+ # shard state dict
+ for param_idx, state in state_dict.items():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+ padding_size = (self.coordinator.world_size -
+ v.numel() % self.coordinator.world_size) % self.coordinator.world_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ v_list = v.split(v.numel() // self.coordinator.world_size)
+ state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
+ load_states_into_optimizer(optimizer, state_dict, id_map)
+
+ sharded_optimizer_loading_epilogue(optimizer)
+
+ def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
+ use_safetensors: bool):
+ assert isinstance(model, LowLevelZeroModel)
+ super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
+
+ def save_sharded_model(self,
+ model: nn.Module,
+ checkpoint_path: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ max_shard_size: int = 1024,
+ use_safetensors: bool = False):
+ assert isinstance(model, LowLevelZeroModel)
+ super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
+ use_safetensors)
+
+ def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
+ assert isinstance(model, LowLevelZeroModel)
+ super().load_unsharded_model(model.module, checkpoint, strict)
+ model.update_master_params()
+
+ def load_sharded_model(self,
+ model: LowLevelZeroModel,
+ checkpoint_index_file: Path,
+ strict: bool = False,
+ use_safetensors: bool = False,
+ load_sub_module: bool = True):
+ assert isinstance(model, LowLevelZeroModel)
+ super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
+ model.update_master_params()
class LowLevelZeroPlugin(DPPluginBase):
@@ -223,22 +283,24 @@ def __init__(
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
-
+ assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage
self.precision = precision
- self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
- communication_dtype=communication_dtype,
- overlap_communication=overlap_communication,
- cpu_offload=cpu_offload)
- self.optim_kwargs = dict(initial_scale=initial_scale,
- growth_factor=growth_factor,
- backoff_factor=backoff_factor,
- growth_interval=growth_interval,
- hysteresis=hysteresis,
- min_scale=min_scale,
- max_scale=max_scale,
- max_norm=max_norm,
- norm_type=norm_type)
+ self.zero_optim_kwargs = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ clip_grad_norm=max_norm,
+ reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(stage == 2),
+ )
self.verbose = verbose
# set class name with stage, for better error message
@@ -269,15 +331,15 @@ def configure(
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
- model = LowLevelZeroModel(model, self.stage, self.precision)
+ model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
- optimizer = zero_optim_wrapper(model.unwrap(),
- optimizer,
- optim_config=self.zero_optim_config,
- **self.optim_kwargs,
- verbose=self.verbose)
+ optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
+ **self.zero_optim_kwargs,
+ verbose=self.verbose)
+ # inject update_master_params
+ model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py
new file mode 100644
index 000000000000..f52844db082f
--- /dev/null
+++ b/colossalai/booster/plugin/pp_plugin_base.py
@@ -0,0 +1,21 @@
+from abc import abstractmethod
+from typing import Any, Callable, Iterator, Optional
+
+import torch
+
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+
+from .plugin_base import Plugin
+
+
+class PipelinePluginBase(Plugin):
+
+ @abstractmethod
+ def execute_pipeline(self,
+ data_iter: Iterator,
+ model: ModelWrapper,
+ criterion: Callable[[Any, Any], torch.Tensor],
+ optimizer: Optional[OptimizerWrapper] = None,
+ return_loss: bool = True,
+ return_outputs: bool = False) -> dict:
+ pass
diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py
index c25048e25754..07b1f81dace6 100644
--- a/colossalai/checkpoint_io/__init__.py
+++ b/colossalai/checkpoint_io/__init__.py
@@ -1,5 +1,6 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
+from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO
from .index_file import CheckpointIndexFile
-__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
+__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
index 83e4bdcc863b..faaf1d22722a 100644
--- a/colossalai/checkpoint_io/general_checkpoint_io.py
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -23,6 +23,7 @@
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
+ save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
@@ -78,8 +79,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
- del state_dict
- gc.collect()
sharded_optimizer_loading_epilogue(optimizer)
@@ -185,6 +184,7 @@ def save_sharded_model(self,
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
new file mode 100644
index 000000000000..fef5b0d16d60
--- /dev/null
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -0,0 +1,702 @@
+import copy
+import gc
+import logging
+import os
+from pathlib import Path
+from shutil import rmtree
+from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+
+from colossalai.interface import OptimizerWrapper
+
+from .general_checkpoint_io import GeneralCheckpointIO
+from .index_file import CheckpointIndexFile
+from .utils import (
+ StateDictSharder,
+ gather_distributed_param,
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ is_safetensors_available,
+ load_shard_state_dict,
+ load_state_dict_into_model,
+ load_states_into_optimizer,
+ save_config_file,
+ save_param_groups,
+ save_state_dict_shards,
+ search_tp_partition_dim,
+ sharded_optimizer_loading_epilogue,
+)
+
+try:
+ from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
+except ImportError:
+ _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
+
+
+class HypridParallelCheckpointIO(GeneralCheckpointIO):
+ """
+ CheckpointIO for Hybrid Parallel Training.
+
+ Args:
+ dp_group (ProcessGroup): Process group along data parallel dimension.
+ pp_group (ProcessGroup): Process group along pipeline parallel dimension.
+ tp_group (ProcessGroup): Process group along tensor parallel dimension.
+ zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
+ verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
+ """
+
+ def __init__(self,
+ dp_group: ProcessGroup,
+ pp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ zero_stage: int,
+ verbose: bool = True) -> None:
+ super().__init__()
+ self.dp_group = dp_group
+ self.pp_group = pp_group
+ self.tp_group = tp_group
+ self.dp_rank = dist.get_rank(self.dp_group)
+ self.tp_rank = dist.get_rank(self.tp_group)
+ self.pp_rank = dist.get_rank(self.pp_group)
+ self.dp_size = dist.get_world_size(dp_group)
+ self.pp_size = dist.get_world_size(pp_group)
+ self.tp_size = dist.get_world_size(tp_group)
+ self.use_zero = (zero_stage > 0)
+ self.verbose = verbose
+ self.working_to_master_map = None
+ self.master_to_working_map = None
+
+ @staticmethod
+ def _model_sharder(model: nn.Module,
+ prefix: str = '',
+ keep_vars: bool = False,
+ size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
+ # An internel method that breaks state_dict of model into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+
+ # Save parameters.
+ for name, param in model.named_parameters():
+ if param is None:
+ continue
+ # Gather tensor pieces when using tensor parallel.
+ param_ = gather_distributed_param(param, keep_vars=False)
+ block, block_size = state_dict_sharder.append_param(prefix + name, param_)
+ if block is not None:
+ yield block, block_size
+
+ # Save buffers.
+ for name, buf in model.named_buffers():
+ if buf is not None and name not in model._non_persistent_buffers_set:
+ buffer = buf if keep_vars else buf.detach()
+ block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
+ if block is not None:
+ yield block, block_size
+
+ # Save extra states.
+ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+ if getattr(model.__class__, "get_extra_state",
+ torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+ extra_state = model.get_extra_state()
+ block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ @staticmethod
+ def _optimizer_sharder(optimizer: OptimizerWrapper,
+ use_zero: bool,
+ dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
+ size_per_shard: int = 1024):
+
+ # An internel method that breaks state_dict of optimizer into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+ param_info = optimizer.param_info
+
+ for param, state in optimizer.optim.state.items():
+
+ if param is None:
+ continue
+
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ param_id = param_info['param2id'][id(working_param)]
+ original_shape = param_info['param2shape'][id(working_param)]
+ state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
+ working_param,
+ original_shape=original_shape,
+ dp_group=dp_group,
+ tp_group=tp_group,
+ use_zero=use_zero,
+ inplace=False)
+
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def save_sharded_model(self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False) -> None:
+ """
+ Save sharded model checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
+ - Multiple files that store state tensors of models.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_model.-000XX.bin"
+
+
+ Args:
+ model (nn.Module): Model on local device to be saved.
+ checkpoint (str): Checkpointing path which should be a directory path.
+ gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
+ prefix (str, optional): Perfix of file to save. Defaults to None.
+ size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
+ use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
+ """
+
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of model.
+ # So only let the device with dp_rank == 0 save the model.
+ if self.dp_rank != 0:
+ return
+
+ # Then collect the sharded parameters & buffers along tp_group.
+ # Only devices with tp_rank == 0 are responsible for model saving.
+ state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
+ weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = (self.tp_rank == 0)
+
+ if self.pp_size == 1:
+ # When pipeline is not used, save the model shards as in general checkpointIO
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors)
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint)
+ if self.verbose:
+ logging.info(f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
+
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
+ weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ use_pp_format=True)
+ if control_saving:
+ assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ return
+
+ dist.barrier(self.pp_group)
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.pp_rank == 0:
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for weight, weight_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(weight, weight_filename)
+
+ final_index_file.write_index_file(final_index_file_path)
+ save_config_file(model, checkpoint)
+ rmtree(tmp_index_file_folder)
+ if self.verbose:
+ logging.info(f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}.")
+
+ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
+ """
+ Load sharded model with the given path to index file of checkpoint folder.
+
+ Args:
+ model (nn.Module): The model to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ strict (bool, optional): For name matching during loading state_dict. Defaults to False.
+ This argument should be manually set to False since params on same device might be stored in different files.
+ """
+
+ # Check whether the checkpoint uses safetensors.
+ use_safetensors = False
+ if "safetensors" in checkpoint_index_file.name:
+ use_safetensors = True
+
+ if use_safetensors and not is_safetensors_available():
+ raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ strict = False
+
+ # Load params & buffers to model.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+
+ def _load(name: str):
+ if name not in weight_map:
+ raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
+ filename = weight_map[name]
+
+ # If this param/buffer has been loaded before, directly return.
+ if filename in loaded_file:
+ return
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
+ missing_keys = []
+
+ load_state_dict_into_model(model,
+ state_dict,
+ missing_keys=missing_keys,
+ strict=strict,
+ load_sub_module=True)
+ loaded_file.add(filename)
+
+ # Load parameters.
+ for name, _ in model.named_parameters():
+ _load(name)
+
+ # Load buffers.
+ non_persistent_buffers = set()
+ for n, m in model.named_modules():
+ non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
+ for name, buf in model.named_buffers():
+ if buf is not None and name not in non_persistent_buffers:
+ _load(name)
+
+ # Load extra states.
+ extra_state_key = _EXTRA_STATE_KEY_SUFFIX
+ if getattr(model.__class__, "get_extra_state",
+ torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
+ _load(extra_state_key)
+
+ # Update master params if mixed-precision training is enabled.
+ with torch.no_grad():
+ if self.working_to_master_map is not None:
+ for param in model.parameters():
+ if (param is None) or (id(param) not in self.working_to_master_map):
+ continue
+ master_param = self.working_to_master_map[id(param)]
+ if self.use_zero:
+ # master_param is sharded under Zero setting
+ padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
+ if padding_size > 0:
+ padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
+ else:
+ padded_param = param.data.view(-1)
+ sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
+ master_param.data.copy_(sharded_param.data)
+ else:
+ master_param.data.copy_(param.data)
+
+ if self.verbose:
+ logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def save_sharded_optimizer(self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024):
+ """
+ Save sharded optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files that store state tensors of optimizers.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_optim.-000XX.bin"
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+ checkpoint (str): Path to save optimizer state_dict
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ prefix (str): Perfix of file to save
+ size_per_shard (int): Max file size of each file shard that store state tensors
+ """
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of states when zero is not used.
+ # In this case only let the device with dp_rank == 0 save the model.
+ if not self.use_zero and self.dp_rank != 0:
+ return
+
+ # Then collect the sharded states along dp_group(if using zero)/tp_group.
+ # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
+ state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(
+ optimizer,
+ use_zero=self.use_zero,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ master_to_working_map=self.master_to_working_map,
+ size_per_shard=size_per_shard)
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
+
+ if self.pp_size == 1:
+ # When pipeline is not used, save the optimizer shards as in general checkpointIO
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving)
+
+ if control_saving:
+ # Store param groups.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(optimizer.param_info, group_file_path)
+ # Store index file.
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ if self.verbose:
+ logging.info(f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}.")
+
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ use_pp_format=True)
+
+ if control_saving:
+ assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ return
+
+ dist.barrier(self.pp_group)
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.pp_rank == 0:
+
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for param_id, state_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(param_id, state_filename)
+
+ # Store param groups.
+ final_index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(optimizer.param_info, group_file_path)
+
+ final_index_file.write_index_file(final_index_file_path)
+ rmtree(tmp_index_file_folder)
+
+ if self.verbose:
+ logging.info(f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}.")
+
+ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
+ """
+ Load sharded optimizer with the given path to index file of checkpoint folder.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ prefix (str): Not used.
+ """
+
+ def _get_param_id_from_optimizer_param(param: torch.Tensor,
+ master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
+ if master_to_working_map is not None:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ return optimizer.param_info['param2id'][id(working_param)]
+
+ # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
+ # When Zero is used, the mapped parameter objects should be fp32 master parameters.
+ # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
+ id_map = {}
+ for pg in optimizer.optim.param_groups:
+ for param in pg['params']:
+ param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
+ id_map[param_id] = param
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory.')
+ saved_groups = torch.load(param_group_path)
+
+ updated_groups = []
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ # obtain updated param group
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({'param_groups': updated_groups})
+
+ # Load saved states to optimizer.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+ for pg in optimizer.optim.param_groups:
+ for param in pg['params']:
+ if param is None:
+ continue
+ param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
+ if param_id not in weight_map:
+ continue
+ filename = weight_map[param_id]
+
+ # If this param's states has been loaded before, directly return.
+ if filename in loaded_file:
+ continue
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
+ load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
+ loaded_file.add(filename)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ device = param.device
+ if self.master_to_working_map is not None:
+ working_param = self.master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info['param2shape'][id(working_param)]
+ sharded_state = self.shard_from_complete_optimizer_state(state,
+ current_shape=working_param.shape,
+ original_shape=original_shape,
+ device=device,
+ inplace=True)
+ optimizer.optim.state[param] = sharded_state
+
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+ if self.verbose:
+ logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
+ # TODO(Baizhou): support this feature after implementing complete state_dict collection
+ raise NotImplementedError
+
+ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+ # TODO(Baizhou): support this feature after implementing complete state_dict collection
+ raise NotImplementedError
+
+ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ # TODO(Baizhou): support this feature after implementing complete state_dict collection
+ raise NotImplementedError
+
+ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
+ # TODO(Baizhou): support this feature after implementing complete state_dict collection
+ raise NotImplementedError
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save lr scheduler to checkpoint but only on master process.
+ """
+ if self.coordinator.is_master():
+ super().save_lr_scheduler(lr_scheduler, checkpoint)
+
+ def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
+ master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
+ """
+ Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
+ This mapping can only be created when mixied precision is used.
+ The created mappings should be mappings from integer parameter addresses to parameter objects.
+
+ Args:
+ working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
+ master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
+ """
+ self.working_to_master_map = dict()
+ for k, v in working_to_master_map.items():
+ if isinstance(k, torch.Tensor):
+ self.working_to_master_map[id(k)] = v
+ elif isinstance(k, int):
+ self.working_to_master_map[k] = v
+ else:
+ raise ValueError(
+ f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
+
+ self.master_to_working_map = dict()
+ for k, v in master_to_working_map.items():
+ if isinstance(k, torch.Tensor):
+ self.master_to_working_map[id(k)] = v
+ elif isinstance(k, int):
+ self.master_to_working_map[k] = v
+ else:
+ raise ValueError(
+ f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
+
+ @staticmethod
+ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
+ dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
+ inplace: bool) -> OrderedDict:
+ """
+ With given parameter and its optimizer states, gather the complete optimizer state for saving.
+
+ Args:
+ state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
+ param (torch.Tensor): The given parameter. It should be working_param when using Zero.
+ original_shape (torch.Size): The size of parameter before sharding.
+ dp_group (ProcessGroup): The process group of data parallel.
+ tp_group (ProcessGroup): The process group of tensor parallel.
+ use_zero (bool): Whether Zero is used.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+
+ Returns:
+ OrderedDict: The complete optimizer state of given parameter.
+ """
+ dp_size = dist.get_world_size(dp_group)
+ tp_size = dist.get_world_size(tp_group)
+ current_shape = param.shape
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+
+ # First gather Zero shards.
+ if use_zero:
+ v = v.cuda()
+ gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
+ dist.all_gather(gather_tensor, v, group=dp_group)
+ v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
+
+ # Then gather TP shards.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
+ if partition_dim is not None:
+ gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
+ dist.all_gather(gather_tensor, v, group=tp_group)
+ v = torch.cat(gather_tensor, dim=partition_dim)
+
+ state_[k] = v.detach().clone().cpu()
+
+ return state_
+
+ def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
+ original_shape: torch.Size, device: torch.device,
+ inplace: bool) -> OrderedDict:
+ """
+ With complete optimizer states of a specific parameter loaded from checkpoint,
+ slice out the sharded optimizer states kept by current device.
+
+ Args:
+ state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
+ current_shape (torch.Size): The size of parameter after sharding.
+ original_shape (torch.Size): The size of parameter before sharding.
+ device (torch.device): The destination device of loaded optimizer states.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+
+ Returns:
+ OrderedDict: The sharded optimizer state of the given parameter.
+ """
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != 'step':
+
+ # Shard state along tensor parallel group.
+ partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
+ if partition_dim is not None:
+ slice_size = current_shape[partition_dim]
+ v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
+
+ # Shard state along data parallel group when using Zero.
+ if self.use_zero:
+ padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ slice_size = v.numel() // self.dp_size
+ v = v.split(slice_size, dim=0)[self.dp_rank]
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 8837776aee4d..3441eca38ce7 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -1,4 +1,5 @@
# coding=utf-8
+import copy
import os
import re
from collections import abc as container_abcs
@@ -11,9 +12,14 @@
import torch.nn as nn
from torch.optim import Optimizer
-from colossalai.interface import OptimizerWrapper
+from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.tensor.d_tensor import is_distributed_tensor
+from colossalai.tensor.d_tensor import (
+ is_customized_distributed_tensor,
+ is_distributed_tensor,
+ to_global,
+ to_global_for_customized_distributed_tensor,
+)
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@@ -88,8 +94,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False
+def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]:
+ """
+ Given the current shape of parameter and the shape of parameter before sharding,
+ return the dimension along which the parameter is sharded when using tensor parallel.
+ If tensor parallel is not used, return None.
+
+ Args:
+ current_shape (torch.Size): The current shape of parameter after sharding.
+ original_shape (torch.Size): The shape of parameter before sharding.
+ tp_size (int): The size of tp group.
+
+ Returns:
+ Optional[int]: The dimension along which parameter is partitioned.
+ """
+ partition_dim = None
+ for dim, length in enumerate(original_shape):
+ if length > current_shape[dim]:
+ partition_dim = dim
+ break
+ if partition_dim is not None:
+ assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
+ f"The parameter isn't evenly distributed among tensor parallel group: \
+ shape before sharding {original_shape}, shape after sharding {current_shape}"
+
+ return partition_dim
+
+
# ======================================
-# Helper functions for saving shard file
+# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
@@ -104,12 +137,97 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim
+class StateDictSharder:
+
+ def __init__(self, size_per_shard: int) -> None:
+ self.max_shard_size = size_per_shard
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
+
+ tensor_size = calculate_tensor_size(tensor)
+ ret_block = None
+ ret_block_size = 0
+
+ # before we return the current block and create a new block,
+ # we need to ensure that the current block is not empty
+ if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
+ ret_block = self.current_block
+ ret_block_size = self.current_block_size
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ self.current_block[name] = tensor
+ self.current_block_size += tensor_size
+ return ret_block, ret_block_size
+
+ def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
+
+ # A state might contain more than one tensors.
+ # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
+ state_size = 0
+ isDTensor = False
+ for state_tensor in state.values():
+
+ # When state_tensor is not of Tensor class,
+ # e.g., a SGD optimizer with momentum set to 0 can have None as state
+ # The calculation of tensor size should be skipped to avoid error.
+ if not isinstance(state_tensor, torch.Tensor):
+ continue
+
+ # If the states are stored as DTensors, mark isDTensor as true.
+ if is_distributed_tensor(state_tensor):
+ isDTensor = True
+ state_size += calculate_tensor_size(state_tensor)
+
+ ret_block = None
+ ret_block_size = 0
+
+ # directly return if state is stored as distributed tensor
+ if isDTensor:
+ return ret_block, ret_block_size
+
+ # before we return the current block and create a new block,
+ # we need to ensure that the current block is not empty
+ if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0:
+ ret_block = self.current_block
+ ret_block_size = self.current_block_size
+ self.current_block = OrderedDict()
+ self.current_block_size = 0
+
+ self.current_block[param_id] = state
+ self.current_block_size += state_size
+ return ret_block, ret_block_size
+
+
+def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor:
+ """
+ Gather the complete parameter for saving if passed in param is distributed under tp setting.
+
+ Args:
+ param (torch.Tensor): A model parameter, might be d_tensor.
+ keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
+
+ Returns:
+ torch.Tensor: the complete parameter
+ """
+ param_ = param if keep_vars else param.detach()
+ if is_distributed_tensor(param_):
+ return to_global(param_)
+ elif is_customized_distributed_tensor(param_):
+ return to_global_for_customized_distributed_tensor(param_)
+ else:
+ return param_
+
+
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
- use_safetensors: bool = False) -> int:
+ use_safetensors: bool = False,
+ use_pp_format: bool = False) -> int:
'''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
@@ -117,18 +235,21 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
- is_master (bool): Whether current rank is master.
- use_safetensors (bool): Whether to use safetensors to save checkpoint.
+ is_master (bool): Whether current rank is main process.
+ use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
+ use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
'''
total_size = 0
+ shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
+ shard, current_size = shard_pair
if not is_master:
+ del shard
continue
- shard, current_size = shard_pair
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
@@ -137,6 +258,11 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]]
# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
+ shard_filenames.append(shard_file)
+ del shard
+
+ # Clean folder, deleted unneeded files.
+ clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size
@@ -146,28 +272,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
- current_block = {}
- current_block_size = 0
+ state_dict_sharder = StateDictSharder(max_shard_size)
for key, weight in state_dict.items():
- ret_block = None
- ret_block_size = 0
if not is_distributed_tensor(weight):
- weight_size = calculate_tensor_size(weight)
-
- # If this weight is going to tip up over the maximal size, we split.
- if current_block_size + weight_size > max_shard_size and current_block_size > 0:
- ret_block = current_block
- ret_block_size = current_block_size
- current_block = {}
- current_block_size = 0
- current_block[key] = weight
- current_block_size += weight_size
+ block, block_size = state_dict_sharder.append_param(key, weight)
- if ret_block != None:
- yield ret_block, ret_block_size
+ if block != None:
+ yield block, block_size
- yield current_block, current_block_size
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
@@ -178,47 +293,212 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state']
-
- current_block = {}
- current_block_size = 0
+ state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state)
+ if block != None:
+ yield block, block_size
- ret_block = None
- ret_block_size = 0
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
- # A state might contain more than one tensors.
- # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
- state_size = 0
- isDTensor = False
- for state_tensor in state.values():
- # When state_tensor is not of Tensor class,
- # e.g., a SGD optimizer with momentum set to 0 can have None as state
- # The calculation of tensor size should be skipped to avoid error.
- if not isinstance(state_tensor, torch.Tensor):
- continue
+# ======================================
+# Helper functions for saving state dict
+# ======================================
- # If the states are stored as DTensors, mark isDTensor as true.
- if is_distributed_tensor(state_tensor):
- isDTensor = True
- state_size += calculate_tensor_size(state_tensor)
- if not isDTensor:
+def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
+ """
+ Save state dict to checkpoint.
+
+ Args:
+ state_dict (dict): state dict.
+ checkpoint_file_path (str): path to the checkpoint file.
+ use_safetensors (bool): whether to use safetensors to save the checkpoint.
+ """
+ if use_safetensors:
+ assert is_safetensors_available(), "safetensors is not available."
+ assert checkpoint_file_path.endswith('.safetensors'), \
+ "safetensors only supports .safetensors suffix for checkpoint file."
+ from safetensors.torch import save_file as safe_save_file
+ safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
+ else:
+ torch.save(state_dict, checkpoint_file_path)
+
+
+def save_param_groups(state_dict: dict, group_file_path: str) -> None:
+ """
+ Save information of param_groups to given file path.
+
+ Args:
+ state_dict (dict): state dict.
+ group_file_path (str): path to the group file.
+ """
+ param_groups = state_dict["param_groups"]
+ torch.save(param_groups, group_file_path)
+
+
+def clean_folder(checkpoint_path: str,
+ weights_name: str,
+ shard_filenames: List[str],
+ is_master: bool = True,
+ use_pp_format: bool = False):
+ """
+ Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
+
+ Args:
+ checkpoint_path (str): Path to the checkpoint directory.
+ weights_name (str): Decides the prefix of filenames of weight shards.
+ shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
+ is_master (bool, optional): Whether current rank is main process. Defaults to True.
+ use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
+
+ """
+ if is_master:
+ for filename in os.listdir(checkpoint_path):
+ full_filename = os.path.join(checkpoint_path, filename)
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+ filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
+ if not use_pp_format:
+ reg = re.compile(r"(.*?)-\d{5}")
+ else:
+ # When this checkpoint is created by pipeline parallel process, the pattern is a little different.
+ reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
+ if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
+ and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
+ os.remove(full_filename)
+
+
+def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
+ """
+ Save config.json/generation_config.json if model is a Huggingface pretrained model.
+ This method can only be called when a model is saved in a sharded way.
+
+ Args:
+ model (nn.Module): The model whose config should be saved if it's a huggingface model.
+ checkpoint_path (str): Path to the checkpoint directory.
+ is_master (bool): Whether current rank is main process.
+ """
+ try:
+ from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
+ from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
+ except ImportError:
+ return
+ if not isinstance(model, PreTrainedModel):
+ return
+
+ model = unwrap_huggingface_model(model)
+
+ # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
+ dtype = get_parameter_dtype(model)
+ model.config.torch_dtype = str(dtype).split(".")[1]
+
+ # Attach architecture to the config
+ model.config.architectures = [model.__class__.__name__]
+
+ # Save the config
+ if is_master:
+ model.config.save_pretrained(checkpoint_path)
+ if model.can_generate():
+ model.generation_config.save_pretrained(checkpoint_path)
+
+
+def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
+ """
+ Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
+ only one tensor.
+
+ Args:
+ tensor (Tensor): tensor to be saved.
+ index_file (CheckpointIndexFile): path to the checkpoint file.
+ size_per_shard (int): size per shard in MB.
+ """
+ root_path = index_file.root_path
+ output_root_path = root_path.joinpath('dtensor')
+
+ # create directory
+ output_root_path.mkdir(exist_ok=True)
+
+ # save tensor to this directory
+ # TODO(YuliangLiu): get index of the tensor shard
+ # e.g. index =
+ index = 0
+
+ # save tensor to file
+ ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
+ ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
+
+ # dtensor ckpt file always contains only one tensor
+ state_dict = {name: tensor}
+ save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
+
+ # update the weight map
+ # * means all shards
+ ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
+ index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
+
+
+def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
+ """
+ Get checkpoint file suffix.
+
+ Args:
+ use_safetensors (bool): whether to use safetensors to save the checkpoint.
+
+ Returns:
+ str: checkpoint file suffix.
+ """
+ if use_safetensors:
+ return '.safetensors'
+ else:
+ return '.bin'
+
+
+def generate_checkpoint_shard_file_name(index: int,
+ total_number: int,
+ use_safetensors: bool,
+ prefix: str = None) -> str:
+ """
+ Generate checkpoint shard file name.
+
+ Args:
+ index (int): index of the shard.
+ total_number (int): total number of shards.
+ use_safetensors (bool): whether to use safetensors to save the checkpoint.
+ prefix (str): prefix of the shard file name. Default: None.
+
+ Returns:
+ str: checkpoint shard file name.
+ """
+ suffix = get_checkpoint_file_suffix(use_safetensors)
+
+ if prefix is None:
+ return f"{index:05d}-of-{total_number:05d}.{suffix}"
+ else:
+ return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
+
+
+def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
+ """
+ Generate dtensor file name.
- if current_block_size + state_size > max_shard_size and current_block_size > 0:
- ret_block = current_block
- ret_block_size = current_block_size
- current_block = {}
- current_block_size = 0
+ Args:
+ param_name (str): name of the distributed parameter.
+ index (int): index of the shard.
+ use_safetensors (bool): whether to use safetensors to save the checkpoint.
- current_block[param_id] = state
- current_block_size += state_size
+ Returns:
+ str: dtensor file name.
+ """
+ suffix = get_checkpoint_file_suffix(use_safetensors)
+ return f'{param_name}.{index}.{suffix}'
- if ret_block != None:
- yield ret_block, ret_block_size
- yield current_block, current_block_size
+# ========================================
+# Helper functions for loading state dict
+# ========================================
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
@@ -237,7 +517,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
- return torch.load(checkpoint_file)
+ return torch.load(checkpoint_file, map_location=torch.device('cpu'))
def load_state_dict_into_model(model: nn.Module,
@@ -297,7 +577,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
- saved_groups = torch.load(param_group_path)
+ saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
@@ -331,17 +611,21 @@ def update_group(group, new_group):
return id_map
-def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
+def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False):
r"""Copies states from `state_dict` into an Optimizer object.
Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded
- state_dict(dict): a mapping from tensor index (an integer)
+ state_dict(dict): A mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor).
- id_map(dict): a mapping from tensor index (an integer)
+ id_map(dict): A mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated.
+ strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False.
"""
+ # Ensure that the keys of state_dict are integers.
+ state_dict = {int(k): v for k, v in state_dict.items()}
+
def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
@@ -368,7 +652,7 @@ def cast(param, value, key=None):
if k in id_map:
param = id_map[k]
new_states[param] = cast(param, v)
- else:
+ elif not strict:
new_states[k] = v
optimizer.state.update(new_states)
@@ -386,165 +670,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
optimizer.defaults.setdefault('differentiable', False)
-# ======================================
-# Helper functions for saving state dict
-# ======================================
-
-
-def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
- """
- Save state dict to checkpoint.
-
- Args:
- state_dict (dict): state dict.
- checkpoint_file_path (str): path to the checkpoint file.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
- """
- if use_safetensors:
- assert is_safetensors_available(), "safetensors is not available."
- assert checkpoint_file_path.endswith('.safetensors'), \
- "safetensors only supports .safetensors suffix for checkpoint file."
- from safetensors.torch import save_file as safe_save_file
- safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
- else:
- torch.save(state_dict, checkpoint_file_path)
-
-
-def save_param_groups(state_dict: dict, group_file_path: str) -> None:
- """
- Save information of param_groups to given file path.
-
- Args:
- state_dict (dict): state dict.
- group_file_path (str): path to the group file.
- """
- param_groups = state_dict["param_groups"]
- torch.save(param_groups, group_file_path)
-
-
-def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
- """
- Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
- only one tensor.
-
- Args:
- tensor (Tensor): tensor to be saved.
- index_file (CheckpointIndexFile): path to the checkpoint file.
- size_per_shard (int): size per shard in MB.
- """
- root_path = index_file.root_path
- output_root_path = root_path.joinpath('dtensor')
-
- # create directory
- output_root_path.mkdir(exist_ok=True)
-
- # save tensor to this directory
- # TODO(YuliangLiu): get index of the tensor shard
- # e.g. index =
- index = 0
-
- # save tensor to file
- ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
- ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
-
- # dtensor ckpt file always contains only one tensor
- state_dict = {name: tensor}
- save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
-
- # update the weight map
- # * means all shards
- ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
- index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
-
-
-def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
- """
- Get checkpoint file suffix.
-
- Args:
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
-
- Returns:
- str: checkpoint file suffix.
- """
- if use_safetensors:
- return '.safetensors'
- else:
- return '.bin'
-
-
-def generate_checkpoint_shard_file_name(index: int,
- total_number: int,
- use_safetensors: bool,
- prefix: str = None) -> str:
- """
- Generate checkpoint shard file name.
-
- Args:
- index (int): index of the shard.
- total_number (int): total number of shards.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
- prefix (str): prefix of the shard file name. Default: None.
-
- Returns:
- str: checkpoint shard file name.
- """
- suffix = get_checkpoint_file_suffix(use_safetensors)
-
- if prefix is None:
- return f"{index:05d}-of-{total_number:05d}.{suffix}"
- else:
- return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
-
-
-def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
- """
- Generate dtensor file name.
-
- Args:
- param_name (str): name of the distributed parameter.
- index (int): index of the shard.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
-
- Returns:
- str: dtensor file name.
- """
- suffix = get_checkpoint_file_suffix(use_safetensors)
- return f'{param_name}.{index}.{suffix}'
-
-
-def save_state_dict_as_shard(
- state_dict: dict,
- checkpoint_path: str,
- index: int,
- total_number: int,
- use_safetensors: bool,
- prefix: str = None,
-) -> None:
- """
- Save state dict as shard.
-
- Args:
- state_dict (dict): state dict.
- checkpoint_path (str): path to the checkpoint file.
- index (int): index of the shard.
- total_number (int): total number of shards.
- prefix (str): prefix of the shard file name.
- use_safetensors (bool): whether to use safetensors to save the checkpoint.
- """
- # generate the shard name
- shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
- shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
-
- # save the shard
- save_state_dict(state_dict, str(shard_file_path), use_safetensors)
-
-
-# ========================================
-# Helper functions for loading state dict
-# ========================================
-
-
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
"""
Check whether the checkpoint has an index file.
@@ -608,7 +733,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
- return torch.load(checkpoint_file_path)
+ return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
@@ -654,5 +779,5 @@ def get_shard_filename(weights_name: str, idx: int):
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
- shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
+ shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file
diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py
index f8fd1c41a059..385b485b6016 100644
--- a/colossalai/cli/benchmark/models.py
+++ b/colossalai/cli/benchmark/models.py
@@ -1,6 +1,6 @@
import torch
-import colossalai.nn as col_nn
+import colossalai.legacy.nn as col_nn
class MLP(torch.nn.Module):
diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py
index 5e74c2c4f5b8..d2d02811ac9d 100644
--- a/colossalai/cli/launcher/run.py
+++ b/colossalai/cli/launcher/run.py
@@ -265,6 +265,10 @@ def launch_multi_processes(args: Config) -> None:
# establish remote connection
runner.connect(host_info_list=active_device_pool, workdir=curr_path, env=env)
+ # overwrite master addr when num_nodes > 1 and not specified
+ if len(active_device_pool) > 1 and args.master_addr == "127.0.0.1":
+ args.master_addr = active_device_pool.hostinfo_list[0].hostname
+
# execute distributed launching command
for node_id, hostinfo in enumerate(active_device_pool):
cmd = get_launch_command(master_addr=args.master_addr,
diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py
index 2fbdfd3cc999..44f571ca2501 100644
--- a/colossalai/cluster/__init__.py
+++ b/colossalai/cluster/__init__.py
@@ -1,5 +1,6 @@
from .device_mesh_manager import DeviceMeshManager
from .dist_coordinator import DistCoordinator
from .process_group_manager import ProcessGroupManager
+from .process_group_mesh import ProcessGroupMesh
-__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
+__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh']
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
new file mode 100644
index 000000000000..623160003767
--- /dev/null
+++ b/colossalai/cluster/process_group_mesh.py
@@ -0,0 +1,209 @@
+import itertools
+from functools import reduce
+from operator import mul
+from typing import Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+
+def prod(nums: List[int]) -> int:
+ """Product of a list of numbers.
+
+ Args:
+ nums (List[int]): A list of numbers.
+
+ Returns:
+ int: The product of the numbers.
+ """
+ return reduce(mul, nums)
+
+
+class ProcessGroupMesh:
+ """A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method.
+ It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation.
+
+ We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process.
+ For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``.
+
+ Args:
+ *size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size.
+
+ Attributes:
+ shape (Tuple[int, ...]): The shape of the process group mesh.
+ rank (int): The rank of the current process.
+ """
+
+ def __init__(self, *size: int) -> None:
+ assert dist.is_initialized(), "Please initialize torch.distributed first."
+ assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
+ self._shape = size
+ self._rank = dist.get_rank()
+ self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
+ self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
+ self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
+
+ @property
+ def shape(self) -> Tuple[int, ...]:
+ return self._shape
+
+ @property
+ def rank(self) -> int:
+ return self._rank
+
+ def size(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
+ """Get the size of the process group mesh.
+
+ Args:
+ dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
+
+ Returns:
+ Union[int, Tuple[int, ...]]: Size of the target dimension or the whole process group mesh.
+ """
+ if dim is None:
+ return self._shape
+ else:
+ return self._shape[dim]
+
+ def coordinate(self, dim: Optional[int] = None) -> Union[int, Tuple[int, ...]]:
+ """Get the coordinate of the process group mesh.
+
+ Args:
+ dim (Optional[int], optional): Dimension of the process group mesh. `None` means all dimensions. Defaults to None.
+
+ Returns:
+ Union[int, Tuple[int, ...]]: Coordinate of the target dimension or the whole process group mesh.
+ """
+ if dim is None:
+ return self._coord
+ else:
+ return self._coord[dim]
+
+ @staticmethod
+ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]:
+ """Convert a rank to a coordinate.
+
+ Args:
+ rank (int): Rank to be converted.
+ shape (Tuple[int, ...]): Shape of the process group mesh.
+
+ Returns:
+ Tuple[int, ...]: Coordinate of the rank.
+ """
+ return np.unravel_index(rank, shape)
+
+ @staticmethod
+ def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
+ """Convert a coordinate to a rank.
+ mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
+ with wrap, index out of range would be wrapped around.
+ For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
+
+ Args:
+ coords (Tuple[int, ...]): Coordinate to be converted.
+ shape (Tuple[int, ...]): Shape of the process group mesh.
+ mode (Optional[str]): The mode for numpy.ravel_multi_index.
+
+ Returns:
+ int: Rank of the coordinate.
+ """
+
+ assert mode in ["raise", "wrap", "clip"]
+ return np.ravel_multi_index(coord, shape, mode)
+
+ def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
+ """Get the process group with the given ranks. It the process group doesn't exist, it will be created.
+
+ Args:
+ ranks_in_group (List[int]): Ranks in the process group.
+ backend (Optional[str], optional): Backend of the process group. Defaults to None.
+
+ Returns:
+ ProcessGroup: The process group with the given ranks.
+ """
+ ranks_in_group = sorted(ranks_in_group)
+ if tuple(ranks_in_group) not in self._group_to_ranks:
+ group = dist.new_group(ranks_in_group, backend=backend)
+ self._ranks_to_group[tuple(ranks_in_group)] = group
+ self._group_to_ranks[group] = tuple(ranks_in_group)
+ return self._ranks_to_group[tuple(ranks_in_group)]
+
+ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
+ """Get the ranks in the given process group. The process group must be created by this class.
+
+ Args:
+ group (ProcessGroup): The process group.
+
+ Returns:
+ List[int]: Ranks in the process group.
+ """
+ return list(self._group_to_ranks[group])
+
+ @staticmethod
+ def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int,
+ indices_at_axis: List[int]) -> List[Tuple[int, ...]]:
+ """Get coordinates along the given axis.
+
+ Args:
+ base_coord (Tuple[int, ...]): Base coordinate which the coordinates along the axis are based on.
+ axis (int): Axis along which the coordinates are generated.
+ indices_at_axis (List[int]): Indices at the axis.
+
+ Returns:
+ List[Tuple[int, ...]]: Coordinates along the axis.
+ """
+ coords_in_group = []
+ for idx in indices_at_axis:
+ coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:])
+ return coords_in_group
+
+ def create_group_along_axis(self,
+ axis: int,
+ indices_at_axis: Optional[List[int]] = None,
+ backend: Optional[str] = None) -> ProcessGroup:
+ """Create all process groups along the given axis, and return the one which the current process belongs to.
+
+ Args:
+ axis (int): Axis along which the process groups are created.
+ indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
+ backend (Optional[str], optional): Backend of the process group. Defaults to None.
+
+ Returns:
+ ProcessGroup: The process group along the given axis which the current process belongs to.
+ """
+ indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
+ reduced_shape = list(self._shape)
+ # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
+ reduced_shape[axis] = 1
+ target_group = None
+ # use Cartesian product to generate all combinations of coordinates
+ for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
+ coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
+ ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
+ group = self.get_group(ranks_in_group, backend=backend)
+ if self._rank in ranks_in_group:
+ target_group = group
+ return target_group
+
+ def get_group_along_axis(self,
+ axis: int,
+ indices_at_axis: Optional[List[int]] = None,
+ backend: Optional[str] = None) -> ProcessGroup:
+ """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
+
+ Args:
+ axis (int): Axis along which the process groups are created.
+ indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
+ backend (Optional[str], optional): Backend of the process group. Defaults to None.
+
+ Returns:
+ ProcessGroup: The process group along the given axis which the current process belongs to.
+ """
+ indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
+ coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
+ ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
+ if ranks_in_group not in self._ranks_to_group:
+ # no need to cache it explicitly, since it will be cached in `create_group_along_axis`
+ return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
+ return self._ranks_to_group[ranks_in_group]
diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py
index 003f0cdd91b6..7186f052ecec 100644
--- a/colossalai/context/parallel_context.py
+++ b/colossalai/context/parallel_context.py
@@ -15,8 +15,8 @@
from colossalai.context.config import Config
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from colossalai.logging import get_dist_logger
-from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py
index 4c05028041ce..ba601d0bf61a 100644
--- a/colossalai/context/process_group_initializer/initializer_1d.py
+++ b/colossalai/context/process_group_initializer/initializer_1d.py
@@ -2,8 +2,9 @@
# -*- encoding: utf-8 -*-
import torch.distributed as dist
+
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/context/process_group_initializer/initializer_2d.py
index 7fbe3be5901f..999cd5f0cfc6 100644
--- a/colossalai/context/process_group_initializer/initializer_2d.py
+++ b/colossalai/context/process_group_initializer/initializer_2d.py
@@ -3,7 +3,7 @@
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py
index 6b6fdc5d715c..b92ae2eec07e 100644
--- a/colossalai/context/process_group_initializer/initializer_2p5d.py
+++ b/colossalai/context/process_group_initializer/initializer_2p5d.py
@@ -4,9 +4,10 @@
import math
import torch.distributed as dist
+
from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py
index 1ed8eec86efc..6bca05ad7d5f 100644
--- a/colossalai/context/process_group_initializer/initializer_3d.py
+++ b/colossalai/context/process_group_initializer/initializer_3d.py
@@ -6,7 +6,7 @@
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/context/process_group_initializer/initializer_data.py
index 9715ebff7f00..b9dec4541dad 100644
--- a/colossalai/context/process_group_initializer/initializer_data.py
+++ b/colossalai/context/process_group_initializer/initializer_data.py
@@ -3,7 +3,7 @@
from torch import distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/context/process_group_initializer/initializer_model.py
index 99b9cc0d4edc..614ba372fbcc 100644
--- a/colossalai/context/process_group_initializer/initializer_model.py
+++ b/colossalai/context/process_group_initializer/initializer_model.py
@@ -2,9 +2,11 @@
# -*- encoding: utf-8 -*-
import torch.distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
-from .process_group_initializer import ProcessGroupInitializer
+
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
+
from ..parallel_mode import ParallelMode
+from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py
index 0ddb52f63e22..e093333ad18a 100644
--- a/colossalai/context/process_group_initializer/initializer_pipeline.py
+++ b/colossalai/context/process_group_initializer/initializer_pipeline.py
@@ -3,7 +3,7 @@
from torch import distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/context/process_group_initializer/initializer_sequence.py
index 251a2940778a..a6e26b6bcaa9 100644
--- a/colossalai/context/process_group_initializer/initializer_sequence.py
+++ b/colossalai/context/process_group_initializer/initializer_sequence.py
@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
import torch.distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
from .initializer_tensor import Initializer_Tensor
diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/context/process_group_initializer/initializer_tensor.py
index d2b5be9cfffb..3be89e52a812 100644
--- a/colossalai/context/process_group_initializer/initializer_tensor.py
+++ b/colossalai/context/process_group_initializer/initializer_tensor.py
@@ -3,9 +3,10 @@
import torch.distributed as dist
-from colossalai.registry import DIST_GROUP_INITIALIZER
-from .process_group_initializer import ProcessGroupInitializer
+from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
+
from ..parallel_mode import ParallelMode
+from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md
new file mode 100644
index 000000000000..9a965dc982a4
--- /dev/null
+++ b/colossalai/inference/README.md
@@ -0,0 +1,117 @@
+# 🚀 Colossal-Inference
+
+## Table of contents
+
+## Introduction
+
+`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
+
+## Design
+
+Colossal Inference is composed of two main components:
+
+1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
+2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
+ 1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
+ 2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
+3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
+ 1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
+ 2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
+ 3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.
+
+## Pipeline of inference:
+
+In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
+
+![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Colossal-inference.png)
+
+## Roadmap of our implementation
+
+- [x] Design cache manager and batch infer state
+- [x] Design TpInference engine to integrates with `Shardformer`
+- [x] Register corresponding high-performance `kernel` and `ops`
+- [x] Design policies and forwards (e.g. `Llama` and `Bloom`)
+ - [x] policy
+ - [x] context forward
+ - [x] token forward
+- [ ] Replace the kernels with `faster-transformer` in token-forward stage
+- [ ] Support all models
+ - [x] Llama
+ - [x] Bloom
+ - [ ] Chatglm2
+- [ ] Benchmarking for all models
+
+## Get started
+
+### Installation
+
+```bash
+pip install -e .
+```
+
+### Requirements
+
+dependencies
+
+```bash
+pytorch= 1.13.1 (gpu)
+cuda>= 11.6
+transformers= 4.30.2
+triton==2.0.0.dev20221202
+# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch
+vllm
+# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
+flash-attention
+```
+
+### Docker
+
+You can use docker run to use docker container to set-up environment
+
+```
+# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support
+docker pull hpcaitech/colossalai-inference:v2
+docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
+
+```
+
+### Dive into fast-inference!
+
+example files are in
+
+```bash
+cd colossalai.examples
+python xx
+```
+
+## Performance
+
+### environment:
+
+We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`.
+
+For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future):
+
+### Single GPU Performance:
+
+Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned.
+
+#### Llama
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
+| colossal-inference | 326.4 | 582.72 | 816.64 |
+
+![llama](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-llama7b.png)
+
+### Bloom
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
+| colossal-inference | 323.28 | 538.52 | 611.64 |
+
+![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-bloom7b.png)
+
+The results of more models are coming soon!
diff --git a/tests/test_layers/test_1d/checks_1d/__init__.py b/colossalai/inference/__init__.py
similarity index 100%
rename from tests/test_layers/test_1d/checks_1d/__init__.py
rename to colossalai/inference/__init__.py
diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py
new file mode 100644
index 000000000000..e467b4c73e6b
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/__init__.py
@@ -0,0 +1,4 @@
+from .engine import TPInferEngine
+from .kvcache_manager import MemoryManager
+
+__all__ = ['MemoryManager', 'TPInferEngine']
diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py
new file mode 100644
index 000000000000..2bff9317283e
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/batch_infer_state.py
@@ -0,0 +1,55 @@
+# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
+from dataclasses import dataclass
+from typing import Any
+
+import torch
+
+from .kvcache_manager import MemoryManager
+
+
+@dataclass
+class BatchInferState:
+ r"""
+ Information to be passed and used for a batch of inputs during
+ a single model forward
+ """
+ batch_size: int
+ max_len_in_batch: int
+
+ cache_manager: MemoryManager = None
+
+ block_loc: torch.Tensor = None
+ start_loc: torch.Tensor = None
+ seq_len: torch.Tensor = None
+ past_key_values_len: int = None
+
+ is_context_stage: bool = False
+ context_mem_index: torch.Tensor = None
+ decode_is_contiguous: bool = None
+ decode_mem_start: int = None
+ decode_mem_end: int = None
+ decode_mem_index: torch.Tensor = None
+ decode_layer_id: int = None
+
+ device: torch.device = torch.device('cuda')
+
+ @property
+ def total_token_num(self):
+ # return self.batch_size * self.max_len_in_batch
+ assert self.seq_len is not None and self.seq_len.size(0) > 0
+ return int(torch.sum(self.seq_len))
+
+ def set_cache_manager(self, manager: MemoryManager):
+ self.cache_manager = manager
+
+ @staticmethod
+ def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
+ alloc_mem_index: torch.Tensor):
+ """ in-place update block loc mapping based on the sequence length of the inputs in current bath"""
+ start_index = 0
+ seq_len_numpy = seq_len.cpu().numpy()
+ for i, cur_seq_len in enumerate(seq_len_numpy):
+ b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
+ cur_seq_len]
+ start_index += cur_seq_len
+ return
diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py
new file mode 100644
index 000000000000..a5a55702ade0
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/engine.py
@@ -0,0 +1,294 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+from transformers import BloomForCausalLM, LlamaForCausalLM
+from transformers.generation import GenerationConfig
+from transformers.generation.stopping_criteria import StoppingCriteriaList
+from transformers.tokenization_utils_base import BatchEncoding
+
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.auto_policy import get_autopolicy
+
+from .batch_infer_state import BatchInferState
+from .kvcache_manager import MemoryManager
+
+DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
+
+_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
+
+
+class TPInferEngine:
+ """Engine class for tensor parallel inference.
+
+ Args:
+ model (Module): original model, e.g. huggingface CausalLM
+ shard_config (ShardConfig): The config for sharding original model
+ max_batch_size (int): maximum batch size
+ max_input_len (int): maximum input length of sequence
+ max_output_len (int): maximum output length of output tokens
+ dtype (torch.dtype): datatype used to init KV cache space
+ device (str): device the KV cache of engine to be initialized on
+
+ Examples:
+ >>> # define model and shard config for your inference
+ >>> model = ...
+ >>> generate_kwargs = ...
+ >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
+ >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
+ >>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
+ """
+
+ def __init__(self,
+ model: nn.Module,
+ shard_config: ShardConfig,
+ max_batch_size: int,
+ max_input_len: int,
+ max_output_len: int,
+ dtype: torch.dtype = torch.float16,
+ device: str = 'cuda') -> None:
+ self.max_batch_size = max_batch_size
+ self.max_input_len = max_input_len
+ self.max_output_len = max_output_len
+ self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len)
+
+ # Constraints relatable with specs of devices and model
+ # This may change into an optional arg in the future
+ assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
+ assert self.max_input_len + self.max_output_len <= 4096, "Max length exceeds the constraint"
+
+ self.dtype = dtype
+
+ self.head_dim = model.config.hidden_size // model.config.num_attention_heads
+ self.head_num = model.config.num_attention_heads
+ self.layer_num = model.config.num_hidden_layers
+
+ self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
+ self.cache_manager = None
+
+ self.shard_config = shard_config
+ self.model = None
+ # optimize the original model by sharding with ShardFormer
+ self._optimize_model(model=model.to(device))
+
+ def _init_manager(self) -> None:
+ assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
+ assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
+ self.head_num //= self.tp_size # update sharded number of heads
+ self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
+ self.layer_num)
+
+ def _optimize_model(self, model: nn.Module) -> None:
+ """
+ Optimize the original model by sharding with ShardFormer.
+ In further generation, use the sharded model instead of original model.
+ """
+ # NOTE we will change to use an inference config later with additional attrs we want
+ assert self.shard_config.inference_only is True
+ shardformer = ShardFormer(shard_config=self.shard_config)
+ self._prepare_with_shard_config(shard_config=self.shard_config)
+ self._shard_model_by(shardformer, model)
+
+ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
+ """ Prepare the engine with a given ShardConfig.
+
+ Args:
+ shard_config (ShardConfig): shard config given to specify settings of the engine.
+ If not provided, a default ShardConfig with tp size 1 will be created.
+ """
+ self.tp_size = 1
+ if shard_config is None:
+ shard_config = ShardConfig(
+ tensor_parallel_process_group=None,
+ pipeline_stage_manager=None,
+ enable_tensor_parallelism=False,
+ enable_fused_normalization=False,
+ enable_all_optimization=False,
+ enable_flash_attention=False,
+ enable_jit_fused=False,
+ inference_only=True,
+ )
+ else:
+ shard_config.inference_only = True
+ shard_config.pipeline_stage_manager = None
+ if shard_config.enable_tensor_parallelism:
+ self.tp_size = shard_config.tensor_parallel_size
+ self._init_manager()
+
+ return shard_config
+
+ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
+ """ Shard original model by the given ShardFormer and store the sharded model. """
+ assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
+ "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
+ model_name = model.__class__.__name__
+ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
+ policy = get_autopolicy(model, inference_only=True)
+ self.model, _ = shardformer.optimize(model, policy)
+ self.model = self.model.cuda()
+
+ @property
+ def supported_models(self) -> List[str]:
+ return _supported_models
+
+ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor:
+ """Generate token sequence.
+
+ Args:
+ input_tokens: could be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ Returns:
+ torch.Tensor: The returned sequence is given inputs + generated_tokens.
+ """
+ if isinstance(input_tokens, torch.Tensor):
+ input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool))
+ for t in input_tokens:
+ if torch.is_tensor(input_tokens[t]):
+ input_tokens[t] = input_tokens[t].cuda()
+ if 'max_new_tokens' not in generate_kwargs:
+ generate_kwargs.update(max_new_tokens=self.max_output_len)
+
+ return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)
+
+ def prepare_batch_state(self, inputs) -> BatchInferState:
+ """
+ Create and prepare BatchInferState used for inference during model forwrad,
+ by processing each sequence of the given inputs.
+
+ Args:
+ inputs: should be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve
+ the actual length (e.g. number of tokens) of each input without attention mask
+ Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume
+ all the inputs in the batch has the maximum length l
+ Returns:
+ BatchInferState: the states for the current batch during inference
+ """
+ if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)):
+ raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state")
+
+ input_ids_list = None
+ attention_mask = None
+
+ if isinstance(inputs, (BatchEncoding, dict)):
+ input_ids_list = inputs['input_ids']
+ attention_mask = inputs['attention_mask']
+ else:
+ input_ids_list = inputs
+ if isinstance(input_ids_list[0], int): # for a single input
+ input_ids_list = [input_ids_list]
+ attention_mask = [attention_mask] if attention_mask is not None else attention_mask
+
+ batch_size = len(input_ids_list)
+
+ seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
+ seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
+ start_index = 0
+
+ max_len_in_batch = -1
+ if isinstance(inputs, (BatchEncoding, dict)):
+ for i, attn_mask in enumerate(attention_mask):
+ curr_seq_len = len(attn_mask)
+ # if isinstance(attn_mask, torch.Tensor):
+ # curr_seq_len = int(torch.sum(attn_mask))
+ # else:
+ # curr_seq_len = int(sum(attn_mask))
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ else:
+ length = max(len(input_id) for input_id in input_ids_list)
+ for i, input_ids in enumerate(input_ids_list):
+ curr_seq_len = length
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
+ batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
+ batch_infer_state.seq_len = seq_lengths.to('cuda')
+ batch_infer_state.start_loc = seq_start_indexes.to('cuda')
+ batch_infer_state.block_loc = block_loc
+ batch_infer_state.decode_layer_id = 0
+ batch_infer_state.past_key_values_len = 0
+ batch_infer_state.is_context_stage = True
+ batch_infer_state.set_cache_manager(self.cache_manager)
+ return batch_infer_state
+
+ @torch.no_grad()
+ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
+ """
+ Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate
+
+ Args:
+ inputs: should be one of the following types
+ 1. BatchEncoding or dict (e.g. tokenizer batch_encode)
+ 2. list of input token ids (e.g. appended result of tokenizer encode)
+ 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
+ """
+
+ # for testing, always use sharded model
+ assert self.model is not None, "sharded model does not exist"
+
+ batch_infer_state = self.prepare_batch_state(input_tokens)
+ assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit"
+
+ # set BatchInferState for the current batch as attr to model
+ # NOTE this is not a preferable way to pass BatchInferState during inference
+ # we might want to rewrite generate function (e.g. _generate_by_pass_infer_state)
+ # and pass BatchInferState via model forward
+ model = self.model
+ if isinstance(model, LlamaForCausalLM):
+ model = self.model.model
+ elif isinstance(model, BloomForCausalLM):
+ model = self.model.transformer
+ setattr(model, 'infer_state', batch_infer_state)
+
+ outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
+
+ # NOTE In future development, we're going to let the scheduler to handle the cache,
+ # instead of freeing space explicitly at the end of generation
+ self.cache_manager.free_all()
+
+ return outputs
+
+ # TODO might want to implement the func that generates output tokens by passing BatchInferState
+ # as an arg into model.forward.
+ # It requires rewriting model generate and replacing model forward.
+ @torch.no_grad()
+ def _generate_by_pass_infer_state(self,
+ input_tokens,
+ max_out_length: int,
+ generation_config: Optional[GenerationConfig] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ **model_kwargs) -> torch.Tensor:
+
+ raise NotImplementedError("generate by passing BatchInferState is not implemented.")
+
+ # might want to use in rewritten generate method: use after model.forward
+ # BatchInferState is created and kept during generation
+ # after each iter of model forward, we should update BatchInferState
+ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
+ batch_size = infer_state.batch_size
+ device = infer_state.start_loc.device
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
+ infer_state.seq_len += 1
+
+ # might want to create a sequence pool
+ # add a single request/sequence/input text at a time and record its length
+ # In other words, store the actual length of input tokens representing a single input text
+ # E.g. "Introduce landmarks in Beijing"
+ # => add request
+ # => record token length and other necessary information to be used
+ # => engine hold all these necessary information until `generate` (or other name) is called,
+ # => put information already recorded in batchinferstate and pass it to model forward
+ # => clear records in engine
+ def add_request():
+ raise NotImplementedError()
diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py
new file mode 100644
index 000000000000..274c01841279
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/kvcache_manager.py
@@ -0,0 +1,101 @@
+# Adapted from lightllm/common/mem_manager.py
+# of the ModelTC/lightllm GitHub repository
+# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
+
+import torch
+from transformers.utils import logging
+
+
+class MemoryManager:
+ r"""
+ Manage token block indexes and allocate physical memory for key and value cache
+
+ Args:
+ size: maximum token number used as the size of key and value buffer
+ dtype: data type of cached key and value
+ head_num: number of heads the memory manager is responsible for
+ head_dim: embedded size per head
+ layer_num: the number of layers in the model
+ device: device used to store the key and value cache
+ """
+
+ def __init__(self,
+ size: int,
+ dtype: torch.dtype,
+ head_num: int,
+ head_dim: int,
+ layer_num: int,
+ device: torch.device = torch.device('cuda')):
+ self.logger = logging.get_logger(__name__)
+ self.available_size = size
+ self.past_key_values_length = 0
+ self._init_mem_states(size, device)
+ self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
+
+ def _init_mem_states(self, size, device):
+ """ Initialize tensors used to manage memory states """
+ self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
+ self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
+ self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
+
+ def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
+ """ Initialize key buffer and value buffer on specified device """
+ self.key_buffer = [
+ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
+ ]
+ self.value_buffer = [
+ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
+ ]
+
+ @torch.no_grad()
+ def alloc(self, required_size):
+ """ allocate space of required_size by providing indexes representing available physical spaces """
+ if required_size > self.available_size:
+ self.logger.warning(f"No enough cache: required_size {required_size} "
+ f"left_size {self.available_size}")
+ return None
+ torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
+ select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
+ select_index = self.indexes[select_index]
+ self.mem_state[select_index] = 0
+ self.available_size -= len(select_index)
+ return select_index
+
+ @torch.no_grad()
+ def alloc_contiguous(self, required_size):
+ """ allocate contiguous space of required_size """
+ if required_size > self.available_size:
+ self.logger.warning(f"No enough cache: required_size {required_size} "
+ f"left_size {self.available_size}")
+ return None
+ torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
+ sum_size = len(self.mem_cum_sum)
+ loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
+ 1] + self.mem_state[0:sum_size -
+ required_size + 1]
+ can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
+ if can_used_loc.shape[0] == 0:
+ self.logger.info(f"No enough contiguous cache: required_size {required_size} "
+ f"left_size {self.available_size}")
+ return None
+ start_loc = can_used_loc[0]
+ select_index = self.indexes[start_loc:start_loc + required_size]
+ self.mem_state[select_index] = 0
+ self.available_size -= len(select_index)
+ start = start_loc.item()
+ end = start + required_size
+ return select_index, start, end
+
+ @torch.no_grad()
+ def free(self, free_index):
+ """ free memory by updating memory states based on given indexes """
+ self.available_size += free_index.shape[0]
+ self.mem_state[free_index] = 1
+
+ @torch.no_grad()
+ def free_all(self):
+ """ free all memory by updating memory states """
+ self.available_size = len(self.mem_state)
+ self.mem_state[:] = 1
+ self.past_key_values_length = 0
+ self.logger.info("freed all space of memory manager")
diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py
new file mode 100644
index 000000000000..7a98b033f37e
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/__init__.py
@@ -0,0 +1,4 @@
+from .bloom import BloomInferenceForwards
+from .llama import LlamaInferenceForwards
+
+__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards']
diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py
new file mode 100644
index 000000000000..9768fc425628
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/bloom.py
@@ -0,0 +1,521 @@
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+from transformers.models.bloom.modeling_bloom import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BloomAttention,
+ BloomBlock,
+ BloomForCausalLM,
+ BloomModel,
+ CausalLMOutputWithCrossAttentions,
+)
+from transformers.utils import logging
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd
+from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
+from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
+
+
+def generate_alibi(n_head, dtype=torch.float16):
+ """
+ This method is adapted from `_generate_alibi` function
+ in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py`
+ of the ModelTC/lightllm GitHub repository.
+ This method is originally the `build_alibi_tensor` function
+ in `transformers/models/bloom/modeling_bloom.py`
+ of the huggingface/transformers GitHub repository.
+ """
+
+ def get_slopes_power_of_2(n):
+ start = 2**(-(2**-(math.log2(n) - 3)))
+ return [start * start**i for i in range(n)]
+
+ def get_slopes(n):
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2**math.floor(math.log2(n))
+ slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
+ slopes_double = get_slopes(2 * closest_power_of_2)
+ slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2]
+ return slopes_combined
+
+ slopes = get_slopes(n_head)
+ return torch.tensor(slopes, dtype=dtype)
+
+
+class BloomInferenceForwards:
+ """
+ This class serves a micro library for bloom inference forwards.
+ We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention,
+ as well as prepare_inputs_for_generation method for BloomForCausalLM.
+ For future improvement, we might want to skip replacing methods for BloomForCausalLM,
+ and call BloomModel.forward iteratively in TpInferEngine
+ """
+
+ @staticmethod
+ def bloom_model_forward(
+ self: BloomModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: Optional[BatchInferState] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ # still need to keep past_key_values to fit original forward flow
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ # NOTE determine if BatchInferState is passed in via arg
+ # if not, get the attr binded to the model
+ # We might wantto remove setattr later
+ if infer_state is None:
+ assert hasattr(self, 'infer_state')
+ infer_state = self.infer_state
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ # if self.cache_manager.past_key_values_length > 0:
+ if infer_state.cache_manager.past_key_values_length > 0:
+ # update the past key values length in cache manager,
+ # NOTE use BatchInferState.past_key_values_length instead the one in cache manager
+ past_key_values_length = infer_state.cache_manager.past_key_values_length
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # infer_state.cache_manager = self.cache_manager
+
+ if use_cache and seq_length != 1:
+ # prefill stage
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
+ infer_state.context_mem_index)
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model,
+ # or store to BatchInferState to prevent re-calculating
+ # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here
+ # alibi = generate_alibi(self.num_heads).contiguous().cuda()
+ tp_size = dist.get_world_size()
+ curr_tp_rank = dist.get_rank()
+ alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) *
+ self.num_heads].cuda()
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ # NOTE: currently our KV cache manager does not handle this condition
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ alibi,
+ causal_mask,
+ layer_past,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ infer_state=infer_state,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # update indices of kv cache block
+ # NOT READY FOR PRIME TIME
+ # might want to remove this part, instead, better to pass the BatchInferState from model forward,
+ # and update these information in engine.generate after model foward called
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.decode_layer_id = 0
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents, # should always be (None, None, ..., None)
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ @staticmethod
+ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: Optional[BatchInferState] = None,
+ **deprecated_arguments):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state)
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ batch_size, seq_length, vocab_size = shift_logits.shape
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
+ shift_labels.view(batch_size * seq_length))
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def bloom_for_causal_lm_prepare_inputs_for_generation(
+ self: BloomForCausalLM,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> dict:
+ # only last token for input_ids if past is not None
+ if past_key_values:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+
+ # NOTE we won't use past key values here
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
+ # if past_key_values[0][0].shape[0] == input_ids.shape[0]:
+ # past_key_values = self._convert_to_bloom_cache(past_key_values)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update({
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ })
+ return model_inputs
+
+ @staticmethod
+ def bloom_block_forward(
+ self: BloomBlock,
+ hidden_states: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ infer_state=infer_state,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+ @staticmethod
+ def bloom_attention_forward(
+ self: BloomAttention,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+ batch_size, q_length, H, D_HEAD = query_layer.shape
+ k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+ v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+
+ mem_manager = infer_state.cache_manager
+ layer_id = infer_state.decode_layer_id
+
+ if layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_length # += 1
+
+ if infer_state.is_context_stage:
+ # context process
+ max_input_len = q_length
+ b_start_loc = infer_state.start_loc
+ b_seq_len = infer_state.seq_len[:batch_size]
+ q = query_layer.reshape(-1, H, D_HEAD)
+
+ copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id])
+
+ # output = self.output[:batch_size*q_length, :, :]
+ output = torch.empty_like(q)
+
+ bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
+
+ context_layer = output.view(batch_size, q_length, H * D_HEAD)
+ else:
+ # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
+ # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD)
+ assert q_length == 1, "for non-context process, we only support q_length == 1"
+ q = query_layer.reshape(-1, H, D_HEAD)
+
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[layer_id][
+ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ cache_v = infer_state.cache_manager.value_buffer[layer_id][
+ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ cache_k.copy_(k)
+ cache_v.copy_(v)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head]
+ copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id])
+
+ b_start_loc = infer_state.start_loc
+ b_loc = infer_state.block_loc
+ b_seq_len = infer_state.seq_len
+ output = torch.empty_like(q)
+ token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc,
+ b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi)
+
+ context_layer = output.view(batch_size, q_length, H * D_HEAD)
+
+ # update layer id
+ infer_state.decode_layer_id += 1
+
+ # NOTE: always set present as none for now, instead of returning past key value to the next decoding,
+ # we create the past key value pair from the cache manager
+ present = None
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices):int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ # dropout is not required here during inference
+ output_tensor = residual + output_tensor
+
+ outputs = (output_tensor, present)
+ assert output_attentions is False, "we do not support output_attentions at this time"
+
+ return outputs
diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py
new file mode 100644
index 000000000000..219cd1ae0d0e
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/modeling/llama.py
@@ -0,0 +1,359 @@
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
+from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
+from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
+from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
+
+try:
+ from vllm import layernorm_ops, pos_encoding_ops
+ rms_norm = layernorm_ops.rms_norm
+ rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
+ HAS_VLLM_KERNERL = True
+except:
+ print("fall back to original rotary_embedding_neox of huggingface")
+ print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
+ print(
+ "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
+ )
+ HAS_VLLM_KERNERL = False
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., :x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+ return
+
+
+class LlamaInferenceForwards:
+ """
+ This class holds forwards for llama inference.
+ We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
+ """
+
+ @staticmethod
+ def llama_model_forward(
+ self: LlamaModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+
+ batch_size = input_ids.shape[0] # input_ids.shape[0]
+
+ infer_state = self.infer_state
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ # NOT READY FOR PRIME TIME
+ # dummy but work, revise it
+ past_key_values_length = infer_state.cache_manager.past_key_values_length
+ # past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # NOTE: differentiate with prefill stage
+ # block_loc require different value-assigning method for two different stage
+ if use_cache and seq_length != 1:
+ # NOTE assuem prefill stage
+ # allocate memory block
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
+ infer_state.context_mem_index)
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device)
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if infer_state.is_context_stage:
+
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1)
+ else:
+ seq_len = infer_state.seq_len
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=inputs_embeds.device)
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
+ past_key_values_length)
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ infer_state.decode_layer_id = 0
+
+ for idx, decoder_layer in enumerate(self.layers):
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+ # NOTE: modify here for passing args to decoder layer
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ infer_state.decode_layer_id += 1
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ hidden_states = self.norm(hidden_states)
+ next_cache = next_decoder_cache if use_cache else None
+
+ # update indices
+ # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ @staticmethod
+ def llama_decoder_layer_forward(
+ self: LlamaDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ @staticmethod
+ def llama_flash_attn_kvcache_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
+ assert use_cache is True, "use_cache should be set to True using this llama attention"
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # NOTE might think about better way to handle transposed k and v
+ # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
+ # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+
+ # NOTE might want to revise
+ # need some way to record the length of past key values cache
+ # since we won't return past_key_value_cache right now
+ if infer_state.decode_layer_id == 0: # once per model.forward
+ infer_state.cache_manager.past_key_values_length += q_len # seq_len
+
+ cos, sin = infer_state.position_cos, infer_state.position_sin
+ # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
+
+ rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
+ rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
+
+ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+ return
+
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+ key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
+ value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
+
+ if infer_state.is_context_stage:
+ # first token generation
+
+ # copy key and value calculated in current step to memory manager
+ _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index,
+ infer_state.cache_manager)
+
+ attn_output = torch.empty_like(query_states)
+
+ llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc,
+ infer_state.seq_len, infer_state.cache_manager.past_key_values_length)
+ else:
+
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :]
+ cache_k.copy_(key_states)
+ cache_v.copy_(value_states)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states,
+ infer_state.decode_mem_index, infer_state.cache_manager)
+
+ # second token and follows
+ # kv = torch.stack((key_states, value_states), dim=2)
+ # (batch_size, seqlen, nheads, headdim)
+ attn_output = torch.empty_like(query_states)
+
+ token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output,
+ infer_state.block_loc, infer_state.start_loc, infer_state.seq_len,
+ infer_state.cache_manager.past_key_values_length)
+
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ # return past_key_value as None
+ return attn_output, None, None
+
+
+def get_llama_vllm_rmsnorm_forward():
+
+ if HAS_VLLM_KERNERL:
+
+ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
+ x = hidden_states
+ out = torch.empty_like(x)
+ rms_norm(
+ out,
+ x,
+ self.weight.data,
+ self.variance_epsilon,
+ )
+
+ return out
+
+ return _vllm_rmsnorm_forward
+ else:
+ return None
diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py
new file mode 100644
index 000000000000..48f8db62c32a
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/__init__.py
@@ -0,0 +1,4 @@
+from .bloom import BloomModelInferPolicy
+from .llama import LlamaModelInferPolicy
+
+__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy']
diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py
new file mode 100644
index 000000000000..63791fe27284
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/bloom.py
@@ -0,0 +1,66 @@
+from functools import partial
+
+import torch
+from torch.nn import LayerNorm
+
+from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
+
+from ..modeling.bloom import BloomInferenceForwards
+
+try:
+ from colossalai.kernel.triton.fused_layernorm import layer_norm
+ HAS_TRITON_NORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_NORM = False
+
+
+def get_triton_layernorm_forward():
+ if HAS_TRITON_NORM:
+
+ def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
+ return layer_norm(hidden_states, self.weight.data, self.bias, self.eps)
+
+ return _triton_layernorm_forward
+ else:
+ return None
+
+
+class BloomModelInferPolicy(BloomForCausalLMPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
+ policy = super().module_policy()
+ # NOTE set inference mode to shard config
+ self.shard_config._infer()
+
+ method_replacement = {
+ 'forward': BloomInferenceForwards.bloom_for_causal_lm_forward,
+ 'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
+ }
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=BloomForCausalLM)
+
+ method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
+
+ method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
+
+ method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=BloomAttention)
+
+ if HAS_TRITON_NORM:
+ infer_method = get_triton_layernorm_forward()
+ method_replacement = {'forward': partial(infer_method)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=LayerNorm)
+
+ return policy
diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py
new file mode 100644
index 000000000000..e819f2a8810c
--- /dev/null
+++ b/colossalai/inference/tensor_parallel/policies/llama.py
@@ -0,0 +1,70 @@
+from functools import partial
+import torch
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaModel,
+ LlamaRMSNorm
+)
+
+# import colossalai
+from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
+from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
+
+try:
+ from colossalai.kernel.triton.rms_norm import rmsnorm_forward
+ HAS_TRITON_RMSNORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_RMSNORM = False
+
+
+def get_triton_rmsnorm_forward():
+ if HAS_TRITON_RMSNORM:
+ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
+ return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
+
+ return _triton_rmsnorm_forward
+ else:
+ return None
+
+class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ self.shard_config._infer()
+
+ infer_forward = LlamaInferenceForwards.llama_model_forward
+ method_replacement = {'forward': partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
+
+ infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
+ method_replacement = {'forward': partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=LlamaDecoderLayer)
+
+ infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
+ method_replacement = {'forward': partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=LlamaAttention)
+
+ infer_forward = None
+ if HAS_TRITON_RMSNORM:
+ infer_forward = get_triton_rmsnorm_forward()
+ else:
+ # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
+ infer_forward = get_llama_vllm_rmsnorm_forward()
+
+ if infer_forward is not None:
+ method_replacement = {'forward': partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=LlamaRMSNorm)
+
+ return policy
+
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index dc0df0517508..a1694e059fb4 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -17,13 +17,13 @@
from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
-from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.core import global_context as gpc
-from colossalai.engine import Engine
-from colossalai.engine.gradient_accumulation import accumulate_gradient
-from colossalai.engine.schedule import (
+from colossalai.legacy.builder.builder import build_gradient_handler
+from colossalai.legacy.engine import Engine
+from colossalai.legacy.engine.gradient_accumulation import accumulate_gradient
+from colossalai.legacy.engine.schedule import (
InterleavedPipelineSchedule,
NonPipelineSchedule,
PipelineSchedule,
diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py
index 8c658e375146..1c3199fc1aff 100644
--- a/colossalai/interface/__init__.py
+++ b/colossalai/interface/__init__.py
@@ -1,4 +1,4 @@
-from .model import ModelWrapper
+from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper
-__all__ = ['OptimizerWrapper', 'ModelWrapper']
+__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py
index a067d7671ce7..7b3d9435d255 100644
--- a/colossalai/interface/model.py
+++ b/colossalai/interface/model.py
@@ -23,3 +23,14 @@ def unwrap(self):
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
+
+
+class AMPModelMixin:
+ """This mixin class defines the interface for AMP training.
+ """
+
+ def update_master_params(self):
+ """
+ Update the master parameters for AMP training.
+ """
+ pass
diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py
index 0eaf2e1ef8ba..bc270b1d9c89 100644
--- a/colossalai/interface/optimizer.py
+++ b/colossalai/interface/optimizer.py
@@ -1,5 +1,6 @@
from typing import Union
+import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
@@ -53,6 +54,9 @@ def backward(self, loss: Tensor, *args, **kwargs):
"""
loss.backward(*args, **kwargs)
+ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
+ torch.autograd.backward(tensor, grad)
+
def state_dict(self):
"""
Returns the optimizer state.
diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py
index 8933fc0a3c2f..a99cb497c3e7 100644
--- a/colossalai/kernel/__init__.py
+++ b/colossalai/kernel/__init__.py
@@ -1,7 +1,14 @@
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
+from .triton import llama_context_attn_fwd, bloom_context_attn_fwd
+from .triton import softmax
+from .triton import copy_kv_cache_to_dest
__all__ = [
"LayerNorm",
"FusedScaleMaskSoftmax",
"MultiHeadAttention",
+ "llama_context_attn_fwd",
+ "bloom_context_attn_fwd",
+ "softmax",
+ "copy_kv_cache_to_dest",
]
diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py
index 4910717b5723..e0136d86e561 100644
--- a/colossalai/kernel/cuda_native/__init__.py
+++ b/colossalai/kernel/cuda_native/__init__.py
@@ -1,8 +1,9 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention
-from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
+from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
__all__ = [
- 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention'
+ 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention',
+ 'AttnMaskType'
]
diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
index e83beb8b2429..8a898080877c 100644
--- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
+++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py
@@ -2,7 +2,13 @@
HAS_MEM_EFF_ATTN = False
try:
- from xformers.ops.fmha import memory_efficient_attention
+ from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
+ from xformers.ops.fmha.attn_bias import (
+ BlockDiagonalCausalMask,
+ BlockDiagonalMask,
+ LowerTriangularMask,
+ LowerTriangularMaskWithTensorBias,
+ )
HAS_MEM_EFF_ATTN = True
except ImportError:
warnings.warn('please install xformers from https://github.com/facebookresearch/xformers')
@@ -16,13 +22,6 @@
from typing import Optional
import torch
- from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
- from xformers.ops.fmha.attn_bias import (
- BlockDiagonalCausalMask,
- BlockDiagonalMask,
- LowerTriangularMask,
- LowerTriangularMaskWithTensorBias,
- )
from .utils import SeqLenInfo
diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py
index e20c08b051ed..8eb4e0c880a0 100644
--- a/colossalai/kernel/jit/option.py
+++ b/colossalai/kernel/jit/option.py
@@ -1,6 +1,6 @@
import torch
-from colossalai.nn.layer.colossalai_layer import Embedding, Linear
+from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from colossalai.utils import get_current_device
from .bias_dropout_add import bias_dropout_add_fused_train
diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py
new file mode 100644
index 000000000000..eb0335c01ce2
--- /dev/null
+++ b/colossalai/kernel/triton/__init__.py
@@ -0,0 +1,5 @@
+from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
+from .copy_kv_cache_dest import copy_kv_cache_to_dest
+from .fused_layernorm import layer_norm
+from .rms_norm import rmsnorm_forward
+from .softmax import softmax
diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py
new file mode 100644
index 000000000000..38db2048c6a4
--- /dev/null
+++ b/colossalai/kernel/triton/context_attention.py
@@ -0,0 +1,184 @@
+import torch
+import math
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+
+if HAS_TRITON:
+ '''
+ this function is modified from
+ https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
+ '''
+ @triton.jit
+ def _context_flash_attention_kernel(
+ Q, K, V, sm_scale,
+ B_Start_Loc, B_Seqlen,
+ TMP,
+ alibi_ptr,
+ Out,
+ stride_qbs, stride_qh, stride_qd,
+ stride_kbs, stride_kh, stride_kd,
+ stride_vbs, stride_vh, stride_vd,
+ stride_obs, stride_oh, stride_od,
+ stride_tmp_b, stride_tmp_h, stride_tmp_s,
+ # suggtest set-up 64, 128, 256, 512
+ BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ ):
+
+ batch_id = tl.program_id(0)
+ cur_head = tl.program_id(1)
+ start_m = tl.program_id(2)
+
+ # initialize offsets
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+
+ # get batch info
+ cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
+ cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
+ block_start_loc = BLOCK_M * start_m
+
+ load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
+ q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
+
+ k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
+ v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
+ t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
+
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+
+ if alibi_ptr is not None:
+ alibi_m = tl.load(alibi_ptr + cur_head)
+
+ block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
+
+ for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+ k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
+ mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
+
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k)
+ qk *= sm_scale
+
+ if alibi_ptr is not None:
+ alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
+ qk -= alibi_loc * alibi_m
+
+ qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
+
+ m_ij = tl.max(qk, 1)
+ p = tl.exp(qk - m_ij[:, None])
+ l_ij = tl.sum(p, 1)
+ # -- update m_i and l_i
+ m_i_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - m_i_new)
+ beta = tl.exp(m_ij - m_i_new)
+ l_i_new = alpha * l_i + beta * l_ij
+ # -- update output accumulator --
+ # scale p
+ p_scale = beta / l_i_new
+ p = p * p_scale[:, None]
+ # scale acc
+ acc_scale = l_i / l_i_new * alpha
+ tl.store(t_ptrs, acc_scale)
+ acc_scale = tl.load(t_ptrs)
+ acc = acc * acc_scale[:, None]
+ # update acc
+ v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
+ mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
+
+ p = p.to(v.dtype)
+ acc += tl.dot(p, v)
+ # update m_i and l_i
+ l_i = l_i_new
+ m_i = m_i_new
+
+ off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
+ out_ptrs = Out + off_o
+ tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
+ return
+
+
+ @torch.no_grad()
+ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
+ BLOCK = 128
+ # shape constraints
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+ assert Lq == Lk, "context process only supports equal query, key, value length"
+ assert Lk == Lv, "context process only supports equal query, key, value length"
+ assert Lk in {16, 32, 64, 128}
+
+ sm_scale = 1.0 / math.sqrt(Lk)
+ batch, head = b_seq_len.shape[0], q.shape[1]
+
+ grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
+
+ num_warps = 4 if Lk <= 64 else 8
+
+ tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
+
+ _context_flash_attention_kernel[grid](
+ q, k, v, sm_scale,
+ b_start_loc, b_seq_len,
+ tmp,
+ alibi,
+ o,
+ q.stride(0), q.stride(1), q.stride(2),
+ k.stride(0), k.stride(1), k.stride(2),
+ v.stride(0), v.stride(1), v.stride(2),
+ o.stride(0), o.stride(1), o.stride(2),
+ tmp.stride(0), tmp.stride(1), tmp.stride(2),
+ # manually setting this blcok num, we can use tuning config to futher speed-up
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=Lk,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
+
+ @torch.no_grad()
+ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
+ BLOCK = 128
+ # shape constraints
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+ assert Lq == Lk, "context process only supports equal query, key, value length"
+ assert Lk == Lv, "context process only supports equal query, key, value length"
+ assert Lk in {16, 32, 64, 128}
+
+ sm_scale = 1.0 / math.sqrt(Lk)
+ batch, head = b_seq_len.shape[0], q.shape[1]
+
+ grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
+
+ tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
+ num_warps = 4 if Lk <= 64 else 8
+ # num_warps = 4
+ _context_flash_attention_kernel[grid](
+ q, k, v, sm_scale, b_start_loc, b_seq_len,
+ tmp,
+ None,
+ o,
+ q.stride(0), q.stride(1), q.stride(2),
+ k.stride(0), k.stride(1), k.stride(2),
+ v.stride(0), v.stride(1), v.stride(2),
+ o.stride(0), o.stride(1), o.stride(2),
+ tmp.stride(0), tmp.stride(1), tmp.stride(2),
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=Lk,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
\ No newline at end of file
diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py
new file mode 100644
index 000000000000..c1eaa8a10ed1
--- /dev/null
+++ b/colossalai/kernel/triton/copy_kv_cache_dest.py
@@ -0,0 +1,69 @@
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ @triton.jit
+ def _fwd_copy_kv_cache_dest(
+ kv_cache_ptr, dest_index_ptr,
+ out,
+ stride_k_bs,
+ stride_k_h,
+ stride_k_d,
+ stride_o_bs,
+ stride_o_h,
+ stride_o_d,
+ head_num,
+ BLOCK_DMODEL: tl.constexpr,
+ BLOCK_HEAD: tl.constexpr
+ ):
+ cur_index = tl.program_id(0)
+ offs_h = tl.arange(0, BLOCK_HEAD)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+
+ dest_index = tl.load(dest_index_ptr + cur_index)
+
+ cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
+ k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets
+
+ o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
+ o_ptrs = out + dest_index * stride_o_bs + o_offsets
+
+ k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
+ tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
+ return
+
+
+ @torch.no_grad()
+ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
+ seq_len = dest_index_ptr.shape[0]
+ head_num = k_ptr.shape[1]
+ head_dim = k_ptr.shape[2]
+ assert head_num == out.shape[1], "head_num should be the same for k_ptr and out"
+ assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
+
+ num_warps = 2
+
+ _fwd_copy_kv_cache_dest[(seq_len,)](
+ k_ptr, dest_index_ptr, out,
+ k_ptr.stride(0),
+ k_ptr.stride(1),
+ k_ptr.stride(2),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ head_num,
+ BLOCK_DMODEL=head_dim,
+ BLOCK_HEAD=triton.next_power_of_2(head_num),
+ num_warps=num_warps,
+ num_stages=2,
+ )
+ return
+
+
diff --git a/colossalai/kernel/triton/fused_layernorm.py b/colossalai/kernel/triton/fused_layernorm.py
new file mode 100644
index 000000000000..99800acfbb92
--- /dev/null
+++ b/colossalai/kernel/triton/fused_layernorm.py
@@ -0,0 +1,83 @@
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ # CREDITS: These functions are adapted from the Triton tutorial
+ # https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+
+ @triton.jit
+ def _layer_norm_fwd_fused(
+ X, # pointer to the input
+ Y, # pointer to the output
+ W, # pointer to the weights
+ B, # pointer to the biases
+ stride, # how much to increase the pointer when moving by 1 row
+ N, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ Y += row * stride
+ X += row * stride
+ # Compute mean
+ mean = 0
+ _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
+ _mean += a
+ mean = tl.sum(_mean, axis=0) / N
+ # Compute variance
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
+ x = tl.where(cols < N, x - mean, 0.)
+ _var += x * x
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # Normalize and apply linear transformation
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ w = tl.load(W + cols, mask=mask)
+ b = tl.load(B + cols, mask=mask)
+ x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
+ x_hat = (x - mean) * rstd
+ y = x_hat * w + b
+ # Write output
+ tl.store(Y + cols, y.to(tl.float16), mask=mask)
+
+ @torch.no_grad()
+ def layer_norm(x, weight, bias, eps):
+ # allocate output
+ y = torch.empty_like(x)
+ # reshape input data into 2D tensor
+ x_arg = x.reshape(-1, x.shape[-1])
+ M, N = x_arg.shape
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ if N > BLOCK_SIZE:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ # enqueue kernel
+ _layer_norm_fwd_fused[(M,)](x_arg,
+ y,
+ weight,
+ bias,
+ x_arg.stride(0),
+ N,
+ eps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps)
+ return y
diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py
new file mode 100644
index 000000000000..1fb79115f8ce
--- /dev/null
+++ b/colossalai/kernel/triton/rms_norm.py
@@ -0,0 +1,72 @@
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+
+if HAS_TRITON:
+ '''
+ this kernel function is modified from
+ https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
+ '''
+ @triton.jit
+ def _rms_norm_fwd_fused(
+ X, # pointer to the input
+ Y, # pointer to the output
+ W, # pointer to the weights
+ stride, # how much to increase the pointer when moving by 1 row
+ N, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ Y += row * stride
+ X += row * stride
+ # Compute variance
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
+ _var += x * x
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # Normalize and apply linear transformation
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
+ x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
+ x_hat = x * rstd
+ y = x_hat * w
+ # Write output
+ tl.store(Y + cols, y.to(tl.float16), mask=mask)
+
+
+ def rmsnorm_forward(x, weight, eps):
+ # allocate output
+ y = torch.empty_like(x)
+ # reshape input data into 2D tensor
+ x_arg = x.view(-1, x.shape[-1])
+ M, N = x_arg.shape
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ # print("BLOCK_SIZE:", BLOCK_SIZE)
+ if N > BLOCK_SIZE:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ # print(BLOCK_SIZE, num_warps, "block_size, numwarps")
+ BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
+ num_warps = 8
+ # enqueue kernel
+ _rms_norm_fwd_fused[(M,)](x_arg, y, weight,
+ x_arg.stride(0), N, eps,
+ BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
+ return y
diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py
new file mode 100644
index 000000000000..d9d1b2bcf026
--- /dev/null
+++ b/colossalai/kernel/triton/rotary_embedding_kernel.py
@@ -0,0 +1,93 @@
+# Adapted from ModelTC https://github.com/ModelTC/lightllm
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _rotary_kernel(
+ q,
+ Cos,
+ Sin,
+ q_bs_stride,
+ q_h_stride,
+ q_d_stride,
+ cos_bs_stride,
+ cos_d_stride,
+ total_len,
+ HEAD_NUM: tl.constexpr,
+ BLOCK_HEAD: tl.constexpr,
+ BLOCK_SEQ: tl.constexpr,
+ HEAD_DIM: tl.constexpr,
+):
+ current_head_index = tl.program_id(0)
+ current_seq_index = tl.program_id(1)
+
+ current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
+ current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
+
+ dim_range0 = tl.arange(0, HEAD_DIM // 2)
+ dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
+
+ off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
+ None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride
+ off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
+ None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride
+
+ off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
+
+ q0 = tl.load(q + off_q0,
+ mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
+ other=0.0)
+ q1 = tl.load(q + off_q1,
+ mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
+ other=0.0)
+
+ cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
+ sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
+
+ out0 = q0 * cos - q1 * sin
+ out1 = q0 * sin + q1 * cos
+
+ tl.store(q + off_q0,
+ out0,
+ mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
+ tl.store(q + off_q1,
+ out1,
+ mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
+
+ return
+
+
+@torch.no_grad()
+def rotary_embedding_fwd(q, cos, sin):
+ total_len = q.shape[0]
+ head_num = q.shape[1]
+ head_dim = q.shape[2]
+ assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
+ BLOCK_HEAD = 4
+ BLOCK_SEQ = 32
+ grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
+ if head_dim >= 128:
+ num_warps = 8
+ else:
+ num_warps = 4
+
+ _rotary_kernel[grid](
+ q,
+ cos,
+ sin,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ cos.stride(0),
+ cos.stride(1),
+ total_len,
+ HEAD_NUM=head_num,
+ BLOCK_HEAD=BLOCK_HEAD,
+ BLOCK_SEQ=BLOCK_SEQ,
+ HEAD_DIM=head_dim,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/self_attention_nofusion.py
similarity index 57%
rename from colossalai/kernel/triton/ops.py
rename to colossalai/kernel/triton/self_attention_nofusion.py
index 5e8d4ba3ec99..6ae54dcb0b38 100644
--- a/colossalai/kernel/triton/ops.py
+++ b/colossalai/kernel/triton/self_attention_nofusion.py
@@ -11,10 +11,11 @@
if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
- from .softmax_kernel import softmax_kernel
+ from .softmax import softmax_kernel
- def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
- r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
+ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+ input_mask: torch.Tensor, scale: float):
+ r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
@@ -36,39 +37,49 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
# head_size * num_of_head
d_model = q.shape[-1] * q.shape[-2]
- score_output = torch.empty(
- (batches, H, M, N), device=q.device, dtype=q.dtype)
+ score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)
grid = lambda meta: (
batches,
H,
- triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
- triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
qkv_gemm_4d_kernel[grid](
- q, k, score_output,
- M, N, K,
- q.stride(0), q.stride(2), q.stride(1), q.stride(3),
- k.stride(0), k.stride(2), k.stride(3), k.stride(1),
- score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
+ q,
+ k,
+ score_output,
+ M,
+ N,
+ K,
+ q.stride(0),
+ q.stride(2),
+ q.stride(1),
+ q.stride(3),
+ k.stride(0),
+ k.stride(2),
+ k.stride(3),
+ k.stride(1),
+ score_output.stride(0),
+ score_output.stride(1),
+ score_output.stride(2),
+ score_output.stride(3),
scale=scale,
- # currently manually setting, later on we can use auto-tune config to match best setting
+ # currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)
-
- softmax_output = torch.empty(
- score_output.shape, device=score_output.device, dtype=score_output.dtype)
+
+ softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype)
score_output_shape = score_output.shape
score_output = score_output.view(-1, score_output.shape[-1])
n_rows, n_cols = score_output.shape
if n_rows <= 350000:
-
+
block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if block_size >= 4096:
@@ -78,37 +89,39 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
else:
num_warps = 4
- softmax_kernel[(n_rows, )](
+ softmax_kernel[(n_rows,)](
softmax_output,
score_output,
score_output.stride(0),
n_cols,
- mask_ptr = input_mask,
+ mask_ptr=input_mask,
num_warps=num_warps,
BLOCK_SIZE=block_size,
)
else:
- #TODO: change softmax kernel functions to make it suitable for large size dimension
+ # NOTE: change softmax kernel functions to make it suitable for large size dimension
softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
softmax_output = softmax_output.view(*score_output_shape)
batches, H, M, K = softmax_output.shape
N = v.shape[-1]
- output = torch.empty(
- (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
+ output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
grid = lambda meta: (
batches,
H,
- triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
- triton.cdiv(N, meta["BLOCK_SIZE_N"]),
+ triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
qkv_gemm_4d_kernel[grid](
- softmax_output, v, output,
- M, N, K,
+ softmax_output,
+ v,
+ output,
+ M,
+ N,
+ K,
softmax_output.stride(0),
softmax_output.stride(1),
softmax_output.stride(2),
@@ -129,7 +142,6 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
)
return output.view(batches, -1, d_model)
-
def self_attention_compute_using_triton(qkv,
input_mask,
layer_past,
@@ -152,58 +164,6 @@ def self_attention_compute_using_triton(qkv,
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)
- data_output_triton = self_attention_forward_without_fusion(
- q, k, v, input_mask, scale)
+ data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale)
return data_output_triton
-
-
- def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
- if mask is not None:
- assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
- assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
-
- hidden_dim = input.shape[-1]
- output = torch.empty_like(input)
- input = input.view(-1, hidden_dim)
- if mask is not None:
- mask = mask.view(-1, hidden_dim)
- assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
-
- num_rows, num_cols = input.shape
- block_size = max(triton.next_power_of_2(num_cols), 2)
- num_warps = 16
- if block_size >= 4096:
- num_warps = 16
- elif block_size >= 2048:
- num_warps = 8
- else:
- num_warps = 4
-
- if num_rows <= 350000:
- grid = (num_rows,)
- softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
- else:
- grid = lambda meta: ()
-
- grid = lambda meta: (
- triton.cdiv(num_rows, meta["BLOCK_M"]),
- )
-
- BLOCK_M = 32
- if block_size >= 4096:
- BLOCK_M = 4
- elif block_size >= 2048:
- BLOCK_M = 8
-
- softmax_kernel_2[grid](output_ptr = output,
- input_ptr = input,
- row_stride = input.stride(0),
- n_rows = num_rows,
- n_cols = num_cols,
- mask_ptr = mask,
- # currently manually setting up size
- BLOCK_M = 32,
- BLOCK_SIZE = block_size)
-
- return output
\ No newline at end of file
diff --git a/colossalai/kernel/triton/softmax.py b/colossalai/kernel/triton/softmax.py
new file mode 100644
index 000000000000..c65adaf40dda
--- /dev/null
+++ b/colossalai/kernel/triton/softmax.py
@@ -0,0 +1,96 @@
+import torch
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ '''
+ softmax kernel is modified based on
+ https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
+ '''
+ @triton.jit
+ def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
+ r""" the kernel function for implementing softmax operator
+ Args:
+ output_ptr: the output after finishing softmax operation, (N, hidden_dim)
+ input_ptr: the tensor of input, shape should be (N, hidden_dim)
+ n_cols(tl.constexpr): the number of cols of input
+ BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
+ """
+ row_idx = tl.program_id(0)
+ row_start_ptr = input_ptr + row_idx * row_stride
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ input_ptrs = row_start_ptr + col_offsets
+ row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
+ row_minus_max = row - tl.max(row, axis=0)
+
+ if mask_ptr is not None:
+ # load mask into SRAM
+ mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
+ mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
+
+ # update
+ row_minus_max = row_minus_max + mask
+
+ numerator = tl.exp(row_minus_max)
+ denominator = tl.sum(numerator, axis=0)
+ softmax_output = numerator / denominator
+ output_row_start_ptr = output_ptr + row_idx * row_stride
+ output_ptrs = output_row_start_ptr + col_offsets
+ # Write back output to DRAM
+ tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
+
+
+ def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
+ if mask is not None:
+ assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
+ assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"
+
+ hidden_dim = input.shape[-1]
+ output = torch.empty_like(input)
+ input = input.view(-1, hidden_dim)
+ if mask is not None:
+ mask = mask.view(-1, hidden_dim)
+ assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"
+
+ num_rows, num_cols = input.shape
+ block_size = max(triton.next_power_of_2(num_cols), 2)
+ num_warps = 16
+ if block_size >= 4096:
+ num_warps = 16
+ elif block_size >= 2048:
+ num_warps = 8
+ else:
+ num_warps = 4
+
+ if num_rows <= 350000:
+ grid = (num_rows,)
+ softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
+ else:
+ grid = lambda meta: ()
+
+ grid = lambda meta: (
+ triton.cdiv(num_rows, meta["BLOCK_M"]),
+ )
+
+ BLOCK_M = 32
+ if block_size >= 4096:
+ BLOCK_M = 4
+ elif block_size >= 2048:
+ BLOCK_M = 8
+
+ softmax_kernel[grid](output_ptr = output,
+ input_ptr = input,
+ row_stride = input.stride(0),
+ n_rows = num_rows,
+ n_cols = num_cols,
+ mask_ptr = mask,
+ # currently manually setting up size
+ BLOCK_M = 32,
+ BLOCK_SIZE = block_size)
+
+ return output
\ No newline at end of file
diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py
deleted file mode 100644
index c215890badff..000000000000
--- a/colossalai/kernel/triton/softmax_kernel.py
+++ /dev/null
@@ -1,44 +0,0 @@
-try:
- import triton
- import triton.language as tl
- HAS_TRITON = True
-except ImportError:
- HAS_TRITON = False
- print("please install triton from https://github.com/openai/triton")
-
-if HAS_TRITON:
- '''
- softmax kernel is modified based on
- https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
- '''
- @triton.jit
- def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
- r""" the kernel function for implementing softmax operator
- Args:
- output_ptr: the output after finishing softmax operation, (N, hidden_dim)
- input_ptr: the tensor of input, shape should be (N, hidden_dim)
- n_cols(tl.constexpr): the number of cols of input
- BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
- """
- row_idx = tl.program_id(0)
- row_start_ptr = input_ptr + row_idx * row_stride
- col_offsets = tl.arange(0, BLOCK_SIZE)
- input_ptrs = row_start_ptr + col_offsets
- row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
- row_minus_max = row - tl.max(row, axis=0)
-
- if mask_ptr is not None:
- # load mask into SRAM
- mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
- mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
-
- # update
- row_minus_max = row_minus_max + mask
-
- numerator = tl.exp(row_minus_max)
- denominator = tl.sum(numerator, axis=0)
- softmax_output = numerator / denominator
- output_row_start_ptr = output_ptr + row_idx * row_stride
- output_ptrs = output_row_start_ptr + col_offsets
- # Write back output to DRAM
- tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
\ No newline at end of file
diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py
new file mode 100644
index 000000000000..c6b25f4abcec
--- /dev/null
+++ b/colossalai/kernel/triton/token_attention_kernel.py
@@ -0,0 +1,333 @@
+# Adapted from ModelTC https://github.com/ModelTC/lightllm
+
+import math
+
+import torch
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+
+ @triton.jit
+ def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
+ attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride,
+ q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride,
+ attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr):
+ current_batch = tl.program_id(0)
+ current_head = tl.program_id(1)
+ start_n = tl.program_id(2)
+
+ offs_d = tl.arange(0, HEAD_DIM)
+ current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
+ current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
+
+ current_batch_start_index = max_kv_cache_len - current_batch_seq_len
+ current_batch_end_index = max_kv_cache_len
+
+ off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
+
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ block_stard_index = start_n * BLOCK_N
+ block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
+
+ for start_mark in range(0, block_mask, 1):
+ q = tl.load(Q + off_q + start_mark)
+ offs_n_new = current_batch_start_index + offs_n
+ k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
+ mask=offs_n_new < current_batch_end_index,
+ other=0)
+ off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
+ k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
+ att_value = tl.sum(q[None, :] * k, 1)
+ att_value *= sm_scale
+ off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
+ tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
+ return
+
+ @triton.jit
+ def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen,
+ max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride,
+ q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride,
+ k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr,
+ BLOCK_N: tl.constexpr):
+ current_batch = tl.program_id(0)
+ current_head = tl.program_id(1)
+ start_n = tl.program_id(2)
+
+ offs_d = tl.arange(0, HEAD_DIM)
+ current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
+ current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
+
+ current_batch_start_index = max_kv_cache_len - current_batch_seq_len
+ current_batch_end_index = max_kv_cache_len
+
+ off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
+
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ block_stard_index = start_n * BLOCK_N
+ block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
+
+ for start_mark in range(0, block_mask, 1):
+ alibi_m = tl.load(alibi + current_head)
+ q = tl.load(Q + off_q + start_mark)
+ offs_n_new = current_batch_start_index + offs_n
+ k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
+ mask=offs_n_new < current_batch_end_index,
+ other=0)
+ off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
+ k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
+ att_value = tl.sum(q[None, :] * k, 1)
+ att_value *= sm_scale
+ att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
+ off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
+ tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
+ return
+
+ @torch.no_grad()
+ def token_attn_fwd_1(q,
+ k,
+ attn_out,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ max_kv_cache_len,
+ alibi=None):
+ BLOCK = 32
+ # shape constraints
+ q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
+ assert q_head_dim == k_head_dim
+ assert k_head_dim in {16, 32, 64, 128}
+ sm_scale = 1.0 / (k_head_dim**0.5)
+
+ batch, head_num = kv_cache_loc.shape[0], q.shape[1]
+
+ grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
+
+ num_warps = 4 if k_head_dim <= 64 else 8
+ num_warps = 2
+
+ if alibi is not None:
+ _token_attn_1_alibi_kernel[grid](
+ q,
+ k,
+ sm_scale,
+ alibi,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ max_kv_cache_len,
+ attn_out,
+ kv_cache_loc.stride(0),
+ kv_cache_loc.stride(1),
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ attn_out.stride(0),
+ attn_out.stride(1),
+ HEAD_DIM=k_head_dim,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ else:
+ _token_attn_1_kernel[grid](
+ q,
+ k,
+ sm_scale,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ max_kv_cache_len,
+ attn_out,
+ kv_cache_loc.stride(0),
+ kv_cache_loc.stride(1),
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ attn_out.stride(0),
+ attn_out.stride(1),
+ HEAD_DIM=k_head_dim,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
+
+ @triton.jit
+ def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out,
+ logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride,
+ BLOCK_SIZE: tl.constexpr):
+ current_batch = tl.program_id(0)
+ current_head = tl.program_id(1)
+
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
+ current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
+
+ row = tl.load(softmax_logics + current_head * logics_head_dim_stride +
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
+ mask=col_offsets < current_batch_seq_len,
+ other=-float('inf')).to(tl.float32)
+
+ row_minus_max = row - tl.max(row, axis=0)
+ numerator = tl.exp(row_minus_max)
+ denominator = tl.sum(numerator, axis=0)
+ softmax_output = numerator / denominator
+
+ tl.store(softmax_prob_out + current_head * prob_head_dim_stride +
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
+ softmax_output,
+ mask=col_offsets < current_batch_seq_len)
+ return
+
+ @torch.no_grad()
+ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
+ BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
+ batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
+
+ num_warps = 4
+ if BLOCK_SIZE >= 2048:
+ num_warps = 8
+ if BLOCK_SIZE >= 4096:
+ num_warps = 16
+
+ _token_attn_softmax_fwd[(batch, head_num)](
+ softmax_logics,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ softmax_prob_out,
+ softmax_logics.stride(0),
+ softmax_logics.stride(1),
+ softmax_prob_out.stride(0),
+ softmax_prob_out.stride(1),
+ num_warps=num_warps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ return
+
+ @triton.jit
+ def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len,
+ kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride,
+ v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride,
+ attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr,
+ BLOCK_N: tl.constexpr):
+ current_batch = tl.program_id(0)
+ current_head = tl.program_id(1)
+
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, HEAD_DIM)
+ current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
+ current_batch_start_index = max_kv_cache_len - current_batch_seq_len
+ current_batch_end_index = current_batch_seq_len
+ current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
+
+ v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
+ p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
+ v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
+
+ acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
+ for start_n in range(0, current_batch_seq_len, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+ p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride,
+ mask=(start_n + offs_n) < current_batch_seq_len,
+ other=0.0)
+ v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
+ mask=(start_n + offs_n) < current_batch_seq_len,
+ other=0.0)
+ v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride,
+ mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
+ other=0.0)
+ acc += tl.sum(p_value[:, None] * v_value, 0)
+
+ acc = acc.to(tl.float16)
+ off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride
+ out_ptrs = attn_out + off_o
+ tl.store(out_ptrs, acc)
+ return
+
+ @torch.no_grad()
+ def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len):
+ if triton.__version__ >= "2.1.0":
+ BLOCK = 128
+ else:
+ BLOCK = 64
+ batch, head = kv_cache_loc.shape[0], v.shape[1]
+ grid = (batch, head)
+ num_warps = 4
+ dim = v.shape[-1]
+
+ _token_attn_2_kernel[grid](
+ prob,
+ v,
+ attn_out,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seqlen,
+ max_kv_cache_len,
+ kv_cache_loc.stride(0),
+ kv_cache_loc.stride(1),
+ prob.stride(0),
+ prob.stride(1),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ attn_out.stride(0),
+ attn_out.stride(1),
+ attn_out.stride(2),
+ HEAD_DIM=dim,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ return
+
+ @torch.no_grad()
+ def token_attention_fwd(q,
+ k,
+ v,
+ attn_out,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seq_len,
+ max_len_in_batch,
+ alibi=None):
+ head_num = k.shape[1]
+ batch_size = kv_cache_seq_len.shape[0]
+ calcu_shape1 = (batch_size, head_num, k.shape[2])
+ total_token_num = k.shape[0]
+
+ att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
+
+ token_attn_fwd_1(q.view(calcu_shape1),
+ k,
+ att_m_tensor,
+ kv_cache_loc,
+ kv_cache_start_loc,
+ kv_cache_seq_len,
+ max_len_in_batch,
+ alibi=alibi)
+
+ prob = torch.empty_like(att_m_tensor)
+
+ token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
+ att_m_tensor = None
+ token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len,
+ max_len_in_batch)
+
+ prob = None
+
+ return
diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py
index 1f5345015bf2..e071563c045a 100644
--- a/colossalai/lazy/lazy_init.py
+++ b/colossalai/lazy/lazy_init.py
@@ -6,6 +6,7 @@
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
+from torch.nn import Parameter
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
@@ -99,8 +100,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the converted tensor
"""
- cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
+ cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
tensor.__class__ = cls_to_become
+ if cls_to_become is Parameter:
+ # to fit UninitializedParameter
+ delattr(tensor, '_is_param')
tensor.data = target
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
@@ -198,10 +202,10 @@ def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> to
def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
"""
- self._factory_method = None
- self._op_buffer = None
- self._materialized_data = None
- self._meta_data = None
+ delattr(self, '_factory_method')
+ delattr(self, '_op_buffer')
+ delattr(self, '_materialized_data')
+ delattr(self, '_meta_data')
@staticmethod
def _replace_with_materialized(x):
@@ -350,20 +354,19 @@ def __deepcopy__(self, memo):
def factory_fn():
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
- copied = new_tensor.detach().clone()
- if new_tensor.requires_grad:
- copied.requires_grad_()
- return copied
+ return _copy_tensor(new_tensor, new_tensor.requires_grad)
if self._materialized_data is not None:
# self is early materialized
- copied = self._materialized_data.detach().clone()
- if self.requires_grad:
- copied.requires_grad_()
+ copied = _copy_tensor(self._materialized_data, self.requires_grad)
target = LazyTensor(lambda: None, concrete_data=copied)
else:
target = LazyTensor(factory_fn, meta_data=self._meta_data)
+ if isinstance(self, Parameter):
+ # hack isinstance check of parameter
+ target._is_param = True
+
memo[id(self)] = target
return target
@@ -408,6 +411,10 @@ def tolist(self) -> list:
def __hash__(self):
return id(self)
+ def __rpow__(self, other):
+ dtype = torch.result_type(self, other)
+ return torch.tensor(other, dtype=dtype, device=self.device)**self
+
class LazyInitContext:
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
@@ -536,7 +543,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
@staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
- """Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
+ """Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
module (nn.Module): Target ``nn.Module``
@@ -553,7 +560,7 @@ def distribute(module: nn.Module,
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
- """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
+ """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
module (nn.Module): Target ``nn.Module``
@@ -625,3 +632,9 @@ def _is_int_tuple(args) -> bool:
if not isinstance(x, int):
return False
return True
+
+
+def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor:
+ copied = tensor.data.clone()
+ copied.requires_grad = requires_grad
+ return copied
diff --git a/tests/test_layers/test_2d/checks_2d/__init__.py b/colossalai/legacy/__init__.py
similarity index 100%
rename from tests/test_layers/test_2d/checks_2d/__init__.py
rename to colossalai/legacy/__init__.py
diff --git a/colossalai/builder/__init__.py b/colossalai/legacy/builder/__init__.py
similarity index 100%
rename from colossalai/builder/__init__.py
rename to colossalai/legacy/builder/__init__.py
diff --git a/colossalai/builder/builder.py b/colossalai/legacy/builder/builder.py
similarity index 96%
rename from colossalai/builder/builder.py
rename to colossalai/legacy/builder/builder.py
index 4a907601327c..ff14f46dc61f 100644
--- a/colossalai/builder/builder.py
+++ b/colossalai/legacy/builder/builder.py
@@ -3,7 +3,7 @@
import inspect
-from colossalai.registry import *
+from colossalai.legacy.registry import *
def build_from_config(module, config: dict):
@@ -71,7 +71,7 @@ def build_gradient_handler(config, model, optimizer):
optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
Returns:
- An object of :class:`colossalai.engine.BaseGradientHandler`
+ An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
"""
config_ = config.copy()
config_['model'] = model
diff --git a/colossalai/communication/__init__.py b/colossalai/legacy/communication/__init__.py
similarity index 53%
rename from colossalai/communication/__init__.py
rename to colossalai/legacy/communication/__init__.py
index 220481b7af15..88ad0487b785 100644
--- a/colossalai/communication/__init__.py
+++ b/colossalai/legacy/communication/__init__.py
@@ -1,9 +1,17 @@
-from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
-from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward,
- send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
- recv_forward, recv_backward)
+from .collective import all_gather, all_reduce, broadcast, reduce, reduce_scatter
+from .p2p import (
+ recv_backward,
+ recv_forward,
+ send_backward,
+ send_backward_recv_backward,
+ send_backward_recv_forward,
+ send_forward,
+ send_forward_backward_recv_forward_backward,
+ send_forward_recv_backward,
+ send_forward_recv_forward,
+)
from .ring import ring_forward
-from .utils import send_obj_meta, recv_obj_meta
+from .utils import recv_obj_meta, send_obj_meta
__all__ = [
'all_gather',
diff --git a/colossalai/communication/collective.py b/colossalai/legacy/communication/collective.py
similarity index 100%
rename from colossalai/communication/collective.py
rename to colossalai/legacy/communication/collective.py
diff --git a/colossalai/communication/p2p.py b/colossalai/legacy/communication/p2p.py
similarity index 100%
rename from colossalai/communication/p2p.py
rename to colossalai/legacy/communication/p2p.py
diff --git a/colossalai/communication/p2p_v2.py b/colossalai/legacy/communication/p2p_v2.py
similarity index 100%
rename from colossalai/communication/p2p_v2.py
rename to colossalai/legacy/communication/p2p_v2.py
diff --git a/colossalai/communication/ring.py b/colossalai/legacy/communication/ring.py
similarity index 100%
rename from colossalai/communication/ring.py
rename to colossalai/legacy/communication/ring.py
diff --git a/colossalai/communication/utils.py b/colossalai/legacy/communication/utils.py
similarity index 100%
rename from colossalai/communication/utils.py
rename to colossalai/legacy/communication/utils.py
diff --git a/colossalai/engine/__init__.py b/colossalai/legacy/engine/__init__.py
similarity index 100%
rename from colossalai/engine/__init__.py
rename to colossalai/legacy/engine/__init__.py
diff --git a/colossalai/engine/_base_engine.py b/colossalai/legacy/engine/_base_engine.py
similarity index 97%
rename from colossalai/engine/_base_engine.py
rename to colossalai/legacy/engine/_base_engine.py
index db27ad0e8abe..9af4469f403f 100644
--- a/colossalai/engine/_base_engine.py
+++ b/colossalai/legacy/engine/_base_engine.py
@@ -8,11 +8,17 @@
from torch.nn import Module
from torch.nn.modules.loss import _Loss
-from colossalai.engine.gradient_handler import BaseGradientHandler
-from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
+from colossalai.legacy.engine.gradient_handler import BaseGradientHandler
+from colossalai.legacy.engine.schedule import (
+ BaseSchedule,
+ InterleavedPipelineSchedule,
+ NonPipelineSchedule,
+ PipelineSchedule,
+)
from colossalai.logging import get_dist_logger
-from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
+
class Engine:
"""Basic engine class for training and evaluation. It runs a specific process method
diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/legacy/engine/gradient_accumulation/__init__.py
similarity index 94%
rename from colossalai/engine/gradient_accumulation/__init__.py
rename to colossalai/legacy/engine/gradient_accumulation/__init__.py
index 4cb6f4ad7384..670c26d06e55 100644
--- a/colossalai/engine/gradient_accumulation/__init__.py
+++ b/colossalai/legacy/engine/gradient_accumulation/__init__.py
@@ -4,7 +4,7 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.engine import BaseGradientHandler
from ._gradient_accumulation import (
GradAccumDataloader,
@@ -33,7 +33,7 @@ def accumulate_gradient(model: nn.Module,
dataloader (:class:`torch.utils.data.DataLoader` or iterable objects):
your dataloader object, would be called like iter(dataloader)
accumulate_size (int): the number of steps to accumulate gradients
- gradient_handlers (List[:class:`colossalai.engine.BaseGradientHandler`]):
+ gradient_handlers (List[:class:`colossalai.legacy.engine.BaseGradientHandler`]):
list of gradient handler objects. Default is None.
lr_scheduler (`torch.optim.lr_scheduler` or `colossalai.nn.lr_scheduler`):
your ``lr_scheduler`` object for gradient accumulation. Defaults to None.
diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
similarity index 98%
rename from colossalai/engine/gradient_accumulation/_gradient_accumulation.py
rename to colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
index cf66be1cd821..c466f7e2d03b 100644
--- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py
+++ b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py
@@ -10,7 +10,7 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
-from colossalai.engine import BaseGradientHandler
+from colossalai.legacy.engine import BaseGradientHandler
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import conditional_context
@@ -262,7 +262,7 @@ class GradAccumGradientHandler:
before accumulation size is reached.
Args:
- grad_handler (:class:`colossalai.engine.BaseGradientHandler`):
+ grad_handler (:class:`colossalai.legacy.engine.BaseGradientHandler`):
Your ``gradient_handler`` object for gradient accumulation, would be called when achieving `accumulate_size`.
accumulate_size (int): The number of steps to accumulate gradients.
diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py
similarity index 100%
rename from colossalai/engine/gradient_handler/__init__.py
rename to colossalai/legacy/engine/gradient_handler/__init__.py
diff --git a/colossalai/engine/gradient_handler/_base_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py
similarity index 100%
rename from colossalai/engine/gradient_handler/_base_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py
diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
similarity index 90%
rename from colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
index 5cc7169c5a9f..c5da2e55a0ed 100644
--- a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py
@@ -1,7 +1,7 @@
+from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
-from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
similarity index 94%
rename from colossalai/engine/gradient_handler/_moe_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
index b499345d4e18..395d83da0478 100644
--- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py
@@ -1,9 +1,9 @@
from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict
-from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
similarity index 97%
rename from colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
index 5b49a9c0360d..7d4d9d73afc8 100644
--- a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py
@@ -7,7 +7,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
similarity index 90%
rename from colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
index ea4f0fbb1c71..41098ab39d0c 100644
--- a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py
@@ -1,7 +1,7 @@
+from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
-from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
diff --git a/colossalai/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
similarity index 92%
rename from colossalai/engine/gradient_handler/_zero_gradient_handler.py
rename to colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
index 19fd1e97f86f..4ca7cd0b0702 100644
--- a/colossalai/engine/gradient_handler/_zero_gradient_handler.py
+++ b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py
@@ -1,4 +1,4 @@
-from colossalai.registry import GRADIENT_HANDLER
+from colossalai.legacy.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
diff --git a/colossalai/engine/gradient_handler/utils.py b/colossalai/legacy/engine/gradient_handler/utils.py
similarity index 100%
rename from colossalai/engine/gradient_handler/utils.py
rename to colossalai/legacy/engine/gradient_handler/utils.py
diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/legacy/engine/schedule/__init__.py
similarity index 100%
rename from colossalai/engine/schedule/__init__.py
rename to colossalai/legacy/engine/schedule/__init__.py
diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py
similarity index 98%
rename from colossalai/engine/schedule/_base_schedule.py
rename to colossalai/legacy/engine/schedule/_base_schedule.py
index a2d50041127a..7505a3eb20e3 100644
--- a/colossalai/engine/schedule/_base_schedule.py
+++ b/colossalai/legacy/engine/schedule/_base_schedule.py
@@ -95,7 +95,7 @@ def forward_backward_step(self,
"""The process function over a batch of dataset for training or evaluation.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
forward_only (bool): If True, the process won't include backward.
return_loss (bool, optional): If False, the loss won't be returned.
diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
similarity index 97%
rename from colossalai/engine/schedule/_non_pipeline_schedule.py
rename to colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
index b9239d928a7b..b67893c1a0bb 100644
--- a/colossalai/engine/schedule/_non_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py
@@ -54,7 +54,7 @@ def forward_backward_step(self,
The returned labels and loss will None if :attr:`return_loss` is False.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
If True, the model is run for the forward pass, else back propagation will be executed.
diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
similarity index 98%
rename from colossalai/engine/schedule/_pipeline_schedule.py
rename to colossalai/legacy/engine/schedule/_pipeline_schedule.py
index 9fc301a26559..4571fd679e8c 100644
--- a/colossalai/engine/schedule/_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
@@ -6,7 +6,7 @@
import torch.cuda
-import colossalai.communication as comm
+import colossalai.legacy.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@@ -236,7 +236,7 @@ def _forward_step(self, engine, input_obj, return_tensors, return_output_label=T
Returns output tensor. This is a helper function and can be ignored by users.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels.
@@ -274,7 +274,7 @@ def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
This is a helper function and can be ignored by users.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
@@ -314,7 +314,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
Returns a tuple with losses if the last stage, an empty tuple otherwise.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
@@ -518,7 +518,7 @@ def _forward_step(self,
Returns output tensor. This is a helper function and can be ignored by users.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
model_chunk_id (int): The id of model chunks.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
@@ -555,7 +555,7 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
communication between pipeline stages as needed.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
similarity index 96%
rename from colossalai/engine/schedule/_pipeline_schedule_v2.py
rename to colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
index 89e45c7aacec..385c615372f5 100644
--- a/colossalai/engine/schedule/_pipeline_schedule_v2.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
@@ -5,10 +5,10 @@
import torch.cuda
-import colossalai.communication.p2p_v2 as comm
-from colossalai import engine
+import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device
from ._pipeline_schedule import PipelineSchedule
@@ -60,7 +60,7 @@ def data_process_func(stage_output, dataloader_output):
"""
def forward_backward_step(self,
- engine: engine.Engine,
+ engine: Engine,
data_iter: Iterable,
forward_only=False,
return_loss=True,
@@ -69,7 +69,7 @@ def forward_backward_step(self,
Returns a tuple with losses if the last stage, an empty tuple otherwise.
Args:
- engine (colossalai.engine.Engine): Colossalai engine for training and inference.
+ engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
diff --git a/colossalai/legacy/nn/__init__.py b/colossalai/legacy/nn/__init__.py
new file mode 100644
index 000000000000..500162901905
--- /dev/null
+++ b/colossalai/legacy/nn/__init__.py
@@ -0,0 +1,4 @@
+from ._ops import *
+from .layer import *
+from .loss import *
+from .metric import *
diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/legacy/nn/_ops/__init__.py
similarity index 100%
rename from colossalai/nn/_ops/__init__.py
rename to colossalai/legacy/nn/_ops/__init__.py
diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/legacy/nn/_ops/_utils.py
similarity index 99%
rename from colossalai/nn/_ops/_utils.py
rename to colossalai/legacy/nn/_ops/_utils.py
index 24877bbb552f..131c2154771b 100644
--- a/colossalai/nn/_ops/_utils.py
+++ b/colossalai/legacy/nn/_ops/_utils.py
@@ -4,7 +4,7 @@
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.nn.layer.utils import divide
+from colossalai.legacy.nn.layer.utils import divide
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
GeneralTensor = Union[ColoTensor, torch.Tensor]
@@ -232,7 +232,7 @@ def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)
-### table wise embedding shard
+# table wise embedding shard
def _all_to_all_for_tablewise(x: torch.Tensor,
diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/legacy/nn/_ops/addmm.py
similarity index 100%
rename from colossalai/nn/_ops/addmm.py
rename to colossalai/legacy/nn/_ops/addmm.py
diff --git a/colossalai/nn/_ops/batch_norm.py b/colossalai/legacy/nn/_ops/batch_norm.py
similarity index 100%
rename from colossalai/nn/_ops/batch_norm.py
rename to colossalai/legacy/nn/_ops/batch_norm.py
diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/legacy/nn/_ops/element_wise.py
similarity index 100%
rename from colossalai/nn/_ops/element_wise.py
rename to colossalai/legacy/nn/_ops/element_wise.py
diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/legacy/nn/_ops/embedding.py
similarity index 98%
rename from colossalai/nn/_ops/embedding.py
rename to colossalai/legacy/nn/_ops/embedding.py
index a045f305b5dc..b145d1763380 100644
--- a/colossalai/nn/_ops/embedding.py
+++ b/colossalai/legacy/nn/_ops/embedding.py
@@ -1,8 +1,10 @@
-import torch.nn.functional as F
from typing import Optional
+
+import torch.nn.functional as F
+
+from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec
from colossalai.tensor.op_wrapper import colo_op_impl
-from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \
- ReplicaSpec
+
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/legacy/nn/_ops/embedding_bag.py
similarity index 97%
rename from colossalai/nn/_ops/embedding_bag.py
rename to colossalai/legacy/nn/_ops/embedding_bag.py
index 0026f579b6dc..9a656d5871a3 100644
--- a/colossalai/nn/_ops/embedding_bag.py
+++ b/colossalai/legacy/nn/_ops/embedding_bag.py
@@ -1,9 +1,11 @@
-import torch.nn.functional as F
from typing import Optional
+
+import torch.nn.functional as F
from torch import Tensor
+
+from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl
-from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \
- ShardSpec, ReplicaSpec
+
from ._utils import GeneralTensor, convert_to_colo_tensor
diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/legacy/nn/_ops/layernorm.py
similarity index 92%
rename from colossalai/nn/_ops/layernorm.py
rename to colossalai/legacy/nn/_ops/layernorm.py
index 2b761b84e3ee..9960c5d48096 100644
--- a/colossalai/nn/_ops/layernorm.py
+++ b/colossalai/legacy/nn/_ops/layernorm.py
@@ -1,7 +1,10 @@
from typing import List, Optional
+
import torch.nn.functional as F
+
+from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec
from colossalai.tensor.op_wrapper import colo_op_impl
-from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec
+
from ._utils import GeneralTensor, convert_to_colo_tensor
diff --git a/colossalai/nn/_ops/linear.py b/colossalai/legacy/nn/_ops/linear.py
similarity index 100%
rename from colossalai/nn/_ops/linear.py
rename to colossalai/legacy/nn/_ops/linear.py
diff --git a/colossalai/nn/_ops/loss.py b/colossalai/legacy/nn/_ops/loss.py
similarity index 96%
rename from colossalai/nn/_ops/loss.py
rename to colossalai/legacy/nn/_ops/loss.py
index 1e54f662859c..90efbfa36f2a 100644
--- a/colossalai/nn/_ops/loss.py
+++ b/colossalai/legacy/nn/_ops/loss.py
@@ -1,9 +1,12 @@
+from typing import Optional
+
import torch
import torch.nn.functional as F
-from typing import Optional
-from colossalai.tensor.op_wrapper import colo_op_impl
+
+from colossalai.legacy.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
from colossalai.tensor import ColoTensor, ColoTensorSpec
-from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
+from colossalai.tensor.op_wrapper import colo_op_impl
+
from ._utils import GeneralTensor, convert_to_colo_tensor
diff --git a/colossalai/nn/_ops/view.py b/colossalai/legacy/nn/_ops/view.py
similarity index 100%
rename from colossalai/nn/_ops/view.py
rename to colossalai/legacy/nn/_ops/view.py
diff --git a/colossalai/legacy/nn/layer/__init__.py b/colossalai/legacy/nn/layer/__init__.py
new file mode 100644
index 000000000000..86961dd933a7
--- /dev/null
+++ b/colossalai/legacy/nn/layer/__init__.py
@@ -0,0 +1,9 @@
+from .colossalai_layer import *
+from .parallel_1d import *
+from .parallel_2d import *
+from .parallel_2p5d import *
+from .parallel_3d import *
+from .parallel_sequence import *
+from .utils import *
+from .vanilla import *
+from .wrapper import *
diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/legacy/nn/layer/base_layer.py
similarity index 100%
rename from colossalai/nn/layer/base_layer.py
rename to colossalai/legacy/nn/layer/base_layer.py
diff --git a/colossalai/nn/layer/colossalai_layer/__init__.py b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py
similarity index 97%
rename from colossalai/nn/layer/colossalai_layer/__init__.py
rename to colossalai/legacy/nn/layer/colossalai_layer/__init__.py
index 2ae1b07a75b2..ed743820ddbc 100644
--- a/colossalai/nn/layer/colossalai_layer/__init__.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py
@@ -1,7 +1,7 @@
-from ._utils import partition_batch
-from .dropout import Dropout
-from .embedding import Embedding, PatchEmbedding
-from .linear import Classifier, Linear
-from .normalization import LayerNorm
-
-__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch']
+from ._utils import partition_batch
+from .dropout import Dropout
+from .embedding import Embedding, PatchEmbedding
+from .linear import Classifier, Linear
+from .normalization import LayerNorm
+
+__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch']
diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/legacy/nn/layer/colossalai_layer/_utils.py
similarity index 100%
rename from colossalai/nn/layer/colossalai_layer/_utils.py
rename to colossalai/legacy/nn/layer/colossalai_layer/_utils.py
diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py
similarity index 100%
rename from colossalai/nn/layer/colossalai_layer/dropout.py
rename to colossalai/legacy/nn/layer/colossalai_layer/dropout.py
diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py
similarity index 97%
rename from colossalai/nn/layer/colossalai_layer/embedding.py
rename to colossalai/legacy/nn/layer/colossalai_layer/embedding.py
index e5c9c46e0ff1..28bcb7ffefb0 100644
--- a/colossalai/nn/layer/colossalai_layer/embedding.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py
@@ -1,151 +1,152 @@
-import math
-from typing import Callable
-
-from colossalai.utils import get_current_device
-from torch import dtype, nn
-
-from ... import init as init
-from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
-from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
-from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D
-from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D
-from ..utils import get_tensor_parallel_mode
-from ..vanilla import VanillaPatchEmbedding
-from ._utils import ColossalaiModule
-
-_parallel_embedding = {
- '1d': Embedding1D,
- '2d': Embedding2D,
- '2.5d': Embedding2p5D,
- '3d': Embedding3D,
-}
-
-_vocab_parallel_embedding = {
- '1d': VocabParallelEmbedding1D,
- '2d': VocabParallelEmbedding2D,
- '2.5d': VocabParallelEmbedding2p5D,
- '3d': VocabParallelEmbedding3D
-}
-
-_parallel_patchembedding = {
- None: VanillaPatchEmbedding,
- '1d': PatchEmbedding1D,
- '2d': PatchEmbedding2D,
- '2.5d': PatchEmbedding2p5D,
- '3d': PatchEmbedding3D
-}
-
-
-class Embedding(ColossalaiModule):
- r"""Embedding for colossalai.
-
- Args:
- num_embeddings (int): number of embeddings.
- embedding_dim (int): dimension of embedding.
- padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
- therefore, the embedding vector at padding_idx is not updated during training,
- i.e. it remains as a fixed “pad”, defaults to None.
- dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
- weight_initializer (:class:`typing.Callable`, optional):
- he initializer of weight, defaults to normal initializer.
-
- The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
- ::
-
- max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
- renormalized to have norm max_norm. Note: this will modify weight in-place.
- norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
- scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
- of frequency of the words in the mini-batch. Default False.
- sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
-
- More details about ``args`` and ``kwargs`` could be found in
- `Embedding `_.
-
- More details about ``initializer`` please refer to
- `init `_
- """
-
- def __init__(self,
- num_embeddings: int,
- embedding_dim: int,
- padding_idx: int = None,
- dtype: dtype = None,
- weight_initializer: Callable = init.normal_(),
- vocab_parallel_limit: int = 2048,
- *args,
- **kwargs) -> None:
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel is None:
- embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
- **kwargs).to(dtype).to(get_current_device())
- weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
- elif num_embeddings <= vocab_parallel_limit:
- embed = _parallel_embedding[tensor_parallel](
- num_embeddings,
- embedding_dim,
- padding_idx=padding_idx,
- dtype=dtype,
- weight_initializer=weight_initializer,
- *args,
- **kwargs,
- )
- else:
- embed = _vocab_parallel_embedding[tensor_parallel](
- num_embeddings,
- embedding_dim,
- padding_idx=padding_idx,
- dtype=dtype,
- weight_initializer=weight_initializer,
- *args,
- **kwargs,
- )
- super().__init__(embed)
-
-
-class PatchEmbedding(ColossalaiModule):
- """2D Image to Patch Embedding.
-
- Args:
- img_size (int): image size.
- patch_size (int): patch size.
- in_chans (int): number of channels of input image.
- embed_size (int): size of embedding.
- dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
- flatten (bool, optional): whether to flatten output tensor, defaults to True.
- weight_initializer (:class:`typing.Callable`, optional):
- The initializer of weight, defaults to kaiming uniform initializer.
- bias_initializer (:class:`typing.Callable`, optional):
- The initializer of bias, defaults to xavier uniform initializer.
- position_embed_initializer (:class:`typing.Callable`, optional):
- The initializer of position embedding, defaults to zeros initializer.
-
- More details about ``initializer`` please refer to
- `init `_.
- """
-
- def __init__(
- self,
- img_size: int,
- patch_size: int,
- in_chans: int,
- embed_size: int,
- dtype: dtype = None,
- flatten: bool = True,
- weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
- bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
- position_embed_initializer: Callable = init.zeros_()
- ) -> None:
- tensor_parallel = get_tensor_parallel_mode()
- embed = _parallel_patchembedding[tensor_parallel](
- img_size,
- patch_size,
- in_chans,
- embed_size,
- dtype=dtype,
- flatten=flatten,
- weight_initializer=weight_initializer,
- bias_initializer=bias_initializer,
- position_embed_initializer=position_embed_initializer,
- )
- super().__init__(embed)
+import math
+from typing import Callable
+
+from torch import dtype, nn
+
+from colossalai.nn import init
+from colossalai.utils import get_current_device
+
+from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
+from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
+from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D
+from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D
+from ..utils import get_tensor_parallel_mode
+from ..vanilla import VanillaPatchEmbedding
+from ._utils import ColossalaiModule
+
+_parallel_embedding = {
+ '1d': Embedding1D,
+ '2d': Embedding2D,
+ '2.5d': Embedding2p5D,
+ '3d': Embedding3D,
+}
+
+_vocab_parallel_embedding = {
+ '1d': VocabParallelEmbedding1D,
+ '2d': VocabParallelEmbedding2D,
+ '2.5d': VocabParallelEmbedding2p5D,
+ '3d': VocabParallelEmbedding3D
+}
+
+_parallel_patchembedding = {
+ None: VanillaPatchEmbedding,
+ '1d': PatchEmbedding1D,
+ '2d': PatchEmbedding2D,
+ '2.5d': PatchEmbedding2p5D,
+ '3d': PatchEmbedding3D
+}
+
+
+class Embedding(ColossalaiModule):
+ r"""Embedding for colossalai.
+
+ Args:
+ num_embeddings (int): number of embeddings.
+ embedding_dim (int): dimension of embedding.
+ padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
+ therefore, the embedding vector at padding_idx is not updated during training,
+ i.e. it remains as a fixed “pad”, defaults to None.
+ dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
+ weight_initializer (:class:`typing.Callable`, optional):
+ he initializer of weight, defaults to normal initializer.
+
+ The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
+ ::
+
+ max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
+ renormalized to have norm max_norm. Note: this will modify weight in-place.
+ norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
+ scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
+ of frequency of the words in the mini-batch. Default False.
+ sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
+
+ More details about ``args`` and ``kwargs`` could be found in
+ `Embedding `_.
+
+ More details about ``initializer`` please refer to
+ `init `_
+ """
+
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: int = None,
+ dtype: dtype = None,
+ weight_initializer: Callable = init.normal_(),
+ vocab_parallel_limit: int = 2048,
+ *args,
+ **kwargs) -> None:
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel is None:
+ embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
+ **kwargs).to(dtype).to(get_current_device())
+ weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
+ elif num_embeddings <= vocab_parallel_limit:
+ embed = _parallel_embedding[tensor_parallel](
+ num_embeddings,
+ embedding_dim,
+ padding_idx=padding_idx,
+ dtype=dtype,
+ weight_initializer=weight_initializer,
+ *args,
+ **kwargs,
+ )
+ else:
+ embed = _vocab_parallel_embedding[tensor_parallel](
+ num_embeddings,
+ embedding_dim,
+ padding_idx=padding_idx,
+ dtype=dtype,
+ weight_initializer=weight_initializer,
+ *args,
+ **kwargs,
+ )
+ super().__init__(embed)
+
+
+class PatchEmbedding(ColossalaiModule):
+ """2D Image to Patch Embedding.
+
+ Args:
+ img_size (int): image size.
+ patch_size (int): patch size.
+ in_chans (int): number of channels of input image.
+ embed_size (int): size of embedding.
+ dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
+ flatten (bool, optional): whether to flatten output tensor, defaults to True.
+ weight_initializer (:class:`typing.Callable`, optional):
+ The initializer of weight, defaults to kaiming uniform initializer.
+ bias_initializer (:class:`typing.Callable`, optional):
+ The initializer of bias, defaults to xavier uniform initializer.
+ position_embed_initializer (:class:`typing.Callable`, optional):
+ The initializer of position embedding, defaults to zeros initializer.
+
+ More details about ``initializer`` please refer to
+ `init `_.
+ """
+
+ def __init__(
+ self,
+ img_size: int,
+ patch_size: int,
+ in_chans: int,
+ embed_size: int,
+ dtype: dtype = None,
+ flatten: bool = True,
+ weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
+ bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
+ position_embed_initializer: Callable = init.zeros_()
+ ) -> None:
+ tensor_parallel = get_tensor_parallel_mode()
+ embed = _parallel_patchembedding[tensor_parallel](
+ img_size,
+ patch_size,
+ in_chans,
+ embed_size,
+ dtype=dtype,
+ flatten=flatten,
+ weight_initializer=weight_initializer,
+ bias_initializer=bias_initializer,
+ position_embed_initializer=position_embed_initializer,
+ )
+ super().__init__(embed)
diff --git a/colossalai/nn/layer/colossalai_layer/linear.py b/colossalai/legacy/nn/layer/colossalai_layer/linear.py
similarity index 99%
rename from colossalai/nn/layer/colossalai_layer/linear.py
rename to colossalai/legacy/nn/layer/colossalai_layer/linear.py
index 3e0c6e285c1c..c05ceb66ce25 100644
--- a/colossalai/nn/layer/colossalai_layer/linear.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/linear.py
@@ -4,9 +4,9 @@
from torch import dtype, nn
+from colossalai.nn import init
from colossalai.utils import get_current_device
-from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py
similarity index 97%
rename from colossalai/nn/layer/colossalai_layer/normalization.py
rename to colossalai/legacy/nn/layer/colossalai_layer/normalization.py
index 86861d30214a..f8e317e723f1 100644
--- a/colossalai/nn/layer/colossalai_layer/normalization.py
+++ b/colossalai/legacy/nn/layer/colossalai_layer/normalization.py
@@ -1,41 +1,42 @@
-from colossalai.utils import get_current_device
-from torch import nn
-
-from ..parallel_1d import LayerNorm1D
-from ..parallel_2d import LayerNorm2D
-from ..parallel_2p5d import LayerNorm2p5D
-from ..parallel_3d import LayerNorm3D
-from ..utils import get_tensor_parallel_mode
-from ..vanilla import VanillaLayerNorm
-from ._utils import ColossalaiModule
-
-_parallel_layernorm = {
- None: VanillaLayerNorm,
- "1d": LayerNorm1D,
- "2d": LayerNorm2D,
- "2.5d": LayerNorm2p5D,
- "3d": LayerNorm3D,
-}
-
-
-class LayerNorm(ColossalaiModule):
- r"""Layer Normalization for colossalai.
-
- Args:
- normalized_shape (int): input shape from an expected input of size.
- :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
- \times \ldots \times \text{normalized_shape}[-1]]`
- If a single integer is used, it is treated as a singleton list, and this module will
- normalize over the last dimension which is expected to be of that specific size.
- eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
- bias (bool, optional): Whether to add a bias, defaults to ``True``.
- dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
- """
-
- def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel is None:
- norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
- else:
- norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
- super().__init__(norm)
+from torch import nn
+
+from colossalai.utils import get_current_device
+
+from ..parallel_1d import LayerNorm1D
+from ..parallel_2d import LayerNorm2D
+from ..parallel_2p5d import LayerNorm2p5D
+from ..parallel_3d import LayerNorm3D
+from ..utils import get_tensor_parallel_mode
+from ..vanilla import VanillaLayerNorm
+from ._utils import ColossalaiModule
+
+_parallel_layernorm = {
+ None: VanillaLayerNorm,
+ "1d": LayerNorm1D,
+ "2d": LayerNorm2D,
+ "2.5d": LayerNorm2p5D,
+ "3d": LayerNorm3D,
+}
+
+
+class LayerNorm(ColossalaiModule):
+ r"""Layer Normalization for colossalai.
+
+ Args:
+ normalized_shape (int): input shape from an expected input of size.
+ :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
+ \times \ldots \times \text{normalized_shape}[-1]]`
+ If a single integer is used, it is treated as a singleton list, and this module will
+ normalize over the last dimension which is expected to be of that specific size.
+ eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
+ bias (bool, optional): Whether to add a bias, defaults to ``True``.
+ dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
+ """
+
+ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel is None:
+ norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
+ else:
+ norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
+ super().__init__(norm)
diff --git a/colossalai/legacy/nn/layer/parallel_1d/__init__.py b/colossalai/legacy/nn/layer/parallel_1d/__init__.py
new file mode 100644
index 000000000000..9cffd4d339f5
--- /dev/null
+++ b/colossalai/legacy/nn/layer/parallel_1d/__init__.py
@@ -0,0 +1,17 @@
+from .layers import (
+ Classifier1D,
+ Dropout1D,
+ Embedding1D,
+ LayerNorm1D,
+ Linear1D,
+ Linear1D_Col,
+ Linear1D_Row,
+ PatchEmbedding1D,
+ VocabParallelClassifier1D,
+ VocabParallelEmbedding1D,
+)
+
+__all__ = [
+ 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
+ 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D'
+]
diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py
similarity index 100%
rename from colossalai/nn/layer/parallel_1d/_operation.py
rename to colossalai/legacy/nn/layer/parallel_1d/_operation.py
diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/legacy/nn/layer/parallel_1d/_utils.py
similarity index 99%
rename from colossalai/nn/layer/parallel_1d/_utils.py
rename to colossalai/legacy/nn/layer/parallel_1d/_utils.py
index 1212d595635d..fddf4e73db51 100644
--- a/colossalai/nn/layer/parallel_1d/_utils.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/_utils.py
@@ -3,6 +3,7 @@
import torch
import torch.distributed as dist
+
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
@@ -124,7 +125,7 @@ def backward(ctx, grad_output):
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
-
+
Args:
input_: input matrix.
parallel_mode: parallel mode.
diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py
similarity index 99%
rename from colossalai/nn/layer/parallel_1d/layers.py
rename to colossalai/legacy/nn/layer/parallel_1d/layers.py
index 406173a18c60..c0a169c1596f 100644
--- a/colossalai/nn/layer/parallel_1d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py
@@ -10,13 +10,13 @@
from torch import Tensor
from torch.nn.parameter import Parameter
-from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm
+from colossalai.legacy.communication import broadcast
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/legacy/nn/layer/parallel_2d/__init__.py
similarity index 59%
rename from colossalai/nn/layer/parallel_2d/__init__.py
rename to colossalai/legacy/nn/layer/parallel_2d/__init__.py
index 5562d1a70036..9c65f3608710 100644
--- a/colossalai/nn/layer/parallel_2d/__init__.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/__init__.py
@@ -1,6 +1,13 @@
from ._operation import reduce_by_batch_2d, split_batch_2d
-from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D,
- VocabParallelEmbedding2D)
+from .layers import (
+ Classifier2D,
+ Embedding2D,
+ LayerNorm2D,
+ Linear2D,
+ PatchEmbedding2D,
+ VocabParallelClassifier2D,
+ VocabParallelEmbedding2D,
+)
__all__ = [
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py
similarity index 98%
rename from colossalai/nn/layer/parallel_2d/_operation.py
rename to colossalai/legacy/nn/layer/parallel_2d/_operation.py
index 306577dbd933..fa9b49bcf53f 100644
--- a/colossalai/nn/layer/parallel_2d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py
@@ -2,13 +2,14 @@
import torch
import torch.distributed as dist
-from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.utils import get_current_device
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
+
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce, reduce_scatter
+from colossalai.utils import get_current_device
def matmul_2d(
@@ -226,9 +227,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opb = [None] * 2
@@ -351,9 +352,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2
opr = [None] * 2
@@ -484,9 +485,9 @@ def forward(
col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
- pipeline_parallel_rank * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opr = [None] * 2
diff --git a/colossalai/nn/layer/parallel_2d/_utils.py b/colossalai/legacy/nn/layer/parallel_2d/_utils.py
similarity index 100%
rename from colossalai/nn/layer/parallel_2d/_utils.py
rename to colossalai/legacy/nn/layer/parallel_2d/_utils.py
diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py
similarity index 99%
rename from colossalai/nn/layer/parallel_2d/layers.py
rename to colossalai/legacy/nn/layer/parallel_2d/layers.py
index f3a4d2bbbc32..b458d15c78e7 100644
--- a/colossalai/nn/layer/parallel_2d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py
@@ -5,21 +5,30 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from colossalai.communication import broadcast
+from torch import Tensor
+from torch.nn import Parameter
+
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.communication import broadcast
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
from colossalai.utils.cuda import get_current_device
-from torch import Tensor
-from torch.nn import Parameter
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
-from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d,
- reduce_scatter_tensor_2d, split_batch_2d)
+from ._operation import (
+ Matmul_AB_2D,
+ Matmul_ABT_2D,
+ add_bias_2d,
+ all_gather_tensor_2d,
+ classifier_2d,
+ layernorm_2d,
+ reduce_scatter_tensor_2d,
+ split_batch_2d,
+)
from ._utils import assert_summa_initialization, get_summa_dim_from_env
diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py
similarity index 59%
rename from colossalai/nn/layer/parallel_2p5d/__init__.py
rename to colossalai/legacy/nn/layer/parallel_2p5d/__init__.py
index bec3b1c4b0b8..23e47e6ed06b 100644
--- a/colossalai/nn/layer/parallel_2p5d/__init__.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py
@@ -1,6 +1,13 @@
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
-from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D,
- VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D)
+from .layers import (
+ Classifier2p5D,
+ Embedding2p5D,
+ LayerNorm2p5D,
+ Linear2p5D,
+ PatchEmbedding2p5D,
+ VocabParallelClassifier2p5D,
+ VocabParallelEmbedding2p5D,
+)
__all__ = [
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
similarity index 99%
rename from colossalai/nn/layer/parallel_2p5d/_operation.py
rename to colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
index 5a0f537cd6d9..55defa4a328d 100644
--- a/colossalai/nn/layer/parallel_2p5d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py
@@ -2,12 +2,13 @@
import torch
import torch.distributed as dist
-from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
+from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
+
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.utils import get_current_device
-from torch import Tensor
-from torch.cuda.amp import custom_bwd, custom_fwd
def get_parallel_group(parallel_mode: ParallelMode):
diff --git a/colossalai/nn/layer/parallel_2p5d/_utils.py b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py
similarity index 100%
rename from colossalai/nn/layer/parallel_2p5d/_utils.py
rename to colossalai/legacy/nn/layer/parallel_2p5d/_utils.py
diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
similarity index 99%
rename from colossalai/nn/layer/parallel_2p5d/layers.py
rename to colossalai/legacy/nn/layer/parallel_2p5d/layers.py
index f849cbbe7b0d..04acc2bb0f4c 100644
--- a/colossalai/nn/layer/parallel_2p5d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
@@ -5,22 +5,34 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from colossalai.communication import broadcast
+from torch import Tensor
+from torch.nn import Parameter
+
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.communication import broadcast
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.registry import LAYERS
-from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
- partition_tensor_parallel_state_dict)
+from colossalai.utils.checkpointing import (
+ broadcast_state_dict,
+ gather_tensor_parallel_state_dict,
+ partition_tensor_parallel_state_dict,
+)
from colossalai.utils.cuda import get_current_device
-from torch import Tensor
-from torch.nn import Parameter
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
-from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d,
- layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d)
+from ._operation import (
+ Matmul_AB_2p5D,
+ Matmul_ABT_2p5D,
+ add_bias_2p5d,
+ all_gather_tensor_2p5d,
+ classifier_2p5d,
+ layernorm_2p5d,
+ reduce_scatter_tensor_2p5d,
+ split_batch_2p5d,
+)
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
diff --git a/colossalai/nn/layer/parallel_3d/__init__.py b/colossalai/legacy/nn/layer/parallel_3d/__init__.py
similarity index 62%
rename from colossalai/nn/layer/parallel_3d/__init__.py
rename to colossalai/legacy/nn/layer/parallel_3d/__init__.py
index 9ae255b449ee..17fe8403c585 100644
--- a/colossalai/nn/layer/parallel_3d/__init__.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/__init__.py
@@ -1,6 +1,13 @@
from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d
-from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D,
- VocabParallelEmbedding3D)
+from .layers import (
+ Classifier3D,
+ Embedding3D,
+ LayerNorm3D,
+ Linear3D,
+ PatchEmbedding3D,
+ VocabParallelClassifier3D,
+ VocabParallelEmbedding3D,
+)
__all__ = [
'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D',
diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/legacy/nn/layer/parallel_3d/_operation.py
similarity index 99%
rename from colossalai/nn/layer/parallel_3d/_operation.py
rename to colossalai/legacy/nn/layer/parallel_3d/_operation.py
index 5dc9a242851f..ca0b0e62783a 100755
--- a/colossalai/nn/layer/parallel_3d/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/_operation.py
@@ -7,10 +7,10 @@
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
-from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.legacy.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
from ._utils import get_parallel_mode_from_env, push_async_grad
diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/legacy/nn/layer/parallel_3d/_utils.py
similarity index 100%
rename from colossalai/nn/layer/parallel_3d/_utils.py
rename to colossalai/legacy/nn/layer/parallel_3d/_utils.py
diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py
similarity index 99%
rename from colossalai/nn/layer/parallel_3d/layers.py
rename to colossalai/legacy/nn/layer/parallel_3d/layers.py
index 99b0c3f8b7ec..b815a842ca52 100644
--- a/colossalai/nn/layer/parallel_3d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py
@@ -8,14 +8,14 @@
from torch import Tensor
from torch.nn import Parameter
-from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.communication import all_reduce, broadcast
+from colossalai.legacy.nn.layer.base_layer import ParallelLayer
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.nn.layer.base_layer import ParallelLayer
-from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (
broadcast_state_dict,
gather_tensor_parallel_state_dict,
diff --git a/colossalai/nn/layer/parallel_sequence/__init__.py b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py
similarity index 74%
rename from colossalai/nn/layer/parallel_sequence/__init__.py
rename to colossalai/legacy/nn/layer/parallel_sequence/__init__.py
index 4fa9eed6f34b..d92d66d40a8e 100644
--- a/colossalai/nn/layer/parallel_sequence/__init__.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py
@@ -1,4 +1,4 @@
-from ._operation import RingQK, RingAV
+from ._operation import RingAV, RingQK
from .layers import TransformerSelfAttentionRing
__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK']
diff --git a/colossalai/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
similarity index 97%
rename from colossalai/nn/layer/parallel_sequence/_operation.py
rename to colossalai/legacy/nn/layer/parallel_sequence/_operation.py
index fc80494224c6..fcf2962017a3 100644
--- a/colossalai/nn/layer/parallel_sequence/_operation.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py
@@ -3,13 +3,13 @@
import torch
from torch import distributed as dist
+from torch.cuda.amp import custom_bwd, custom_fwd
-from colossalai.communication import ring_forward
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
+from colossalai.legacy.communication import ring_forward
+from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
from colossalai.utils import get_current_device
-from torch.cuda.amp import custom_bwd, custom_fwd
class RingQK(torch.autograd.Function):
diff --git a/colossalai/nn/layer/parallel_sequence/_utils.py b/colossalai/legacy/nn/layer/parallel_sequence/_utils.py
similarity index 100%
rename from colossalai/nn/layer/parallel_sequence/_utils.py
rename to colossalai/legacy/nn/layer/parallel_sequence/_utils.py
diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py
similarity index 98%
rename from colossalai/nn/layer/parallel_sequence/layers.py
rename to colossalai/legacy/nn/layer/parallel_sequence/layers.py
index 0887f8389dbe..e44e61c2fb7d 100644
--- a/colossalai/nn/layer/parallel_sequence/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py
@@ -2,20 +2,20 @@
# -*- encoding: utf-8 -*-
import math
-import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
+import colossalai
+from colossalai.context import seed
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
-from colossalai.registry import LAYERS
-from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.kernel import FusedScaleMaskSoftmax
-from colossalai.context import seed
+from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK
+from colossalai.legacy.registry import LAYERS
@LAYERS.register_module
diff --git a/colossalai/legacy/nn/layer/utils/__init__.py b/colossalai/legacy/nn/layer/utils/__init__.py
new file mode 100644
index 000000000000..56e969bfd0bd
--- /dev/null
+++ b/colossalai/legacy/nn/layer/utils/__init__.py
@@ -0,0 +1,15 @@
+from .common import (
+ ACT2FN,
+ CheckpointModule,
+ _ntuple,
+ divide,
+ get_tensor_parallel_mode,
+ set_tensor_parallel_attribute_by_partition,
+ set_tensor_parallel_attribute_by_size,
+ to_2tuple,
+)
+
+__all__ = [
+ 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
+ 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
+]
diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/legacy/nn/layer/utils/common.py
similarity index 99%
rename from colossalai/nn/layer/utils/common.py
rename to colossalai/legacy/nn/layer/utils/common.py
index f2297304fdc9..d8f3ad2a7eca 100644
--- a/colossalai/nn/layer/utils/common.py
+++ b/colossalai/legacy/nn/layer/utils/common.py
@@ -6,10 +6,11 @@
import numpy as np
import torch
+from torch import Tensor, nn
+
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.utils import checkpoint
-from torch import Tensor, nn
class CheckpointModule(nn.Module):
diff --git a/colossalai/nn/layer/vanilla/__init__.py b/colossalai/legacy/nn/layer/vanilla/__init__.py
similarity index 100%
rename from colossalai/nn/layer/vanilla/__init__.py
rename to colossalai/legacy/nn/layer/vanilla/__init__.py
diff --git a/colossalai/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py
similarity index 99%
rename from colossalai/nn/layer/vanilla/layers.py
rename to colossalai/legacy/nn/layer/vanilla/layers.py
index 225aed3916a6..0e11fc4d0dab 100644
--- a/colossalai/nn/layer/vanilla/layers.py
+++ b/colossalai/legacy/nn/layer/vanilla/layers.py
@@ -8,8 +8,8 @@
from torch.nn.parameter import Parameter
from colossalai.context import seed
+from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device
from ..utils import to_2tuple
diff --git a/colossalai/nn/layer/wrapper/__init__.py b/colossalai/legacy/nn/layer/wrapper/__init__.py
similarity index 100%
rename from colossalai/nn/layer/wrapper/__init__.py
rename to colossalai/legacy/nn/layer/wrapper/__init__.py
diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
similarity index 99%
rename from colossalai/nn/layer/wrapper/pipeline_wrapper.py
rename to colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
index ef1d794cc68f..68fea8622c5c 100644
--- a/colossalai/nn/layer/wrapper/pipeline_wrapper.py
+++ b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
@@ -1,6 +1,8 @@
-import torch.nn as nn
-import torch.distributed as dist
from typing import List, Tuple, Union
+
+import torch.distributed as dist
+import torch.nn as nn
+
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
diff --git a/colossalai/legacy/nn/loss/__init__.py b/colossalai/legacy/nn/loss/__init__.py
new file mode 100644
index 000000000000..1bd8872d9c3a
--- /dev/null
+++ b/colossalai/legacy/nn/loss/__init__.py
@@ -0,0 +1,41 @@
+from torch import nn
+from torch.nn.modules.loss import *
+from torch.nn.modules.loss import _Loss
+
+from colossalai.global_variables import tensor_parallel_env as env
+from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
+
+from .loss_1d import VocabParallelCrossEntropyLoss1D
+from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
+from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
+from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
+
+_parallel_cross_entropy = {
+ '2d': CrossEntropyLoss2D,
+ '2.5d': CrossEntropyLoss2p5D,
+ '3d': CrossEntropyLoss3D,
+}
+
+_vocab_parallel_cross_entropy = {
+ '1d': VocabParallelCrossEntropyLoss1D,
+ '2d': VocabParallelCrossEntropyLoss2D,
+ '2.5d': VocabParallelCrossEntropyLoss2p5D,
+ '3d': VocabParallelCrossEntropyLoss3D,
+}
+
+
+class CrossEntropyLoss(_Loss):
+
+ def __init__(self, reduction: bool = True, *args, **kwargs):
+ super().__init__()
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel is not None and env.vocab_parallel:
+ self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
+ elif tensor_parallel is None or tensor_parallel == '1d':
+ reduction = 'mean' if reduction else 'none'
+ self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
+ else:
+ self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
+
+ def forward(self, *args):
+ return self.loss(*args)
diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/legacy/nn/loss/loss_1d.py
similarity index 96%
rename from colossalai/nn/loss/loss_1d.py
rename to colossalai/legacy/nn/loss/loss_1d.py
index dd548c1d3dd4..8c9483fccaec 100644
--- a/colossalai/nn/loss/loss_1d.py
+++ b/colossalai/legacy/nn/loss/loss_1d.py
@@ -1,105 +1,106 @@
-import torch
-import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.registry import LOSSES
-from torch.cuda.amp import custom_bwd, custom_fwd
-from torch.nn.modules.loss import _Loss
-
-
-class _VocabParallelCrossEntropy1D(torch.autograd.Function):
-
- @staticmethod
- @custom_fwd(cast_inputs=torch.float32)
- def forward(ctx, vocab_parallel_logits, targets, process_group):
- if process_group is None:
- process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
-
- # Maximum value along vocab dimension across all GPUs.
- logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
- torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
- # Subtract the maximum value.
- vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
-
- # Get the partition's vocab indices
- partition_vocab_size = vocab_parallel_logits.size()[-1]
- rank = dist.get_rank(process_group)
- vocab_start_index = partition_vocab_size * rank
- vocab_end_index = vocab_start_index + partition_vocab_size
-
- # Create a mask of valid vocab ids (1 means it needs to be masked).
- target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
- masked_target = targets.clone() - vocab_start_index
- masked_target[target_mask] = 0
-
- # Get predicted-logits = logits[target].
- # For Simplicity, we convert logits to a 2-D tensor with size
- # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
- logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
- masked_target_1d = masked_target.view(-1)
- arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
- predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
- predicted_logits_1d = predicted_logits_1d.clone().contiguous()
- predicted_logits = predicted_logits_1d.view_as(targets)
- predicted_logits[target_mask] = 0.0
- # All reduce is needed to get the chunks from other GPUs.
- torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
-
- # Sum of exponential of logits along vocab dimension across all GPUs.
- exp_logits = torch.exp(vocab_parallel_logits)
- sum_exp_logits = exp_logits.sum(dim=-1)
- torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
-
- # Loss = log(sum(exp(logits))) - predicted-logit.
- loss = torch.log(sum_exp_logits) - predicted_logits
- # Store softmax, target-mask and masked-target for backward pass.
- exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
- ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
- return loss
-
- @staticmethod
- @custom_bwd
- def backward(ctx, grad_output):
-
- # Retrieve tensors from the forward path.
- softmax, target_mask, masked_target_1d = ctx.saved_tensors
-
- # All the inputs have softmax as their gradient.
- grad_input = softmax
- # For simplicity, work with the 2D gradient.
- partition_vocab_size = softmax.size()[-1]
- grad_2d = grad_input.view(-1, partition_vocab_size)
-
- # Add the gradient from matching classes.
- arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
- grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
-
- # Finally elementwise multiplication with the output gradients.
- grad_input.mul_(grad_output.unsqueeze(dim=-1))
-
- return grad_input, None, None
-
-
-@LOSSES.register_module
-class VocabParallelCrossEntropyLoss1D(_Loss):
- """Vocab parallel cross entropy loss for 1D parallelism.
-
- Args:
- reduction (bool, optional): whether to average the loss, defaults to True.
- """
-
- def __init__(self, reduction=True):
- super().__init__()
- self.reduction_mean = reduction
-
- def forward(self, logits, targets, process_group=None):
- """Calculate loss between logits and targets.
-
- Args:
- logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
- """
- loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
- if self.reduction_mean:
- loss = loss.mean()
- return loss
+import torch
+import torch.distributed as dist
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torch.nn.modules.loss import _Loss
+
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.registry import LOSSES
+
+
+class _VocabParallelCrossEntropy1D(torch.autograd.Function):
+
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float32)
+ def forward(ctx, vocab_parallel_logits, targets, process_group):
+ if process_group is None:
+ process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
+
+ # Maximum value along vocab dimension across all GPUs.
+ logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+ torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
+ # Subtract the maximum value.
+ vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
+
+ # Get the partition's vocab indices
+ partition_vocab_size = vocab_parallel_logits.size()[-1]
+ rank = dist.get_rank(process_group)
+ vocab_start_index = partition_vocab_size * rank
+ vocab_end_index = vocab_start_index + partition_vocab_size
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
+ masked_target = targets.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ # Get predicted-logits = logits[target].
+ # For Simplicity, we convert logits to a 2-D tensor with size
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+ predicted_logits = predicted_logits_1d.view_as(targets)
+ predicted_logits[target_mask] = 0.0
+ # All reduce is needed to get the chunks from other GPUs.
+ torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
+
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+ exp_logits = torch.exp(vocab_parallel_logits)
+ sum_exp_logits = exp_logits.sum(dim=-1)
+ torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
+
+ # Loss = log(sum(exp(logits))) - predicted-logit.
+ loss = torch.log(sum_exp_logits) - predicted_logits
+ # Store softmax, target-mask and masked-target for backward pass.
+ exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+ ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
+ return loss
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, grad_output):
+
+ # Retrieve tensors from the forward path.
+ softmax, target_mask, masked_target_1d = ctx.saved_tensors
+
+ # All the inputs have softmax as their gradient.
+ grad_input = softmax
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+ grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
+
+ # Finally elementwise multiplication with the output gradients.
+ grad_input.mul_(grad_output.unsqueeze(dim=-1))
+
+ return grad_input, None, None
+
+
+@LOSSES.register_module
+class VocabParallelCrossEntropyLoss1D(_Loss):
+ """Vocab parallel cross entropy loss for 1D parallelism.
+
+ Args:
+ reduction (bool, optional): whether to average the loss, defaults to True.
+ """
+
+ def __init__(self, reduction=True):
+ super().__init__()
+ self.reduction_mean = reduction
+
+ def forward(self, logits, targets, process_group=None):
+ """Calculate loss between logits and targets.
+
+ Args:
+ logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+ targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+ """
+ loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
+ if self.reduction_mean:
+ loss = loss.mean()
+ return loss
diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py
similarity index 96%
rename from colossalai/nn/loss/loss_2d.py
rename to colossalai/legacy/nn/loss/loss_2d.py
index 7da8b2d697fa..6191602b71ee 100644
--- a/colossalai/nn/loss/loss_2d.py
+++ b/colossalai/legacy/nn/loss/loss_2d.py
@@ -1,15 +1,16 @@
import torch
import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
-from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
+from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization
+from colossalai.legacy.registry import LOSSES
+from colossalai.utils import get_current_device
+
@LOSSES.register_module
class CrossEntropyLoss2D(_Loss):
diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py
similarity index 95%
rename from colossalai/nn/loss/loss_2p5d.py
rename to colossalai/legacy/nn/loss/loss_2p5d.py
index 63dc4f33ad32..2746b201152c 100644
--- a/colossalai/nn/loss/loss_2p5d.py
+++ b/colossalai/legacy/nn/loss/loss_2p5d.py
@@ -1,15 +1,16 @@
import torch
import torch.distributed as dist
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
-from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
+from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
+from colossalai.legacy.registry import LOSSES
+from colossalai.utils import get_current_device
+
@LOSSES.register_module
class CrossEntropyLoss2p5D(_Loss):
diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py
similarity index 95%
rename from colossalai/nn/loss/loss_3d.py
rename to colossalai/legacy/nn/loss/loss_3d.py
index f27d57ad6c99..2aeb1bd9825d 100644
--- a/colossalai/nn/loss/loss_3d.py
+++ b/colossalai/legacy/nn/loss/loss_3d.py
@@ -1,15 +1,16 @@
import torch
import torch.distributed as dist
-from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
-from colossalai.core import global_context as gpc
-from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
-from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
-from colossalai.registry import LOSSES
-from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
+from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
+from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+from colossalai.legacy.registry import LOSSES
+from colossalai.utils import get_current_device
+
@LOSSES.register_module
class CrossEntropyLoss3D(_Loss):
diff --git a/colossalai/nn/metric/__init__.py b/colossalai/legacy/nn/metric/__init__.py
similarity index 87%
rename from colossalai/nn/metric/__init__.py
rename to colossalai/legacy/nn/metric/__init__.py
index 00833b6119c1..76c6dac89c5b 100644
--- a/colossalai/nn/metric/__init__.py
+++ b/colossalai/legacy/nn/metric/__init__.py
@@ -1,26 +1,28 @@
-from torch import nn
-
-from ._utils import calc_acc
-from .accuracy_2d import Accuracy2D
-from .accuracy_2p5d import Accuracy2p5D
-from .accuracy_3d import Accuracy3D
-from colossalai.nn.layer.utils import get_tensor_parallel_mode
-
-_parallel_accuracy = {
- '2d': Accuracy2D,
- '2.5d': Accuracy2p5D,
- '3d': Accuracy3D,
-}
-
-
-class Accuracy(nn.Module):
- def __init__(self):
- super().__init__()
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel not in _parallel_accuracy:
- self.acc = calc_acc
- else:
- self.acc = _parallel_accuracy[tensor_parallel]()
-
- def forward(self, *args):
- return self.acc(*args)
+from torch import nn
+
+from colossalai.legacy.nn.layer.utils import get_tensor_parallel_mode
+
+from ._utils import calc_acc
+from .accuracy_2d import Accuracy2D
+from .accuracy_2p5d import Accuracy2p5D
+from .accuracy_3d import Accuracy3D
+
+_parallel_accuracy = {
+ '2d': Accuracy2D,
+ '2.5d': Accuracy2p5D,
+ '3d': Accuracy3D,
+}
+
+
+class Accuracy(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel not in _parallel_accuracy:
+ self.acc = calc_acc
+ else:
+ self.acc = _parallel_accuracy[tensor_parallel]()
+
+ def forward(self, *args):
+ return self.acc(*args)
diff --git a/colossalai/nn/metric/_utils.py b/colossalai/legacy/nn/metric/_utils.py
similarity index 95%
rename from colossalai/nn/metric/_utils.py
rename to colossalai/legacy/nn/metric/_utils.py
index eac591b64c65..8706ffc101b0 100644
--- a/colossalai/nn/metric/_utils.py
+++ b/colossalai/legacy/nn/metric/_utils.py
@@ -1,7 +1,7 @@
-import torch
-
-
-def calc_acc(logits, targets):
- preds = torch.argmax(logits, dim=-1)
- correct = torch.sum(targets == preds)
- return correct
+import torch
+
+
+def calc_acc(logits, targets):
+ preds = torch.argmax(logits, dim=-1)
+ correct = torch.sum(targets == preds)
+ return correct
diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/legacy/nn/metric/accuracy_2d.py
similarity index 89%
rename from colossalai/nn/metric/accuracy_2d.py
rename to colossalai/legacy/nn/metric/accuracy_2d.py
index a86832973cfd..838c48834a96 100644
--- a/colossalai/nn/metric/accuracy_2d.py
+++ b/colossalai/legacy/nn/metric/accuracy_2d.py
@@ -1,7 +1,8 @@
import torch
-from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from torch import nn
+from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
+
from ._utils import calc_acc
diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/legacy/nn/metric/accuracy_2p5d.py
similarity index 88%
rename from colossalai/nn/metric/accuracy_2p5d.py
rename to colossalai/legacy/nn/metric/accuracy_2p5d.py
index 3044da065de1..183380cd9846 100644
--- a/colossalai/nn/metric/accuracy_2p5d.py
+++ b/colossalai/legacy/nn/metric/accuracy_2p5d.py
@@ -1,7 +1,8 @@
import torch
-from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from torch import nn
+from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
+
from ._utils import calc_acc
diff --git a/colossalai/nn/metric/accuracy_3d.py b/colossalai/legacy/nn/metric/accuracy_3d.py
similarity index 85%
rename from colossalai/nn/metric/accuracy_3d.py
rename to colossalai/legacy/nn/metric/accuracy_3d.py
index 5506fc1d2ffc..1aaac73ecabd 100644
--- a/colossalai/nn/metric/accuracy_3d.py
+++ b/colossalai/legacy/nn/metric/accuracy_3d.py
@@ -1,33 +1,35 @@
-import torch
-from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
-from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
-from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
-from torch import nn
-
-from ._utils import calc_acc
-
-
-class Accuracy3D(nn.Module):
- """Accuracy for 3D parallelism
- """
- def __init__(self):
- super().__init__()
- self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
- self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
-
- def forward(self, logits, targets):
- """Calculate the accuracy of predicted labels.
-
- Args:
- logits (:class:`torch.tensor`): Predicted labels.
- targets (:class:`torch.tensor`): True labels from data.
-
- Returns:
- float: the accuracy of prediction.
- """
- with torch.no_grad():
- targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
- targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
- correct = calc_acc(logits, targets)
- correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode)
- return correct
+import torch
+from torch import nn
+
+from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
+from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
+from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
+
+from ._utils import calc_acc
+
+
+class Accuracy3D(nn.Module):
+ """Accuracy for 3D parallelism
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
+ self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
+
+ def forward(self, logits, targets):
+ """Calculate the accuracy of predicted labels.
+
+ Args:
+ logits (:class:`torch.tensor`): Predicted labels.
+ targets (:class:`torch.tensor`): True labels from data.
+
+ Returns:
+ float: the accuracy of prediction.
+ """
+ with torch.no_grad():
+ targets = split_tensor_3d(targets, 0, self.weight_parallel_mode)
+ targets = split_tensor_3d(targets, 0, self.input_parallel_mode)
+ correct = calc_acc(logits, targets)
+ correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode)
+ return correct
diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/legacy/nn/parallel/__init__.py
similarity index 100%
rename from colossalai/nn/parallel/__init__.py
rename to colossalai/legacy/nn/parallel/__init__.py
diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/legacy/nn/parallel/data_parallel.py
similarity index 100%
rename from colossalai/nn/parallel/data_parallel.py
rename to colossalai/legacy/nn/parallel/data_parallel.py
diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/legacy/nn/parallel/layers/__init__.py
similarity index 56%
rename from colossalai/nn/parallel/layers/__init__.py
rename to colossalai/legacy/nn/parallel/layers/__init__.py
index 29b8353e63c5..f38124efedf7 100644
--- a/colossalai/nn/parallel/layers/__init__.py
+++ b/colossalai/legacy/nn/parallel/layers/__init__.py
@@ -1,10 +1,17 @@
+from .cache_embedding import (
+ CachedEmbeddingBag,
+ CachedParamMgr,
+ EvictionStrategy,
+ LimitBuffIndexCopyer,
+ ParallelCachedEmbeddingBag,
+ ParallelCachedEmbeddingBagTablewise,
+ ParallelCachedEmbeddingBagTablewiseSpiltCache,
+ TablewiseEmbeddingBagConfig,
+)
from .colo_module import ColoModule
-from .linear import ColoLinear
from .embedding import ColoEmbedding
-from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
-
-from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \
- ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache
+from .linear import ColoLinear
+from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module
__all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py
similarity index 100%
rename from colossalai/nn/parallel/layers/cache_embedding/__init__.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py
index 5bbc931a79dc..d87930c1c6b3 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py
@@ -1,8 +1,8 @@
from .cache_mgr import CachedParamMgr, EvictionStrategy
-from .copyer import LimitBuffIndexCopyer
from .cached_embedding import CachedEmbeddingBag
-from .parallel_cached_embedding import ParallelCachedEmbeddingBag
+from .copyer import LimitBuffIndexCopyer
from .embedding_config import TablewiseEmbeddingBagConfig
+from .parallel_cached_embedding import ParallelCachedEmbeddingBag
from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise
from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache
diff --git a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py
similarity index 99%
rename from colossalai/nn/parallel/layers/cache_embedding/base_embedding.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py
index 705835a0ed22..9558c541e703 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py
@@ -1,4 +1,5 @@
import abc
+
import torch.nn as nn
diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py
similarity index 99%
rename from colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py
index a6159856dcce..16530c4ce7b8 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py
@@ -1,12 +1,14 @@
+import sys
+from contextlib import contextmanager
+from enum import Enum
+from typing import List, Optional
+
import numpy as np
import torch
-from torch.profiler import record_function
-from typing import List, Optional
from contexttimer import Timer
+from torch.profiler import record_function
+
from .copyer import LimitBuffIndexCopyer
-from enum import Enum
-import sys
-from contextlib import contextmanager
class EvictionStrategy(Enum):
@@ -35,7 +37,7 @@ def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
class CachedParamMgr(torch.nn.Module):
"""
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
- CPU maintains the entire original weight.
+ CPU maintains the entire original weight.
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
During training, GPU needs to transmit embedding rows between CPU and GPU.
Args:
@@ -115,7 +117,7 @@ def timer(self, name):
self._elapsed_dict[name] += t.elapsed
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
- """_find_evict_gpu_idxs
+ """_find_evict_gpu_idxs
Find the gpu idxs to be evicted, according to their freq.
Args:
evict_num (int): how many rows has to be evicted
@@ -202,7 +204,7 @@ def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7
"""reorder
reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase.
-
+
Note:
If you would like to use the DATASET as the eviction strategy, you must call this function.
Note:
@@ -516,7 +518,7 @@ def _evict(self) -> int:
"""
deprecated
evict one row from cuda to cpu.
- Returns:
+ Returns:
(int) : the slot id be evicted.
"""
mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1)
diff --git a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py
similarity index 98%
rename from colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py
index a74cb8d94bab..bc7d178906da 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py
@@ -1,10 +1,11 @@
+from typing import Iterator, List, Optional, Tuple, Union
+
import torch
import torch.nn.functional as F
-from typing import List, Optional, Iterator, Tuple, Union
+from torch.nn.parameter import Parameter
from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr, EvictionStrategy
-from torch.nn.parameter import Parameter
class CachedEmbeddingBag(BaseEmbeddingBag):
@@ -27,7 +28,7 @@ class CachedEmbeddingBag(BaseEmbeddingBag):
include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False.
dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32.
device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu.
- cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row
+ cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occurs in dataset. Defaults to None.
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0.
@@ -85,10 +86,10 @@ def _preprocess(self,
buffer_size=50_000,
pin_weight=False):
"""
- Called after initialized.
+ Called after initialized.
Reorder the weight rows according to the ids_freq_mapping.
Then, let the weights of the Module be managed by a CachedParamMgr.
-
+
Args:
cuda_row_num (int): number of rows can be hosted in CUDA memory
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
diff --git a/colossalai/nn/parallel/layers/cache_embedding/copyer.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py
similarity index 97%
rename from colossalai/nn/parallel/layers/cache_embedding/copyer.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py
index aa1f794482f9..804a07f88207 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/copyer.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py
@@ -3,7 +3,7 @@
class LimitBuffIndexCopyer(object):
- """LimitBuffIndexCopyer
+ """LimitBuffIndexCopyer
Index Copy using limited temp buffer on CUDA.
Args:
@@ -15,7 +15,7 @@ def __init__(self, size: int) -> None:
@torch.no_grad()
def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor):
- """copy
+ """copy
src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index]
The valid rows in the src tensor are continuous, while rows in tgt tensor is scattered.
diff --git a/colossalai/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py
similarity index 100%
rename from colossalai/nn/parallel/layers/cache_embedding/embedding_config.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py
diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
similarity index 96%
rename from colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
index d7f77e195f4b..79d7672b26bc 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py
@@ -1,12 +1,13 @@
+from typing import Iterator, List, Optional, Tuple
+
import torch
import torch.nn.functional as F
-from typing import List, Optional, Iterator, Tuple
-from .cached_embedding import CachedEmbeddingBag
-from colossalai.nn._ops._utils import dual_all_to_all
+from colossalai.legacy.nn._ops._utils import dual_all_to_all
+from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec
-from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
from .cache_mgr import CachedParamMgr, EvictionStrategy
+from .cached_embedding import CachedEmbeddingBag
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
similarity index 99%
rename from colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
index 949f85ad4baf..116d836b7139 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py
@@ -1,15 +1,16 @@
+import time
+from typing import List
+
import torch
import torch.distributed as dist
import torch.nn.functional as F
-from .cached_embedding import CachedEmbeddingBag
-from .cache_mgr import EvictionStrategy
-from .embedding_config import TablewiseEmbeddingBagConfig
+from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
from colossalai.tensor import ProcessGroup
-from colossalai.nn._ops._utils import dual_all_to_all_tablewise
-from typing import List
-import time
+from .cache_mgr import EvictionStrategy
+from .cached_embedding import CachedEmbeddingBag
+from .embedding_config import TablewiseEmbeddingBagConfig
class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag):
diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
similarity index 99%
rename from colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
rename to colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
index 80a54b4fadd4..0014c784fba1 100644
--- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
+++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py
@@ -1,17 +1,17 @@
+import abc
+from typing import List
+
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.profiler import record_function
-from .cached_embedding import CachedEmbeddingBag
-
+from colossalai.legacy.nn._ops._utils import dual_all_to_all_tablewise
from colossalai.tensor import ProcessGroup
-from colossalai.nn._ops._utils import dual_all_to_all_tablewise
-from .embedding_config import TablewiseEmbeddingBagConfig
-from .cache_mgr import EvictionStrategy
-from typing import List
-import abc
+from .cache_mgr import EvictionStrategy
+from .cached_embedding import CachedEmbeddingBag
+from .embedding_config import TablewiseEmbeddingBagConfig
class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module):
diff --git a/colossalai/nn/parallel/layers/colo_module.py b/colossalai/legacy/nn/parallel/layers/colo_module.py
similarity index 98%
rename from colossalai/nn/parallel/layers/colo_module.py
rename to colossalai/legacy/nn/parallel/layers/colo_module.py
index 8f0f5d5f520a..a0a3eb40cf08 100644
--- a/colossalai/nn/parallel/layers/colo_module.py
+++ b/colossalai/legacy/nn/parallel/layers/colo_module.py
@@ -1,6 +1,7 @@
-from colossalai.tensor.distspec import _DistSpec
+from typing import Dict, List
+
from colossalai.tensor import ComputePattern
-from typing import List, Dict
+from colossalai.tensor.distspec import _DistSpec
class ColoModule(object):
diff --git a/colossalai/nn/parallel/layers/embedding.py b/colossalai/legacy/nn/parallel/layers/embedding.py
similarity index 92%
rename from colossalai/nn/parallel/layers/embedding.py
rename to colossalai/legacy/nn/parallel/layers/embedding.py
index ccacc1ead297..3e4e7ffd8de7 100644
--- a/colossalai/nn/parallel/layers/embedding.py
+++ b/colossalai/legacy/nn/parallel/layers/embedding.py
@@ -1,5 +1,6 @@
+from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
+
from .colo_module import ColoModule
-from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoEmbedding(ColoModule):
diff --git a/colossalai/nn/parallel/layers/linear.py b/colossalai/legacy/nn/parallel/layers/linear.py
similarity index 93%
rename from colossalai/nn/parallel/layers/linear.py
rename to colossalai/legacy/nn/parallel/layers/linear.py
index 84a8c042587d..e391cf808933 100644
--- a/colossalai/nn/parallel/layers/linear.py
+++ b/colossalai/legacy/nn/parallel/layers/linear.py
@@ -1,5 +1,6 @@
+from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec
+
from .colo_module import ColoModule
-from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec
class ColoLinear(ColoModule):
diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/legacy/nn/parallel/layers/module_utils.py
similarity index 99%
rename from colossalai/nn/parallel/layers/module_utils.py
rename to colossalai/legacy/nn/parallel/layers/module_utils.py
index 38d128cc705e..191266fa70fd 100644
--- a/colossalai/nn/parallel/layers/module_utils.py
+++ b/colossalai/legacy/nn/parallel/layers/module_utils.py
@@ -1,9 +1,11 @@
from typing import Dict
-from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup
-from colossalai.tensor import distspec
-from . import ColoModule
+
import torch
+from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec
+
+from . import ColoModule
+
_COLOSSAL_MODULES: Dict[type, ColoModule] = {}
diff --git a/colossalai/nn/parallel/reducer.py b/colossalai/legacy/nn/parallel/reducer.py
similarity index 100%
rename from colossalai/nn/parallel/reducer.py
rename to colossalai/legacy/nn/parallel/reducer.py
diff --git a/colossalai/registry/__init__.py b/colossalai/legacy/registry/__init__.py
similarity index 100%
rename from colossalai/registry/__init__.py
rename to colossalai/legacy/registry/__init__.py
diff --git a/colossalai/registry/registry.py b/colossalai/legacy/registry/registry.py
similarity index 98%
rename from colossalai/registry/registry.py
rename to colossalai/legacy/registry/registry.py
index 8a4173f7ab99..50d6b74c5617 100644
--- a/colossalai/registry/registry.py
+++ b/colossalai/legacy/registry/registry.py
@@ -6,7 +6,7 @@
class Registry:
- """This is a registry class used to register classes and modules so that a universal
+ """This is a registry class used to register classes and modules so that a universal
object builder can be enabled.
Args:
@@ -42,7 +42,7 @@ def register_module(self, module_class):
return module_class
def get_module(self, module_name: str):
- """Retrieves a module with name `module_name` and returns the module if it has
+ """Retrieves a module with name `module_name` and returns the module if it has
already been registered before.
Args:
diff --git a/colossalai/trainer/__init__.py b/colossalai/legacy/trainer/__init__.py
similarity index 100%
rename from colossalai/trainer/__init__.py
rename to colossalai/legacy/trainer/__init__.py
diff --git a/colossalai/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py
similarity index 98%
rename from colossalai/trainer/_trainer.py
rename to colossalai/legacy/trainer/_trainer.py
index bfe1c403fd48..1847e56222a1 100644
--- a/colossalai/trainer/_trainer.py
+++ b/colossalai/legacy/trainer/_trainer.py
@@ -1,14 +1,13 @@
-from typing import Union, List, Any
+from typing import Any, List, Union
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
-from colossalai.engine import Engine
+from colossalai.legacy.engine import Engine
+from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import DistributedLogger
-from colossalai.utils import MultiTimer
-from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
-from colossalai.trainer.hooks import BaseHook
+from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
class Trainer:
diff --git a/colossalai/trainer/hooks/__init__.py b/colossalai/legacy/trainer/hooks/__init__.py
similarity index 75%
rename from colossalai/trainer/hooks/__init__.py
rename to colossalai/legacy/trainer/hooks/__init__.py
index 4d36093833d9..bf9cc6421b67 100644
--- a/colossalai/trainer/hooks/__init__.py
+++ b/colossalai/legacy/trainer/hooks/__init__.py
@@ -1,7 +1,12 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook
-from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
- TensorboardHook)
+from ._log_hook import (
+ LogMemoryByEpochHook,
+ LogMetricByEpochHook,
+ LogMetricByStepHook,
+ LogTimingByEpochHook,
+ TensorboardHook,
+)
from ._lr_scheduler_hook import LRSchedulerHook
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
diff --git a/colossalai/trainer/hooks/_base_hook.py b/colossalai/legacy/trainer/hooks/_base_hook.py
similarity index 100%
rename from colossalai/trainer/hooks/_base_hook.py
rename to colossalai/legacy/trainer/hooks/_base_hook.py
diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
similarity index 96%
rename from colossalai/trainer/hooks/_checkpoint_hook.py
rename to colossalai/legacy/trainer/hooks/_checkpoint_hook.py
index 3bcb32cd2dcb..6b150d29139f 100644
--- a/colossalai/trainer/hooks/_checkpoint_hook.py
+++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py
@@ -1,11 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
-from colossalai.logging import get_dist_logger
-from colossalai.registry import HOOKS
-from colossalai.trainer.hooks import BaseHook
+from colossalai.legacy.registry import HOOKS
+from colossalai.legacy.trainer.hooks import BaseHook
+from colossalai.logging import get_dist_logger
from colossalai.utils.checkpointing import save_checkpoint
+
from ._lr_scheduler_hook import LRSchedulerHook
diff --git a/colossalai/trainer/hooks/_commons_.py b/colossalai/legacy/trainer/hooks/_commons_.py
similarity index 100%
rename from colossalai/trainer/hooks/_commons_.py
rename to colossalai/legacy/trainer/hooks/_commons_.py
diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py
similarity index 98%
rename from colossalai/trainer/hooks/_log_hook.py
rename to colossalai/legacy/trainer/hooks/_log_hook.py
index 5b1f33983422..7d9ad19aa9e9 100644
--- a/colossalai/trainer/hooks/_log_hook.py
+++ b/colossalai/legacy/trainer/hooks/_log_hook.py
@@ -3,17 +3,17 @@
import os
import os.path as osp
-
from typing import List
+
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import HOOKS
+from colossalai.legacy.registry import HOOKS
+from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger
-from colossalai.utils import report_memory_usage, is_dp_rank_0, \
- is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
+from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
+
from ._base_hook import BaseHook
from ._commons_ import _format_number
-from colossalai.trainer.hooks._metric_hook import ThroughputMetric
class LogByEpochHook(BaseHook):
diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
similarity index 97%
rename from colossalai/trainer/hooks/_lr_scheduler_hook.py
rename to colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
index c6da33442dc3..6d60966da12a 100644
--- a/colossalai/trainer/hooks/_lr_scheduler_hook.py
+++ b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
@@ -1,6 +1,7 @@
-from colossalai.registry import HOOKS
from torch import Tensor
+from colossalai.legacy.registry import HOOKS
+
from ._metric_hook import LearningRateMetric, MetricHook
diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py
similarity index 97%
rename from colossalai/trainer/hooks/_metric_hook.py
rename to colossalai/legacy/trainer/hooks/_metric_hook.py
index 526d6c746ec6..f1bd19387cb5 100644
--- a/colossalai/trainer/hooks/_metric_hook.py
+++ b/colossalai/legacy/trainer/hooks/_metric_hook.py
@@ -6,10 +6,11 @@
import torch
import torch.distributed as dist
-from colossalai.communication import all_reduce
+
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import HOOKS
+from colossalai.legacy.communication import all_reduce
+from colossalai.legacy.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook
@@ -19,8 +20,8 @@
class Metric(ABC):
"""A basic class of metric collectors. It collects a specific
metric during training or evaluation and would always be used with
- :class:`MetricHook` to help it update its states and show the
- metric. So please use corresponding hook class to make the metric
+ :class:`MetricHook` to help it update its states and show the
+ metric. So please use corresponding hook class to make the metric
collector works.
Args:
@@ -220,9 +221,9 @@ def is_better(a, b) -> bool:
class MetricHook(BaseHook):
- """Specialized hook classes for :class:`Metric`.
- Some help metric collectors initialize, reset and
- update their states. Others are used to display and
+ """Specialized hook classes for :class:`Metric`.
+ Some help metric collectors initialize, reset and
+ update their states. Others are used to display and
record the metric.
Args:
@@ -355,7 +356,7 @@ def get_last_step_value(self) -> float:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
- gpc.get_world_size(ParallelMode.DATA)
+ gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
@@ -366,7 +367,7 @@ def get_last_step_info(self) -> str:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
- gpc.get_world_size(ParallelMode.DATA)
+ gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py
index af7b7de54a8d..f9abe4a2a2b6 100644
--- a/colossalai/logging/logger.py
+++ b/colossalai/logging/logger.py
@@ -6,8 +6,7 @@
from pathlib import Path
from typing import List, Union
-import colossalai
-from colossalai.context.parallel_mode import ParallelMode
+import torch.distributed as dist
class DistributedLogger:
@@ -63,6 +62,7 @@ def __init__(self, name):
self._logger.propagate = False
DistributedLogger.__instances[name] = self
+ self.rank = dist.get_rank() if dist.is_initialized() else 0
@staticmethod
def __get_call_info():
@@ -109,16 +109,10 @@ def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INF
# create log directory
path.mkdir(parents=True, exist_ok=True)
- # set the default file name if path is a directory
- if not colossalai.core.global_context.is_initialized(ParallelMode.GLOBAL):
- rank = 0
- else:
- rank = colossalai.core.global_context.get_global_rank()
-
if suffix is not None:
- log_file_name = f'rank_{rank}_{suffix}.log'
+ log_file_name = f'rank_{self.rank}_{suffix}.log'
else:
- log_file_name = f'rank_{rank}.log'
+ log_file_name = f'rank_{self.rank}.log'
path = path.joinpath(log_file_name)
# add file handler
@@ -128,19 +122,14 @@ def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INF
file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler)
- def _log(self,
- level,
- message: str,
- parallel_mode: ParallelMode = ParallelMode.GLOBAL,
- ranks: List[int] = None) -> None:
+ def _log(self, level, message: str, ranks: List[int] = None) -> None:
if ranks is None:
getattr(self._logger, level)(message)
else:
- local_rank = colossalai.core.global_context.get_local_rank(parallel_mode)
- if local_rank in ranks:
+ if self.rank in ranks:
getattr(self._logger, level)(message)
- def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
+ def info(self, message: str, ranks: List[int] = None) -> None:
"""Log an info message.
Args:
@@ -150,10 +139,10 @@ def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
- self._log('info', message_prefix, parallel_mode, ranks)
- self._log('info', message, parallel_mode, ranks)
+ self._log('info', message_prefix, ranks)
+ self._log('info', message, ranks)
- def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
+ def warning(self, message: str, ranks: List[int] = None) -> None:
"""Log a warning message.
Args:
@@ -163,10 +152,10 @@ def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBA
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
- self._log('warning', message_prefix, parallel_mode, ranks)
- self._log('warning', message, parallel_mode, ranks)
+ self._log('warning', message_prefix, ranks)
+ self._log('warning', message, ranks)
- def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
+ def debug(self, message: str, ranks: List[int] = None) -> None:
"""Log a debug message.
Args:
@@ -176,10 +165,10 @@ def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
- self._log('debug', message_prefix, parallel_mode, ranks)
- self._log('debug', message, parallel_mode, ranks)
+ self._log('debug', message_prefix, ranks)
+ self._log('debug', message, ranks)
- def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
+ def error(self, message: str, ranks: List[int] = None) -> None:
"""Log an error message.
Args:
@@ -189,5 +178,5 @@ def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks (List[int]): List of parallel ranks.
"""
message_prefix = "{}:{} {}".format(*self.__get_call_info())
- self._log('error', message_prefix, parallel_mode, ranks)
- self._log('error', message, parallel_mode, ranks)
+ self._log('error', message_prefix, ranks)
+ self._log('error', message, ranks)
diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py
index 910ad203180c..c6c4d3042556 100644
--- a/colossalai/nn/__init__.py
+++ b/colossalai/nn/__init__.py
@@ -1,6 +1,5 @@
-from ._ops import *
+from .init import *
from .layer import *
from .loss import *
from .lr_scheduler import *
-from .metric import *
from .optimizer import *
diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py
index b705632f8040..edd986ef5e82 100644
--- a/colossalai/nn/layer/__init__.py
+++ b/colossalai/nn/layer/__init__.py
@@ -1,10 +1,2 @@
-from .colossalai_layer import *
-from .parallel_1d import *
-from .parallel_2d import *
-from .parallel_2p5d import *
-from .parallel_3d import *
-from .parallel_sequence import *
from .moe import *
from .utils import *
-from .vanilla import *
-from .wrapper import *
diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py
deleted file mode 100644
index 2353851df665..000000000000
--- a/colossalai/nn/layer/parallel_1d/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row,
- PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D)
-
-__all__ = [
- 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
- 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D'
-]
diff --git a/colossalai/nn/layer/utils.py b/colossalai/nn/layer/utils.py
new file mode 100644
index 000000000000..dc12ff8daa4e
--- /dev/null
+++ b/colossalai/nn/layer/utils.py
@@ -0,0 +1,14 @@
+def divide(numerator, denominator):
+ """Only allow exact division.
+
+ Args:
+ numerator (int): Numerator of the division.
+ denominator (int): Denominator of the division.
+
+ Returns:
+ int: the result of exact division.
+ """
+ assert denominator != 0, 'denominator can not be zero'
+ assert numerator % denominator == 0, \
+ '{} is not divisible by {}'.format(numerator, denominator)
+ return numerator // denominator
diff --git a/colossalai/nn/layer/utils/__init__.py b/colossalai/nn/layer/utils/__init__.py
deleted file mode 100644
index 7e999ee82149..000000000000
--- a/colossalai/nn/layer/utils/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
- set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
-
-__all__ = [
- 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
- 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
-]
diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py
index 373e4ec9468b..ee2add48ab91 100644
--- a/colossalai/nn/loss/__init__.py
+++ b/colossalai/nn/loss/__init__.py
@@ -1,41 +1 @@
-from colossalai.global_variables import tensor_parallel_env as env
-from colossalai.nn.layer.utils import get_tensor_parallel_mode
-from torch import nn
-from torch.nn.modules.loss import *
-from torch.nn.modules.loss import _Loss
-
-from .loss_1d import VocabParallelCrossEntropyLoss1D
-from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
-from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
-from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
from .loss_moe import MoeCrossEntropyLoss, MoeLoss
-
-_parallel_cross_entropy = {
- '2d': CrossEntropyLoss2D,
- '2.5d': CrossEntropyLoss2p5D,
- '3d': CrossEntropyLoss3D,
-}
-
-_vocab_parallel_cross_entropy = {
- '1d': VocabParallelCrossEntropyLoss1D,
- '2d': VocabParallelCrossEntropyLoss2D,
- '2.5d': VocabParallelCrossEntropyLoss2p5D,
- '3d': VocabParallelCrossEntropyLoss3D,
-}
-
-
-class CrossEntropyLoss(_Loss):
-
- def __init__(self, reduction: bool = True, *args, **kwargs):
- super().__init__()
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel is not None and env.vocab_parallel:
- self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
- elif tensor_parallel is None or tensor_parallel == '1d':
- reduction = 'mean' if reduction else 'none'
- self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
- else:
- self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
-
- def forward(self, *args):
- return self.loss(*args)
diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py
index a8b18a3e37ee..40cea788c3c3 100644
--- a/colossalai/nn/loss/loss_moe.py
+++ b/colossalai/nn/loss/loss_moe.py
@@ -1,80 +1,81 @@
-import torch.nn as nn
-from colossalai.registry import LOSSES
-from torch.nn.modules.loss import _Loss
-from colossalai.context.moe_context import MOE_CONTEXT
-
-
-@LOSSES.register_module
-class MoeCrossEntropyLoss(_Loss):
- r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
-
- Args:
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
- aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
-
- The ``args`` and ``kwargs`` should include parameters below:
- ::
-
- weight (Tensor, optional)
- size_average (bool, optional)
- ignore_index (int, optional)
- reduce (bool, optional)
- reduction (str, optional)
- label_smoothing (float, optional)
-
- More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
- `Cross_entropy `_.
- """
-
- def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
- super().__init__()
- self.loss = nn.CrossEntropyLoss(*args, **kwargs)
- self.aux_weight = aux_weight
-
- def forward(self, *args):
- """
- The ``args`` should at least include parameters below:
- ::
-
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-
- More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
- `Cross_entropy `_.
- """
- main_loss = self.loss(*args)
- aux_loss = MOE_CONTEXT.get_loss()
- return main_loss + self.aux_weight * aux_loss
-
-
-@LOSSES.register_module
-class MoeLoss(_Loss):
- """A wrapper class for any loss module to add with auxiliary loss.
-
- Args:
- aux_weight (float): Weight of auxiliary loss in total loss.
- loss_fn (``Callable``): Loss function.
- args (list): Args in loss function.
- kwargs (dict): Kwargs in loss function
- """
-
- def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
- super().__init__()
- self.loss_fn = loss_fn(*args, **kwargs)
- self.aux_weight = aux_weight
-
- def forward(self, *args, **kwargs):
- """
- The ``args`` and ``kwargs`` should at least include parameters below:
- ::
-
- input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
- target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
-
- Note:
- The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
- """
- main_loss = self.loss_fn(*args, **kwargs)
- aux_loss = MOE_CONTEXT.get_loss()
- return main_loss + self.aux_weight * aux_loss
+import torch.nn as nn
+from torch.nn.modules.loss import _Loss
+
+from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.legacy.registry import LOSSES
+
+
+@LOSSES.register_module
+class MoeCrossEntropyLoss(_Loss):
+ r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
+
+ Args:
+ input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+ target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+ aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
+
+ The ``args`` and ``kwargs`` should include parameters below:
+ ::
+
+ weight (Tensor, optional)
+ size_average (bool, optional)
+ ignore_index (int, optional)
+ reduce (bool, optional)
+ reduction (str, optional)
+ label_smoothing (float, optional)
+
+ More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
+ `Cross_entropy `_.
+ """
+
+ def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
+ super().__init__()
+ self.loss = nn.CrossEntropyLoss(*args, **kwargs)
+ self.aux_weight = aux_weight
+
+ def forward(self, *args):
+ """
+ The ``args`` should at least include parameters below:
+ ::
+
+ input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+ target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+
+ More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
+ `Cross_entropy `_.
+ """
+ main_loss = self.loss(*args)
+ aux_loss = MOE_CONTEXT.get_loss()
+ return main_loss + self.aux_weight * aux_loss
+
+
+@LOSSES.register_module
+class MoeLoss(_Loss):
+ """A wrapper class for any loss module to add with auxiliary loss.
+
+ Args:
+ aux_weight (float): Weight of auxiliary loss in total loss.
+ loss_fn (``Callable``): Loss function.
+ args (list): Args in loss function.
+ kwargs (dict): Kwargs in loss function
+ """
+
+ def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
+ super().__init__()
+ self.loss_fn = loss_fn(*args, **kwargs)
+ self.aux_weight = aux_weight
+
+ def forward(self, *args, **kwargs):
+ """
+ The ``args`` and ``kwargs`` should at least include parameters below:
+ ::
+
+ input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
+ target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
+
+ Note:
+ The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
+ """
+ main_loss = self.loss_fn(*args, **kwargs)
+ aux_loss = MOE_CONTEXT.get_loss()
+ return main_loss + self.aux_weight * aux_loss
diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py
index aab523bef8b3..fb587e1a1341 100644
--- a/colossalai/nn/lr_scheduler/cosine.py
+++ b/colossalai/nn/lr_scheduler/cosine.py
@@ -1,10 +1,8 @@
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
-from colossalai.registry import LR_SCHEDULERS
from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler
-@LR_SCHEDULERS.register_module
class CosineAnnealingLR(_CosineAnnealingLR):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
@@ -48,7 +46,6 @@ def __init__(self, optimizer, total_steps: int, eta_min: int = 0, last_epoch: in
super().__init__(optimizer, total_steps, eta_min=eta_min, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class CosineAnnealingWarmupLR(WarmupScheduler):
"""Cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be applied.
@@ -69,7 +66,6 @@ def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min:
super().__init__(optimizer, warmup_steps, base_scheduler)
-@LR_SCHEDULERS.register_module
class FlatAnnealingLR(DelayerScheduler):
"""Flat and cosine annealing learning rate scheduler. The learning rate will be a fixed value before starting decay.
@@ -90,7 +86,6 @@ def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_ep
super().__init__(optimizer, flat_steps, base_scheduler, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class FlatAnnealingWarmupLR(WarmupDelayerScheduler):
"""Flat and cosine annealing learning rate scheduler with learning rate warmup. A linear warmup schedule will be
applied, and then the learning rate will be a fixed value before starting decay.
diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py
index 556938b8a60c..21a865e4c12b 100644
--- a/colossalai/nn/lr_scheduler/linear.py
+++ b/colossalai/nn/lr_scheduler/linear.py
@@ -1,9 +1,6 @@
from torch.optim.lr_scheduler import _LRScheduler
-from colossalai.registry import LR_SCHEDULERS
-
-@LR_SCHEDULERS.register_module
class LinearWarmupLR(_LRScheduler):
"""Linearly warmup learning rate and then linearly decay.
diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py
index 29531a9e3855..c428c911c94d 100644
--- a/colossalai/nn/lr_scheduler/multistep.py
+++ b/colossalai/nn/lr_scheduler/multistep.py
@@ -2,11 +2,9 @@
from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR
-from colossalai.registry import LR_SCHEDULERS
from .delayed import WarmupScheduler
-@LR_SCHEDULERS.register_module
class MultiStepLR(_MultiStepLR):
"""Decays the learning rate of each parameter group by gamma once the
number of epoch reaches one of the milestones. Notice that such decay can
@@ -32,7 +30,6 @@ def __init__(self,
super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class MultiStepWarmupLR(WarmupScheduler):
"""Multistep learning rate scheduler with warmup.
diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py
index 8007fd36008e..6835b3ee1cf2 100644
--- a/colossalai/nn/lr_scheduler/onecycle.py
+++ b/colossalai/nn/lr_scheduler/onecycle.py
@@ -1,9 +1,6 @@
from torch.optim.lr_scheduler import OneCycleLR as _OneCycleLR
-from colossalai.registry import LR_SCHEDULERS
-
-@LR_SCHEDULERS.register_module
class OneCycleLR(_OneCycleLR):
r"""Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py
index 16352bc5175f..4f2249720ef6 100644
--- a/colossalai/nn/lr_scheduler/poly.py
+++ b/colossalai/nn/lr_scheduler/poly.py
@@ -1,10 +1,8 @@
from torch.optim.lr_scheduler import _LRScheduler
-from colossalai.registry import LR_SCHEDULERS
from .delayed import WarmupScheduler
-@LR_SCHEDULERS.register_module
class PolynomialLR(_LRScheduler):
"""Polynomial learning rate scheduler.
@@ -40,7 +38,6 @@ def _get_closed_form_lr(self):
for base_lr in self.base_lrs]
-@LR_SCHEDULERS.register_module
class PolynomialWarmupLR(WarmupScheduler):
"""Polynomial learning rate scheduler with warmup.
diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py
index 05d2a49c1ea5..8846e13c7511 100644
--- a/colossalai/nn/lr_scheduler/torch.py
+++ b/colossalai/nn/lr_scheduler/torch.py
@@ -1,12 +1,9 @@
+from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
from torch.optim.lr_scheduler import LambdaLR as _LambdaLR
from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR
from torch.optim.lr_scheduler import StepLR as _StepLR
-from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR
-
-from colossalai.registry import LR_SCHEDULERS
-@LR_SCHEDULERS.register_module
class LambdaLR(_LambdaLR):
"""Sets the learning rate of each parameter group to the initial lr
times a given function. When last_epoch=-1, sets initial lr as lr.
@@ -24,7 +21,6 @@ def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1)
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class MultiplicativeLR(_MultiplicativeLR):
"""Multiply the learning rate of each parameter group by the factor given
in the specified function. When last_epoch=-1, sets initial lr as lr.
@@ -42,7 +38,6 @@ def __init__(self, optimizer, total_steps, lr_lambda=None, last_epoch: int = -1)
super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class StepLR(_StepLR):
"""Decays the learning rate of each parameter group by gamma every
step_size epochs. Notice that such decay can happen simultaneously with
@@ -61,7 +56,6 @@ def __init__(self, optimizer, total_steps, step_size: int = 1, gamma: float = 0.
super().__init__(optimizer, step_size, gamma=gamma, last_epoch=last_epoch)
-@LR_SCHEDULERS.register_module
class ExponentialLR(_ExponentialLR):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr
diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py
index 3a6d37103398..9767fcb8b1e2 100644
--- a/colossalai/nn/optimizer/cpu_adam.py
+++ b/colossalai/nn/optimizer/cpu_adam.py
@@ -4,12 +4,10 @@
import torch
from colossalai.kernel.op_builder import CPUAdamBuilder
-from colossalai.registry import OPTIMIZERS
from .nvme_optimizer import NVMeOptimizer
-@OPTIMIZERS.register_module
class CPUAdam(NVMeOptimizer):
"""Implements Adam algorithm.
diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py
index 82a6250f1fd1..3a05a34f52d2 100644
--- a/colossalai/nn/optimizer/fused_adam.py
+++ b/colossalai/nn/optimizer/fused_adam.py
@@ -8,11 +8,9 @@
'''
import torch
-from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
-@OPTIMIZERS.register_module
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py
index 72520064e98b..a2807d70f454 100644
--- a/colossalai/nn/optimizer/fused_lamb.py
+++ b/colossalai/nn/optimizer/fused_lamb.py
@@ -1,11 +1,9 @@
# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py
import torch
-from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
-@OPTIMIZERS.register_module
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py
index 468713b223c1..59a93a8be9c7 100644
--- a/colossalai/nn/optimizer/fused_sgd.py
+++ b/colossalai/nn/optimizer/fused_sgd.py
@@ -2,11 +2,9 @@
import torch
from torch.optim.optimizer import Optimizer, required
-from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
-@OPTIMIZERS.register_module
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py
index 84903ac36832..e08df410effe 100644
--- a/colossalai/nn/optimizer/hybrid_adam.py
+++ b/colossalai/nn/optimizer/hybrid_adam.py
@@ -4,13 +4,11 @@
from torch.optim import Adam
from colossalai.kernel.op_builder import FusedOptimBuilder
-from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
from .cpu_adam import CPUAdam
-@OPTIMIZERS.register_module
class HybridAdam(CPUAdam):
"""Implements Adam algorithm.
diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py
index 399ad39b6658..d5de267f73ee 100644
--- a/colossalai/nn/optimizer/lamb.py
+++ b/colossalai/nn/optimizer/lamb.py
@@ -5,10 +5,7 @@
import torch
from torch.optim import Optimizer
-from colossalai.registry import OPTIMIZERS
-
-@OPTIMIZERS.register_module
class Lamb(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py
index 212f66671a0d..58393fdae4bf 100644
--- a/colossalai/nn/optimizer/lars.py
+++ b/colossalai/nn/optimizer/lars.py
@@ -5,10 +5,7 @@
import torch
from torch.optim import Optimizer
-from colossalai.registry import OPTIMIZERS
-
-@OPTIMIZERS.register_module
class Lars(Optimizer):
r"""Implements the LARS optimizer from `"Large batch training of convolutional networks"
`_.
@@ -22,28 +19,24 @@ class Lars(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""
- def __init__(
- self,
- params: Iterable[torch.nn.Parameter],
- lr=1e-3,
- momentum=0,
- eeta=1e-3,
- weight_decay=0,
- epsilon=0.0
- ) -> None:
+ def __init__(self,
+ params: Iterable[torch.nn.Parameter],
+ lr=1e-3,
+ momentum=0,
+ eeta=1e-3,
+ weight_decay=0,
+ epsilon=0.0) -> None:
if not isinstance(lr, float) or lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
- raise ValueError(
- "Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if eeta <= 0 or eeta > 1:
raise ValueError("Invalid eeta value: {}".format(eeta))
if epsilon < 0:
raise ValueError("Invalid epsilon value: {}".format(epsilon))
- defaults = dict(lr=lr, momentum=momentum,
- weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
+ defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True)
super().__init__(params, defaults)
@@ -76,11 +69,9 @@ def step(self, closure=None):
if lars:
w_norm = torch.norm(p)
g_norm = torch.norm(p.grad)
- trust_ratio = torch.where(
- w_norm > 0 and g_norm > 0,
- eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
- torch.ones_like(w_norm)
- )
+ trust_ratio = torch.where(w_norm > 0 and g_norm > 0,
+ eeta * w_norm / (g_norm + weight_decay * w_norm + eps),
+ torch.ones_like(w_norm))
trust_ratio.clamp_(0.0, 50)
scaled_lr *= trust_ratio.item()
if weight_decay != 0:
@@ -90,8 +81,7 @@ def step(self, closure=None):
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
- buf = param_state['momentum_buffer'] = torch.clone(
- decayed_grad).detach()
+ buf = param_state['momentum_buffer'] = torch.clone(decayed_grad).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(decayed_grad)
diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py
new file mode 100644
index 000000000000..aed85cf91512
--- /dev/null
+++ b/colossalai/pipeline/p2p.py
@@ -0,0 +1,222 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import io
+import pickle
+import re
+from typing import Any, List, Optional, Union
+
+import torch
+import torch.distributed as dist
+from packaging.version import Version
+from torch.distributed import ProcessGroup
+from torch.distributed import distributed_c10d as c10d
+
+from .stage_manager import PipelineStageManager
+
+_unpickler = pickle.Unpickler
+
+
+def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object:
+ """transform tensor to object with unpickle.
+ Info of the device in bytes stream will be modified into current device before unpickling
+
+ Args:
+ tensor (:class:`torch.tensor`): tensor to be unpickled
+ tensor_size (:class:`torch.Size`): Size of the real info in bytes
+
+ Returns:
+ Any: object after unpickled
+ """
+ buf = tensor.numpy().tobytes()[:tensor_size]
+ if b'cuda' in buf:
+ buf_array = bytearray(buf)
+ device_index = torch.cuda.current_device()
+ # There might be more than one output tensors during forward
+ for cuda_str in re.finditer(b'cuda', buf_array):
+ pos = cuda_str.start()
+ buf_array[pos + 5] = 48 + device_index
+ buf = bytes(buf_array)
+
+ io_bytes = io.BytesIO(buf)
+ byte_pickler = _unpickler(io_bytes)
+ unpickle = byte_pickler.load()
+
+ return unpickle
+
+
+def _broadcast_object_list(object_list: List[Any],
+ src: int,
+ group: ProcessGroup,
+ device: Optional[Union[torch.device, str, int]] = None):
+ """This is a modified version of the broadcast_object_list in torch.distribution
+ The only difference is that object will be move to correct device after unpickled.
+ If local_rank = src, then object list will be sent to rank src. Otherwise, object list will
+ be updated with data sent from rank src.
+
+ Args:
+ object_list (List[Any]): list of object to broadcast
+ src (int): source rank to broadcast
+ dst (int): dst rank to broadcast
+ device (:class:`torch.device`): device to do broadcast. current device in default
+
+ """
+
+ if c10d._rank_not_in_group(group):
+ c10d._warn_not_in_group("broadcast_object_list")
+ return
+
+ is_nccl_backend = c10d._check_for_nccl_backend(group)
+ current_device = None
+
+ if device is not None:
+ if is_nccl_backend and device.type != "cuda":
+ raise ValueError("device type must be cuda for nccl backend")
+ current_device = device
+ else:
+ current_device = torch.device("cpu")
+ if is_nccl_backend:
+ current_device = torch.device("cuda", torch.cuda.current_device())
+
+ my_rank = dist.get_rank()
+ # Serialize object_list elements to tensors on src rank.
+ if my_rank == src:
+ if Version(torch.__version__) >= Version("1.13.0"):
+ tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
+ else:
+ tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
+ object_sizes_tensor = torch.cat(size_list)
+ else:
+ object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
+
+ if is_nccl_backend:
+ object_sizes_tensor = object_sizes_tensor.to(current_device)
+
+ # Broadcast object sizes
+ c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False)
+
+ # Concatenate and broadcast serialized object tensors
+ if my_rank == src:
+ object_tensor = torch.cat(tensor_list)
+ else:
+ object_tensor = torch.empty( # type: ignore[call-overload]
+ torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
+ dtype=torch.uint8,
+ )
+
+ if is_nccl_backend:
+ object_tensor = object_tensor.to(current_device)
+
+ c10d.broadcast(object_tensor, src=src, group=group, async_op=False)
+
+ # Deserialize objects using their stored sizes.
+ offset = 0
+
+ if my_rank != src:
+ for i, obj_size in enumerate(object_sizes_tensor):
+ obj_view = object_tensor[offset:offset + obj_size]
+ obj_view = obj_view.type(torch.uint8)
+ if obj_view.device != torch.device("cpu"):
+ obj_view = obj_view.cpu()
+ offset += obj_size
+ # unpickle
+ unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
+
+ # unconsistence in device
+ if isinstance(unpickle_object,
+ torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
+ unpickle_object = unpickle_object.cuda()
+
+ object_list[i] = unpickle_object
+
+
+def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
+ """send anything to dst rank
+
+ Args:
+ object (Any): object needed to be sent
+ dst (int): rank of the destination
+
+ Returns:
+ None
+ """
+ # then broadcast safely
+ _broadcast_object_list([object], src, group)
+
+
+def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
+ """recv anything from src
+
+ Args:
+ src (int): source rank of data. local rank will receive data from src rank.
+
+ Returns:
+ Any: Object received from src.
+ """
+ object_list = [None]
+ _broadcast_object_list(object_list, src, group)
+
+ return object_list[0]
+
+
+class PipelineP2PCommunication:
+
+ def __init__(self, stage_manager: PipelineStageManager) -> None:
+ self.stage_manager = stage_manager
+
+ def recv_forward(self, prev_rank: int = None) -> Any:
+ """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
+
+ Args:
+ prev_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input tensor or input tensor list.
+ """
+ if prev_rank is None:
+ prev_rank = self.stage_manager.get_prev_rank()
+ cur_rank = self.stage_manager.get_rank()
+ input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
+
+ return input_tensor
+
+ def recv_backward(self, next_rank: int = None) -> Any:
+ """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
+
+ Args:
+ next_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input gradient tensor or gradient tensor list.
+ """
+ if next_rank is None:
+ next_rank = self.stage_manager.get_next_rank()
+ cur_rank = self.stage_manager.get_rank()
+ output_tensor_grad = _recv_object(next_rank, cur_rank,
+ self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
+
+ return output_tensor_grad
+
+ def send_forward(self, output_object: Any, next_rank: int = None) -> None:
+ """Sends the input tensor to the next stage in pipeline.
+
+ Args:
+ output_object (Any): Object to be sent.
+ next_rank (int, optional): The rank of the recipient of the tensor.
+ """
+ if next_rank is None:
+ next_rank = self.stage_manager.get_next_rank()
+ cur_rank = self.stage_manager.get_rank()
+ _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
+
+ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
+ """Sends the gradient tensor to the previous stage in pipeline.
+
+ Args:
+ input_object (Any): Object to be sent.
+ prev_rank (int, optional): The rank of the recipient of the tensor
+ """
+ if prev_rank is None:
+ prev_rank = self.stage_manager.get_prev_rank()
+ cur_rank = self.stage_manager.get_rank()
+ _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py
index 79913987b7cc..ba8b1591da9d 100644
--- a/colossalai/pipeline/pipelinable.py
+++ b/colossalai/pipeline/pipelinable.py
@@ -1,15 +1,24 @@
-import torch
import inspect
-from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
-from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, \
- build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, \
- call_module, customized_partition
-from colossalai.nn.layer.utils import CheckpointModule
-from colossalai.tensor import ColoParameter
-from colossalai.core import global_context as gpc
+import torch
+
from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.legacy.nn.layer.utils import CheckpointModule
+from colossalai.tensor import ColoParameter
+from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
+
from .layer_spec import LayerSpec
+from .utils import (
+ build_kwargs_for_function,
+ build_kwargs_for_module,
+ call_module,
+ customized_partition,
+ exec_func_with_kwargs,
+ exec_funcs_with_kwargs,
+ partition_balanced,
+ partition_uniform,
+)
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py
new file mode 100644
index 000000000000..8b13413b1a31
--- /dev/null
+++ b/colossalai/pipeline/schedule/__init__.py
@@ -0,0 +1,7 @@
+from .base import PipelineSchedule
+from .one_f_one_b import OneForwardOneBackwardSchedule
+
+__all__ = [
+ 'PipelineSchedule',
+ 'OneForwardOneBackwardSchedule',
+]
diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py
new file mode 100644
index 000000000000..583558551b3c
--- /dev/null
+++ b/colossalai/pipeline/schedule/_utils.py
@@ -0,0 +1,184 @@
+from collections import OrderedDict
+from typing import Any, List, Optional, Tuple
+
+import torch
+import torch.cuda
+from torch.nn import Module
+from torch.utils._pytree import (
+ SUPPORTED_NODES,
+ LeafSpec,
+ TreeSpec,
+ _is_leaf,
+ _register_pytree_node,
+ tree_flatten,
+ tree_map,
+ tree_unflatten,
+)
+
+
+# this register are for torch under version 1.13.1, maybe removed in the future
+def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
+ return list(d.values()), list(d.keys())
+
+
+def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
+ return OrderedDict((key, value) for key, value in zip(context, values))
+
+
+_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
+
+
+def tree_map_hf(fn: Any, pytree: Any):
+ flat_args, spec = tree_flatten_hf(pytree)
+ return tree_unflatten([fn(i) for i in flat_args], spec)
+
+
+# use this flatten function to handle the ModelingOutput Class instance.
+def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
+ """Flattens a pytree into a list of values an a TreeSpec that can be used
+ to reconstruct the pytree.
+ """
+ if isinstance(pytree, OrderedDict):
+ node_type = OrderedDict
+ flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
+ child_pytrees, context = flatten_fn(pytree)
+
+ # Recursively flatten the children
+ result: List[Any] = []
+ children_specs: List['TreeSpec'] = []
+ for child in child_pytrees:
+ flat, child_spec = tree_flatten_hf(child)
+ result += flat
+ children_specs.append(child_spec)
+ return result, TreeSpec(node_type, context, children_specs)
+ else:
+ result, tree_spec = tree_flatten(pytree)
+ return result, tree_spec
+
+
+def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
+ """Move object to device if it is a tensor.
+
+ Args:
+ x (Any): Object to be moved.
+ device (Optional[torch.device], optional): Target device. Defaults to None.
+
+ Returns:
+ Any: Moved object.
+ """
+ if isinstance(x, torch.Tensor):
+ return x.to(device)
+ return x
+
+
+def get_batch_size(batch: Any) -> int:
+ """Get the batch size (size of dimension-0) of the first tensor in the batch.
+
+ Args:
+ batch (Any): Batch to be inspected.
+
+ Raises:
+ RuntimeError: If no tensor is found in the batch.
+
+ Returns:
+ int: Batch size.
+ """
+ data_list, _ = tree_flatten(batch)
+ for data in data_list:
+ if isinstance(data, torch.Tensor):
+ return data.size(0)
+ raise RuntimeError('No tensor found in the batch')
+
+
+def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any:
+ """Get a micro batch of the original batch.
+
+ Args:
+ batch (Any): Batch to be sliced.
+ start (int): Start index of the micro batch.
+ micro_batch_size (int): Size of the micro batch.
+
+ Returns:
+ Any: Target micro batch.
+ """
+
+ def _get_tensor_slice(x: Any):
+ if isinstance(x, torch.Tensor):
+ return x[start:start + micro_batch_size]
+ return x
+
+ return tree_map(_get_tensor_slice, batch)
+
+
+def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any:
+ """Call model forward function with data and internal inputs.
+
+ Args:
+ model (Module): Model to be called.
+ data (Any): Data loaded from data iterator.
+ internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage.
+
+ Returns:
+ Any: Outputs of the model.
+ """
+ if internal_inputs is None:
+ internal_inputs = {}
+ if isinstance(data, (list, tuple)):
+ return model(*data, **internal_inputs)
+ elif isinstance(data, dict):
+ return model(**data, **internal_inputs)
+ return model(data, **internal_inputs)
+
+
+def retain_grad(x: Any) -> None:
+ """Call retain_grad() on a tensor.
+
+ Args:
+ x (Any): Object to be called.
+ """
+ if isinstance(x, torch.Tensor) and x.requires_grad:
+ x.retain_grad()
+
+
+def detach(x: Any) -> Any:
+ """Call detach() on a tensor.
+
+ Args:
+ x (Any): Object to be called.
+
+ Returns:
+ Any: The detached object.
+ """
+ if isinstance(x, torch.Tensor):
+ return x.detach()
+ return x
+
+
+def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
+ """Merge micro batches into a batch.
+
+ Args:
+ data (List[Any]): A list of micro batches.
+
+ Returns:
+ Any: Merge batch.
+ """
+ if len(data) == 0:
+ return
+ flattened_data = []
+ tree_spec = None
+ for d in data:
+ # elems should be an instance of OrderedDict
+ elems, tree_spec = tree_flatten_hf(d)
+ flattened_data.append(elems)
+ merged_data = []
+
+ for elem_batch in zip(*flattened_data):
+ if isinstance(elem_batch[0], torch.Tensor):
+ if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
+ merged_data.append(None)
+ else:
+ merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
+ else:
+ merged_data.append(list(elem_batch))
+ return tree_unflatten(merged_data, tree_spec)
diff --git a/colossalai/pipeline/schedule/base.py b/colossalai/pipeline/schedule/base.py
new file mode 100644
index 000000000000..b0fa6e6ad2b8
--- /dev/null
+++ b/colossalai/pipeline/schedule/base.py
@@ -0,0 +1,35 @@
+from typing import Any, Callable, Iterable, Optional
+
+from torch import Tensor
+from torch.nn import Module
+
+from colossalai.interface import OptimizerWrapper
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+
+class PipelineSchedule:
+
+ def __init__(self, stage_manager: PipelineStageManager) -> None:
+ self.stage_manager = stage_manager
+
+ def forward_backward_step(self,
+ model: Module,
+ data_iter: Iterable,
+ criterion: Callable[[Any, Any], Tensor],
+ optimizer: Optional[OptimizerWrapper] = None,
+ return_loss: bool = False,
+ return_outputs: bool = False) -> dict:
+ """Forward and backward step for pipeline training.
+
+ Args:
+ model (Module): Model to be trained.
+ data_iter (Iterable): Data iterator.
+ criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
+ return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
+ return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
+
+ Returns:
+ dict: A dict with keys: 'loss' and 'outputs'.
+ """
+ raise NotImplementedError
diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py
new file mode 100644
index 000000000000..6fdb09be5f32
--- /dev/null
+++ b/colossalai/pipeline/schedule/interleaved_pp.py
@@ -0,0 +1,372 @@
+from functools import partial
+from typing import Any, Callable, Iterable, List, Optional, Union
+
+import torch
+import torch.cuda
+from torch.nn import Module
+from torch.utils._pytree import tree_map
+
+from colossalai.interface import OptimizerWrapper
+from colossalai.pipeline.p2p import PipelineP2PCommunication
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.utils.cuda import get_current_device
+
+from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
+from .base import PipelineSchedule
+
+
+class InterleavedSchedule(PipelineSchedule):
+
+ def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
+ self.num_model_chunks = num_model_chunks
+ assert num_microbatches % self.num_model_chunks == 0, \
+ "Number of microbatches should be an integer multiple of number of model chunks"
+ super().__init__(stage_manager)
+ self.comm = PipelineP2PCommunication(stage_manager)
+ self.num_microbatches = num_microbatches
+ self.batch: Optional[Any] = None
+ self.batch_size: Optional[int] = None
+ self.microbatch_offset: Optional[int] = None
+ self.microbatch_size: Optional[int] = None
+
+ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
+ """Load a batch from data iterator.
+
+ Args:
+ data_iter (Iterable): Data iterator.
+ device (Optional[torch.device], optional): Target device. Defaults to None.
+ """
+ batch = next(data_iter)
+ if device is not None:
+ batch = tree_map(partial(to_device, device=device), batch)
+ self.batch = batch
+ self.batch_size = get_batch_size(batch)
+ self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
+ assert self.batch_size % self.num_microbatches == 0, \
+ "Batch size should divided by the number of microbatches"
+ self.microbatch_size = self.batch_size // self.num_microbatches
+
+ def load_micro_batch(self, model_chunk_id: int) -> Any:
+ """Load a micro batch from the current batch.
+
+ Args:
+ microbatch_id (int): the current model chunk idx.
+
+ Returns:
+ Any: Micro batch.
+ """
+ micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
+ self.microbatch_offset[model_chunk_id] += self.microbatch_size
+ return tree_map(partial(to_device, device=get_current_device()), micro_batch)
+
+ def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
+ """Helper method to get the model chunk ID given the iteration number.
+
+ Args:
+ microbatch_id (int): the current microbatch idx
+ forward (bool): if is the forward process
+
+ Returns:
+ int: The model chunk idx of the input microbatch_id
+ """
+ microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
+ model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
+ if not forward:
+ model_chunk_id = (self.num_model_chunks - model_chunk_id - 1)
+ return model_chunk_id
+
+ def is_first_stage(self, model_chunk_id: int) -> bool:
+ """Is the current virtual stage the first stage
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+
+ Returns:
+ bool: Whether the current virtual stage is the first stage.
+ """
+ if self.stage_manager.is_first_stage() and model_chunk_id == 0:
+ return True
+ return False
+
+ def is_last_stage(self, model_chunk_id: int) -> bool:
+ """Is the current virtual stage the last stage
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+
+ Returns:
+ bool: Whether the current virtual stage is the last stage.
+ """
+ if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
+ return True
+ return False
+
+ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
+ """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
+ For interleaved 1F1B.
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+ prev_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input tensor or input tensor list.
+ """
+ if self.is_first_stage(model_chunk_id):
+ input_tensor = None
+ else:
+ input_tensor = self.comm.recv_forward(prev_rank)
+
+ return input_tensor
+
+ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
+ """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
+ For interleaved 1F1B.
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+ next_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input gradient tensor or gradient tensor list.
+ """
+ if self.is_last_stage(model_chunk_id):
+ output_tensor_grad = None
+ else:
+ output_tensor_grad = self.comm.recv_backward(next_rank)
+
+ return output_tensor_grad
+
+ def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
+ """Sends the input tensor to the next stage in pipeline.
+ For interleaved 1F1B.
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+ output_object (Any): Object to be sent.
+ next_rank (int, optional): The rank of the recipient of the tensor.
+ """
+ if not self.is_last_stage(model_chunk_id):
+ self.comm.send_forward(output_object, next_rank)
+
+ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
+ """Sends the gradient tensor to the previous stage in pipeline.
+ For interleaved 1F1B.
+
+ Args:
+ model_chunk_id (int): The current model chunk idx.
+ input_object (Any): Object to be sent.
+ prev_rank (int, optional): The rank of the recipient of the tensor
+ """
+ if not self.is_first_stage(model_chunk_id):
+ self.comm.send_backward(input_object, prev_rank)
+
+ def forward_step(self,
+ model_chunk: Module,
+ model_chunk_id: int,
+ input_obj: Optional[dict],
+ criterion: Callable,
+ accum_loss: Optional[torch.Tensor] = None,
+ outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
+ """Forward one step of the pipeline
+ Args:
+ model (Module): Model Chunk to be run
+ input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
+ criterion (Callable): Criterion to calculate loss.
+ accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
+ outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
+
+ Returns:
+ Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
+ """
+ micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
+
+ # for the first stage, input_obj is None
+ # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
+ output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
+
+ if self.is_last_stage(model_chunk_id):
+ loss = criterion(output_obj, micro_batch) / self.num_microbatches
+ if accum_loss is not None:
+ accum_loss.add_(loss.detach())
+ if outputs is not None:
+ outputs.append(tree_map(detach, output_obj))
+ return loss
+ else:
+ return output_obj
+
+ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
+ output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
+ """Backward one step of the pipeline
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to update the model
+ input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
+ output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
+ output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
+
+ Returns:
+ Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
+ """
+
+ # Retain the grad on the input_obj.
+ tree_map(retain_grad, input_obj)
+
+ # Backward pass.
+ if output_obj_grad is None:
+ optimizer.backward(output_obj)
+ else:
+ if "backward_tensor_keys" not in output_obj:
+ for k, grad in output_obj_grad.items():
+ optimizer.backward_by_grad(output_obj[k], grad)
+ else:
+ for k, grad in output_obj_grad.items():
+ output_obj[k].grad = grad
+ for k in output_obj["backward_tensor_keys"]:
+ tensor_to_backward = output_obj[k]
+ optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
+
+ # Collect the grad of the input_obj.
+ input_obj_grad = None
+ if input_obj is not None:
+ input_obj_grad = {}
+ for k, v in input_obj.items():
+ if isinstance(v, torch.Tensor) and v.grad is not None:
+ input_obj_grad[k] = v.grad
+ return input_obj_grad
+
+ def forward_backward_step(self,
+ model_chunk: Module,
+ data_iter: Iterable,
+ criterion: Callable[..., Any],
+ optimizer: Optional[OptimizerWrapper] = None,
+ return_loss: bool = False,
+ return_outputs: bool = False) -> dict:
+ """Runs interleaved 1F1B schedule, with communication between pipeline stages.
+
+ Args:
+ model_chunk (List[Module]): Model Chunk to be trained.
+ data_iter (Iterable): Data iterator.
+ criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
+ return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
+ return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
+
+ Returns:
+ dict: A dict with keys: 'loss' and 'outputs'.
+ """
+ forward_only = not torch.is_grad_enabled()
+ if optimizer is None:
+ assert forward_only, "Optimizer should be passed when doing backward."
+
+ self.load_batch(data_iter)
+ num_model_chunks = len(model_chunk)
+
+ # num_warmup_microbatches is the step when not all the processes are working
+ num_microbatches = self.num_microbatches * num_model_chunks
+ if forward_only:
+ num_warmup_microbatches = num_microbatches
+ else:
+ num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
+ num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
+ num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
+
+ num_microbatches_remaining = num_microbatches - num_warmup_microbatches
+
+ # Input, output tensors only need to be saved when doing backward passes
+ input_objs = None
+ output_objs = None
+
+ if not forward_only:
+ input_objs = [[] for _ in range(num_model_chunks)]
+ output_objs = [[] for _ in range(num_model_chunks)]
+
+ outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
+
+ if return_loss and self.stage_manager.is_last_stage():
+ accum_loss = torch.zeros(1, device=get_current_device())
+ else:
+ accum_loss = None
+
+ # for ranks except the first one, get into recv state
+ # print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
+ input_obj = self.recv_forward(0)
+ input_objs[0].append(input_obj)
+ # Run warmup forward passes.
+ for i in range(num_warmup_microbatches):
+ model_chunk_id = self.get_model_chunk_id(i, forward=True)
+
+ # recv first on first rank to avoid sending or recving at the same time
+ if self.stage_manager.is_first_stage():
+ input_obj = self.recv_forward(model_chunk_id)
+ output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
+ self.send_forward(model_chunk_id, output_obj)
+ if not forward_only:
+ input_objs[model_chunk_id].append(input_obj)
+ output_objs[model_chunk_id].append(output_obj)
+ else:
+ output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
+ if not forward_only:
+ output_objs[model_chunk_id].append(output_obj)
+ self.send_forward(model_chunk_id, output_obj)
+ if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
+ break
+ else:
+ model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
+
+ input_obj = self.recv_forward(model_chunk_id)
+ if not forward_only:
+ input_objs[model_chunk_id].append(input_obj)
+
+ # Run 1F1B in steady state.
+ for i in range(num_microbatches_remaining):
+ model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
+ last_iteration = (i == (num_microbatches_remaining - 1))
+
+ output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
+ if forward_only:
+ self.send_forward(model_chunk_id, output_obj)
+
+ if not last_iteration:
+ input_obj = self.recv_forward(model_chunk_id)
+
+ else:
+ self.send_forward(model_chunk_id, output_obj)
+ # Add input_obj and output_obj to end of list.
+ input_objs[model_chunk_id].append(input_obj)
+ output_objs[model_chunk_id].append(output_obj)
+
+ model_chunk_id = self.get_model_chunk_id(i, forward=False)
+ output_obj_grad = self.recv_backward(model_chunk_id)
+
+ # Pop output_obj and output_obj from the start of the list for
+ # the backward pass.
+ input_obj = input_objs[model_chunk_id].pop(0)
+ output_obj = output_objs[model_chunk_id].pop(0)
+
+ # backward
+ input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
+
+ if last_iteration:
+ input_obj = None
+ else:
+ model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True)
+ input_obj = self.recv_forward(model_chunk_id)
+ model_chunk_id = self.get_model_chunk_id(i, forward=False)
+ self.send_backward(model_chunk_id, input_obj_grad)
+
+ # Run cooldown backward passes.
+ if not forward_only:
+ for i in range(num_microbatches_remaining, num_microbatches):
+ model_chunk_id = self.get_model_chunk_id(i, forward=False)
+ # print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
+ input_obj = input_objs[model_chunk_id].pop(0)
+ output_obj = output_objs[model_chunk_id].pop(0)
+
+ output_obj_grad = self.recv_backward(model_chunk_id)
+ input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
+ self.send_backward(model_chunk_id, input_obj_grad)
+
+ if outputs is not None:
+ outputs = merge_batch(outputs)
+ return {'loss': accum_loss, 'outputs': outputs}
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
new file mode 100644
index 000000000000..fbd0f9f0d4c0
--- /dev/null
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -0,0 +1,320 @@
+from functools import partial
+from typing import Any, Callable, Iterable, List, Optional, Union
+
+import torch
+import torch.cuda
+from torch.nn import Module
+from torch.utils._pytree import tree_map
+
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.pipeline.p2p import PipelineP2PCommunication
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.utils.cuda import get_current_device
+
+from ._utils import (
+ detach,
+ get_batch_size,
+ get_micro_batch,
+ merge_batch,
+ model_forward,
+ retain_grad,
+ to_device,
+ tree_map_hf,
+)
+from .base import PipelineSchedule
+
+
+class OneForwardOneBackwardSchedule(PipelineSchedule):
+
+ def __init__(self,
+ stage_manager: PipelineStageManager,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None) -> None:
+ """1F1B pipeline schedule.
+
+ Args:
+ stage_manager (PipelineStageManager): Pipeline stage manager
+ num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
+ microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
+ """
+ super().__init__(stage_manager)
+ assert num_microbatches is not None or microbatch_size is not None, \
+ "Either num_microbatches or microbatch_size should be provided"
+ self.comm = PipelineP2PCommunication(stage_manager)
+ self.num_microbatches = num_microbatches
+ self.microbatch_size = microbatch_size
+ self.batch: Optional[Any] = None
+ self.batch_size: Optional[int] = None
+ self.microbatch_offset: Optional[int] = None
+ self._use_microbatch_size = num_microbatches is None
+
+ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
+ """Load a batch from data iterator.
+
+ Args:
+ data_iter (Iterable): Data iterator.
+ device (Optional[torch.device], optional): Target device. Defaults to None.
+ """
+ batch = next(data_iter)
+ if device is not None:
+ batch = tree_map(partial(to_device, device=device), batch)
+ self.batch = batch
+ self.batch_size = get_batch_size(batch)
+ self.microbatch_offset = 0
+ if not self._use_microbatch_size:
+ assert self.batch_size % self.num_microbatches == 0, \
+ "Batch size should divided by the number of microbatches"
+ self.microbatch_size = self.batch_size // self.num_microbatches
+ else:
+ assert self.batch_size % self.microbatch_size == 0, \
+ "Batch size should divided by the microbatch size"
+ self.num_microbatches = self.batch_size // self.microbatch_size
+
+ def load_micro_batch(self) -> Any:
+ """Load a micro batch from the current batch.
+
+ Returns:
+ Any: Micro batch.
+ """
+ micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
+ self.microbatch_offset += self.microbatch_size
+ return tree_map(partial(to_device, device=get_current_device()), micro_batch)
+
+ def recv_forward(self, prev_rank: int = None) -> Any:
+ """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
+ For 1F1B.
+
+ Args:
+ prev_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input tensor or input tensor list.
+ """
+ if self.stage_manager.is_first_stage():
+ input_tensor = None
+ else:
+ input_tensor = self.comm.recv_forward(prev_rank)
+
+ return input_tensor
+
+ def recv_backward(self, next_rank: int = None) -> Any:
+ """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
+ For 1F1B.
+
+ Args:
+ next_rank (int, optional): The rank of the source of the tensor.
+
+ Returns:
+ Any: The input gradient tensor or gradient tensor list.
+ """
+ if self.stage_manager.is_last_stage():
+ output_tensor_grad = None
+ else:
+ output_tensor_grad = self.comm.recv_backward(next_rank)
+
+ return output_tensor_grad
+
+ def send_forward(self, output_object: Any, next_rank: int = None) -> None:
+ """Sends the input tensor to the next stage in pipeline.
+ For 1F1B.
+
+ Args:
+ output_object (Any): Object to be sent.
+ next_rank (int, optional): The rank of the recipient of the tensor.
+ """
+ if not self.stage_manager.is_last_stage():
+ self.comm.send_forward(output_object, next_rank)
+
+ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
+ """Sends the gradient tensor to the previous stage in pipeline.
+ For 1F1B.
+
+ Args:
+ input_object (Any): Object to be sent.
+ prev_rank (int, optional): The rank of the recipient of the tensor
+ """
+ if not self.stage_manager.is_first_stage():
+ self.comm.send_backward(input_object, prev_rank)
+
+ def forward_step(self,
+ model: Module,
+ input_obj: Optional[dict],
+ criterion: Callable,
+ accum_loss: Optional[torch.Tensor] = None,
+ outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
+ """Forward one step of the pipeline
+
+ Args:
+ model (Module): Model to be run
+ input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
+ criterion (Callable): Criterion to calculate loss.
+ accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
+ outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
+
+ Returns:
+ Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
+ """
+ micro_batch = self.load_micro_batch()
+ # for the first stage, input_obj is None
+ # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
+ output_obj = model_forward(model, micro_batch, input_obj)
+ if self.stage_manager.is_last_stage():
+
+ loss = criterion(output_obj, micro_batch) / self.num_microbatches
+ if accum_loss is not None:
+ accum_loss.add_(loss.detach())
+ if outputs is not None:
+ outputs.append(tree_map_hf(detach, output_obj))
+ return loss
+ else:
+ return output_obj
+
+ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
+ output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
+ """Backward one step of the pipeline
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to update the model
+ input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
+ output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
+ output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
+
+ Returns:
+ Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
+ """
+
+ # Retain the grad on the input_obj.
+ tree_map(retain_grad, input_obj)
+ # Backward pass.
+ if output_obj_grad is None:
+ optimizer.backward(output_obj)
+ else:
+ if "backward_tensor_keys" not in output_obj:
+ for k, grad in output_obj_grad.items():
+ optimizer.backward_by_grad(output_obj[k], grad)
+ else:
+ for k, grad in output_obj_grad.items():
+ output_obj[k].grad = grad
+ for k in output_obj["backward_tensor_keys"]:
+ tensor_to_backward = output_obj[k]
+ optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
+
+ # Collect the grad of the input_obj.
+ input_obj_grad = None
+ if input_obj is not None:
+ input_obj_grad = {}
+ for k, v in input_obj.items():
+ if isinstance(v, torch.Tensor) and v.grad is not None:
+ input_obj_grad[k] = v.grad
+ return input_obj_grad
+
+ def forward_backward_step(self,
+ model: Module,
+ data_iter: Iterable,
+ criterion: Callable[..., Any],
+ optimizer: Optional[OptimizerWrapper] = None,
+ return_loss: bool = False,
+ return_outputs: bool = False) -> dict:
+ """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
+
+ Args:
+ model (Module): Model to be trained.
+ data_iter (Iterable): Data iterator.
+ criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
+ optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
+ return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
+ return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
+
+ Returns:
+ dict: A dict with keys: 'loss' and 'outputs'.
+ """
+ forward_only = not torch.is_grad_enabled()
+ if optimizer is None:
+ assert forward_only, "Optimizer should be passed when doing backward."
+
+ self.load_batch(data_iter)
+
+ # num_warmup_microbatches is the step when not all the processes are working
+ num_warmup_microbatches = self.stage_manager.num_stages - self.stage_manager.stage - 1
+ num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
+ num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
+
+ # Input, output tensors only need to be saved when doing backward passes
+ input_objs = None
+ output_objs = None
+
+ if not forward_only:
+ input_objs = []
+ output_objs = []
+
+ outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
+ if return_loss and self.stage_manager.is_last_stage():
+ accum_loss = torch.zeros(1, device=get_current_device())
+ else:
+ accum_loss = None
+
+ # Run warmup forward passes.
+ for i in range(num_warmup_microbatches):
+ input_obj = self.recv_forward()
+
+ output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
+
+ self.send_forward(output_obj)
+
+ if not forward_only:
+ input_objs.append(input_obj)
+ output_objs.append(output_obj)
+
+ # Before running 1F1B, need to receive first forward tensor.
+ # If all microbatches are run in warmup / cooldown phase, then no need to
+ # receive this tensor here.
+ if num_microbatches_remaining > 0:
+ input_obj = self.recv_forward()
+
+ # Run 1F1B in steady state.
+ for i in range(num_microbatches_remaining):
+ last_iteration = (i == (num_microbatches_remaining - 1))
+
+ output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
+ if forward_only:
+ self.send_forward(output_obj)
+
+ if not last_iteration:
+ input_obj = self.recv_forward()
+
+ else:
+ # TODO adjust here
+ self.send_forward(output_obj)
+ output_obj_grad = self.recv_backward()
+
+ # Add input_obj and output_obj to end of list.
+ input_objs.append(input_obj)
+ output_objs.append(output_obj)
+
+ # Pop output_obj and output_obj from the start of the list for
+ # the backward pass.
+ input_obj = input_objs.pop(0)
+ output_obj = output_objs.pop(0)
+ input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
+
+ if last_iteration:
+ input_obj = None
+ else:
+ input_obj = self.recv_forward()
+ self.send_backward(input_obj_grad)
+
+ # Run cooldown backward passes.
+ if not forward_only:
+ for i in range(num_warmup_microbatches):
+ input_obj = input_objs.pop(0)
+ output_obj = output_objs.pop(0)
+
+ output_obj_grad = self.recv_backward()
+ input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
+ self.send_backward(input_obj_grad)
+
+ if outputs is not None:
+ if isinstance(model, ModelWrapper):
+ model = model.unwrap()
+ outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
+ return {'loss': accum_loss, 'outputs': outputs}
diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py
new file mode 100644
index 000000000000..6ba7dc629958
--- /dev/null
+++ b/colossalai/pipeline/stage_manager.py
@@ -0,0 +1,136 @@
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple
+
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+from colossalai.cluster import ProcessGroupMesh
+
+
+class PipelineStageManager:
+ """PipelineStageManager is a helper class to manage pipeline stages.
+
+ Args:
+ pg_mesh (ProcessGroupMesh): Process group mesh.
+ pipeline_axis (int): The axis along which the pipeline is constructed.
+
+ Attributes:
+ num_stages (int): Number of stages in the pipeline.
+ stage (int): The current stage.
+ """
+
+ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None:
+ self.pg_mesh = pg_mesh
+ self.pipeline_axis = pipeline_axis
+ self.prev_rank: Optional[Tuple[int, ...]] = None
+ self.next_rank: Optional[Tuple[int, ...]] = None
+ self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
+ # init prev and next coord
+ coord = self.pg_mesh.coordinate()
+ # the prev rank of rank0 is the last rank
+ prev_coord = coord[: self.pipeline_axis] + \
+ (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
+ self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap')
+ # the next rank of the last rank is rank0
+ next_coord = coord[: self.pipeline_axis] + \
+ (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
+ self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap')
+
+ # init p2p process groups
+ stages = list(range(self.num_stages))
+ for prev, cur in zip(stages[:-1], stages[1:]):
+ group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
+ if self.stage in [prev, cur]:
+ ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
+ self.p2p_groups[tuple(ranks_in_group)] = group
+
+ if is_virtual:
+ # add the process group of the first rank and the last rank
+ # only used in interleaved pipeline for now
+ group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
+ if self.stage in [stages[0], stages[-1]]:
+ ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
+ self.p2p_groups[tuple(ranks_in_group)] = group
+
+ def is_first_stage(self) -> bool:
+ """Is the current stage the first stage.
+
+ Returns:
+ bool: Whether the current stage is the first stage.
+ """
+ return self.stage == 0
+
+ def is_last_stage(self) -> bool:
+ """Is the current stage the last stage.
+
+ Returns:
+ bool: Whether the current stage is the last stage.
+ """
+ return self.stage == self.num_stages - 1
+
+ @property
+ def num_stages(self) -> int:
+ """Number of stages in the pipeline.
+
+ Returns:
+ int: Number of stages in the pipeline.
+ """
+ return self.pg_mesh.size(self.pipeline_axis)
+
+ @property
+ def stage(self) -> int:
+ """Current stage.
+
+ Returns:
+ int: Current stage.
+ """
+ return self.pg_mesh.coordinate(self.pipeline_axis)
+
+ def get_rank(self) -> int:
+ """Get the rank of the current process.
+
+ Returns:
+ int: Rank of the current process.
+ """
+ return dist.get_rank()
+
+ def get_prev_rank(self) -> int:
+ """Get the rank of the previous stage.
+
+ Returns:
+ int: Rank of the previous stage.
+ """
+ return self.prev_rank
+
+ def get_next_rank(self) -> int:
+ """Get the rank of the next stage.
+
+ Returns:
+ int: Rank of the next stage.
+ """
+ return self.next_rank
+
+ def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
+ """Get the p2p process group between two ranks. The order of the two ranks does not matter.
+
+ Args:
+ first_rank (int): The first rank.
+ second_rank (int): The second rank.
+
+ Returns:
+ ProcessGroup: P2P process group between the two ranks.
+ """
+ if first_rank > second_rank:
+ first_rank, second_rank = second_rank, first_rank
+ return self.p2p_groups[(first_rank, second_rank)]
+
+ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
+ """Get the process group of the given stages.
+
+ Args:
+ stages (List[int]): List of stages.
+
+ Returns:
+ ProcessGroup: Process group of the given stages.
+ """
+ return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py
index ac8a3ad7d1db..be8428692756 100644
--- a/colossalai/pipeline/utils.py
+++ b/colossalai/pipeline/utils.py
@@ -1,12 +1,13 @@
import heapq
import inspect
+from collections import OrderedDict
+from typing import List
+
import torch
+from colossalai.legacy.nn.layer.utils import CheckpointModule
from colossalai.logging import get_dist_logger
-from colossalai.nn.layer.utils import CheckpointModule
-from typing import List
-from collections import OrderedDict
def _binary_partition(weights: List, start: int, end: int):
"""Returns the binary partition position of `weights`, given the start
@@ -162,7 +163,7 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
kwargs_offset = 1
elif isinstance(input_tensor, (tuple, OrderedDict)):
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
- # Huggingface will take their own structures based on OrderedDict as the output
+ # Huggingface will take their own structures based on OrderedDict as the output
# between layers so we've to close this check.
kwargs_offset = len(input_tensor)
args_name_list = list(sig.parameters.keys())
@@ -256,7 +257,7 @@ def call_module(module, args=None, kwargs=None):
def customized_partition(exec_seq):
'''
- This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
+ This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
annotation to note the partition point.
'''
customized_parts = {}
diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md
index bf4215c52980..559f9a56f61e 100644
--- a/colossalai/shardformer/README.md
+++ b/colossalai/shardformer/README.md
@@ -30,27 +30,48 @@
### Quick Start
-The sample API usage is given below:
+The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization):
```python
-from colossalai.shardformer import ShardConfig, Shard
+from colossalai.shardformer import ShardConfig, ShardFormer
from transformers import BertForMaskedLM
+import colossalai
# launch colossalai
-colossalai.launch_from_torch()
+colossalai.launch_from_torch(config={})
# create model
config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
# create huggingface model as normal
-shard_config = ShardConfig()
+shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
+ pipeline_stage_manager=stage_manager,
+ enable_tensor_parallelism=True,
+ enable_fused_normalization=True,
+ enable_flash_attention=True,
+ enable_jit_fused=True,
+ enable_sequence_parallelism=True,
+ enable_sequence_overlap=True)
+
shard_former = ShardFormer(shard_config=shard_config)
-sharded_model = shard_former.optimize(model).to('cuda')
+sharded_model, shared_params = shard_former.optimize(model).to('cuda')
# do everything like normal
...
```
+shardformer configuration
+
+`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel.
+`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism.
+{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }}
+`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows
+`enable_fused_normalization`: using apex fused layernorm
+`enable_flash_attention`: using flash attention
+`enable_jit_fused`: using jit fused operators
+`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension.
+`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`.
+
### Write your own policy
@@ -82,29 +103,30 @@ We will follow this roadmap to develop Shardformer:
- [x] API Implementation
- [x] Unit Testing
- [ ] Policy Implementation
- - [ ] Hugging Face
- - [ ] NLP
- - [x] BERT
- - [x] T5
- - [x] LlaMa
- - [x] GPT2
- - [x] OPT
- - [x] BLOOM
- - [ ] GLM
- - [ ] RoBERTa
- - [ ] ALBERT
- - [ ] ERNIE
- - [ ] GPT Neo
- - [ ] GPT-J
- - [ ] CV
- - [x] ViT
- - [ ] BEiT
- - [ ] SwinTransformer
- - [ ] SwinTransformer V2
- - [ ] Audio
- - [ ] Whisper
- - [ ] Multi-modal
- - [ ] To be added
+
+| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
+| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
+| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
+| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
+| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
+| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] |
+| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] |
+| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
+
## 💡 API Design
@@ -271,41 +293,36 @@ class ShardFormer:
Example:
+ org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
+ shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
- shard_former.init_distributed()
- model = shard_former.optimize(model, policy=policy)
- dataloader = shard_former.shard_dataset(dataset)
+ model, shared_params = shard_former.optimize(org_model)
"""
def __init__(self, shard_config: ShardConfig):
"""
Do two things:
- 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
+ 1. Create a distribute coordinator
2. serve as a store for shard config
"""
self.shard_config = shard_config
- self.pg_manager = None
+ self.coordinator = DistCoordinator()
- def init_distributed(self) -> colossalai.cluster.ProcessGroupManager:
- """
- Initialize the distributed process group according to the
- """
- pg_manager = ...
- self.pg_manager = pg_manager
- return pg_manager
+ def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
+ r"""
+ This method will optimize the model based on the given policy.
- def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module:
- """
- Shard model for TP and PP
- """
- ...
+ Args:
+ model (`torch.nn.Model`): the origin huggingface model
+ shard_config (`ShardConfig`): the config for distribute information
+ policy (`Policy`): the custom policy for sharding
- def shard_dataset(self, dataset: Dataset) -> Dataloader:
- """
- Shard dataset for DP
+ Returns: the sharded model and the shared parameters
"""
- ...
+ sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
+ shared_params = sharder.shard()
+ return model, shared_params
```
## ⌨️ Development Notes
@@ -372,16 +389,66 @@ pytest tests/test_shardformer
### System Performance
-To be added.
+We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.
+
+We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
+
+In the case of using 2 GPUs, the training times are as follows.
+| N_CTX | org_model | shard_model |
+| :------: | :-----: | :-----: |
+| 256 | 11.2ms | 17.2ms |
+| 512 | 9.8ms | 19.5ms |
+| 1024 | 19.6ms | 18.9ms |
+| 2048 | 46.6ms | 30.8ms |
+| 4096 | 160.5ms | 90.4ms |
+
+
+
+
+
+
+
+In the case of using 4 GPUs, the training times are as follows.
+
+| N_CTX | org_model | shard_model |
+| :------: | :-----: | :-----: |
+| 256 | 10.0ms | 21.1ms |
+| 512 | 11.5ms | 20.2ms |
+| 1024 | 22.1ms | 20.6ms |
+| 2048 | 46.9ms | 24.8ms |
+| 4096 | 160.4ms | 68.0ms |
+
+
+
+
+
+
+
+
+
+As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.
### Convergence
-To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
-| accuracy | f1 | loss | GPU number | model shard |
+To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
+
+the configurations are as follows:
+```python
+batch_size = 2
+epoch = 3
+lr = 2.4e-5
+accumulation_steps = 8
+warmup_fraction = 0.03
+```
+
+
+
+| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: |
-| 0.82594 | 0.87441 | 0.09913 | 4 | True |
-| 0.81884 | 0.87299 | 0.10120 | 2 | True |
-| 0.81855 | 0.87124 | 0.10357 | 1 | False |
+| 0.82971 | 0.87713 | 0.23194 | 4 | True |
+| 0.83797 | 0.88006 | 0.22683 | 2 | True |
+| 0.84521 | 0.88700 | 0.21822 | 1 | False |
+
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
diff --git a/colossalai/shardformer/_utils.py b/colossalai/shardformer/_utils.py
index 4ad877e72357..c553080de0a0 100644
--- a/colossalai/shardformer/_utils.py
+++ b/colossalai/shardformer/_utils.py
@@ -1,25 +1,57 @@
import re
-def get_obj_list_element(obj, a):
+def get_obj_list_element(obj, attr: str):
r"""
Get the element of the list in the object
+
+ If the attr is a normal attribute, return the attribute of the object.
+ If the attr is a index type, return the element of the index in the list, like `layers[0]`.
+
+ Args:
+ obj (Object): The object to get
+ attr (str): The suffix of the attribute to get
+
"""
re_pattern = r'\[\d+\]'
prog = re.compile(re_pattern)
- result = prog.search(a)
+ result = prog.search(attr)
if result:
matched_brackets = result.group()
matched_index = matched_brackets.replace('[', '')
matched_index = matched_index.replace(']', '')
- a_ = a.replace(matched_brackets, '')
- container_obj = getattr(obj, a_)
+ attr_ = attr.replace(matched_brackets, '')
+ container_obj = getattr(obj, attr_)
obj = container_obj[int(matched_index)]
else:
- obj = getattr(obj, a)
+ obj = getattr(obj, attr)
return obj
+def set_obj_list_element(obj, attr: str, value):
+ r"""
+ Set the element to value of a list object
+
+ It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value
+
+ Args:
+ obj (object): The object to set
+ attr (str): the string including a list index like `layers[0]`
+ """
+ re_pattern = r'\[\d+\]'
+ prog = re.compile(re_pattern)
+ result = prog.search(attr)
+ if result:
+ matched_brackets = result.group()
+ matched_index = matched_brackets.replace('[', '')
+ matched_index = matched_index.replace(']', '')
+ attr_ = attr.replace(matched_brackets, '')
+ container_obj = getattr(obj, attr_)
+ container_obj[int(matched_index)] = value
+ else:
+ setattr(obj, attr, value)
+
+
def hasattr_(obj, attr: str):
r"""
Check whether the object has the multi sublevel attr
@@ -56,7 +88,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
if ignore:
return
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
- setattr(obj, attrs[-1], value)
+ set_obj_list_element(obj, attrs[-1], value)
def getattr_(obj, attr: str, ignore: bool = False):
diff --git a/colossalai/shardformer/examples/shardformer_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py
similarity index 95%
rename from colossalai/shardformer/examples/shardformer_benchmark.py
rename to colossalai/shardformer/examples/convergence_benchmark.py
index de82305b2547..81be2017855c 100644
--- a/colossalai/shardformer/examples/shardformer_benchmark.py
+++ b/colossalai/shardformer/examples/convergence_benchmark.py
@@ -49,9 +49,12 @@ def train(args):
# if multiple GPUs, shard the model
if dist.get_world_size() > 1:
- shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
+ tp_group = dist.new_group(backend='nccl')
+ shard_config = ShardConfig(tensor_parallel_process_group=tp_group,
+ enable_tensor_parallelism=True,
+ enable_all_optimization=True)
shard_former = ShardFormer(shard_config=shard_config)
- model = shard_former.optimize(model)
+ model, _ = shard_former.optimize(model)
optim = Adam(model.parameters(), lr=args.lr)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
diff --git a/colossalai/shardformer/examples/shardformer_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh
similarity index 68%
rename from colossalai/shardformer/examples/shardformer_benchmark.sh
rename to colossalai/shardformer/examples/convergence_benchmark.sh
index f42b19a32d35..22f13a7cf827 100644
--- a/colossalai/shardformer/examples/shardformer_benchmark.sh
+++ b/colossalai/shardformer/examples/convergence_benchmark.sh
@@ -1,7 +1,7 @@
-torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
+torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
- --max_epochs 1 \
+ --max_epochs 3 \
--batch_size 2 \
--lr 2.4e-5 \
--fused_layernorm False \
diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py
new file mode 100644
index 000000000000..2f186709d946
--- /dev/null
+++ b/colossalai/shardformer/examples/performance_benchmark.py
@@ -0,0 +1,88 @@
+"""
+Shardformer Benchmark
+"""
+import torch
+import torch.distributed as dist
+import transformers
+import triton
+
+import colossalai
+from colossalai.shardformer import ShardConfig, ShardFormer
+
+
+def data_gen(batch_size, seq_length):
+ input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
+ attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def data_gen_for_sequence_classification(batch_size, seq_length):
+ # LM data gen
+ # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
+ data = data_gen(batch_size, seq_length)
+ data['labels'] = torch.ones((batch_size), dtype=torch.long)
+ return data
+
+
+MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
+ hidden_size=128,
+ intermediate_size=256,
+ num_attention_heads=4,
+ max_position_embeddings=128,
+ num_labels=16,
+ pad_token_id=2)
+BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
+model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
+
+# vary seq length for fixed head and batch=4
+configs = [
+ triton.testing.Benchmark(x_names=['N_CTX'],
+ x_vals=[2**i for i in range(8, 13)],
+ line_arg='provider',
+ line_vals=['org_model', 'shard_model'],
+ line_names=['org_model', 'shard_model'],
+ styles=[('red', '-'), ('blue', '-')],
+ ylabel='ms',
+ plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
+ args={
+ 'BATCH': BATCH,
+ 'dtype': torch.float16,
+ 'model_func': model_func
+ })
+]
+
+
+def train(model, data):
+ output = model(**data)
+ loss = output.logits.mean()
+ loss.backward()
+
+
+@triton.testing.perf_report(configs)
+def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
+ warmup = 10
+ rep = 100
+ # prepare data
+ data = data_gen_for_sequence_classification(BATCH, N_CTX)
+ data = {k: v.cuda() for k, v in data.items()}
+ model = model_func().to(device)
+ model.train()
+ if provider == "org_model":
+ fn = lambda: train(model, data)
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
+ return ms
+ if provider == "shard_model":
+ shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
+ shard_former = ShardFormer(shard_config=shard_config)
+ sharded_model, _ = shard_former.optimize(model)
+ sharded_model = sharded_model.cuda()
+ fn = lambda: train(sharded_model, data)
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
+ return ms
+
+
+# start benchmark, command:
+# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
+if __name__ == "__main__":
+ colossalai.launch_from_torch({})
+ bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index 7fad4948dfd0..c4586d18b90c 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -3,10 +3,11 @@
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
-from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
+from .parallel_module import ParallelModule
+from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
- 'FusedLayerNorm', 'FusedRMSNorm'
+ 'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col', 'ParallelModule'
]
diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py
index 7e97bee01b33..45b305733813 100644
--- a/colossalai/shardformer/layer/_operation.py
+++ b/colossalai/shardformer/layer/_operation.py
@@ -1,3 +1,5 @@
+from typing import Any
+
import torch
import torch.distributed as dist
import torch.nn.functional as F
@@ -141,6 +143,240 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None, None, None
+class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
+ """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
+
+ Args:
+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
+ overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
+
+ """
+
+ @staticmethod
+ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
+ ctx.save_for_backward(input_, weight)
+ ctx.use_bias = bias is not None
+ ctx.process_group = process_group
+ ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
+ ctx.dim = dim
+ ctx.overlap = overlap
+
+ input_parallel = _gather(input_, dim, process_group)
+
+ if bias is not None:
+ output = F.linear(input_parallel, weight, bias)
+ else:
+ output = F.linear(input_parallel, weight)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ dim = ctx.dim
+ process_group = ctx.process_group
+ overlap = ctx.overlap
+
+ if not overlap:
+ input_parallel = _gather(input_, dim, process_group)
+
+ total_input = input_parallel
+ grad_input = grad_output.matmul(weight)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ total_input = total_input.view(-1, total_input.shape[-1])
+
+ if ctx.async_grad_reduce_scatter:
+ # Asynchronous reduce-scatter
+ input_list = [
+ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(input_.shape, dtype=input_parallel.dtype,
+ device=input_parallel.device).contiguous()
+ handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ # Delay the start of weight gradient computation shortly (3us) to have
+ # reduce-scatter scheduled first and have GPU resources allocated
+ _ = torch.empty(1, device=grad_output.device) + 1
+
+ grad_weight = grad_output.t().matmul(total_input)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.async_grad_reduce_scatter:
+ handle.wait()
+
+ else:
+ input_ = input_.contiguous()
+ world_size = dist.get_world_size(process_group)
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+
+ # do all gather in is async way
+ gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
+ # calculate gradient and prepare data asynchronously with all-gather
+ # calculate
+ grad_input = grad_output.matmul(weight)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+ # prepare data
+ input_list = [
+ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
+ # wait until all-gather finished
+ gather_handle.wait()
+
+ # do reduce-scatter in async way
+ reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
+ # calculate gradient
+ if len(input_parallel.shape) > 2:
+ input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
+ grad_weight = grad_output.t().matmul(input_parallel)
+ # wait until reduce-scatter finished
+ reducescatter_handle.wait()
+
+ return output, grad_weight, grad_bias, None, None, None, None
+
+
+class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
+ """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
+
+ Args:
+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
+
+ """
+
+ @staticmethod
+ def forward(ctx, input_, process_group, dim):
+ ctx.dim = dim
+ ctx.process_group = process_group
+
+ # do reduce-scatter
+ new_shape = list(input_.shape)
+ assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
+ f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
+ new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
+ input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
+ output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
+ dist.reduce_scatter(output, input_list, group=process_group)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ dim = ctx.dim
+ process_group = ctx.process_group
+
+ return _gather(grad_output, dim, process_group), None, None
+
+
+class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
+ """
+ This class is designed for matmul operation with gather forward and reduce-scatter backward.
+
+ Args:
+ input_ (`torch.Tensor`): input matrix.
+ dim (int): the dimension to perform split and gather
+ process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
+
+ """
+
+ @staticmethod
+ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
+ ctx.save_for_backward(input_, weight)
+ ctx.use_bias = bias is not None
+ ctx.process_group = process_group
+ ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
+ ctx.dim = dim
+ ctx.overlap = overlap
+
+ input_parallel = _gather(input_, dim, process_group)
+
+ output = torch.matmul(input_parallel, weight)
+
+ if bias is not None:
+ output = output + bias
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_, weight = ctx.saved_tensors
+ use_bias = ctx.use_bias
+ dim = ctx.dim
+ process_group = ctx.process_group
+ overlap = ctx.overlap
+
+ if not overlap:
+ input_parallel = _gather(input_, dim, process_group)
+
+ total_input = input_parallel
+ grad_input = grad_output.matmul(weight.T)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ total_input = total_input.view(-1, total_input.shape[-1])
+
+ if ctx.async_grad_reduce_scatter:
+ # Asynchronous reduce-scatter
+ input_list = [
+ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(input_.shape, dtype=input_parallel.dtype,
+ device=input_parallel.device).contiguous()
+ handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ # Delay the start of weight gradient computation shortly (3us) to have
+ # reduce-scatter scheduled first and have GPU resources allocated
+ _ = torch.empty(1, device=grad_output.device) + 1
+
+ grad_weight = total_input.t().matmul(grad_output)
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+
+ if ctx.async_grad_reduce_scatter:
+ handle.wait()
+
+ else:
+ world_size = dist.get_world_size(process_group)
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+
+ # do all gather in is async way
+ gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
+ # calculate gradient and prepare data asynchronously with all-gather
+ # calculate
+ grad_input = grad_output.matmul(weight.T)
+ grad_output = grad_output.contiguous()
+ # Convert the tensor shapes to 2D for execution compatibility
+ if len(grad_output.shape) > 2:
+ grad_output = grad_output.view(-1, grad_output.shape[-1])
+ grad_bias = grad_output.sum(dim=0) if use_bias else None
+ # prepare data
+ input_list = [
+ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
+ ]
+ output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
+ # wait until all-gather finished
+ gather_handle.wait()
+
+ # do reduce-scatter in async way
+ reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
+ input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
+ # calculate gradient
+ if len(input_parallel.shape) > 2:
+ input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
+ grad_weight = input_parallel.t().matmul(grad_output)
+ # wait until reduce-scatter finished
+ reducescatter_handle.wait()
+
+ return output, grad_weight, grad_bias, None, None, None, None
+
+
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
@@ -200,6 +436,26 @@ def backward(ctx, grad_output):
return _reduce(grad_output, ctx.process_group), None
+class _GatherForwardSplitBackward(torch.autograd.Function):
+ """Gather the input from model parallel region and concatenate.
+
+ Args:
+ input_: input matrix.
+ parallel_mode: parallel mode.
+ dim: dimension
+ """
+
+ @staticmethod
+ def forward(ctx, input_, dim, process_group):
+ ctx.process_group = process_group
+ ctx.dim = dim
+ return _gather(input_, dim, process_group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split(grad_output, ctx.dim, ctx.process_group), None, None
+
+
def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
@@ -235,9 +491,8 @@ def _gather(input_, dim=-1, process_group=None):
return input_
# all gather
- rank = dist.get_rank(process_group)
+ input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
- tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat
@@ -246,24 +501,27 @@ def _gather(input_, dim=-1, process_group=None):
return output
-class _GatherForwardSplitBackward(torch.autograd.Function):
- """Gather the input from model parallel region and concatenate.
+def _reduce_scatter(input_, dim=1, process_group=None):
+ """ Do reduce-scatter operation.
Args:
- input_: input matrix.
- parallel_mode: parallel mode.
- dim: dimension
+ input_ (`torch.Tensor`): The input tensor from sequence parallel region.
+ dim (int): The dimension to perform reduce-scatter.
+ process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
+ world_size = dist.get_world_size(process_group)
+ if world_size == 1:
+ return input_
- @staticmethod
- def forward(ctx, input_, dim, process_group):
- ctx.process_group = process_group
- ctx.dim = dim
- return _gather(input_, dim, process_group)
+ # reduce-scatter
+ new_shape = list(input_.shape)
+ assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
+ f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
+ new_shape[dim] = new_shape[dim] // world_size
+ output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
+ dist.reduce_scatter(output, input_, group=process_group)
- @staticmethod
- def backward(ctx, grad_output):
- return _split(grad_output, ctx.dim, ctx.process_group), None, None
+ return output
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
@@ -274,6 +532,22 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
+def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
+ overlap):
+ return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
+ async_grad_reduce_scatter, dim, overlap)
+
+
+def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
+ return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
+
+
+def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
+ overlap):
+ return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
+ async_grad_reduce_scatter, dim, overlap)
+
+
def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py
index db39a457b7fd..847ca175ad57 100644
--- a/colossalai/shardformer/layer/embedding.py
+++ b/colossalai/shardformer/layer/embedding.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-from typing import Callable, List, Union
+from typing import Callable, List, Optional, Union
import torch
import torch.distributed as dist
@@ -9,11 +9,16 @@
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
-from torch.nn.parameter import Parameter
+from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
-from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
+from colossalai.tensor.d_tensor.api import (
+ is_distributed_tensor,
+ shard_colwise,
+ shard_rowwise,
+ sharded_tensor_to_existing_param,
+)
from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
@@ -60,6 +65,7 @@ def __init__(self,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
+ weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
@@ -74,18 +80,24 @@ def __init__(self,
self.embed_kwargs = kwargs
self.gather_output = gather_output
- # Parameters.
- factory_kwargs = {'device': device, 'dtype': dtype}
- weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
- sharded_weight = shard_colwise(weight, process_group)
- self.weight = sharded_tensor_to_param(sharded_weight)
-
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
- with self.randomizer.fork_rng(enable_cpu=True):
- self.reset_parameters(weight_initializer)
+ # Parameters.
+ if weight is None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+ self.weight = weight
+ if not is_distributed_tensor(self.weight):
+ sharded_weight = shard_colwise(self.weight.data, process_group)
+ sharded_tensor_to_existing_param(sharded_weight, self.weight)
+
+ if weight is None:
+ with self.randomizer.fork_rng(enable_cpu=True):
+ self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding,
@@ -95,6 +107,7 @@ def from_native_module(module: nn.Embedding,
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
+ LazyInitContext.materialize(module)
# get the attributes
num_embedding = module.num_embeddings
embedding_dim = module.embedding_dim
@@ -120,14 +133,10 @@ def from_native_module(module: nn.Embedding,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
+ weight=module.weight,
*args,
**kwargs)
- # copy the weight
- with torch.no_grad():
- sharded_weight = shard_colwise(module.weight.data, process_group)
- embedding.weight.copy_(sharded_weight)
-
return embedding
def reset_parameters(self, weight_initializer) -> None:
@@ -142,7 +151,6 @@ def _fill_padding_idx_with_zero(self) -> None:
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
-
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
@@ -187,13 +195,13 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
+ weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
- self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
@@ -206,16 +214,26 @@ def __init__(self,
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
- # parameter
- factory_kwargs = {'device': device, 'dtype': dtype}
- weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
- sharded_weight = shard_rowwise(weight, process_group)
- self.weight = sharded_tensor_to_param(sharded_weight)
+ # padding index
+ self.padding_idx = self._select_padding_idx(padding_idx)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
- self.reset_parameters(weight_initializer)
+
+ # parameter
+ if weight is None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+ self.weight = weight
+ if not is_distributed_tensor(self.weight):
+ sharded_weight = shard_rowwise(self.weight.data, process_group)
+ sharded_tensor_to_existing_param(sharded_weight, self.weight)
+
+ if weight is None:
+ self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@@ -223,6 +241,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup,
r"""
Convert a native pytorch embedding module to a parallel module.
"""
+ LazyInitContext.materialize(module)
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
@@ -241,13 +260,9 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup,
padding_idx=padding_idx,
device=device,
process_group=process_group,
+ weight=module.weight,
*args,
**kwargs)
- with torch.no_grad():
- # shard and slice the weight along the vocabulary(num_embeddings) dimension
- # the shape of the weight is (num_embeddings, embedding_dim)
- shard_weight = shard_rowwise(module.weight.data, process_group)
- vocab_embedding_1d.weight.data.copy_(shard_weight)
return vocab_embedding_1d
@@ -263,6 +278,15 @@ def _fill_padding_idx_with_zero(self) -> None:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
+ def _select_padding_idx(self, padding_idx: int):
+ # select padding index according to the rank
+ if padding_idx is None:
+ return None
+ elif padding_idx < self.vocab_end_index and padding_idx >= self.vocab_start_index:
+ return padding_idx - self.vocab_start_index
+ else:
+ return None
+
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index 26ba5883c64f..111d51b3f8d8 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
import math
-from typing import Callable, List, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -12,12 +12,20 @@
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
+from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
-from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
+from colossalai.tensor.d_tensor.api import (
+ is_distributed_tensor,
+ shard_colwise,
+ shard_rowwise,
+ sharded_tensor_to_existing_param,
+)
from ._operation import (
gather_forward_split_backward,
+ linear_gather_forward_reducescatter_backward,
+ linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
split_forward_gather_backward,
@@ -44,6 +52,8 @@ class Linear1D_Col(ParallelModule):
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
+ overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
@@ -63,7 +73,12 @@ def __init__(self,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
+ seq_parallel: bool = False,
+ seq_parallel_dim: int = 1,
+ overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
+ weight: Optional[Parameter] = None,
+ bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
@@ -72,6 +87,9 @@ def __init__(self,
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
+ self.seq_parallel = seq_parallel
+ self.seq_parallel_dim = seq_parallel_dim
+ self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
@@ -79,26 +97,42 @@ def __init__(self,
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
- # Parameters.
- factory_kwargs = {'device': device, 'dtype': dtype}
+ # offset the seed with randomizer index and rank
+ seed = torch.random.initial_seed()
+ self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
- weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
- sharded_weight = shard_rowwise(weight, self.process_group)
- self.weight = sharded_tensor_to_param(sharded_weight)
+ # sanity check
+ if weight is not None:
+ assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
+ else:
+ assert bias_ is None, 'bias_ must be None if weight is None'
+
+ # Parameters.
+ if weight is None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+ self.weight = weight
+ if not is_distributed_tensor(self.weight):
+ sharded_weight = shard_rowwise(self.weight.data, self.process_group)
+ sharded_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
- bias = torch.empty(self.out_features, **factory_kwargs)
- sharded_bias = shard_colwise(bias, self.process_group)
- self.bias = sharded_tensor_to_param(sharded_bias)
+ if bias_ is None:
+ self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+ else:
+ bias_.data = bias_.data.to(device=device, dtype=dtype)
+ self.bias = bias_
+ if not is_distributed_tensor(self.bias):
+ sharded_bias = shard_colwise(self.bias.data, self.process_group)
+ sharded_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.bias = None
- # offset the seed with randomizer index and rank
- seed = torch.random.initial_seed()
- self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
-
- # init weights
- self.reset_parameters(weight_initializer, bias_initializer)
+ if weight is None:
+ # init weights
+ self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@@ -106,6 +140,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
+ LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
@@ -118,22 +153,24 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
+ tp_size = dist.get_world_size(process_group)
+ if out_features < tp_size:
+ return module
+
+ if out_features % tp_size != 0:
+ raise ValueError(
+ f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")
+
linear_1d = Linear1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
+ weight=module.weight,
+ bias_=module.bias,
*args,
**kwargs)
- with torch.no_grad():
- # the weigh to the linear layer is a transpose
- # thus shard on row is equal to shard on column
- sharded_weight = shard_rowwise(module.weight.data, process_group)
- linear_1d.weight.data.copy_(sharded_weight)
- if bias:
- sharded_bias = shard_colwise(module.bias.data, process_group)
- linear_1d.bias.copy_(sharded_bias)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@@ -153,7 +190,12 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
+ if self.seq_parallel:
+ output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
+ self.process_group, True,
+ self.seq_parallel_dim, self.overlap)
+ else:
+ output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
@@ -176,6 +218,8 @@ class Linear1D_Row(ParallelModule):
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
+ process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
@@ -194,8 +238,12 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
+ seq_parallel: bool = False,
+ seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
+ weight: Optional[Parameter] = None,
+ bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
@@ -209,32 +257,51 @@ def __init__(self,
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
+ self.seq_parallel = seq_parallel
+ self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
+ # offset the seed with randomizer index and rank
+ seed = torch.random.initial_seed()
+ self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
+
+ # sanity check
+ if weight is not None:
+ assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
+ else:
+ assert bias_ is None, 'bias_ must be None if weight is None'
+
# Parameters.
- # Initialize weight.
- factory_kwargs = {'device': device, 'dtype': dtype}
- weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
- sharded_weight = shard_colwise(weight, self.process_group)
- self.weight = sharded_tensor_to_param(sharded_weight)
+ if weight is None:
+ # Initialize weight.
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+ self.weight = weight
+ if not is_distributed_tensor(self.weight):
+ sharded_weight = shard_colwise(self.weight.data, self.process_group)
+ sharded_tensor_to_existing_param(sharded_weight, self.weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
+
if bias:
- self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+ if bias_ is None:
+ self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+ else:
+ bias_.data = bias_.data.to(device=device, dtype=dtype)
+ self.bias = bias_
else:
self.bias = None
- # offset the seed with randomizer index and rank
- seed = torch.random.initial_seed()
- self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
-
- with self.randomizer.fork_rng(enable_cpu=True):
- self.reset_parameters(weight_initializer, bias_initializer)
+ if weight is None:
+ with self.randomizer.fork_rng(enable_cpu=True):
+ self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@@ -242,6 +309,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
+ LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
@@ -254,24 +322,24 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
+ tp_size = dist.get_world_size(process_group)
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
+
linear_1d = Linear1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
+ weight=module.weight,
+ bias_=module.bias,
*args,
**kwargs)
- # TODO: copy the sharded weights
- with torch.no_grad():
- # the weigh to the linear layer is a transpose
- # thus shard on col is equal to shard on row
- sharded_weight = shard_colwise(module.weight.data, process_group)
- linear_1d.weight.data.copy_(sharded_weight)
-
- if bias:
- linear_1d.bias.copy_(module.bias.data)
-
return linear_1d
def chunk_weight(self):
@@ -326,7 +394,11 @@ def forward(self, input_: Tensor) -> Tensor:
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
- output = reduce_forward(output_parallel, self.process_group)
+ if self.seq_parallel:
+ output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
+ self.seq_parallel_dim)
+ else:
+ output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
index b27307154a76..0aea295664a7 100644
--- a/colossalai/shardformer/layer/normalization.py
+++ b/colossalai/shardformer/layer/normalization.py
@@ -4,6 +4,8 @@
import torch
import torch.nn as nn
+from colossalai.lazy import LazyInitContext
+
__all__ = ['FusedLayerNorm', 'FusedRMSNorm']
FAST_LAYERNORM_SUPPORTED_SIZE = [
@@ -35,6 +37,7 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
+ LazyInitContext.materialize(module)
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
@@ -57,10 +60,8 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps,
elementwise_affine=elementwise_affine).to(dtype).to(device)
- with torch.no_grad():
- # copy weight and bias
- layernorm.weight.copy_(module.weight)
- layernorm.bias.copy_(module.bias)
+ layernorm.weight = module.weight
+ layernorm.bias = module.bias
return layernorm
@@ -84,6 +85,7 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
)
+ LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm":
normalized_shape = module.weight.shape[0]
@@ -97,8 +99,6 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
- with torch.no_grad():
- # copy weight and bias
- rmsnorm.weight.copy_(module.weight)
+ rmsnorm.weight = module.weight
return rmsnorm
diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py
index bda147b121ab..4f391920e29b 100644
--- a/colossalai/shardformer/layer/parallel_module.py
+++ b/colossalai/shardformer/layer/parallel_module.py
@@ -10,6 +10,7 @@
from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
+from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
@@ -56,13 +57,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
for name, param in self._parameters.items():
if param is not None:
- param_ = param if keep_vars else param.detach()
- if is_distributed_tensor(param_):
- destination[prefix + name] = to_global(param_)
- elif is_customized_distributed_tensor(param_):
- destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
- else:
- destination[prefix + name] = param_
+ destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 9d51670c65dd..5ce77805f9b8 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -2,27 +2,32 @@
# -*- encoding: utf-8 -*-
import math
-from typing import Callable, List, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
-import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
+from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import (
- customized_distributed_tensor_to_param,
+ customized_distributed_tensor_to_existing_param,
distribute_tensor_with_customization,
+ is_customized_distributed_tensor,
+ is_distributed_tensor,
shard_rowwise,
- sharded_tensor_to_param,
+ sharded_tensor_to_existing_param,
)
from ._operation import (
gather_forward_split_backward,
+ linear_reducescatter_forward_gather_backward,
+ linear_with_async_comm,
+ matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
@@ -31,7 +36,7 @@
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
-__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row']
+__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row']
# ====================================
# For GPT Only
@@ -147,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
@@ -170,8 +176,12 @@ def __init__(self,
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
+ seq_parallel: bool = False,
+ overlap: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
+ weight: Optional[Parameter] = None,
+ bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
@@ -180,6 +190,8 @@ def __init__(self,
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
+ self.seq_parallel = seq_parallel
+ self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
@@ -189,40 +201,56 @@ def __init__(self,
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
+ # offset the seed with randomizer index and rank
+ seed = torch.random.initial_seed()
+ self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
+
+ # sanity check
+ if weight is not None:
+ assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
+ else:
+ assert bias_ is None, 'bias_ must be None if weight is None'
+
# Parameters.
- # Initialize weight.
- factory_kwargs = {'device': device, 'dtype': dtype}
- weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
+ if weight is None:
+ # Initialize weight.
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+ self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
def gather_fn(tensor):
- return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True)
+ return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
- with torch.no_grad():
- sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
- self.weight = customized_distributed_tensor_to_param(sharded_weight)
+ if not is_customized_distributed_tensor(self.weight):
+ with torch.no_grad():
+ sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
+ customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
- bias = torch.empty(self.out_features, **factory_kwargs)
-
- with torch.no_grad():
- sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
- self.bias = customized_distributed_tensor_to_param(sharded_bias)
+ if bias_ is None:
+ self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+ else:
+ bias_.data = bias_.data.to(device=device, dtype=dtype)
+ self.bias = bias_
+ if not is_customized_distributed_tensor(self.bias):
+ with torch.no_grad():
+ sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn)
+ customized_distributed_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.bias = None
- # offset the seed with randomizer index and rank
- seed = torch.random.initial_seed()
- self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
-
- # init weights
- self.reset_parameters(weight_initializer, bias_initializer)
+ if weight is None:
+ # init weights
+ self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
- def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
- *args, **kwargs) -> ParallelModule:
+ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
+ **kwargs) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@@ -231,6 +259,7 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
"""
+ LazyInitContext.materialize(module)
# get the attributes
in_features = module.weight.shape[0]
out_features = module.weight.shape[1]
@@ -243,29 +272,24 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
+ tp_size = dist.get_world_size(process_group)
+ if out_features < tp_size:
+ return module
+
+ if out_features % tp_size != 0:
+ raise ValueError(
+ f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")
+
linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
+ weight=module.weight,
+ bias_=module.bias,
*args,
**kwargs)
- # TODO: copy the sharded weights
- with torch.no_grad():
- sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
- n_fused=n_fused,
- process_group=process_group,
- is_transposed=True)
- linear_1d.weight.data.copy_(sharded_weight.data)
-
- if bias:
- sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
- n_fused=n_fused,
- process_group=process_group,
- is_transposed=True)
- linear_1d.bias.data.copy_(sharded_bias.data)
-
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@@ -279,15 +303,19 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
- # Set up backprop all-reduce.
- input_parallel = reduce_backward(input_, self.process_group)
- # input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
- output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
- self.async_communication)
+ if self.seq_parallel:
+ input_parallel = input_
+ output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
+ self.process_group, True, 1, self.overlap)
+ else:
+ # Set up backprop all-reduce.
+ input_parallel = reduce_backward(input_, self.process_group)
+ output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
+ self.async_communication)
if self.gather_output:
# All-gather across the partitions.
@@ -312,6 +340,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
+ seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
@@ -329,8 +358,11 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
+ seq_parallel: bool = False,
parallel_input: bool = True,
skip_bias_add: bool = False,
+ weight: Optional[Parameter] = None,
+ bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
@@ -344,35 +376,52 @@ def __init__(self,
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
+ self.seq_parallel = seq_parallel
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
+ # offset the seed with randomizer index and rank
+ seed = torch.random.initial_seed()
+ self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
+
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
+ # sanity check
+ if weight is not None:
+ assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
+ else:
+ assert bias_ is None, 'bias_ must be None if weight is None'
+
# Parameters.
- # Initialize weight.
- factory_kwargs = {'device': device, 'dtype': dtype}
- weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
- sharded_weight = shard_rowwise(weight, self.process_group)
- self.weight = sharded_tensor_to_param(sharded_weight)
+ if weight is None:
+ # Initialize weight.
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+ self.weight = weight
+ if not is_distributed_tensor(self.weight):
+ sharded_weight = shard_rowwise(self.weight.data, self.process_group)
+ sharded_tensor_to_existing_param(sharded_weight, self.weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
- self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+ if bias_ is None:
+ self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+ else:
+ bias_.data = bias_.data.to(device=device, dtype=dtype)
+ self.bias = bias_
else:
self.bias = None
- # offset the seed with randomizer index and rank
- seed = torch.random.initial_seed()
- self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
-
- # init weights
- self.reset_parameters(weight_initializer, bias_initializer)
+ if weight is None:
+ # init weights
+ self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@@ -380,6 +429,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
+ LazyInitContext.materialize(module)
# get the attributes
in_features = module.weight.shape[0]
out_features = module.weight.shape[1]
@@ -392,24 +442,24 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
f'Expected only one process group, got {len(process_group)}.'
process_group = process_group[0]
+ tp_size = dist.get_world_size(process_group)
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
+
linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
+ weight=module.weight,
+ bias_=module.bias,
*args,
**kwargs)
- # TODO: copy the sharded weights
- with torch.no_grad():
- # the weigh to the linear layer is a transpose
- # thus shard on col is equal to shard on row
- sharded_weight = shard_rowwise(module.weight.data, process_group)
- linear_1d.weight.data.copy_(sharded_weight.data)
-
- if bias:
- linear_1d.bias.copy_(module.bias.data)
-
return linear_1d
def chunk_weight(self):
@@ -428,21 +478,21 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
origin_device = self.bias.device
- self.bias = self.bias.cuda()
+ self.bias.data = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
- self.bias = self.bias.to(origin_device)
+ self.bias.data = self.bias.to(origin_device)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
- input_.shape, self.weight.shape, self.weight.shape[-1])
+ input_.shape, self.weight.shape, self.weight.shape[0])
input_ = input_
else:
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
- input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
+ input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
if self.stream_chunk_num > 1:
@@ -463,7 +513,10 @@ def forward(self, input_: Tensor) -> Tensor:
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = torch.matmul(input_, self.weight)
- output = reduce_forward(output_parallel, self.process_group)
+ if self.seq_parallel:
+ output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
+ else:
+ output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
@@ -471,3 +524,194 @@ def forward(self, input_: Tensor) -> Tensor:
return output
else:
return output, self.bias
+
+
+# ====================================
+# For Fused torch.nn.Linear
+# ====================================
+
+
+class FusedLinear1D_Col(ParallelModule):
+ r"""Fused Linear layer with column parallelism.
+
+ The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
+ its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `torch.nn.Linear` layer (Fused QKV) in normal torch layer of huggingface, like SAM.
+
+ Args:
+ in_features (int): size of each input sample.
+ out_features (int): size of each output sample.
+ bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
+ dtype (`torch.dtype`): The dtype of parameters, defaults to None.
+ device (`torch.device`): The device of parameters, defaults to None.
+ n_fused (int): The number items fused, defaults to 3 (QKV).
+ process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
+ gather_output (bool, optional): If true, call all-gather on output and make Y available
+ to all GPUs, otherwise, every GPU will have its output
+ which is :math:`Y_i = XA_i`, defaults to False
+ skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
+ which is preserved for kernel fusion, defaults to False
+ weight_initializer (`typing.Callable`):
+ The initializer of weight, defaults to kaiming uniform initializer.
+ bias_initializer (`typing.Callable`):
+ The initializer of bias, defaults to xavier uniform initializer.
+
+ More details about ``initializer`` please refer to
+ `init `_.
+ """
+
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype: torch.dtype = None,
+ device: torch.device = None,
+ process_group: ProcessGroup = None,
+ async_communication: bool = False,
+ gather_output: bool = False,
+ skip_bias_add: bool = False,
+ n_fused: int = 3,
+ weight: Optional[Parameter] = None,
+ bias_: Optional[Parameter] = None,
+ weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
+ bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
+ super().__init__()
+ # Keep input parameters
+ self.in_features = in_features
+ self.out_features = out_features
+ self.gather_output = gather_output
+ self.skip_bias_add = skip_bias_add
+ self.device = device
+ self.n_fused = n_fused
+ self.process_group = process_group
+ self.async_communication = async_communication
+
+ if skip_bias_add and not bias:
+ raise ValueError('cannot skip bias addition if bias is None')
+
+ # offset the seed with randomizer index and rank
+ seed = torch.random.initial_seed()
+ self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
+
+ # sanity check
+ if weight is not None:
+ assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
+ else:
+ assert bias_ is None, 'bias_ must be None if weight is None'
+
+ # Parameters.
+ if weight is None:
+ # Initialize weight.
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
+ else:
+ weight.data = weight.data.to(device=device, dtype=dtype)
+ self.weight = weight
+
+ def shard_fn(tensor):
+ return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
+
+ def gather_fn(tensor):
+ return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
+
+ if not is_customized_distributed_tensor(self.weight):
+ with torch.no_grad():
+ sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
+ customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
+
+ if bias:
+ if bias_ is None:
+ self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
+ else:
+ bias_.data = bias_.data.to(device=device, dtype=dtype)
+ self.bias = bias_
+ if not is_customized_distributed_tensor(self.bias):
+ with torch.no_grad():
+ sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn)
+ customized_distributed_tensor_to_existing_param(sharded_bias, self.bias)
+ else:
+ self.bias = None
+
+ if weight is None:
+ # init weights
+ self.reset_parameters(weight_initializer, bias_initializer)
+
+ @staticmethod
+ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
+ *args, **kwargs) -> ParallelModule:
+ r"""
+ Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
+
+ Args:
+ module (`nn.Linear`): The module to be converted.
+ process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
+ n_fused (int): The number of layers to be fused. In common, Q,K,V are fused in one weight.
+ """
+ # get the attributes
+ in_features = module.in_features
+ out_features = module.out_features
+ bias = module.bias is not None
+ device = module.weight.device
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, \
+ f'Expected only one process group, got {len(process_group)}.'
+ process_group = process_group[0]
+
+ linear_1d = FusedLinear1D_Col(in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ process_group=process_group,
+ weight=module.weight,
+ bias_=module.bias,
+ *args,
+ **kwargs)
+
+ # # TODO: copy the sharded weights
+ # with torch.no_grad():
+ # sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
+ # n_fused=n_fused,
+ # process_group=process_group,
+ # is_transposed=False)
+ # linear_1d.weight.data.copy_(sharded_weight.data)
+
+ # if bias:
+ # sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
+ # n_fused=n_fused,
+ # process_group=process_group,
+ # is_transposed=False)
+ # linear_1d.bias.data.copy_(sharded_bias.data)
+ print(linear_1d.weight.shape)
+ return linear_1d
+
+ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
+ with self.randomizer.fork_rng(enable_cpu=True):
+ fan_in, fan_out = self.in_features, self.out_features
+ weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
+ if self.bias is not None:
+ bias_initializer(self.bias, fan_in=fan_in)
+
+ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
+ assert input_.shape[-1] == self.weight.shape[-1], \
+ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
+ input_.shape, self.weight.shape, self.weight.shape[-1])
+ # Set up backprop all-reduce.
+ # input_parallel = reduce_backward(input_, self.process_group)
+ input_parallel = input_
+
+ # Matrix multiply.
+ bias = self.bias if not self.skip_bias_add else None
+
+ output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
+
+ if self.gather_output:
+ # All-gather across the partitions.
+ output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
+ else:
+ output = output_parallel
+
+ if self.skip_bias_add:
+ return output, self.bias
+ else:
+ return output
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index f2ac6563c46f..577bef076a7e 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -29,8 +29,6 @@ class Randomizer:
_INDEX = 0
def __init__(self, seed: int):
- # TODO: remove colossalai.context.random
-
self.seed = seed
# Handle CUDA rng state
@@ -122,6 +120,13 @@ def increment_index():
"""
Randomizer._INDEX += 1
+ @staticmethod
+ def reset_index():
+ """
+ Reset the index to zero.
+ """
+ Randomizer._INDEX = 0
+
@staticmethod
def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
"""
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
new file mode 100644
index 000000000000..30855a622adb
--- /dev/null
+++ b/colossalai/shardformer/modeling/bert.py
@@ -0,0 +1,1285 @@
+import math
+import warnings
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.models.bert.modeling_bert import (
+ BertForMaskedLM,
+ BertForMultipleChoice,
+ BertForNextSentencePrediction,
+ BertForPreTraining,
+ BertForPreTrainingOutput,
+ BertForQuestionAnswering,
+ BertForSequenceClassification,
+ BertForTokenClassification,
+ BertLMHeadModel,
+ BertModel,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig
+from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
+
+
+class BertPipelineForwards:
+ '''
+ This class serves as a micro library for forward function substitution of Bert models
+ under pipeline setting.
+ '''
+
+ @staticmethod
+ def bert_model_forward(
+ self: BertModel,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ # TODO(jianghai): add explaination of the output here.
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+ else:
+ input_shape = hidden_states.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+ attention_mask = extended_attention_mask
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+ hidden_states = hidden_states if hidden_states is not None else None
+
+ if stage_manager.is_first_stage():
+ hidden_states = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ # inherit from bert_layer,this should be changed when we add the feature to record hidden_states
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.encoder.gradient_checkpointing and self.encoder.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+ next_decoder_cache = () if use_cache else None
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ # layer_outputs
+ layer_outputs = hidden_states if hidden_states is not None else None
+
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ if shard_config is not None and shard_config.enable_sequence_parallelism:
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = split_forward_gather_backward(
+ encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
+
+ for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
+ if stage_manager.is_first_stage() and idx == 0:
+ encoder_attention_mask = encoder_extended_attention_mask
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[idx] if head_mask is not None else None
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.encoder.gradient_checkpointing and self.encoder.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + \
+ (layer_outputs[2],)
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ if shard_config is not None and shard_config.enable_sequence_parallelism:
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # end of a stage loop
+ sequence_output = hidden_states if hidden_states is not None else None
+
+ if stage_manager.is_last_stage():
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+ if not return_dict:
+ return (sequence_output, pooled_output) + layer_outputs[1:]
+ # return dict is not supported at this moment
+ else:
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ # output of non-first and non-last stages: must be a dict
+ else:
+ # intermediate stage always return dict
+ return {
+ 'hidden_states': hidden_states,
+ }
+
+ @staticmethod
+ def bert_for_pretraining_forward(
+ self: BertForPreTraining,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ next_sentence_label: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # TODO(jianghai) left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ outputs = BertPipelineForwards.bert_model_forward(
+ self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states if hidden_states is not None else None,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+ past_key_values = None
+ all_hidden_states = None
+ all_self_attentions = None
+ all_cross_attentions = None
+
+ if stage_manager.is_last_stage():
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+ # the last stage for pretraining model
+ total_loss = None
+ if labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ total_loss = masked_lm_loss + next_sentence_loss
+
+ if not return_dict:
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return BertForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ seq_relationship_logits=seq_relationship_score,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+
+ # intermediate stage always return dict
+ return {
+ 'hidden_states': hidden_states,
+ }
+
+ @staticmethod
+ def bert_lm_head_model_forward(
+ self: BertLMHeadModel,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ logger = logging.get_logger(__name__)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ use_cache = False
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ outputs = BertPipelineForwards.bert_model_forward(
+ self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states if hidden_states is not None else None,
+ stage_index=stage_index,
+ shard_config=shard_config)
+ past_key_values = None
+ all_hidden_states = None
+ all_self_attentions = None
+ all_cross_attentions = None
+
+ if stage_manager.is_last_stage():
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ # intermediate stage always return dict
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bert_for_masked_lm_forward(
+ self: BertForMaskedLM,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.Tensor] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ outputs = BertPipelineForwards.bert_model_forward(
+ self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+
+ if stage_manager.is_last_stage():
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bert_for_next_sentence_prediction_forward(
+ self: BertForNextSentencePrediction,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.Tensor] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ **kwargs,
+ ):
+ # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, BertForNextSentencePrediction
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
+ >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
+ >>> logits = outputs.logits
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+ ```
+ """
+ logger = logging.get_logger(__name__)
+
+ if "next_sentence_label" in kwargs:
+ warnings.warn(
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
+ FutureWarning,
+ )
+ labels = kwargs.pop("next_sentence_label")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ outputs = BertPipelineForwards.bert_model_forward(self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config)
+
+ if stage_manager.is_last_stage():
+ pooled_output = outputs[1]
+ seq_relationship_scores = self.cls(pooled_output)
+
+ next_sentence_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
+
+ if not return_dict:
+ output = (seq_relationship_scores,) + outputs[2:]
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
+
+ return NextSentencePredictorOutput(
+ loss=next_sentence_loss,
+ logits=seq_relationship_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ # intermediate stage always return dict
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bert_for_sequence_classification_forward(
+ self: BertForSequenceClassification,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.Tensor] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ outputs = BertPipelineForwards.bert_model_forward(self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config)
+
+ if stage_manager.is_last_stage():
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bert_for_token_classification_forward(
+ self: BertForTokenClassification,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.Tensor] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ outputs = BertPipelineForwards.bert_model_forward(self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config)
+
+ if stage_manager.is_last_stage():
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bert_for_multiple_choice_forward(
+ self: BertForMultipleChoice,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.Tensor] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ # in our pipeline design,input ids are copied for every stage and shouldn't be none
+ # the input_ids for multiple choice model is [batch_size, num_choices, sequence_length]
+ if stage_manager.is_last_stage():
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None else None)
+
+ outputs = BertPipelineForwards.bert_model_forward(
+ self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+ if stage_manager.is_last_stage():
+ pooled_output = outputs[1]
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bert_for_question_answering_forward(
+ self: BertForQuestionAnswering,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.Tensor] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ # NOTE: the arg start_position and end_position are used only for the last stage
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ outputs = BertPipelineForwards.bert_model_forward(self.bert,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config)
+ if stage_manager.is_last_stage():
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+
+def get_bert_flash_attention_forward():
+
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+ from transformers.models.bert.modeling_bert import BertAttention
+
+ def forward(
+ self: BertAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ use_cache = past_key_value is not None
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ final_attention_mask = None
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if use_cache:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ final_attention_mask = relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ final_attention_mask = relative_position_scores_query + relative_position_scores_key
+
+ scale = 1 / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ if final_attention_mask != None:
+ final_attention_mask = final_attention_mask * scale + attention_mask
+ else:
+ final_attention_mask = attention_mask
+
+ if final_attention_mask is not None:
+ batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
+ tgt_len = key_layer.size()[2]
+ final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len,
+ tgt_len).contiguous()
+
+ query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
+ key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
+ value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
+
+ context_layer = me_attention(query_layer,
+ key_layer,
+ value_layer,
+ attn_bias=final_attention_mask,
+ p=self.dropout.p,
+ scale=scale)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, None)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_bert_self_output_forward():
+
+ from transformers.models.bert.modeling_bert import BertSelfOutput
+
+ def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+ return forward
+
+
+def get_jit_fused_bert_output_forward():
+
+ from transformers.models.bert.modeling_bert import BertOutput
+
+ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+ return forward
+
+
+def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ embedding_output = split_forward_gather_backward(embedding_output,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = split_forward_gather_backward(
+ encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ sequence_output = gather_forward_split_backward(sequence_output,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py
new file mode 100644
index 000000000000..69730fd3d254
--- /dev/null
+++ b/colossalai/shardformer/modeling/blip2.py
@@ -0,0 +1,120 @@
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+
+def forward_fn():
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ mixed_qkv = self.qkv(hidden_states)
+
+ # modified from original code, which is:
+ # mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
+ # 2, 0, 3, 1, 4
+ # )
+ # to:
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ query_states, key_states, value_states = (
+ mixed_qkv[0],
+ mixed_qkv[1],
+ mixed_qkv[2],
+ )
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
+
+ attention_scores = attention_scores * self.scale
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ output = self.projection(context_layer)
+
+ outputs = (output, attention_probs) if output_attentions else (output, None)
+
+ return outputs
+
+ return forward
+
+
+def get_blip2_flash_attention_forward():
+
+ from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
+
+ from colossalai.kernel.cuda_native import ColoAttention
+
+ def forward(
+ self: Blip2Attention,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+ mixed_qkv = self.qkv(hidden_states)
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
+ query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
+
+ attention = ColoAttention(embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout.p,
+ scale=self.scale)
+ context_layer = attention(query_states, key_states, value_states)
+
+ output = self.projection(context_layer)
+ outputs = (output, None)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_blip2_QFormer_self_output_forward():
+
+ from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
+
+ def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+ return forward
+
+
+def get_jit_fused_blip2_QFormer_output_forward():
+
+ from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
+
+ def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index a3d774ff2abb..66f24dc6088b 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -1,6 +1,32 @@
+import warnings
+from typing import List, Optional, Tuple, Union
+
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.nn import functional as F
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.models.bloom.modeling_bloom import (
+ BloomForCausalLM,
+ BloomForQuestionAnswering,
+ BloomForSequenceClassification,
+ BloomForTokenClassification,
+ BloomModel,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
+from colossalai.shardformer.shard import ShardConfig
+
+logger = logging.get_logger(__name__)
def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
@@ -67,3 +93,984 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int,
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
return build_bloom_alibi_tensor
+
+
+class BloomPipelineForwards:
+ '''
+ This class serves as a micro library for bloom pipeline forwards.
+ '''
+
+ @staticmethod
+ def bloom_model_forward(
+ self: BloomModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']:
+
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # add warnings here
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ # case: First stage of training
+ if stage_manager.is_first_stage():
+ # check input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+ # initialize in the first stage and then pass to the next stage
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+
+ # extra recording tensor should be generated in the first stage
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+ # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ if past_key_values[0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[2] # source_len
+
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
+
+ # causal_mask is constructed every stage and its input is passed through different stages
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]),
+ start=start_idx):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ alibi,
+ causal_mask,
+ layer_past,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ )
+
+ hidden_states = outputs[0]
+
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + \
+ (outputs[2 if use_cache else 1],)
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ if stage_manager.is_last_stage():
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ # TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if stage_manager.is_last_stage():
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ # attention_mask is not returned ; presents = past_key_values
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+ else:
+ # always return dict for imediate stage
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bloom_for_causal_lm_forward(self: BloomForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ **deprecated_arguments):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config)
+ past_key_values = None
+ all_hidden_states = None
+ all_self_attentions = None
+ all_cross_attentions = None
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+ lm_logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ batch_size, seq_length, vocab_size = shift_logits.shape
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size),
+ shift_labels.view(batch_size * seq_length))
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+ else:
+ hidden_states = transformer_outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bloom_for_sequence_classification_forward(
+ self: BloomForSequenceClassification,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ **deprecated_arguments,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ transformer_outputs = BloomPipelineForwards.bloom_model_forward(
+ self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+ past_key_values = None
+ all_hidden_states = None
+ all_self_attentions = None
+ all_cross_attentions = None
+ if stage_manager.is_last_stage():
+ batch_size = hidden_states.shape[0]
+ # update batch size
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+ logger.warning(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`")
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+ else:
+ hidden_states = transformer_outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bloom_for_token_classification_forward(
+ self: BloomForTokenClassification,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ **deprecated_arguments,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ transformer_outputs = BloomPipelineForwards.bloom_model_forward(
+ self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+ past_key_values = None
+ all_hidden_states = None
+ all_self_attentions = None
+ all_cross_attentions = None
+
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ batch_size, seq_length = labels.shape
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels),
+ labels.view(batch_size * seq_length))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+ else:
+ hidden_states = transformer_outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def bloom_for_question_answering_forward(
+ self: BloomForQuestionAnswering,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ outputs = BloomPipelineForwards.bloom_model_forward(
+ self.transformer,
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+ past_key_values = None
+ all_hidden_states = None
+ all_self_attentions = None
+ all_cross_attentions = None
+
+ if stage_manager.is_last_stage():
+ sequence_output = outputs[0]
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+
+def get_bloom_flash_attention_forward(enabel_jit_fused=False):
+
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+ from transformers.models.bloom.modeling_bloom import BloomAttention
+
+ def forward(
+ self: BloomAttention,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+
+ fused_qkv = self.query_key_value(hidden_states)
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+ batch_size, tgt_len, _ = query_layer.size()
+
+ _, kv_length, _, _ = key_layer.size()
+
+ proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
+ query_layer = query_layer.contiguous().view(*proj_shape)
+ key_layer = key_layer.contiguous().view(*proj_shape)
+ value_layer = value_layer.contiguous().view(*proj_shape)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ # concatenate along seq_length dimension:
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
+ key_layer = torch.cat((past_key, key_layer), dim=1)
+ value_layer = torch.cat((past_value, value_layer), dim=1)
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ tgt_len = key_layer.size()[1]
+
+ attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length),
+ dtype=torch.float32,
+ device=query_layer.device,
+ requires_grad=True)
+ attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1,
+ kv_length) * self.beta
+ attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask,
+ torch.finfo(torch.float32).min)
+
+ context_layer = me_attention(query_layer,
+ key_layer,
+ value_layer,
+ attn_bias=attention_numerical_mask,
+ scale=self.inv_norm_factor,
+ p=self.attention_dropout.p)
+ context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices):int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ # TODO to replace with the bias_dropout_add function in jit
+ output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
+ outputs = (output_tensor, present, None)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_bloom_attention_forward():
+
+ from transformers.models.bloom.modeling_bloom import BloomAttention
+
+ def forward(
+ self: BloomAttention,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ):
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+
+ batch_size, q_length, _, _ = query_layer.shape
+
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
+ key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ # concatenate along seq_length dimension:
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
+ # - value: [batch_size * self.num_heads, kv_length, head_dim]
+ key_layer = torch.cat((past_key, key_layer), dim=2)
+ value_layer = torch.cat((past_value, value_layer), dim=1)
+
+ _, _, kv_length = key_layer.shape
+
+ if use_cache is True:
+ present = (key_layer, value_layer)
+ else:
+ present = None
+
+ # [batch_size * num_heads, q_length, kv_length]
+ # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
+ matmul_result = alibi.baddbmm(
+ batch1=query_layer,
+ batch2=key_layer,
+ beta=self.beta,
+ alpha=self.inv_norm_factor,
+ )
+
+ # change view to [batch_size, num_heads, q_length, kv_length]
+ attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
+
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
+ input_dtype = attention_scores.dtype
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
+ if input_dtype == torch.float16:
+ attention_scores = attention_scores.to(torch.float)
+ attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
+
+ # [batch_size, num_heads, q_length, kv_length]
+ attention_probs = self.attention_dropout(attention_probs)
+
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ # change view [batch_size x num_heads, q_length, kv_length]
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
+
+ # matmul: [batch_size * num_heads, q_length, head_dim]
+ context_layer = torch.bmm(attention_probs_reshaped, value_layer)
+
+ # change view [batch_size, num_heads, q_length, head_dim]
+ context_layer = self._merge_heads(context_layer)
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices):int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
+
+ outputs = (output_tensor, present)
+ if output_attentions:
+ outputs += (attention_probs,)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_bloom_mlp_forward():
+
+ from transformers.models.bloom.modeling_bloom import BloomMLP
+
+ def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
+
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ intermediate_output = torch.zeros_like(residual)
+ slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
+ for i in range(self.pretraining_tp):
+ intermediate_output = intermediate_output + F.linear(
+ hidden_states[:, :, int(i * slices):int((i + 1) * slices)],
+ self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)],
+ )
+ else:
+ intermediate_output = self.dense_4h_to_h(hidden_states)
+ output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
+ return output
+
+ return forward
+
+
+def get_jit_fused_bloom_gelu_forward():
+
+ from transformers.models.bloom.modeling_bloom import BloomGelu
+
+ from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
+
+ def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor:
+ bias = torch.zeros_like(x)
+ if self.training:
+ return JitGeLUFunction.apply(x, bias)
+ else:
+ return self.bloom_gelu_forward(x, bias)
+
+ return forward
+
+
+def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
+
+ from transformers import BloomModel
+
+ def forward(
+ self: BloomModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.h))
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ if past_key_values[0] is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
+
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ alibi,
+ causal_mask,
+ layer_past,
+ head_mask[i],
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+ # Add last hidden state
+ hidden_states = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py
new file mode 100644
index 000000000000..16dcf87c8cfc
--- /dev/null
+++ b/colossalai/shardformer/modeling/chatglm2.py
@@ -0,0 +1,399 @@
+""" PyTorch ChatGLM model. """
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.nn import CrossEntropyLoss, LayerNorm
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig
+from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
+from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMModel,
+ GLMBlock,
+)
+
+
+def get_flash_core_attention_forward():
+
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
+
+ from .chatglm2_6b.modeling_chatglm import CoreAttention
+
+ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
+ pytorch_major_version = int(torch.__version__.split(".")[0])
+ if pytorch_major_version >= 2:
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
+ key_layer,
+ value_layer,
+ is_causal=True)
+ else:
+ if attention_mask is not None:
+ attention_mask = ~attention_mask
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
+ attention_mask)
+ context_layer = context_layer.permute(2, 0, 1, 3)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.reshape(*new_context_layer_shape)
+ else:
+ # Raw attention scores
+ query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
+ key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
+ value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
+
+ scale = 1.0 / self.norm_factor
+ if self.coeff is not None:
+ scale = scale * self.coeff
+
+ flash_attention_mask = None
+ attn_mask_type = None
+ if attention_mask is None:
+ attn_mask_type = AttnMaskType.causal
+ else:
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+ attn_mask_type = AttnMaskType.paddedcausal
+
+ attention = ColoAttention(embed_dim=self.hidden_size_per_partition,
+ num_heads=self.num_attention_heads_per_partition,
+ dropout=self.attention_dropout.p,
+ scale=scale)
+ context_layer = attention(query_layer,
+ key_layer,
+ value_layer,
+ attn_mask=flash_attention_mask,
+ attn_mask_type=attn_mask_type)
+
+ context_layer = context_layer.permute(1, 0, -1).contiguous()
+
+ return context_layer
+
+ return forward
+
+
+def get_jit_fused_glm_block_forward():
+
+ from .chatglm2_6b.modeling_chatglm import GLMBlock
+
+ def forward(
+ self: GLMBlock,
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=None,
+ use_cache=True,
+ ):
+ # hidden_states: [s, b, h]
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, kv_cache = self.self_attention(
+ layernorm_output,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=kv_cache,
+ use_cache=use_cache,
+ )
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training)
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training)
+
+ return output, kv_cache
+
+ return forward
+
+
+class ChatGLMPipelineForwards:
+ '''
+ This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
+ '''
+
+ @staticmethod
+ def chatglm_model_forward(
+ self: ChatGLMModel,
+ input_ids,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ full_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ logger = logging.get_logger(__name__)
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+ if stage_manager.is_first_stage():
+ batch_size, seq_length = input_ids.shape
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ seq_length, batch_size = hidden_states.shape[:2]
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(batch_size=batch_size,
+ device=input_ids.device,
+ dtype=inputs_embeds.dtype)
+ if attention_mask is not None:
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask],
+ dim=-1)
+ if full_attention_mask is None:
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
+ # Rotary positional embeddings
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
+ if position_ids is not None:
+ rotary_pos_emb = rotary_pos_emb[position_ids]
+ else:
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
+ if not past_key_values:
+ past_key_values = [None for _ in range(self.num_layers)]
+ presents = () if use_cache else None
+ if self.encoder.gradient_checkpointing and self.encoder.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+ all_self_attentions = None
+ all_hidden_states = () if output_hidden_states else None
+ start_idx, end_idx = stage_index[0], stage_index[1]
+
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group)
+ for idx in range(start_idx, end_idx):
+ layer = self.encoder._get_layer(idx)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ if self.encoder.gradient_checkpointing and self.encoder.training:
+ layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb,
+ past_key_values[idx], use_cache)
+ else:
+ layer_ret = layer(hidden_states,
+ full_attention_mask,
+ rotary_pos_emb,
+ kv_cache=past_key_values[idx],
+ use_cache=use_cache)
+ hidden_states, kv_cache = layer_ret
+ if use_cache:
+ presents = presents + (kv_cache,)
+
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ if stage_manager.is_last_stage():
+ # final layer_norm
+ if self.encoder.post_layer_norm:
+ hidden_states = self.encoder.final_layernorm(hidden_states)
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+ else:
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_last_logit: Optional[bool] = False,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None):
+ logger = logging.get_logger(__name__)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
+ transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward(
+ self.transformer,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+ if return_last_logit:
+ hidden_states = hidden_states[-1:]
+ lm_logits = self.transformer.output_layer(hidden_states)
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits.to(torch.float32)
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+ else:
+ return transformer_outputs
+
+
+def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
+
+ def forward(
+ self,
+ input_ids,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ full_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
+
+ batch_size, seq_length = input_ids.shape
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids)
+
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(
+ batch_size=batch_size,
+ device=input_ids.device,
+ dtype=inputs_embeds.dtype,
+ )
+ if attention_mask is not None:
+ attention_mask = torch.cat(
+ [
+ attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask,
+ ],
+ dim=-1,
+ )
+
+ if full_attention_mask is None:
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
+
+ # Rotary positional embeddings
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
+ if position_ids is not None:
+ rotary_pos_emb = rotary_pos_emb[position_ids]
+ else:
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
+
+ # Run encoder.
+ # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
+ inputs_embeds = split_forward_gather_backward(inputs_embeds,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group)
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
+ inputs_embeds,
+ full_attention_mask,
+ rotary_pos_emb=rotary_pos_emb,
+ kv_caches=past_key_values,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ )
+
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=0,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ if not return_dict:
+ return tuple(v for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py
new file mode 100644
index 000000000000..3e78732be2da
--- /dev/null
+++ b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py
@@ -0,0 +1,58 @@
+from transformers import PretrainedConfig
+
+
+class ChatGLMConfig(PretrainedConfig):
+ model_type = "chatglm"
+
+ def __init__(self,
+ num_layers=28,
+ padded_vocab_size=65024,
+ hidden_size=4096,
+ ffn_hidden_size=13696,
+ kv_channels=128,
+ num_attention_heads=32,
+ seq_length=2048,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ layernorm_epsilon=1e-5,
+ rmsnorm=True,
+ apply_residual_connection_post_layernorm=False,
+ post_layer_norm=True,
+ add_bias_linear=False,
+ add_qkv_bias=False,
+ bias_dropout_fusion=True,
+ multi_query_attention=False,
+ multi_query_group_num=1,
+ apply_query_key_layer_scaling=True,
+ attention_softmax_in_fp32=True,
+ fp32_residual_connection=False,
+ quantization_bit=0,
+ pre_seq_len=None,
+ prefix_projection=False,
+ **kwargs):
+ self.num_layers = num_layers
+ self.vocab_size = padded_vocab_size
+ self.padded_vocab_size = padded_vocab_size
+ self.hidden_size = hidden_size
+ self.ffn_hidden_size = ffn_hidden_size
+ self.kv_channels = kv_channels
+ self.num_attention_heads = num_attention_heads
+ self.seq_length = seq_length
+ self.hidden_dropout = hidden_dropout
+ self.attention_dropout = attention_dropout
+ self.layernorm_epsilon = layernorm_epsilon
+ self.rmsnorm = rmsnorm
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
+ self.post_layer_norm = post_layer_norm
+ self.add_bias_linear = add_bias_linear
+ self.add_qkv_bias = add_qkv_bias
+ self.bias_dropout_fusion = bias_dropout_fusion
+ self.multi_query_attention = multi_query_attention
+ self.multi_query_group_num = multi_query_group_num
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
+ self.fp32_residual_connection = fp32_residual_connection
+ self.quantization_bit = quantization_bit
+ self.pre_seq_len = pre_seq_len
+ self.prefix_projection = prefix_projection
+ super().__init__(**kwargs)
diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
new file mode 100644
index 000000000000..a21ee0231422
--- /dev/null
+++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
@@ -0,0 +1,1373 @@
+"""
+The ChatGLM2-6B License
+
+1. Definitions
+
+“Licensor” means the ChatGLM2-6B Model Team that distributes its Software.
+
+“Software” means the ChatGLM2-6B model parameters made available under this license.
+
+2. License Grant
+
+Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+3. Restriction
+
+You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
+
+You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
+
+4. Disclaimer
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+5. Limitation of Liability
+
+EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
+
+6. Dispute Resolution
+
+This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
+
+Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
+"""
+""" PyTorch ChatGLM model. """
+
+import copy
+import math
+import re
+import sys
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss, LayerNorm
+from torch.nn.utils import skip_init
+from transformers.generation.logits_process import LogitsProcessor
+from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+
+from .configuration_chatglm import ChatGLMConfig
+
+# flags required to enable jit fusion kernels
+
+if sys.platform != "darwin":
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+ torch._C._jit_override_can_fuse_on_cpu(True)
+ torch._C._jit_override_can_fuse_on_gpu(True)
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
+_CONFIG_FOR_DOC = "ChatGLM6BConfig"
+
+CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "THUDM/chatglm2-6b",
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
+]
+
+
+def default_init(cls, *args, **kwargs):
+ return cls(*args, **kwargs)
+
+
+class InvalidScoreLogitsProcessor(LogitsProcessor):
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
+ scores.zero_()
+ scores[..., 5] = 5e4
+ return scores
+
+
+class PrefixEncoder(torch.nn.Module):
+ """
+ The torch.nn model to encode the prefix
+ Input shape: (batch-size, prefix-length)
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
+ """
+
+ def __init__(self, config: ChatGLMConfig):
+ super().__init__()
+ self.prefix_projection = config.prefix_projection
+ if self.prefix_projection:
+ # Use a two-layer MLP to encode the prefix
+ kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
+ self.trans = torch.nn.Sequential(
+ torch.nn.Linear(kv_size, config.hidden_size),
+ torch.nn.Tanh(),
+ torch.nn.Linear(config.hidden_size, kv_size),
+ )
+ else:
+ self.embedding = torch.nn.Embedding(
+ config.pre_seq_len,
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2,
+ )
+
+ def forward(self, prefix: torch.Tensor):
+ if self.prefix_projection:
+ prefix_tokens = self.embedding(prefix)
+ past_key_values = self.trans(prefix_tokens)
+ else:
+ past_key_values = self.embedding(prefix)
+ return past_key_values
+
+
+def split_tensor_along_last_dim(
+ tensor: torch.Tensor,
+ num_partitions: int,
+ contiguous_split_chunks: bool = False,
+) -> List[torch.Tensor]:
+ """Split a tensor along its last dimension.
+
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+
+ Returns:
+ A list of Tensors
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = tensor.size()[last_dim] // num_partitions
+ # Split.
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
+ # Note: torch.split does not create contiguous tensors by default.
+ if contiguous_split_chunks:
+ return tuple(chunk.contiguous() for chunk in tensor_list)
+
+ return tensor_list
+
+
+class RotaryEmbedding(nn.Module):
+
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
+ super().__init__()
+ inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.dim = dim
+ self.original_impl = original_impl
+
+ def forward_impl(
+ self,
+ seq_len: int,
+ n_elem: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ base: int = 10000,
+ ):
+ """Enhanced Transformer with Rotary Position Embedding.
+
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
+ transformers/rope/__init__.py. MIT License:
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
+ """
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
+ theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
+
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
+
+ # Calculate the product of position index and $\theta_i$
+ idx_theta = torch.outer(seq_idx, theta).float()
+
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
+
+ # this is to mimic the behaviour of complex32, else we will get different results
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
+ return cache
+
+ def forward(self, max_seq_len, offset=0):
+ return self.forward_impl(
+ max_seq_len,
+ self.dim,
+ dtype=self.inv_freq.dtype,
+ device=self.inv_freq.device,
+ )
+
+
+@torch.jit.script
+def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
+ # x: [sq, b, np, hn]
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
+ rot_dim = rope_cache.shape[-2] * 2
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
+ # truncate to support variable sizes
+ rope_cache = rope_cache[:sq]
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
+ ],
+ -1,
+ )
+ x_out2 = x_out2.flatten(3)
+ return torch.cat((x_out2, x_pass), dim=-1)
+
+
+class RMSNorm(torch.nn.Module):
+
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
+ super().__init__()
+ self.elementwise_affine = True
+ self.normalized_shape = normalized_shape
+ self.weight = torch.nn.Parameter(torch.ones(normalized_shape, device=device, dtype=dtype))
+ self.eps = eps
+
+ def forward(self, hidden_states: torch.Tensor):
+ input_dtype = hidden_states.dtype
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
+ return (self.weight * hidden_states).to(input_dtype)
+
+
+class CoreAttention(torch.nn.Module):
+
+ def __init__(self, config: ChatGLMConfig, layer_number):
+ super(CoreAttention, self).__init__()
+
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
+ if self.apply_query_key_layer_scaling:
+ self.attention_softmax_in_fp32 = True
+ self.layer_number = max(1, layer_number)
+
+ projection_size = config.kv_channels * config.num_attention_heads
+
+ # Per attention head and per partition values.
+ self.hidden_size_per_partition = projection_size
+ self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads)
+ self.num_attention_heads_per_partition = config.num_attention_heads
+
+ coeff = None
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
+ if self.apply_query_key_layer_scaling:
+ coeff = self.layer_number
+ self.norm_factor *= coeff
+ self.coeff = coeff
+
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
+
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
+ pytorch_major_version = int(torch.__version__.split(".")[0])
+ if pytorch_major_version >= 2:
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
+ key_layer,
+ value_layer,
+ is_causal=True)
+ else:
+ if attention_mask is not None:
+ attention_mask = ~attention_mask
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
+ attention_mask)
+ context_layer = context_layer.permute(2, 0, 1, 3)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.reshape(*new_context_layer_shape)
+ else:
+ # Raw attention scores
+
+ # [b, np, sq, sk]
+ output_size = (
+ query_layer.size(1),
+ query_layer.size(2),
+ query_layer.size(0),
+ key_layer.size(0),
+ )
+
+ # [sq, b, np, hn] -> [sq, b * np, hn]
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
+ # [sk, b, np, hn] -> [sk, b * np, hn]
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
+
+ # preallocting input tensor: [b * np, sq, sk]
+ matmul_input_buffer = torch.empty(
+ output_size[0] * output_size[1],
+ output_size[2],
+ output_size[3],
+ dtype=query_layer.dtype,
+ device=query_layer.device,
+ )
+
+ # Raw attention scores. [b * np, sq, sk]
+ matmul_result = torch.baddbmm(
+ matmul_input_buffer,
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
+ beta=0.0,
+ alpha=(1.0 / self.norm_factor),
+ )
+
+ # change view to [b, np, sq, sk]
+ attention_scores = matmul_result.view(*output_size)
+
+ # ===========================
+ # Attention probs and dropout
+ # ===========================
+
+ # attention scores and attention mask [b, np, sq, sk]
+ if self.attention_softmax_in_fp32:
+ attention_scores = attention_scores.float()
+ if self.coeff is not None:
+ attention_scores = attention_scores * self.coeff
+ if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]):
+ attention_mask = torch.ones(
+ output_size[0],
+ 1,
+ output_size[2],
+ output_size[3],
+ device=attention_scores.device,
+ dtype=torch.bool,
+ )
+ attention_mask.tril_()
+ attention_mask = ~attention_mask
+ if attention_mask is not None:
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
+ attention_probs = F.softmax(attention_scores, dim=-1)
+ attention_probs = attention_probs.type_as(value_layer)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.attention_dropout(attention_probs)
+ # =========================
+ # Context layer. [sq, b, hp]
+ # =========================
+
+ # value_layer -> context layer.
+ # [sk, b, np, hn] --> [b, np, sq, hn]
+
+ # context layer shape: [b, np, sq, hn]
+ output_size = (
+ value_layer.size(1),
+ value_layer.size(2),
+ query_layer.size(0),
+ value_layer.size(3),
+ )
+ # change view [sk, b * np, hn]
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
+ # change view [b * np, sq, sk]
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
+ # matmul: [b * np, sq, hn]
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
+ # change view [b, np, sq, hn]
+ context_layer = context_layer.view(*output_size)
+ # [b, np, sq, hn] --> [sq, b, np, hn]
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
+ # [sq, b, np, hn] --> [sq, b, hp]
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ return context_layer
+
+
+class SelfAttention(torch.nn.Module):
+ """Parallel self-attention layer abstract class.
+
+ Self-attention layer takes input with size [s, b, h]
+ and returns output of the same size.
+ """
+
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
+ super(SelfAttention, self).__init__()
+ self.layer_number = max(1, layer_number)
+
+ self.projection_size = config.kv_channels * config.num_attention_heads
+ # Per attention head and per partition values.
+ self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads)
+ self.num_attention_heads_per_partition = config.num_attention_heads
+
+ self.multi_query_attention = config.multi_query_attention
+ self.qkv_hidden_size = 3 * self.projection_size
+ if self.multi_query_attention:
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
+ self.qkv_hidden_size = (self.projection_size +
+ 2 * self.hidden_size_per_attention_head * config.multi_query_group_num)
+ self.query_key_value = nn.Linear(
+ config.hidden_size,
+ self.qkv_hidden_size,
+ bias=config.add_bias_linear or config.add_qkv_bias,
+ device=device,
+ **_config_to_kwargs(config),
+ )
+
+ self.core_attention = CoreAttention(config, self.layer_number)
+
+ # Output.
+ self.dense = nn.Linear(
+ self.projection_size,
+ config.hidden_size,
+ bias=config.add_bias_linear,
+ device=device,
+ **_config_to_kwargs(config),
+ )
+
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
+ if self.multi_query_attention:
+ num_attention_heads = self.num_multi_query_groups_per_partition
+ else:
+ num_attention_heads = self.num_attention_heads_per_partition
+ return torch.empty(
+ inference_max_sequence_len,
+ batch_size,
+ num_attention_heads,
+ self.hidden_size_per_attention_head,
+ dtype=dtype,
+ device=device,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=None,
+ use_cache=True,
+ ):
+ # hidden_states: [sq, b, h]
+
+ # =================================================
+ # Pre-allocate memory for key-values for inference.
+ # =================================================
+ # =====================
+ # Query, Key, and Value
+ # =====================
+
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
+ mixed_x_layer = self.query_key_value(hidden_states)
+
+ if self.multi_query_attention:
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
+ [
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ ],
+ dim=-1,
+ )
+ query_layer = query_layer.view(query_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ ))
+ key_layer = key_layer.view(key_layer.size()[:-1] + (
+ self.num_multi_query_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ ))
+ value_layer = value_layer.view(value_layer.size()[:-1] + (
+ self.num_multi_query_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ ))
+ else:
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+
+ # apply relative positional encoding (rotary embedding)
+ if rotary_pos_emb is not None:
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
+
+ # adjust key and value for inference
+ if kv_cache is not None:
+ cache_k, cache_v = kv_cache
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
+ if use_cache:
+ kv_cache = (key_layer, value_layer)
+ else:
+ kv_cache = None
+
+ if self.multi_query_attention:
+ key_layer = key_layer.unsqueeze(-2)
+ key_layer = key_layer.expand(
+ -1,
+ -1,
+ -1,
+ self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
+ -1,
+ )
+ key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ ))
+ value_layer = value_layer.unsqueeze(-2)
+ value_layer = value_layer.expand(
+ -1,
+ -1,
+ -1,
+ self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
+ -1,
+ )
+ value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ ))
+
+ # ==================================
+ # core attention computation
+ # ==================================
+
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
+
+ # =================
+ # Output. [sq, b, h]
+ # =================
+
+ output = self.dense(context_layer)
+
+ return output, kv_cache
+
+
+def _config_to_kwargs(args):
+ common_kwargs = {
+ "dtype": args.torch_dtype,
+ }
+ return common_kwargs
+
+
+class MLP(torch.nn.Module):
+ """MLP.
+
+ MLP will take the input with h hidden state, project it to 4*h
+ hidden dimension, perform nonlinear transformation, and project the
+ state back into h hidden dimension.
+ """
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(MLP, self).__init__()
+
+ self.add_bias = config.add_bias_linear
+
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+ self.dense_h_to_4h = nn.Linear(
+ config.hidden_size,
+ config.ffn_hidden_size * 2,
+ bias=self.add_bias,
+ device=device,
+ **_config_to_kwargs(config),
+ )
+
+ def swiglu(x):
+ x = torch.chunk(x, 2, dim=-1)
+ return F.silu(x[0]) * x[1]
+
+ self.activation_func = swiglu
+
+ # Project back to h.
+ self.dense_4h_to_h = nn.Linear(
+ config.ffn_hidden_size,
+ config.hidden_size,
+ bias=self.add_bias,
+ device=device,
+ **_config_to_kwargs(config),
+ )
+
+ def forward(self, hidden_states):
+ # [s, b, 4hp]
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
+ intermediate_parallel = self.activation_func(intermediate_parallel)
+ # [s, b, h]
+ output = self.dense_4h_to_h(intermediate_parallel)
+ return output
+
+
+class GLMBlock(torch.nn.Module):
+ """A single transformer layer.
+
+ Transformer layer takes input with size [s, b, h] and returns an
+ output of the same size.
+ """
+
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
+ super(GLMBlock, self).__init__()
+ self.layer_number = layer_number
+
+ self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm)
+
+ self.fp32_residual_connection = config.fp32_residual_connection
+
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
+ # Layernorm on the input data.
+ self.input_layernorm = LayerNormFunc(
+ config.hidden_size,
+ eps=config.layernorm_epsilon,
+ device=device,
+ dtype=config.torch_dtype,
+ )
+
+ # Self attention.
+ self.self_attention = SelfAttention(config, layer_number, device=device)
+ self.hidden_dropout = config.hidden_dropout
+
+ # Layernorm on the attention output
+ self.post_attention_layernorm = LayerNormFunc(
+ config.hidden_size,
+ eps=config.layernorm_epsilon,
+ device=device,
+ dtype=config.torch_dtype,
+ )
+
+ # MLP
+ self.mlp = MLP(config, device=device)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=None,
+ use_cache=True,
+ ):
+ # hidden_states: [s, b, h]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, kv_cache = self.self_attention(
+ layernorm_output,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=kv_cache,
+ use_cache=use_cache,
+ )
+
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
+ layernorm_input = residual + layernorm_input
+
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
+ output = residual + output
+
+ return output, kv_cache
+
+
+class GLMTransformer(torch.nn.Module):
+ """Transformer class."""
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(GLMTransformer, self).__init__()
+
+ self.fp32_residual_connection = config.fp32_residual_connection
+ self.post_layer_norm = config.post_layer_norm
+
+ # Number of layers.
+ self.num_layers = config.num_layers
+
+ # Transformer layers.
+ def build_layer(layer_number):
+ return GLMBlock(config, layer_number, device=device)
+
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
+
+ if self.post_layer_norm:
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
+ # Final layer norm before output.
+ self.final_layernorm = LayerNormFunc(
+ config.hidden_size,
+ eps=config.layernorm_epsilon,
+ device=device,
+ dtype=config.torch_dtype,
+ )
+
+ self.gradient_checkpointing = False
+
+ def _get_layer(self, layer_number):
+ return self.layers[layer_number]
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_caches=None,
+ use_cache: Optional[bool] = True,
+ output_hidden_states: Optional[bool] = False,
+ ):
+ if not kv_caches:
+ kv_caches = [None for _ in range(self.num_layers)]
+ presents = () if use_cache else None
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ all_self_attentions = None
+ all_hidden_states = () if output_hidden_states else None
+ for index in range(self.num_layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer = self._get_layer(index)
+ if self.gradient_checkpointing and self.training:
+ layer_ret = torch.utils.checkpoint.checkpoint(
+ layer,
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_caches[index],
+ use_cache,
+ )
+ else:
+ layer_ret = layer(
+ hidden_states,
+ attention_mask,
+ rotary_pos_emb,
+ kv_cache=kv_caches[index],
+ use_cache=use_cache,
+ )
+ hidden_states, kv_cache = layer_ret
+ if use_cache:
+ presents = presents + (kv_cache,)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # Final layer norm.
+ if self.post_layer_norm:
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states, presents, all_hidden_states, all_self_attentions
+
+
+class ChatGLMPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and
+ a simple interface for downloading and loading pretrained models.
+ """
+
+ is_parallelizable = False
+ supports_gradient_checkpointing = True
+ config_class = ChatGLMConfig
+ base_model_prefix = "transformer"
+ _no_split_modules = ["GLMBlock"]
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ return
+
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
+ batch_size, seq_length = input_ids.shape
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
+ full_attention_mask.tril_()
+ past_length = 0
+ if past_key_values:
+ past_length = past_key_values[0][0].shape[0]
+ if past_length:
+ full_attention_mask = torch.cat(
+ (
+ torch.ones(batch_size, seq_length, past_length, device=input_ids.device),
+ full_attention_mask,
+ ),
+ dim=-1,
+ )
+ if padding_mask is not None:
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
+ if not past_length and padding_mask is not None:
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
+ full_attention_mask = (full_attention_mask < 0.5).bool()
+ full_attention_mask.unsqueeze_(1)
+ return full_attention_mask
+
+ def get_position_ids(self, input_ids, device):
+ batch_size, seq_length = input_ids.shape
+ position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1))
+ return position_ids
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, GLMTransformer):
+ module.gradient_checkpointing = value
+
+
+class Embedding(torch.nn.Module):
+ """Language model embeddings."""
+
+ def __init__(self, config: ChatGLMConfig, device=None):
+ super(Embedding, self).__init__()
+
+ self.hidden_size = config.hidden_size
+ # Word embeddings (parallel).
+ self.word_embeddings = nn.Embedding(
+ config.padded_vocab_size,
+ self.hidden_size,
+ dtype=config.torch_dtype,
+ device=device,
+ )
+ self.fp32_residual_connection = config.fp32_residual_connection
+
+ def forward(self, input_ids):
+ # Embeddings.
+ words_embeddings = self.word_embeddings(input_ids)
+ embeddings = words_embeddings
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
+ embeddings = embeddings.transpose(0, 1).contiguous()
+ # If the input flag for fp32 residual connection is set, convert for float.
+ if self.fp32_residual_connection:
+ embeddings = embeddings.float()
+ return embeddings
+
+
+class ChatGLMModel(ChatGLMPreTrainedModel):
+
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
+ super().__init__(config)
+ if empty_init:
+ init_method = skip_init
+ else:
+ init_method = default_init
+ init_kwargs = {}
+ if device is not None:
+ init_kwargs["device"] = device
+ self.embedding = init_method(Embedding, config, **init_kwargs)
+ self.num_layers = config.num_layers
+ self.multi_query_group_num = config.multi_query_group_num
+ self.kv_channels = config.kv_channels
+
+ # Rotary positional embeddings
+ self.seq_length = config.seq_length
+ rotary_dim = (config.hidden_size //
+ config.num_attention_heads if config.kv_channels is None else config.kv_channels)
+
+ self.rotary_pos_emb = RotaryEmbedding(
+ rotary_dim // 2,
+ original_impl=config.original_rope,
+ device=device,
+ dtype=config.torch_dtype,
+ )
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
+ self.output_layer = init_method(
+ nn.Linear,
+ config.hidden_size,
+ config.padded_vocab_size,
+ bias=False,
+ dtype=config.torch_dtype,
+ **init_kwargs,
+ )
+ self.pre_seq_len = config.pre_seq_len
+ self.prefix_projection = config.prefix_projection
+ if self.pre_seq_len is not None:
+ for param in self.parameters():
+ param.requires_grad = False
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+ self.dropout = torch.nn.Dropout(0.1)
+
+ def get_input_embeddings(self):
+ return self.embedding.word_embeddings
+
+ def get_prompt(self, batch_size, device, dtype=torch.half):
+ prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device))
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.num_layers * 2,
+ self.multi_query_group_num,
+ self.kv_channels,
+ )
+ # seq_len, b, nh, hidden_size
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ full_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
+
+ batch_size, seq_length = input_ids.shape
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids)
+
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(
+ batch_size=batch_size,
+ device=input_ids.device,
+ dtype=inputs_embeds.dtype,
+ )
+ if attention_mask is not None:
+ attention_mask = torch.cat(
+ [
+ attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask,
+ ],
+ dim=-1,
+ )
+
+ if full_attention_mask is None:
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
+
+ # Rotary positional embeddings
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
+ if position_ids is not None:
+ rotary_pos_emb = rotary_pos_emb[position_ids]
+ else:
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
+
+ # Run encoder.
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
+ inputs_embeds,
+ full_attention_mask,
+ rotary_pos_emb=rotary_pos_emb,
+ kv_caches=past_key_values,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if not return_dict:
+ return tuple(v for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def quantize(self, weight_bit_width: int):
+ from .quantization import quantize
+
+ quantize(self.encoder, weight_bit_width)
+ return self
+
+
+class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
+
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
+ super().__init__(config)
+
+ self.max_sequence_length = config.max_length
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
+ self.config = config
+ self.quantized = False
+
+ if self.config.quantization_bit:
+ self.quantize(self.config.quantization_bit, empty_init=True)
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ ) -> Dict[str, Any]:
+ # update past_key_values
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format)
+
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
+ dim=-1,
+ )
+
+ # update position ids
+ if "position_ids" in model_kwargs:
+ position_ids = model_kwargs["position_ids"]
+ new_position_id = position_ids[..., -1:].clone()
+ new_position_id += 1
+ model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
+
+ model_kwargs["is_first_forward"] = False
+ return model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ is_first_forward: bool = True,
+ **kwargs,
+ ) -> dict:
+ # only last token for input_ids if past is not None
+ if position_ids is None:
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
+ if not is_first_forward:
+ position_ids = position_ids[..., -1:]
+ input_ids = input_ids[:, -1:]
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ "return_last_logit": True,
+ }
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_last_logit: Optional[bool] = False,
+ ):
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
+
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = transformer_outputs[0]
+ if return_last_logit:
+ hidden_states = hidden_states[-1:]
+ lm_logits = self.transformer.output_layer(hidden_states)
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
+
+ loss = None
+ if labels is not None:
+ lm_logits = lm_logits.to(torch.float32)
+
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ lm_logits = lm_logits.to(hidden_states.dtype)
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...],
+ beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+
+ Output shares the same memory storage as `past`.
+ """
+ return tuple((
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
+ ) for layer_past in past)
+
+ def process_response(self, response):
+ response = response.strip()
+ response = response.replace("[[训练时间]]", "2023年")
+ return response
+
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
+ prompt = tokenizer.build_prompt(query, history=history)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ return inputs
+
+ def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
+ if history:
+ prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
+ input_ids = input_ids[1:]
+ inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
+ else:
+ prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
+ inputs = tokenizer([prompt], return_tensors="pt")
+ inputs = inputs.to(self.device)
+ return inputs
+
+ @torch.no_grad()
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ max_length: int = 8192,
+ num_beams=1,
+ do_sample=True,
+ top_p=0.8,
+ temperature=0.8,
+ logits_processor=None,
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {
+ "max_length": max_length,
+ "num_beams": num_beams,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
+ inputs = self.build_inputs(tokenizer, query, history=history)
+ outputs = self.generate(**inputs, **gen_kwargs)
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ response = tokenizer.decode(outputs)
+ response = self.process_response(response)
+ history = history + [(query, response)]
+ return response, history
+
+ @torch.no_grad()
+ def stream_chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = None,
+ past_key_values=None,
+ max_length: int = 8192,
+ do_sample=True,
+ top_p=0.8,
+ temperature=0.8,
+ logits_processor=None,
+ return_past_key_values=False,
+ **kwargs,
+ ):
+ if history is None:
+ history = []
+ if logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ logits_processor.append(InvalidScoreLogitsProcessor())
+ gen_kwargs = {
+ "max_length": max_length,
+ "do_sample": do_sample,
+ "top_p": top_p,
+ "temperature": temperature,
+ "logits_processor": logits_processor,
+ **kwargs,
+ }
+ if past_key_values is None and not return_past_key_values:
+ inputs = self.build_inputs(tokenizer, query, history=history)
+ else:
+ inputs = self.build_stream_inputs(tokenizer, query, history=history)
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[0]
+ if self.transformer.pre_seq_len is not None:
+ past_length -= self.transformer.pre_seq_len
+ inputs.position_ids += past_length
+ attention_mask = inputs.attention_mask
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
+ inputs["attention_mask"] = attention_mask
+ for outputs in self.stream_generate(
+ **inputs,
+ past_key_values=past_key_values,
+ return_past_key_values=return_past_key_values,
+ **gen_kwargs,
+ ):
+ if return_past_key_values:
+ outputs, past_key_values = outputs
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
+ response = tokenizer.decode(outputs)
+ if response and response[-1] != "�":
+ response = self.process_response(response)
+ new_history = history + [(query, response)]
+ if return_past_key_values:
+ yield response, new_history, past_key_values
+ else:
+ yield response, new_history
+
+ @torch.no_grad()
+ def stream_generate(
+ self,
+ input_ids,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ return_past_key_values=False,
+ **kwargs,
+ ):
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
+
+ if generation_config is None:
+ generation_config = self.generation_config
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+ bos_token_id, eos_token_id = (
+ generation_config.bos_token_id,
+ generation_config.eos_token_id,
+ )
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+
+ has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None)
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif generation_config.max_new_tokens is not None:
+ generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length)
+ if not has_default_max_length:
+ logger.warn(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
+ UserWarning,
+ )
+
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids")
+ logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`.")
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList())
+ stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList())
+
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=input_ids,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ stopping_criteria = self._get_stopping_criteria(generation_config=generation_config,
+ stopping_criteria=stopping_criteria)
+ logits_warper = self._get_logits_warper(generation_config)
+
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+ scores = None
+ while True:
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=False,
+ output_hidden_states=False,
+ )
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ if generation_config.do_sample:
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = torch.argmax(probs, dim=-1)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(outputs,
+ model_kwargs,
+ is_encoder_decoder=self.config.is_encoder_decoder)
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
+ if return_past_key_values:
+ yield input_ids, outputs.past_key_values
+ else:
+ yield input_ids
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ break
+
+ def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
+ if bits == 0:
+ return
+
+ from .quantization import quantize
+
+ if self.quantized:
+ logger.info("Already quantized.")
+ return self
+
+ self.quantized = True
+
+ self.config.quantization_bit = bits
+
+ self.transformer.encoder = quantize(
+ self.transformer.encoder,
+ bits,
+ empty_init=empty_init,
+ device=device,
+ **kwargs,
+ )
+ return self
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
new file mode 100644
index 000000000000..bc99be4cc391
--- /dev/null
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -0,0 +1,988 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.models.gpt2.modeling_gpt2 import (
+ GPT2DoubleHeadsModel,
+ GPT2DoubleHeadsModelOutput,
+ GPT2ForQuestionAnswering,
+ GPT2ForSequenceClassification,
+ GPT2ForTokenClassification,
+ GPT2LMHeadModel,
+ GPT2Model,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
+from colossalai.shardformer.shard import ShardConfig
+
+
+class GPT2PipelineForwards:
+ '''
+ This class serves as a micro library for forward function substitution of GPT2 models
+ under pipeline setting.
+ '''
+
+ @staticmethod
+ def gpt2_model_forward(
+ self: GPT2Model,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+
+ # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
+ # Please refer to original code of transformers for more details.
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ logger = logging.get_logger(__name__)
+
+ # Preprocess passed in arguments
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ if stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ else:
+ if hidden_states is None:
+ raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
+ input_shape = hidden_states.size()[:-1]
+ batch_size = input_shape[0]
+ device = hidden_states.device
+ hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
+
+ # GPT2Attention mask.
+ if attention_mask is not None:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if stage_manager.is_first_stage():
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+ else:
+ position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ # Going through held blocks.
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ for i in range(start_idx, end_idx):
+ block = self.h[i]
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=None,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ if shard_config.enable_sequence_parallelism:
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ if stage_manager.is_last_stage():
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if stage_manager.is_last_stage():
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+ else:
+ # always return dict for intermediate stage
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def gpt2_lmhead_model_forward(
+ self: GPT2LMHeadModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+
+ This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
+ Please refer to original code of transformers for more details.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config)
+
+ # If not at the last stage, return hidden_states as in GPT2Model
+ if not stage_manager.is_last_stage():
+ return {'hidden_states': outputs['hidden_states']}
+
+ hidden_states = outputs[0]
+ lm_logits = self.lm_head(hidden_states)
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ @staticmethod
+ def gpt2_double_heads_model_forward(
+ self: GPT2DoubleHeadsModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ mc_token_ids: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ mc_labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
+ r"""
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
+ 1]`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
+
+ This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel.forward.
+ Please refer to original code of transformers for more details.
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config)
+
+ # If not at the last stage, return hidden_states as in GPT2Model
+ if not stage_manager.is_last_stage():
+ return {'hidden_states': outputs['hidden_states']}
+
+ hidden_states = outputs[0]
+ lm_logits = self.lm_head(hidden_states)
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
+
+ mc_loss = None
+ if mc_labels is not None:
+ loss_fct = CrossEntropyLoss()
+ mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
+ lm_loss = None
+ if labels is not None:
+ labels = labels.to(lm_logits.device)
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits, mc_logits) + outputs[1:]
+ if mc_loss is not None:
+ output = (mc_loss,) + output
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return GPT2DoubleHeadsModelOutput(
+ loss=lm_loss,
+ mc_loss=mc_loss,
+ logits=lm_logits,
+ mc_logits=mc_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ @staticmethod
+ def gpt2_for_question_answering_forward(
+ self: GPT2ForQuestionAnswering,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+
+ # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward.
+ # Please refer to original code of transformers for more details.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config)
+
+ # If not at the last stage, return hidden_states as in GPT2Model
+ if not stage_manager.is_last_stage():
+ return {'hidden_states': outputs['hidden_states']}
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ @staticmethod
+ def gpt2_for_token_classification_forward(
+ self: GPT2ForTokenClassification,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification.forward.
+ # Please refer to original code of transformers for more details.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config)
+
+ # If not at the last stage, return hidden_states as in GPT2Model
+ if not stage_manager.is_last_stage():
+ return {'hidden_states': outputs['hidden_states']}
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ @staticmethod
+ def gpt2_for_sequence_classification_forward(
+ self: GPT2ForSequenceClassification,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward.
+ # Please refer to original code of transformers for more details.
+ """
+ logger = logging.get_logger(__name__)
+
+ if input_ids is not None:
+ batch_size, _ = input_ids.shape[:2]
+ else:
+ batch_size, _ = hidden_states.shape[:2]
+ assert (self.config.pad_token_id is not None
+ or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined."
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config)
+
+ # If not at the last stage, return hidden_states as in GPT2Model
+ if not stage_manager.is_last_stage():
+ return {'hidden_states': outputs['hidden_states']}
+
+ hidden_states = outputs[0]
+ logits = self.score(hidden_states)
+
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`")
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def get_gpt2_flash_attention_forward():
+
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
+
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
+
+ def split_heads(tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(new_shape)
+ return tensor
+
+ def forward(
+ self: GPT2Attention,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")
+
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ query = split_heads(query, self.num_heads, self.head_dim)
+ key = split_heads(key, self.num_heads, self.head_dim)
+ value = split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key = torch.cat((past_key, key), dim=1)
+ value = torch.cat((past_value, value), dim=1)
+
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ if not self.is_cross_attention:
+ attn_mask_type = AttnMaskType.causal
+ flash_attention_mask = None
+ if attention_mask != None:
+ if attn_mask_type == AttnMaskType.causal:
+ attn_mask_type == AttnMaskType.paddedcausal
+ else:
+ attn_mask_type = AttnMaskType.padding
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+
+ scale = value.size(-1)**-0.5
+ if self.scale_attn_by_inverse_layer_idx:
+ scale = scale * (1 / float(self.layer_idx + 1))
+
+ # use coloattention
+ attention = ColoAttention(embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ dropout=self.attn_dropout.p,
+ scale=scale)
+
+ attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
+
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+ outputs = (attn_output, present, None)
+
+ return outputs
+
+ return forward
+
+
+def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+ if position_ids is None:
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+ # GPT2Attention mask.
+ if attention_mask is not None:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger = logging.get_logger(__name__)
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+
+ # split the input tensor along sequence dimension
+ # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
+ hidden_states = split_forward_gather_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
+ if layer_past is not None:
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ # When sequence parallelism done, gather the output tensor in forward and split it in backward
+ hidden_states = gather_forward_split_backward(hidden_states,
+ dim=1,
+ process_group=shard_config.tensor_parallel_process_group)
+
+ hidden_states = self.ln_f(hidden_states)
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/modeling/jit.py b/colossalai/shardformer/modeling/jit.py
new file mode 100644
index 000000000000..6434348ef823
--- /dev/null
+++ b/colossalai/shardformer/modeling/jit.py
@@ -0,0 +1,34 @@
+import torch
+
+
+def get_dropout_add_func():
+
+ from transformers.models.bloom.modeling_bloom import dropout_add
+
+ def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
+ return dropout_add(x, residual, prob, training)
+
+ return self_dropout_add
+
+
+def get_jit_fused_dropout_add_func():
+
+ from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
+
+ def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
+ bias = torch.zeros_like(x)
+ if training:
+ return bias_dropout_add_fused_train(x, bias, residual, prob)
+ return bias_dropout_add_fused_inference(x, bias, residual, prob)
+
+ return self_dropout_add
+
+
+def get_jit_fused_gelu_forward_func():
+
+ from colossalai.kernel.jit.bias_gelu import bias_gelu
+
+ def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
+ return bias_gelu(bias, x)
+
+ return bloom_gelu_forward
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
new file mode 100644
index 000000000000..ff622c306c59
--- /dev/null
+++ b/colossalai/shardformer/modeling/llama.py
@@ -0,0 +1,471 @@
+import warnings
+from typing import Callable, List, Optional, Tuple
+
+import torch
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+
+class LlamaPipelineForwards:
+ '''
+ This class serves as a micro library for forward function substitution of Llama models
+ under pipeline setting.
+ '''
+
+ @staticmethod
+ def llama_model_forward(
+ self: LlamaModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ):
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ position_ids = torch.arange(past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device)
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions, for the first stage, hidden_states is the input embeddings,
+ # for the other stages, hidden_states is the output of the previous stage
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=hidden_states.device)
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states,
+ past_key_values_length)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if stage_manager.is_last_stage():
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ next_cache = next_decoder_cache if use_cache else None
+ if stage_manager.is_last_stage():
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+ # always return dict for imediate stage
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def llama_for_causal_lm_forward(
+ self: LlamaForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = LlamaPipelineForwards.llama_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ )
+ past_key_values = None
+ all_hidden_states = None
+ all_self_attentions = None
+ all_cross_attentions = None
+
+ if stage_manager.is_last_stage():
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def llama_for_sequence_classification_forward(
+ self: LlamaForSequenceClassification,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ transformer_outputs = LlamaPipelineForwards.llama_model_forward(
+ self.model,
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ )
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ batch_size = inputs_embeds.shape[0]
+ else:
+ batch_size = hidden_states.shape[0]
+
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ else:
+ hidden_states = transformer_outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+
+def get_llama_flash_attention_forward():
+
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
+
+ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
+
+ llama_version = 2
+ try:
+ from transformers.models.llama.modeling_llama import repeat_kv
+ except:
+ warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
+ llama_version = 1
+
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
+
+ def forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+ assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ if llama_version == 2:
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
+ query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
+ key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
+ value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
+
+ flash_attention_mask = None
+ attn_mask_type = AttnMaskType.causal
+ if attention_mask != None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+ attn_mask_type = AttnMaskType.paddedcausal
+
+ attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
+ attn_output = attention(query_states,
+ key_states,
+ value_states,
+ attn_mask=flash_attention_mask,
+ attn_mask_type=attn_mask_type)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+ return forward
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
new file mode 100644
index 000000000000..ad088f3702e5
--- /dev/null
+++ b/colossalai/shardformer/modeling/opt.py
@@ -0,0 +1,666 @@
+import random
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.models.opt.modeling_opt import (
+ OPTForCausalLM,
+ OPTForQuestionAnswering,
+ OPTForSequenceClassification,
+ OPTModel,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+
+class OPTPipelineForwards:
+ '''
+ This class serves as a micro library for forward function substitution of OPT models
+ under pipeline setting.
+ '''
+
+ @staticmethod
+ def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ from transformers.models.opt.modeling_opt import _make_causal_mask
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ _dtype,
+ device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype,
+ tgt_len=input_shape[-1]).to(device)
+ combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
+ combined_attention_mask)
+
+ return combined_attention_mask
+
+ @staticmethod
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+ @staticmethod
+ def opt_model_forward(
+ self: OPTModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ '''
+ This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
+ '''
+
+ from transformers.modeling_outputs import BaseModelOutputWithPast
+ from transformers.utils import logging
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ decoder = self.decoder
+ if stage_manager.is_first_stage():
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if inputs_embeds is None:
+ inputs_embeds = decoder.embed_tokens(input_ids)
+
+ if decoder.project_in is not None:
+ inputs_embeds = decoder.project_in(inputs_embeds)
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ _dtype = inputs_embeds.dtype
+
+ else:
+ if hidden_states is None:
+ raise ValueError("hidden_states shouln't be None for intermediate stages.")
+ input_shape = hidden_states.size()[:-1]
+ batch_size, seq_length = input_shape[0], input_shape[1]
+ device = hidden_states.device
+ _dtype = hidden_states.dtype
+
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values_length + seq_length
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
+ elif attention_mask.shape[1] != mask_seq_length:
+ raise ValueError(
+ f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
+ f"{mask_seq_length} (sum of the lengths of current and past inputs)")
+
+ causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype,
+ device, past_key_values_length)
+
+ if stage_manager.is_first_stage():
+ pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
+ hidden_states = inputs_embeds + pos_embeds
+
+ if decoder.gradient_checkpointing and decoder.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(decoder.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}.")
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+
+ torch.cuda.set_device(device)
+
+ for idx in range(start_idx, end_idx):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ decoder_layer = decoder.layers[idx]
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ dropout_probability = random.uniform(0, 1)
+ if decoder.training and (dropout_probability < decoder.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if decoder.gradient_checkpointing and decoder.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ causal_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if stage_manager.is_last_stage():
+ if decoder.final_layer_norm is not None:
+ hidden_states = decoder.final_layer_norm(hidden_states)
+ if decoder.project_out is not None:
+ hidden_states = decoder.project_out(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if stage_manager.is_last_stage():
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+ else:
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def opt_for_causal_lm_forward(
+ self: OPTForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
+ Please refer to original code of transformers for more details.
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = OPTPipelineForwards.opt_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ )
+ if stage_manager.is_last_stage():
+ logits = self.lm_head(outputs[0]).contiguous()
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def opt_for_sequence_classification_forward(
+ self: OPTForSequenceClassification,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
+ Please refer to original code of transformers for more details.
+ """
+
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index)
+
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ batch_size = input_ids.shape[0] if input_ids is not None else hidden_states.shape[0]
+
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+ logger.warning(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`")
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+ else:
+ hidden_states = transformer_outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+ @staticmethod
+ def opt_for_question_answering_forward(
+ self: OPTForQuestionAnswering,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
+ Please refer to original code of transformers for more details.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index)
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+
+ logits = self.qa_outputs(hidden_states)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + transformer_outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+ else:
+ hidden_states = transformer_outputs.get('hidden_states')
+ return {'hidden_states': hidden_states}
+
+
+def get_opt_flash_attention_forward():
+
+ from transformers.models.opt.modeling_opt import OPTAttention
+
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
+
+ def forward(
+ self: OPTAttention,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, _ = hidden_states.size()
+
+ attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*attention_input_shape)
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k, v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
+ value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states).view(*attention_input_shape)
+ value_states = self.v_proj(key_value_states).view(*attention_input_shape)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self.k_proj(hidden_states).view(*attention_input_shape)
+ value_states = self.v_proj(hidden_states).view(*attention_input_shape)
+ key_states = torch.cat([past_key_value[0], key_states], dim=1)
+ value_states = torch.cat([past_key_value[1], value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self.k_proj(hidden_states).view(*attention_input_shape)
+ value_states = self.v_proj(hidden_states).view(*attention_input_shape)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ src_len = key_states.size(1)
+ if layer_head_mask != None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}")
+
+ flash_attention_mask = None
+ attn_mask_type = AttnMaskType.causal
+ if attention_mask != None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+ attn_mask_type = AttnMaskType.paddedcausal
+
+ attention = ColoAttention(embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ scale=self.scaling)
+ attn_output = attention(query_states,
+ key_states,
+ value_states,
+ attn_mask=flash_attention_mask,
+ attn_mask_type=attn_mask_type)
+
+ attn_output = self.out_proj(attn_output)
+ return attn_output, None, past_key_value
+
+ return forward
+
+
+def get_jit_fused_opt_decoder_layer_forward():
+
+ from transformers.models.opt.modeling_opt import OPTDecoderLayer
+
+ def forward(
+ self: OPTDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ hidden_states_shape = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
+ residual = hidden_states
+
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
+ if self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = self.fc2(hidden_states)
+
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape)
+
+ # 350m applies layer norm AFTER attention
+ if not self.do_layer_norm_before:
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ return forward
diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py
new file mode 100644
index 000000000000..c40c02ec411a
--- /dev/null
+++ b/colossalai/shardformer/modeling/sam.py
@@ -0,0 +1,203 @@
+import math
+from typing import Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+
+def forward_fn():
+
+ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
+ batch_size, height, width, _ = hidden_states.shape
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
+ qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
+ -1).permute(2, 0, 3, 1, 4))
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
+
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w,
+ (height, width), (height, width))
+
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
+
+ # replace dropout process with added DropoutForParallelInput layer
+ # origin code:
+ # attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_probs = self.dropout_layer(attn_weights)
+
+ attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
+ attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
+
+ attn_output = self.proj(attn_output)
+
+ if output_attentions:
+ outputs = (attn_output, attn_weights)
+ else:
+ outputs = (attn_output, None)
+
+ return outputs
+
+ return forward
+
+
+def get_sam_flash_attention_forward():
+
+ from transformers.models.sam.modeling_sam import SamAttention
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+
+ def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
+ batch, point_batch_size, n_tokens, channel = hidden_states.shape
+ c_per_head = channel // num_attention_heads
+ hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
+ return hidden_states
+
+ def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
+ batch, n_tokens, n_heads, c_per_head = hidden_states.shape
+ return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
+
+ def forward(self: SamAttention,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ attention_similarity: Tensor = None) -> Tensor:
+ # Input projections
+ query = self.q_proj(query)
+ key = self.k_proj(key)
+ value = self.v_proj(value)
+
+ point_batch_size = query.shape[1]
+ # Separate into heads
+ query = _separate_heads(query, self.num_attention_heads)
+ key = _separate_heads(key, self.num_attention_heads)
+ value = _separate_heads(value, self.num_attention_heads)
+
+ # SamAttention
+ _, _, _, c_per_head = query.shape
+ bias = None
+ if attention_similarity is not None:
+ bias = attention_similarity
+
+ scale = 1.0 / math.sqrt(c_per_head)
+ out = me_attention(query, key, value, attn_bias=bias, scale=scale)
+
+ out = _recombine_heads(out, point_batch_size)
+ out = self.out_proj(out)
+
+ return out
+
+ return forward
+
+
+def get_sam_vision_flash_attention_forward():
+
+ from transformers.models.sam.modeling_sam import SamVisionAttention
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+
+ def add_decomposed_rel_pos(
+ query: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+ ) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
+
+ Args:
+ attn (`torch.Tensor`):
+ attention map.
+ query (`torch.Tensor`):
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
+ rel_pos_h (`torch.Tensor`):
+ relative position embeddings (Lh, channel) for height axis.
+ rel_pos_w (`torch.Tensor`):
+ relative position embeddings (Lw, channel) for width axis.
+ q_size (tuple):
+ spatial sequence size of query q with (query_height, query_width).
+ k_size (tuple):
+ spatial sequence size of key k with (key_height, key_width).
+
+ Returns:
+ attn (`torch.Tensor`):
+ attention map with added relative positional embeddings.
+ """
+
+ query_height, query_width = q_size
+ key_height, key_width = k_size
+ relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
+ relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
+
+ batch_size, _, nHead, dim = query.shape
+ reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
+ rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+ rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
+ return rel_pos
+
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+
+ Args:
+ q_size (int):
+ size of the query.
+ k_size (int):
+ size of key k.
+ rel_pos (`torch.Tensor`):
+ relative position embeddings (L, channel).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+ def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
+ batch_size, height, width, _ = hidden_states.shape
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
+ qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
+ -1).permute(2, 0, 1, 3, 4))
+
+ query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
+
+ rel_pos = None
+ if self.use_rel_pos:
+ rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
+
+ attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
+
+ attn_output = attn_output.reshape(batch_size, height, width, -1)
+
+ attn_output = self.proj(attn_output)
+
+ outputs = (attn_output, None)
+
+ return outputs
+
+ return forward
diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py
new file mode 100644
index 000000000000..9cc071f91dfc
--- /dev/null
+++ b/colossalai/shardformer/modeling/t5.py
@@ -0,0 +1,786 @@
+import warnings
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from torch.nn import CrossEntropyLoss
+from torch.utils.checkpoint import checkpoint
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+
+class T5PipelineForwards:
+ '''
+ This class serves as a micro library for forward function substitution of
+ T5 models under pipeline setting.
+ '''
+
+ @staticmethod
+ def t5_stack_forward(
+ self: T5Stack,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+
+ # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Stack.forward.
+ # Please refer to original code of transformers for more details.
+
+ logger = logging.get_logger(__name__)
+
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+ if use_cache is True:
+ if not in_decoder:
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ stage = stage_manager.stage
+ in_decoder = self.is_decoder
+ if in_decoder != (stage >= decoder_starting_stage):
+ raise ValueError("Config in T5Stack is not aligned with pipeline setting.")
+
+ # at_first_stage: current stage is the first stage of encoder/decoder, taking input_ids/input_embedds
+ # at_last_stage: current stage is the last stage of encoder/decoder, making outputs the same form as huggingface
+ at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
+ at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
+
+ # Process inputs if at the first stage of encoder/decoder.
+ if at_first_stage:
+ if input_ids is not None and inputs_embeds is not None:
+ err_msg_prefix = "decoder_" if in_decoder else ""
+ raise ValueError(
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ err_msg_prefix = "decoder_" if in_decoder else ""
+ raise ValueError(
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
+ if inputs_embeds is None:
+ if self.embed_tokens is None:
+ raise ValueError("You have to initialize the model with valid token embeddings")
+ inputs_embeds = self.embed_tokens(input_ids)
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ hidden_states = self.dropout(inputs_embeds)
+ else:
+ if hidden_states is None:
+ raise ValueError(
+ "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
+ input_shape = hidden_states.size()[:-1]
+ batch_size, seq_length = input_shape[0], input_shape[1]
+ device = hidden_states.device
+
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
+
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
+ if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
+ encoder_seq_length = encoder_hidden_states.shape[1]
+ encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long)
+
+ # initialize past_key_values with `None` if past does not exist
+ if past_key_values is None:
+ past_key_values = [None] * len(self.block)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+ present_key_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+
+ # Going through held blocks.
+ start_idx, end_idx = stage_index[0], stage_index[1]
+
+ for i in range(start_idx, end_idx):
+
+ past_key_value = past_key_values[i]
+ layer_module = self.block[i]
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ torch.cuda.set_device(hidden_states.device)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return tuple(module(*inputs, use_cache, output_attentions))
+
+ return custom_forward
+
+ layer_outputs = checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ extended_attention_mask,
+ position_bias,
+ encoder_hidden_states,
+ encoder_extended_attention_mask,
+ encoder_decoder_position_bias,
+ layer_head_mask,
+ cross_attn_layer_head_mask,
+ None, # past_key_value is always None with gradient checkpointing
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+
+ if use_cache is False or use_cache is None:
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
+ hidden_states, present_key_value_state = layer_outputs[:2]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[2]
+
+ if in_decoder and encoder_hidden_states is not None:
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
+ # append next layer key value states
+ if use_cache:
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
+
+ # last layer
+ if at_last_stage:
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ if not return_dict:
+ return tuple(v for v in [
+ hidden_states,
+ present_key_value_states,
+ all_hidden_states,
+ all_attentions,
+ all_cross_attentions,
+ ] if v is not None)
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_value_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+ else:
+ return {
+ 'hidden_states': hidden_states,
+ 'position_bias': position_bias,
+ 'encoder_decoder_position_bias': encoder_decoder_position_bias,
+ 'backward_tensor_keys': ['hidden_states']
+ }
+
+ @staticmethod
+ def t5_model_forward(
+ self: T5Model,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+ backward_tensor_keys: Optional[List[str]] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
+
+ # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward.
+ # Please refer to original code of transformers for more details.
+
+ __HEAD_MASK_WARNING_MSG = """
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
+ num_heads)`.
+ """
+
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ logger = logging.get_logger(__name__)
+
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+ if head_mask is not None and decoder_head_mask is None:
+ if self.config.num_layers == self.config.num_decoder_layers:
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+ decoder_head_mask = head_mask
+
+ in_decoder = stage_manager.stage >= decoder_starting_stage
+ # Stage is in encoder, directly return the output of t5_stack_forward
+ if not in_decoder:
+ encoder_outputs = T5PipelineForwards.t5_stack_forward(
+ self.encoder,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ position_bias=position_bias,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+ if stage_manager.stage == decoder_starting_stage - 1:
+ # last stage of encoder
+ return {'encoder_hidden_states': encoder_outputs[0]}
+ else:
+ return encoder_outputs
+
+ at_last_decoder_stage = stage_manager.is_last_stage()
+ at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
+
+ if encoder_outputs is not None:
+ encoder_hidden_states = encoder_outputs[0]
+ elif encoder_hidden_states is None:
+ raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
+
+ if not at_first_decoder_stage and hidden_states is None:
+ raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
+
+ # Decode
+ decoder_outputs = T5PipelineForwards.t5_stack_forward(
+ self.decoder,
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ past_key_values=past_key_values,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ position_bias=position_bias,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+
+ # Directly return outputs of overloaded T5Stack forward if not at last stage.
+ if not at_last_decoder_stage:
+ # encoder_hidden_states should be passed to the next stage
+ decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
+ return decoder_outputs
+
+ if not return_dict:
+ return decoder_outputs + encoder_hidden_states
+ else:
+ return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_hidden_states)
+
+ @staticmethod
+ def t5_for_conditional_generation_forward(
+ self: T5ForConditionalGeneration,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+ backward_tensor_keys: Optional[List[str]] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
+
+ # This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward.
+ # Please refer to original code of transformers for more details.
+
+ __HEAD_MASK_WARNING_MSG = """
+ The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
+ `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
+ If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
+ num_heads)`.
+ """
+
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ logger = logging.get_logger(__name__)
+
+ # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+ if head_mask is not None and decoder_head_mask is None:
+ if self.config.num_layers == self.config.num_decoder_layers:
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+ decoder_head_mask = head_mask
+
+ in_decoder = stage_manager.stage >= decoder_starting_stage
+
+ # Stage is in encoder, directly return the output of t5_stack_forward
+ if not in_decoder:
+ encoder_outputs = T5PipelineForwards.t5_stack_forward(
+ self.encoder,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ position_bias=position_bias,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+ if stage_manager.stage == decoder_starting_stage - 1:
+ # last stage of encoder
+ return {'encoder_hidden_states': encoder_outputs[0]}
+ else:
+ return encoder_outputs
+
+ at_last_decoder_stage = stage_manager.is_last_stage()
+ at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
+
+ if encoder_outputs is not None:
+ encoder_hidden_states = encoder_outputs[0]
+ elif encoder_hidden_states is None:
+ raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
+
+ if not at_first_decoder_stage and hidden_states is None:
+ raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
+
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+ # get decoder inputs from shifting lm labels to the right
+ decoder_input_ids = self._shift_right(labels)
+
+ # Decode
+ decoder_outputs = T5PipelineForwards.t5_stack_forward(
+ self.decoder,
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ past_key_values=past_key_values,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ position_bias=position_bias,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+
+ # Directly return outputs of overloaded T5Stack forward if not at last stage.
+ if not at_last_decoder_stage:
+ # encoder_hidden_states should be passed to the next stage
+ decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
+ return decoder_outputs
+
+ sequence_output = decoder_outputs[0]
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.model_dim**-0.5)
+
+ lm_logits = self.lm_head(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ # move labels to correct device to enable PP
+ labels = labels.to(lm_logits.device)
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqLMOutput(loss=loss,
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_hidden_states)
+
+ @staticmethod
+ def t5_encoder_model_forward(
+ self: T5EncoderModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+ backward_tensor_keys: Optional[List[str]] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
+ r"""
+ This function is modified on the basis of transformers.models.t5.modeling_gpt2.T5EncoderModel.forward.
+ Please refer to original code of transformers for more details.
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = T5PipelineForwards.t5_stack_forward(self.encoder,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ position_bias=position_bias,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+
+ return outputs
+
+
+def get_t5_flash_attention_forward():
+
+ try:
+ from xformers.ops import memory_efficient_attention as me_attention
+ except:
+ raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
+ from transformers.models.t5.modeling_t5 import T5Attention
+
+ def forward(
+ self: T5Attention,
+ hidden_states: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ key_value_states: Optional[torch.Tensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ query_length: Optional[int] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
+ """
+ # Input is (batch_size, seq_length, dim)
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ real_seq_length = seq_length
+
+ if past_key_value is not None:
+ if len(past_key_value) != 2:
+ raise ValueError(
+ f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+ )
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
+
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
+
+ def shape(states):
+ """projection"""
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
+
+ def unshape(states):
+ """reshape"""
+ return states.view(batch_size, -1, self.inner_dim)
+
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
+ """projects hidden states correctly to key/query states"""
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(hidden_states))
+ elif past_key_value is None:
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+
+ if past_key_value is not None:
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, key_length, dim_per_head)
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
+ elif past_key_value.shape[1] != key_value_states.shape[1]:
+ # checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ # cross-attn
+ # (batch_size, n_heads, seq_length, dim_per_head)
+ hidden_states = shape(proj_layer(key_value_states))
+ else:
+ # cross-attn
+ hidden_states = past_key_value
+ return hidden_states
+
+ # get query states
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
+
+ # get key/value states
+ key_states = project(hidden_states, self.k, key_value_states,
+ past_key_value[0] if past_key_value is not None else None)
+ value_states = project(hidden_states, self.v, key_value_states,
+ past_key_value[1] if past_key_value is not None else None)
+
+ if position_bias is None:
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length),
+ device=query_states.device,
+ dtype=query_states.dtype)
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+ position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device)
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+ if past_key_value is not None:
+ position_bias = position_bias[:, :, -hidden_states.size(1):, :]
+
+ if mask is not None:
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
+
+ if self.pruned_heads:
+ mask = torch.ones(position_bias.shape[1])
+ mask[list(self.pruned_heads)] = 0
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+
+ position_bias_masked = position_bias_masked.contiguous()
+ attn_output = me_attention(query_states,
+ key_states,
+ value_states,
+ attn_bias=position_bias_masked,
+ p=self.dropout,
+ scale=1.0)
+ attn_output = unshape(attn_output)
+ attn_output = self.o(attn_output)
+
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
+
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_T5_layer_ff_forward():
+
+ from transformers.models.t5.modeling_t5 import T5LayerFF
+
+ def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor:
+ forwarded_states = self.layer_norm(hidden_states)
+ forwarded_states = self.DenseReluDense(forwarded_states)
+ hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training)
+ return hidden_states
+
+ return forward
+
+
+def get_T5_layer_self_attention_forward():
+
+ from transformers.models.t5.modeling_t5 import T5LayerSelfAttention
+
+ def forward(
+ self: T5LayerSelfAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.SelfAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+ return forward
+
+
+def get_T5_layer_cross_attention_forward():
+
+ from transformers.models.t5.modeling_t5 import T5LayerCrossAttention
+
+ def forward(
+ self: T5LayerCrossAttention,
+ hidden_states: torch.Tensor,
+ key_value_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_bias: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: bool = False,
+ query_length: Optional[int] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
+ normed_hidden_states = self.layer_norm(hidden_states)
+ attention_output = self.EncDecAttention(
+ normed_hidden_states,
+ mask=attention_mask,
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+ )
+ layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
+ return outputs
+
+ return forward
diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py
new file mode 100644
index 000000000000..2ce52163ac32
--- /dev/null
+++ b/colossalai/shardformer/modeling/vit.py
@@ -0,0 +1,385 @@
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import torch
+from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+
+def _encoder_forward(
+ encoder: ViTEncoder,
+ start_idx: int,
+ end_idx: int,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ stage_manager: PipelineStageManager = None,
+) -> Union[tuple, BaseModelOutput]:
+
+ for i in range(start_idx, end_idx):
+ layer_module = encoder.layer[i]
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if encoder.gradient_checkpointing and encoder.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs, False)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ layer_head_mask,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, False)
+
+ hidden_states = layer_outputs[0]
+ if not stage_manager.is_last_stage():
+ return hidden_states
+ else:
+ if not return_dict:
+ return tuple(hidden_states)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=None,
+ attentions=None,
+ )
+
+
+def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
+
+ from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling
+
+ def pp_forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ logger = logging.get_logger(__name__)
+
+ # Preprocess passed in arguments
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if stage_manager.is_first_stage():
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # TODO(FoolPlayer): maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+ if pixel_values.dtype != expected_dtype:
+ pixel_values = pixel_values.to(expected_dtype)
+
+ embedding_output = self.embeddings(pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ interpolate_pos_encoding=interpolate_pos_encoding)
+ else:
+ assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
+
+ # Go through encoder
+ if not stage_manager.is_last_stage():
+ hidden_states = _encoder_forward(
+ encoder=self.encoder,
+ start_idx=stage_index[0],
+ end_idx=stage_index[1],
+ hidden_states=embedding_output,
+ head_mask=head_mask,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ )
+ return {'hidden_states': hidden_states}
+ else:
+ encoder_outputs = _encoder_forward(
+ encoder=self.encoder,
+ start_idx=stage_index[0],
+ end_idx=stage_index[1],
+ hidden_states=hidden_states,
+ head_mask=head_mask,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ )
+
+ # Go through rest layers
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ return pp_forward
+
+
+def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
+
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+ from transformers.models.vit.modeling_vit import ImageClassifierOutput
+
+ def pp_forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if not stage_manager.is_first_stage():
+ assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
+
+ outputs = self.vit(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ hidden_states=hidden_states,
+ )
+
+ # not last stage, return hidden_states
+ if not stage_manager.is_last_stage():
+ return outputs
+ else:
+ sequence_output = outputs[0]
+
+ # last stage
+ logits = self.classifier(sequence_output[:, 0, :])
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ return pp_forward
+
+
+def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]):
+
+ import math
+
+ import torch.nn as nn
+ from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput
+
+ def pp_forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
+ >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
+
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
+ >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
+ >>> # create random boolean mask of shape (batch_size, num_patches)
+ >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
+
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
+ >>> list(reconstructed_pixel_values.shape)
+ [1, 3, 224, 224]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
+ raise ValueError(
+ "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
+ "the reconstructed image has the same dimensions as the input."
+ f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.")
+
+ if not stage_manager.is_first_stage():
+ assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None"
+
+ outputs = self.vit(pixel_values,
+ bool_masked_pos=bool_masked_pos,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ hidden_states=hidden_states)
+ if not stage_manager.is_last_stage():
+ return outputs
+ else:
+ sequence_output = outputs[0]
+
+ # Reshape to (batch_size, num_channels, height, width)
+ sequence_output = sequence_output[:, 1:]
+ batch_size, sequence_length, num_channels = sequence_output.shape
+ height = width = math.floor(sequence_length**0.5)
+ sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
+
+ # Reconstruct pixel values
+ reconstructed_pixel_values = self.decoder(sequence_output)
+
+ masked_im_loss = None
+ if bool_masked_pos is not None:
+ size = self.config.image_size // self.config.patch_size
+ bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
+ mask = (bool_masked_pos.repeat_interleave(self.config.patch_size,
+ 1).repeat_interleave(self.config.patch_size,
+ 2).unsqueeze(1).contiguous())
+ reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
+ masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
+
+ if not return_dict:
+ output = (reconstructed_pixel_values,) + outputs[1:]
+ return ((masked_im_loss,) + output) if masked_im_loss is not None else output
+
+ return MaskedImageModelingOutput(
+ loss=masked_im_loss,
+ reconstruction=reconstructed_pixel_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ return pp_forward
+
+
+def get_vit_flash_self_attention_forward():
+
+ from transformers.models.vit.modeling_vit import ViTSelfAttention
+
+ from colossalai.kernel.cuda_native import ColoAttention
+
+ def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
+ x = x.view(new_x_shape)
+ return x
+
+ def forward(self: ViTSelfAttention,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
+ value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads,
+ self.attention_head_size)
+ query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
+
+ scale = 1.0 / math.sqrt(self.attention_head_size)
+ attention = ColoAttention(embed_dim=self.all_head_size,
+ num_heads=self.num_attention_heads,
+ dropout=self.dropout.p,
+ scale=scale)
+ context_layer = attention(query_layer, key_layer, value_layer)
+
+ outputs = (context_layer,)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_vit_output_forward():
+
+ from transformers.models.vit.modeling_vit import ViTOutput
+
+ def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
+ return hidden_states
+
+ return forward
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
new file mode 100644
index 000000000000..62f8f7b4763e
--- /dev/null
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -0,0 +1,962 @@
+import logging
+import random
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+ SequenceClassifierOutput,
+)
+from transformers.models.whisper.modeling_whisper import (
+ WhisperEncoder,
+ WhisperForAudioClassification,
+ WhisperForConditionalGeneration,
+ WhisperModel,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+
+def get_whisper_flash_attention_forward():
+
+ from transformers.models.whisper.modeling_whisper import WhisperAttention
+
+ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
+
+ def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
+ return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
+
+ def forward(
+ self: WhisperAttention,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (is_cross_attention and past_key_value is not None
+ and past_key_value[0].shape[1] == key_value_states.shape[1]):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
+ value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+ value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+ key_states = torch.cat([past_key_value[0], key_states], dim=1)
+ value_states = torch.cat([past_key_value[1], value_states], dim=1)
+ else:
+ # self_attention
+ key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+ value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ # get query proj
+ query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
+
+ src_len = key_states.size(1)
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}")
+
+ attn_type = None
+ flash_attention_mask = None
+
+ if self.is_decoder:
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
+ attn_type = AttnMaskType.paddedcausal
+
+ attention = ColoAttention(embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ dropout=self.dropout,
+ scale=self.scaling)
+ attn_output = attention(query_states,
+ key_states,
+ value_states,
+ attn_mask=flash_attention_mask,
+ attn_mask_type=attn_type)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+ return forward
+
+
+def get_jit_fused_whisper_encoder_layer_forward():
+
+ from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
+
+ def forward(
+ self: WhisperEncoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_head_mask: torch.Tensor,
+ output_attentions: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any()
+ or torch.isnan(hidden_states).any()):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ return forward
+
+
+def get_jit_fused_whisper_decoder_layer_forward():
+
+ from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer
+
+ def forward(
+ self: WhisperDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ return forward
+
+
+class WhisperPipelineForwards:
+ '''
+ This class serves as a micro library for forward function substitution of Llama models
+ under pipeline setting.
+ '''
+
+ @staticmethod
+ def whisper_encoder_forward(
+ self: WhisperEncoder,
+ input_features,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_states=None,
+ all_attentions=None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ):
+ r"""
+ Args:
+ input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
+ Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
+ obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
+ `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
+ `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
+ and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
+ attention_mask (`torch.Tensor`)`, *optional*):
+ Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
+ but it is not used. By default the silence in the input log mel spectrogram are ignored.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ logger = logging.get_logger(__name__)
+
+ stage = stage_manager.stage
+ at_first_stage = (stage == 0)
+ at_last_stage = (stage == decoder_starting_stage - 1)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Process inputs if at the first stage of encoder.
+ if at_first_stage:
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
+
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
+ embed_pos = self.embed_positions.weight
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ assert head_mask.size()[0] == (
+ len(self.layers)
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+
+ else:
+ if hidden_states is None:
+ raise ValueError(
+ "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+
+ for idx in range(start_idx, end_idx):
+ encoder_layer = self.layers[idx]
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ None,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ None,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if at_last_stage:
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(last_hidden_state=hidden_states,
+ hidden_states=encoder_states,
+ attentions=all_attentions)
+
+ else:
+ return {'hidden_states': hidden_states, 'head_mask': head_mask}
+
+ @staticmethod
+ def whisper_decoder_forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
+ on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
+ embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ logger = logging.get_logger(__name__)
+ stage = stage_manager.stage
+ at_first_stage = (stage == decoder_starting_stage)
+ at_last_stage = (stage == stage_manager.num_stages - 1)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}.")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if at_first_stage:
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ if input_ids is not None:
+ positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
+ else:
+ positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
+ past_key_values_length)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ )
+ use_cache = False
+
+ else:
+
+ if hidden_states is None:
+ raise ValueError(
+ "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
+ input_shape = hidden_states.size()[:-1]
+
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states,
+ past_key_values_length)
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+
+ for idx in range(start_idx, end_idx):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ decoder_layer = self.layers[idx]
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ dropout_probability = random.uniform(0, 1)
+ if self.training and (dropout_probability < self.layerdrop):
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, use_cache)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ None, # encoder attention mask
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None, # past_key_value
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(cross_attn_head_mask[idx]
+ if cross_attn_head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ if at_last_stage:
+ hidden_states = self.layer_norm(hidden_states)
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None)
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+ else:
+ return {
+ 'head_mask': head_mask,
+ 'cross_attn_head_mask': cross_attn_head_mask,
+ 'hidden_states': hidden_states,
+ }
+
+ @staticmethod
+ def whisper_model_forward(
+ self: WhisperModel,
+ input_features: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, WhisperModel
+ >>> from datasets import load_dataset
+
+ >>> model = WhisperModel.from_pretrained("openai/whisper-base")
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
+ >>> input_features = inputs.input_features
+ >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
+ >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
+ >>> list(last_hidden_state.shape)
+ [1, 2, 512]
+ ```"""
+ # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if past_key_values:
+ logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
+ past_key_values = None
+ if output_attentions:
+ logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
+ use_cache = False
+
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ in_decoder = stage_manager.stage >= decoder_starting_stage
+ if not in_decoder:
+ if encoder_outputs is None:
+ input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
+
+ encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
+ self.encoder,
+ input_features,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+
+ if stage_manager.stage == decoder_starting_stage - 1:
+ # last stage of encoder
+ return {'encoder_hidden_states': encoder_outputs[0]}
+ else:
+ return encoder_outputs
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ at_last_decoder_stage = stage_manager.is_last_stage()
+ at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
+ if encoder_outputs is not None:
+ encoder_hidden_states = encoder_outputs[0]
+ elif encoder_hidden_states is None:
+ raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
+
+ if not at_first_decoder_stage and hidden_states is None:
+ raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder,
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+
+ # Directly return outputs of overloaded Whisper forward if not at last stage.
+ if not at_last_decoder_stage:
+ # encoder_hidden_states should be passed to the next stage
+ decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
+ return decoder_outputs
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_hidden_states,
+ )
+
+ @staticmethod
+ def whisper_for_conditional_generation_forward(
+ self: WhisperForConditionalGeneration,
+ input_features: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
+ or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
+ only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
+ >>> from datasets import load_dataset
+
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
+ >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+ >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
+ >>> input_features = inputs.input_features
+
+ >>> generated_ids = model.generate(inputs=input_features)
+
+ >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ >>> transcription
+ ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id,
+ self.config.decoder_start_token_id)
+ in_decoder = stage_manager.stage >= decoder_starting_stage
+ at_last_decoder_stage = stage_manager.is_last_stage()
+ outputs = WhisperPipelineForwards.whisper_model_forward(self.model,
+ input_features,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+ if not in_decoder:
+ return outputs
+
+ if not at_last_decoder_stage:
+ # encoder_hidden_states should be passed to the next stage
+ outputs['encoder_hidden_states'] = encoder_hidden_states
+ return outputs
+
+ lm_logits = self.proj_out(outputs[0])
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # move labels to correct device to enable PP
+ labels = labels.to(lm_logits.device)
+ loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ @staticmethod
+ def whisper_for_audio_classification_forward(
+ self: WhisperForAudioClassification,
+ input_features: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_states=None,
+ all_attentions=None,
+ stage_index: Optional[List[int]] = None,
+ decoder_starting_stage: Optional[int] = None,
+ ):
+ r"""
+ This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
+ Please refer to original code of transformers for more details.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # audio_classification only holds encoder
+ encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
+ self.encoder,
+ input_features,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage,
+ )
+
+ if not stage_manager.is_last_stage():
+ return encoder_outputs
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = torch.stack(encoder_outputs, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = encoder_outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ pooled_output = hidden_states.mean(dim=1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # move labels to correct device to enable PP
+ labels = labels.to(logits.device)
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + encoder_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/auto_policy.py
similarity index 61%
rename from colossalai/shardformer/policies/autopolicy.py
rename to colossalai/shardformer/policies/auto_policy.py
index 085e3150c697..49613ffb37e0 100644
--- a/colossalai/shardformer/policies/autopolicy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -1,9 +1,10 @@
import importlib
from dataclasses import dataclass
+from typing import Optional
import torch.nn as nn
-from .basepolicy import Policy
+from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
@@ -29,7 +30,7 @@ class PolicyLocation:
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining":
- PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"),
+ PolicyLocation(file_name="bert", class_name="BertForPreTrainingPolicy"),
"transformers.models.bert.modeling_bert.BertLMHeadModel":
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
@@ -42,10 +43,12 @@ class PolicyLocation:
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
+ "transformers.models.bert.modeling_bert.BertForQuestionAnswering":
+ PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"),
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel":
- PolicyLocation(file_name="llama", class_name="LlamaPolicy"),
+ PolicyLocation(file_name="llama", class_name="LlamaModelPolicy"),
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification":
@@ -66,11 +69,21 @@ class PolicyLocation:
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
+ "transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering":
+ PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
+ # ViT
+ "transformers.models.vit.modeling_vit.ViTModel":
+ PolicyLocation(file_name="vit", class_name="ViTModelPolicy"),
+ "transformers.models.vit.modeling_vit.ViTForImageClassification":
+ PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"),
+ "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling":
+ PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"),
+
# OPT
"transformers.models.opt.modeling_opt.OPTModel":
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
@@ -92,14 +105,54 @@ class PolicyLocation:
PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"),
"transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering":
PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"),
+
+ # Whisper
+ "transformers.models.whisper.modeling_whisper.WhisperModel":
+ PolicyLocation(file_name="whisper", class_name="WhisperModelPolicy"),
+ "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration":
+ PolicyLocation(file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"),
+ "transformers.models.whisper.modeling_whisper.WhisperForAudioClassification":
+ PolicyLocation(file_name="whisper", class_name="WhisperForAudioClassificationPolicy"),
+
+ # Sam
+ "transformers.models.sam.modeling_sam.SamModel":
+ PolicyLocation(file_name="sam", class_name="SamModelPolicy"),
+
+ # Blip2
+ "transformers.models.blip_2.modeling_blip_2.Blip2Model":
+ PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
+ "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
+ PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),
+
+ # ChatGLM
+ "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
+ PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"),
+ "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
+ PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"),
+}
+
+_INFER_POLICY_LIST = {
+ # LlaMa
+ "transformers.models.llama.modeling_llama.LlamaModel":
+ PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
+ "transformers.models.llama.modeling_llama.LlamaForCausalLM":
+ PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
+ # Bloom
+ "transformers.models.bloom.modeling_bloom.BloomModel":
+ PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
+ "transformers.models.bloom.modeling_bloom.BloomForCausalLM":
+ PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
}
-def import_policy(policy_location: PolicyLocation) -> Policy:
+def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
- module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
+ if inference_only:
+ module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
+ else:
+ module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
@@ -115,7 +168,7 @@ def _fullname(obj):
return module + '.' + klass.__qualname__
-def get_autopolicy(model: nn.Module) -> Policy:
+def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
r"""
Return the auto policy for the model
@@ -126,12 +179,15 @@ def get_autopolicy(model: nn.Module) -> Policy:
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
- policy_location = _POLICY_LIST.get(full_name, None)
+ if inference_only:
+ policy_location = _INFER_POLICY_LIST.get(full_name, None)
+ else:
+ policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
- policy = import_policy(policy_location)
+ policy = import_policy(policy_location, inference_only)
return policy()
diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/base_policy.py
similarity index 62%
rename from colossalai/shardformer/policies/basepolicy.py
rename to colossalai/shardformer/policies/base_policy.py
index 2d347542fa7a..961c6a5259fe 100644
--- a/colossalai/shardformer/policies/basepolicy.py
+++ b/colossalai/shardformer/policies/base_policy.py
@@ -2,21 +2,21 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Type, Union
+from typing import Any, Callable, Dict, List, Optional, Union
+import numpy as np
import torch.nn as nn
+from torch import Tensor
+from torch.nn import Module
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
-class ParallelModule():
-
- def __init__(self):
- pass
-
-
@dataclass
class SubModuleReplacementDescription:
r"""
@@ -71,9 +71,8 @@ class Policy(ABC):
"""
def __init__(self) -> None:
- self.shard_config = None
- self.model = None
- self.shard_config = None
+ self.shard_config: Optional[ShardConfig] = None
+ self.model: Optional[Module] = None
def set_model(self, model: nn.Module) -> None:
r"""
@@ -94,6 +93,12 @@ def set_shard_config(self, shard_config: ShardConfig) -> None:
self.shard_config = shard_config
self.config_sanity_check()
+ @property
+ def pipeline_stage_manager(self) -> Optional[PipelineStageManager]:
+ if self.shard_config is not None:
+ return self.shard_config.pipeline_stage_manager
+ return None
+
@abstractmethod
def config_sanity_check(self):
"""
@@ -146,8 +151,78 @@ def append_or_create_submodule_replacement(
# append or create a new description
if target_key in policy:
- policy[target_key].sub_module_replacement.extend(description)
+ if policy[target_key].sub_module_replacement is None:
+ policy[target_key].sub_module_replacement = description
+ else:
+ policy[target_key].sub_module_replacement.extend(description)
else:
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
return policy
+
+ def append_or_create_method_replacement(
+ self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription],
+ target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ r"""
+ Append or create a new method replacement description to the policy for the given key.
+
+ Args:
+ description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
+ policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
+ target_key (Union[str, nn.Module]): the key of the policy to be updated
+ """
+ if target_key in policy:
+ if policy[target_key].method_replacement is None:
+ policy[target_key].method_replacement = description
+ else:
+ policy[target_key].method_replacement.update(description)
+ else:
+ policy[target_key] = ModulePolicyDescription(method_replacement=description)
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get layers that should be held in current stage. This method should be implemented by subclass.
+
+ Returns:
+ List[Module]: List of layers that should be hold in current stage
+ """
+ raise NotImplementedError
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """Get parameters that should be shared across stages. This method should be implemented by subclass.
+
+ Returns:
+ List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
+ """
+ return []
+
+ @staticmethod
+ def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
+ """Divide layers into stages
+
+ """
+ quotient = num_layers // num_stages
+ remainder = num_layers % num_stages
+
+ # calculate the num_layers per stage
+ layers_per_stage = [quotient] * num_stages
+
+ # deal with the rest layers
+ if remainder > 0:
+ start_position = num_stages // 2 - remainder // 2
+ for i in range(start_position, start_position + remainder):
+ layers_per_stage[i] += 1
+ return layers_per_stage
+
+ @staticmethod
+ def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
+ """
+ get the start index and end index of layers for each stage.
+ """
+ num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
+
+ start_idx = num_layers_per_stage_accumulated[stage]
+ end_idx = num_layers_per_stage_accumulated[stage + 1]
+
+ return [start_idx, end_idx]
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index 9c2736cc64d3..a141b7bd8fdf 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -1,14 +1,27 @@
+from functools import partial
+from typing import Callable, Dict, List
+
import torch.nn as nn
+from torch import Tensor
+from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
-from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from ..modeling.bert import (
+ BertPipelineForwards,
+ bert_sequence_parallel_forward_fn,
+ get_bert_flash_attention_forward,
+ get_jit_fused_bert_output_forward,
+ get_jit_fused_bert_self_output_forward,
+)
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
- 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
+ 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy',
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
- 'BertForMultipleChoicePolicy'
+ 'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy'
]
@@ -23,18 +36,27 @@ def preprocess(self):
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ if self.shard_config.enable_tensor_parallelism:
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
- from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
+ from transformers.models.bert.modeling_bert import (
+ BertEmbeddings,
+ BertLayer,
+ BertModel,
+ BertOutput,
+ BertSelfAttention,
+ BertSelfOutput,
+ )
policy = {}
-
+ use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
@@ -50,14 +72,26 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
@@ -66,6 +100,7 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
+ kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@@ -74,10 +109,15 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
+ kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@@ -96,6 +136,12 @@ def module_policy(self):
)
])
+ if use_sequence_parallel:
+ self.append_or_create_method_replacement(
+ description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
+ policy=policy,
+ target_key=BertModel)
+
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle bert layer
@@ -111,7 +157,6 @@ def module_policy(self):
],
policy=policy,
target_key=BertLayer)
-
# handle embedding layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
@@ -120,6 +165,30 @@ def module_policy(self):
)],
policy=policy,
target_key=BertEmbeddings)
+
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_bert_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=BertSelfAttention)
+
+ # use jit operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_bert_self_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=BertSelfOutput)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_bert_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=BertOutput)
+
return policy
def add_lm_head_policy(self, base_policy):
@@ -143,9 +212,66 @@ def add_lm_head_policy(self, base_policy):
target_key=BertLMPredictionHead)
return base_policy
+ def add_lm_prediction_policy(self, base_policy):
+ from transformers.models.bert.modeling_bert import BertLMPredictionHead
+ method_replacement = {
+ '_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict,
+ '_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict,
+ }
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=base_policy,
+ target_key=BertLMPredictionHead)
+ return base_policy
+
def postprocess(self):
return self.model
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "BertModel":
+ module = self.model
+ else:
+ module = self.model.bert
+
+ layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config)
+ }
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=model_cls)
+
+ return
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == 'BertModel':
+ module = self.model
+ else:
+ module = self.model.bert
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embeddings)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.encoder.layer[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.pooler)
+
+ return held_layers
+
# BertModel
class BertModelPolicy(BertPolicy):
@@ -153,24 +279,61 @@ class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
+ def module_policy(self):
+ policy = super().module_policy()
+ from transformers.models.bert.modeling_bert import BertModel
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertModel,
+ new_forward=BertPipelineForwards.bert_model_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in bert model"""
+ return []
+
# BertForPreTraining
-class BertForPretrainingPolicy(BertPolicy):
+class BertForPreTrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
- module_policy = super().module_policy()
- module_policy = self.add_lm_head_policy(module_policy)
- return module_policy
+ policy = super().module_policy()
+ policy = self.add_lm_head_policy(policy)
+ policy = self.add_lm_prediction_policy(policy)
+ from transformers.models.bert.modeling_bert import BertForPreTraining
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertForPreTraining,
+ new_forward=BertPipelineForwards.bert_for_pretraining_forward,
+ policy=policy)
+ return policy
- def postprocess(self):
- binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
- for k, v in binding_map.items():
- param = getattr_(self.model, k)
- setattr_(self.model, v, param)
- return self.model
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage"""
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.cls)
+
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ model = self.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
+ # tie weights
+ return [{
+ 0: model.bert.embeddings.word_embeddings.weight,
+ self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
+ }]
+ return []
# BertLMHeadModel
@@ -180,16 +343,36 @@ def __init__(self) -> None:
super().__init__()
def module_policy(self):
- module_policy = super().module_policy()
- module_policy = self.add_lm_head_policy(module_policy)
- return module_policy
+ policy = super().module_policy()
+ policy = self.add_lm_head_policy(policy)
+ policy = self.add_lm_prediction_policy(policy)
+ from transformers.models.bert.modeling_bert import BertLMHeadModel
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertLMHeadModel,
+ new_forward=BertPipelineForwards.bert_lm_head_model_forward,
+ policy=policy)
+ return policy
- def postprocess(self):
- binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
- for k, v in binding_map.items():
- param = getattr_(self.model, k)
- setattr_(self.model, v, param)
- return self.model
+ def get_held_layers(self) -> List[Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.cls)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ bert_model = self.model.bert
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
+ # tie weights
+ return [{
+ 0: bert_model.embeddings.word_embeddings.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
+ }]
+ return []
# BertForMaskedLM
@@ -199,16 +382,36 @@ def __init__(self) -> None:
super().__init__()
def module_policy(self):
- module_policy = super().module_policy()
- module_policy = self.add_lm_head_policy(module_policy)
- return module_policy
+ policy = super().module_policy()
+ policy = self.add_lm_head_policy(policy)
+ policy = self.add_lm_prediction_policy(policy)
+ from transformers.models.bert.modeling_bert import BertForMaskedLM
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertForMaskedLM,
+ new_forward=BertPipelineForwards.bert_for_masked_lm_forward,
+ policy=policy)
+ return policy
- def postprocess(self):
- binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
- for k, v in binding_map.items():
- param = getattr_(self.model, k)
- setattr_(self.model, v, param)
- return self.model
+ def get_held_layers(self) -> List[Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.cls)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ bert_model = self.model.bert
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
+ # tie weights
+ return [{
+ 0: bert_model.embeddings.word_embeddings.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
+ }]
+ return []
# BertForSequenceClassification
@@ -220,7 +423,7 @@ def __init__(self) -> None:
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification
- module_policy = super().module_policy()
+ policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
@@ -232,8 +435,28 @@ def module_policy(self):
)
])
}
- module_policy.update(addon_module)
- return module_policy
+ policy.update(addon_module)
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertForSequenceClassification,
+ new_forward=BertPipelineForwards.bert_for_sequence_classification_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.dropout)
+ held_layers.append(self.model.classifier)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ # no shared params for sequence classification model
+ return []
# BertForTokenClassification
@@ -245,7 +468,7 @@ def __init__(self) -> None:
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification
- module_policy = super().module_policy()
+ policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
@@ -257,8 +480,28 @@ def module_policy(self):
)
])
}
- module_policy.update(addon_module)
- return module_policy
+ policy.update(addon_module)
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertForTokenClassification,
+ new_forward=BertPipelineForwards.bert_for_token_classification_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.dropout)
+ held_layers.append(self.model.classifier)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ # no shared params for sequence classification model
+ return []
# BertForNextSentencePrediction
@@ -267,6 +510,30 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
+ def module_policy(self):
+ policy = super().module_policy()
+ from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertForNextSentencePrediction,
+ new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.cls)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ # no shared params for sequence classification model
+ return []
+
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
@@ -277,7 +544,7 @@ def __init__(self) -> None:
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice
- module_policy = super().module_policy()
+ policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
@@ -289,5 +556,55 @@ def module_policy(self):
)
])
}
- module_policy.update(addon_module)
- return module_policy
+ policy.update(addon_module)
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertForMultipleChoice,
+ new_forward=BertPipelineForwards.bert_for_multiple_choice_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.dropout)
+ held_layers.append(self.model.classifier)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ # no shared params for sequence classification model
+ return []
+
+
+class BertForQuestionAnsweringPolicy(BertPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.bert.modeling_bert import BertForQuestionAnswering
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BertForQuestionAnswering,
+ new_forward=BertPipelineForwards.bert_for_question_answering_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.qa_outputs)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ # no shared params for sequence classification model
+ return []
diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py
new file mode 100644
index 000000000000..2e5388ab0490
--- /dev/null
+++ b/colossalai/shardformer/policies/blip2.py
@@ -0,0 +1,326 @@
+import torch.nn as nn
+
+import colossalai.shardformer.layer as col_nn
+
+from .._utils import getattr_, setattr_
+from ..modeling.blip2 import (
+ forward_fn,
+ get_blip2_flash_attention_forward,
+ get_jit_fused_blip2_QFormer_output_forward,
+ get_jit_fused_blip2_QFormer_self_output_forward,
+)
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+__all__ = ['BlipPolicy', 'BlipModelPolicy']
+
+
+class BlipPolicy(Policy):
+
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ # reshape the embedding layer
+ r"""
+ Reshape the Embedding layer to make the embedding dimension divisible by world_size
+ """
+ # TODO:
+ vocab_size = self.model.config.qformer_config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
+ return self.model
+
+ def module_policy(self):
+ from transformers.models.blip_2.modeling_blip_2 import (
+ Blip2Attention,
+ Blip2EncoderLayer,
+ Blip2QFormerLayer,
+ Blip2QFormerModel,
+ Blip2QFormerOutput,
+ Blip2QFormerSelfOutput,
+ Blip2VisionModel,
+ )
+ from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
+
+ policy = {}
+
+ if self.shard_config.enable_tensor_parallelism:
+ policy[Blip2EncoderLayer] = ModulePolicyDescription(attribute_replacement={
+ "self_attn.num_heads":
+ self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attn.embed_dim":
+ self.model.config.vision_config.hidden_size // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.qkv",
+ target_module=col_nn.FusedLinear1D_Col,
+ kwargs={
+ "n_fused": 3,
+ }),
+ SubModuleReplacementDescription(
+ suffix="self_attn.projection",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.fc1",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.fc2",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ ])
+
+ policy[Blip2QFormerModel] = ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ ])
+
+ policy[Blip2QFormerLayer] = ModulePolicyDescription(attribute_replacement={
+ "attention.attention.num_attention_heads":
+ self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "attention.attention.all_head_size":
+ self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size,
+ "crossattention.attention.num_attention_heads":
+ self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "crossattention.attention.all_head_size":
+ self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attention.attention.query",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.attention.key",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.attention.value",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.attention.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.output.dense",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.output.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="crossattention.attention.query",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="crossattention.attention.key",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="crossattention.attention.value",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="crossattention.attention.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="crossattention.output.dense",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="crossattention.output.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="intermediate_query.dense",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="output_query.dense",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="output_query.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ )
+ ])
+
+ policy[OPTDecoderLayer] = ModulePolicyDescription(attribute_replacement={
+ "self_attn.embed_dim":
+ self.model.config.text_config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads":
+ self.model.config.text_config.num_attention_heads // self.shard_config.tensor_parallel_size
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.out_proj",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="fc1",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="fc2",
+ target_module=col_nn.Linear1D_Row,
+ )
+ ])
+
+ policy[OPTForCausalLM] = ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="model.decoder.embed_tokens",
+ target_module=col_nn.VocabParallelEmbedding1D,
+ ),
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=col_nn.Linear1D_Col,
+ kwargs={"gather_output": True},
+ ),
+ ])
+
+ policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ # Handle Blip2EncoderLayer layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm1",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm2",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=Blip2EncoderLayer)
+
+ # handle Blip2VisionModel layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="post_layernorm",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=Blip2VisionModel)
+
+ # handle Blip2VisionModel layer
+ self.append_or_create_submodule_replacement(
+ description=[SubModuleReplacementDescription(
+ suffix="layernorm",
+ target_module=col_nn.FusedLayerNorm,
+ )],
+ policy=policy,
+ target_key=Blip2QFormerModel)
+
+ # handle Blip2QFormerLayer layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="attention.output.LayerNorm",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="crossattention.output.LayerNorm",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="output_query.LayerNorm",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=Blip2QFormerLayer)
+
+ # handle OPTForCausalLM layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="model.decoder.final_layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=OPTForCausalLM)
+
+ # handle OPTDecoderLayer layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="self_attn_layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=OPTDecoderLayer)
+
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_blip2_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=Blip2Attention)
+
+ # use jit operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=Blip2QFormerSelfOutput)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_blip2_QFormer_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=Blip2QFormerOutput)
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+
+# Blip2Model
+class Blip2ModelPolicy(BlipPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+
+# Blip2ForConditionalGeneration
+class Blip2ForConditionalGenerationPolicy(BlipPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index a0b5340f72bc..7c418d02bcb6 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -1,10 +1,24 @@
+from functools import partial
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
import torch.nn as nn
+from torch import Tensor
+from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
-from ..modeling.bloom import build_bloom_alibi_tensor_fn
-from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from ..modeling.bloom import (
+ BloomPipelineForwards,
+ build_bloom_alibi_tensor_fn,
+ get_bloom_flash_attention_forward,
+ get_bloom_sequence_parallel_forward_fn,
+ get_jit_fused_bloom_attention_forward,
+ get_jit_fused_bloom_gelu_forward,
+ get_jit_fused_bloom_mlp_forward,
+)
+from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class BloomPolicy(Policy):
@@ -17,18 +31,21 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ if self.shard_config.enable_tensor_parallelism:
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
- from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
+ from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel
policy = {}
+ use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@@ -39,11 +56,14 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
- ),
+ kwargs={
+ 'seq_parallel': use_sequence_parallel,
+ 'overlap': overlap
+ }),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
- ),
+ kwargs={'seq_parallel': use_sequence_parallel}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
@@ -51,11 +71,14 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
- ),
+ kwargs={
+ 'seq_parallel': use_sequence_parallel,
+ 'overlap': overlap
+ }),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
- ),
+ kwargs={'seq_parallel': use_sequence_parallel}),
])
policy[BloomModel] = ModulePolicyDescription(
@@ -102,14 +125,117 @@ def module_policy(self):
policy=policy,
target_key=BloomBlock)
+ if use_sequence_parallel:
+ self.append_or_create_method_replacement(
+ description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)},
+ policy=policy,
+ target_key=BloomModel)
+
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_bloom_flash_attention_forward(),
+ 'dropout_add': get_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=BloomAttention)
+
+ # enable jit fused operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_bloom_attention_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=BloomAttention)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_bloom_mlp_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=BloomMLP)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_bloom_gelu_forward(),
+ 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
+ },
+ policy=policy,
+ target_key=BloomGelu)
+
return policy
def postprocess(self):
return self.model
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "BloomModel":
+ module = self.model
+ else:
+ module = self.model.transformer
+
+ layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config)
+ }
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=model_cls)
+ return
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == 'BloomModel':
+ module = self.model
+ else:
+ module = self.model.transformer
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.word_embeddings)
+ held_layers.append(module.word_embeddings_layernorm)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.h[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.ln_f)
+
+ return held_layers
+
class BloomModelPolicy(BloomPolicy):
- pass
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ from transformers.models.bloom.modeling_bloom import BloomModel
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BloomModel,
+ new_forward=BloomPipelineForwards.bloom_model_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """
+ get pipeline layers for current stage
+ """
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ '''no shared params in bloom model'''
+ return []
class BloomForCausalLMPolicy(BloomPolicy):
@@ -124,21 +250,30 @@ def module_policy(self):
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForCausalLM)
-
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BloomForCausalLM,
+ new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward,
+ policy=policy)
return policy
- def postprocess(self):
- binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
-
- for k, v in binding_map.items():
- param = getattr_(self.model, k)
-
- if not isinstance(param, nn.Parameter):
- param = nn.Parameter(param)
-
- # tie weights
- setattr_(self.model, v, param)
- return self.model
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ bloom_model = self.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if id(bloom_model.transformer.word_embeddings.weight) == id(bloom_model.lm_head.weight):
+ # tie weights
+ return [{
+ 0: bloom_model.transformer.word_embeddings.weight,
+ self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight
+ }]
+ return []
class BloomForSequenceClassificationPolicy(BloomPolicy):
@@ -153,9 +288,24 @@ def module_policy(self):
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForSequenceClassification)
-
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BloomForSequenceClassification,
+ new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward,
+ policy=policy)
return policy
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.score)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in bloom for sequence classification model"""
+ return []
+
class BloomForTokenClassificationPolicy(BloomPolicy):
@@ -176,10 +326,46 @@ def module_policy(self):
],
policy=policy,
target_key=BloomForTokenClassification)
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BloomForTokenClassification,
+ new_forward=BloomPipelineForwards.bloom_for_token_classification_forward,
+ policy=policy)
return policy
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.dropout)
+ held_layers.append(self.model.classifier)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in bloom for token classification model"""
+ return []
+
class BloomForQuestionAnsweringPolicy(BloomPolicy):
# No head sharding as the output features is only 2
- pass
+ def module_policy(self):
+ from transformers.models.bloom.modeling_bloom import BloomForQuestionAnswering
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=BloomForQuestionAnswering,
+ new_forward=BloomPipelineForwards.bloom_for_question_answering_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.qa_outputs)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in bloom for question answering model"""
+ return []
diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py
new file mode 100644
index 000000000000..44898847056a
--- /dev/null
+++ b/colossalai/shardformer/policies/chatglm2.py
@@ -0,0 +1,262 @@
+from functools import partial
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch.nn as nn
+from torch import Tensor
+from transformers.modeling_outputs import BaseModelOutputWithPast
+
+import colossalai.shardformer.layer as col_nn
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
+from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMModel,
+ GLMBlock,
+)
+
+from ..modeling.chatglm2 import (
+ get_chatglm_sequence_parallel_forward_fn,
+ get_flash_core_attention_forward,
+ get_jit_fused_glm_block_forward,
+)
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
+
+
+class ChatGLMPolicy(Policy):
+
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ # Resize embedding
+ if self.shard_config.enable_tensor_parallelism:
+ vocab_size = self.model.config.padded_vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
+
+ if self.pipeline_stage_manager is not None:
+ # the batch_size_dim is bounded to Model
+ bsz_dim = 1
+ setattr(self.model, 'batch_size_dim', bsz_dim)
+
+ return self.model
+
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+
+ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
+
+ policy = {}
+
+ use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ overlap = self.shard_config.enable_sequence_overlap
+ if self.shard_config.enable_tensor_parallelism:
+ policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="embedding.word_embeddings",
+ target_module=col_nn.VocabParallelEmbedding1D,
+ )
+ ])
+
+ policy[GLMBlock] = ModulePolicyDescription(
+ attribute_replacement={
+ "self_attention.num_attention_heads_per_partition":
+ self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attention.projection_size":
+ (self.model.config.kv_channels * self.model.config.num_attention_heads) //
+ self.shard_config.tensor_parallel_size,
+ "self_attention.qkv_hidden_size":
+ (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
+ self.shard_config.tensor_parallel_size,
+ "self_attention.core_attention.num_attention_heads_per_partition":
+ self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attention.core_attention.hidden_size_per_partition":
+ self.model.config.kv_channels * self.model.config.num_attention_heads //
+ self.shard_config.tensor_parallel_size,
+ },
+ param_replacement=[],
+ sub_module_replacement=[
+ SubModuleReplacementDescription(suffix="self_attention.query_key_value",
+ target_module=col_nn.Linear1D_Col,
+ kwargs={
+ 'seq_parallel': use_sequence_parallel,
+ 'seq_parallel_dim': 0,
+ 'overlap': overlap
+ }),
+ SubModuleReplacementDescription(suffix="self_attention.dense",
+ target_module=col_nn.Linear1D_Row,
+ kwargs={
+ 'seq_parallel': use_sequence_parallel,
+ 'seq_parallel_dim': 0
+ }),
+ SubModuleReplacementDescription(
+ suffix="self_attention.core_attention.attention_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ ])
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ if not self.model.config.rmsnorm:
+
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
+ SubModuleReplacementDescription(suffix="post_attention_layernorm",
+ target_module=col_nn.FusedLayerNorm)
+ ],
+ policy=policy,
+ target_key=GLMBlock)
+
+ if self.model.config.post_layer_norm:
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(suffix="encoder.final_layernorm",
+ target_module=col_nn.FusedLayerNorm)
+ ],
+ policy=policy,
+ target_key=ChatGLMModel)
+
+ else:
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
+ SubModuleReplacementDescription(suffix="post_attention_layernorm",
+ target_module=col_nn.FusedRMSNorm)
+ ],
+ policy=policy,
+ target_key=GLMBlock)
+
+ if self.model.config.post_layer_norm:
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(suffix="encoder.final_layernorm",
+ target_module=col_nn.FusedRMSNorm)
+ ],
+ policy=policy,
+ target_key=ChatGLMModel)
+
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_flash_core_attention_forward(),
+ },
+ policy=policy,
+ target_key=CoreAttention)
+
+ # use sequence parallel
+ if use_sequence_parallel:
+ self.append_or_create_method_replacement(
+ description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
+ policy=policy,
+ target_key=ChatGLMModel)
+
+ # use jit fused operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_glm_block_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=GLMBlock)
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+ def get_held_layers(self) -> List[nn.Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == 'ChatGLMModel':
+ module = self.model
+ else:
+ module = self.model.transformer
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embedding)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.encoder.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ if module.encoder.post_layer_norm:
+ held_layers.append(module.encoder.final_layernorm)
+
+ # rotary_pos_emb is needed for all stages
+ held_layers.append(module.rotary_pos_emb)
+
+ return held_layers
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if not self.pipeline_stage_manager:
+ raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == 'ChatGLMModel':
+ module = self.model
+ else:
+ module = self.model.transformer
+
+ layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config)
+ }
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
+
+
+class ChatGLMModelPolicy(ChatGLMPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Model
+
+ policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=ChatGLMModel,
+ new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ return super().get_held_layers()
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in ChatGLMModel."""
+ return []
+
+
+class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
+
+ def module_policy(self):
+ policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=ChatGLMForConditionalGeneration,
+ new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.transformer.output_layer)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in ChatGLMForConditionalGenerationModel."""
+ return []
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index 549cdbf87a80..5093fd469af8 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -1,9 +1,13 @@
-import torch.nn as nn
+from functools import partial
+from typing import Callable, Dict, List
+
+from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
-from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
@@ -21,66 +25,80 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ if self.shard_config.enable_tensor_parallelism:
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
- from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {}
-
+ use_sequence_parallel = self.shard_config.enable_sequence_parallelism
+ overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
+ SubModuleReplacementDescription(
+ suffix="drop",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
])
- policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
- "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
- },
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="attn.c_attn",
- target_module=col_nn.GPT2FusedLinearConv1D_Col,
- kwargs={
- "n_fused": 3,
- },
- ),
- SubModuleReplacementDescription(
- suffix="attn.c_proj",
- target_module=col_nn.GPT2FusedLinearConv1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="mlp.c_fc",
- target_module=col_nn.GPT2FusedLinearConv1D_Col,
- kwargs={
- "n_fused": 1,
- },
- ),
- SubModuleReplacementDescription(
- suffix="mlp.c_proj",
- target_module=col_nn.GPT2FusedLinearConv1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="attn.attn_dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="attn.resid_dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="mlp.dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- ])
+
+ policy[GPT2Block] = ModulePolicyDescription(
+ attribute_replacement={
+ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn.c_attn",
+ target_module=col_nn.GPT2FusedLinearConv1D_Col,
+ kwargs={
+ "n_fused": 3,
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
+ ),
+ SubModuleReplacementDescription(suffix="attn.c_proj",
+ target_module=col_nn.GPT2FusedLinearConv1D_Row,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ }),
+ SubModuleReplacementDescription(
+ suffix="mlp.c_fc",
+ target_module=col_nn.GPT2FusedLinearConv1D_Col,
+ kwargs={
+ "n_fused": 1,
+ "seq_parallel": use_sequence_parallel,
+ "overlap": overlap
+ },
+ ),
+ SubModuleReplacementDescription(suffix="mlp.c_proj",
+ target_module=col_nn.GPT2FusedLinearConv1D_Row,
+ kwargs={
+ "seq_parallel": use_sequence_parallel,
+ }),
+ SubModuleReplacementDescription(
+ suffix="attn.attn_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.resid_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ ])
# optimization configuration
if self.shard_config.enable_fused_normalization:
@@ -106,11 +124,66 @@ def module_policy(self):
],
policy=policy,
target_key=GPT2Block)
+
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_gpt2_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=GPT2Attention)
+
+ if self.shard_config.enable_sequence_parallelism:
+ policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
+
return policy
def postprocess(self):
return self.model
+ def get_held_layers(self) -> List[nn.Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == 'GPT2Model':
+ module = self.model
+ else:
+ module = self.model.transformer
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.wte)
+ held_layers.append(module.wpe)
+ held_layers.append(module.drop)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.h[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.ln_f)
+ return held_layers
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if not self.pipeline_stage_manager:
+ raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == 'GPT2Model':
+ module = self.model
+ else:
+ module = self.model.transformer
+
+ layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=self.shard_config)
+ }
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
+
# GPT2Model
class GPT2ModelPolicy(GPT2Policy):
@@ -118,6 +191,24 @@ class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
+ def module_policy(self):
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Model
+
+ policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=GPT2Model,
+ new_forward=GPT2PipelineForwards.gpt2_model_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ return super().get_held_layers()
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in GPT2Model."""
+ return []
+
# GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
@@ -139,17 +230,31 @@ def module_policy(self):
])
}
module_policy.update(addon_module)
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
+ new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
+ policy=module_policy)
return module_policy
- def postprocess(self):
- binding_map = {"transformer.wte.weight": "lm_head.weight"}
- for k, v in binding_map.items():
- param = getattr_(self.model, k)
- setattr_(self.model, v, param)
- return self.model
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ '''The weights of wte and lm_head are shared.'''
+ module = self.model
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager is not None:
+ if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
+ first_stage, last_stage = 0, stage_manager.num_stages - 1
+ return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
+ return []
-# GPT22DoubleHeadsModel
+# GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
@@ -169,14 +274,64 @@ def module_policy(self):
])
}
module_policy.update(addon_module)
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel,
+ new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward,
+ policy=module_policy)
+
return module_policy
- def postprocess(self):
- binding_map = {"transformer.wte.weight": "lm_head.weight"}
- for k, v in binding_map.items():
- param = getattr_(self.model, k)
- setattr_(self.model, v, param)
- return self.model
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ multiple_choice_head = self.model.multiple_choice_head
+ held_layers.append(self.model.lm_head)
+ held_layers.append(multiple_choice_head.summary)
+ held_layers.append(multiple_choice_head.activation)
+ held_layers.append(multiple_choice_head.first_dropout)
+ held_layers.append(multiple_choice_head.last_dropout)
+
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ '''The weights of wte and lm_head are shared.'''
+ module = self.model
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager is not None:
+ if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
+ first_stage, last_stage = 0, stage_manager.num_stages - 1
+ return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
+ return []
+
+
+# GPT2ForQuestionAnswering
+class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
+
+ module_policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering,
+ new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward,
+ policy=module_policy)
+
+ return module_policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.qa_outputs)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ '''No shared_params in gpt2 for QA.'''
+ return []
# GPT2ForTokenClassification
@@ -185,9 +340,61 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
+ def module_policy(self):
+ from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
+
+ module_policy = super().module_policy()
+
+ if self.shard_config.enable_tensor_parallelism:
+ addon_module = {
+ GPT2ForTokenClassification:
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput)
+ ])
+ }
+ module_policy.update(addon_module)
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=GPT2ForTokenClassification,
+ new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward,
+ policy=module_policy)
+ return module_policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.dropout)
+ held_layers.append(self.model.classifier)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in GPT2ForTokenClassification."""
+ return []
+
# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
+
+ def module_policy(self):
+ from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
+
+ module_policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification,
+ new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward,
+ policy=module_policy)
+ return module_policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.score)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in GPT2ForTokenClassification."""
+ return []
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 157785bdcf13..cc131e8168fc 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -1,10 +1,15 @@
-from typing import Dict, Union
+import warnings
+from functools import partial
+from typing import Callable, Dict, List, Union
import torch.nn as nn
+from torch import Tensor
+from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
-from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
@@ -15,29 +20,38 @@ def config_sanity_check(self):
pass
def preprocess(self):
- # Resize embedding
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
+ if self.shard_config.enable_tensor_parallelism:
+ # Resize embedding
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
+ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = {}
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
+
if self.shard_config.enable_tensor_parallelism:
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["self_attn.num_key_value_heads"] = \
+ self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
+
policy[LlamaDecoderLayer] = ModulePolicyDescription(
- attribute_replacement={
- "self_attn.hidden_size":
- self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "self_attn.num_heads":
- self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
- },
+ attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
@@ -99,11 +113,83 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=LlamaModel)
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_llama_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=LlamaAttention)
+
return policy
def postprocess(self):
return self.model
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "LlamaModel":
+ module = self.model
+ else:
+ module = self.model.model
+
+ layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=model_cls)
+
+ return
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == 'LlamaModel':
+ module = self.model
+ else:
+ module = self.model.model
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+
+ return held_layers
+
+
+class LlamaModelPolicy(LlamaPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ from transformers.models.llama.modeling_llama import LlamaModel
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(model_cls=LlamaModel,
+ new_forward=LlamaPipelineForwards.llama_model_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama model"""
+ return []
+
class LlamaForCausalLMPolicy(LlamaPolicy):
@@ -122,8 +208,35 @@ def module_policy(self):
])
}
policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(model_cls=LlamaForCausalLM,
+ new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
+ policy=policy)
+
return policy
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ llama_model = self.model.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if id(llama_model.embed_tokens.weight) == id(
+ self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1:
+ # tie weights
+ return [{
+ 0: llama_model.embed_tokens.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight
+ }]
+ return []
+
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
@@ -142,4 +255,22 @@ def module_policy(self):
])
}
policy.update(new_item)
+ # to be confirmed
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(model_cls=LlamaForSequenceClassification,
+ new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward,
+ policy=policy)
return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.score)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama for sequence classification model"""
+ return []
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index b87db53f45f1..abe491bfaace 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -1,7 +1,16 @@
+import warnings
+from functools import partial
+from typing import Callable, Dict, List
+
+import torch.nn as nn
+from torch import Tensor, nn
+
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
-from .._utils import getattr_, setattr_
-from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from .._utils import getattr_
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
@@ -19,17 +28,21 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ if self.shard_config.enable_tensor_parallelism:
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {}
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
@@ -89,17 +102,91 @@ def module_policy(self):
policy=policy,
target_key=OPTDecoderLayer)
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_opt_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=OPTAttention)
+
+ # use jit fused operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_opt_decoder_layer_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=OPTDecoderLayer)
+
return policy
def postprocess(self):
return self.model
+ def get_held_layers(self) -> List[nn.Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == 'OPTModel':
+ module = self.model.decoder
+ else:
+ module = self.model.model.decoder
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ held_layers.append(module.embed_positions)
+ held_layers.append(module.project_in)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.final_layer_norm)
+ held_layers.append(module.project_out)
+ return held_layers
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == 'OPTModel':
+ module = self.model.decoder
+ else:
+ module = self.model.model.decoder
+
+ layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=model_cls)
+
class OPTModelPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
+ def module_policy(self):
+ from transformers.models.opt.modeling_opt import OPTModel
+
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=OPTModel,
+ new_forward=OPTPipelineForwards.opt_model_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ return super().get_held_layers()
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in OPTModel."""
+ return []
+
class OPTForCausalLMPolicy(OPTPolicy):
@@ -107,23 +194,42 @@ def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy()
-
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=OPTForCausalLM)
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=OPTForCausalLM,
+ new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
+ policy=policy)
+
return policy
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ opt_model = self.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ num_stages = self.pipeline_stage_manager.num_stages
+ if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
+ return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
+ return []
+
def postprocess(self):
- binding_map = {
- 'model.decoder.embed_tokens': 'lm_head',
- }
+ if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
+ binding_map = {
+ 'model.decoder.embed_tokens': 'lm_head',
+ }
- for k, v in binding_map.items():
- src_mod = getattr_(self.model, k)
- dst_mod = getattr_(self.model, v)
- dst_mod.weight = src_mod.weight
+ for k, v in binding_map.items():
+ src_mod = getattr_(self.model, k)
+ dst_mod = getattr_(self.model, v)
+ dst_mod.weight = src_mod.weight
return self.model
@@ -133,8 +239,50 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
+ def module_policy(self):
+ from transformers.models.opt.modeling_opt import OPTForSequenceClassification
+
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=OPTForSequenceClassification,
+ new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.score)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ "no shared params in OPTForSequenceClassification"
+ return []
+
class OPTForQuestionAnsweringPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
+
+ def module_policy(self):
+ from transformers.models.opt.modeling_opt import OPTForQuestionAnswering
+
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(model_cls=OPTForQuestionAnswering,
+ new_forward=OPTPipelineForwards.opt_for_question_answering_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.qa_outputs)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ "no shared params in OPTForSequenceClassification"
+ return []
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
new file mode 100644
index 000000000000..9753d5a737b9
--- /dev/null
+++ b/colossalai/shardformer/policies/sam.py
@@ -0,0 +1,223 @@
+import torch.nn as nn
+
+import colossalai.shardformer.layer as col_nn
+
+from .._utils import getattr_, setattr_
+from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+__all__ = ['SamPolicy', 'SamModelPolicy']
+
+
+class SamPolicy(Policy):
+
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ return self.model
+
+ def module_policy(self):
+ from transformers.models.sam.modeling_sam import (
+ SamAttention,
+ SamFeedForward,
+ SamTwoWayAttentionBlock,
+ SamTwoWayTransformer,
+ SamVisionAttention,
+ SamVisionLayer,
+ )
+
+ policy = {}
+
+ if self.shard_config.enable_tensor_parallelism:
+ policy[SamVisionLayer] = ModulePolicyDescription(attribute_replacement={
+ "attn.num_attention_heads":
+ self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attn.qkv",
+ target_module=col_nn.FusedLinear1D_Col,
+ kwargs={
+ "n_fused": 3,
+ },
+ ),
+ SubModuleReplacementDescription(
+ suffix="attn.proj",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.lin1",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.lin2",
+ target_module=col_nn.Linear1D_Row,
+ )
+ ])
+ policy[SamTwoWayAttentionBlock] = ModulePolicyDescription(
+ attribute_replacement={
+ "self_attn.num_attention_heads":
+ self.model.config.mask_decoder_config.num_attention_heads //
+ self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.out_proj",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="cross_attn_token_to_image.q_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="cross_attn_token_to_image.k_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="cross_attn_token_to_image.v_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="cross_attn_token_to_image.out_proj",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.lin1",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.lin2",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="cross_attn_image_to_token.q_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="cross_attn_image_to_token.k_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="cross_attn_image_to_token.v_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="cross_attn_image_to_token.out_proj",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ ])
+ policy[SamTwoWayTransformer] = ModulePolicyDescription(attribute_replacement={
+ "final_attn_token_to_image.num_attention_heads":
+ self.model.config.mask_decoder_config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="final_attn_token_to_image.q_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_attn_token_to_image.k_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_attn_token_to_image.v_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_attn_token_to_image.out_proj",
+ target_module=col_nn.Linear1D_Row,
+ )
+ ])
+
+ # add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout`
+ policy[SamVisionAttention] = ModulePolicyDescription(attribute_replacement={
+ "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout)
+ },
+ method_replacement={"forward": forward_fn()},
+ sub_module_replacement=[])
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ # Handle SamVisionLayer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm1",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm2",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=SamVisionLayer)
+
+ # Handle SamTwoWayAttentionBlock
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm1",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm2",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm3",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm4",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=SamTwoWayAttentionBlock)
+
+ # Handle SamTwoWayTransformer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm_final_attn",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=SamTwoWayTransformer)
+
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_sam_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=SamAttention)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_sam_vision_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=SamVisionAttention)
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+
+# SamModel
+class SamModelPolicy(SamPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index cde59ab77042..92cbd3f72b83 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -1,3 +1,10 @@
+import warnings
+from functools import partial
+from typing import Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+from torch import Tensor, nn
+
from colossalai.shardformer.layer import (
DropoutForParallelInput,
Embedding1D,
@@ -6,12 +13,20 @@
Linear1D_Row,
VocabParallelEmbedding1D,
)
-from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_
-from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from ..modeling.t5 import (
+ T5PipelineForwards,
+ get_jit_fused_T5_layer_ff_forward,
+ get_t5_flash_attention_forward,
+ get_T5_layer_cross_attention_forward,
+ get_T5_layer_self_attention_forward,
+)
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
-__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
+__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
class T5BasePolicy(Policy):
@@ -24,11 +39,12 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
+ if self.shard_config.enable_tensor_parallelism:
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self):
@@ -44,6 +60,10 @@ def module_policy(self):
policy = {}
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
@@ -52,7 +72,7 @@ def module_policy(self):
),
SubModuleReplacementDescription(
suffix="embed_tokens",
- target_module=Embedding1D,
+ target_module=VocabParallelEmbedding1D,
)
])
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
@@ -106,7 +126,7 @@ def module_policy(self):
])
policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
- suffix="wi_0",
+ suffix="wi_0 ",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
@@ -161,35 +181,195 @@ def module_policy(self):
suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy,
target_key=T5Stack)
+
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_t5_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=T5Attention)
+
+ # use jit operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_T5_layer_ff_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=T5LayerFF)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_T5_layer_self_attention_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=T5LayerSelfAttention)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_T5_layer_cross_attention_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=T5LayerCrossAttention)
+
return policy
def postprocess(self):
- binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
-
- for k, v in binding_map:
- mod = getattr_(self.model, k)
- setattr_(self.model, v, mod)
return self.model
+ @staticmethod
+ def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
+ num_stages: int) -> Tuple[List[int], int]:
+ """
+ Distribute t5 layers into stages when pipeline parallel is used.
+ Return the layer distribution as a list and the starting stage of decoder.
+ If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
+ """
+
+ # number of encoder layers must be a positive integer
+ if num_encoder_layers <= 0:
+ raise ValueError("The number of encoder layers for T5 must be a positive integer.")
+
+ # number of layers should be large enough to fill in every stage
+ if num_encoder_layers + num_decoder_layers < num_stages:
+ raise ValueError("The total number of layers can't be smaller than number of stages.")
+
+ # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
+ if num_decoder_layers == 0:
+ return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
+
+ # the number of stages distributed between encoder and decoder is optmized in this way:
+ # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
+ # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
+ def objective(num_encoder_stages):
+ return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
+
+ num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
+ num_decoder_stages = num_stages - num_encoder_stages
+
+ encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
+ decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
+ return encoder_distribution + decoder_distribution, num_encoder_stages
+
+ @staticmethod
+ def get_t5_stage_index(layers_per_stage: List[int], stage: int,
+ decoder_starting_stage: int) -> Tuple[bool, int, int]:
+ """
+ Input the distribution of layers among stages, the current stage and the first stage of decoder.
+ Return the starting/ending idx of layers in encoder/decoder
+ """
+ if stage < decoder_starting_stage:
+ return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
+ else:
+ return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
+
+ def get_held_layers(self) -> List[nn.Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+ stage_manager = self.pipeline_stage_manager
+
+ model = self.model
+ encoder = self.model.encoder
+ decoder = getattr(self.model, 'decoder', None)
+
+ num_encoder_layers = len(encoder.block)
+ num_decoder_layers = len(decoder.block) if decoder else 0
+
+ held_layers = []
+ layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
+ num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
+ start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage,
+ decoder_starting_stage)
+
+ if stage_manager.stage < decoder_starting_stage:
+ # current stage is in t5's encoder
+ if stage_manager.is_first_stage():
+ held_layers.append(model.shared)
+ held_layers.append(encoder.embed_tokens)
+ held_layers.append(encoder.dropout)
+ if stage_manager.stage == decoder_starting_stage - 1:
+ held_layers.append(encoder.final_layer_norm)
+ held_layers.append(encoder.dropout)
+ held_layers.extend(encoder.block[start_idx:end_idx])
+ else:
+ # current stage is in t5's decoder
+ if stage_manager.stage == decoder_starting_stage:
+ held_layers.append(decoder.embed_tokens)
+ held_layers.append(decoder.dropout)
+ if stage_manager.is_last_stage():
+ held_layers.append(decoder.final_layer_norm)
+ held_layers.append(decoder.dropout)
+ held_layers.extend(decoder.block[start_idx:end_idx])
+ return held_layers
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if not self.pipeline_stage_manager:
+ raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
+ stage_manager = self.pipeline_stage_manager
+
+ encoder = self.model.encoder
+ decoder = getattr(self.model, 'decoder', None)
+
+ num_encoder_layers = len(encoder.block)
+ num_decoder_layers = len(decoder.block) if decoder else 0
+
+ layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
+ num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
+ stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
+
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+ }
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
+
class T5ModelPolicy(T5BasePolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
def module_policy(self):
from transformers import T5Model
- base_policy = super().module_policy()
+ policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
- policy=base_policy,
+ policy=policy,
target_key=T5Model)
- return base_policy
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ return super().get_held_layers()
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ module = self.model
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager is not None and stage_manager.num_stages > 1:
+ _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
+ len(module.decoder.block),
+ stage_manager.num_stages)
+
+ if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
+ return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
+ return []
class T5ForConditionalGenerationPolicy(T5BasePolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
def module_policy(self):
from transformers import T5ForConditionalGeneration
@@ -207,43 +387,71 @@ def module_policy(self):
],
policy=policy,
target_key=T5ForConditionalGeneration)
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=T5ForConditionalGeneration,
+ new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
+ policy=policy)
return policy
- def postprocess(self):
- super().postprocess()
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
- binding_map = {"shared": "lm_head"}
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ module = self.model
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager is not None and stage_manager.num_stages > 1:
+ _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
+ len(module.decoder.block),
+ stage_manager.num_stages)
- for k, v in binding_map.items():
- src_mod = getattr_(self.model, k)
- dst_mod = getattr_(self.model, v)
- dst_mod.weight = src_mod.weight
+ shared_params = []
+ shared_embedding = {}
+ if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
+ shared_embedding[0] = module.shared.weight
+ shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight
- return self.model
+ if id(module.lm_head.weight) == id(module.shared.weight):
+ shared_embedding[0] = module.shared.weight
+ shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight
+
+ if len(shared_embedding) > 0:
+ shared_params.append(shared_embedding)
+
+ return shared_params
+
+ return []
class T5EncoderPolicy(T5BasePolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
def module_policy(self):
from transformers import T5EncoderModel
- base_policy = super().module_policy()
+ policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
- policy=base_policy,
+ policy=policy,
target_key=T5EncoderModel)
- return base_policy
- def postprocess(self):
- binding_map = [
- ["shared", "encoder.embed_tokens"],
- ]
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=T5EncoderModel,
+ new_forward=T5PipelineForwards.t5_encoder_model_forward,
+ policy=policy)
+ return policy
- for k, v in binding_map:
- mod = getattr_(self.model, k)
- setattr_(self.model, v, mod)
- return self.model
+ def get_held_layers(self) -> List[nn.Module]:
+ return super().get_held_layers()
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ return []
diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py
index eaebe2eee0ba..b4fb8692e684 100644
--- a/colossalai/shardformer/policies/vit.py
+++ b/colossalai/shardformer/policies/vit.py
@@ -1,12 +1,22 @@
-from typing import Dict, Union
+import warnings
+from typing import Callable, Dict, List, Union
import torch.nn as nn
-from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
+import colossalai.shardformer.layer as col_nn
+from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col
-from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from ..modeling.vit import (
+ ViTForImageClassification_pipeline_forward,
+ ViTForMaskedImageModeling_pipeline_forward,
+ ViTModel_pipeline_forward,
+ get_jit_fused_vit_output_forward,
+ get_vit_flash_self_attention_forward,
+)
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
-__all__ = ['ViTPolicy']
+__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy']
class ViTPolicy(Policy):
@@ -15,96 +25,224 @@ def config_sanity_check(self):
pass
def preprocess(self):
- # Resize embedding
- vocab_size = self.model.config.vocab_size
- world_size = self.shard_config.tensor_parallel_size
-
- if vocab_size % world_size != 0:
- new_vocab_size = vocab_size + world_size - vocab_size % world_size
- self.model.resize_token_embeddings(new_vocab_size)
-
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
- from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
-
- base_policy = {
- ViTEmbeddings:
- ModulePolicyDescription(sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=DropoutForReplicatedInput,
- )
- ]),
- ViTLayer:
- ModulePolicyDescription(attribute_replacement={
- "attention.attention.num_attention_heads":
- self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
- "attention.attention.all_head_size":
- self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- },
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="attention.attention.query",
- target_module=Linear1D_Col,
- ),
- SubModuleReplacementDescription(
- suffix="attention.attention.key",
- target_module=Linear1D_Col,
- ),
- SubModuleReplacementDescription(
- suffix="attention.attention.value",
- target_module=Linear1D_Col,
- ),
- SubModuleReplacementDescription(
- suffix="attention.attention.dropout",
- target_module=DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="attention.output.dense",
- target_module=Linear1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="attention.output.dropout",
- target_module=DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="intermediate.dense",
- target_module=Linear1D_Col,
- ),
- SubModuleReplacementDescription(
- suffix="output.dense",
- target_module=Linear1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="output.dropout",
- target_module=DropoutForParallelInput,
- ),
- ]),
- }
-
- # optimization configuration
- if self.shard_config.enable_fused_normalization:
- base_policy[ViTAttention].sub_module_replacement.extend([
- SubModuleReplacementDescription(
- suffix="layernorm_before",
- target_module=FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="layernorm_after",
- target_module=FusedLayerNorm,
- )
- ])
- base_policy[ViTModel].sub_module_replacement.append(
- SubModuleReplacementDescription(
- suffix="layernorm",
- target_module=FusedLayerNorm,
- ))
-
- return base_policy
+
+ from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention
+
+ policy = {}
+
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
+ if self.shard_config.enable_tensor_parallelism:
+ policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
+ param_replacement=[],
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForReplicatedInput,
+ )
+ ])
+
+ policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
+ "attention.attention.num_attention_heads":
+ self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "attention.attention.all_head_size":
+ self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ },
+ param_replacement=[],
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="attention.attention.query",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.attention.key",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.attention.value",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.attention.dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.output.dense",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="attention.output.dropout",
+ target_module=col_nn.DropoutForReplicatedInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="intermediate.dense",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="output.dense",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="output.dropout",
+ target_module=col_nn.DropoutForReplicatedInput,
+ ),
+ ])
+
+ # use flash attention
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_vit_flash_self_attention_forward(),
+ },
+ policy=policy,
+ target_key=ViTSelfAttention)
+
+ # use jit fused operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_vit_output_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=ViTOutput)
+ return policy
def new_model_class(self):
return None
def postprocess(self):
return self.model
+
+ def get_held_layers(self) -> List[nn.Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
+
+ if self.model.__class__.__name__ == 'ViTModel':
+ module = self.model
+ else:
+ module = self.model.vit
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embeddings)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.encoder.layer[start_idx:end_idx])
+ return held_layers
+
+ def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict):
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == 'ViTModel':
+ module = self.model
+ else:
+ module = self.model.vit
+
+ layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=model_cls)
+
+
+# ViTModel
+class ViTModelPolicy(ViTPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.vit.modeling_vit import ViTModel
+
+ policy = super().module_policy()
+
+ if self.shard_config.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
+
+ module = self.model
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(module.layernorm)
+ held_layers.append(module.pooler)
+
+ return held_layers
+
+
+# ViTForImageClassification
+class ViTForImageClassificationPolicy(ViTPolicy):
+
+ def module_policy(self):
+ from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel
+
+ policy = super().module_policy()
+ if self.shard_config.enable_tensor_parallelism:
+ new_item = {
+ ViTForImageClassification:
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
+ ])
+ }
+ policy.update(new_item)
+
+ if self.shard_config.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
+ self.set_pipeline_forward(model_cls=ViTForImageClassification,
+ pipeline_forward=ViTForImageClassification_pipeline_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
+
+ module = self.model.vit
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(module.layernorm)
+ held_layers.append(self.model.classifier)
+
+ return held_layers
+
+
+# ViTForMaskedImageModeling
+class ViTForMaskedImageModelingPolicy(ViTPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
+
+ policy = super().module_policy()
+
+ if self.shard_config.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy)
+ self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling,
+ pipeline_forward=ViTForMaskedImageModeling_pipeline_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
+
+ module = self.model.vit
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager.is_last_stage():
+ held_layers.append(module.layernorm)
+ held_layers.append(self.model.decoder)
+
+ return held_layers
diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py
new file mode 100644
index 000000000000..5d496f08e1db
--- /dev/null
+++ b/colossalai/shardformer/policies/whisper.py
@@ -0,0 +1,495 @@
+import warnings
+from functools import partial
+from typing import Callable, Dict, List, Tuple
+
+import numpy as np
+import torch.nn as nn
+from torch import Tensor
+
+import colossalai.shardformer.layer as col_nn
+
+from .._utils import getattr_, setattr_
+from ..modeling.jit import get_jit_fused_dropout_add_func
+from ..modeling.whisper import (
+ WhisperPipelineForwards,
+ get_jit_fused_whisper_decoder_layer_forward,
+ get_jit_fused_whisper_encoder_layer_forward,
+ get_whisper_flash_attention_forward,
+)
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+__all__ = [
+ 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
+ 'WhisperForAudioClassificationPolicy'
+]
+
+
+class WhisperPolicy(Policy):
+
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ # reshape the embedding layer
+ r"""
+ Reshape the Embedding layer to make the embedding dimension divisible by world_size
+ """
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
+ return self.model
+
+ def module_policy(self):
+ from transformers.models.whisper.modeling_whisper import (
+ WhisperAttention,
+ WhisperDecoder,
+ WhisperDecoderLayer,
+ WhisperEncoder,
+ WhisperEncoderLayer,
+ )
+
+ policy = {}
+
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ warnings.warn(
+ "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
+ if self.shard_config.enable_tensor_parallelism:
+ policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
+ "self_attn.embed_dim":
+ self.model.config.d_model // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads":
+ self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.out_proj",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="fc1",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="fc2",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ ])
+
+ policy[WhisperDecoderLayer] = ModulePolicyDescription(attribute_replacement={
+ "self_attn.embed_dim":
+ self.model.config.d_model // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads":
+ self.model.config.decoder_attention_heads // self.shard_config.tensor_parallel_size,
+ "encoder_attn.embed_dim":
+ self.model.config.d_model // self.shard_config.tensor_parallel_size,
+ "encoder_attn.num_heads":
+ self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.out_proj",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="encoder_attn.q_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="encoder_attn.k_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="encoder_attn.v_proj",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="encoder_attn.out_proj",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="fc1",
+ target_module=col_nn.Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="fc2",
+ target_module=col_nn.Linear1D_Row,
+ ),
+ ])
+
+ policy[WhisperDecoder] = ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=col_nn.VocabParallelEmbedding1D,
+ ),
+ ])
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ # Handle encoder layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="self_attn_layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=WhisperEncoderLayer)
+
+ # Handle decoder layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="self_attn_layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=WhisperDecoderLayer)
+
+ # handle encoder layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=WhisperEncoder)
+
+ # handle decoder layer
+ self.append_or_create_submodule_replacement(description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm",
+ target_module=col_nn.FusedLayerNorm,
+ )
+ ],
+ policy=policy,
+ target_key=WhisperDecoder)
+
+ # enable flash attention
+ if self.shard_config.enable_flash_attention:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_whisper_flash_attention_forward(),
+ },
+ policy=policy,
+ target_key=WhisperAttention)
+
+ # use jit fused operator
+ if self.shard_config.enable_jit_fused:
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_whisper_decoder_layer_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=WhisperDecoderLayer)
+ self.append_or_create_method_replacement(description={
+ 'forward': get_jit_fused_whisper_encoder_layer_forward(),
+ 'dropout_add': get_jit_fused_dropout_add_func(),
+ },
+ policy=policy,
+ target_key=WhisperEncoderLayer)
+
+ return policy
+
+ def add_lm_head_policy(self, base_policy):
+ from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration
+
+ # optimize for tensor parallelism
+ if self.shard_config.enable_tensor_parallelism:
+ self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
+ suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
+ policy=base_policy,
+ target_key=WhisperForConditionalGeneration)
+
+ return base_policy
+
+ def postprocess(self):
+ return self.model
+
+ @staticmethod
+ def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
+ num_stages: int) -> Tuple[List[int], int]:
+ """
+ Distribute whisper layers into stages when pipeline parallel is used.
+ Return the layer distribution as a list and the starting stage of decoder.
+ If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
+ """
+
+ # number of encoder layers must be a positive integer
+ if num_encoder_layers <= 0:
+ raise ValueError("The number of encoder layers for whisper must be a positive integer.")
+
+ # number of layers should be large enough to fill in every stage
+ if num_encoder_layers + num_decoder_layers < num_stages:
+ raise ValueError("The total number of layers can't be smaller than number of stages.")
+
+ # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
+ if num_decoder_layers == 0:
+ return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
+
+ # the number of stages distributed between encoder and decoder is optmized in this way:
+ # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
+ # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
+ def objective(num_encoder_stages):
+ return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
+
+ num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
+ num_decoder_stages = num_stages - num_encoder_stages
+
+ encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
+ decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
+ return encoder_distribution + decoder_distribution, num_encoder_stages
+
+ @staticmethod
+ def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
+ decoder_starting_stage: int) -> Tuple[bool, int, int]:
+ """
+ Input the distribution of layers among stages, the current stage and the first stage of decoder.
+ Return the starting/ending idx of layers in encoder/decoder
+ """
+ if stage < decoder_starting_stage:
+ return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
+ else:
+ return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
+
+ def get_held_layers(self) -> List[nn.Module]:
+
+ assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
+ stage_manager = self.pipeline_stage_manager
+
+ if self.model.__class__.__name__ == 'WhisperModel':
+ model = self.model
+ elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
+ model = self.model.model
+ else:
+ model = None
+
+ if model:
+ encoder = self.model.get_encoder()
+ decoder = self.model.get_decoder()
+ else:
+ # whisper for audio classification holds encoder only
+ encoder = self.model.encoder
+ decoder = None
+
+ num_encoder_layers = len(encoder.layers)
+ if decoder:
+ num_decoder_layers = len(decoder.layers)
+ else:
+ num_decoder_layers = 0
+
+ held_layers = []
+ layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
+ num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
+ start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
+ decoder_starting_stage)
+
+ if stage_manager.stage < decoder_starting_stage:
+ # current stage is in whisper's encoder
+ if stage_manager.is_first_stage():
+ held_layers.append(encoder.embed_positions)
+ held_layers.append(encoder.conv1)
+ held_layers.append(encoder.conv2)
+ if stage_manager.stage == decoder_starting_stage - 1:
+ held_layers.append(encoder.layer_norm)
+ held_layers.extend(encoder.layers[start_idx:end_idx])
+ else:
+ # current stage is in whisper's decoder
+ # TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
+ # the case encoder and decoder put in same stage should be add in the future.
+ if stage_manager.stage == decoder_starting_stage:
+ held_layers.append(decoder.embed_tokens)
+ held_layers.append(decoder.embed_positions)
+ if stage_manager.is_last_stage():
+ held_layers.append(decoder.layer_norm)
+ held_layers.extend(decoder.layers[start_idx:end_idx])
+ return held_layers
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if not self.pipeline_stage_manager:
+ raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
+ stage_manager = self.pipeline_stage_manager
+
+ if self.model.__class__.__name__ == 'WhisperModel':
+ model = self.model
+ elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
+ model = self.model.model
+ else:
+ model = None
+
+ if model:
+ encoder = self.model.get_encoder()
+ decoder = self.model.get_decoder()
+ else:
+ encoder = self.model.encoder
+ decoder = None
+
+ num_encoder_layers = len(encoder.layers)
+ if decoder:
+ num_decoder_layers = len(decoder.layers)
+ else:
+ num_decoder_layers = 0
+
+ layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
+ num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
+ stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
+ decoder_starting_stage)
+
+ method_replacement = {
+ 'forward':
+ partial(new_forward,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ decoder_starting_stage=decoder_starting_stage)
+ }
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
+
+
+# WhisperModel
+class WhisperModelPolicy(WhisperPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers import WhisperModel
+ policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=WhisperModel,
+ new_forward=WhisperPipelineForwards.whisper_model_forward,
+ policy=policy)
+
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ return super().get_held_layers()
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ "no shared params in whisper model"
+ return []
+
+
+# WhisperForConditionalGeneration
+class WhisperForConditionalGenerationPolicy(WhisperPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers import WhisperForConditionalGeneration
+ policy = super().module_policy()
+ policy = self.add_lm_head_policy(policy)
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
+ new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
+ policy=policy)
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.proj_out)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ module = self.model
+ model = module.model
+
+ if model:
+ encoder = self.model.get_encoder()
+ decoder = self.model.get_decoder()
+ else:
+ encoder = self.model.encoder
+ decoder = None
+
+ num_encoder_layers = len(encoder.layers)
+ if decoder:
+ num_decoder_layers = len(decoder.layers)
+ else:
+ num_decoder_layers = 0
+
+ stage_manager = self.pipeline_stage_manager
+ if stage_manager is not None and stage_manager.num_stages > 1:
+ _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
+ stage_manager.num_stages)
+ shared_params = []
+ shared_embedding = {}
+ if id(module.proj_out) == id(model.decoder.embed_tokens):
+ shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens
+ shared_embedding[stage_manager.num_stages - 1] = module.proj_out
+ if len(shared_embedding) > 0:
+ shared_params.append(shared_embedding)
+ return shared_params
+ return []
+
+
+# WhisperForAudioClassification
+class WhisperForAudioClassificationPolicy(WhisperPolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def preprocess(self):
+ return self.model
+
+ def module_policy(self):
+ from transformers import WhisperForAudioClassification
+ policy = super().module_policy()
+
+ if self.pipeline_stage_manager is not None:
+ self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
+ new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
+ policy=policy)
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ held_layers = super().get_held_layers()
+ if self.pipeline_stage_manager.is_last_stage():
+ held_layers.append(self.model.projector)
+ held_layers.append(self.model.classifier)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ return []
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 83c08d275df3..4380ac30814d 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -1,8 +1,11 @@
from dataclasses import dataclass
+from typing import Optional
import torch.distributed as dist
from torch.distributed import ProcessGroup
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
__all__ = ['ShardConfig']
@@ -12,17 +15,27 @@ class ShardConfig:
The config for sharding the huggingface model
Args:
- tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
+ tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group.
+ pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline.
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
+ enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False.
+ enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False.
"""
- tensor_parallel_process_group: ProcessGroup = None
+ tensor_parallel_process_group: Optional[ProcessGroup] = None
+ pipeline_stage_manager: Optional[PipelineStageManager] = None
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
enable_all_optimization: bool = False
+ enable_flash_attention: bool = False
+ enable_jit_fused: bool = False
+ enable_sequence_parallelism: bool = False
+ enable_sequence_overlap: bool = False
+ inference_only: bool = False
+ enable_sequence_parallelism: bool = False
+ enable_sequence_overlap: bool = False
- # TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@@ -34,12 +47,16 @@ def tensor_parallel_size(self):
return self._tensor_parallel_size
def __post_init__(self):
+ if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
+ raise ValueError(
+ "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True")
+ if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
+ raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
# get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
-
# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization:
self._turn_on_all_optimization()
@@ -50,3 +67,13 @@ def _turn_on_all_optimization(self):
"""
# you can add all the optimization flag here
self.enable_fused_normalization = True
+ self.enable_flash_attention = True
+ self.enable_jit_fused = True
+ self.enable_sequence_parallelism = True
+ self.enable_sequence_overlap = True
+
+ def _infer(self):
+ """
+ Set default params for inference.
+ """
+ assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py
index 201e0a08cbfe..7592069a2dd9 100644
--- a/colossalai/shardformer/shard/sharder.py
+++ b/colossalai/shardformer/shard/sharder.py
@@ -1,11 +1,16 @@
-from typing import Any, Callable, Dict, List, Union
+from types import MethodType
+from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch.nn as nn
+from torch import Tensor
+
+from colossalai.lazy import LazyInitContext
from .._utils import getattr_, setattr_
-from ..policies.autopolicy import get_autopolicy
-from ..policies.basepolicy import Policy, SubModuleReplacementDescription
+from ..policies.auto_policy import get_autopolicy
+from ..policies.base_policy import Policy, SubModuleReplacementDescription
from .shard_config import ShardConfig
+from .utils import set_tensors_to_none
__all__ = ['ModelSharder', 'shard_model']
@@ -22,18 +27,23 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
- self.policy = get_autopolicy(self.model) if policy is None else policy
+ self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
self.shard_config = shard_config
- def shard(self) -> None:
+ def shard(self) -> List[Dict[int, Tensor]]:
r"""
Shard the model according to the policy
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
- self._replace_module()
+ # get shared params before release unheld layers, this avoid misjudgement of shared params (None is None)
+ shared_params = self.policy.get_shared_params()
+ held_layers = self._release_unheld_layers()
+ self._replace_module(include=held_layers)
+ self._materialize()
self._postprocess()
+ return shared_params
def _preprocess(self) -> None:
self.model = self.policy.preprocess()
@@ -41,7 +51,7 @@ def _preprocess(self) -> None:
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
- def _replace_module(self,) -> None:
+ def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None:
r"""
Replace the module according to the policy, and replace the module one by one
@@ -54,8 +64,13 @@ def _replace_module(self,) -> None:
param_replacement = module_description.param_replacement
sub_module_replacement = module_description.sub_module_replacement
method_replacement = module_description.method_replacement
- self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement,
- method_replacement, sub_module_replacement)
+ self._recursive_replace_layer(self.model,
+ layer_cls,
+ attr_replacement,
+ param_replacement,
+ method_replacement,
+ sub_module_replacement,
+ include=include)
def _recursive_replace_layer(
self,
@@ -64,35 +79,43 @@ def _recursive_replace_layer(
attr_replacement: Dict[str, Any],
param_replacement: List[Callable],
method_replacement: Dict[str, Callable],
- sub_module_replacement: List[Callable],
+ sub_module_replacement: List[SubModuleReplacementDescription],
+ include: Optional[Set[nn.Module]] = None,
) -> None:
r"""
Reverse the replace layer operation
Args:
- layer (torch.nn.Module): The object of layer to shard
- origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
- attr_replacement (Dict): The attribute dict to modify
+ module (torch.nn.Module): The object of layer to shard
+ origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name
+ attr_replacement (Dict[str, Any]): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in policy
- sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
+ method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
+ sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
+ include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
- if param_replacement is not None:
+ if param_replacement is not None and (include is None or module in include):
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
if sub_module_replacement is not None:
- self._replace_sub_module(module, sub_module_replacement)
+ self._replace_sub_module(module, sub_module_replacement, include)
for name, child in module.named_children():
- self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
- sub_module_replacement)
+ self._recursive_replace_layer(child,
+ origin_cls,
+ attr_replacement,
+ param_replacement,
+ method_replacement,
+ sub_module_replacement,
+ include=include)
def _replace_attr(
self,
@@ -103,7 +126,7 @@ def _replace_attr(
Replace the attribute of the layer
Args:
- layer (:class:`torch.nn.Module`): The object of layer to shard
+ module (:class:`torch.nn.Module`): The object of layer to shard
attr_replacement (Dict): The attribute dict to modify
"""
for k, v in attr_replacement.items():
@@ -118,7 +141,7 @@ def _replace_param(
Replace the parameter of the layer
Args:
- layer (:class:`torch.nn.Module`): The object of layer to shard
+ module (:class:`torch.nn.Module`): The object of layer to shard
param_replacement (List[Callable]): The function list to get parameter shard information in policy
"""
for param_func in param_replacement:
@@ -127,20 +150,20 @@ def _replace_param(
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
for method_name, new_method in method_replacement.items():
# bind the new method to the module
- setattr(module, method_name, new_method.__get__(module, module.__class__))
+ bound_method = MethodType(new_method, module)
+ setattr(module, method_name, bound_method)
- def _replace_sub_module(
- self,
- org_layer: nn.Module,
- sub_module_replacement: List[SubModuleReplacementDescription],
- ) -> None:
+ def _replace_sub_module(self,
+ org_layer: nn.Module,
+ sub_module_replacement: List[SubModuleReplacementDescription],
+ include: Optional[Set[nn.Module]] = None) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
-
+ include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
for description in sub_module_replacement:
suffix = description.suffix
@@ -149,9 +172,12 @@ def _replace_sub_module(
assert target_module is not None, 'target_module should not be None'
- # TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix, ignore=True)
+ # Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
+ if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
+ continue
+
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"
@@ -172,3 +198,33 @@ def _replace_sub_module(
)
setattr_(org_layer, suffix, replace_layer)
+
+ def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]:
+
+ def collect_sub_modules(module: nn.Module):
+ if module is None:
+ return
+ recursive_held_layers.append(module)
+ for name, child in module.named_children():
+ collect_sub_modules(child)
+
+ recursive_held_layers = []
+ for module in held_layers:
+ collect_sub_modules(module)
+ return recursive_held_layers
+
+ def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
+ r"""
+ Release the unheld layers in the model
+ """
+ if self.shard_config and self.shard_config.pipeline_stage_manager:
+ held_layers = self.policy.get_held_layers()
+ set_tensors_to_none(self.model, exclude=set(held_layers))
+ return set(self._get_recursive_held_layers(held_layers))
+ return None
+
+ def _materialize(self) -> None:
+ r"""
+ Materialize the model if lazy initialization is used
+ """
+ LazyInitContext.materialize(self.model)
diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py
index 3fce12463414..7a0d75bf2f2a 100644
--- a/colossalai/shardformer/shard/shardformer.py
+++ b/colossalai/shardformer/shard/shardformer.py
@@ -1,8 +1,11 @@
+from typing import Dict, List, Tuple
+
import torch.nn as nn
+from torch import Tensor
from colossalai.cluster import DistCoordinator
-from ..policies.basepolicy import Policy
+from ..policies.base_policy import Policy
from .shard_config import ShardConfig
from .sharder import ModelSharder
@@ -24,7 +27,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
- model = shard_former.optimize(org_model)
+ model, shared_params = shard_former.optimize(org_model)
```
"""
@@ -32,7 +35,7 @@ def __init__(self, shard_config: ShardConfig):
self.coordinator = DistCoordinator()
self.shard_config = shard_config
- def optimize(self, model: nn.Module, policy: Policy = None):
+ def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
r"""
This method will optimize the model based on the given policy.
@@ -40,7 +43,9 @@ def optimize(self, model: nn.Module, policy: Policy = None):
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
+
+ Returns: the sharded model and the shared parameters
"""
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
- sharder.shard()
- return model
+ shared_params = sharder.shard()
+ return model, shared_params
diff --git a/colossalai/shardformer/shard/utils.py b/colossalai/shardformer/shard/utils.py
new file mode 100644
index 000000000000..2bac37bfedda
--- /dev/null
+++ b/colossalai/shardformer/shard/utils.py
@@ -0,0 +1,19 @@
+from typing import Set
+
+import torch.nn as nn
+
+
+def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None:
+ """Set all parameters and buffers of model to None
+
+ Args:
+ model (nn.Module): The model to set
+ """
+ if model in exclude:
+ return
+ for child in model.children():
+ set_tensors_to_none(child, exclude=exclude)
+ for n, p in model.named_parameters(recurse=False):
+ setattr(model, n, None)
+ for n, buf in model.named_buffers(recurse=False):
+ setattr(model, n, None)
diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py
index b384579feb35..076661a08824 100644
--- a/colossalai/tensor/colo_parameter.py
+++ b/colossalai/tensor/colo_parameter.py
@@ -3,9 +3,15 @@
import torch
from colossalai.tensor.colo_tensor import ColoTensor
-from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
-from colossalai.tensor.tensor_spec import ColoTensorSpec
+
+from .colo_tensor import _convert_output
+
+WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
+
+
+def is_no_hook_op(func) -> bool:
+ return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS
def filter_colo_parameters(*args, **kwargs):
@@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
"""
- def __new__(cls,
- data: Optional[torch.Tensor] = None,
- requires_grad: bool = True,
- spec: ColoTensorSpec = None) -> 'ColoParameter':
+ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
- def __init__(self,
- data: Optional[torch.Tensor] = None,
- requires_grad: bool = True,
- spec: ColoTensorSpec = None) -> None:
- ColoTensor.__init__(self, data, spec)
- self._type = TensorType.MODEL
- # a list contains modules sharing this ColoParameter with others.
- self._shared_param_modules = []
-
- @property
- def shared_param_modules(self):
- return self._shared_param_modules
-
- @staticmethod
- def from_torch_tensor(tensor: torch.Tensor,
- requires_grad: bool = True,
- spec: ColoTensorSpec = None) -> 'ColoParameter':
- tensor = tensor.as_subclass(ColoParameter)
- tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
- return tensor
-
- def __repr__(self):
- return super(ColoParameter, self).__repr__()
-
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
- if ColoParamOpHookManager.has_hook():
- if not func.__name__.startswith('__'):
- if kwargs is None:
- kwargs = {}
- params = filter_colo_parameters(*args, **kwargs)
- if len(params) > 0:
- with torch._C.DisableTorchFunction():
- new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
- args, kwargs = replace_args(args, kwargs, new_args)
- ret = super().__torch_function__(func, types, args, kwargs)
- with torch._C.DisableTorchFunction():
- ret = ColoParamOpHookManager.post_op(params, ret)
- return ret
+ if kwargs is None:
+ kwargs = {}
+ if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
+ params = filter_colo_parameters(*args, **kwargs)
+ if len(params) > 0:
+ with torch._C.DisableTorchFunction():
+ new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
+ args, kwargs = replace_args(args, kwargs, new_args)
+ ret = super().__torch_function__(func, types, args, kwargs)
+ with torch._C.DisableTorchFunction():
+ ret = ColoParamOpHookManager.post_op(params, ret)
+ return _convert_output(ret, func)
return super().__torch_function__(func, types, args, kwargs)
def __deepcopy__(self, memo):
@@ -96,9 +74,7 @@ def __deepcopy__(self, memo):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
- tensor = ColoParameter(data,
- self.requires_grad,
- spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
+ tensor = ColoParameter(data, self.requires_grad)
memo[id(self)] = tensor
return tensor
diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py
index 4d762076461d..a20a1444a406 100644
--- a/colossalai/tensor/colo_tensor.py
+++ b/colossalai/tensor/colo_tensor.py
@@ -1,17 +1,14 @@
-import operator
-from copy import copy
-from functools import lru_cache, reduce
-from typing import Callable, Optional, Set
+from functools import lru_cache
+from typing import Callable, Set
import torch
-from colossalai.tensor.dist_spec_mgr import DistSpecManager
-from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
-from colossalai.tensor.process_group import ProcessGroup
-from colossalai.tensor.tensor_spec import ColoTensorSpec
-
-from .const import TensorType
-from .op_wrapper import _COLOSSAL_OPS
+INPALCE_MAPPING = {
+ torch.Tensor.add_: torch.Tensor.add,
+ torch.Tensor.sub_: torch.Tensor.sub,
+ torch.Tensor.mul_: torch.Tensor.mul,
+ torch.Tensor.div_: torch.Tensor.div
+}
@lru_cache(None)
@@ -25,61 +22,37 @@ def _get_my_nowrap_functions() -> Set[Callable]:
}
-def _convert_output(output, colo_spec: ColoTensorSpec):
- if type(output) == torch.Tensor:
- return ColoTensor.from_torch_tensor(output, colo_spec)
+def _convert(output):
+ if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
+ output.__class__ = ColoTensor
elif isinstance(output, (list, tuple)):
- return type(output)(_convert_output(o, colo_spec) for o in output)
- else:
- return output
+ output = type(output)(_convert(o) for o in output)
+ return output
-def _get_spec_from_args(args, kwargs) -> ColoTensorSpec:
- for elem in args:
- if isinstance(elem, ColoTensor):
- pg = elem.get_process_group()
- dp = elem.dist_spec
- return ColoTensorSpec(pg, dp)
- elif isinstance(elem, (list, tuple)):
- spec = _get_spec_from_args(elem, {})
- if spec is not None:
- return spec
- for k, v in kwargs.items():
- if isinstance(v, ColoTensor):
- pg = v.get_process_group()
- dp = v.dist_spec
- return ColoTensorSpec(pg, dp)
- return None
+def _convert_output(output, func):
+ if func in _get_my_nowrap_functions():
+ return output
+ return _convert(output)
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
- The Colotensor can be initialized with a PyTorch tensor in the following ways.
-
- >>> pg = ProcessGroup()
- >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
- >>> # The tensor passed in is a tensor after sharding but not a global tensor.
- >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
- >>> dims=[0],
- >>> num_partitions=[world_size])
- >>> tensor_spec = ColoTensorSpec(pg, shard_spec)
- >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
+ It is only used to trigger the torch function hook.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
- spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
"""
torch_major = int(torch.__version__.split('.')[0])
torch_minor = int(torch.__version__.split('.')[1])
- def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
+ def __new__(cls, data: torch.Tensor) -> 'ColoTensor':
"""
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
- spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
@@ -88,86 +61,6 @@ def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
- def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
- # If not set spec, use a DP process group and replicate dist spec
- if spec is None:
- self.has_initialized = False
- self.dist_spec = ReplicaSpec()
- self.compute_spec = None
- self.process_group = ProcessGroup()
- else:
- self.has_initialized = True
- self.dist_spec = spec.dist_attr
- self.compute_spec = spec.compute_attr
- if spec.pg is None:
- self.process_group = ProcessGroup()
- else:
- self.process_group = spec.pg
-
- self._type = TensorType.NONMODEL
-
- def has_compute_spec(self) -> bool:
- return self.compute_spec is not None
-
- def is_model_data(self) -> bool:
- return self._type == TensorType.MODEL
-
- def get_process_group(self) -> 'ProcessGroup':
- return self.process_group
-
- def set_process_group(self, pg: ProcessGroup):
- """set_process_group
- change the pg of the ColoTensor. Note that the valid use cases is limited.
- It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
-
- Args:
- pg (ProcessGroup): target pg
-
- """
- assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
- # if the new pg is the same as the old pg, just returns
- if self.process_group == pg:
- return
- assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
- "Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
- assert self.dist_spec.placement.value == 'r', \
- "Can not set_process_group on a ColoTensor whose dist spec is not Replica"
-
- self.process_group = pg
-
- def get_tp_world_size(self) -> int:
- return self.process_group.tp_world_size()
-
- def get_dp_world_size(self) -> int:
- """get_dp_world_size
- get the dp world size of the tensor.
-
- Returns:
- int: dp world size
- """
- return self.process_group.dp_world_size()
-
- def set_dist_spec(self, dist_spec: _DistSpec):
- """set_dist_spec
- set dist spec and change the payloads.
-
- Args:
- dist_spec (_DistSpec): target dist spec.
- """
- assert isinstance(dist_spec, _DistSpec)
- assert self.process_group is not None
- self._redistribute(dist_spec)
-
- def set_tensor_spec(self, dist_spec, compute_spec):
- if dist_spec is not None:
- assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
- self.set_dist_spec(dist_spec)
- if compute_spec is not None:
- self.compute_spec = compute_spec
-
- def has_compute_pattern(self, compute_pattern):
- return self.compute_spec.compute_pattern == compute_pattern
-
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
@@ -175,9 +68,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
- global _COLOSSAL_OPS
- if func in _COLOSSAL_OPS:
- func = _COLOSSAL_OPS[func]
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
# in order to trigger pre-op hook in the forward of checkpoint module
@@ -189,94 +79,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
+ # replace the in-place function
+ if func in INPALCE_MAPPING:
+ func = INPALCE_MAPPING[func]
+ # set the 'inplace' kwargs to False
+ if 'inplace' in kwargs:
+ kwargs['inplace'] = False
+
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
- if func in _get_my_nowrap_functions():
- return ret
- else:
- colo_spec = _get_spec_from_args(args, kwargs)
- return _convert_output(ret, colo_spec)
-
- def __repr__(self):
- output_list = [super(ColoTensor, self).__repr__()]
- output_list.append(str(self.process_group))
- output_list.append(str(self.dist_spec))
- if self.compute_spec is not None:
- output_list.append(str(self.compute_spec))
- return "\n".join(output_list)
-
- def _redistribute(self, dist_spec: _DistSpec) -> None:
- """_redistribute
- Note the function will not handle the logic of backward propagation!
- It is used during model tensor initializations as an internal function.
-
- Args:
- dist_spec (_DistSpec): the target dist. spec.
- """
- assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
- with DistSpecManager.no_grad():
- self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
- self.dist_spec = dist_spec
-
- def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
- """redistribute
- Redistribute the tensor among processes. The rule is like this:
-
- 1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
- DP process group not changed.
-
- 2. If the pg is not not None and not equal to the current process group.
- First, convert the tensor as replicated among the TP process group.
- Second, reset the process group to the new pg.
- Third, convert the tensor (new replicated both among the tp process group) to the new dist_spec.
-
- Args:
- dist_spec (_DistSpec): the new dist spec.
- pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
-
- Returns:
- ColoTensor: a redistributed colotensor
- """
- if pg is not None and pg != self.get_process_group():
- # if the pg is not equal, convert the current tensor to replicated
- handled = self.redistribute(ReplicaSpec())
- else:
- handled = self
- pg = self.process_group
-
- ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
- return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
-
- def to_replicate_(self):
- """to_replicate_
-
- an inline member function, converting dist spec of the tensor to REPLICATE
- """
- self._redistribute(dist_spec=ReplicaSpec())
-
- def to_replicate(self) -> 'ColoTensor':
- """to_replicate
-
- converting dist spec of the tensor to ReplicaSpec()
- """
- return self.redistribute(ReplicaSpec())
-
- @staticmethod
- def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
- """from_torch_tensor
-
- A static method builds a `ColoTensor` from a PyTorch Tensor.
-
- Args:
- tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
- spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
-
- Returns:
- ColoTensor: a ColoTensor
- """
- tensor = tensor.as_subclass(ColoTensor)
- tensor.__init__(tensor, spec=spec)
- return tensor
+ return _convert_output(ret, func)
def __deepcopy__(self, memo):
if id(self) in memo:
@@ -284,60 +96,6 @@ def __deepcopy__(self, memo):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
- tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec)))
+ tensor = ColoTensor(data)
memo[id(self)] = tensor
return tensor
-
- # override builtin functions which must use tensor in replicate placement #
-
- def size_local(self, *args) -> torch.Size:
- with torch._C.DisableTorchFunction():
- return super().size(*args)
-
- def size_global(self, *args) -> torch.Size:
- """size_global
-
- override the torch building size()
- the shape passed in must be in a replicate placement.
-
- Returns:
- torch.Size: the global tensor shape
- """
- if self.is_replicate():
- return self.size_local(*args)
- spec = self.dist_spec
- dims = spec.dims
- num_partitions = spec.num_partitions
- # import inspect
- # print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
- size_list = list(self.size_local())
- for dim, num_partition in zip(dims, num_partitions):
- size_list[dim] *= num_partition
- if args == ():
- return torch.Size(size_list)
- else:
- return size_list[args[0]]
-
- def numel_global(self):
- """Returns the number of elements in the tensor when it's replicated.
- """
- return reduce(operator.mul, self.size_global(), 1)
-
- # Some API for dist spec check
-
- def is_replicate(self):
- return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
- or (len(self.dist_spec.num_partitions) == 1
- and self.dist_spec.num_partitions[0] == 1) \
- or (self.process_group.tp_world_size() == 1)
-
- def is_shard_1dcol(self):
- return self.dist_spec.placement == DistPlacementPattern.SHARD \
- and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
-
- def is_shard_1drow(self):
- return self.dist_spec.placement == DistPlacementPattern.SHARD \
- and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
-
- def is_sharded(self):
- return self.dist_spec.placement == DistPlacementPattern.SHARD
diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py
index 95a44e09e16a..9848e4ca423e 100644
--- a/colossalai/tensor/d_tensor/api.py
+++ b/colossalai/tensor/d_tensor/api.py
@@ -16,6 +16,11 @@
layout_converter = LayoutConverter()
+def clear_layout_converter():
+ global layout_converter
+ layout_converter.cached_solution.clear()
+
+
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a distributed tensor.
@@ -235,6 +240,14 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
return param
+def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
+ assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
+ param.data = dtensor
+ # make it distributed as well
+ param.dist_layout = dtensor.dist_layout
+ _hijack_detach_and_clone(param)
+
+
def compute_global_numel(dtensor: torch.Tensor) -> int:
"""
Compute the global number of elements in the distributed tensor.
@@ -432,3 +445,15 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad:
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
return param
+
+
+def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
+ """
+ Convert the given customized distributed tensor to an existing parameter.
+ """
+ assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
+
+ param.data = dtensor.data
+ param.shard_fn = dtensor.shard_fn
+ param.gather_fn = dtensor.gather_fn
+ _hijack_detach_and_clone_for_customized_distributed_tensor(param)
diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py
index c968050de49d..4740a316b7f5 100644
--- a/colossalai/tensor/dist_spec_mgr.py
+++ b/colossalai/tensor/dist_spec_mgr.py
@@ -2,7 +2,6 @@
import torch
import torch.distributed as dist
-# from colossalai.nn.layer.utils import divide
from numpy import prod
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py
index 8ed8176d996a..e37859bac0c3 100644
--- a/colossalai/tensor/param_op_hook.py
+++ b/colossalai/tensor/param_op_hook.py
@@ -3,9 +3,7 @@
from typing import Any, List, Tuple
import torch
-
-from colossalai.tensor.colo_tensor import ColoTensor
-from colossalai.tensor.tensor_spec import ColoTensorSpec
+from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
class ColoParamOpHook(ABC):
@@ -82,26 +80,18 @@ def _trigger_post_backward(params: List[torch.Tensor]) -> None:
@staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ColoParamOpHookManager._trigger_pre_forward(params)
- grad_args, rear_args = _get_grad_args(*args)
- colo_info = _get_colo_tensors_info(*grad_args)
- rets = PreFwdPostBwd.apply(params, *grad_args)
- update_args = _update_colo_tensors(colo_info, *rets)
- if rear_args is None:
- return update_args
- else:
- arg_zero = (tuple(update_args),)
- return arg_zero + rear_args
+ # auto grad function can only recognize torch.Tensor, thus we have to flatten the input
+ # if one of the input requires grad, all the output will be treated as requires grad
+ # and will have grad fn even the corresponding input does not require grad
+ # we have to extract tensors requiring grad into flat list and then merge them back
+ grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
+ new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
+ return _merge_args(new_grad_args, other_args, grad_flags, spec)
@staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params)
- colo_info = _get_colo_tensors_info(arg)
- ret = PostFwdPreBwd.apply(params, arg)
- res = _update_colo_tensors(colo_info, ret)
- if len(res) == 1:
- return res[0]
- else:
- return res
+ return PostFwdPreBwd.apply(params, arg)
@staticmethod
def has_hook() -> bool:
@@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool:
return False
-def _has_grad_tensor(obj) -> bool:
- if isinstance(obj, tuple) or isinstance(obj, list):
- for x in obj:
- if _has_grad_tensor(x):
- return True
- return False
- elif isinstance(obj, dict):
- for x in obj.values():
- if _has_grad_tensor(x):
- return True
- return False
- else:
- return _is_grad_tensor(obj)
-
-
-def _get_grad_args(*args):
- # if there is no grad tensors, do nothing
- if not _has_grad_tensor(args):
- return args, None
- # returns the identical args if there is a grad tensor
- for obj in args:
- if _is_grad_tensor(obj):
- return args, None
- # otherwise, the first argument should be a tuple of grad tensors
- # if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
- arg_zero = args[0]
- if not isinstance(arg_zero, tuple):
- raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
- check_grad_flag = False
- for obj in arg_zero:
- check_grad_flag |= _is_grad_tensor(obj)
- if not check_grad_flag:
- raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
- return arg_zero, args[1:]
-
-
-def _get_colo_tensors_info(*args) -> list:
- info = []
- for arg in args:
- if isinstance(arg, ColoTensor):
- info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
+def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
+ flat_args, spec = tree_flatten(args)
+ grad_args = []
+ other_args = []
+ grad_flags = []
+ for arg in flat_args:
+ flag = _is_grad_tensor(arg)
+ grad_flags.append(flag)
+ if flag:
+ grad_args.append(arg)
else:
- info.append(None)
- return info
-
-
-def _update_colo_tensors(info, *args) -> list:
- ret = []
- for t_info, arg in zip(info, args):
- if t_info is not None:
- t_cls, spec = t_info
- arg = t_cls.from_torch_tensor(arg, spec=spec)
- ret.append(arg)
- return ret
+ other_args.append(arg)
+ assert len(grad_args) > 0
+ return grad_args, other_args, grad_flags, spec
+
+
+def _merge_args(grad_args, other_args, grad_flags, spec):
+ grad_iter = iter(grad_args)
+ other_iter = iter(other_args)
+ flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
+ return tree_unflatten(flat_args, spec)
diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py
index 7b2e8480c66c..6f9717d353e6 100644
--- a/colossalai/utils/__init__.py
+++ b/colossalai/utils/__init__.py
@@ -1,12 +1,14 @@
from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint
from .common import (
+ _cast_float,
clip_grad_norm_fp32,
conditional_context,
copy_tensor_parallel_attributes,
count_zeros_fp32,
disposable,
ensure_path_exists,
+ free_storage,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
@@ -72,4 +74,6 @@
'disposable',
'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity',
+ '_cast_float',
+ 'free_storage',
]
diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py
index 8022e84dc24b..998901708239 100644
--- a/colossalai/utils/common.py
+++ b/colossalai/utils/common.py
@@ -470,3 +470,22 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
+
+
+def free_storage(data: torch.Tensor) -> None:
+ """Free underlying storage of a Tensor."""
+ if data.storage().size() > 0:
+ # Since we're modifying the Tensor's Storage directly, make sure the Tensor
+ # is the sole occupant of the Storage.
+ assert data.storage_offset() == 0
+ data.storage().resize_(0)
+
+
+def _cast_float(args, dtype: torch.dtype):
+ if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
+ args = args.to(dtype)
+ elif isinstance(args, (list, tuple)):
+ args = type(args)(_cast_float(t, dtype) for t in args)
+ elif isinstance(args, dict):
+ args = {k: _cast_float(v, dtype) for k, v in args.items()}
+ return args
diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py
index 2318e07a7f8d..881ddde78648 100644
--- a/colossalai/utils/data_sampler/data_parallel_sampler.py
+++ b/colossalai/utils/data_sampler/data_parallel_sampler.py
@@ -4,20 +4,18 @@
import math
import random
-import numpy as np
-from typing import TypeVar, Iterator
+from typing import Iterator, TypeVar
+import numpy as np
import torch
-from torch.utils.data import Sampler, Dataset, DataLoader
+from torch.utils.data import DataLoader, Dataset, Sampler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
-from colossalai.registry import DATA_SAMPLERS
T_co = TypeVar('T_co', covariant=True)
-@DATA_SAMPLERS.register_module
class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism.
@@ -30,11 +28,7 @@ class DataParallelSampler(Sampler):
the batch size, then the last batch will be smaller, defaults to False.
"""
- def __init__(self,
- dataset: Dataset,
- shuffle: bool = False,
- seed: int = 0,
- drop_last: bool = False) -> None:
+ def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None:
self.dataset = dataset
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
self.rank = gpc.get_local_rank(ParallelMode.DATA)
@@ -54,8 +48,7 @@ def __init__(self,
self.num_replicas # type: ignore[arg-type]
)
else:
- self.num_samples = math.ceil(
- len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
@@ -72,7 +65,7 @@ def __iter__(self) -> Iterator[T_co]:
# set_epoch manually
self.epoch += 1
else:
- indices = list(range(len(self.dataset))) # type: ignore[arg-type]
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
@@ -80,8 +73,7 @@ def __iter__(self) -> Iterator[T_co]:
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
- indices += (indices * math.ceil(padding_size /
- len(indices)))[:padding_size]
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
@@ -109,8 +101,8 @@ def set_epoch(self, epoch: int) -> None:
def get_dataloader(dataset,
shuffle=False,
- seed=1024,
- add_sampler=True,
+ seed=1024,
+ add_sampler=True,
drop_last=False,
pin_memory=False,
num_workers=0,
diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/utils/profiler/profiler.py
index 8f43a0b96de0..3026d723deb0 100644
--- a/colossalai/utils/profiler/profiler.py
+++ b/colossalai/utils/profiler/profiler.py
@@ -1,17 +1,17 @@
-import os
-from typing import List
-from colossalai.engine import Engine
-from torch.profiler import profile as torch_profile
-from torch.profiler.profiler import ProfilerAction
-from typing import Any, Callable, Iterable, Optional
-from torch.autograd import ProfilerActivity
+import gzip
import json
import os
import tempfile
-import gzip
+from typing import Any, Callable, Iterable, List, Optional
+
+from torch.autograd import ProfilerActivity
+from torch.profiler import profile as torch_profile
+from torch.profiler.profiler import ProfilerAction
+
+from colossalai.legacy.engine import Engine
+from colossalai.logging import get_dist_logger
from colossalai.utils.profiler.extention import ProfilerExtension
from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention
-from colossalai.logging import get_dist_logger
class profile(torch_profile):
diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/utils/profiler/stateful_tensor_mem_extention.py
index 127055c8c1ef..412bd7277eee 100644
--- a/colossalai/utils/profiler/stateful_tensor_mem_extention.py
+++ b/colossalai/utils/profiler/stateful_tensor_mem_extention.py
@@ -1,12 +1,14 @@
import os
import threading
import time
-import torch
from enum import Enum
from typing import List
-from colossalai.gemini.stateful_tensor import StatefulTensor
+
+import torch
+
from colossalai.gemini.ophooks import BaseOpHook
-from colossalai.engine import Engine
+from colossalai.gemini.stateful_tensor import StatefulTensor
+from colossalai.legacy.engine import Engine
from colossalai.utils.profiler.extention import ProfilerExtension
diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py
index 3465079e4fbb..4991241b8df1 100644
--- a/colossalai/zero/__init__.py
+++ b/colossalai/zero/__init__.py
@@ -2,8 +2,7 @@
ColoInitContext,
GeminiAdamOptimizer,
GeminiDDP,
- ZeroDDP,
- ZeroOptimizer,
+ GeminiOptimizer,
get_static_torch_model,
post_process_colo_init_ctx,
)
@@ -11,6 +10,6 @@
from .wrapper import zero_model_wrapper, zero_optim_wrapper
__all__ = [
- 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
+ 'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
]
diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py
index 60f85ca2f540..7ac6a9be4140 100644
--- a/colossalai/zero/gemini/__init__.py
+++ b/colossalai/zero/gemini/__init__.py
@@ -1,11 +1,11 @@
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
-from .gemini_ddp import GeminiDDP, ZeroDDP
+from .gemini_ddp import GeminiDDP
from .gemini_mgr import GeminiManager
-from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
+from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
from .utils import get_static_torch_model
__all__ = [
- 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
- 'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
+ 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
+ 'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
]
diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py
index 51da9be2b1f8..3e7403adb53b 100644
--- a/colossalai/zero/gemini/chunk/chunk.py
+++ b/colossalai/zero/gemini/chunk/chunk.py
@@ -4,8 +4,8 @@
import torch
import torch.distributed as dist
+from torch.distributed import ProcessGroup
-from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device
@@ -55,7 +55,7 @@ class Chunk:
def __init__(self,
chunk_size: int,
- process_group: ColoProcessGroup,
+ process_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
@@ -69,7 +69,7 @@ def __init__(self,
Args:
chunk_size (int): the number of elements in the chunk
- process_group (ColoProcessGroup): the process group of this chunk
+ process_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
@@ -83,7 +83,7 @@ def __init__(self,
self.chunk_size = chunk_size
self.utilized_size = 0
- self.torch_pg = process_group.dp_process_group()
+ self.torch_pg = process_group
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
@@ -218,7 +218,7 @@ def can_release(self) -> bool:
return False
else:
return self.tensor_state_cnter[TensorState.HOLD] + \
- self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
+ self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property
def can_reduce(self):
diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py
index 38d34f14863e..1e96234326a9 100644
--- a/colossalai/zero/gemini/chunk/manager.py
+++ b/colossalai/zero/gemini/chunk/manager.py
@@ -2,8 +2,9 @@
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
-from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device
from .chunk import Chunk, ChunkFullError, TensorState
@@ -27,16 +28,17 @@ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = No
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device
- self.chunk_groups: Dict[str, Deque] = dict()
+ self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def register_tensor(self,
- tensor: ColoTensor,
+ tensor: torch.Tensor,
group_type: str,
config_key: int,
+ process_group: ProcessGroup,
cpu_offload: bool = False,
pin_memory: bool = False) -> None:
"""
@@ -51,7 +53,7 @@ def register_tensor(self,
pin_memory: whether the chunk is pinned in the cpu memory
"""
assert tensor not in self.tensor_chunk_map
- assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
+ assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager"
assert config_key in self.dp_degree_chunk_size_dict
chunk_size = self.dp_degree_chunk_size_dict[config_key]
@@ -73,12 +75,12 @@ def register_tensor(self,
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
- dp_size = tensor.get_dp_world_size()
+ dp_size = dist.get_world_size(process_group)
chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk(
chunk_size=chunk_size,
- process_group=tensor.process_group,
+ process_group=process_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
@@ -220,7 +222,7 @@ def __repr__(self) -> str:
msg.append(f'[{i}] {chunk}\n')
return ''.join(msg)
- def __get_chunk_group(self, group_name: str) -> Deque:
+ def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
"""Register a chunk group.
"""
if group_name not in self.chunk_groups:
diff --git a/colossalai/zero/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py
index 6c3d4f9a1b41..abaca5f8294d 100644
--- a/colossalai/zero/gemini/chunk/search_utils.py
+++ b/colossalai/zero/gemini/chunk/search_utils.py
@@ -4,6 +4,7 @@
import numpy as np
import torch.distributed as dist
import torch.nn as nn
+from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored
@@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc
-def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
+def _tensor_numel(local_param: ColoParameter) -> int:
"""_tensor_numel
Get the number of elements of a tensor.
@@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
Returns:
int: the number of elements.
"""
- if strict_ddp_flag and type(local_param) is ColoParameter:
- return local_param.numel_global()
- else:
- # if local_param is not ColoParameter, we assume it's replicated
- return local_param.numel()
+ # TODO(ver217): support dtensor here
+ return local_param.numel()
def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
- strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
+ process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
@@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param):
continue
-
- if strict_ddp_flag or type(param) is not ColoParameter:
- # if model is not initialized with ColoInitContext, we assume it's replicated
- # TODO(ver217): integrate DTensor
- param_key = dist.get_world_size()
- else:
- param_key = param.process_group.dp_world_size()
+ param_key = dist.get_world_size(process_group)
if param_key not in params_dict:
params_dict[param_key] = []
@@ -119,6 +111,7 @@ def search_chunk_configuration(
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
+ process_group: Optional[ProcessGroup] = None,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration
@@ -149,7 +142,7 @@ def search_chunk_configuration(
min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0
- params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
+ params_dict = classify_params_by_dp_degree(param_order, process_group)
size_lcm = np.lcm.reduce(list(params_dict.keys()))
config_dict: Dict[int, Dict] = dict()
total_param_size = 0
@@ -157,7 +150,7 @@ def search_chunk_configuration(
size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict:
params_list = params_dict[dp_degree]
- size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
+ size_list = [_tensor_numel(p) for p in params_list]
group_acc_size = sum(size_list)
total_param_size += group_acc_size
diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py
index 75f8576ca477..dad852a34a71 100644
--- a/colossalai/zero/gemini/colo_init_context.py
+++ b/colossalai/zero/gemini/colo_init_context.py
@@ -87,7 +87,7 @@ def __init__(self,
self._default_dist_spec = default_dist_spec
def _register_colo_modules(self):
- from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
+ from colossalai.legacy.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding())
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index 08384ee82d0b..918b08cd3150 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -2,21 +2,21 @@
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
-from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
+from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
+from torch.distributed import ProcessGroup
+from torch.distributed.distributed_c10d import _get_default_group
-from colossalai.checkpoint_io.utils import calculate_tensor_size
+from colossalai.checkpoint_io.utils import StateDictSharder, calculate_tensor_size
+from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
-from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
-from colossalai.tensor import ProcessGroup as ColoProcessGroup
-from colossalai.tensor import ReplicaSpec
-from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
+from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
-from colossalai.utils import get_current_device, is_ddp_ignored
+from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@@ -30,14 +30,13 @@
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
__all__ = [
- 'ZeroDDP',
'GeminiDDP',
]
-class ZeroDDP(ColoDDP):
- """ZeRO DDP for ColoTensor.
- Warning: Nested ZeroDDP is not supported now.
+class GeminiDDP(ModelWrapper):
+ """ZeRO DDP.
+ Warning: Nested GeminiDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
@@ -54,20 +53,54 @@ class ZeroDDP(ColoDDP):
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
"""
- def __init__(self,
- module: torch.nn.Module,
- gemini_manager: GeminiManager,
- pin_memory: bool = False,
- force_outputs_fp32: bool = False,
- strict_ddp_mode: bool = False,
- scatter_after_inference: bool = True,
- mixed_precision: torch.dtype = torch.float16) -> None:
+ def __init__(
+ self,
+ module: torch.nn.Module,
+ chunk_config_dict: Optional[dict] = None,
+ chunk_init_device: torch.device = torch.device('cpu'),
+ placement_policy: str = "static",
+ shard_param_frac: float = 1.0, # only for static placement
+ offload_optim_frac: float = 0.0, # only for static placement
+ offload_param_frac: float = 0.0, # only for static placement
+ warmup_non_model_data_ratio: float = 0.8, # only for auto placement
+ steady_cuda_cap_ratio: float = 0.9, # only for auto placement
+ search_range_m: int = 32, # chunk search options
+ hidden_dim: Optional[int] = None, # chunk search options
+ min_chunk_size_m: float = 32, # chunk search options
+ pin_memory: bool = False,
+ force_outputs_fp32: bool = False,
+ strict_ddp_mode: bool = False,
+ scatter_after_inference: bool = True,
+ mixed_precision: torch.dtype = torch.float16,
+ process_group: Optional[ProcessGroup] = None,
+ memstats: Optional[MemStats] = None, # genimi memory stats
+ verbose: bool = False) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
- self.gemini_manager = gemini_manager
- self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
+ if chunk_config_dict is not None:
+ self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
+ else:
+ # some ugly hotfix for the compatibility with Lightning
+ if search_range_m is None:
+ search_range_m = 32
+ self.chunk_manager = init_chunk_manager(model=module,
+ init_device=chunk_init_device,
+ hidden_dim=hidden_dim,
+ search_range_m=search_range_m,
+ min_chunk_size_m=min_chunk_size_m,
+ strict_ddp_flag=strict_ddp_mode,
+ process_group=process_group,
+ verbose=verbose)
+ self.gemini_manager = GeminiManager(placement_policy,
+ self.chunk_manager,
+ memstats,
+ shard_param_frac=shard_param_frac,
+ offload_optim_frac=offload_optim_frac,
+ offload_param_frac=offload_param_frac,
+ warmup_non_model_data_ratio=warmup_non_model_data_ratio,
+ steady_cuda_cap_ratio=steady_cuda_cap_ratio)
self.force_outputs_fp32 = force_outputs_fp32
- self.param_op_hook = GeminiZeROHook(gemini_manager)
- self.fp32_params: List[ColoTensor] = list()
+ self.param_op_hook = GeminiZeROHook(self.gemini_manager)
+ self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@@ -75,6 +108,7 @@ def __init__(self,
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
+ self.dp_process_group = process_group or _get_default_group()
self._logger = get_dist_logger()
@@ -88,20 +122,67 @@ def __init__(self,
for p in module.parameters():
param_order.append(p)
- self._init_chunks(param_order=param_order,
- strict_ddp_mode=strict_ddp_mode,
- cpu_offload=self.gemini_manager.policy_name != 'cuda',
- pin_memory=pin_memory)
-
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
- super().__init__(module, process_group=ColoProcessGroup())
+
+ self._init_chunks(param_order=param_order,
+ strict_ddp_mode=strict_ddp_mode,
+ cpu_offload=self.gemini_manager.policy_name != 'cuda',
+ pin_memory=pin_memory)
+ super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()
+ # register grad hook
+ for p in module.parameters():
+ if is_ddp_ignored(p):
+ continue
+ if p.requires_grad:
+ p.register_hook(partial(self.grad_handle, p))
+
+ def parameters(self, recurse: bool = True):
+ return self.module.parameters(recurse)
+
+ def named_parameters(self, prefix: str = '', recurse: bool = True):
+ return self.module.named_parameters(prefix, recurse)
+
+ def named_buffers(self, prefix: str = '', recurse: bool = True):
+ return self.module.named_buffers(prefix, recurse)
+
+ def named_children(self):
+ return self.module.named_children()
+
+ def named_modules(self,
+ memo: Optional[Set[torch.nn.Module]] = None,
+ prefix: str = '',
+ remove_duplicate: bool = True):
+ return self.module.named_modules(memo, prefix, remove_duplicate)
+
+ @staticmethod
+ def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
+ """Sets parameters to be ignored by DDP.
+ This method must be called before initializing ColoDDP.
+
+ Example:
+ >>> params_to_ignore = []
+ >>> for p in module.parameters():
+ >>> if should_ignore(p):
+ >>> params_to_ignore.append(p)
+ >>> ColoDDP.set_params_to_ignore(params_to_ignore)
+ >>> module = ColoDDP(module)
+
+ Args:
+ params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
+ """
+ for p in params_to_ignore:
+ p._ddp_to_ignore = True
+
+ def unwrap(self):
+ # as save/load state dict is overwrited, only return self
+ return self
def _get_non_persistent_buffers_set(self,
module,
@@ -207,7 +288,7 @@ def _post_backward(self):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
- "The most possible reason is that the model is not compatible with ZeroDDP.\n",
+ "The most possible reason is that the model is not compatible with GeminiDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
@@ -227,6 +308,7 @@ def backward_by_grad(self, tensor, grad):
self._post_backward()
def grad_handle(self, p, grad):
+ setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
@@ -533,7 +615,7 @@ def load_fp32_parameter(chunk_slice, data):
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
- chunk_16.optim_update()
+ chunk_16.payload.copy_(chunk_32.payload)
for name, buf in persistent_buffers.items():
if buf is not None:
@@ -557,17 +639,11 @@ def load_fp32_parameter(chunk_slice, data):
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
- ddp_pg = ColoProcessGroup()
+ dp_world_size = dist.get_world_size(self.dp_process_group)
for p in param_order.generate():
self._preprocess_param(p)
assert type(p) is ColoParameter
- # gather sharded parameters in the strict ddp mode
- if strict_ddp_mode:
- if not p.is_replicate():
- p.set_dist_spec(ReplicaSpec())
- p.set_process_group(pg=ddp_pg)
-
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
@@ -578,38 +654,37 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi
continue
# create a fp32 parameter
- fp32_data = p.data.float()
- fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
+ fp32_p = p.data.float()
# create a fp16 parameter
p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager
- dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
+ process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
+ process_group=self.dp_process_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
- self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
+ self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
+ # move master weights to corresponding device and setup paired chunks
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
-
- # keep gathered chunks are in CUDA
- if chunk_16.keep_gathered:
- self.grads_device[p] = get_current_device()
+ if chunk_32.device_type != self.grads_device[p].type:
+ self.chunk_manager.move_chunk(chunk_32, self.grads_device[p])
def _cast_buffers(self):
for buffer in self.module.buffers():
@@ -657,7 +732,7 @@ def state_dict_shard(self,
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
- sharder = _StateDictSharder(max_shard_size)
+ sharder = StateDictSharder(max_shard_size)
# get the mapping between copies and fp16 parameters
fp16_to_fp32 = dict()
@@ -679,7 +754,7 @@ def state_dict_shard(self,
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
- block, block_size = sharder.append(prefix + name, gathered_param)
+ block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
@@ -690,7 +765,7 @@ def state_dict_shard(self,
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
- block, block_size = sharder.append(prefix + name, buffer)
+ block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# save extra states
@@ -698,96 +773,8 @@ def state_dict_shard(self,
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
- block, block_size = sharder.append(extra_state_key, extra_state)
+ block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
yield sharder.current_block, sharder.current_block_size
-
-
-class _StateDictSharder:
-
- def __init__(self, max_shard_size: int) -> None:
- self.max_shard_size = max_shard_size
- self.current_block = OrderedDict()
- self.current_block_size = 0
-
- def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
- tensor_size = calculate_tensor_size(tensor)
- ret_block = None
- ret_block_size = 0
-
- # before we return the current block and create a new block,
- # we need to ensure that the current block is not empty
- if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
- ret_block = self.current_block
- ret_block_size = self.current_block_size
- self.current_block = OrderedDict()
- self.current_block_size = 0
- self.current_block[name] = tensor
- self.current_block_size += tensor_size
- return ret_block, ret_block_size
-
-
-class GeminiDDP(ZeroDDP):
-
- def __init__(self,
- module: torch.nn.Module,
- device: torch.device,
- placement_policy: str = "cpu",
- pin_memory: bool = False,
- force_outputs_fp32: bool = False,
- strict_ddp_mode: bool = False,
- scatter_after_inference: bool = True,
- search_range_m: int = 32,
- hidden_dim: Optional[int] = None,
- min_chunk_size_m: float = 32,
- memstats: Optional[MemStats] = None,
- mixed_precision: torch.dtype = torch.float16,
- verbose: bool = False) -> None:
- """
- A torch.Module wrapper using ZeRO-DP and Gemini.
- ZeRO is for parallel. Gemini is for memory management.
- WARNING: The class will modify the module inline!
-
- Example:
- model is initialized under the context of ColoInitContext
- >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
- >>> logits = model(x)
- >>> loss = criterion(logits, labels)
- >>> model.backward(loss)
-
- Args:
- module (torch.nn.Module): the model to be wrapped.
- device (torch.device): device to place the model.
- placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
- pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
- force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
- search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
- hidden_dim (int, optional): the hidden dimension of DNN.
- Users can provide this argument to speed up searching.
- If users do not know this argument before training, it is ok. We will use a default value 1024.
- min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
- If the aggregate size of parameters is still smaller than the minimum chunk size,
- all parameters will be compacted into one small chunk.
- memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
- """
- # some ugly hotfix for the compatibility with Lightning
- if search_range_m is None:
- search_range_m = 32
-
- chunk_manager = init_chunk_manager(model=module,
- init_device=device,
- hidden_dim=hidden_dim,
- search_range_m=search_range_m,
- min_chunk_size_m=min_chunk_size_m,
- strict_ddp_flag=strict_ddp_mode,
- verbose=verbose)
- gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
- super().__init__(module,
- gemini_manager,
- pin_memory,
- force_outputs_fp32,
- strict_ddp_mode,
- scatter_after_inference,
- mixed_precision=mixed_precision)
diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py
index c38e6eff840d..b8e4717908f7 100644
--- a/colossalai/zero/gemini/gemini_mgr.py
+++ b/colossalai/zero/gemini/gemini_mgr.py
@@ -1,6 +1,6 @@
import functools
from time import time
-from typing import List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
import torch
@@ -26,7 +26,11 @@ class GeminiManager:
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""
- def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
+ def __init__(self,
+ placement_policy: str,
+ chunk_manager: ChunkManager,
+ memstats: Optional[MemStats] = None,
+ **placement_kwargs) -> None:
assert placement_policy in PlacementPolicyFactory.get_policy_names()
self.policy_name = placement_policy
@@ -37,7 +41,7 @@ def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats:
self._memstats = memstats
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
self._memstats) if policy_cls.need_mem_stats else None
- self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
+ self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
@@ -133,10 +137,6 @@ def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
if self._warmup and self._placement_policy.need_mem_stats:
self._compute_list.append(chunks)
- @property
- def default_device(self):
- return self._placement_policy.get_default_device()
-
def sample_overall_data(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_overall_data()
@@ -159,6 +159,6 @@ def cuda_margin_mem(self) -> Optional[float]:
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats
- @staticmethod
- def get_default_device(policy_name: str) -> torch.device:
- return PlacementPolicyFactory.get_default_device(policy_name)
+ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
+ torch.device]) -> None:
+ self._placement_policy.setup_grads_device(params, grads_device_map)
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 7d0db6b1fa23..0c593deff225 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -1,9 +1,8 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
-import gc
import math
import warnings
-from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple
+from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch
import torch.distributed as dist
@@ -11,16 +10,17 @@
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
-from colossalai.checkpoint_io.utils import calculate_tensor_size
+from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
+from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
-from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
+from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
-from .gemini_ddp import ZeroDDP
+from .gemini_ddp import GeminiDDP
-__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
+__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer']
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
@@ -28,7 +28,7 @@
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
- module: ZeroDDP,
+ module: GeminiDDP,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
@@ -47,11 +47,11 @@ def pre_zero_grad(self) -> None:
self.module.overflow_counter = 0
-class ZeroOptimizer(ColossalaiOptimizer):
- """A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
+class GeminiOptimizer(OptimizerWrapper):
+ """A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
Note:
- You must use ``ZeroDDP`` with ``ZeroOptimizer``.
+ You must use ``GeminiDDP`` with ``GeminiOptimizer``.
Note:
Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`,
@@ -59,7 +59,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
Args:
optim (Optimizer): An Optimizer instance.
- module (ZeroDDP): A ``ZeroDDP`` instance.
+ module (GeminiDDP): A ``GeminiDDP`` instance.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
@@ -71,15 +71,15 @@ class ZeroOptimizer(ColossalaiOptimizer):
growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32.
- clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
+ max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
- is supported in ZeroOptimizer. Defaults to 2.0.
+ is supported in GeminiOptimizer. Defaults to 2.0.
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
"""
def __init__(self,
optim: Optimizer,
- module: ZeroDDP,
+ module: GeminiDDP,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
@@ -88,12 +88,12 @@ def __init__(self,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
- clipping_norm: float = 0.0,
+ max_norm: float = 0.0,
norm_type: float = 2.0,
verbose: bool = False,
**defaults: Any):
super().__init__(optim)
- assert isinstance(module, ZeroDDP)
+ assert isinstance(module, GeminiDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
f"{_AVAIL_OPTIM_LIST}"
self.module = module
@@ -102,8 +102,8 @@ def __init__(self,
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
- self.clipping_flag = clipping_norm > 0.0
- self.max_norm = clipping_norm
+ self.clipping_flag = max_norm > 0.0
+ self.max_norm = max_norm
self.verbose = verbose
self.param_groups_backup = list()
@@ -112,7 +112,7 @@ def __init__(self,
self.id_to_fake_params: Dict[int, Parameter] = dict()
if self.clipping_flag:
- assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
+ assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now"
ddp_param_list = []
for name, param in module.named_parameters():
@@ -468,11 +468,6 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
shard_size)
- # Clean gathered states
- for state_shard in gathered_state_shards:
- del state_shard[0]
- gc.collect()
-
# Reshape tensors
if is_collector:
for state_name, state_tensor in collected_states.items():
@@ -697,52 +692,31 @@ def state_shard(self,
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
"""
- current_block = {}
- current_block_size = 0
-
+ sharder = StateDictSharder(max_shard_size)
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
- ret_block = None
- ret_block_size = 0
-
- # A state might contain more than one tensors.
- # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
- state_size = 0
- isDTensor = False
- for state_tensor in state.values():
-
- # When state_tensor is not of Tensor class,
- # e.g., a SGD optimizer with momentum set to 0 can have None as state
- # The calculation of tensor size should be skipped to avoid error.
- if not isinstance(state_tensor, torch.Tensor):
- continue
-
- # If the states are stored as DTensors, mark isDTensor as true.
- if is_distributed_tensor(state_tensor):
- isDTensor = True
- state_size += calculate_tensor_size(state_tensor)
-
- if not isDTensor:
-
- if current_block_size + state_size > max_shard_size and current_block_size > 0:
- ret_block = current_block
- ret_block_size = current_block_size
- current_block = {}
- current_block_size = 0
+ block, block_size = sharder.append_optim_state(param_id, state)
+ if block is not None:
+ yield block, block_size
- current_block[param_id] = state
- current_block_size += state_size
+ yield sharder.current_block, sharder.current_block_size
- if ret_block != None:
- yield ret_block, ret_block_size
+ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
+ raise NotImplementedError('Gemini does not support clip_grad_by_value')
- yield current_block, current_block_size
+ def clip_grad_by_norm(self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs) -> torch.Tensor:
+ warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
-class GeminiAdamOptimizer(ZeroOptimizer):
+class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
diff --git a/colossalai/zero/gemini/memory_tracer/memory_stats.py b/colossalai/zero/gemini/memory_tracer/memory_stats.py
index 41d7e5754e96..02de6ecb97a9 100644
--- a/colossalai/zero/gemini/memory_tracer/memory_stats.py
+++ b/colossalai/zero/gemini/memory_tracer/memory_stats.py
@@ -9,7 +9,7 @@ class MemStats(object):
def __init__(self) -> None:
"""
- Store the non model data statistics used for Gemini and ZeroOptimizer.
+ Store the non model data statistics used for Gemini and GeminiOptimizer.
"""
# (preop_step, List[param])
self._step_param_dict = dict()
diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
index 0c9eac8b63e3..e5466965cc48 100644
--- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
+++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py
@@ -1,7 +1,7 @@
import torch.nn
-from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
+from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py
index 84a868872f88..cd775da5e11f 100644
--- a/colossalai/zero/gemini/placement_policy.py
+++ b/colossalai/zero/gemini/placement_policy.py
@@ -1,4 +1,5 @@
import functools
+import warnings
from abc import ABC, abstractmethod
from time import time
from typing import Dict, List, Optional, Tuple, Type
@@ -7,6 +8,7 @@
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
+from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
@@ -17,7 +19,8 @@ class PlacementPolicy(ABC):
def __init__(self,
chunk_manager: ChunkManager,
- mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
+ mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
+ **kwargs) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
@@ -25,57 +28,87 @@ def __init__(self,
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
raise NotImplementedError
- @staticmethod
- def get_default_device() -> torch.device:
- return torch.device('cpu')
+ @abstractmethod
+ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
+ torch.device]) -> None:
+ raise NotImplementedError
-class CPUPlacementPolicy(PlacementPolicy):
+class StaticPlacementPolicy(PlacementPolicy):
def __init__(self,
chunk_manager: ChunkManager,
- mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
+ mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
+ shard_param_frac: float = 1.0,
+ offload_optim_frac: float = 0.0,
+ offload_param_frac: float = 0.0,
+ **kwargs) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
+ if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
+ warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
+ offload_param_frac = 0.0
+ self.shard_param_frac = shard_param_frac
+ self.offload_optim_frac = offload_optim_frac
+ self.offload_param_frac = offload_param_frac
+ # these should be initialized in setup_grads_device
+ self.keep_gathered_chunk_mem = 0.0
+ self.keep_cuda_chunk_mem = 0.0
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
- volume = 0
- start = time()
+ can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks)
+ can_offload_chunk_mem = can_shard_chunk_mem
for chunk in can_evict_chunks:
+ if can_shard_chunk_mem <= self.keep_gathered_chunk_mem:
+ break
self.chunk_manager.release_chunk(chunk)
+ # real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem
+ can_shard_chunk_mem -= chunk.chunk_mem
+ for chunk in can_evict_chunks:
+ if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
+ break
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
- volume += chunk.chunk_mem
- return volume, time() - start
-
-
-class CUDAPlacementPolicy(PlacementPolicy):
-
- def __init__(self,
- chunk_manager: ChunkManager,
- mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
- assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
- super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
-
- def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
- return 0, 0
-
- @staticmethod
- def get_default_device() -> torch.device:
- return get_current_device()
+ # real saved mem is shard_mem, for simplicity we use chunk_mem
+ can_offload_chunk_mem -= chunk.chunk_mem
+ return 0, 0.0
+
+ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
+ torch.device]) -> None:
+ total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
+
+ offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
+ offloaded_optim_chunk_mem = 0
+ chunks = set(self.chunk_manager.get_chunk(p) for p in params)
+ for chunk in chunks:
+ params = chunk.get_tensors()
+ # init offload optim settings
+ # keep gathered chunks are in CUDA
+ if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
+ device = get_current_device()
+ else:
+ device = torch.device('cpu')
+ # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
+ offloaded_optim_chunk_mem += chunk.chunk_mem
+ for p in params:
+ grads_device_map[p] = device
+ self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
+ self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
class AutoPlacementPolicy(PlacementPolicy):
-
need_mem_stats: bool = True
- # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
- # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
- # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
- _warmup_non_model_data_ratio: float = 0.8
- _steady_cuda_cap_ratio: float = 0.9
def __init__(self,
chunk_manager: ChunkManager,
- mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
+ mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
+ warmup_non_model_data_ratio: float = 0.8,
+ steady_cuda_cap_ratio: float = 0.9,
+ **kwargs) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
+ # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
+ # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
+ # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
+ self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
+ self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
def evict_tensors(self,
can_evict_chunks: List[Chunk],
@@ -105,11 +138,11 @@ def evict_tensors(self,
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
- max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio
+ max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
- cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio
+ cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
@@ -145,89 +178,22 @@ def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_li
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
- @staticmethod
- def set_warmup_non_model_data_ratio(ratio: float) -> None:
- ratio = float(ratio)
- assert 0.0 < ratio < 1.0
- AutoPlacementPolicy._warmup_non_model_data_ratio = ratio
-
- @staticmethod
- def set_steady_cuda_cap_ratio(ratio: float) -> None:
- ratio = float(ratio)
- assert 0.0 < ratio < 1.0
- AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
-
-
-class ConstPlacementPolicy(PlacementPolicy):
-
- need_mem_stats: bool = False
- _accessed_memory_boundary = 512 * 1024**2
-
- def __init__(self,
- chunk_manager: ChunkManager,
- mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
- super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
-
- def evict_tensors(self,
- can_evict_chunks: List[Chunk],
- cuda_demand: int = 0,
- warmup: bool = True,
- compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
- compute_idx: int = 0,
- **kwargs) -> Tuple[int, float]:
- """
- See the docstrings in the class `AutoPlacementPolicy`.
- """
- start = time()
- used_accessed_memory = self.chunk_manager.accessed_mem
- avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
- freed_accessed_memory = 0
-
- if avail_accessed_memory < cuda_demand:
- to_free_memory = cuda_demand - avail_accessed_memory
- to_free_chunks = can_evict_chunks
-
- if not warmup:
- # sort all chunks
- to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
-
- for chunk in to_free_chunks:
- if freed_accessed_memory >= to_free_memory:
- break
-
- self.chunk_manager.release_chunk(chunk)
- self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
- freed_accessed_memory += chunk.chunk_mem
-
- if freed_accessed_memory < to_free_memory:
- raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
- f"Need {to_free_memory}, freed {freed_accessed_memory}")
- return freed_accessed_memory, time() - start
-
- @staticmethod
- @functools.lru_cache(maxsize=None)
- def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
- next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
- for i in range(len(compute_list) - 1, compute_idx, -1):
- for chunk in compute_list[i]:
- if chunk in next_compute_idx:
- next_compute_idx[chunk] = i
- next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
- return [t for (t, idx) in next_compute_idx]
-
- @staticmethod
- def set_const_memory_boundary(cuda_memory_mb: int) -> None:
- boundary = int(cuda_memory_mb * 1024**2)
- assert boundary > 0
- ConstPlacementPolicy._accessed_memory_boundary = boundary
+ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
+ torch.device]) -> None:
+ for p in params:
+ chunk = self.chunk_manager.get_chunk(p)
+ # init offload optim settings
+ # keep gathered chunks are in CUDA
+ if chunk.keep_gathered:
+ grads_device_map[p] = get_current_device()
+ else:
+ grads_device_map[p] = torch.device('cpu')
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {
- 'cpu': CPUPlacementPolicy,
- 'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy,
- 'const': ConstPlacementPolicy
+ 'static': StaticPlacementPolicy,
}
@staticmethod
@@ -239,8 +205,3 @@ def create(policy_name: str) -> Type[PlacementPolicy]:
@staticmethod
def get_policy_names():
return tuple(PlacementPolicyFactory.policies.keys())
-
- @staticmethod
- def get_default_device(policy_name: str) -> torch.device:
- policy_cls = PlacementPolicyFactory.create(policy_name)
- return policy_cls.get_default_device()
diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py
index 6f4a253b504b..0d92d32e5603 100644
--- a/colossalai/zero/gemini/utils.py
+++ b/colossalai/zero/gemini/utils.py
@@ -64,13 +64,13 @@ def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
- """Get a static torch.nn.Module model from the given ZeroDDP module.
- You should notice that the original ZeroDDP model is not modified.
+ """Get a static torch.nn.Module model from the given GeminiDDP module.
+ You should notice that the original GeminiDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
- zero_ddp_model (ZeroDDP): a zero ddp model
+ zero_ddp_model (GeminiDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the converted torch model
@@ -78,8 +78,8 @@ def get_static_torch_model(zero_ddp_model,
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
- from colossalai.zero.gemini.gemini_ddp import ZeroDDP
- assert isinstance(zero_ddp_model, ZeroDDP)
+ from colossalai.zero.gemini.gemini_ddp import GeminiDDP
+ assert isinstance(zero_ddp_model, GeminiDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = zero_ddp_model.module
diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
index 8f8fec64924e..d68a9dc6458f 100644
--- a/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
+++ b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py
@@ -1,6 +1,6 @@
import torch
-from colossalai.registry import OPHOOKS
+from colossalai.legacy.registry import OPHOOKS
from . import BaseOpHook
diff --git a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
index a2a62fb9788a..6b76a2116a49 100644
--- a/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
+++ b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py
@@ -1,6 +1,6 @@
import torch
-from colossalai.registry import OPHOOKS
+from colossalai.legacy.registry import OPHOOKS
from . import BaseOpHook
diff --git a/colossalai/zero/legacy/sharded_model/zero_hook.py b/colossalai/zero/legacy/sharded_model/zero_hook.py
index 50f4bdfc775d..1815bee3a9e0 100644
--- a/colossalai/zero/legacy/sharded_model/zero_hook.py
+++ b/colossalai/zero/legacy/sharded_model/zero_hook.py
@@ -3,8 +3,8 @@
import torch
import torch.distributed as dist
+from colossalai.legacy.registry import OPHOOKS
from colossalai.logging import get_dist_logger
-from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py
index 98f1b78d0049..0ab10e25d407 100644
--- a/colossalai/zero/low_level/bookkeeping/bucket_store.py
+++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py
@@ -13,15 +13,20 @@ class BucketStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
- # init and reset
+ # init
self.current_group_id = 0
+ self._num_elements_in_bucket = 0
# mapping gardient slices and parameter
self.grad_to_param_mapping = dict()
+ self._grad_in_bucket = dict()
self._param_list = []
self._padding_size = []
+ for rank in range(self._world_size):
+ self._grad_in_bucket[rank] = []
- self.reset()
+ # offset_list records number of tensors in the bucket before each reduction
+ self.offset_list = [0]
def num_elements_in_bucket(self) -> int:
"""Return the total number of elements in bucket
@@ -32,6 +37,12 @@ def num_elements_in_bucket(self) -> int:
return self._num_elements_in_bucket
+ def reset_num_elements_in_bucket(self):
+ """Set the number of elements in bucket to zero.
+ """
+
+ self._num_elements_in_bucket = 0
+
def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
"""Add a param to bucket and record the padding size of a param for gradient padding
@@ -46,28 +57,32 @@ def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
self._num_elements_in_bucket += (param.numel() + padding_size)
self.current_group_id = group_id
+ # number of tensors in current bucket
+ self.offset_list[-1] += 1
+
def build_grad_in_bucket(self):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
Data structure of self._grad_in_bucket:
{
rank0: [grad0_rank0, grad1_rank0, ...]
- rank1: [grad1_rank1, grad1_rank1, ...]
+ rank1: [grad0_rank1, grad1_rank1, ...]
}
"""
-
for param, padding_size in zip(self._param_list, self._padding_size):
- with torch.no_grad():
- grad = param.grad.detach().flatten()
- if padding_size > 0:
- grad = torch.nn.functional.pad(grad, [0, padding_size])
- grad_list = grad.split(grad.numel() // self._world_size)
- for rank in range(self._world_size):
- grad_current_rank = grad_list[rank].detach()
- self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
- self._grad_in_bucket[rank].append(grad_current_rank)
+ grad = param.grad.clone().detach().flatten()
+ if padding_size > 0:
+ with torch.no_grad():
+ grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
+ grad_list = grad.split(grad.numel() // self._world_size)
+ for rank in range(self._world_size):
+ grad_current_rank = grad_list[rank].clone().detach()
+ self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
+ self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None
+ self.offset_list.append(0)
+
def get_grad(self) -> Dict:
"""Return the dictionary of gradients slices, of which the keys are ranks
@@ -104,10 +119,12 @@ def get_param_id_of_grad(self, grad: Tensor) -> int:
return self.grad_to_param_mapping[id(grad)]
def reset(self):
- self.grad_to_param_mapping = dict()
- self._num_elements_in_bucket = 0
- self._param_list = []
- self._padding_size = []
- self._grad_in_bucket = dict()
+ """Reset the bucket storage after reduction, only release the tensors have been reduced
+ """
+ cur_offset = self.offset_list.pop(0)
+ self._param_list = self._param_list[cur_offset:]
+ self._padding_size = self._padding_size[cur_offset:]
+ for _ in range(cur_offset):
+ del self.grad_to_param_mapping[next(iter(self.grad_to_param_mapping))]
for rank in range(self._world_size):
- self._grad_in_bucket[rank] = []
+ self._grad_in_bucket[rank] = self._grad_in_bucket[rank][cur_offset:]
diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py
index 0b86ec8ca89e..2890b329a642 100644
--- a/colossalai/zero/low_level/bookkeeping/gradient_store.py
+++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py
@@ -57,8 +57,8 @@ def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: in
self._grads_of_params[group_id][param_id].append(grad)
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
- """For old gradient accumulation, not in use now.
- Add a gradient slice on an existing slice of the parameter's gradient
+ """Add a gradient slice on an existing slice of the parameter's gradient
+ Used when no_sync is not activated.
Args:
grad (Tensor): The split gradient to append to list
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 2b3f50ed4fd4..0bdd6a3e2370 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -6,6 +6,7 @@
import torch
import torch.distributed as dist
+import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
@@ -80,9 +81,6 @@ def __init__(
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
- # TODO:
- # 1. state_dict for checkpoint IO
-
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
@@ -242,10 +240,19 @@ def _attach_reduction_hook(self):
def _run_reduction(self):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()
+
flat_grads = self._bucket_store.get_flatten_grad()
flat_grads /= self._world_size
+
+ # ready to add other tensors to bucket
+ self._bucket_store.reset_num_elements_in_bucket()
+
if self._overlap_communication:
stream = self._comm_stream
+ # in case of the memory being reused in the default stream
+ flat_grads.record_stream(stream)
+ # waiting for ops in the default stream finishing
+ stream.wait_stream(torch.cuda.current_stream())
else:
stream = torch.cuda.current_stream()
@@ -268,7 +275,11 @@ def _run_reduction(self):
sync_tensor(flat_grads_per_rank[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
- self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
+ if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
+ param_id)) < self._world_size:
+ self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
+ else:
+ self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
@@ -282,7 +293,10 @@ def _run_reduction(self):
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
for grad in grad_in_bucket_current_rank:
param_id = self._bucket_store.get_param_id_of_grad(grad)
- self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
+ if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
+ self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
+ else:
+ self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
self._bucket_store.reset()
@@ -294,7 +308,7 @@ def _add_to_bucket(self, param, group_id):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \
- group_id != self._bucket_store.current_group_id:
+ group_id != self._bucket_store.current_group_id:
self._run_reduction()
padding_size = self._param_store.get_param_padding_size(param)
@@ -306,7 +320,7 @@ def _add_to_bucket(self, param, group_id):
def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
- "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
+ "ZeRO2(partition_grads) and no_sync are not compatible"
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
@@ -324,6 +338,24 @@ def backward(self, loss, retain_graph=False):
self.zero_grad()
+ def backward_by_grad(self, tensor, grad):
+ assert not(self._partition_grads and not self.require_grad_sync), \
+ "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
+
+ if self.mixed_precision_mixin is not None:
+ grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
+ torch.autograd.backward(tensor, grad)
+
+ if not self.require_grad_sync:
+ return
+ self._reduce_grad(self._partition_grads)
+
+ # clear reduced grads
+ if self._overlap_communication:
+ torch.cuda.synchronize()
+
+ self.zero_grad()
+
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
@@ -349,7 +381,6 @@ def zero_grad(self, set_to_none=True):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
-
if not self.require_grad_sync:
return
@@ -512,9 +543,12 @@ def state_dict(self) -> Dict:
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
working_param = self._param_store.master_to_working_param[id(param)]
- gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
- dist.all_gather(gather_tensor, v, group=self.dp_pg)
- param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
+ gather_tensor = [
+ torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)
+ ]
+ dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
+ param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(
+ working_param).cpu()
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
@@ -537,10 +571,9 @@ def load_state_dict(self, state_dict: Dict):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
- zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
+ zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
- zero_state_dict = dict()
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
@@ -569,9 +602,10 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != 'step':
- state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
- dist.all_gather(state_tensor, v, group=self.dp_pg)
- state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
+ state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)]
+ dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
+ state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(
+ working_param).cpu()
current_block_size += state_tensor.numel()
current_block[k] = state_tensor
@@ -584,3 +618,19 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i
ret_block_size += current_block_size
yield ret_block, ret_block_size
+
+ def update_master_params(self, model: nn.Module) -> None:
+ """Update master params from working params
+
+ Args:
+ model (nn.Module): The model to update master params
+ """
+ for p in model.parameters():
+ p_id = id(p)
+ if p_id in self._param_store.working_to_master_param:
+ master_param = self._param_store.working_to_master_param[p_id]
+ padding_size = self._param_store.get_param_padding_size(p)
+ working_param = p.data.view(-1)
+ if padding_size > 0:
+ working_param = torch.nn.functional.pad(working_param, [0, padding_size])
+ master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
diff --git a/colossalai/zero/low_level/readme.md b/colossalai/zero/low_level/readme.md
index aa92159d8022..b960a436219d 100644
--- a/colossalai/zero/low_level/readme.md
+++ b/colossalai/zero/low_level/readme.md
@@ -1,5 +1,41 @@
# Low Level ZeRO
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
+## Examples of ZeRO and gradient accumulation
+
+The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss.
+
+```python
+# examples of ZeRO1 with gradient accumulation
+...
+outputs = model(input)
+loss = SomeLoss(outputs)
+if (idx + 1) % ACCUMULATE_STEP != 0:
+ with booster.no_sync(model, optimizer):
+ # under this context, the gradient would not sync when backward,
+ # left each rank having different gradient.
+ # It saves the backward time
+ booster.backward(loss, optimizer)
+ continue
+else:
+ # need to sync all the accumulated gradient
+ booster.backward(loss, optimizer):
+ optimizer.step()
+ ...
+```
+
+```python
+# example of ZeRO2 with gradient accumulation
+
+...
+outputs = model(input)
+loss = SomeLoss(outputs)
+# ZeRO2 split the gradients and can NOT accumulate gradient with syncing.
+booster.backward(loss, optimizer)
+if (idx + 1) % ACCUMULATE_STEP == 0:
+ optimizer.step()
+...
+```
+
## Design:
### Notion
@@ -25,11 +61,11 @@ The data structure looks like this:
```
After that, the gradients would be flattened by rank, and the data structure looks like this:
```
-# g-0 means flatten([g-00, g-10])
+# g-X0 means flatten([g-00, g-10])
{
-0: [g-0],
-1: [g-1],
-2: [g-2]
+0: [g-X0],
+1: [g-X1],
+2: [g-X2]
}
```
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.
diff --git a/colossalai/zero/wrapper.py b/colossalai/zero/wrapper.py
index 3e48f49fa305..90325fe0a704 100644
--- a/colossalai/zero/wrapper.py
+++ b/colossalai/zero/wrapper.py
@@ -109,6 +109,6 @@ def zero_optim_wrapper(model: nn.Module,
config_dict['clip_grad_norm'] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
else:
- from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
+ from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
config_dict['clipping_norm'] = max_norm
- return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose)
+ return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index a1e136ee58a5..26d3fab1b6d7 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -18,7 +18,7 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/*
# install torch
-RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
+RUN conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
# install ninja
RUN apt-get update && \
@@ -43,8 +43,9 @@ RUN git clone -b ${VERSION} https://github.com/hpcaitech/ColossalAI.git \
RUN pip install --no-cache-dir titans
# install tensornvme
-RUN conda install cmake && \
+RUN conda install -y cmake && \
git clone https://github.com/hpcaitech/TensorNVMe.git && \
cd TensorNVMe && \
+ apt update -y && apt install -y libaio-dev && \
pip install -r requirements.txt && \
pip install -v --no-cache-dir .
diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md
index 945ca4080413..dda4f86a29a0 100644
--- a/docs/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -24,6 +24,7 @@
## 新闻
+* [2023/09] [70 Billion Parameter LLaMA2 Model Training Accelerated by 195%](https://www.hpc-ai.tech/blog/70b-llama2-training)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
* [2023/07] [65B Model Pretraining Accelerated by 38%, Best Practices for Building LLaMA-Like Base Models Open-Source](https://www.hpc-ai.tech/blog/large-model-pretraining)
* [2023/03] [ColossalChat: An Open-Source Solution for Cloning ChatGPT With a Complete RLHF Pipeline](https://medium.com/@yangyou_berkeley/colossalchat-an-open-source-solution-for-cloning-chatgpt-with-a-complete-rlhf-pipeline-5edf08fb538b)
@@ -49,7 +50,7 @@