-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLM Inference] Refactor BlockInferencePredictor #8879
Changes from all commits
e410e4a
6f9d819
244802a
bc75f59
147f2c1
0946ab9
88ea827
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,7 +103,7 @@ std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype(const paddle::DataTyp | |
} | ||
|
||
PD_BUILD_OP(get_padding_offset_v2) | ||
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) | ||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修复算子输入位置错乱问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 很好奇之前怎么能跑通 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 影响name顺序而已,tensor顺序没有错 |
||
.Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"}) | ||
.SetKernelFn(PD_KERNEL(GetPaddingOffsetV2)) | ||
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,11 @@ def main(): | |
}, | ||
) | ||
predictor.model.config.save_pretrained(export_args.output_path) | ||
predictor.model.generation_config.save_pretrained(export_args.output_path) | ||
if predictor.generation_config is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修复generation_config.json保存不正确的bug |
||
predictor.generation_config.save_pretrained(export_args.output_path) | ||
else: | ||
predictor.model.generation_config.save_pretrained(export_args.output_path) | ||
|
||
predictor.tokenizer.save_pretrained(export_args.output_path) | ||
generate_rank_mapping(os.path.join(export_args.output_path, "rank_mapping.csv")) | ||
|
||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -349,6 +349,7 @@ | |
self.quant_type = config.quant_type | ||
|
||
self.rope_theta = config.rope_theta | ||
self.use_neox = True | ||
|
||
self.use_weight_only = False | ||
if config.quant_type == "weight_only_int8": | ||
|
@@ -562,7 +563,7 @@ | |
cache_v_out_scale_attrs=cache_v_out_scale_attrs, | ||
epsilon=self.epsilon, | ||
norm_type="rmsnorm", | ||
use_neox_rotary_style=True, | ||
use_neox_rotary_style=self.use_neox, | ||
cachekv_int8_type=config.cachekv_int8_type, | ||
rank_id=config.tensor_parallel_rank, | ||
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True), | ||
|
@@ -683,7 +684,7 @@ | |
from paddlenlp_ops import fused_get_rotary_embedding | ||
|
||
new_rope = fused_get_rotary_embedding( | ||
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, True | ||
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox | ||
) | ||
|
||
with dy2st_nocheck_guard_context(): | ||
|
@@ -800,11 +801,9 @@ | |
concated_ffn1_weight = np.concatenate( | ||
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1 | ||
) | ||
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight) | ||
|
||
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight) | ||
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥要cast到 default dtype There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了同时能跑bf16/fp16 |
||
if self.use_weight_only: | ||
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight) | ||
qkv_weight_tensor = paddle.transpose(qkv_weight_tensor, perm=[1, 0]) | ||
qkv_quanted_weight_tensor, qkv_weight_scale_tensor = weight_quantize( | ||
qkv_weight_tensor, algo=self.quant_algo | ||
|
@@ -816,11 +815,11 @@ | |
paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8") | ||
) | ||
else: | ||
self.transformer_block.qkv_weights[idx].set_value( | ||
qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype) | ||
) | ||
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor) | ||
|
||
linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]) | ||
linear_weight_tensor = paddle.to_tensor( | ||
state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)] | ||
).cast(paddle.get_default_dtype()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同问 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同回 |
||
if self.use_weight_only: | ||
linear_quanted_weight_tensor, linear_weight_scale_tensor = weight_quantize( | ||
linear_weight_tensor, algo=self.quant_algo | ||
|
@@ -844,10 +843,9 @@ | |
) | ||
) | ||
else: | ||
self.transformer_block.linear_weights[idx].set_value( | ||
linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype) | ||
) | ||
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor) | ||
|
||
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype()) | ||
if self.use_weight_only: | ||
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize( | ||
ffn1_weight_tensor, algo=self.quant_algo | ||
|
@@ -864,11 +862,11 @@ | |
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8") | ||
) | ||
else: | ||
self.transformer_block.ffn1_weights[idx].set_value( | ||
ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype) | ||
) | ||
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor) | ||
|
||
ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]) | ||
ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]).cast( | ||
paddle.get_default_dtype() | ||
) | ||
if self.use_weight_only: | ||
ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize( | ||
ffn2_weight_tensor, algo=self.quant_algo | ||
|
@@ -892,9 +890,7 @@ | |
) | ||
) | ||
else: | ||
self.transformer_block.ffn2_weights[idx].set_value( | ||
ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype) | ||
) | ||
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor) | ||
|
||
if "a8w8" in self.quant_type: | ||
if self.shift_smooth_all_linears: | ||
|
@@ -1009,7 +1005,7 @@ | |
) | ||
|
||
if self.config.cachekv_int8_type == "static": | ||
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_act_scales.json") | ||
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json") | ||
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: | ||
cache_scale_json_path = os.path.join( | ||
self.quant_model_path, f"cachekv_act_scales_{self.config.tensor_parallel_rank}.json" | ||
|
@@ -1018,7 +1014,8 @@ | |
cache_scale_json_path, | ||
cache_scale_map_dict, | ||
num_of_layers=self.config.num_hidden_layers, | ||
num_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, | ||
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, | ||
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree, | ||
) | ||
for k, v in cache_scales_loader.scale.items(): | ||
for i_layer, weight_scale in enumerate(v): | ||
|
@@ -1400,7 +1397,9 @@ | |
@paddle.no_grad() | ||
def set_state_dict(self, state_dict): | ||
if "lm_head.weight" in state_dict: | ||
self.lm_head.weight.set_value(state_dict["lm_head.weight"]) | ||
self.lm_head.weight.set_value( | ||
paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype) | ||
) | ||
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ inference-predict: | |
inference-to-static: | ||
default: | ||
dtype: float16 | ||
max_length: 20 | ||
|
||
|
||
inference-infer: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修复gqa kernel计算bug