Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
add python class ModelServer
Browse files Browse the repository at this point in the history
Signed-off-by: Yu, Zhentao <[email protected]>
  • Loading branch information
zhentaoyu committed Mar 25, 2024
1 parent db2d0f3 commit af602b3
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 93 deletions.
140 changes: 86 additions & 54 deletions neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,7 @@
model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"}
max_request_num_default = 1


class Model:

def __init__(self):
self.module = None
self.model = None
self.model_type = None
self.bin_file = None
self.generate_round = 0
self.max_request_num = -1
self.reinit_from_bin = False

def __import_package(self, model_type):
if self.module:
return
def _import_package(model_type):
if model_type == "gptj":
import neural_speed.gptj_cpp as cpp_model
elif model_type == "falcon":
Expand Down Expand Up @@ -81,28 +67,62 @@ def __import_package(self, model_type):
import neural_speed.mixtral_cpp as cpp_model
else:
raise TypeError("Unsupported model type {}!".format(model_type))
self.module = cpp_model

@staticmethod
def get_model_type(model_config):
model_type = model_maps.get(model_config.model_type, model_config.model_type)
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
model_type = "chatglm2"

# For ChatGLM3
if model_type == "chatglm" and "chatglm3" in model_config._name_or_path:
# due to the same model architecture.
model_type = "chatglm2"
return cpp_model

def _get_model_config(model_name, model_hub="huggingface"):
if model_hub == "modelscope":
from modelscope import AutoConfig
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
else:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
return config

def _get_model_type(model_config):
model_type = model_maps.get(model_config.model_type, model_config.model_type)
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
model_type = "chatglm2"

# For ChatGLM3
if model_type == "chatglm" and "chatglm3" in model_config._name_or_path:
# due to the same model architecture.
model_type = "chatglm2"

# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
model_type = "falcon"

# for TheBloke/phi-2-GPTQ
if model_type == "phi-msft":
model_type = "phi"

return model_type

def _filter_model_args(valid_args, **input_kwargs):
invalid_args = []
for k in input_kwargs.keys():
if k not in valid_args:
invalid_args.append(k)
for k in invalid_args:
input_kwargs.pop(k)
return input_kwargs

# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
model_type = "falcon"
def get_cpp_module(model_name, model_hub="huggingface"):
model_config = _get_model_config(model_name, model_hub=model_hub)
model_type = _get_model_type(model_config)
cpp_module = _import_package(model_type)
return cpp_module

# for TheBloke/phi-2-GPTQ
if model_type == "phi-msft":
model_type = "phi"
class Model:

return model_type
def __init__(self):
self.module = None
self.model = None
self.model_type = None
self.bin_file = None
self.generate_round = 0
self.max_request_num = -1
self.reinit_from_bin = False

