diff --git a/paddlenlp/transformers/generation_utils.py b/paddlenlp/transformers/generation_utils.py index ee248bde90f2..1b96d2057edc 100644 --- a/paddlenlp/transformers/generation_utils.py +++ b/paddlenlp/transformers/generation_utils.py @@ -369,20 +369,23 @@ def expand_inputs_for_generation(input_ids, model_kwargs["attention_mask"] = paddle.gather(attention_mask, index) - if "token_type_ids" in model_kwargs: + if "token_type_ids" in model_kwargs and model_kwargs[ + "token_type_ids"] is not None: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = paddle.gather(token_type_ids, index) - if "position_ids" in model_kwargs: + if "position_ids" in model_kwargs and model_kwargs[ + "position_ids"] is not None: position_ids = model_kwargs["position_ids"] model_kwargs["position_ids"] = paddle.gather(position_ids, index) - if "seq_len" in model_kwargs: + if "seq_len" in model_kwargs and model_kwargs["seq_len"] is not None: seq_len = model_kwargs["seq_len"] model_kwargs["seq_len"] = paddle.gather(seq_len, index) - if "encoder_output" in model_kwargs: + if "encoder_output" in model_kwargs and model_kwargs[ + "encoder_output"] is not None: encoder_output = model_kwargs["encoder_output"] model_kwargs["encoder_output"] = paddle.gather(encoder_output, index)