Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] chatglm2 infer demo #4724

Merged
merged 23 commits into from
Sep 22, 2023
35 changes: 30 additions & 5 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2

_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"]
_supported_models = [
"LlamaForCausalLM",
"LlamaModel",
"BloomForCausalLM",
"ChatGLMModel",
"ChatGLMForConditionalGeneration",
]


class TPInferEngine:
Expand Down Expand Up @@ -63,7 +69,13 @@ def __init__(

self.head_dim = model.config.hidden_size // model.config.num_attention_heads
self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers
num_hidden_layers = (
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
self.layer_num = num_hidden_layers
self.multi_query_group_num = (
model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
)

self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
Expand All @@ -77,9 +89,22 @@ def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
self.cache_manager = MemoryManager(
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
)
if self.multi_query_group_num:
CjhHa1 marked this conversation as resolved.
Show resolved Hide resolved
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0
), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}"
self.cache_manager = MemoryManager(
self.max_total_token_num,
self.dtype,
self.multi_query_group_num // self.tp_size,
self.head_dim,
self.layer_num,
)
else:
self.cache_manager = MemoryManager(
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
)

def _optimize_model(self, model: nn.Module) -> None:
"""
Expand Down
5 changes: 4 additions & 1 deletion colossalai/inference/tensor_parallel/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import _utils

from .bloom import BloomInferenceForwards
from .chatglm2 import ChatGLM2InferenceForwards
from .llama import LlamaInferenceForwards

__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"]
__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"]
10 changes: 10 additions & 0 deletions colossalai/inference/tensor_parallel/modeling/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Utils for model inference
"""
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest


def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return
Loading
Loading