Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference]support miniGPT4's second part dy2st #6905

Merged
merged 11 commits into from
Sep 7, 2023
29 changes: 20 additions & 9 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,16 @@ class PredictorArgument:
inference_model: bool = field(default=False, metadata={"help": "whether use InferenceModel to do generation"})
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
max_batch_size: int = field(default=None, metadata={"help": "The max batch size of data during serving."})
benchmark: bool = field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
benchmark: bool = (
field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
),
)
llm_for_img2txt: bool = field(
default=False, metadata={"help": "whether this llm model is used for img2txt, such as miniGPT4, blip2."}
)


Expand Down Expand Up @@ -571,13 +576,19 @@ def create_predictor(
# TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
if "llama" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel,
)
if predictor_args.llm_for_img2txt:
# we use llama for img2txt.
from paddlenlp.experimental.transformers import (
LlamaForminiGPT4InferenceModel as LlamaInferenceModel,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LlamaForminiGPT4InferenceModel as LlamaInferenceModel,
LlamaForMiniGPT4InferenceModel as LlamaInferenceModel,

)
else:
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
)

config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank
model = LlamaForCausalLMInferenceModel.from_pretrained(
model = LlamaInferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
Expand Down
33 changes: 31 additions & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@
model, output_path, skip_prune_program=True
) # Note(Zhengzekang): If we prune program it may cause some inference error.

@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
batch_size = 1
seq_len = 1
if bos_token_id is None:
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
if encoder_output is not None:
batch_size = encoder_output.shape[0]
seq_len = encoder_output.shape[1]
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id

Check warning on line 97 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L88-L97

Added lines #L88 - L97 were not covered by tests

@paddle.no_grad()
def generate(
self,
Expand All @@ -109,6 +120,7 @@
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**model_kwargs,
):

Expand Down Expand Up @@ -136,6 +148,7 @@
top_p=top_p,
cache_kvs=cache_kvs,
temperature=temperature,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return ret
Expand Down Expand Up @@ -213,17 +226,32 @@

def sample(
self,
input_ids,
eos_token_id,
input_ids=None,
eos_token_id=None,
cache_kvs=[],
top_p=None,
temperature=None,
inputs_embeds=None,
**model_kwargs,
):
step_idx_ori = paddle.full(shape=[1], dtype="int64", fill_value=1)
batch_idx = paddle.full(shape=[1], dtype="int32", fill_value=-1)

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

Check warning on line 243 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L240-L243

Added lines #L240 - L243 were not covered by tests

# genereate a fake input_ids according to inputs_embeds.
if input_ids is None and inputs_embeds is not None:
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])
model_kwargs["inputs_embeds"] = inputs_embeds

Check warning on line 251 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L246-L251

Added lines #L246 - L251 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的逻辑是需要迁移到模型的 forward 里面去的,而不是在 generation_utils 里面,具体可参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1189

在 experimental/transformers/llama/modeling.py 下面目前是没有对应的 checking,所以建议你将这部分的代码挪过去一下,非常感谢。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的逻辑是需要迁移到模型的 forward 里面去的,而不是在 generation_utils 里面,具体可参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1189

在 experimental/transformers/llama/modeling.py 下面目前是没有对应的 checking,所以建议你将这部分的代码挪过去一下,非常感谢。

已改,辛苦review


def _forward_(**args):
# cache_kvs is never empty because it is passed as a parameter in def sample.
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args)
return self(**model_inputs)

Expand Down Expand Up @@ -294,6 +322,7 @@
)
step_idx_ori += 1
encoder_output = outputs
# gives it a value, means we will entered into decoder phase.
model_kwargs["cache"] = 0

# decoder
Expand Down
120 changes: 119 additions & 1 deletion paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from paddlenlp.transformers.model_utils import register_base_model

__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel"]
__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel", "LlamaForminiGPT4InferenceModel"]

Check warning on line 36 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L36

Added line #L36 was not covered by tests


