Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM Inference] Refactor BlockInferencePredictor #8879

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/generation/encode_rotary_qk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ void LaunchRotaryQK(const paddle::Tensor& q,

auto cu_stream = q.stream();
dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims);
dim3 grid_k(batch_size, kv_head_num, seq_len * rotary_emb_dims);
const int last_dim = dim_head / rotary_emb_dims;
auto getBlockSize = [](int dim) {
if (dim > 256) {
Expand Down Expand Up @@ -148,7 +149,6 @@ void LaunchRotaryQK(const paddle::Tensor& q,
head_num,
seq_len * rotary_emb_dims,
last_dim);
dim3 grid_k(batch_size, kv_head_num, seq_len * rotary_emb_dims);
RotaryKernel<<<grid_k, BlockSize, 0, cu_stream>>>(
k_data,
cos_emb,
Expand All @@ -172,15 +172,15 @@ void LaunchRotaryQK(const paddle::Tensor& q,
head_num,
seq_len * rotary_emb_dims,
last_dim);
NeoXRotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
NeoXRotaryKernel<<<grid_k, BlockSize, 0, cu_stream>>>(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复gqa kernel计算bug

k_data,
cos_emb,
sin_emb,
seq_lens.data<int>()/*sequence_lengths*/,
k_out_data,
rotary_emb_dims,
batch_size,
head_num,
kv_head_num,
seq_len * rotary_emb_dims,
last_dim);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/generation/get_padding_offset_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype(const paddle::DataTyp
}

PD_BUILD_OP(get_padding_offset_v2)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复算子输入位置错乱问题

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

很好奇之前怎么能跑通

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

影响name顺序而已,tensor顺序没有错

.Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(GetPaddingOffsetV2))
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape))
Expand Down
6 changes: 5 additions & 1 deletion llm/predict/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ def main():
},
)
predictor.model.config.save_pretrained(export_args.output_path)
predictor.model.generation_config.save_pretrained(export_args.output_path)
if predictor.generation_config is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复generation_config.json保存不正确的bug

predictor.generation_config.save_pretrained(export_args.output_path)
else:
predictor.model.generation_config.save_pretrained(export_args.output_path)

predictor.tokenizer.save_pretrained(export_args.output_path)
generate_rank_mapping(os.path.join(export_args.output_path, "rank_mapping.csv"))

Expand Down
410 changes: 162 additions & 248 deletions llm/predict/predictor.py

Large diffs are not rendered by default.

41 changes: 3 additions & 38 deletions llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
AutoTokenizer,
ChatGLMv2Tokenizer,
LlamaForCausalLMPipe,
PretrainedConfig,
Qwen2ForCausalLMPipe,
)
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
Expand Down Expand Up @@ -738,45 +737,11 @@ def get_model_max_position_embeddings(config: PretrainedConfig) -> Optional[int]
return None


def get_default_max_decoding_length(config: PretrainedConfig, default: int = 1024) -> int:
"""get the default max decoding length from config.

Args:
config (PretrainedConfig): the instance of PretrainedConfig
default (int): the default value of max decoding length

Returns:
int: the default max_length of decoding length
"""
max_position_embeddings = get_model_max_position_embeddings(config)
if max_position_embeddings is None:
return default
return max_position_embeddings // 4


def get_default_max_encoding_length(config: PretrainedConfig, default: int = 1024) -> int:
"""get the default max encoding length from config.

Args:
config (PretrainedConfig): the instance of PretrainedConfig
default (int): the default value of max encoding length

Returns:
int: the default max_length of encoding length
"""

max_position_embeddings = get_model_max_position_embeddings(config)
if max_position_embeddings is None:
return default
return max_position_embeddings // 4 * 3


def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue):
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

paddle.device.set_device("cpu")
paddle.disable_static()
outputs = []
output_tensor = tensor_queue.get(timeout=1)

Expand All @@ -793,7 +758,7 @@ def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Q
output_numpy = output_tensor[2 : bsz + 2].numpy()
output_numpy[output_numpy == -1] = 2
outputs.append(output_numpy)
if output_tensor[0, 0] == -1:
if int(output_tensor[0, 0]) == -1:
break
output = np.concatenate(outputs, axis=1).tolist()
seqs = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)
Expand Down
26 changes: 20 additions & 6 deletions paddlenlp/experimental/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@

