Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pipeline inference] Combine kvcache with pipeline inference #4938

Merged
merged 14 commits into from
Oct 27, 2023
3 changes: 2 additions & 1 deletion colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pipeline import PPInferEngine

__all__ = ["PPInferEngine"]

__all__ = ['PPInferEngine']
67 changes: 37 additions & 30 deletions colossalai/inference/pipeline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,22 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag
```python
from colossalai.pipeline import PPInferEngine
# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
model = LlamaForCausalLM.from_pretrained('/path/to/model')
inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
engine = PPInferEngine(
pp_size=2,
dtype='fp16',
micro_batch_size=1,
new_length=10,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy())

output = engine.inference([inputs])
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer

colossalai.launch_from_torch(config={})

model = LlamaForCausalLM.from_pretrained("path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("path_to_model")
# assume the model is infered with 2 pipeline stages
yuanheng-zhao marked this conversation as resolved.
Show resolved Hide resolved
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)

input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference(data.to('cuda'))


```

Expand All @@ -55,30 +60,32 @@ sh run.sh

## Performance

We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G.

### Llama Throughput(tokens/s)
### Llama Throughput (tokens/s) | input length=1024, output length=128

#### 7b, fp16
#### A10 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |

#### 7b, fp32
#### A10 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |

#### 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |

#### A800 7b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |


#### A800 13b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
7 changes: 5 additions & 2 deletions colossalai/inference/pipeline/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import colossalai
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy

GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024
Expand Down Expand Up @@ -117,8 +117,11 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
micro_batch_size=args.mb_size,
new_length=args.new_length,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy(),
model_policy=LlamaModelInferPolicy(),
verbose=True,
max_batch_size=args.mb_size,
max_input_len=args.seq_len,
max_output_len=args.seq_len + args.new_length + 256,
)
data = data_gen(args.batch_size, args.seq_len)

Expand Down
6 changes: 3 additions & 3 deletions colossalai/inference/pipeline/benchmark/run.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
script_dir=$(cd "$(dirname "$0")" && pwd)
cd "${script_dir}"

# 7b, fp32, 2 gpu, 1024, 128
# 7b, fp16, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
Expand All @@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do
--pp_size=2
done

# 7b, fp32, 2 gpu, 512, 512
# 7b, fp16, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16 32; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
Expand All @@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do
--pp_size=2
done

# 7b, fp32, 2 gpu, 1024, 128
# 7b, fp16, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
Expand Down
90 changes: 73 additions & 17 deletions colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding

from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy

from ..tensor_parallel.kvcache_manager import MemoryManager
from .microbatch_manager import MicroBatchManager


Expand All @@ -23,20 +25,29 @@ class PPInferEngine:
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.

Example:

```python
from colossalai.ppinference import PPInferEngine
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer

model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
# assume the model is infered with 4 pipeline stages
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})
colossalai.launch_from_torch(config={})

model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)

input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference([data.to('cuda').data])
FoolPlayer marked this conversation as resolved.
Show resolved Hide resolved

input = ["Hello, my dog is cute, and I like"]
tokenized_input = tokenizer(input, return_tensors='pt')
output = engine.inference([tokenized_input])
```

"""
Expand All @@ -51,31 +62,62 @@ def __init__(
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
max_batch_size: int = 4,
max_input_len: int = 32,
max_output_len: int = 32,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"

max_output_len = max(max_output_len, max_input_len + new_length)

self.pp_size = pp_size
if dtype == "fp16":
self.dtype = torch.float16
model.half()
elif dtype == "bf16":
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
else:
self.dtype = torch.float32
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.model = pp_model or self._shardformer(model, model_policy)
self.cache_manager_list = [self._init_manager(max_batch_size, max_input_len, max_output_len)] * (
micro_batch_buffer_size or pp_size
)
self.mb_manager = MicroBatchManager(
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
self.stage_manager.stage,
new_length,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)

assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == "fp16":
model.half()
elif dtype == "bf16":
model.to(torch.bfloat16)
self.model = pp_model or self._shardformer(model, model_policy)

def inference(self, input_list):
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
"""
Args:
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.

Returns:
out (list): a list of output data, each element is a list of token.
timestamp (float): the time cost of the inference, only return when verbose is `True`.
"""
assert isinstance(
input_list, (BatchEncoding, dict)
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
if isinstance(input_list, BatchEncoding):
input_list = input_list.data
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp
else:
Expand All @@ -95,3 +137,17 @@ def _shardformer(self, model, model_policy):
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()

def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
head_num = self.model.config.num_attention_heads
num_hidden_layers = (
self.model.config.num_hidden_layers
if hasattr(self.model.config, "num_hidden_layers")
else self.model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size

cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager
Loading
Loading