class FusedLlamaRMSNorm(nn.Layer):
Expand Down Expand Up @@ -165,6 +165,7 @@
return_dict=False,
**kwargs,
):
# kwargs["cache"] is used used to distinguish between encoder and decoder phase.
past_key_values = kwargs.get("cache", None)
is_decoder = past_key_values is not None

Expand Down Expand Up @@ -345,14 +346,19 @@
position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)
cache = kwargs.get("cache", None)
inputs_embeds = kwargs.get("inputs_embeds", None)

Check warning on line 349 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L349

Added line #L349 was not covered by tests
if cache is not None:
input_ids = tgt_ids
position_ids = tgt_pos
attention_mask = (tgt_generation_mask - 1) * 1e4
# make inputs_embeds be none in decoder phase.
# in forward function, it will be assigned according to input_ids.
inputs_embeds = None

Check warning on line 356 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L356

Added line #L356 was not covered by tests
else:
attention_mask = (attention_mask - 1) * 1e4
model_inputs = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_kvs": cache_kvs,
Expand Down Expand Up @@ -432,3 +438,115 @@
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


class LlamaForminiGPT4InferenceModel(LlamaForCausalLMInferenceModel):

Check warning on line 443 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L443

Added line #L443 was not covered by tests
"""
This class is 99% like LlamaForCausalLMInferenceModel.
Used only for miniGPT4's second part.
"""

# This function corresponds to miniGPT4's second part, only used in miniGPT4.
@paddle.no_grad()
def generate_text_with_image_features(

Check warning on line 451 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L450-L451

Added lines #L450 - L451 were not covered by tests
self,
image_features: paddle.Tensor,
first_input_ids: paddle.Tensor,
second_input_ids: paddle.Tensor,
attention_mask: paddle.Tensor,
position_ids=None,
penalty_score=None,
frequency_score=None,
presence_score=None,
min_length=None,
max_length=None,
temperature=None,
top_p=None,
eos_token_id=None,
seq_len_encoder=None,
seq_len_decoder=None,
step_idx=None,
stop_flags=None,
tgt_ids=None,
tgt_pos=None,
tgt_generation_mask=None,
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**generate_kwargs
) -> paddle.Tensor:

first_embeds = self.llama.embed_tokens(first_input_ids)
second_embeds = self.llama.embed_tokens(second_input_ids)
image_features = paddle.cast(image_features, dtype=first_embeds.dtype)
inputs_embeds = paddle.concat([first_embeds, image_features, second_embeds], axis=1)

Check warning on line 483 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L480-L483

Added lines #L480 - L483 were not covered by tests

outputs = self.generate(

Check warning on line 485 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L485

Added line #L485 was not covered by tests
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
seq_len_encoder=seq_len_encoder,
seq_len_decoder=seq_len_decoder,
step_idx=step_idx,
stop_flags=stop_flags,
tgt_ids=tgt_ids,
tgt_pos=tgt_pos,
tgt_generation_mask=tgt_generation_mask,
pre_ids=pre_ids,
stop_nums=stop_nums,
cache_kvs=cache_kvs,
)
return outputs

Check warning on line 508 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L508

Added line #L508 was not covered by tests

# rewrite to_static function in generation_utils.py
def to_static(self, output_path: str, config: dict):
dtype = config.get("dtype", paddle.get_default_dtype())
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
input_spec = [

Check warning on line 514 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L511-L514

Added lines #L511 - L514 were not covered by tests
paddle.static.InputSpec(
shape=[None, None, None], dtype="float32", name="image_features"
), # image_features
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="first_input_ids"), # first_input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="second_input_ids"), # second_input_ids
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
paddle.static.InputSpec(
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
), # tgt_generation_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
[
paddle.static.InputSpec(
shape=shape,
dtype=dtype,
name="cache_kvs_{}".format(i),
)
for i, shape in enumerate(cache_kvs_shapes)
], # cache_kvs
]

model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
paddle.jit.save(model, output_path)

Check warning on line 552 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L551-L552

Added lines #L551 - L552 were not covered by tests