diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 7639633eaa79..e3a261f96a86 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -145,7 +145,7 @@ def step(self): outputs = results[0] # get any one of the copies return outputs - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) def is_running(self): diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 5af9ae10357d..c43a46ccb71e 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -58,7 +58,7 @@ def __init__( self.mem_usage_interval = log_stats_interval * 2 self.tokenizer = get_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str = ""): + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): """ Add new request to req queue, during initialization all requests are held in waiting list. """ @@ -75,7 +75,7 @@ def add_input(self, request_id, prompts, sampling_params): if prompt_len > self.engine.max_input_len: raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id, prompts) + self.add_req(request_id, prompt_ids, sampling_params, prompts) return def abort(self, request_id): @@ -259,11 +259,11 @@ def clean_up(self): # this logic should be implemented in the future. pass - def generate(self, prompts, sampling_params, request_id): + def generate(self, request_id, prompts, sampling_params): """ Generate the output of a request. """ - self.add_input(request_id, sampling_params, prompts) + self.add_input(request_id, prompts, sampling_params) return self.loop_for_fwd() def is_running(self): diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index a98b96565c50..e75004d506a3 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -400,8 +400,8 @@ def forward(self, batch_id, is_prefill): model = self.model.model elif isinstance(model, BloomForCausalLM): model = self.model.transformer + setattr(model, "infer_state", infer_state) - output = self.model.forward(input_ids=input_) logits = output.logits # bsz, seq_len, vocab_size diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 958868a0974e..7e6978ad815b 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -76,14 +76,13 @@ def llama_model_forward( batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # NOT READY FOR PRIME TIME # dummy but work, revise it if infer_state.is_context_stage: past_key_values_length = 0 else: past_key_values_length = infer_state.max_len_in_batch - 1 - + # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage if use_cache and seq_length != 1: diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py similarity index 80% rename from tests/test_infer/test_dynamic_batching/test_forward.py rename to tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py index 63df491e5b52..9925a80b6e77 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + import pytest import torch from packaging import version @@ -5,7 +7,6 @@ from transformers.models.llama.configuration_llama import LlamaConfig import colossalai -from dataclasses import dataclass from colossalai.inference.dynamic_batching.io_struct import Req from colossalai.inference.dynamic_batching.sampling_params import SamplingParams from colossalai.inference.manager import start_dynamic_batching @@ -19,17 +20,26 @@ MAX_OUTPUT_LEN = 16 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + @dataclass class args: max_total_token_num: int batch_max_tokens: int + model: str eos_id: int disable_log_stats: bool log_stats_interval: int def run(): - arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) + arg = args( + max_total_token_num=42, + model="llama", + batch_max_tokens=42, + eos_id=0, + disable_log_stats=False, + log_stats_interval=10, + ) sampling_params = SamplingParams() req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) @@ -43,14 +53,18 @@ def run(): waiting_list.append(req3) waiting_list.append(req4) - llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + + ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params) + for result in ans_gen: + assert result is not None def check_dynamic_forward(rank, world_size, port):