def init(self,
model_name,
Expand All @@ -117,15 +137,11 @@ def init(self,
compute_dtype="int8",
use_ggml=False,
model_hub="huggingface"):
if model_hub == "modelscope":
from modelscope import AutoConfig
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
else:
from transformers import AutoConfig
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model_type = Model.get_model_type(self.config)
self.config = _get_model_config(model_name, model_hub=model_hub)
model_type = _get_model_type(self.config)
self.model_type = model_type
self.__import_package(model_type)
if self.module is None:
self.module = _import_package(model_type)

# check cache and quantization
output_path = "runtime_outs"
Expand Down Expand Up @@ -182,7 +198,8 @@ def init(self,
os.remove(fp32_bin)

def init_from_bin(self, model_type, model_path, **generate_kwargs):
self.__import_package(model_type)
if self.module is None:
self.module = _import_package(model_type)
self.model = self.module.Model()

if self.max_request_num == -1:
Expand Down Expand Up @@ -272,10 +289,16 @@ def get_scratch_size_ratio(size):
else:
generate_kwargs["scratch_size_ratio"] = 35

self.model.init_model(model_path, **self._filter_model_init_args(**generate_kwargs))
valid_args = {"max_new_tokens", "n_batch", "ctx_size", "seed", "threads", "repetition_penalty",
"num_beams", "do_sample", "top_k", "top_p", "temperature", "min_new_tokens",
"length_penalty", "early_stopping", "n_keep", "n_discard", "shift_roped_k",
"batch_size","pad_token", "memory_dtype", "continuous_batching", "max_request_num",
"scratch_size_ratio"}
self.model.init_model(model_path, **_filter_model_args(valid_args, **generate_kwargs))

def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
self.__import_package(model_type)
if self.module is None:
self.module = _import_package(model_type)
self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs)

def generate(self,
Expand Down Expand Up @@ -435,16 +458,25 @@ def _get_model_input_list(self, input_ids, **kwargs):
input_list = input_ids.tolist()
return input_list

def _filter_model_init_args(self, **init_kwargs):

class ModelServer:
def __init__(self, model_name, reponse_function, model_path, **server_kwargs):
if not os.path.exists(model_path):
raise ValueError("model file {} does not exist.".format(model_path))
self.module = get_cpp_module(model_name)
valid_args = {"max_new_tokens", "n_batch", "ctx_size", "seed", "threads", "repetition_penalty",
"num_beams", "do_sample", "top_k", "top_p", "temperature", "min_new_tokens",
"length_penalty", "early_stopping", "n_keep", "n_discard", "shift_roped_k",
"batch_size","pad_token", "memory_dtype", "continuous_batching", "max_request_num",
"scratch_size_ratio"}
invalid_args = []
for k in init_kwargs.keys():
if k not in valid_args:
invalid_args.append(k)
for k in invalid_args:
init_kwargs.pop(k)
return init_kwargs
"scratch_size_ratio", "return_prompt", "print_log", "init_cb"}
self.cpp_server = self.module.ModelServer(reponse_function,
model_path,
**_filter_model_args(valid_args, **server_kwargs))

def issueQuery(self, index, token_ids):
self.cpp_server.issueQuery([self.module.Query(index, token_ids)])

def Empty(self):
return self.cpp_server.Empty()

__all__ = ["get_cpp_module", "Model", "ModelServer"]
45 changes: 23 additions & 22 deletions scripts/python_api_example_for_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import argparse
from pathlib import Path
from typing import List, Optional
import neural_speed.llama_cpp as cpp
from neural_speed import ModelServer
from transformers import AutoTokenizer


Expand Down Expand Up @@ -92,30 +92,31 @@ def f_response(res, working):
print("=====================================")

added_count = 0
s = cpp.ModelServer(f_response,
str(args.model_path),
max_new_tokens=args.max_new_tokens,
num_beams=args.num_beams,
min_new_tokens=args.min_new_tokens,
early_stopping=args.early_stopping,
do_sample=args.do_sample,
continuous_batching=True,
return_prompt=args.return_prompt,
threads=args.threads,
max_request_num=args.max_request_num,
print_log=args.print_log,
scratch_size_ratio = args.scratch_size_ratio,
memory_dtype= args.memory_dtype,
ctx_size=args.ctx_size,
seed=args.seed,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repeat_penalty,
temperature=args.temperature,
s = ModelServer(args.model_name,
f_response,
str(args.model_path),
max_new_tokens=args.max_new_tokens,
num_beams=args.num_beams,
min_new_tokens=args.min_new_tokens,
early_stopping=args.early_stopping,
do_sample=args.do_sample,
continuous_batching=True,
return_prompt=args.return_prompt,
threads=args.threads,
max_request_num=args.max_request_num,
print_log=args.print_log,
scratch_size_ratio = args.scratch_size_ratio,
memory_dtype= args.memory_dtype,
ctx_size=args.ctx_size,
seed=args.seed,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repeat_penalty,
temperature=args.temperature,
)
for i in range(len(prompts)):
p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids.tolist()
s.issueQuery([cpp.Query(i, p_token_ids)])
s.issueQuery(i, p_token_ids)
added_count += 1
time.sleep(2) # adjust query sending time interval

Expand Down
34 changes: 17 additions & 17 deletions tests/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import time
import unittest
import shutil
from neural_speed import Model
import neural_speed.llama_cpp as cpp
from neural_speed import Model, ModelServer
from transformers import AutoTokenizer

class TestModelServer(unittest.TestCase):
Expand Down Expand Up @@ -91,24 +90,25 @@ def f_response(res, working):
print("============={} {} MODEL SERVER TESTING========".format(log_map[md],
log_map[policy]), flush=True)
added_count = 0
s = cpp.ModelServer(f_response,
model_path,
max_new_tokens=128,
num_beams=4 if policy == "beam" else 1,
min_new_tokens=30,
early_stopping=True,
do_sample=False,
continuous_batching=True,
return_prompt=True,
max_request_num=8,
threads=56,
print_log=False,
scratch_size_ratio = 1.0,
memory_dtype= md,
s = ModelServer(model_name,
f_response,
model_path,
max_new_tokens=128,
num_beams=4 if policy == "beam" else 1,
min_new_tokens=30,
early_stopping=True,
do_sample=False,
continuous_batching=True,
return_prompt=True,
max_request_num=8,
threads=56,
print_log=False,
scratch_size_ratio = 1.0,
memory_dtype= md,
)
for i in range(len(prompts)):
p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids.tolist()
s.issueQuery([cpp.Query(i, p_token_ids)])
s.issueQuery(i, p_token_ids)
added_count += 1
time.sleep(2) # adjust query sending time interval

Expand Down

0 comments on commit af602b3

Please sign in to comment.