From b3f925da6a246148386068de575abafc48c6db4c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 12 Dec 2023 11:07:06 +0000 Subject: [PATCH] push --- src/peft/peft_model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 781a99fe18..e225421464 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1138,11 +1138,15 @@ def generate(self, **kwargs): def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs): peft_config = self.active_peft_config model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) + + _uses_transformers_4_26 = True + if peft_config.is_prompt_learning: if model_kwargs.get("attention_mask", None) is not None: if packaging.version.parse(transformers.__version__) < packaging.version.parse("4.36.0"): # TODO figure out why this workaround is necessary, see #1252 for context size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens + _uses_transformers_4_26 = False elif model_kwargs["past_key_values"] is None: size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens else: @@ -1174,6 +1178,10 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1) model_kwargs["input_ids"] = None + if _uses_transformers_4_26: + # TODO: why? + model_kwargs = self.base_model_prepare_inputs_for_generation(**model_kwargs) + return model_kwargs