Skip to content

Commit

Permalink
[Pipeline inference] Combine kvcache with pipeline inference (hpcaite…
Browse files Browse the repository at this point in the history
…ch#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test
  • Loading branch information
FoolPlayer authored and flybird11111 committed Nov 9, 2023
1 parent fe79560 commit a610046
Show file tree
Hide file tree
Showing 19 changed files with 881 additions and 704 deletions.
3 changes: 2 additions & 1 deletion colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pipeline import PPInferEngine

__all__ = ["PPInferEngine"]

__all__ = ['PPInferEngine']
75 changes: 37 additions & 38 deletions colossalai/inference/pipeline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).

1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
- Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`.
- Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`.
- Run the pipeline inference model.

2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
Expand All @@ -31,54 +31,53 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag

### Example
```python
from colossalai.pipeline import PPInferEngine
# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
model = LlamaForCausalLM.from_pretrained('/path/to/model')
inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
engine = PPInferEngine(
pp_size=2,
dtype='fp16',
micro_batch_size=1,
new_length=10,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy())

output = engine.inference([inputs])
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer

```
colossalai.launch_from_torch(config={})

model = LlamaForCausalLM.from_pretrained("/path/to/model")
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")

### Quick start
```shell
cd benchmark
sh run.sh
# assume the model is inferred with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32)

input = ["Introduce a landmark in London","Introduce a landmark in Singapore"]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference(data.to('cuda'))
print(tokenizer.batch_decode(output))
```

## Performance

We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G.

### Llama Throughput(tokens/s)
### Llama Throughput (tokens/s) | input length=1024, output length=128

#### 7b, fp16
#### A10 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |

#### 7b, fp32
#### A10 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |

#### 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |

#### A800 7b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |


#### A800 13b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
7 changes: 5 additions & 2 deletions colossalai/inference/pipeline/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import colossalai
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy

GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024
Expand Down Expand Up @@ -117,8 +117,11 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
micro_batch_size=args.mb_size,
new_length=args.new_length,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy(),
model_policy=LlamaModelInferPolicy(),
verbose=True,
max_batch_size=args.mb_size,
max_input_len=args.seq_len,
max_output_len=args.seq_len + args.new_length + 256,
)
data = data_gen(args.batch_size, args.seq_len)

Expand Down
6 changes: 3 additions & 3 deletions colossalai/inference/pipeline/benchmark/run.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
script_dir=$(cd "$(dirname "$0")" && pwd)
cd "${script_dir}"

# 7b, fp32, 2 gpu, 1024, 128
# 7b, fp16, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
Expand All @@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do
--pp_size=2
done

# 7b, fp32, 2 gpu, 512, 512
# 7b, fp16, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16 32; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
Expand All @@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do
--pp_size=2
done

# 7b, fp32, 2 gpu, 1024, 128
# 7b, fp16, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
Expand Down
91 changes: 74 additions & 17 deletions colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding

from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy

from ..tensor_parallel.kvcache_manager import MemoryManager
from .microbatch_manager import MicroBatchManager


Expand All @@ -23,20 +25,29 @@ class PPInferEngine:
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.
Example:
```python
from colossalai.ppinference import PPInferEngine
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
# assume the model is infered with 4 pipeline stages
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})
colossalai.launch_from_torch(config={})
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference([data.to('cuda').data])
input = ["Hello, my dog is cute, and I like"]
tokenized_input = tokenizer(input, return_tensors='pt')
output = engine.inference([tokenized_input])
```
"""
Expand All @@ -51,31 +62,63 @@ def __init__(
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
max_batch_size: int = 4,
max_input_len: int = 32,
max_output_len: int = 32,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"

max_output_len = max(max_output_len, max_input_len + new_length)

self.pp_size = pp_size
if dtype == "fp16":
self.dtype = torch.float16
model.half()
elif dtype == "bf16":
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
else:
self.dtype = torch.float32
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.model = pp_model or self._shardformer(model, model_policy)
self.cache_manager_list = [
self._init_manager(max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
self.stage_manager.stage,
new_length,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)

assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == "fp16":
model.half()
elif dtype == "bf16":
model.to(torch.bfloat16)
self.model = pp_model or self._shardformer(model, model_policy)

def inference(self, input_list):
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
"""
Args:
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
Returns:
out (list): a list of output data, each element is a list of token.
timestamp (float): the time cost of the inference, only return when verbose is `True`.
"""
assert isinstance(
input_list, (BatchEncoding, dict)
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
if isinstance(input_list, BatchEncoding):
input_list = input_list.data
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp
else:
Expand All @@ -95,3 +138,17 @@ def _shardformer(self, model, model_policy):
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()

def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
head_num = self.model.config.num_attention_heads
num_hidden_layers = (
self.model.config.num_hidden_layers
if hasattr(self.model.config, "num_hidden_layers")
else self.model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size

cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager
Loading

0 comments on commit a610046

Please sign in to comment.