Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pipeline Inference] Merge pp with tp #4993

Merged
merged 7 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .pipeline import PPInferEngine
from .hybridengine import CaiInferEngine
from .hybridengine.polices import LlamaModelInferPolicy


__all__ = ['PPInferEngine']
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
3 changes: 3 additions & 0 deletions colossalai/inference/hybridengine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .engine import CaiInferEngine

__all__ = ["CaiInferEngine"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding

Expand All @@ -8,31 +9,35 @@
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy

from ..pipeline.microbatch_manager import MicroBatchManager
from ..tensor_parallel.kvcache_manager import MemoryManager
from .microbatch_manager import MicroBatchManager

PP_AXIS, TP_AXIS = 0, 1

class PPInferEngine:
_supported_models = [
"LlamaForCausalLM",
]


class CaiInferEngine:
"""
PPInferEngine is a class that handles the pipeline parallel inference.
CaiInferEngine is a class that handles the pipeline parallel inference.

Args:
pp_size (int): the number of pipeline stages.
pp_model (`nn.Module`): the model already in pipeline parallelism style.
tp_size (int): the size of tensor parallelism.
pp_size (int): the size of pipeline parallelism.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.

Example:

```python
from colossalai.inference import PPInferEngine
from colossalai.inference import InferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
Expand All @@ -42,7 +47,7 @@ class PPInferEngine:
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())

input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
Expand All @@ -54,12 +59,11 @@ class PPInferEngine:

def __init__(
self,
pp_size: int,
tp_size: int = 1,
pp_size: int = 1,
dtype: str = "fp16",
pp_model: nn.Module = None,
model: nn.Module = None,
model_policy: Policy = None,
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
max_batch_size: int = 4,
Expand All @@ -71,12 +75,21 @@ def __init__(
do_sample: bool = False,
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
assert (
tp_size * pp_size == dist.get_world_size()
), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert model and model_policy, "Model with model_policy should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"

max_output_len = max(max_output_len, max_input_len + new_length)
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"

# TODO: support only tensor parallel inference
assert pp_size > 1, "Not support only tensor parallel inference."
self.pp_size = pp_size
self.tp_size = tp_size

tiandiao123 marked this conversation as resolved.
Show resolved Hide resolved
if dtype == "fp16":
self.dtype = torch.float16
model.half()
Expand All @@ -85,24 +98,29 @@ def __init__(
model.to(torch.bfloat16)
else:
self.dtype = torch.float32
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.model = pp_model or self._shardformer(model, model_policy)
self.cache_manager_list = [
self._init_manager(max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
self.stage_manager.stage,
new_length,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)

# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)

stage_manager = None
if pp_size > 1:
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)

self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))

def inference(self, input_list):
"""
Expand All @@ -124,10 +142,10 @@ def inference(self, input_list):
else:
return out

def _shardformer(self, model, model_policy):
def _shardformer(self, model, model_policy, stage_manager, tp_group):
shardconfig = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=self.stage_manager,
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_all_optimization=False,
Expand All @@ -139,14 +157,12 @@ def _shardformer(self, model, model_policy):
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()

def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
head_num = self.model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads
num_hidden_layers = (
self.model.config.num_hidden_layers
if hasattr(self.model.config, "num_hidden_layers")
else self.model.config.num_layers
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size

Expand Down
Loading
Loading