class CacheScaleLoader:
def __init__(
self, scale_json_file_path="cache_scales.json", key_map_dict=None, num_of_layers=None, num_heads=None
self,
scale_json_file_path="cache_scales.json",
key_map_dict=None,
num_of_layers=None,
num_heads=None,
num_key_value_heads=None,
):
with open(scale_json_file_path) as json_file:
self.scale_dict = json.load(json_file)
Expand All @@ -402,12 +407,21 @@
scale_type_out = "cache_k_out_scale"
else:
scale_type_out = "cache_v_out_scale"
self.scale[scale_type] = np.full([num_of_layers, num_heads], fill_value=-1.0)
self.scale[scale_type_out] = np.full([num_of_layers, num_heads], fill_value=-1.0)
self.scale[scale_type] = np.full([num_of_layers, num_key_value_heads], fill_value=-1.0)
self.scale[scale_type_out] = np.full([num_of_layers, num_key_value_heads], fill_value=-1.0)

Check warning on line 411 in paddlenlp/experimental/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/model_utils.py#L410-L411

Added lines #L410 - L411 were not covered by tests

for i in range(num_of_layers):
if key_template.replace("#", str(i)) in self.scale_dict.keys():
self.scale[scale_type][i, :] = [
127.0 / num for num in self.scale_dict[key_template.replace("#", str(i))]
if num_heads != num_key_value_heads:
self.scale[scale_type][i, :] = [

Check warning on line 416 in paddlenlp/experimental/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/model_utils.py#L415-L416

Added lines #L415 - L416 were not covered by tests
127.0 / self.scale_dict[key_template.replace("#", str(i))][j]
for j in range(0, num_heads, num_heads // num_key_value_heads)
]
else:
self.scale[scale_type][i, :] = [

Check warning on line 421 in paddlenlp/experimental/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/model_utils.py#L421

Added line #L421 was not covered by tests
127.0 / self.scale_dict[key_template.replace("#", str(i))][j]
for j in range(0, num_key_value_heads)
]
self.scale[scale_type_out][i, :] = [

Check warning on line 425 in paddlenlp/experimental/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/model_utils.py#L425

Added line #L425 was not covered by tests
1.0 / self.scale[scale_type][i, j] for j in range(0, num_key_value_heads)
]
self.scale[scale_type_out][i, :] = [1.0 / self.scale[scale_type][i, j] for j in range(num_heads)]
43 changes: 21 additions & 22 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@
self.quant_type = config.quant_type

self.rope_theta = config.rope_theta
self.use_neox = True

Check warning on line 352 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L352

Added line #L352 was not covered by tests

self.use_weight_only = False
if config.quant_type == "weight_only_int8":
Expand Down Expand Up @@ -562,7 +563,7 @@
cache_v_out_scale_attrs=cache_v_out_scale_attrs,
epsilon=self.epsilon,
norm_type="rmsnorm",
use_neox_rotary_style=True,
use_neox_rotary_style=self.use_neox,
cachekv_int8_type=config.cachekv_int8_type,
rank_id=config.tensor_parallel_rank,
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True),
Expand Down Expand Up @@ -683,7 +684,7 @@
from paddlenlp_ops import fused_get_rotary_embedding

new_rope = fused_get_rotary_embedding(
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, True
input_ids, position_ids, self.head_dim_shape_tensor, position_offset, self.rope_theta, self.use_neox
)

with dy2st_nocheck_guard_context():
Expand Down Expand Up @@ -800,11 +801,9 @@
concated_ffn1_weight = np.concatenate(
[unfused_state_dict["mlp.gate_proj.weight"], unfused_state_dict["mlp.up_proj.weight"]], axis=-1
)
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)

qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype())

Check warning on line 805 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L805

Added line #L805 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要cast到 default dtype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了同时能跑bf16/fp16

if self.use_weight_only:
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
qkv_weight_tensor = paddle.transpose(qkv_weight_tensor, perm=[1, 0])
qkv_quanted_weight_tensor, qkv_weight_scale_tensor = weight_quantize(
qkv_weight_tensor, algo=self.quant_algo
Expand All @@ -816,11 +815,11 @@
paddle.cast(paddle.to_tensor(concated_qkv_weight), "int8")
)
else:
self.transformer_block.qkv_weights[idx].set_value(
qkv_weight_tensor.cast(self.transformer_block.qkv_weights[idx].dtype)
)
self.transformer_block.qkv_weights[idx].set_value(qkv_weight_tensor)

Check warning on line 818 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L818

Added line #L818 was not covered by tests

linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)])
linear_weight_tensor = paddle.to_tensor(

Check warning on line 820 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L820

Added line #L820 was not covered by tests
state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]
).cast(paddle.get_default_dtype())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同问

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同回

