From 4ab4520e390fbe8d9d7c831b9c73dda28c01c32f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 18 Oct 2023 13:42:46 +0800 Subject: [PATCH 01/14] merge kvcache with pipeline inference and refactor the code structure --- colossalai/inference/__init__.py | 5 +- colossalai/inference/cai_engine/__init__.py | 3 + .../inference/cai_engine/batch_infer_state.py | 120 +++++ .../{pipeline => cai_engine}/engine.py | 52 +- .../inference/cai_engine/kvcache_manager.py | 104 ++++ .../inference/cai_engine/modeling/__init__.py | 3 + .../inference/cai_engine/modeling/_utils.py | 67 +++ .../inference/cai_engine/modeling/llama.py | 477 ++++++++++++++++++ .../inference/cai_engine/policies/__init__.py | 3 + .../inference/cai_engine/policies/llama.py | 139 +++++ colossalai/inference/pipeline/__init__.py | 4 +- .../inference/pipeline/microbatch_manager.py | 102 ++-- .../inference/pipeline/modeling/__init__.py | 0 .../inference/pipeline/modeling/gpt2.py | 280 ---------- .../inference/pipeline/modeling/llama.py | 229 --------- .../inference/pipeline/policy/gpt2_ppinfer.py | 74 --- .../pipeline/policy/llama_ppinfer.py | 48 -- colossalai/pipeline/schedule/generate.py | 65 +-- tests/test_infer/test_pipeline_infer.py | 27 +- 19 files changed, 1068 insertions(+), 734 deletions(-) create mode 100644 colossalai/inference/cai_engine/__init__.py create mode 100644 colossalai/inference/cai_engine/batch_infer_state.py rename colossalai/inference/{pipeline => cai_engine}/engine.py (68%) create mode 100644 colossalai/inference/cai_engine/kvcache_manager.py create mode 100644 colossalai/inference/cai_engine/modeling/__init__.py create mode 100644 colossalai/inference/cai_engine/modeling/_utils.py create mode 100644 colossalai/inference/cai_engine/modeling/llama.py create mode 100644 colossalai/inference/cai_engine/policies/__init__.py create mode 100644 colossalai/inference/cai_engine/policies/llama.py delete mode 100644 colossalai/inference/pipeline/modeling/__init__.py delete mode 100644 colossalai/inference/pipeline/modeling/gpt2.py delete mode 100644 colossalai/inference/pipeline/modeling/llama.py delete mode 100644 colossalai/inference/pipeline/policy/gpt2_ppinfer.py delete mode 100644 colossalai/inference/pipeline/policy/llama_ppinfer.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index 35891307e754..5975de786d46 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,3 +1,4 @@ -from .pipeline import PPInferEngine +from .cai_engine import CaiInferEngine -__all__ = ["PPInferEngine"] + +__all__ = ['CaiInferEngine'] diff --git a/colossalai/inference/cai_engine/__init__.py b/colossalai/inference/cai_engine/__init__.py new file mode 100644 index 000000000000..a80de0898d06 --- /dev/null +++ b/colossalai/inference/cai_engine/__init__.py @@ -0,0 +1,3 @@ +from .engine import CaiInferEngine + +__all__ = ['CaiInferEngine'] diff --git a/colossalai/inference/cai_engine/batch_infer_state.py b/colossalai/inference/cai_engine/batch_infer_state.py new file mode 100644 index 000000000000..ec70cb6cbfd0 --- /dev/null +++ b/colossalai/inference/cai_engine/batch_infer_state.py @@ -0,0 +1,120 @@ +# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later +from dataclasses import dataclass + +import torch + +from .kvcache_manager import MemoryManager +from transformers.tokenization_utils_base import BatchEncoding + + +@dataclass +class BatchInferState: + r""" + Information to be passed and used for a batch of inputs during + a single model forward + """ + batch_size: int + max_len_in_batch: int + + cache_manager: MemoryManager = None + + block_loc: torch.Tensor = None + start_loc: torch.Tensor = None + seq_len: torch.Tensor = None + past_key_values_len: int = None + + is_context_stage: bool = False + context_mem_index: torch.Tensor = None + decode_is_contiguous: bool = None + decode_mem_start: int = None + decode_mem_end: int = None + decode_mem_index: torch.Tensor = None + decode_layer_id: int = None + + device: torch.device = torch.device("cuda") + + @property + def total_token_num(self): + # return self.batch_size * self.max_len_in_batch + assert self.seq_len is not None and self.seq_len.size(0) > 0 + return int(torch.sum(self.seq_len)) + + def set_cache_manager(self, manager: MemoryManager): + self.cache_manager = manager + + @staticmethod + def init_block_loc( + b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor + ): + """in-place update block loc mapping based on the sequence length of the inputs in current bath""" + start_index = 0 + seq_len_numpy = seq_len.cpu().numpy() + for i, cur_seq_len in enumerate(seq_len_numpy): + b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[ + start_index : start_index + cur_seq_len + ] + start_index += cur_seq_len + return + + @classmethod + def init_from_batch( + cls, + batch: torch.Tensor, + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + ): + if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)): + raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state") + + input_ids_list = None + attention_mask = None + + if isinstance(batch, (BatchEncoding, dict)): + input_ids_list = batch["input_ids"] + attention_mask = batch["attention_mask"] + else: + input_ids_list = batch + if isinstance(input_ids_list[0], int): # for a single input + input_ids_list = [input_ids_list] + attention_mask = [attention_mask] if attention_mask is not None else attention_mask + + batch_size = len(input_ids_list) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + + max_len_in_batch = -1 + if isinstance(batch, (BatchEncoding, dict)): + for i, attn_mask in enumerate(attention_mask): + curr_seq_len = len(attn_mask) + # if isinstance(attn_mask, torch.Tensor): + # curr_seq_len = int(torch.sum(attn_mask)) + # else: + # curr_seq_len = int(sum(attn_mask)) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + length = max(len(input_id) for input_id in input_ids_list) + for i, input_ids in enumerate(input_ids_list): + curr_seq_len = length + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda") + + return cls( + batch_size=batch_size, + max_len_in_batch=max_len_in_batch, + seq_len=seq_lengths.to('cuda'), + start_loc = seq_start_indexes.to("cuda"), + block_loc = block_loc, + decode_layer_id = 0, + past_key_values_len = 0, + is_context_stage = True, + cache_manager=cache_manager, + ) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/cai_engine/engine.py similarity index 68% rename from colossalai/inference/pipeline/engine.py rename to colossalai/inference/cai_engine/engine.py index 4f42385caf8f..dc52ac699b5f 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/cai_engine/engine.py @@ -7,10 +7,11 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy -from .microbatch_manager import MicroBatchManager +from ..pipeline import MicroBatchManager +from .kvcache_manager import MemoryManager -class PPInferEngine: +class CaiInferEngine: """ PPInferEngine is a class that handles the pipeline parallel inference. @@ -51,6 +52,9 @@ def __init__( new_length: int = 32, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, + max_batch_size: int = 4, + max_input_len: int = 32, + max_output_len: int = 32, verbose: bool = False, # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, @@ -58,22 +62,28 @@ def __init__( num_beams: int = 1, ) -> None: assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." + assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" + + max_output_len = max(max_output_len, max_input_len + new_length) + self.pp_size = pp_size + if dtype == 'fp16': + self.dtype = torch.float16 + model.half() + elif dtype == 'bf16': + self.dtype = torch.bfloat16 + model.to(torch.bfloat16) + else: + self.dtype = torch.float32 self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) - self.mb_manager = MicroBatchManager( - self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size - ) + self.model = pp_model or self._shardformer(model, model_policy) + self.cache_manager_list = [self._init_manager(max_batch_size, max_input_len, max_output_len)]*(micro_batch_buffer_size or pp_size) + self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, + micro_batch_buffer_size or pp_size, max_input_len, max_output_len, self.cache_manager_list) self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" - if dtype == "fp16": - model.half() - elif dtype == "bf16": - model.to(torch.bfloat16) - self.model = pp_model or self._shardformer(model, model_policy) - def inference(self, input_list): out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) if self.verbose: @@ -95,3 +105,21 @@ def _shardformer(self, model, model_policy): shardformer = ShardFormer(shard_config=shardconfig) shard_model, _ = shardformer.optimize(model, model_policy) return shard_model.cuda() + + def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: + max_total_token_num = max_batch_size * (max_input_len + max_output_len) + head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + head_num = self.model.config.num_attention_heads + num_hidden_layers = ( + self.model.config.num_hidden_layers if hasattr(self.model.config, "num_hidden_layers") else self.model.config.num_layers + ) + layer_num = num_hidden_layers // self.pp_size + + cache_manager = MemoryManager( + max_total_token_num, + self.dtype, + head_num, + head_dim, + layer_num + ) + return cache_manager \ No newline at end of file diff --git a/colossalai/inference/cai_engine/kvcache_manager.py b/colossalai/inference/cai_engine/kvcache_manager.py new file mode 100644 index 000000000000..e74a3a491a7b --- /dev/null +++ b/colossalai/inference/cai_engine/kvcache_manager.py @@ -0,0 +1,104 @@ +# Adapted from lightllm/common/mem_manager.py +# of the ModelTC/lightllm GitHub repository +# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py + +import torch +from transformers.utils import logging + + +class MemoryManager: + r""" + Manage token block indexes and allocate physical memory for key and value cache + + Args: + size: maximum token number used as the size of key and value buffer + dtype: data type of cached key and value + head_num: number of heads the memory manager is responsible for + head_dim: embedded size per head + layer_num: the number of layers in the model + device: device used to store the key and value cache + """ + + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + device: torch.device = torch.device("cuda"), + ): + self.logger = logging.get_logger(__name__) + self.available_size = size + self.past_key_values_length = 0 + self._init_mem_states(size, device) + self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) + + def _init_mem_states(self, size, device): + """Initialize tensors used to manage memory states""" + self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) + self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) + self.indexes = torch.arange(0, size, dtype=torch.long, device=device) + + def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): + """Initialize key buffer and value buffer on specified device""" + self.key_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] + self.value_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] + + @torch.no_grad() + def alloc(self, required_size): + """allocate space of required_size by providing indexes representing available physical spaces""" + if required_size > self.available_size: + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") + return None + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) + select_index = self.indexes[select_index] + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + return select_index + + @torch.no_grad() + def alloc_contiguous(self, required_size): + """allocate contiguous space of required_size""" + if required_size > self.available_size: + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") + return None + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + sum_size = len(self.mem_cum_sum) + loc_sums = ( + self.mem_cum_sum[required_size - 1 :] + - self.mem_cum_sum[0 : sum_size - required_size + 1] + + self.mem_state[0 : sum_size - required_size + 1] + ) + can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size] + if can_used_loc.shape[0] == 0: + self.logger.info( + f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}" + ) + return None + start_loc = can_used_loc[0] + select_index = self.indexes[start_loc : start_loc + required_size] + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + start = start_loc.item() + end = start + required_size + return select_index, start, end + + @torch.no_grad() + def free(self, free_index): + """free memory by updating memory states based on given indexes""" + self.available_size += free_index.shape[0] + self.mem_state[free_index] = 1 + + @torch.no_grad() + def free_all(self): + """free all memory by updating memory states""" + self.available_size = len(self.mem_state) + self.mem_state[:] = 1 + self.past_key_values_length = 0 + self.logger.info("freed all space of memory manager") diff --git a/colossalai/inference/cai_engine/modeling/__init__.py b/colossalai/inference/cai_engine/modeling/__init__.py new file mode 100644 index 000000000000..239bdebd7efd --- /dev/null +++ b/colossalai/inference/cai_engine/modeling/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaInferenceForwards + +__all__ = ["LlamaInferenceForwards"] diff --git a/colossalai/inference/cai_engine/modeling/_utils.py b/colossalai/inference/cai_engine/modeling/_utils.py new file mode 100644 index 000000000000..068b64b4f829 --- /dev/null +++ b/colossalai/inference/cai_engine/modeling/_utils.py @@ -0,0 +1,67 @@ +""" +Utils for model inference +""" +import os + +import torch + +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + +def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + """ + This function copies the key and value cache to the memory cache + Args: + layer_id : id of current layer + key_buffer : key cache + value_buffer : value cache + context_mem_index : index of memory cache in kv cache manager + mem_manager : cache manager + """ + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + + +def init_to_get_rotary(self, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + self : Model that holds the rotary positional embedding + base : calculation arg + use_elem : activated when using chatglm-based models + """ + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) + + if ntk_alpha is not None: + ntk_alpha = float(ntk_alpha) + assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + + n_elem = self.config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/inference/cai_engine/modeling/llama.py b/colossalai/inference/cai_engine/modeling/llama.py new file mode 100644 index 000000000000..00907ab5f2b3 --- /dev/null +++ b/colossalai/inference/cai_engine/modeling/llama.py @@ -0,0 +1,477 @@ +from typing import List, Optional, Tuple + +import torch +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm +from transformers.utils import logging + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.pipeline.stage_manager import PipelineStageManager + +from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd + +from ._utils import copy_kv_to_mem_cache + +try: + from vllm import layernorm_ops, pos_encoding_ops + + rms_norm = layernorm_ops.rms_norm + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print( + "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + ) + HAS_VLLM_KERNERL = False + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaInferenceForwards: + """ + This class holds forwards for llama inference. + We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. + """ + + @staticmethod + def llama_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + + # If is first stage and after warmup, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + lm_logits = self.lm_head(hidden_states) + return {'logits': lm_logits} + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaInferenceForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + return outputs + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + # batch_size = input_ids.shape[0] # input_ids.shape[0] + # print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}") + + # infer_state = self.infer_state + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager is None or stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + assert stage_manager is not None + assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}" + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assume prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + new_shape = [1] * position_ids.dim() + new_shape[0] = batch_size + position_ids = position_ids.repeat(*new_shape).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + infer_state.decode_layer_id = 0 + + start_idx, end_idx = stage_index[0], stage_index[1] + if past_key_values is None: + past_key_values = tuple([None] * (end_idx - start_idx + 1)) + + for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): + decoder_layer = self.layers[idx] + # NOTE: modify here for passing args to decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + infer_state.decode_layer_id += 1 + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + next_cache = next_decoder_cache if use_cache else None + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + # TODO: fix this to necessary return + # if not return_dict: + # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + # print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}") + return {'hidden_states': hidden_states, 'past_key_values': next_cache} + + + @staticmethod + def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # NOTE might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + cos, sin = infer_state.position_cos, infer_state.position_sin + + rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # print(f"rank:{torch.distributed.get_rank()}, {infer_state}") + # first token generation + + # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_states) + + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + # print(f"[After prefill] rank:{torch.distributed.get_rank()}, {infer_state}") + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # second token and follows + # kv = torch.stack((key_states, value_states), dim=2) + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + # print(f"rank:{torch.distributed.get_rank()}, {infer_state}") + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + # print(f"rank:{torch.distributed.get_rank()}, {attn_output}") + attn_output = self.o_proj(attn_output) + + # return past_key_value as None + return attn_output, None, None + + +def get_llama_vllm_rmsnorm_forward(): + if HAS_VLLM_KERNERL: + + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None diff --git a/colossalai/inference/cai_engine/policies/__init__.py b/colossalai/inference/cai_engine/policies/__init__.py new file mode 100644 index 000000000000..7271812c5366 --- /dev/null +++ b/colossalai/inference/cai_engine/policies/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaModelInferPolicy + +__all__ = ["LlamaModelInferPolicy"] diff --git a/colossalai/inference/cai_engine/policies/llama.py b/colossalai/inference/cai_engine/policies/llama.py new file mode 100644 index 000000000000..865d36d7a601 --- /dev/null +++ b/colossalai/inference/cai_engine/policies/llama.py @@ -0,0 +1,139 @@ +from functools import partial +from typing import List +from torch.nn import Module + +import torch +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm, LlamaForCausalLM + +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription + +# import colossalai +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward + +try: + from colossalai.kernel.triton import rmsnorm_forward + + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + ], + ) + + self.shard_config._infer() + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward(model_cls=LlamaForCausalLM, + new_forward=LlamaInferenceForwards.llama_causal_lm_forward, + policy=policy) + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + else: + # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 + infer_forward = get_llama_vllm_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_first_stage(): + held_layers.append(self.model.lm_head) + return held_layers diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py index 41af9f3ef948..98711b0b18d6 100644 --- a/colossalai/inference/pipeline/__init__.py +++ b/colossalai/inference/pipeline/__init__.py @@ -1,3 +1,3 @@ -from .engine import PPInferEngine +from .microbatch_manager import MicroBatchManager -__all__ = ["PPInferEngine"] +__all__ = ['MicroBatchManager'] diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 49d1bf3f42cb..ca33c26079a0 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -1,6 +1,7 @@ from enum import Enum from typing import Dict, Tuple - +from ..tensor_parallel.batch_infer_state import BatchInferState +from ..tensor_parallel.kvcache_manager import MemoryManager import torch __all__ = "MicroBatchManager" @@ -27,21 +28,22 @@ class MicroBatchDescription: def __init__( self, inputs_dict: Dict[str, torch.Tensor], - output_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, new_length: int, ) -> None: - assert output_dict.get("hidden_states") is not None - self.mb_length = output_dict["hidden_states"].shape[-2] + self.mb_length = inputs_dict['input_ids'].shape[-1] self.target_length = self.mb_length + new_length - self.kv_cache = () - - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - if output_dict is not None: - self._update_kvcache(output_dict["past_key_values"]) + self.infer_state = BatchInferState.init_from_batch( + batch=inputs_dict, + max_input_len=max_input_len, + max_output_len=max_output_len, + cache_manager=cache_manager) + # print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}") - def _update_kvcache(self, kv_cache: Tuple): - assert type(kv_cache) == tuple - self.kv_cache = kv_cache + def update(self, *args, **kwargs): + pass @property def state(self): @@ -80,17 +82,21 @@ class HeadMicroBatchDescription(MicroBatchDescription): """ def __init__( - self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int - ) -> None: - super().__init__(inputs_dict, output_dict, new_length) + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + new_length: int + ) -> None: + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) assert inputs_dict is not None assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None self.input_ids = inputs_dict["input_ids"] self.attn_mask = inputs_dict["attention_mask"] self.new_tokens = None - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - super().update(output_dict, new_token) + def update(self, new_token: torch.Tensor = None): if new_token is not None: self._update_newtokens(new_token) if self.state is not Status.DONE and new_token is not None: @@ -125,27 +131,25 @@ class BodyMicroBatchDescription(MicroBatchDescription): Args: inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. - output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. """ def __init__( - self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int - ) -> None: - super().__init__(inputs_dict, output_dict, new_length) - - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - super().update(output_dict, new_token) - + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + new_length: int + ) -> None: + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) + @property def cur_length(self): """ When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 """ - if len(self.kv_cache) == 0: - return self.mb_length - else: - return self.kv_cache[0][0].shape[-2] + 1 + return self.infer_state.seq_len.max().item() class MicroBatchManager: @@ -160,16 +164,34 @@ class MicroBatchManager: """ - def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + def __init__( + self, + stage: int, + new_length: int, + micro_batch_size: int, + micro_batch_buffer_size: int, + max_input_len: int, + max_output_len: int, + cache_manager_list: MemoryManager + ): self.stage = stage self.new_length = new_length self.micro_batch_size = micro_batch_size self.buffer_size = micro_batch_buffer_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.cache_manager_list = cache_manager_list self.mb_descrption_buffer = {} self.new_tokens_buffer = {} self.idx = 0 - def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): + if self.stage == 0: + self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length) + else: + self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length) + + def step(self, new_token: torch.Tensor = None): """ Update the state if microbatch manager, 2 conditions. 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. @@ -181,11 +203,7 @@ def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, ne new_token (torch.Tensor): the new token generated by current stage. """ # Add descrption first if the descrption is None - if inputs_dict is None and output_dict is None and new_token is None: - return Status.PREFILL - if self.mb_descrption_buffer.get(self.idx) is None: - self._add_descrption(inputs_dict, output_dict) - self.cur_descrption.update(output_dict, new_token) + self.cur_descrption.update(new_token) return self.cur_state def export_new_tokens(self): @@ -204,16 +222,12 @@ def is_micro_batch_done(self): def clear(self): self.mb_descrption_buffer.clear() + for cache in self.cache_manager_list: + cache.free_all() def next(self): self.idx = (self.idx + 1) % self.buffer_size - def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]): - if self.stage == 0: - self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length) - else: - self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length) - def _remove_descrption(self): self.mb_descrption_buffer.pop(self.idx) @@ -222,10 +236,10 @@ def cur_descrption(self) -> MicroBatchDescription: return self.mb_descrption_buffer.get(self.idx) @property - def cur_kv_cache(self): + def cur_infer_state(self): if self.cur_descrption is None: return None - return self.cur_descrption.kv_cache + return self.cur_descrption.infer_state @property def cur_state(self): diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/colossalai/inference/pipeline/modeling/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py deleted file mode 100644 index d2bfcb8b6842..000000000000 --- a/colossalai/inference/pipeline/modeling/gpt2.py +++ /dev/null @@ -1,280 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union - -import torch -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions -from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - - -class GPT2PipelineForwards: - """ - This class serves as a micro library for forward function substitution of GPT2 models - under pipeline setting. - """ - - @staticmethod - def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. - # Please refer to original code of transformers for more details. - logger = logging.get_logger(__name__) - - # Preprocess passed in arguments - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - else: - if hidden_states is None: - raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if stage_manager.is_first_stage(): - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - else: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # Going through held blocks. - start_idx, end_idx = stage_index[0], stage_index[1] - for i, layer_past in zip(range(start_idx, end_idx), past_key_values): - block = self.h[i] - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - if stage_manager.is_last_stage(): - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - return {"hidden_states": hidden_states, "past_key_values": presents} - - @staticmethod - def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # If is first stage and after warmup, go throught lm_head first - if stage_manager.is_first_stage() and hidden_states is not None: - lm_logits = self.lm_head(hidden_states) - return {"logits": lm_logits} - - # Not first stage or before warmup, go through gpt2 model - outputs = GPT2PipelineForwards.gpt2_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - - return outputs diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py deleted file mode 100644 index f46e1fbdd7b3..000000000000 --- a/colossalai/inference/pipeline/modeling/llama.py +++ /dev/null @@ -1,229 +0,0 @@ -from typing import List, Optional - -import torch -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - - -class LlamaPipelineForwards: - """ - This class serves as a micro library for forward function substitution of Llama models - under pipeline setting. - """ - - def llama_model_forward( - self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): - logger = logging.get_logger(__name__) - - # Preprocess passed in arguments - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs_embeds - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - # embed positions, for the first stage, hidden_states is the input embeddings, - # for the other stages, hidden_states is the output of the previous stage - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - start_idx, end_idx = stage_index[0], stage_index[1] - if past_key_values is None: - past_key_values = tuple([None] * (end_idx - start_idx + 1)) - - for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): - decoder_layer = self.layers[idx] - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if stage_manager.is_last_stage(): - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - - # always return dict for imediate stage - return {"hidden_states": hidden_states, "past_key_values": next_cache} - - def llama_for_causal_lm_forward( - self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - logger = logging.get_logger(__name__) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - # If is first stage and after warmup, go throught lm_head first - if stage_manager.is_first_stage() and hidden_states is not None: - lm_logits = self.lm_head(hidden_states) - return {"logits": lm_logits} - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = LlamaPipelineForwards.llama_model_forward( - self.model, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - - return outputs diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py deleted file mode 100644 index 51e6425b113e..000000000000 --- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py +++ /dev/null @@ -1,74 +0,0 @@ -from functools import partial -from typing import Callable, Dict, List - -from torch import Tensor, nn - -import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from colossalai.shardformer.policies.gpt2 import GPT2Policy - -from ..modeling.gpt2 import GPT2PipelineForwards - - -class GPT2LMHeadModelPipelinePolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel - - module_policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2LMHeadModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} - ) - ] - ) - } - module_policy.update(addon_module) - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy, - ) - return module_policy - - def get_held_layers(self) -> List[nn.Module]: - held_layers = super().get_held_layers() - # make the tie weight lm_head and embedding in the same device to save memory - # if self.pipeline_stage_manager.is_first_stage(): - if self.pipeline_stage_manager.is_first_stage(): - held_layers.append(self.model.lm_head) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """The weights of wte and lm_head are shared.""" - module = self.model - stage_manager = self.pipeline_stage_manager - if stage_manager is not None: - if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] - return [] - - def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if not self.pipeline_stage_manager: - raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "GPT2Model": - module = self.model - else: - module = self.model.transformer - - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py deleted file mode 100644 index 6e12ed61bf7b..000000000000 --- a/colossalai/inference/pipeline/policy/llama_ppinfer.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import List - -from torch.nn import Module - -from colossalai.shardformer.layer import Linear1D_Col -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription -from colossalai.shardformer.policies.llama import LlamaPolicy - -from ..modeling.llama import LlamaPipelineForwards - - -class LlamaForCausalLMPipelinePolicy(LlamaPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers import LlamaForCausalLM - - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm - new_item = { - LlamaForCausalLM: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ) - ] - ) - } - policy.update(new_item) - - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy - ) - - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - stage_manager = self.pipeline_stage_manager - held_layers = super().get_held_layers() - if stage_manager.is_first_stage(): - held_layers.append(self.model.lm_head) - return held_layers diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 1f4bbe9f8dad..66f5e66cef3d 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -93,9 +93,9 @@ def _prepare_inputs_for_interval_stage(self): Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` """ - model_inputs = ( - {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None - ) + model_inputs = { + 'infer_state': self.mb_manager.cur_descrption.infer_state + } return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): @@ -108,9 +108,9 @@ def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` """ new_mask = self.mb_manager.cur_descrption.attn_mask - past_key_values = self.mb_manager.cur_descrption.kv_cache + past_key_values = self.mb_manager.cur_descrption.infer_state - return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values) + return dict(input_ids=new_token, attention_mask=new_mask) def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: last_hidden_state = hidden_state[:, -1] @@ -128,27 +128,38 @@ def _recv_pre_stage(self) -> Any: return self.comm.p2p_recv() return self.comm.recv_forward() + def _init_infer_state_action(self) -> None: + """ + This action is only for no first stage, to load batch and init infer_state. + 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state + """ + inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) + def _load_stage_action(self, model: Module) -> None: """ - In this action, 1.load micro_batch 2.do the forward 3.step to update + This action is only for first stage, load, init and do forward. + 1.load micro_batch 2.do the forward 3.step to update """ inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - output_dict = model_forward(model, inputs_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _gen_token_action(self, model: Module): """ - In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update + This action is only for first stage + 1.do the forward with hidden_states to generate new tokens 2.step to update """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" - hidden_states = {"hidden_states": hidden_states} - logits = model_forward(model, None, hidden_states) + interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) @@ -157,7 +168,7 @@ def _gen_token_action(self, model: Module): ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits["logits"]) - self.mb_manager.step(None, None, new_token) + self.mb_manager.step(new_token) self.action_interval_buffer.new_token = new_token self.action_interval_buffer.hidden_states = None @@ -168,20 +179,19 @@ def _head_encoding_action(self, model: Module): new_token = self.action_interval_buffer.new_token assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" inputs_dict = self._prepare_inputs_for_new_token(new_token) - output_dict = model_forward(model, inputs_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" - inputs_dict = self._prepare_inputs_for_interval_stage() - hidden_states = {"hidden_states": hidden_states} - output_dict = model_forward(model, inputs_dict, hidden_states) + # inputs_dict = self._prepare_inputs_for_interval_stage() + interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, None, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ @@ -218,6 +228,8 @@ def _gen_action(self, model: Module): actions.append(partial(self._gen_token_action, model)) # other stage else: + if self.mb_manager.cur_state is Status.PREFILL: + actions.append(partial(self._init_infer_state_action)) actions.append(partial(self._comm_action, True)) actions.append(partial(self._body_encoding_action, model)) @@ -309,7 +321,6 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) - self.mb_manager.step(inputs_dict, output_dict, None) # In GENERATE phase else: # Get hidden_states from previous stage @@ -323,21 +334,17 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert ( - "logits" in logits - ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits["logits"]) - self.mb_manager.step(None, None, new_token) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits['logits']) + self.mb_manager.step(new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) output_dict = model_forward(model, inputs_dict, None) - self.mb_manager.step(inputs_dict, output_dict, None) else: assert hidden_states is not None, "When not first stage, the hidden states should not be None" inputs_dict = self._prepare_inputs_for_interval_stage() output_dict = model_forward(model, inputs_dict, hidden_states) - self.mb_manager.step(inputs_dict, output_dict, None) # Current microbatch is not DONE, send hidden_state to next stage if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index ad8e32b48bae..b5213c905141 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -4,8 +4,8 @@ import transformers import colossalai -from colossalai.inference.pipeline.engine import PPInferEngine -from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy +from colossalai.inference import CaiInferEngine +from colossalai.inference.cai_engine.policies import LlamaModelInferPolicy from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -24,22 +24,21 @@ def data_gen(): def pipeline_inference_test(pp_size, new_length, micro_batch_size): - model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) - engine = PPInferEngine( - pp_size=pp_size, - model=model, - model_policy=GPT2LMHeadModelPipelinePolicy(), - new_length=new_length, - micro_batch_size=micro_batch_size, - ) + model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) + + engine = CaiInferEngine(pp_size=pp_size, + model=model, + model_policy=LlamaModelInferPolicy(), + new_length=new_length, + micro_batch_size=micro_batch_size) output = engine.inference([inputs]) if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}" -@parameterize("pp_size", [4]) -@parameterize("new_length", [4, 8, 16]) -@parameterize("micro_batch_size", [1, 4]) +@parameterize('pp_size', [2]) +@parameterize('new_length', [4, 8, 16]) +@parameterize('micro_batch_size', [1, 4]) @clear_cache_before_run() def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): pipeline_inference_test(pp_size, new_length, micro_batch_size) @@ -55,7 +54,7 @@ def check_pipeline_inference(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_pipeline_inference(): - spawn(check_pipeline_inference, nprocs=4) + spawn(check_pipeline_inference, nprocs=2) if __name__ == "__main__": From c63ff2b4aaf95e93acf7d3524c7678db457d4b8e Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 18 Oct 2023 15:21:14 +0800 Subject: [PATCH 02/14] support ppsize > 2 --- colossalai/pipeline/schedule/generate.py | 21 +++++++++++++------- colossalai/shardformer/shard/shard_config.py | 3 ++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 66f5e66cef3d..db02dab59ca6 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -108,7 +108,6 @@ def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` """ new_mask = self.mb_manager.cur_descrption.attn_mask - past_key_values = self.mb_manager.cur_descrption.infer_state return dict(input_ids=new_token, attention_mask=new_mask) @@ -187,7 +186,6 @@ def _head_encoding_action(self, model: Module): def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" - # inputs_dict = self._prepare_inputs_for_interval_stage() interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} output_dict = model_forward(model, None, interval_inputs) @@ -320,7 +318,9 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - output_dict = model_forward(model, inputs_dict, None) + self.mb_manager.add_descrption(inputs_dict) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) # In GENERATE phase else: # Get hidden_states from previous stage @@ -330,7 +330,8 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t assert ( hidden_states is not None ), "When first stage in GENERATE phase, the hidden states should not be None" - logits = model_forward(model, None, hidden_states) + interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) @@ -340,11 +341,17 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) - output_dict = model_forward(model, inputs_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) else: assert hidden_states is not None, "When not first stage, the hidden states should not be None" - inputs_dict = self._prepare_inputs_for_interval_stage() - output_dict = model_forward(model, inputs_dict, hidden_states) + # inputs_dict = self._prepare_inputs_for_interval_stage() + inputs_dict = None + if self.mb_manager.cur_state is Status.PREFILL: + inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) + interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) # Current microbatch is not DONE, send hidden_state to next stage if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a285874d218b..2aa6139836a5 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -76,4 +76,5 @@ def _infer(self): """ Set default params for inference. """ - assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" + # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" + pass From 50979bdda6206bf6091d9585208dc631e272400a Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 18 Oct 2023 16:11:07 +0800 Subject: [PATCH 03/14] refactor pipeline code --- colossalai/inference/__init__.py | 4 ++-- colossalai/inference/cai_engine/__init__.py | 3 --- colossalai/inference/pipeline/__init__.py | 4 ++-- .../inference/{cai_engine => pipeline}/batch_infer_state.py | 0 colossalai/inference/{cai_engine => pipeline}/engine.py | 4 ++-- .../inference/{cai_engine => pipeline}/kvcache_manager.py | 0 colossalai/inference/pipeline/microbatch_manager.py | 4 ++-- .../inference/{cai_engine => pipeline}/modeling/__init__.py | 0 .../inference/{cai_engine => pipeline}/modeling/_utils.py | 0 .../inference/{cai_engine => pipeline}/modeling/llama.py | 0 .../inference/{cai_engine => pipeline}/policies/__init__.py | 0 .../inference/{cai_engine => pipeline}/policies/llama.py | 0 tests/test_infer/test_pipeline_infer.py | 6 +++--- 13 files changed, 11 insertions(+), 14 deletions(-) delete mode 100644 colossalai/inference/cai_engine/__init__.py rename colossalai/inference/{cai_engine => pipeline}/batch_infer_state.py (100%) rename colossalai/inference/{cai_engine => pipeline}/engine.py (98%) rename colossalai/inference/{cai_engine => pipeline}/kvcache_manager.py (100%) rename colossalai/inference/{cai_engine => pipeline}/modeling/__init__.py (100%) rename colossalai/inference/{cai_engine => pipeline}/modeling/_utils.py (100%) rename colossalai/inference/{cai_engine => pipeline}/modeling/llama.py (100%) rename colossalai/inference/{cai_engine => pipeline}/policies/__init__.py (100%) rename colossalai/inference/{cai_engine => pipeline}/policies/llama.py (100%) diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index 5975de786d46..761e48e5917a 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,4 +1,4 @@ -from .cai_engine import CaiInferEngine +from .pipeline import PPInferEngine -__all__ = ['CaiInferEngine'] +__all__ = ['PPInferEngine'] diff --git a/colossalai/inference/cai_engine/__init__.py b/colossalai/inference/cai_engine/__init__.py deleted file mode 100644 index a80de0898d06..000000000000 --- a/colossalai/inference/cai_engine/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .engine import CaiInferEngine - -__all__ = ['CaiInferEngine'] diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py index 98711b0b18d6..aff4568f7d08 100644 --- a/colossalai/inference/pipeline/__init__.py +++ b/colossalai/inference/pipeline/__init__.py @@ -1,3 +1,3 @@ -from .microbatch_manager import MicroBatchManager +from .engine import PPInferEngine -__all__ = ['MicroBatchManager'] +__all__ = ['PPInferEngine'] diff --git a/colossalai/inference/cai_engine/batch_infer_state.py b/colossalai/inference/pipeline/batch_infer_state.py similarity index 100% rename from colossalai/inference/cai_engine/batch_infer_state.py rename to colossalai/inference/pipeline/batch_infer_state.py diff --git a/colossalai/inference/cai_engine/engine.py b/colossalai/inference/pipeline/engine.py similarity index 98% rename from colossalai/inference/cai_engine/engine.py rename to colossalai/inference/pipeline/engine.py index dc52ac699b5f..0e3d4b767db6 100644 --- a/colossalai/inference/cai_engine/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -7,11 +7,11 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy -from ..pipeline import MicroBatchManager +from .microbatch_manager import MicroBatchManager from .kvcache_manager import MemoryManager -class CaiInferEngine: +class PPInferEngine: """ PPInferEngine is a class that handles the pipeline parallel inference. diff --git a/colossalai/inference/cai_engine/kvcache_manager.py b/colossalai/inference/pipeline/kvcache_manager.py similarity index 100% rename from colossalai/inference/cai_engine/kvcache_manager.py rename to colossalai/inference/pipeline/kvcache_manager.py diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index ca33c26079a0..62a9cae152b0 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Dict, Tuple -from ..tensor_parallel.batch_infer_state import BatchInferState -from ..tensor_parallel.kvcache_manager import MemoryManager +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager import torch __all__ = "MicroBatchManager" diff --git a/colossalai/inference/cai_engine/modeling/__init__.py b/colossalai/inference/pipeline/modeling/__init__.py similarity index 100% rename from colossalai/inference/cai_engine/modeling/__init__.py rename to colossalai/inference/pipeline/modeling/__init__.py diff --git a/colossalai/inference/cai_engine/modeling/_utils.py b/colossalai/inference/pipeline/modeling/_utils.py similarity index 100% rename from colossalai/inference/cai_engine/modeling/_utils.py rename to colossalai/inference/pipeline/modeling/_utils.py diff --git a/colossalai/inference/cai_engine/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py similarity index 100% rename from colossalai/inference/cai_engine/modeling/llama.py rename to colossalai/inference/pipeline/modeling/llama.py diff --git a/colossalai/inference/cai_engine/policies/__init__.py b/colossalai/inference/pipeline/policies/__init__.py similarity index 100% rename from colossalai/inference/cai_engine/policies/__init__.py rename to colossalai/inference/pipeline/policies/__init__.py diff --git a/colossalai/inference/cai_engine/policies/llama.py b/colossalai/inference/pipeline/policies/llama.py similarity index 100% rename from colossalai/inference/cai_engine/policies/llama.py rename to colossalai/inference/pipeline/policies/llama.py diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index b5213c905141..1d00a92bac68 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -4,8 +4,8 @@ import transformers import colossalai -from colossalai.inference import CaiInferEngine -from colossalai.inference.cai_engine.policies import LlamaModelInferPolicy +from colossalai.inference.pipeline import PPInferEngine +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -26,7 +26,7 @@ def data_gen(): def pipeline_inference_test(pp_size, new_length, micro_batch_size): model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) - engine = CaiInferEngine(pp_size=pp_size, + engine = PPInferEngine(pp_size=pp_size, model=model, model_policy=LlamaModelInferPolicy(), new_length=new_length, From 49817d0234681f38a4171fd9a4f6923362e4a0af Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 18 Oct 2023 16:20:31 +0800 Subject: [PATCH 04/14] do pre-commit --- colossalai/inference/pipeline/__init__.py | 2 +- .../inference/pipeline/batch_infer_state.py | 18 ++--- .../inference/pipeline/benchmark/benchmark.py | 4 +- colossalai/inference/pipeline/engine.py | 39 ++++++----- .../inference/pipeline/microbatch_manager.py | 70 ++++++++++--------- .../inference/pipeline/modeling/llama.py | 64 +++++++++-------- .../inference/pipeline/policies/llama.py | 18 +++-- 7 files changed, 116 insertions(+), 99 deletions(-) diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py index aff4568f7d08..41af9f3ef948 100644 --- a/colossalai/inference/pipeline/__init__.py +++ b/colossalai/inference/pipeline/__init__.py @@ -1,3 +1,3 @@ from .engine import PPInferEngine -__all__ = ['PPInferEngine'] +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/batch_infer_state.py b/colossalai/inference/pipeline/batch_infer_state.py index ec70cb6cbfd0..3084d2473319 100644 --- a/colossalai/inference/pipeline/batch_infer_state.py +++ b/colossalai/inference/pipeline/batch_infer_state.py @@ -2,9 +2,9 @@ from dataclasses import dataclass import torch +from transformers.tokenization_utils_base import BatchEncoding from .kvcache_manager import MemoryManager -from transformers.tokenization_utils_base import BatchEncoding @dataclass @@ -63,7 +63,7 @@ def init_from_batch( max_input_len: int, max_output_len: int, cache_manager: MemoryManager, - ): + ): if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)): raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state") @@ -106,15 +106,15 @@ def init_from_batch( start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda") - + return cls( batch_size=batch_size, max_len_in_batch=max_len_in_batch, - seq_len=seq_lengths.to('cuda'), - start_loc = seq_start_indexes.to("cuda"), - block_loc = block_loc, - decode_layer_id = 0, - past_key_values_len = 0, - is_context_stage = True, + seq_len=seq_lengths.to("cuda"), + start_loc=seq_start_indexes.to("cuda"), + block_loc=block_loc, + decode_layer_id=0, + past_key_values_len=0, + is_context_stage=True, cache_manager=cache_manager, ) diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 9c47909f70f0..1e4714563b10 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -7,7 +7,7 @@ import colossalai from colossalai.inference import PPInferEngine -from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 @@ -117,7 +117,7 @@ def print_details_info(timestamps, model_config, args, whole_end2end): micro_batch_size=args.mb_size, new_length=args.new_length, model=model, - model_policy=LlamaForCausalLMPipelinePolicy(), + model_policy=LlamaModelInferPolicy(), verbose=True, ) data = data_gen(args.batch_size, args.seq_len) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 0e3d4b767db6..c83361a454d9 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -7,8 +7,8 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy -from .microbatch_manager import MicroBatchManager from .kvcache_manager import MemoryManager +from .microbatch_manager import MicroBatchManager class PPInferEngine: @@ -62,15 +62,15 @@ def __init__( num_beams: int = 1, ) -> None: assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." - assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" - + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + max_output_len = max(max_output_len, max_input_len + new_length) self.pp_size = pp_size - if dtype == 'fp16': + if dtype == "fp16": self.dtype = torch.float16 model.half() - elif dtype == 'bf16': + elif dtype == "bf16": self.dtype = torch.bfloat16 model.to(torch.bfloat16) else: @@ -78,9 +78,18 @@ def __init__( self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.model = pp_model or self._shardformer(model, model_policy) - self.cache_manager_list = [self._init_manager(max_batch_size, max_input_len, max_output_len)]*(micro_batch_buffer_size or pp_size) - self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, - micro_batch_buffer_size or pp_size, max_input_len, max_output_len, self.cache_manager_list) + self.cache_manager_list = [self._init_manager(max_batch_size, max_input_len, max_output_len)] * ( + micro_batch_buffer_size or pp_size + ) + self.mb_manager = MicroBatchManager( + self.stage_manager.stage, + new_length, + micro_batch_size, + micro_batch_buffer_size or pp_size, + max_input_len, + max_output_len, + self.cache_manager_list, + ) self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) @@ -111,15 +120,11 @@ def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads head_num = self.model.config.num_attention_heads num_hidden_layers = ( - self.model.config.num_hidden_layers if hasattr(self.model.config, "num_hidden_layers") else self.model.config.num_layers + self.model.config.num_hidden_layers + if hasattr(self.model.config, "num_hidden_layers") + else self.model.config.num_layers ) layer_num = num_hidden_layers // self.pp_size - cache_manager = MemoryManager( - max_total_token_num, - self.dtype, - head_num, - head_dim, - layer_num - ) - return cache_manager \ No newline at end of file + cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num) + return cache_manager diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 62a9cae152b0..31671ccb32d5 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -1,8 +1,10 @@ from enum import Enum -from typing import Dict, Tuple +from typing import Dict + +import torch + from .batch_infer_state import BatchInferState from .kvcache_manager import MemoryManager -import torch __all__ = "MicroBatchManager" @@ -33,13 +35,11 @@ def __init__( cache_manager: MemoryManager, new_length: int, ) -> None: - self.mb_length = inputs_dict['input_ids'].shape[-1] + self.mb_length = inputs_dict["input_ids"].shape[-1] self.target_length = self.mb_length + new_length self.infer_state = BatchInferState.init_from_batch( - batch=inputs_dict, - max_input_len=max_input_len, - max_output_len=max_output_len, - cache_manager=cache_manager) + batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager + ) # print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}") def update(self, *args, **kwargs): @@ -82,13 +82,13 @@ class HeadMicroBatchDescription(MicroBatchDescription): """ def __init__( - self, - inputs_dict: Dict[str, torch.Tensor], - max_input_len: int, - max_output_len: int, - cache_manager: MemoryManager, - new_length: int - ) -> None: + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + new_length: int, + ) -> None: super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) assert inputs_dict is not None assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None @@ -134,15 +134,15 @@ class BodyMicroBatchDescription(MicroBatchDescription): """ def __init__( - self, - inputs_dict: Dict[str, torch.Tensor], - max_input_len: int, - max_output_len: int, - cache_manager: MemoryManager, - new_length: int - ) -> None: + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + new_length: int, + ) -> None: super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) - + @property def cur_length(self): """ @@ -165,15 +165,15 @@ class MicroBatchManager: """ def __init__( - self, - stage: int, - new_length: int, - micro_batch_size: int, - micro_batch_buffer_size: int, - max_input_len: int, - max_output_len: int, - cache_manager_list: MemoryManager - ): + self, + stage: int, + new_length: int, + micro_batch_size: int, + micro_batch_buffer_size: int, + max_input_len: int, + max_output_len: int, + cache_manager_list: MemoryManager, + ): self.stage = stage self.new_length = new_length self.micro_batch_size = micro_batch_size @@ -187,9 +187,13 @@ def __init__( def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): if self.stage == 0: - self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length) + self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + ) else: - self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length) + self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + ) def step(self, new_token: torch.Tensor = None): """ diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index 00907ab5f2b3..41c304461b01 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -1,14 +1,18 @@ from typing import List, Optional, Tuple import torch -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) from transformers.utils import logging from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.pipeline.stage_manager import PipelineStageManager - from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd +from colossalai.pipeline.stage_manager import PipelineStageManager from ._utils import copy_kv_to_mem_cache @@ -71,45 +75,45 @@ def llama_causal_lm_forward( stage_index: Optional[List[int]] = None, ): r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # If is first stage and after warmup, go throught lm_head first if stage_manager.is_first_stage() and hidden_states is not None: lm_logits = self.lm_head(hidden_states) - return {'logits': lm_logits} + return {"logits": lm_logits} # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaInferenceForwards.llama_model_forward( @@ -239,7 +243,6 @@ def llama_model_forward( infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - # embed positions if attention_mask is None: attention_mask = torch.ones( @@ -251,8 +254,8 @@ def llama_model_forward( ) # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None + () if output_hidden_states else None + () if output_attentions else None next_decoder_cache = () if use_cache else None infer_state.decode_layer_id = 0 @@ -278,7 +281,7 @@ def llama_model_forward( if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - + if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) next_cache = next_decoder_cache if use_cache else None @@ -299,8 +302,7 @@ def llama_model_forward( # attentions=all_self_attns, # ) # print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}") - return {'hidden_states': hidden_states, 'past_key_values': next_cache} - + return {"hidden_states": hidden_states, "past_key_values": next_cache} @staticmethod def llama_decoder_layer_forward( diff --git a/colossalai/inference/pipeline/policies/llama.py b/colossalai/inference/pipeline/policies/llama.py index 865d36d7a601..9f8c93c61234 100644 --- a/colossalai/inference/pipeline/policies/llama.py +++ b/colossalai/inference/pipeline/policies/llama.py @@ -1,9 +1,15 @@ from functools import partial from typing import List -from torch.nn import Module import torch -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm, LlamaForCausalLM +from torch.nn import Module +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription @@ -108,9 +114,9 @@ def module_policy(self): if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaForCausalLM, - new_forward=LlamaInferenceForwards.llama_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy + ) infer_forward = None if HAS_TRITON_RMSNORM: infer_forward = get_triton_rmsnorm_forward() @@ -129,7 +135,7 @@ def module_policy(self): def postprocess(self): init_to_get_rotary(self.model.model) return self.model - + def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager From 8881c9c4d85eb8ddcf0c9a5d5cceb1fa526c0771 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 18 Oct 2023 16:31:03 +0800 Subject: [PATCH 05/14] modify benchmark --- colossalai/inference/pipeline/benchmark/benchmark.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 1e4714563b10..8342ac0a6dfe 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -119,6 +119,8 @@ def print_details_info(timestamps, model_config, args, whole_end2end): model=model, model_policy=LlamaModelInferPolicy(), verbose=True, + max_input_len=args.seq_len, + max_output_len=args.new_length, ) data = data_gen(args.batch_size, args.seq_len) From e18ceadaebd66a642ac35440db623833ea1fd2cc Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 18 Oct 2023 17:12:08 +0800 Subject: [PATCH 06/14] fix bench mark --- colossalai/inference/pipeline/benchmark/benchmark.py | 3 ++- colossalai/inference/pipeline/benchmark/run.sh | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 8342ac0a6dfe..6650862bf52f 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -119,8 +119,9 @@ def print_details_info(timestamps, model_config, args, whole_end2end): model=model, model_policy=LlamaModelInferPolicy(), verbose=True, + max_batch_size=args.batch_size, max_input_len=args.seq_len, - max_output_len=args.new_length, + max_output_len=args.seq_len + args.new_length + 256, ) data = data_gen(args.batch_size, args.seq_len) diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh index 7d8da858692f..e3c33bb88db1 100644 --- a/colossalai/inference/pipeline/benchmark/run.sh +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -1,7 +1,7 @@ script_dir=$(cd "$(dirname "$0")" && pwd) cd "${script_dir}" -# 7b, fp32, 2 gpu, 1024, 128 +# 7b, fp16, 2 gpu, 1024, 128 for BATCH_SIZE in 2 4 8 16; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ @@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do --pp_size=2 done -# 7b, fp32, 2 gpu, 512, 512 +# 7b, fp16, 2 gpu, 512, 512 for BATCH_SIZE in 2 4 8 16 32; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ @@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do --pp_size=2 done -# 7b, fp32, 2 gpu, 1024, 128 +# 7b, fp16, 2 gpu, 1024, 128 for BATCH_SIZE in 2 4 8; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="13b" \ From 5bd91068612ec46dc89929f3a972e8bd710f0a6f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 18 Oct 2023 19:39:45 +0800 Subject: [PATCH 07/14] polish code --- colossalai/inference/pipeline/benchmark/benchmark.py | 2 +- colossalai/inference/pipeline/modeling/llama.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 6650862bf52f..8392d0a1e579 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -119,7 +119,7 @@ def print_details_info(timestamps, model_config, args, whole_end2end): model=model, model_policy=LlamaModelInferPolicy(), verbose=True, - max_batch_size=args.batch_size, + max_batch_size=args.mb_size, max_input_len=args.seq_len, max_output_len=args.seq_len + args.new_length + 256, ) diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index 41c304461b01..0f2f68d07edf 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -410,7 +410,6 @@ def llama_flash_attn_kvcache_forward( infer_state.seq_len, infer_state.cache_manager.past_key_values_length, ) - # print(f"[After prefill] rank:{torch.distributed.get_rank()}, {infer_state}") else: if infer_state.decode_is_contiguous: @@ -439,7 +438,6 @@ def llama_flash_attn_kvcache_forward( # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_states) - # print(f"rank:{torch.distributed.get_rank()}, {infer_state}") token_attention_fwd( query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], From bd00085f21b427ade22604361a4ba776e1160931 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 19 Oct 2023 17:25:49 +0800 Subject: [PATCH 08/14] add docstring and update readme --- colossalai/inference/pipeline/README.md | 67 ++++++++++--------- colossalai/inference/pipeline/engine.py | 25 ++++--- .../inference/pipeline/modeling/llama.py | 20 +----- 3 files changed, 56 insertions(+), 56 deletions(-) diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md index a90d5d6da182..ff670148f317 100644 --- a/colossalai/inference/pipeline/README.md +++ b/colossalai/inference/pipeline/README.md @@ -33,17 +33,22 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag ```python from colossalai.pipeline import PPInferEngine # Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example. -model = LlamaForCausalLM.from_pretrained('/path/to/model') -inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt") -engine = PPInferEngine( - pp_size=2, - dtype='fp16', - micro_batch_size=1, - new_length=10, - model=model, - model_policy=LlamaForCausalLMPipelinePolicy()) - -output = engine.inference([inputs]) +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy +import colossalai +from transformers import LlamaForCausalLM, LlamaTokenizer + +colossalai.launch_from_torch(config={}) + +model = LlamaForCausalLM.from_pretrained("path_to_model") +tokenizer = LlamaTokenizer.from_pretrained("path_to_model") +# assume the model is infered with 2 pipeline stages +inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) + +input = ["Introduce a landmark in China ","Introduce a landmark in China "] +data = tokenizer(input, return_tensors='pt') +output = inferengine.inference([data.to('cuda').data]) + ``` @@ -55,30 +60,32 @@ sh run.sh ## Performance -We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G. +We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G. -### Llama Throughput(tokens/s) +### Llama Throughput (tokens/s) | input length=1024, output length=128 -#### 7b, fp16 +#### A10 7b, fp16 | batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| | :---: | :---: | :---: | :---: | :---: | :---: | :---:| -| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM | -| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | -| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 | -| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM | +| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM | +| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | -#### 7b, fp32 +#### A10 13b, fp16 | batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 | -| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM | -| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 | -| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM | +| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | +| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | -#### 13b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | -| :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 | -| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM | -| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 | -| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM | + +#### A800 7b, fp16 +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | +| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | + + +#### A800 13b, fp16 +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 | +| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 | diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index c83361a454d9..a94910865021 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -24,20 +24,29 @@ class PPInferEngine: micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. new_length (int): the new length of the input sequence. early_stopping (bool): whether to stop early. + max_batch_size (int): the maximum batch size. + max_input_len (int): the maximum input length. + max_output_len (int): the maximum output length. Example: ```python - from colossalai.ppinference import PPInferEngine - from transformers import GPT2LMHeadModel, GPT2Tokenizer + from colossalai.inference import PPInferEngine + from colossalai.inference.pipeline.policies import LlamaModelInferPolicy + import colossalai + from transformers import LlamaForCausalLM, LlamaTokenizer - model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') - # assume the model is infered with 4 pipeline stages - inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding}) + colossalai.launch_from_torch(config={}) + + model = LlamaForCausalLM.from_pretrained("/home/lczyh/share/models/llama-7b-hf") + tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") + # assume the model is infered with 2 pipeline stages + inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) + + input = ["Introduce a landmark in China ","Introduce a landmark in China "] + data = tokenizer(input, return_tensors='pt') + output = inferengine.inference([data.to('cuda').data]) - input = ["Hello, my dog is cute, and I like"] - tokenized_input = tokenizer(input, return_tensors='pt') - output = engine.inference([tokenized_input]) ``` """ diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index 0f2f68d07edf..a914c80e3b70 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -1,3 +1,4 @@ +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py from typing import List, Optional, Tuple import torch @@ -81,24 +82,7 @@ def llama_causal_lm_forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + """ logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict From 64c1f4f9071899993ee84d894f2b717838f81e3e Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 20 Oct 2023 13:27:01 +0800 Subject: [PATCH 09/14] refactor the code --- colossalai/inference/pipeline/README.md | 2 +- .../inference/pipeline/batch_infer_state.py | 120 ------------------ colossalai/inference/pipeline/engine.py | 20 ++- .../inference/pipeline/kvcache_manager.py | 104 --------------- .../inference/pipeline/microbatch_manager.py | 4 +- .../inference/pipeline/modeling/llama.py | 14 +- colossalai/inference/pipeline/utils.py | 35 ----- .../tensor_parallel/batch_infer_state.py | 61 +++++++++ tests/test_infer/test_pipeline_infer.py | 20 +-- 9 files changed, 103 insertions(+), 277 deletions(-) delete mode 100644 colossalai/inference/pipeline/batch_infer_state.py delete mode 100644 colossalai/inference/pipeline/kvcache_manager.py delete mode 100644 colossalai/inference/pipeline/utils.py diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md index ff670148f317..18f69e39ab18 100644 --- a/colossalai/inference/pipeline/README.md +++ b/colossalai/inference/pipeline/README.md @@ -47,7 +47,7 @@ inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInfer input = ["Introduce a landmark in China ","Introduce a landmark in China "] data = tokenizer(input, return_tensors='pt') -output = inferengine.inference([data.to('cuda').data]) +output = inferengine.inference(data.to('cuda')) ``` diff --git a/colossalai/inference/pipeline/batch_infer_state.py b/colossalai/inference/pipeline/batch_infer_state.py deleted file mode 100644 index 3084d2473319..000000000000 --- a/colossalai/inference/pipeline/batch_infer_state.py +++ /dev/null @@ -1,120 +0,0 @@ -# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later -from dataclasses import dataclass - -import torch -from transformers.tokenization_utils_base import BatchEncoding - -from .kvcache_manager import MemoryManager - - -@dataclass -class BatchInferState: - r""" - Information to be passed and used for a batch of inputs during - a single model forward - """ - batch_size: int - max_len_in_batch: int - - cache_manager: MemoryManager = None - - block_loc: torch.Tensor = None - start_loc: torch.Tensor = None - seq_len: torch.Tensor = None - past_key_values_len: int = None - - is_context_stage: bool = False - context_mem_index: torch.Tensor = None - decode_is_contiguous: bool = None - decode_mem_start: int = None - decode_mem_end: int = None - decode_mem_index: torch.Tensor = None - decode_layer_id: int = None - - device: torch.device = torch.device("cuda") - - @property - def total_token_num(self): - # return self.batch_size * self.max_len_in_batch - assert self.seq_len is not None and self.seq_len.size(0) > 0 - return int(torch.sum(self.seq_len)) - - def set_cache_manager(self, manager: MemoryManager): - self.cache_manager = manager - - @staticmethod - def init_block_loc( - b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor - ): - """in-place update block loc mapping based on the sequence length of the inputs in current bath""" - start_index = 0 - seq_len_numpy = seq_len.cpu().numpy() - for i, cur_seq_len in enumerate(seq_len_numpy): - b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[ - start_index : start_index + cur_seq_len - ] - start_index += cur_seq_len - return - - @classmethod - def init_from_batch( - cls, - batch: torch.Tensor, - max_input_len: int, - max_output_len: int, - cache_manager: MemoryManager, - ): - if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)): - raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state") - - input_ids_list = None - attention_mask = None - - if isinstance(batch, (BatchEncoding, dict)): - input_ids_list = batch["input_ids"] - attention_mask = batch["attention_mask"] - else: - input_ids_list = batch - if isinstance(input_ids_list[0], int): # for a single input - input_ids_list = [input_ids_list] - attention_mask = [attention_mask] if attention_mask is not None else attention_mask - - batch_size = len(input_ids_list) - - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - start_index = 0 - - max_len_in_batch = -1 - if isinstance(batch, (BatchEncoding, dict)): - for i, attn_mask in enumerate(attention_mask): - curr_seq_len = len(attn_mask) - # if isinstance(attn_mask, torch.Tensor): - # curr_seq_len = int(torch.sum(attn_mask)) - # else: - # curr_seq_len = int(sum(attn_mask)) - seq_lengths[i] = curr_seq_len - seq_start_indexes[i] = start_index - start_index += curr_seq_len - max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - else: - length = max(len(input_id) for input_id in input_ids_list) - for i, input_ids in enumerate(input_ids_list): - curr_seq_len = length - seq_lengths[i] = curr_seq_len - seq_start_indexes[i] = start_index - start_index += curr_seq_len - max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda") - - return cls( - batch_size=batch_size, - max_len_in_batch=max_len_in_batch, - seq_len=seq_lengths.to("cuda"), - start_loc=seq_start_indexes.to("cuda"), - block_loc=block_loc, - decode_layer_id=0, - past_key_values_len=0, - is_context_stage=True, - cache_manager=cache_manager, - ) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index a94910865021..05b794600198 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from transformers.tokenization_utils_base import BatchEncoding from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.schedule.generate import GenerateSchedule @@ -7,7 +8,7 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy -from .kvcache_manager import MemoryManager +from ..tensor_parallel.kvcache_manager import MemoryManager from .microbatch_manager import MicroBatchManager @@ -38,7 +39,7 @@ class PPInferEngine: colossalai.launch_from_torch(config={}) - model = LlamaForCausalLM.from_pretrained("/home/lczyh/share/models/llama-7b-hf") + model = LlamaForCausalLM.from_pretrained("your_path_to_model") tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") # assume the model is infered with 2 pipeline stages inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) @@ -103,7 +104,20 @@ def __init__( self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) def inference(self, input_list): - out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) + """ + Args: + input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`. + + Returns: + out (list): a list of output data, each element is a list of token. + timestamp (float): the time cost of the inference, only return when verbose is `True`. + """ + assert isinstance( + input_list, (BatchEncoding, dict) + ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." + if isinstance(input_list, BatchEncoding): + input_list = input_list.data + out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) if self.verbose: return out, timestamp else: diff --git a/colossalai/inference/pipeline/kvcache_manager.py b/colossalai/inference/pipeline/kvcache_manager.py deleted file mode 100644 index e74a3a491a7b..000000000000 --- a/colossalai/inference/pipeline/kvcache_manager.py +++ /dev/null @@ -1,104 +0,0 @@ -# Adapted from lightllm/common/mem_manager.py -# of the ModelTC/lightllm GitHub repository -# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py - -import torch -from transformers.utils import logging - - -class MemoryManager: - r""" - Manage token block indexes and allocate physical memory for key and value cache - - Args: - size: maximum token number used as the size of key and value buffer - dtype: data type of cached key and value - head_num: number of heads the memory manager is responsible for - head_dim: embedded size per head - layer_num: the number of layers in the model - device: device used to store the key and value cache - """ - - def __init__( - self, - size: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, - layer_num: int, - device: torch.device = torch.device("cuda"), - ): - self.logger = logging.get_logger(__name__) - self.available_size = size - self.past_key_values_length = 0 - self._init_mem_states(size, device) - self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) - - def _init_mem_states(self, size, device): - """Initialize tensors used to manage memory states""" - self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) - self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) - self.indexes = torch.arange(0, size, dtype=torch.long, device=device) - - def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): - """Initialize key buffer and value buffer on specified device""" - self.key_buffer = [ - torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) - ] - self.value_buffer = [ - torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) - ] - - @torch.no_grad() - def alloc(self, required_size): - """allocate space of required_size by providing indexes representing available physical spaces""" - if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") - return None - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) - select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) - select_index = self.indexes[select_index] - self.mem_state[select_index] = 0 - self.available_size -= len(select_index) - return select_index - - @torch.no_grad() - def alloc_contiguous(self, required_size): - """allocate contiguous space of required_size""" - if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") - return None - torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) - sum_size = len(self.mem_cum_sum) - loc_sums = ( - self.mem_cum_sum[required_size - 1 :] - - self.mem_cum_sum[0 : sum_size - required_size + 1] - + self.mem_state[0 : sum_size - required_size + 1] - ) - can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size] - if can_used_loc.shape[0] == 0: - self.logger.info( - f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}" - ) - return None - start_loc = can_used_loc[0] - select_index = self.indexes[start_loc : start_loc + required_size] - self.mem_state[select_index] = 0 - self.available_size -= len(select_index) - start = start_loc.item() - end = start + required_size - return select_index, start, end - - @torch.no_grad() - def free(self, free_index): - """free memory by updating memory states based on given indexes""" - self.available_size += free_index.shape[0] - self.mem_state[free_index] = 1 - - @torch.no_grad() - def free_all(self): - """free all memory by updating memory states""" - self.available_size = len(self.mem_state) - self.mem_state[:] = 1 - self.past_key_values_length = 0 - self.logger.info("freed all space of memory manager") diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 31671ccb32d5..2bf52161d611 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -3,8 +3,8 @@ import torch -from .batch_infer_state import BatchInferState -from .kvcache_manager import MemoryManager +from ..tensor_parallel.batch_infer_state import BatchInferState +from ..tensor_parallel.kvcache_manager import MemoryManager __all__ = "MicroBatchManager" diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index a914c80e3b70..7520c5ed0044 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -12,7 +12,7 @@ from transformers.utils import logging from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager from ._utils import copy_kv_to_mem_cache @@ -31,6 +31,14 @@ ) HAS_VLLM_KERNERL = False +try: + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -363,8 +371,8 @@ def llama_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin - rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_heads, self.head_dim) diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py deleted file mode 100644 index c26aa4e40b71..000000000000 --- a/colossalai/inference/pipeline/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Set - -import torch.nn as nn - -from colossalai.shardformer._utils import getattr_, setattr_ - - -def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None: - """ - Set all parameters and buffers of model to None - - Args: - model (nn.Module): The model to set - """ - for module_suffix in include: - set_module = getattr_(model, module_suffix) - for n, p in set_module.named_parameters(): - setattr_(set_module, n, None) - for n, buf in set_module.named_buffers(): - setattr_(set_module, n, None) - setattr_(model, module_suffix, None) - - -def get_suffix_name(suffix: str, name: str): - """ - Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit, - and 'name' when `suffix` is empty. - - Args: - suffix (str): The suffix of the suffix module - name (str): The name of the current module - """ - point = "" if suffix is "" else "." - suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}" - return suffix_name diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index de150311cc08..f707a86df37e 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -2,9 +2,11 @@ from dataclasses import dataclass import torch +from transformers.tokenization_utils_base import BatchEncoding from .kvcache_manager import MemoryManager + # adapted from: lightllm/server/router/model_infer/infer_batch.py @dataclass class BatchInferState: @@ -55,3 +57,62 @@ def init_block_loc( ] start_index += cur_seq_len return + + @classmethod + def init_from_batch( + cls, + batch: torch.Tensor, + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + ): + if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)): + raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state") + + input_ids_list = None + attention_mask = None + + if isinstance(batch, (BatchEncoding, dict)): + input_ids_list = batch["input_ids"] + attention_mask = batch["attention_mask"] + else: + input_ids_list = batch + if isinstance(input_ids_list[0], int): # for a single input + input_ids_list = [input_ids_list] + attention_mask = [attention_mask] if attention_mask is not None else attention_mask + + batch_size = len(input_ids_list) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + + max_len_in_batch = -1 + if isinstance(batch, (BatchEncoding, dict)): + for i, attn_mask in enumerate(attention_mask): + curr_seq_len = len(attn_mask) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + length = max(len(input_id) for input_id in input_ids_list) + for i, input_ids in enumerate(input_ids_list): + curr_seq_len = length + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda") + + return cls( + batch_size=batch_size, + max_len_in_batch=max_len_in_batch, + seq_len=seq_lengths.to("cuda"), + start_loc=seq_start_indexes.to("cuda"), + block_loc=block_loc, + decode_layer_id=0, + past_key_values_len=0, + is_context_stage=True, + cache_manager=cache_manager, + ) diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 1d00a92bac68..741e0d043394 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -26,19 +26,21 @@ def data_gen(): def pipeline_inference_test(pp_size, new_length, micro_batch_size): model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) - engine = PPInferEngine(pp_size=pp_size, - model=model, - model_policy=LlamaModelInferPolicy(), - new_length=new_length, - micro_batch_size=micro_batch_size) - output = engine.inference([inputs]) + engine = PPInferEngine( + pp_size=pp_size, + model=model, + model_policy=LlamaModelInferPolicy(), + new_length=new_length, + micro_batch_size=micro_batch_size, + ) + output = engine.inference(inputs) if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}" -@parameterize('pp_size', [2]) -@parameterize('new_length', [4, 8, 16]) -@parameterize('micro_batch_size', [1, 4]) +@parameterize("pp_size", [2]) +@parameterize("new_length", [4, 8, 16]) +@parameterize("micro_batch_size", [1, 4]) @clear_cache_before_run() def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): pipeline_inference_test(pp_size, new_length, micro_batch_size) From 037bda068b53744ce635dc7474d5cdc719c83207 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 23 Oct 2023 15:15:33 +0800 Subject: [PATCH 10/14] fix some logic bug of ppinfer --- colossalai/inference/pipeline/engine.py | 7 ++++--- colossalai/inference/pipeline/modeling/llama.py | 5 +---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 05b794600198..480ac5dc79fb 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -88,9 +88,10 @@ def __init__( self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.model = pp_model or self._shardformer(model, model_policy) - self.cache_manager_list = [self._init_manager(max_batch_size, max_input_len, max_output_len)] * ( - micro_batch_buffer_size or pp_size - ) + self.cache_manager_list = [ + self._init_manager(max_batch_size, max_input_len, max_output_len) + for _ in range(micro_batch_buffer_size or pp_size) + ] self.mb_manager = MicroBatchManager( self.stage_manager.stage, new_length, diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index 7520c5ed0044..9c72b02ccef8 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -175,11 +175,8 @@ def llama_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it + if infer_state.is_context_stage is False: past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length # NOTE: differentiate with prefill stage From df5f00a25739889d8b99dba9f6a0b4e3c9c8f85a Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 23 Oct 2023 15:34:08 +0800 Subject: [PATCH 11/14] polish readme --- colossalai/inference/pipeline/README.md | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md index 18f69e39ab18..735134d98263 100644 --- a/colossalai/inference/pipeline/README.md +++ b/colossalai/inference/pipeline/README.md @@ -31,8 +31,6 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag ### Example ```python -from colossalai.pipeline import PPInferEngine -# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example. from colossalai.inference import PPInferEngine from colossalai.inference.pipeline.policies import LlamaModelInferPolicy import colossalai @@ -40,22 +38,16 @@ from transformers import LlamaForCausalLM, LlamaTokenizer colossalai.launch_from_torch(config={}) -model = LlamaForCausalLM.from_pretrained("path_to_model") -tokenizer = LlamaTokenizer.from_pretrained("path_to_model") +model = LlamaForCausalLM.from_pretrained("/path/to/model") +tokenizer = LlamaTokenizer.from_pretrained("/path/to/model") + # assume the model is infered with 2 pipeline stages -inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) +inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32) -input = ["Introduce a landmark in China ","Introduce a landmark in China "] +input = ["Introduce a landmark in London","Introduce a landmark in Singapore"] data = tokenizer(input, return_tensors='pt') output = inferengine.inference(data.to('cuda')) - - -``` - -### Quick start -```shell -cd benchmark -sh run.sh +print(tokenizer.batch_decode(output)) ``` ## Performance From 85671beb5d68dc9e65031c5ba057c8be6b3e40ee Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 23 Oct 2023 15:46:10 +0800 Subject: [PATCH 12/14] fix typo --- colossalai/inference/pipeline/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md index 735134d98263..f9bb35cc4d4c 100644 --- a/colossalai/inference/pipeline/README.md +++ b/colossalai/inference/pipeline/README.md @@ -17,7 +17,7 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py). 1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks: - - Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`. + - Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`. - Run the pipeline inference model. 2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks: @@ -41,7 +41,7 @@ colossalai.launch_from_torch(config={}) model = LlamaForCausalLM.from_pretrained("/path/to/model") tokenizer = LlamaTokenizer.from_pretrained("/path/to/model") -# assume the model is infered with 2 pipeline stages +# assume the model is inferred with 2 pipeline stages inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32) input = ["Introduce a landmark in London","Introduce a landmark in Singapore"] From c28ee57496533a15627383514a22c599212f85e8 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 27 Oct 2023 11:41:14 +0800 Subject: [PATCH 13/14] skip infer test --- tests/test_infer/test_pipeline_infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 741e0d043394..972860053c31 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -52,6 +52,7 @@ def check_pipeline_inference(rank, world_size, port): run_pipeline_inference_test() +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From 0ef477e0b5bdffc49433ead1029891f924998fb0 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 27 Oct 2023 12:02:05 +0800 Subject: [PATCH 14/14] skip infer test --- tests/test_infer/test_pipeline_infer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 972860053c31..6d02f2b326b4 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -2,12 +2,15 @@ import torch import torch.distributed as dist import transformers +from packaging import version import colossalai from colossalai.inference.pipeline import PPInferEngine from colossalai.inference.pipeline.policies import LlamaModelInferPolicy from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + def data_gen(): input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)