Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Finish dynamic batching offline test #4948

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/inference/dynamic_batching/ray_dist_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def step(self):
outputs = results[0] # get any one of the copies
return outputs

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str):
ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers])

def is_running(self):
Expand Down
8 changes: 4 additions & 4 deletions colossalai/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.mem_usage_interval = log_stats_interval * 2
self.tokenizer = get_tokenizer(tokenizer_name=self.model)

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str = ""):
def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
"""
Add new request to req queue, during initialization all requests are held in waiting list.
"""
Expand All @@ -75,7 +75,7 @@ def add_input(self, request_id, prompts, sampling_params):
if prompt_len > self.engine.max_input_len:
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
self.add_req(prompt_ids, sampling_params, request_id, prompts)
self.add_req(request_id, prompt_ids, sampling_params, prompts)
return

def abort(self, request_id):
Expand Down Expand Up @@ -258,11 +258,11 @@ def clean_up(self):
# this logic should be implemented in the future.
pass

def generate(self, prompts, sampling_params, request_id):
def generate(self, request_id, prompts, sampling_params):
"""
Generate the output of a request.
"""
self.add_input(request_id, sampling_params, prompts)
self.add_input(request_id, prompts, sampling_params)
return self.loop_for_fwd()

def is_running(self):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def forward(self, batch_id, is_prefill):
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer

setattr(model, "infer_state", infer_state)

output = self.model.forward(input_ids=input_)
logits = output.logits
# bsz, seq_len, vocab_size
Expand Down
3 changes: 1 addition & 2 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,13 @@ def llama_model_forward(
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# NOT READY FOR PRIME TIME
# dummy but work, revise it
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1

# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass

import pytest
import torch
from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig

import colossalai
from dataclasses import dataclass
from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import start_dynamic_batching
Expand All @@ -19,17 +20,26 @@
MAX_OUTPUT_LEN = 16
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")


@dataclass
class args:
max_total_token_num: int
batch_max_tokens: int
model: str
eos_id: int
disable_log_stats: bool
log_stats_interval: int


def run():
arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10)
arg = args(
max_total_token_num=42,
model="llama",
batch_max_tokens=42,
eos_id=0,
disable_log_stats=False,
log_stats_interval=10,
)
sampling_params = SamplingParams()

req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
Expand All @@ -43,14 +53,18 @@ def run():
waiting_list.append(req3)
waiting_list.append(req4)

llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()

shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)

infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)

ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params)
for result in ans_gen:
assert result is not None


def check_dynamic_forward(rank, world_size, port):
Expand Down
Loading