Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference]support miniGPT4's second part dy2st #6905

Merged
merged 11 commits into from
Sep 7, 2023
26 changes: 17 additions & 9 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,13 @@ class PredictorArgument:
inference_model: bool = field(default=False, metadata={"help": "whether use InferenceModel to do generation"})
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
max_batch_size: int = field(default=None, metadata={"help": "The max batch size of data during serving."})
benchmark: bool = field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
benchmark: bool = (
field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
),
)


Expand Down Expand Up @@ -573,13 +575,19 @@ def create_predictor(
# TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
if "llama" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel,
)
if model_args.model_type == "llama-img2txt":
# we use llama for img2txt.
from paddlenlp.experimental.transformers import (
LlamaForMiniGPT4InferenceModel as LlamaInferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
)

config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank
model = LlamaForCausalLMInferenceModel.from_pretrained(
model = LlamaInferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
Expand Down
24 changes: 22 additions & 2 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@
model, output_path, skip_prune_program=True
) # Note(Zhengzekang): If we prune program it may cause some inference error.

@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
batch_size = 1
seq_len = 1
if bos_token_id is None:
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
if encoder_output is not None:
batch_size = encoder_output.shape[0]
seq_len = encoder_output.shape[1]
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id

Check warning on line 97 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L88-L97

Added lines #L88 - L97 were not covered by tests

@paddle.no_grad()
def generate(
self,
Expand All @@ -109,6 +120,7 @@
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**model_kwargs,
):

Expand Down Expand Up @@ -136,6 +148,7 @@
top_p=top_p,
cache_kvs=cache_kvs,
temperature=temperature,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return ret
Expand Down Expand Up @@ -215,17 +228,23 @@

def sample(
self,
input_ids,
eos_token_id,
input_ids=None,
eos_token_id=None,
cache_kvs=[],
top_p=None,
temperature=None,
inputs_embeds=None,
**model_kwargs,
):
step_idx_ori = paddle.full(shape=[1], dtype="int64", fill_value=1)
batch_idx = paddle.full(shape=[1], dtype="int32", fill_value=-1)

# let inputs_embeds enter into model_kwargs.
# because the code below directly use the model_kwargs as a parameter without using inputs_embeds.
model_kwargs["inputs_embeds"] = inputs_embeds

Check warning on line 244 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L244

Added line #L244 was not covered by tests

def _forward_(**args):
# cache_kvs is never empty because it is passed as a parameter in def sample.
model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args)
return self(**model_inputs)

Expand Down Expand Up @@ -297,6 +316,7 @@
)
step_idx_ori += 1
encoder_output = outputs
# gives it a value, means we will entered into decoder phase.
model_kwargs["cache"] = 0

# decoder
Expand Down
146 changes: 145 additions & 1 deletion paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from paddlenlp.transformers.model_utils import register_base_model

__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel"]
__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel", "LlamaForMiniGPT4InferenceModel"]

Check warning on line 36 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L36

Added line #L36 was not covered by tests


class FusedLlamaRMSNorm(nn.Layer):
Expand Down Expand Up @@ -149,6 +149,18 @@
)
return ids_remove_padding, padding_offset, cum_offsets

# This function is a little different from prepare_input_ids_for_generation in paddlenlp/transformers/generation/utils.py
@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
batch_size = 1
seq_len = 1
if bos_token_id is None:
raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.")
if encoder_output is not None:
batch_size = encoder_output.shape[0]
seq_len = encoder_output.shape[1]
return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id

Check warning on line 162 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L153-L162

Added lines #L153 - L162 were not covered by tests

def forward(
self,
input_ids=None,
Expand All @@ -165,9 +177,24 @@
return_dict=False,
**kwargs,
):
# kwargs["cache"] is used used to distinguish between encoder and decoder phase.
past_key_values = kwargs.get("cache", None)
is_decoder = past_key_values is not None

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

Check warning on line 187 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L184-L187

Added lines #L184 - L187 were not covered by tests

# genereate a fake input_ids according to inputs_embeds
# this is usually occurred in img2txt multimodal model when first enter into this forward function.
if input_ids is None and inputs_embeds is not None:
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape

Check warning on line 194 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L191-L194

Added lines #L191 - L194 were not covered by tests
# merge batch and seq_len dimension.
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])

Check warning on line 196 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L196

Added line #L196 was not covered by tests

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -345,14 +372,19 @@
position_ids = kwargs.get("position_ids", None)
attention_mask = kwargs.get("attention_mask", None)
cache = kwargs.get("cache", None)
inputs_embeds = kwargs.get("inputs_embeds", None)

Check warning on line 375 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L375

Added line #L375 was not covered by tests
if cache is not None:
input_ids = tgt_ids
position_ids = tgt_pos
attention_mask = (tgt_generation_mask - 1) * 1e4
# make inputs_embeds be none in decoder phase.
# in forward function, it will be assigned according to input_ids.
inputs_embeds = None

Check warning on line 382 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L382

Added line #L382 was not covered by tests
else:
attention_mask = (attention_mask - 1) * 1e4
model_inputs = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_kvs": cache_kvs,
Expand Down Expand Up @@ -432,3 +464,115 @@
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


class LlamaForMiniGPT4InferenceModel(LlamaForCausalLMInferenceModel):

Check warning on line 469 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L469

Added line #L469 was not covered by tests
"""
This class is 99% like LlamaForCausalLMInferenceModel.
Used only for miniGPT4's second part.
"""

# This function corresponds to miniGPT4's second part, only used in miniGPT4.
@paddle.no_grad()
def generate_text_with_image_features(

Check warning on line 477 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L476-L477

Added lines #L476 - L477 were not covered by tests
self,
image_features: paddle.Tensor,
first_input_ids: paddle.Tensor,
second_input_ids: paddle.Tensor,
attention_mask: paddle.Tensor,
position_ids=None,
penalty_score=None,
frequency_score=None,
presence_score=None,
min_length=None,
max_length=None,
temperature=None,
top_p=None,
eos_token_id=None,
seq_len_encoder=None,
seq_len_decoder=None,
step_idx=None,
stop_flags=None,
tgt_ids=None,
tgt_pos=None,
tgt_generation_mask=None,
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**generate_kwargs
) -> paddle.Tensor:

first_embeds = self.llama.embed_tokens(first_input_ids)
second_embeds = self.llama.embed_tokens(second_input_ids)
image_features = paddle.cast(image_features, dtype=first_embeds.dtype)
inputs_embeds = paddle.concat([first_embeds, image_features, second_embeds], axis=1)

Check warning on line 509 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L506-L509

Added lines #L506 - L509 were not covered by tests

outputs = self.generate(

Check warning on line 511 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L511

Added line #L511 was not covered by tests
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
seq_len_encoder=seq_len_encoder,
seq_len_decoder=seq_len_decoder,
step_idx=step_idx,
stop_flags=stop_flags,
tgt_ids=tgt_ids,
tgt_pos=tgt_pos,
tgt_generation_mask=tgt_generation_mask,
pre_ids=pre_ids,
stop_nums=stop_nums,
cache_kvs=cache_kvs,
)
return outputs

Check warning on line 534 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L534

Added line #L534 was not covered by tests

# rewrite to_static function in generation_utils.py
def to_static(self, output_path: str, config: dict):
dtype = config.get("dtype", paddle.get_default_dtype())
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
input_spec = [

Check warning on line 540 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L537-L540

Added lines #L537 - L540 were not covered by tests
paddle.static.InputSpec(
shape=[None, None, None], dtype="float32", name="image_features"
), # image_features
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="first_input_ids"), # first_input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="second_input_ids"), # second_input_ids
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
paddle.static.InputSpec(
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
), # tgt_generation_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
[
paddle.static.InputSpec(
shape=shape,
dtype=dtype,
name="cache_kvs_{}".format(i),
)
for i, shape in enumerate(cache_kvs_shapes)
], # cache_kvs
]

model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
paddle.jit.save(model, output_path)

Check warning on line 578 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L577-L578

Added lines #L577 - L578 were not covered by tests