Skip to content

Commit

Permalink
Fix Generation expand_inputs_for_generation when corresponding param …
Browse files Browse the repository at this point in the history
…is None (#1656)
  • Loading branch information
FrostML authored Jan 28, 2022
1 parent c577cb4 commit 3e9ee57
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,20 +369,23 @@ def expand_inputs_for_generation(input_ids,
model_kwargs["attention_mask"] = paddle.gather(attention_mask,
index)

if "token_type_ids" in model_kwargs:
if "token_type_ids" in model_kwargs and model_kwargs[
"token_type_ids"] is not None:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = paddle.gather(token_type_ids,
index)

if "position_ids" in model_kwargs:
if "position_ids" in model_kwargs and model_kwargs[
"position_ids"] is not None:
position_ids = model_kwargs["position_ids"]
model_kwargs["position_ids"] = paddle.gather(position_ids, index)

if "seq_len" in model_kwargs:
if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None:
seq_len = model_kwargs["seq_len"]
model_kwargs["seq_len"] = paddle.gather(seq_len, index)

if "encoder_output" in model_kwargs:
if "encoder_output" in model_kwargs and model_kwargs[
"encoder_output"] is not None:
encoder_output = model_kwargs["encoder_output"]
model_kwargs["encoder_output"] = paddle.gather(encoder_output,
index)
Expand Down

0 comments on commit 3e9ee57

Please sign in to comment.