if self.use_weight_only:
linear_quanted_weight_tensor, linear_weight_scale_tensor = weight_quantize(
linear_weight_tensor, algo=self.quant_algo
Expand All @@ -844,10 +843,9 @@
)
)
else:
self.transformer_block.linear_weights[idx].set_value(
linear_weight_tensor.cast(self.transformer_block.linear_weights[idx].dtype)
)
self.transformer_block.linear_weights[idx].set_value(linear_weight_tensor)

Check warning on line 846 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L846

Added line #L846 was not covered by tests

ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype())

Check warning on line 848 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L848

Added line #L848 was not covered by tests
if self.use_weight_only:
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
ffn1_weight_tensor, algo=self.quant_algo
Expand All @@ -864,11 +862,11 @@
paddle.cast(paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)), "int8")
)
else:
self.transformer_block.ffn1_weights[idx].set_value(
ffn1_weight_tensor.cast(self.transformer_block.ffn1_weights[idx].dtype)
)
self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight_tensor)

Check warning on line 865 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L865

Added line #L865 was not covered by tests

ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)])
ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]).cast(

Check warning on line 867 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L867

Added line #L867 was not covered by tests
paddle.get_default_dtype()
)
if self.use_weight_only:
ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize(
ffn2_weight_tensor, algo=self.quant_algo
Expand All @@ -892,9 +890,7 @@
)
)
else:
self.transformer_block.ffn2_weights[idx].set_value(
ffn2_weight_tensor.cast(self.transformer_block.ffn2_weights[idx].dtype)
)
self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight_tensor)

Check warning on line 893 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L893

Added line #L893 was not covered by tests

if "a8w8" in self.quant_type:
if self.shift_smooth_all_linears:
Expand Down Expand Up @@ -1009,7 +1005,7 @@
)

if self.config.cachekv_int8_type == "static":
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_act_scales.json")
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json")

Check warning on line 1008 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1008

Added line #L1008 was not covered by tests
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
cache_scale_json_path = os.path.join(
self.quant_model_path, f"cachekv_act_scales_{self.config.tensor_parallel_rank}.json"
Expand All @@ -1018,7 +1014,8 @@
cache_scale_json_path,
cache_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
num_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
)
for k, v in cache_scales_loader.scale.items():
for i_layer, weight_scale in enumerate(v):
Expand Down Expand Up @@ -1400,7 +1397,9 @@
@paddle.no_grad()
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.lm_head.weight.set_value(

Check warning on line 1400 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L1400

Added line #L1400 was not covered by tests
paddle.to_tensor(state_dict["lm_head.weight"]).cast(self.lm_head.weight.dtype)
)
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/quantization/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def __init__(
raise ValueError(
f"weight_quantize_algo:{weight_quantize_algo} not in supported list ['weight_only_int8', 'weight_only_int4', 'llm.int8', 'a8w8', 'nf4', 'fp4']"
)
if quant_type is not None and quant_type not in ["weight_only_int8", "weight_only_int4", "a8w8"]:
if quant_type is not None and quant_type not in ["weight_only_int8", "weight_only_int4", "a8w8", "a8w8c8"]:
raise ValueError(
f"quant_type:{quant_type} not in supported list ['weight_only_int8', 'weight_only_int4', 'a8w8']"
f"quant_type:{quant_type} not in supported list ['weight_only_int8', 'weight_only_int4', 'a8w8', 'a8w8c8']"
)
self.weight_quantize_algo = weight_quantize_algo
self.quant_type = quant_type
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ inference-predict:
inference-to-static:
default:
dtype: float16
max_length: 20

inference-infer:
default:
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ inference-predict:
inference-to-static:
default:
dtype: float16
max_length: 20

inference-infer:
default:
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/predictor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ inference-predict:
inference-to-static:
default:
dtype: float16
max_length: 40

inference-infer:
default:
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/prefix_tuning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ inference-to-static:
default:
dtype: float16
export_precache: true
max_length: 20

inference-infer:
default:
Expand Down
4 changes: 4 additions & 0 deletions tests/fixtures/llm/pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pretrain:
inference-predict:
default:
mode: dynamic
src_length: 512
max_length: 20
batch_size: 2
decode_strategy: greedy_search
Expand All @@ -43,12 +44,15 @@ inference-predict:
inference-to-static:
default:
dtype: float16
src_length: 512
max_length: 20

inference-infer:
default:
mode: static
dtype: float16
batch_size: 2
decode_strategy: greedy_search
src_length: 512
max_length: 20
chat_template: none
1 change: 1 addition & 0 deletions tests/fixtures/llm/ptq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ inference-predict:
inference-to-static:
default:
dtype: float16
max_length: 20


inference-infer:
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/vera.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ inference-predict:
inference-to-static:
default:
dtype: float16
max_length: 20

inference-infer:
default:
Expand Down
Loading
Loading