Skip to content

Commit

Permalink
[Pipeline Inference] Merge pp with tp (hpcaitech#4993)
Browse files Browse the repository at this point in the history
* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo
  • Loading branch information
FoolPlayer authored and flybird11111 committed Nov 10, 2023
1 parent 5a7a47b commit 8ae3722
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 204 deletions.
6 changes: 3 additions & 3 deletions colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .pipeline import PPInferEngine
from .hybridengine import CaiInferEngine
from .hybridengine.polices import LlamaModelInferPolicy


__all__ = ['PPInferEngine']
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
3 changes: 3 additions & 0 deletions colossalai/inference/hybridengine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .engine import CaiInferEngine

__all__ = ["CaiInferEngine"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding

Expand All @@ -8,31 +9,35 @@
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy

from ..pipeline.microbatch_manager import MicroBatchManager
from ..tensor_parallel.kvcache_manager import MemoryManager
from .microbatch_manager import MicroBatchManager

PP_AXIS, TP_AXIS = 0, 1

class PPInferEngine:
_supported_models = [
"LlamaForCausalLM",
]


class CaiInferEngine:
"""
PPInferEngine is a class that handles the pipeline parallel inference.
CaiInferEngine is a class that handles the pipeline parallel inference.
Args:
pp_size (int): the number of pipeline stages.
pp_model (`nn.Module`): the model already in pipeline parallelism style.
tp_size (int): the size of tensor parallelism.
pp_size (int): the size of pipeline parallelism.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.
Example:
```python
from colossalai.inference import PPInferEngine
from colossalai.inference import InferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
Expand All @@ -42,7 +47,7 @@ class PPInferEngine:
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
Expand All @@ -54,12 +59,11 @@ class PPInferEngine:

def __init__(
self,
pp_size: int,
tp_size: int = 1,
pp_size: int = 1,
dtype: str = "fp16",
pp_model: nn.Module = None,
model: nn.Module = None,
model_policy: Policy = None,
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
max_batch_size: int = 4,
Expand All @@ -71,12 +75,21 @@ def __init__(
do_sample: bool = False,
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert model and model_policy, "Model with model_policy should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"

max_output_len = max(max_output_len, max_input_len + new_length)
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"

# TODO: support only tensor parallel inference
assert pp_size > 1, "Not support only tensor parallel inference."
self.pp_size = pp_size
self.tp_size = tp_size

if dtype == "fp16":
self.dtype = torch.float16
model.half()
Expand All @@ -85,24 +98,29 @@ def __init__(
model.to(torch.bfloat16)
else:
self.dtype = torch.float32
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.model = pp_model or self._shardformer(model, model_policy)
self.cache_manager_list = [
self._init_manager(max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
self.stage_manager.stage,
new_length,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)

# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)

stage_manager = None
if pp_size > 1:
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)

self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))

def inference(self, input_list):
"""
Expand All @@ -124,10 +142,10 @@ def inference(self, input_list):
else:
return out

def _shardformer(self, model, model_policy):
def _shardformer(self, model, model_policy, stage_manager, tp_group):
shardconfig = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=self.stage_manager,
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_all_optimization=False,
Expand All @@ -139,14 +157,12 @@ def _shardformer(self, model, model_policy):
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()

def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
head_num = self.model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads
num_hidden_layers = (
self.model.config.num_hidden_layers
if hasattr(self.model.config, "num_hidden_layers")
else self.model.config.num_layers
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size

Expand Down
Loading

0 comments on commit 8ae3722

Please sign in to comment.