Skip to content

Commit

Permalink
[Paddle Inference]support miniGPT4's second part dy2st (#6905)
Browse files Browse the repository at this point in the history
* support miniGPT4

* remove some useless code

* move some function to modeling.py
commit

* 1->self.config.bos_token_id

* remove useless comment

* huifu

* move prepare_input_ids_for_generation to modeling

* LlamaForMiniGPT4InferenceModel

* use model_type
  • Loading branch information
zhoutianzi666 authored Sep 7, 2023
1 parent 579d1f9 commit 294df07
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 12 deletions.
26 changes: 17 additions & 9 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ class PredictorArgument:
inference_model: bool = field(default=False, metadata={"help": "whether use InferenceModel to do generation"})
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
max_batch_size: int = field(default=None, metadata={"help": "The max batch size of data during serving."})
benchmark: bool = field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
benchmark: bool = (
field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
),
)


Expand Down Expand Up @@ -573,13 +575,19 @@ def create_predictor(
# TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
if "llama" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel,
)
if model_args.model_type == "llama-img2txt":
# we use llama for img2txt.
from paddlenlp.experimental.transformers import (
LlamaForMiniGPT4InferenceModel as LlamaInferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
)

config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank
model = LlamaForCausalLMInferenceModel.from_pretrained(
model = LlamaInferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
Expand Down
24 changes: 22 additions & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def to_static(self, output_path: str, config: dict):
model, output_path, skip_prune_program=True
) # Note(Zhengzekang): If we prune program it may cause some inference error.

@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
batch_size = 1
seq_len = 1
if bos_token_id is None:
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
if encoder_output is not None:
batch_size = encoder_output.shape[0]
seq_len = encoder_output.shape[1]
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id

@paddle.no_grad()
def generate(
self,
Expand All @@ -109,6 +120,7 @@ def generate(
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**model_kwargs,
):

Expand Down Expand Up @@ -136,6 +148,7 @@ def generate(
top_p=top_p,
cache_kvs=cache_kvs,
temperature=temperature,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return ret
Expand Down Expand Up @@ -215,17 +228,23 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e

def sample(
self,
input_ids,
eos_token_id,
input_ids=None,
eos_token_id=None,
cache_kvs=[],
top_p=None,
temperature=None,
inputs_embeds=None,
**model_kwargs,
):
step_idx_ori = paddle.full(shape=[1], dtype="int64", fill_value=1)
batch_idx = paddle.full(shape=[1], dtype="int32", fill_value=-1)

# let inputs_embeds enter into model_kwargs.
# because the code below directly use the model_kwargs as a parameter without using inputs_embeds.
model_kwargs["inputs_embeds"] = inputs_embeds

def _forward_(**args):
# cache_kvs is never empty because it is passed as a parameter in def sample.
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args)
return self(**model_inputs)

Expand Down Expand Up @@ -297,6 +316,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
)
step_idx_ori += 1
encoder_output = outputs
# gives it a value, means we will entered into decoder phase.
model_kwargs["cache"] = 0

# decoder
Expand Down
146 changes: 145 additions & 1 deletion paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from paddlenlp.transformers.model_utils import register_base_model

__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel"]
__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel", "LlamaForMiniGPT4InferenceModel"]


class FusedLlamaRMSNorm(nn.Layer):
Expand Down Expand Up @@ -149,6 +149,18 @@ def remove_padding(self, input_ids, seq_lens_this_time):
)
return ids_remove_padding, padding_offset, cum_offsets

# This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py
@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
batch_size = 1
seq_len = 1
if bos_token_id is None:
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
if encoder_output is not None:
batch_size = encoder_output.shape[0]
seq_len = encoder_output.shape[1]
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id

def forward(
self,
input_ids=None,
Expand All @@ -165,9 +177,24 @@ def forward(
return_dict=False,
**kwargs,
):
# kwargs["cache"] is used used to distinguish between encoder and decoder phase.
past_key_values = kwargs.get("cache", None)
is_decoder = past_key_values is not None

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

# genereate a fake input_ids according to inputs_embeds
# this is usually occurred in img2txt multimodal model when first enter into this forward function.
if input_ids is None and inputs_embeds is not None:
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape
# merge batch and seq_len dimension.
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -345,14 +372,19 @@ def prepare_inputs_for_generation(
position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)
cache = kwargs.get("cache", None)
inputs_embeds = kwargs.get("inputs_embeds", None)
if cache is not None:
input_ids = tgt_ids
position_ids = tgt_pos
attention_mask = (tgt_generation_mask - 1) * 1e4
# make inputs_embeds be none in decoder phase.
# in forward function, it will be assigned according to input_ids.
inputs_embeds = None
else:
attention_mask = (attention_mask - 1) * 1e4
model_inputs = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_kvs": cache_kvs,
Expand Down Expand Up @@ -432,3 +464,115 @@ def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


class LlamaForMiniGPT4InferenceModel(LlamaForCausalLMInferenceModel):
"""
This class is 99% like LlamaForCausalLMInferenceModel.
Used only for miniGPT4's second part.
"""

# This function corresponds to miniGPT4's second part, only used in miniGPT4.
@paddle.no_grad()
def generate_text_with_image_features(
self,
image_features: paddle.Tensor,
first_input_ids: paddle.Tensor,
second_input_ids: paddle.Tensor,
attention_mask: paddle.Tensor,
position_ids=None,
penalty_score=None,
frequency_score=None,
presence_score=None,
min_length=None,
max_length=None,
temperature=None,
top_p=None,
eos_token_id=None,
seq_len_encoder=None,
seq_len_decoder=None,
step_idx=None,
stop_flags=None,
tgt_ids=None,
tgt_pos=None,
tgt_generation_mask=None,
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**generate_kwargs
) -> paddle.Tensor:

first_embeds = self.llama.embed_tokens(first_input_ids)
second_embeds = self.llama.embed_tokens(second_input_ids)
image_features = paddle.cast(image_features, dtype=first_embeds.dtype)
inputs_embeds = paddle.concat([first_embeds, image_features, second_embeds], axis=1)

outputs = self.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
seq_len_encoder=seq_len_encoder,
seq_len_decoder=seq_len_decoder,
step_idx=step_idx,
stop_flags=stop_flags,
tgt_ids=tgt_ids,
tgt_pos=tgt_pos,
tgt_generation_mask=tgt_generation_mask,
pre_ids=pre_ids,
stop_nums=stop_nums,
cache_kvs=cache_kvs,
)
return outputs

# rewrite to_static function in generation_utils.py
def to_static(self, output_path: str, config: dict):
dtype = config.get("dtype", paddle.get_default_dtype())
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
input_spec = [
paddle.static.InputSpec(
shape=[None, None, None], dtype="float32", name="image_features"
), # image_features
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="first_input_ids"), # first_input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="second_input_ids"), # second_input_ids
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
paddle.static.InputSpec(
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
), # tgt_generation_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
[
paddle.static.InputSpec(
shape=shape,
dtype=dtype,
name="cache_kvs_{}".format(i),
)
for i, shape in enumerate(cache_kvs_shapes)
], # cache_kvs
]

model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
paddle.jit.save(model, output_path)

0 comments on commit 294df07

Please sign in to comment.