From 3e9ee57bfe54f66a9ec28b36b87908c4084421ea Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Fri, 28 Jan 2022 15:17:44 +0800 Subject: [PATCH] Fix Generation expand_inputs_for_generation when corresponding param is None (#1656) --- paddlenlp/transformers/generation_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index ee248bde90f2..1b96d2057edc 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -369,20 +369,23 @@ def expand_inputs_for_generation(input_ids, model_kwargs["attention_mask"] = paddle.gather(attention_mask, index) - if "token_type_ids" in model_kwargs: + if "token_type_ids" in model_kwargs and model_kwargs[ + "token_type_ids"] is not None: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index) - if "position_ids" in model_kwargs: + if "position_ids" in model_kwargs and model_kwargs[ + "position_ids"] is not None: position_ids = model_kwargs["position_ids"] model_kwargs["position_ids"] = paddle.gather(position_ids, index) - if "seq_len" in model_kwargs: + if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None: seq_len = model_kwargs["seq_len"] model_kwargs["seq_len"] = paddle.gather(seq_len, index) - if "encoder_output" in model_kwargs: + if "encoder_output" in model_kwargs and model_kwargs[ + "encoder_output"] is not None: encoder_output = model_kwargs["encoder_output"] model_kwargs["encoder_output"] = paddle.gather(encoder_output, index)