Skip to content

Commit

Permalink
add qwen inference model in static graph
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGuge committed Dec 20, 2023
1 parent 17793d0 commit 87caabd
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def create_predictor(
)
model.eval()
else:
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]")
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]")
predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer)
elif predictor_args.mode == "static":
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
Expand Down Expand Up @@ -915,8 +915,16 @@ def create_predictor(
cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "qwen" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
QWenForCausalLMInferenceModel,
)

cache_kvs_shape = QWenForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
else:
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]")
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]")
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
else:
raise ValueError("the `mode` should be one of [dynamic, static]")
Expand Down

0 comments on commit 87caabd

Please sign in to comment.