From e410e4a18f9f5274c894b925d494bfa091472fcc Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 6 Aug 2024 09:08:22 +0000 Subject: [PATCH 1/6] stage 1 --- csrc/generation/get_padding_offset_v2.cu | 2 +- llm/predict/export_model.py | 3 +- llm/predict/predictor.py | 390 +++++++----------- llm/utils/utils.py | 47 --- .../transformers/llama/modeling.py | 7 +- paddlenlp/quantization/quantization_config.py | 4 +- 6 files changed, 151 insertions(+), 302 deletions(-) diff --git a/csrc/generation/get_padding_offset_v2.cu b/csrc/generation/get_padding_offset_v2.cu index 080764ed9955..737351fa0b5d 100644 --- a/csrc/generation/get_padding_offset_v2.cu +++ b/csrc/generation/get_padding_offset_v2.cu @@ -103,7 +103,7 @@ std::vector GetPaddingOffsetV2InferDtype(const paddle::DataTyp } PD_BUILD_OP(get_padding_offset_v2) - .Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) + .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) .Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(GetPaddingOffsetV2)) .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape)) diff --git a/llm/predict/export_model.py b/llm/predict/export_model.py index df8598e05cb8..91b3bc4cb7c3 100644 --- a/llm/predict/export_model.py +++ b/llm/predict/export_model.py @@ -61,7 +61,8 @@ def main(): }, ) predictor.model.config.save_pretrained(export_args.output_path) - predictor.model.generation_config.save_pretrained(export_args.output_path) + predictor.generation_config.save_pretrained(export_args.output_path) + predictor.tokenizer.save_pretrained(export_args.output_path) generate_rank_mapping(os.path.join(export_args.output_path, "rank_mapping.csv")) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 39ffc6b40fe6..1841ca1eea58 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -31,10 +31,7 @@ from utils.utils import ( dybatch_preprocess, get_alibi_slopes, - get_default_max_decoding_length, - get_default_max_encoding_length, get_infer_model_path, - get_model_max_position_embeddings, get_prefix_tuning_params, init_chat_template, load_real_time_tokens, @@ -67,8 +64,9 @@ class PredictorArgument: model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."}) model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"}) - src_length: int = field(default=None, metadata={"help": "The max length of source text."}) - max_length: int = field(default=None, metadata={"help": "the max length for decoding."}) + src_length: int = field(default=4096, metadata={"help": "The max length of source text."}) + min_length: int = field(default=1, metadata={"help": "the min length for decoding."}) + max_length: int = field(default=2048, metadata={"help": "the max length for decoding."}) top_k: int = field(default=0, metadata={"help": "top_k parameter for generation"}) top_p: float = field(default=0.7, metadata={"help": "top_p parameter for generation"}) temperature: float = field(default=0.95, metadata={"help": "top_p parameter for generation"}) @@ -118,7 +116,9 @@ class PredictorArgument: block_size: int = field(default=64, metadata={"help": "the block size for cache_kvs."}) cachekv_int8_type: str = field( default=None, - metadata={"help": "If cachekv_int8_type set as `dynamic`, cache kv would be quantized to int8 dynamically. If cachekv_int8_type set as `static`, cache kv would be quantized to int8 Statically."}, + metadata={ + "help": "If cachekv_int8_type set as `dynamic`, cache kv would be quantized to int8 dynamically. If cachekv_int8_type set as `static`, cache kv would be quantized to int8 Statically." + }, ) chat_template: str = field( @@ -180,11 +180,11 @@ def init_dist_env(): def get_eos_token_id( tokenizer: PretrainedTokenizer, generation_config: Optional[GenerationConfig] = None -) -> int | List[List[int]]: +) -> List[List[int]]: """get eos_token_id from generation_config or tokenizer Returns: - int | List[int]: eos_token_id to stop the generation + List[int]: eos_token_id to stop the generation """ eos_token_ids = [] if tokenizer.eos_token_id is not None: @@ -390,8 +390,10 @@ def _infer(self, inputs: dict[str, np.ndarray]): return decoded_ids -class InferencePredictorMixin: +class InferencePredictorMixin(BasePredictor): def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): + BasePredictor.__init__(self, config, tokenizer) + self.architectures = self.model_config.architectures[0].lower() self.dtype = config.dtype or self.model_config @@ -461,14 +463,6 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): item.squeeze_(0) for item in paddle.split(prefix_cache, self.num_layers, axis=0) ] - try: - self.generation_config = GenerationConfig.from_pretrained(config.model_name_or_path) - except: - logger.warning( - "Can't find generation config, so it will not use generation_config field in the model config" - ) - self.generation_config = None - def _postprocess(self, predictions, return_tokens=False): if paddle.distributed.get_rank() == 0: tokens: np.ndarray = load_real_time_tokens() @@ -643,7 +637,7 @@ def _preprocess(self, source): return inputs -class StaticInferencePredictor(InferencePredictorMixin, BasePredictor): +class StaticInferencePredictor(InferencePredictorMixin): def __init__( self, config: PredictorArgument, @@ -651,7 +645,6 @@ def __init__( tokenizer: PretrainedTokenizer = None, ): self.cache_kvs_shape = cache_kvs_shape - BasePredictor.__init__(self, config, tokenizer) InferencePredictorMixin.__init__(self, config, tokenizer) self.predictor = self._create_predictor(config) @@ -735,7 +728,7 @@ def _infer(self, inputs): self.predictor.run() -class DygraphInferencePredictor(InferencePredictorMixin, BasePredictor): +class DygraphInferencePredictor(InferencePredictorMixin): def __init__( self, config: PredictorArgument, @@ -743,7 +736,6 @@ def __init__( tokenizer: PretrainedTokenizer = None, ): self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length) - BasePredictor.__init__(self, config, tokenizer) InferencePredictorMixin.__init__(self, config, tokenizer) self.model = model @@ -766,8 +758,9 @@ def _infer(self, inputs: dict[str, paddle.Tensor]): return None -class BlockInferencePredictorMixin: +class BlockInferencePredictorMixin(BasePredictor): def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): + BasePredictor.__init__(self, config, tokenizer) self.num_layers = len(self.cache_kvs_shape) // 2 self.num_attention_heads = self.cache_kvs_shape[0][-3] @@ -780,10 +773,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.dtype = config.dtype or self.model_config.dtype - self.total_max_length = config.src_length + config.max_length self.block_size = config.block_size - self.pre_max_block_num = (self.total_max_length + config.block_size - 1) // config.block_size - self.max_block_nums = config.batch_size * self.pre_max_block_num try: self.rope_theta = self.model_config.rope_theta @@ -838,60 +828,71 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): for _ in range(self.num_layers) ] - if config.benchmark: - self.min_length = config.max_length - else: - self.min_length = 2 - - self.free_list = [i for i in range(self.max_block_nums)][::-1] - self.used_list = [[] for _ in range(config.batch_size)] - - def init_inputs(self, config: PredictorArgument): - self.inputs = {} + def pad_batch_data(self, insts): + """Pad the instances to the max sequence length in batch.""" + seq_lens = [] + for i, inst in enumerate(insts): + length = len(inst) + seq_lens.append(length) + self.input_ids[i, :length] = np.array(inst) + return seq_lens + + def init_model_inputs(self, config: PredictorArgument): + self.input_ids = paddle.full( + shape=[config.batch_size, config.total_max_length], fill_value=self.tokenizer.pad_token_id, dtype="int64" + ) + self.model_inputs = {} if config.export_precache: - self.inputs["src_mask"] = (self.pre_cache_mask - 1) * 1e4 - self.inputs["pre_ids"] = paddle.full([config.batch_size, self.total_max_length], -1, dtype="int64") - self.inputs["bad_tokens"] = paddle.to_tensor( - [ - -1, - ], - dtype="int64", - ) - self.inputs["penalty_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") - self.inputs["frequency_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") - self.inputs["presence_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") + self.model_inputs["src_mask"] = (self.pre_cache_mask - 1) * 1e4 - self.inputs["min_length"] = paddle.full( - shape=[config.batch_size, 1], fill_value=self.min_length, dtype="int64" + self.model_inputs["block_tables"] = paddle.full( + shape=[config.batch_size, (config.total_max_length + config.block_size - 1) // config.block_size], + fill_value=-1, + dtype="int32", ) - self.inputs["max_length"] = paddle.full( - shape=[config.batch_size, 1], fill_value=config.max_length, dtype="int64" + self.model_inputs["top_p"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.top_p, dtype="float32" + ) + self.model_inputs["temperature"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.temperature, dtype="float32" + ) + self.model_inputs["eos_token_id"] = paddle.to_tensor( + np.array(get_eos_token_id(self.tokenizer, self.generation_config)).reshape(-1, 1).astype("int64") + ) + self.model_inputs["penalty_score"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.repetition_penalty, dtype="float32" + ) + self.model_inputs["frequency_score"] = paddle.full( + shape=[config.batch_size, 1], fill_value=0.0, dtype="float32" ) - self.inputs["stop_nums"] = paddle.full(shape=[1], fill_value=config.batch_size, dtype="int64") - self.inputs["rope_emb"] = self._get_rotary_position_embedding( - paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta + self.model_inputs["presence_score"] = paddle.full( + shape=[config.batch_size, 1], fill_value=0.0, dtype="float32" ) - eos_token_id = get_eos_token_id(self.tokenizer, self.generation_config) - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - self.inputs["eos_token_id"] = paddle.to_tensor( - np.array(eos_token_id * config.batch_size).reshape(-1, 1).astype("int64") + self.model_inputs["min_length"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.min_length, dtype="int64" ) + self.model_inputs["max_length"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.max_length, dtype="int64" + ) + self.model_inputs["rope_emb"] = self._get_rotary_position_embedding( + paddle.arange(config.total_max_length).reshape((1, -1)), self.head_dim, self.rope_theta + ) + self.model_inputs["bad_tokens"] = paddle.to_tensor([-1], dtype="int64") + self.model_inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") + # bloom model needs src_mask and tgt_mask! if "bloom" in self.architectures: - lower_one_tril = paddle.tril( - paddle.ones(shape=(self.total_max_length, self.total_max_length), dtype=self.dtype) - ) + lower_one_tril = paddle.tril(paddle.ones(shape=(config.total_max_length, config.total_max_length), dtype=self.dtype)) lower_one_tril = lower_one_tril[None, None, :, :] - self.inputs["src_mask"] = lower_one_tril.tile([self.batch_size, 1, 1, 1]) - self.inputs["tgt_mask"] = paddle.full( - shape=[config.batch_size, 1, 1, self.total_max_length], fill_value=1, dtype=self.dtype + self.model_inputs["src_mask"] = lower_one_tril.tile([self.batch_size, 1, 1, 1]) + self.model_inputs["tgt_mask"] = paddle.full( + shape=[config.batch_size, 1, 1, config.total_max_length], fill_value=1, dtype=self.dtype ) - arange_tensor_encoder = paddle.arange(self.total_max_length).astype(self.dtype) + arange_tensor_encoder = paddle.arange(config.total_max_length).astype(self.dtype) alibi_slopes = get_alibi_slopes(self.num_attention_heads) alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder - alibi_encoder = alibi.tile([self.batch_size, 1, self.total_max_length, 1]) + alibi_encoder = alibi.tile([self.batch_size, 1, config.total_max_length, 1]) alibi_decoder = alibi.tile( [ self.batch_size, @@ -900,43 +901,14 @@ def init_inputs(self, config: PredictorArgument): 1, ] ) - # self.inputs["src_mask/tgt_mask"] is read only, will not be updated! - self.inputs["src_mask"] = ( - alibi_encoder + (1 - self.inputs["src_mask"]) * paddle.finfo(self.dtype).min + # self.model_inputs["src_mask/tgt_mask"] is read only, will not be updated! + self.model_inputs["src_mask"] = ( + alibi_encoder + (1 - self.model_inputs["src_mask"]) * paddle.finfo(self.dtype).min ).cast(self.dtype) - self.inputs["tgt_mask"] = ( - alibi_decoder + (1 - self.inputs["tgt_mask"]) * paddle.finfo(self.dtype).min + self.model_inputs["tgt_mask"] = ( + alibi_decoder + (1 - self.model_inputs["tgt_mask"]) * paddle.finfo(self.dtype).min ).cast(self.dtype) - # need update - self.inputs["block_tables"] = paddle.full( - shape=[config.batch_size, self.pre_max_block_num], fill_value=-1, dtype="int32" - ) - self.inputs["input_ids"] = paddle.full( - shape=[config.batch_size, self.total_max_length], fill_value=-1, dtype="int64" - ) - self.inputs["top_p"] = paddle.full(shape=[config.batch_size, 1], fill_value=config.top_p, dtype="float32") - self.inputs["temperature"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") - self.inputs["seq_lens_this_time"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") - self.inputs["seq_lens_encoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") - self.inputs["seq_lens_decoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") - self.inputs["step_idx"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int64") - self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool") - self.inputs["stop_flags"] = paddle.full(shape=[config.batch_size, 1], fill_value=True, dtype="bool") - self.inputs["next_tokens"] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64") - self.inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") - free_list = list(range(self.pre_max_block_num - 1, int(self.pre_max_block_num * 0.75) - 1, -1)) - self.inputs["encoder_block_lens"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") - self.inputs["step_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs["step_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs["recover_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs["recover_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs["need_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs["need_block_len"] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs["used_list_len"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") - self.inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32") - self.inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.pre_max_block_num * 0.25, dtype="int32") - def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=10000.0): """ Pre-calculate rotary position embedding for position_ids. @@ -961,12 +933,13 @@ def _get_rotary_position_embedding(self, position_ids, head_dim, rope_theta=1000 rot_emb[1] = paddle.sin(emb) return rot_emb - def _preprocess(self, source): + def _preprocess(self, input_text: list[str]): if self.tokenizer.chat_template is not None: - source = [source] if isinstance(source, str) else source - source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source] + input_text = [input_text] if isinstance(input_text, str) else input_text + input_text = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in input_text] - for i, text in enumerate(source): + input_ids = [] + for text in input_text: tokens = self.tokenizer( text, return_tensors="np", @@ -977,30 +950,42 @@ def _preprocess(self, source): add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, (ChatGLMv2Tokenizer, ChatGLMTokenizer)), ) - input_ids = tokens["input_ids"][0] - length = len(input_ids) - self.inputs["input_ids"][i : i + 1, :length] = input_ids - self.inputs["penalty_score"][i : i + 1] = self.config.repetition_penalty - self.inputs["frequency_score"][i : i + 1] = 0.0 - self.inputs["presence_score"][i : i + 1] = 0.0 - self.inputs["top_p"][i : i + 1] = self.config.top_p - self.inputs["temperature"][i : i + 1] = self.config.temperature - self.inputs["seq_lens_this_time"][i : i + 1] = length - self.inputs["seq_lens_encoder"][i : i + 1] = length - self.inputs["seq_lens_decoder"][i : i + 1] = 0 - self.inputs["step_idx"][i : i + 1] = 0 - self.inputs["stop_flags"][i : i + 1] = False - self.inputs["not_need_stop"][0] = True - need_block_nums = ( - length + self.config.max_length + self.pre_cache_length + self.block_size - 1 - ) // self.block_size - for bi in range(need_block_nums): - bi_now = self.free_list.pop() - self.used_list[i].append(bi_now) - self.inputs["block_tables"][i : i + 1, bi] = bi_now - - -class DygraphBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): + input_ids.append(tokens["input_ids"][0]) + + seq_lens = self.pad_batch_data(input_ids) + self.model_inputs["input_ids"] = self.input_ids + + self.model_inputs["block_tables"][:][:] = -1 + free_list = list(range(self.max_block_nums)) + for i in range(self.config.batch_size): + for j in range( + (seq_lens[i] + self.config.max_length + self.config.block_size - 1) // self.config.block_size + ): + used_block_id = free_list.pop() + self.model_inputs["block_tables"][i, j] = used_block_id + + self.model_inputs["seq_lens_this_time"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1)) + self.model_inputs["seq_lens_encoder"] = paddle.to_tensor(np.array(seq_lens).astype("int32").reshape(-1, 1)) + self.model_inputs["seq_lens_decoder"] = paddle.full( + shape=[self.config.batch_size, 1], fill_value=0, dtype="int32" + ) + self.model_inputs["step_idx"] = paddle.full(shape=[self.config.batch_size, 1], fill_value=0, dtype="int64") + self.model_inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=True, dtype="bool") + self.model_inputs["stop_flags"] = paddle.full( + shape=[self.config.batch_size, 1], fill_value=False, dtype="bool" + ) + self.model_inputs["stop_nums"] = paddle.full(shape=[1], fill_value=self.config.batch_size, dtype="int64") + self.model_inputs["pre_ids"] = paddle.full( + shape=[self.config.batch_size, self.config.max_length], fill_value=-1, dtype="int64" + ) + self.model_inputs["next_tokens"] = paddle.full(shape=[self.config.batch_size, 1], fill_value=-1, dtype="int64") + + if self.config.mode == "static": + for k, v in self.model_inputs.items(): + v.name = k + + +class DygraphBlockInferencePredictor(BlockInferencePredictorMixin): def __init__( self, config: PredictorArgument, @@ -1008,26 +993,23 @@ def __init__( tokenizer: PretrainedTokenizer = None, ): self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size) - BasePredictor.__init__(self, config, tokenizer) BlockInferencePredictorMixin.__init__(self, config, tokenizer) - cachekv_dtype = self.dtype - if config.cachekv_int8_type is not None: - cachekv_dtype = "uint8" + cachekv_dtype = self.dtype if config.cachekv_int8_type is None else "uint8" self.cache_kvs = [paddle.zeros(shape, dtype=cachekv_dtype) for shape in self.cache_kvs_shape] self.model = model - self.init_inputs(config) + self.init_model_inputs(config) if config.export_precache: - self.inputs["pre_caches"] = self.pre_caches + self.model_inputs["pre_caches"] = self.pre_caches if config.cachekv_int8_type == "dynamic": - self.inputs["k_quant_scales"] = self.k_quant_scales - self.inputs["v_quant_scales"] = self.v_quant_scales - self.inputs["k_dequant_scales"] = self.k_dequant_scales - self.inputs["v_dequant_scales"] = self.v_dequant_scales + self.model_inputs["k_quant_scales"] = self.k_quant_scales + self.model_inputs["v_quant_scales"] = self.v_quant_scales + self.model_inputs["k_dequant_scales"] = self.k_dequant_scales + self.model_inputs["v_dequant_scales"] = self.v_dequant_scales - self.inputs["cache_kvs"] = self.cache_kvs + self.model_inputs["cache_kvs"] = self.cache_kvs @paddle.no_grad() def _infer(self, inputs: dict[str, paddle.Tensor]): @@ -1036,7 +1018,7 @@ def _infer(self, inputs: dict[str, paddle.Tensor]): ) @paddle.no_grad() - def predict(self, input_texts: str | list[str], return_tokens=False): + def predict(self, input_texts: list[str], return_tokens=False): self._preprocess(input_texts) result_queue = mp.Queue() @@ -1049,12 +1031,8 @@ def predict(self, input_texts: str | list[str], return_tokens=False): read_res_process = mp.Process(target=read_res, args=[self.model_name_or_path, tensor_queue, result_queue]) read_res_process.start() - while self.inputs["not_need_stop"]: - self._infer(self.inputs) - # reset free_list - for i in range(self.config.batch_size): - self.free_list.extend(self.used_list[i]) - self.used_list[i] = [] + while self.model_inputs["not_need_stop"]: + self._infer(self.model_inputs) outputs = [] output_tokens = [] @@ -1068,7 +1046,7 @@ def predict(self, input_texts: str | list[str], return_tokens=False): return outputs -class StaticBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): +class StaticBlockInferencePredictor(BlockInferencePredictorMixin): def __init__( self, config: PredictorArgument, @@ -1076,39 +1054,31 @@ def __init__( tokenizer: PretrainedTokenizer = None, ): self.cache_kvs_shape = cache_kvs_shape - BasePredictor.__init__(self, config, tokenizer) BlockInferencePredictorMixin.__init__(self, config, tokenizer) - self.init_inputs(config) + self._create_predictor(config) + + self.init_model_inputs(config) if config.export_precache: for i in range(self.num_layers): - self.inputs["pre_caches_{}".format(i)] = self.pre_caches[i] + self.model_inputs["pre_caches_{}".format(i)] = self.pre_caches[i] - self.cache_kvs = {} - cachekv_dtype = config.dtype - if config.cachekv_int8_type is not None: - cachekv_dtype = "uint8" + cachekv_dtype = config.dtype if config.cachekv_int8_type is None else "uint8" for i in range(len(self.cache_kvs_shape) // 2): - self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros( + self.model_inputs["key_caches_{}".format(i)] = paddle.zeros( self.cache_kvs_shape[2 * i], dtype=cachekv_dtype ) - self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros( + self.model_inputs["value_caches_{}".format(i)] = paddle.zeros( self.cache_kvs_shape[2 * i + 1], dtype=cachekv_dtype ) for i in range(self.num_layers): if self.config.cachekv_int8_type == "dynamic": - self.inputs["k_quant_scales_" + str(i)] = self.k_quant_scales[i] - self.inputs["v_quant_scales_" + str(i)] = self.v_quant_scales[i] - self.inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i] - self.inputs["v_dequant_scales_" + str(i)] = self.v_dequant_scales[i] - - self._create_predictor(config) - self.input_names = self.predictor.get_input_names() - - self._share_data() - self.seq_lens_handle = self.predictor.get_input_handle("seq_lens_this_time") + self.model_inputs["k_quant_scales_" + str(i)] = self.k_quant_scales[i] + self.model_inputs["v_quant_scales_" + str(i)] = self.v_quant_scales[i] + self.model_inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i] + self.model_inputs["v_dequant_scales_" + str(i)] = self.v_dequant_scales[i] def _create_predictor(self, predictor_args: PredictorArgument): if not is_paddlenlp_ops_available(): @@ -1163,37 +1133,9 @@ def _create_predictor(self, predictor_args: PredictorArgument): self.predictor = paddle.inference.create_predictor(config) - def _share_data(self): - """ - Share external data for inference predictor. - """ - for name in self.input_names: - if "pre_key_" in name or "pre_value_" in name: - input_tensor = self.predictor.get_input_handle(name) - input_tensor.share_external_data(self.inputs[name]) - continue - if "caches" in name: - input_tensor = self.predictor.get_input_handle(name) - input_tensor.share_external_data(self.cache_kvs[name]) - continue - if "seq_lens_this_time" in name: - continue - input_tensor = self.predictor.get_input_handle(name) - input_tensor.share_external_data(self.inputs[name]) - - def _infer(self): - self.predictor.run() - - def predict(self, input_texts: str | list[str], return_tokens=False): - + def predict(self, input_texts: list[str], return_tokens=False): s_time = time.time() self._preprocess(input_texts) - real_bsz = len(input_texts) - - import copy - - seq_lens_this_time = copy.deepcopy(self.inputs["seq_lens_this_time"][:real_bsz]) - self.seq_lens_handle.share_external_data(seq_lens_this_time) logger.info(f"preprocess spend {time.time() - s_time}") result_queue = mp.Queue() @@ -1207,15 +1149,10 @@ def predict(self, input_texts: str | list[str], return_tokens=False): read_res_process.start() s_time = time.time() - while self.inputs["not_need_stop"]: - self.predictor.run() + while self.model_inputs["not_need_stop"]: + self.predictor.run(list(self.model_inputs.values())) logger.info(f"running spend {time.time() - s_time}") - # reset free_list - for i in range(self.config.batch_size): - self.free_list.extend(self.used_list[i]) - self.used_list[i] = [] - outputs = [] output_tokens = [] while len(outputs) < self.batch_size: @@ -1227,19 +1164,6 @@ def predict(self, input_texts: str | list[str], return_tokens=False): else: return outputs - def _preprocess(self, source): - BlockInferencePredictorMixin._preprocess(self, source) - for i, text in enumerate(source): - tokens = self.tokenizer( - text, return_tensors="np", padding=False, truncation=True, max_length=(self.config.src_length) - ) - input_ids = tokens["input_ids"][0] - length = len(input_ids) - need_block_nums = ( - length + self.config.max_length + self.pre_cache_length + self.block_size - 1 - ) // self.block_size - self.inputs["encoder_block_lens"][i : i + 1] = need_block_nums - def get_ptq_multicards_num(directory): count = 0 @@ -1268,38 +1192,6 @@ def create_predictor( config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) - max_position_embeddings = get_model_max_position_embeddings(config) - if max_position_embeddings is None: - max_position_embeddings = 2048 - logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048") - - if predictor_args.src_length is None: - if predictor_args.max_length is None: - predictor_args.src_length = get_default_max_encoding_length(config) - predictor_args.max_length = get_default_max_decoding_length(config) - else: - predictor_args.src_length = max_position_embeddings - predictor_args.max_length - if predictor_args.src_length <= 0: - raise ValueError( - f"--max_length<{predictor_args.max_length}> param should be smaller " - f"than max_position_embeddings<{max_position_embeddings}>" - ) - else: - if predictor_args.max_length is None: - predictor_args.max_length = max_position_embeddings - predictor_args.src_length - if predictor_args.max_length <= 0: - raise ValueError( - f"--src_length<{predictor_args.src_length}> param should be smaller " - f"than max_position_embeddings<{max_position_embeddings}>" - ) - else: - if predictor_args.src_length + predictor_args.max_length > max_position_embeddings: - raise ValueError( - f"The sum of src_length<{predictor_args.src_length}> and " - f"max_length<{predictor_args.max_length}> should be smaller than or equal to " - f"the maximum position embedding size<{max_position_embeddings}>" - ) - # update config parameter for inference predictor if predictor_args.decode_strategy == "greedy_search": predictor_args.top_p = 0.0 @@ -1362,8 +1254,10 @@ def create_predictor( config.avx_type = predictor_args.avx_type if config.quantization_config.quant_type is not None: + predictor_args.quant_type = config.quantization_config.quant_type config.quant_type = config.quantization_config.quant_type if "c8" in config.quant_type: + predictor_args.cachekv_int8_type = "static" config.cachekv_int8_type = "static" ptq_multicards_num = get_ptq_multicards_num(config.model_name_or_path) @@ -1631,7 +1525,7 @@ def predict(): target_texts.append("") else: - source_texts = ["你好,请问你是谁?"] * predictor_args.batch_size + source_texts = ["解释一下温故而知新"] * predictor_args.batch_size target_texts = [""] * predictor_args.batch_size batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) diff --git a/llm/utils/utils.py b/llm/utils/utils.py index a7839d79f5e1..7c5397d90cbd 100644 --- a/llm/utils/utils.py +++ b/llm/utils/utils.py @@ -34,7 +34,6 @@ AutoTokenizer, ChatGLMv2Tokenizer, LlamaForCausalLMPipe, - PretrainedConfig, Qwen2ForCausalLMPipe, ) from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer @@ -725,52 +724,6 @@ def init_chat_template( tokenizer.init_chat_template(chat_template_file) -def get_model_max_position_embeddings(config: PretrainedConfig) -> Optional[int]: - names = [ - "max_position_embeddings", # most of models - "max_sequence_length", # GLM model - "seq_length", # llama model - ] - for name in names: - max_length = config.get(name, None) - if max_length is not None: - return max_length - return None - - -def get_default_max_decoding_length(config: PretrainedConfig, default: int = 1024) -> int: - """get the default max decoding length from config. - - Args: - config (PretrainedConfig): the instance of PretrainedConfig - default (int): the default value of max decoding length - - Returns: - int: the default max_length of decoding length - """ - max_position_embeddings = get_model_max_position_embeddings(config) - if max_position_embeddings is None: - return default - return max_position_embeddings // 4 - - -def get_default_max_encoding_length(config: PretrainedConfig, default: int = 1024) -> int: - """get the default max encoding length from config. - - Args: - config (PretrainedConfig): the instance of PretrainedConfig - default (int): the default value of max encoding length - - Returns: - int: the default max_length of encoding length - """ - - max_position_embeddings = get_model_max_position_embeddings(config) - if max_position_embeddings is None: - return default - return max_position_embeddings // 4 * 3 - - def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue): tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index a1b47c070f6d..bfd425293bdb 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -349,6 +349,7 @@ def __init__(self, config: LlamaConfig): self.quant_type = config.quant_type self.rope_theta = config.rope_theta + self.use_neox = True self.use_weight_only = False if config.quant_type == "weight_only_int8": @@ -562,7 +563,7 @@ def __init__(self, config: LlamaConfig): cache_v_out_scale_attrs=cache_v_out_scale_attrs, epsilon=self.epsilon, norm_type="rmsnorm", - use_neox_rotary_style=True, + use_neox_rotary_style=self.use_neox, cachekv_int8_type=config.cachekv_int8_type, rank_id=config.tensor_parallel_rank, trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True), @@ -683,7 +684,7 @@ def forward( from paddlenlp_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( - input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, True + input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox ) with dy2st_nocheck_guard_context(): @@ -1009,7 +1010,7 @@ def set_state_dict(self, state_dict): ) if self.config.cachekv_int8_type == "static": - cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_act_scales.json") + cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json") if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: cache_scale_json_path = os.path.join( self.quant_model_path, f"cachekv_act_scales_{self.config.tensor_parallel_rank}.json" diff --git a/paddlenlp/quantization/quantization_config.py b/paddlenlp/quantization/quantization_config.py index 0222aeb1ef3e..f5b04e188e15 100644 --- a/paddlenlp/quantization/quantization_config.py +++ b/paddlenlp/quantization/quantization_config.py @@ -64,9 +64,9 @@ def __init__( raise ValueError( f"weight_quantize_algo:{weight_quantize_algo} not in supported list ['weight_only_int8', 'weight_only_int4', 'llm.int8', 'a8w8', 'nf4', 'fp4']" ) - if quant_type is not None and quant_type not in ["weight_only_int8", "weight_only_int4", "a8w8"]: + if quant_type is not None and quant_type not in ["weight_only_int8", "weight_only_int4", "a8w8", "a8w8c8"]: raise ValueError( - f"quant_type:{quant_type} not in supported list ['weight_only_int8', 'weight_only_int4', 'a8w8']" + f"quant_type:{quant_type} not in supported list ['weight_only_int8', 'weight_only_int4', 'a8w8', 'a8w8c8']" ) self.weight_quantize_algo = weight_quantize_algo self.quant_type = quant_type From 6f9d819a85903ca7b0a2241398d6c53383de70b9 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 6 Aug 2024 14:54:16 +0000 Subject: [PATCH 2/6] update --- llm/predict/export_model.py | 5 ++- llm/predict/predictor.py | 19 +++++----- paddlenlp/experimental/model_utils.py | 26 ++++++++++---- .../transformers/llama/modeling.py | 36 +++++++++---------- 4 files changed, 50 insertions(+), 36 deletions(-) diff --git a/llm/predict/export_model.py b/llm/predict/export_model.py index 91b3bc4cb7c3..83dcc371427e 100644 --- a/llm/predict/export_model.py +++ b/llm/predict/export_model.py @@ -61,7 +61,10 @@ def main(): }, ) predictor.model.config.save_pretrained(export_args.output_path) - predictor.generation_config.save_pretrained(export_args.output_path) + if predictor.generation_config is not None: + predictor.generation_config.save_pretrained(export_args.output_path) + else: + predictor.model.generation_config.save_pretrained(export_args.output_path) predictor.tokenizer.save_pretrained(export_args.output_path) generate_rank_mapping(os.path.join(export_args.output_path, "rank_mapping.csv")) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 1841ca1eea58..10c634424bba 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -396,7 +396,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.architectures = self.model_config.architectures[0].lower() - self.dtype = config.dtype or self.model_config + self.dtype = config.dtype or self.model_config.dtype self.pre_ids = paddle.full([config.batch_size, config.total_max_length], -1, dtype="int64") if config.device == "cpu" and config.avx_model: @@ -408,7 +408,6 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.tgt_generation_mask = None self.tgt_pos = None else: - self.arange_tensor_encoder = paddle.arange(config.total_max_length, dtype=self.dtype) self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape] self.num_layers, self.num_attention_heads, self.head_dim = ( len(self.cache_kvs), @@ -548,8 +547,8 @@ def _preprocess(self, source): # alibi encoder alibi_slopes = get_alibi_slopes(self.model_config.n_head) inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32") - - alibi = alibi_slopes[None, :, None, None] * self.arange_tensor_encoder + arange_tensor_encoder = paddle.arange(self.config.total_max_length, dtype=self.config.dtype) + alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder if self.model_config.tensor_parallel_degree > 1: block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree @@ -773,8 +772,6 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): self.dtype = config.dtype or self.model_config.dtype - self.block_size = config.block_size - try: self.rope_theta = self.model_config.rope_theta except: @@ -883,19 +880,21 @@ def init_model_inputs(self, config: PredictorArgument): # bloom model needs src_mask and tgt_mask! if "bloom" in self.architectures: - lower_one_tril = paddle.tril(paddle.ones(shape=(config.total_max_length, config.total_max_length), dtype=self.dtype)) + lower_one_tril = paddle.tril( + paddle.ones(shape=(config.total_max_length, config.total_max_length), dtype=self.dtype) + ) lower_one_tril = lower_one_tril[None, None, :, :] - self.model_inputs["src_mask"] = lower_one_tril.tile([self.batch_size, 1, 1, 1]) + self.model_inputs["src_mask"] = lower_one_tril.tile([config.batch_size, 1, 1, 1]) self.model_inputs["tgt_mask"] = paddle.full( shape=[config.batch_size, 1, 1, config.total_max_length], fill_value=1, dtype=self.dtype ) arange_tensor_encoder = paddle.arange(config.total_max_length).astype(self.dtype) alibi_slopes = get_alibi_slopes(self.num_attention_heads) alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder - alibi_encoder = alibi.tile([self.batch_size, 1, config.total_max_length, 1]) + alibi_encoder = alibi.tile([config.batch_size, 1, config.total_max_length, 1]) alibi_decoder = alibi.tile( [ - self.batch_size, + config.batch_size, 1, 1, 1, diff --git a/paddlenlp/experimental/model_utils.py b/paddlenlp/experimental/model_utils.py index b5a43eebd387..b187bb3700e9 100644 --- a/paddlenlp/experimental/model_utils.py +++ b/paddlenlp/experimental/model_utils.py @@ -391,7 +391,12 @@ def __init__( class CacheScaleLoader: def __init__( - self, scale_json_file_path="cache_scales.json", key_map_dict=None, num_of_layers=None, num_heads=None + self, + scale_json_file_path="cache_scales.json", + key_map_dict=None, + num_of_layers=None, + num_heads=None, + num_key_value_heads=None, ): with open(scale_json_file_path) as json_file: self.scale_dict = json.load(json_file) @@ -402,12 +407,21 @@ def __init__( scale_type_out = "cache_k_out_scale" else: scale_type_out = "cache_v_out_scale" - self.scale[scale_type] = np.full([num_of_layers, num_heads], fill_value=-1.0) - self.scale[scale_type_out] = np.full([num_of_layers, num_heads], fill_value=-1.0) + self.scale[scale_type] = np.full([num_of_layers, num_key_value_heads], fill_value=-1.0) + self.scale[scale_type_out] = np.full([num_of_layers, num_key_value_heads], fill_value=-1.0) for i in range(num_of_layers): if key_template.replace("#", str(i)) in self.scale_dict.keys(): - self.scale[scale_type][i, :] = [ - 127.0 / num for num in self.scale_dict[key_template.replace("#", str(i))] + if num_heads != num_key_value_heads: + self.scale[scale_type][i, :] = [ + 127.0 / self.scale_dict[key_template.replace("#", str(i))][j] + for j in range(0, num_heads, num_heads // num_key_value_heads) + ] + else: + self.scale[scale_type][i, :] = [ + 127.0 / self.scale_dict[key_template.replace("#", str(i))][j] + for j in range(0, num_key_value_heads) + ] + self.scale[scale_type_out][i, :] = [ + 1.0 / self.scale[scale_type][i, j] for j in range(0, num_key_value_heads) ] - self.scale[scale_type_out][i, :] = [1.0 / self.scale[scale_type][i, j] for j in range(num_heads)] diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index bfd425293bdb..b90355a31afc 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -801,11 +801,9 @@ def set_state_dict(self, state_dict): concated_ffn1_weight = np.concatenate( [unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1 ) - ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight) - qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight) + qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype()) if self.use_weight_only: - qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight) qkv_weight_tensor = paddle.transpose(qkv_weight_tensor, perm=[1, 0]) qkv_quanted_weight_tensor, qkv_weight_scale_tensor = weight_quantize( qkv_weight_tensor, algo=self.quant_algo @@ -817,11 +815,11 @@ def set_state_dict(self, state_dict): paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8") ) else: - self.transformer_block.qkv_weights[idx].set_value( - qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype) - ) + self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor) - linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]) + linear_weight_tensor = paddle.to_tensor( + state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)] + ).cast(paddle.get_default_dtype()) if self.use_weight_only: linear_quanted_weight_tensor, linear_weight_scale_tensor = weight_quantize( linear_weight_tensor, algo=self.quant_algo @@ -845,10 +843,9 @@ def set_state_dict(self, state_dict): ) ) else: - self.transformer_block.linear_weights[idx].set_value( - linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype) - ) + self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor) + ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype()) if self.use_weight_only: ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize( ffn1_weight_tensor, algo=self.quant_algo @@ -865,11 +862,11 @@ def set_state_dict(self, state_dict): paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8") ) else: - self.transformer_block.ffn1_weights[idx].set_value( - ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype) - ) + self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor) - ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]) + ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]).cast( + paddle.get_default_dtype() + ) if self.use_weight_only: ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize( ffn2_weight_tensor, algo=self.quant_algo @@ -893,9 +890,7 @@ def set_state_dict(self, state_dict): ) ) else: - self.transformer_block.ffn2_weights[idx].set_value( - ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype) - ) + self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor) if "a8w8" in self.quant_type: if self.shift_smooth_all_linears: @@ -1019,7 +1014,8 @@ def set_state_dict(self, state_dict): cache_scale_json_path, cache_scale_map_dict, num_of_layers=self.config.num_hidden_layers, - num_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, ) for k, v in cache_scales_loader.scale.items(): for i_layer, weight_scale in enumerate(v): @@ -1401,7 +1397,9 @@ def forward( @paddle.no_grad() def set_state_dict(self, state_dict): if "lm_head.weight" in state_dict: - self.lm_head.weight.set_value(state_dict["lm_head.weight"]) + self.lm_head.weight.set_value( + paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) + ) self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) From 244802a32b5058c76b161d2939d28b2a35f7121e Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 7 Aug 2024 09:14:12 +0000 Subject: [PATCH 3/6] update --- csrc/generation/encode_rotary_qk.cu | 6 +++--- llm/predict/predictor.py | 10 ++++++++-- llm/utils/utils.py | 7 +++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/csrc/generation/encode_rotary_qk.cu b/csrc/generation/encode_rotary_qk.cu index d5f55a172592..3c0860feb1c3 100644 --- a/csrc/generation/encode_rotary_qk.cu +++ b/csrc/generation/encode_rotary_qk.cu @@ -111,6 +111,7 @@ void LaunchRotaryQK(const paddle::Tensor& q, auto cu_stream = q.stream(); dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims); + dim3 grid_k(batch_size, kv_head_num, seq_len * rotary_emb_dims); const int last_dim = dim_head / rotary_emb_dims; auto getBlockSize = [](int dim) { if (dim > 256) { @@ -148,7 +149,6 @@ void LaunchRotaryQK(const paddle::Tensor& q, head_num, seq_len * rotary_emb_dims, last_dim); - dim3 grid_k(batch_size, kv_head_num, seq_len * rotary_emb_dims); RotaryKernel<<>>( k_data, cos_emb, @@ -172,7 +172,7 @@ void LaunchRotaryQK(const paddle::Tensor& q, head_num, seq_len * rotary_emb_dims, last_dim); - NeoXRotaryKernel<<>>( + NeoXRotaryKernel<<>>( k_data, cos_emb, sin_emb, @@ -180,7 +180,7 @@ void LaunchRotaryQK(const paddle::Tensor& q, k_out_data, rotary_emb_dims, batch_size, - head_num, + kv_head_num, seq_len * rotary_emb_dims, last_dim); } diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 10c634424bba..16e2bb4ba825 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -1039,6 +1039,9 @@ def predict(self, input_texts: list[str], return_tokens=False): result = result_queue.get(timeout=1) outputs.append(result[-1]) output_tokens.append(result[-2]) + + read_res_process.terminate() + if return_tokens: return outputs, output_tokens else: @@ -1158,6 +1161,9 @@ def predict(self, input_texts: list[str], return_tokens=False): result = result_queue.get(timeout=1) outputs.append(result[-1]) output_tokens.append(result[-2]) + + read_res_process.terminate() + if return_tokens: return outputs, output_tokens else: @@ -1524,8 +1530,8 @@ def predict(): target_texts.append("") else: - source_texts = ["解释一下温故而知新"] * predictor_args.batch_size - target_texts = [""] * predictor_args.batch_size + source_texts = ["解释一下温故而知新", "解释一下温故而知新"] + target_texts = ["", ""] batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size) diff --git a/llm/utils/utils.py b/llm/utils/utils.py index 7c5397d90cbd..098248437298 100644 --- a/llm/utils/utils.py +++ b/llm/utils/utils.py @@ -725,11 +725,10 @@ def init_chat_template( def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue): - tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, - ) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) paddle.device.set_device("cpu") + paddle.disable_static() outputs = [] output_tensor = tensor_queue.get(timeout=1) @@ -746,7 +745,7 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q output_numpy = output_tensor[2 : bsz + 2].numpy() output_numpy[output_numpy == -1] = 2 outputs.append(output_numpy) - if output_tensor[0, 0] == -1: + if int(output_tensor[0, 0]) == -1: break output = np.concatenate(outputs, axis=1).tolist() seqs = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False) From 147f2c18fbf84675d49e030925dfaa16ae9028de Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Thu, 8 Aug 2024 08:49:09 +0000 Subject: [PATCH 4/6] fix ci --- tests/llm/test_predictor.py | 72 ++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/tests/llm/test_predictor.py b/tests/llm/test_predictor.py index 093e086bec3f..fe9642628ac5 100644 --- a/tests/llm/test_predictor.py +++ b/tests/llm/test_predictor.py @@ -62,9 +62,9 @@ def setUp(self) -> None: AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir) def test_predictor(self): - self.run_predictor({"inference_model": True}) + self.run_predictor({"inference_model": True, "src_length": 512, "max_length": 256}) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False}) + self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 256}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) # compare the generation result of inference & dygraph model @@ -84,10 +84,14 @@ def test_predictor(self): self.assertGreaterEqual(count / len(result_0), 0.4) def test_flash_attention(self): - self.run_predictor({"inference_model": False, "use_flash_attention": False}) + self.run_predictor( + {"inference_model": False, "use_flash_attention": False, "src_length": 512, "max_length": 256} + ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False, "use_flash_attention": True}) + self.run_predictor( + {"inference_model": False, "use_flash_attention": True, "src_length": 512, "max_length": 256} + ) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) # compare the generation result of dygraph & flash attention model @@ -108,9 +112,11 @@ def test_flash_attention(self): self.assertEqual(full_match / len(result_0), 1.0) def test_wint8(self): - self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8"}) + self.run_predictor( + {"inference_model": True, "quant_type": "weight_only_int8", "src_length": 512, "max_length": 256} + ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False}) + self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 256}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) assert len(result_0) == len(result_1) @@ -159,9 +165,25 @@ def download_precache_files(self): get_path_from_url_with_filelock(file_url, root_dir=self.output_dir) def test_predictor(self): - self.run_predictor({"inference_model": True, "export_precache": True, "prefix_path": self.output_dir}) + self.run_predictor( + { + "inference_model": True, + "export_precache": True, + "prefix_path": self.output_dir, + "src_length": 512, + "max_length": 256, + } + ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False, "export_precache": True, "prefix_path": self.output_dir}) + self.run_predictor( + { + "inference_model": False, + "export_precache": True, + "prefix_path": self.output_dir, + "src_length": 512, + "max_length": 256, + } + ) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) # compare the generation result of inference & dygraph model @@ -231,9 +253,9 @@ def setUp(self) -> None: AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir) def test_blha(self): - self.run_predictor({"inference_model": True, "block_attn": True}) + self.run_predictor({"inference_model": True, "block_attn": True, "src_length": 1024, "max_length": 48}) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False}) + self.run_predictor({"inference_model": False, "src_length": 1024, "max_length": 48}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) # compare the generation result of inference & dygraph model @@ -253,9 +275,19 @@ def test_blha(self): self.assertGreaterEqual(count / len(result_0), 0.4) def test_wint8(self): - self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8", "block_attn": True}) + self.run_predictor( + { + "inference_model": True, + "quant_type": "weight_only_int8", + "block_attn": True, + "src_length": 512, + "max_length": 256, + } + ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8"}) + self.run_predictor( + {"inference_model": True, "quant_type": "weight_only_int8", "src_length": 512, "max_length": 256} + ) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) assert len(result_0) == len(result_1) @@ -274,9 +306,17 @@ def test_wint8(self): self.assertGreaterEqual(count / len(result_0), 0.4) def test_cachekv_int8(self): - self.run_predictor({"inference_model": True, "block_attn": True, "cachekv_int8_type": "dynamic"}) + self.run_predictor( + { + "inference_model": True, + "block_attn": True, + "cachekv_int8_type": "dynamic", + "src_length": 512, + "max_length": 256, + } + ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": True, "block_attn": True}) + self.run_predictor({"inference_model": True, "block_attn": True, "src_length": 512, "max_length": 256}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) print(f"result_0 {result_0}, result_1 {result_1}") @@ -311,9 +351,9 @@ def setUp(self) -> None: def test_predictor(self): self.init_dist_env() - self.run_predictor({"inference_model": True}) + self.run_predictor({"inference_model": True, "src_length": 512, "max_length": 256}) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False}) + self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 256}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) # compare the generation result of inference & dygraph model From 0946ab9f5466fcf6501b6fcc7a62d7abbda186ea Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Thu, 8 Aug 2024 12:29:55 +0000 Subject: [PATCH 5/6] fix ci --- llm/predict/predictor.py | 23 +++++++++--- llm/utils/utils.py | 13 +++++++ tests/llm/test_predictor.py | 72 ++++++++++--------------------------- 3 files changed, 50 insertions(+), 58 deletions(-) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index 16e2bb4ba825..536f8ed1861a 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -32,6 +32,7 @@ dybatch_preprocess, get_alibi_slopes, get_infer_model_path, + get_model_max_position_embeddings, get_prefix_tuning_params, init_chat_template, load_real_time_tokens, @@ -64,9 +65,9 @@ class PredictorArgument: model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."}) model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"}) - src_length: int = field(default=4096, metadata={"help": "The max length of source text."}) + src_length: int = field(default=1024, metadata={"help": "The max length of source text."}) min_length: int = field(default=1, metadata={"help": "the min length for decoding."}) - max_length: int = field(default=2048, metadata={"help": "the max length for decoding."}) + max_length: int = field(default=1024, metadata={"help": "the max length for decoding."}) top_k: int = field(default=0, metadata={"help": "top_k parameter for generation"}) top_p: float = field(default=0.7, metadata={"help": "top_p parameter for generation"}) temperature: float = field(default=0.95, metadata={"help": "top_p parameter for generation"}) @@ -1197,6 +1198,20 @@ def create_predictor( config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) + max_position_embeddings = get_model_max_position_embeddings(config) + if max_position_embeddings is None: + max_position_embeddings = predictor_args.src_length + predictor_args.max_length + logger.warning( + f"Can not retrieval `max_position_embeddings` from config.json, use default value {max_position_embeddings}" + ) + else: + if predictor_args.src_length + predictor_args.max_length > max_position_embeddings: + raise ValueError( + f"The sum of src_length<{predictor_args.src_length}> and " + f"max_length<{predictor_args.max_length}> should be smaller than or equal to " + f"the maximum position embedding size<{max_position_embeddings}>" + ) + # update config parameter for inference predictor if predictor_args.decode_strategy == "greedy_search": predictor_args.top_p = 0.0 @@ -1530,8 +1545,8 @@ def predict(): target_texts.append("") else: - source_texts = ["解释一下温故而知新", "解释一下温故而知新"] - target_texts = ["", ""] + source_texts = ["解释一下温故而知新"] * predictor_args.batch_size + target_texts = [""] * predictor_args.batch_size batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size) diff --git a/llm/utils/utils.py b/llm/utils/utils.py index 098248437298..c784783bc898 100644 --- a/llm/utils/utils.py +++ b/llm/utils/utils.py @@ -724,6 +724,19 @@ def init_chat_template( tokenizer.init_chat_template(chat_template_file) +def get_model_max_position_embeddings(config: PretrainedConfig) -> Optional[int]: + names = [ + "max_position_embeddings", # most of models + "max_sequence_length", # GLM model + "seq_length", # llama model + ] + for name in names: + max_length = config.get(name, None) + if max_length is not None: + return max_length + return None + + def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) diff --git a/tests/llm/test_predictor.py b/tests/llm/test_predictor.py index fe9642628ac5..9354b1f54322 100644 --- a/tests/llm/test_predictor.py +++ b/tests/llm/test_predictor.py @@ -17,7 +17,6 @@ import unittest import paddle -import pytest from parameterized import parameterized_class from paddlenlp.experimental.transformers import QWenForQWenVLInferenceModel @@ -62,9 +61,9 @@ def setUp(self) -> None: AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir) def test_predictor(self): - self.run_predictor({"inference_model": True, "src_length": 512, "max_length": 256}) + self.run_predictor({"inference_model": True, "src_length": 512, "max_length": 48}) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 256}) + self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 48}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) # compare the generation result of inference & dygraph model @@ -85,12 +84,12 @@ def test_predictor(self): def test_flash_attention(self): self.run_predictor( - {"inference_model": False, "use_flash_attention": False, "src_length": 512, "max_length": 256} + {"inference_model": False, "use_flash_attention": False, "src_length": 512, "max_length": 48} ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) self.run_predictor( - {"inference_model": False, "use_flash_attention": True, "src_length": 512, "max_length": 256} + {"inference_model": False, "use_flash_attention": True, "src_length": 512, "max_length": 48} ) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) @@ -113,10 +112,10 @@ def test_flash_attention(self): def test_wint8(self): self.run_predictor( - {"inference_model": True, "quant_type": "weight_only_int8", "src_length": 512, "max_length": 256} + {"inference_model": True, "quant_type": "weight_only_int8", "src_length": 512, "max_length": 48} ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 256}) + self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 48}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) assert len(result_0) == len(result_1) @@ -171,7 +170,7 @@ def test_predictor(self): "export_precache": True, "prefix_path": self.output_dir, "src_length": 512, - "max_length": 256, + "max_length": 48, } ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) @@ -181,7 +180,7 @@ def test_predictor(self): "export_precache": True, "prefix_path": self.output_dir, "src_length": 512, - "max_length": 256, + "max_length": 48, } ) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) @@ -199,41 +198,6 @@ def test_predictor(self): self.assertGreaterEqual(count / len(result_0), 0.8) -class PredictorBaseTest(LLMTest, unittest.TestCase): - def load_test_config(self): - config = load_test_config("./tests/fixtures/llm/predictor.yaml", "inference-predict") - config["model_name_or_path"] = "__internal_testing__/micro-random-llama" - - return config - - def test_create_predictor_with_unexpected_length(self): - from predict.predictor import predict - - config = self.load_test_config() - config.pop("src_length", None) - config.pop("max_length", None) - - with pytest.raises(ValueError, match="--src_length<2048> param should be smaller "): - config["src_length"] = 2048 - - with argv_context_guard(config): - predict() - - with pytest.raises(ValueError, match="--max_length<2048> param should be smaller "): - config.pop("src_length", None) - config["max_length"] = 2048 - - with argv_context_guard(config): - predict() - - with pytest.raises(ValueError, match="The sum of src_length<1025> and"): - config["max_length"] = 1024 - config["src_length"] = 1025 - - with argv_context_guard(config): - predict() - - @parameterized_class( ["model_name_or_path", "model_class"], [ @@ -253,9 +217,9 @@ def setUp(self) -> None: AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir) def test_blha(self): - self.run_predictor({"inference_model": True, "block_attn": True, "src_length": 1024, "max_length": 48}) + self.run_predictor({"inference_model": True, "block_attn": True, "src_length": 512, "max_length": 48}) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False, "src_length": 1024, "max_length": 48}) + self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 48}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) # compare the generation result of inference & dygraph model @@ -281,12 +245,12 @@ def test_wint8(self): "quant_type": "weight_only_int8", "block_attn": True, "src_length": 512, - "max_length": 256, + "max_length": 48, } ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) self.run_predictor( - {"inference_model": True, "quant_type": "weight_only_int8", "src_length": 512, "max_length": 256} + {"inference_model": True, "quant_type": "weight_only_int8", "src_length": 512, "max_length": 48} ) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) @@ -298,7 +262,7 @@ def test_wint8(self): count += int(inference_item[: min_length // 2] == no_inference_item[: min_length // 2]) full_match += int(inference_item[:min_length] == no_inference_item[:min_length]) - self.assertGreaterEqual(full_match / len(result_0), 0.6) + self.assertGreaterEqual(full_match / len(result_0), 0.55) if self.model_name_or_path == "__internal_testing__/tiny-fused-chatglm": self.assertGreaterEqual(count / len(result_0), 0.3) @@ -312,11 +276,11 @@ def test_cachekv_int8(self): "block_attn": True, "cachekv_int8_type": "dynamic", "src_length": 512, - "max_length": 256, + "max_length": 48, } ) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": True, "block_attn": True, "src_length": 512, "max_length": 256}) + self.run_predictor({"inference_model": True, "block_attn": True, "src_length": 512, "max_length": 48}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) print(f"result_0 {result_0}, result_1 {result_1}") @@ -328,7 +292,7 @@ def test_cachekv_int8(self): count += int(inference_item[: min_length // 2] == no_inference_item[: min_length // 2]) full_match += int(inference_item[:min_length] == no_inference_item[:min_length]) - self.assertGreaterEqual(count / len(result_0), 0.15) + self.assertGreaterEqual(count / len(result_0), 0.1) @parameterized_class( @@ -351,9 +315,9 @@ def setUp(self) -> None: def test_predictor(self): self.init_dist_env() - self.run_predictor({"inference_model": True, "src_length": 512, "max_length": 256}) + self.run_predictor({"inference_model": True, "src_length": 512, "max_length": 48}) result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) - self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 256}) + self.run_predictor({"inference_model": False, "src_length": 512, "max_length": 48}) result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) # compare the generation result of inference & dygraph model From 88ea827edcd190e4d98ce1c1d1b53e0e0d8be1c9 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Fri, 9 Aug 2024 08:05:59 +0000 Subject: [PATCH 6/6] fix ut --- tests/fixtures/llm/finetune.yaml | 1 + tests/fixtures/llm/lora.yaml | 1 + tests/fixtures/llm/predictor.yaml | 1 + tests/fixtures/llm/prefix_tuning.yaml | 1 + tests/fixtures/llm/pretrain.yaml | 4 ++++ tests/fixtures/llm/ptq.yaml | 1 + tests/fixtures/llm/vera.yaml | 1 + 7 files changed, 10 insertions(+) diff --git a/tests/fixtures/llm/finetune.yaml b/tests/fixtures/llm/finetune.yaml index 08a8fe4a13b1..7e79f9b441a8 100644 --- a/tests/fixtures/llm/finetune.yaml +++ b/tests/fixtures/llm/finetune.yaml @@ -55,6 +55,7 @@ inference-predict: inference-to-static: default: dtype: float16 + max_length: 20 inference-infer: default: diff --git a/tests/fixtures/llm/lora.yaml b/tests/fixtures/llm/lora.yaml index 3c884239614d..5d75cb752682 100644 --- a/tests/fixtures/llm/lora.yaml +++ b/tests/fixtures/llm/lora.yaml @@ -101,6 +101,7 @@ inference-predict: inference-to-static: default: dtype: float16 + max_length: 20 inference-infer: default: diff --git a/tests/fixtures/llm/predictor.yaml b/tests/fixtures/llm/predictor.yaml index 5c08c8f28f83..c59a658a896a 100644 --- a/tests/fixtures/llm/predictor.yaml +++ b/tests/fixtures/llm/predictor.yaml @@ -11,6 +11,7 @@ inference-predict: inference-to-static: default: dtype: float16 + max_length: 40 inference-infer: default: diff --git a/tests/fixtures/llm/prefix_tuning.yaml b/tests/fixtures/llm/prefix_tuning.yaml index de9740a705a8..5de1244030b6 100644 --- a/tests/fixtures/llm/prefix_tuning.yaml +++ b/tests/fixtures/llm/prefix_tuning.yaml @@ -58,6 +58,7 @@ inference-to-static: default: dtype: float16 export_precache: true + max_length: 20 inference-infer: default: diff --git a/tests/fixtures/llm/pretrain.yaml b/tests/fixtures/llm/pretrain.yaml index 1210a7b29cb0..1b4598b1024d 100644 --- a/tests/fixtures/llm/pretrain.yaml +++ b/tests/fixtures/llm/pretrain.yaml @@ -34,6 +34,7 @@ pretrain: inference-predict: default: mode: dynamic + src_length: 512 max_length: 20 batch_size: 2 decode_strategy: greedy_search @@ -43,6 +44,8 @@ inference-predict: inference-to-static: default: dtype: float16 + src_length: 512 + max_length: 20 inference-infer: default: @@ -50,5 +53,6 @@ inference-infer: dtype: float16 batch_size: 2 decode_strategy: greedy_search + src_length: 512 max_length: 20 chat_template: none diff --git a/tests/fixtures/llm/ptq.yaml b/tests/fixtures/llm/ptq.yaml index ee9ae8e56602..ad71c11cb686 100644 --- a/tests/fixtures/llm/ptq.yaml +++ b/tests/fixtures/llm/ptq.yaml @@ -34,6 +34,7 @@ inference-predict: inference-to-static: default: dtype: float16 + max_length: 20 inference-infer: diff --git a/tests/fixtures/llm/vera.yaml b/tests/fixtures/llm/vera.yaml index 72752ddffa91..b988fab4fb7d 100644 --- a/tests/fixtures/llm/vera.yaml +++ b/tests/fixtures/llm/vera.yaml @@ -56,6 +56,7 @@ inference-predict: inference-to-static: default: dtype: float16 + max_length: 20 inference-infer: default: