From d06fe115ee6151e3fc14d18fa1c7c803e1bc40a7 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 11 Dec 2023 15:54:15 +0100 Subject: [PATCH 01/10] Empty commit to check CI From 9801947f60a7f1f48f24e1d1b400012e56e4b6d7 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 12 Dec 2023 11:19:23 +0100 Subject: [PATCH 02/10] Fix issues with transformers 4.36.0 --- src/peft/peft_model.py | 14 +++++++++++--- src/peft/tuners/adaption_prompt/utils.py | 7 ++++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 430fa044d9..781a99fe18 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -23,7 +23,9 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Union +import packaging.version import torch +import transformers from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory @@ -1138,9 +1140,15 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) if peft_config.is_prompt_learning: if model_kwargs.get("attention_mask", None) is not None: - prefix_attention_mask = torch.ones( - model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens - ).to(model_kwargs["input_ids"].device) + if packaging.version.parse(transformers.__version__) < packaging.version.parse("4.36.0"): + # TODO figure out why this workaround is necessary, see #1252 for context + size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens + elif model_kwargs["past_key_values"] is None: + size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens + else: + size = model_kwargs["input_ids"].shape[0], model_kwargs["past_key_values"][0][0].shape[-2] + + prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device) model_kwargs["attention_mask"] = torch.cat( (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1 ) diff --git a/src/peft/tuners/adaption_prompt/utils.py b/src/peft/tuners/adaption_prompt/utils.py index 921982fbb7..0cbc95c1a1 100644 --- a/src/peft/tuners/adaption_prompt/utils.py +++ b/src/peft/tuners/adaption_prompt/utils.py @@ -73,7 +73,12 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor: seq_len = q_len if past_key_value is not None: - seq_len += past_key_value[0].shape[-2] + if isinstance(past_key_value, tuple): + # for transformers <= 4.35 + seq_len += past_key_value[0].shape[-2] + else: + # since transformers 4.36, this is a DynamicCache instance + seq_len += past_key_value.get_seq_length(model.layer_idx) cos, sin = model.rotary_emb(value_states, seq_len=seq_len) return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids) From fbaa646b882946fd5db16606f49282e462fc6ca0 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 12 Dec 2023 12:13:35 +0100 Subject: [PATCH 03/10] push (#5) --- src/peft/peft_model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 781a99fe18..e225421464 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1138,11 +1138,15 @@ def generate(self, **kwargs): def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs): peft_config = self.active_peft_config model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) + + _uses_transformers_4_26 = True + if peft_config.is_prompt_learning: if model_kwargs.get("attention_mask", None) is not None: if packaging.version.parse(transformers.__version__) < packaging.version.parse("4.36.0"): # TODO figure out why this workaround is necessary, see #1252 for context size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens + _uses_transformers_4_26 = False elif model_kwargs["past_key_values"] is None: size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens else: @@ -1174,6 +1178,10 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1) model_kwargs["input_ids"] = None + if _uses_transformers_4_26: + # TODO: why? + model_kwargs = self.base_model_prepare_inputs_for_generation(**model_kwargs) + return model_kwargs From 7745d0baa1fb118e9b4d0260ae2b1348ca8b7502 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 12 Dec 2023 12:56:28 +0100 Subject: [PATCH 04/10] Next attempt to fix the issue --- src/peft/peft_model.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index e225421464..aad57cf89b 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1139,18 +1139,19 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** peft_config = self.active_peft_config model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) - _uses_transformers_4_26 = True + # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format + # for some architectures which requires a special fix for prompt tuning etc. + uses_transformers_4_26 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0") + transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"] + uses_cache = self.base_model.config.model_type in transformers_new_cache_archs if peft_config.is_prompt_learning: if model_kwargs.get("attention_mask", None) is not None: - if packaging.version.parse(transformers.__version__) < packaging.version.parse("4.36.0"): + if uses_transformers_4_26 and uses_cache and (model_kwargs["past_key_values"] is not None): # TODO figure out why this workaround is necessary, see #1252 for context - size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens - _uses_transformers_4_26 = False - elif model_kwargs["past_key_values"] is None: - size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens - else: size = model_kwargs["input_ids"].shape[0], model_kwargs["past_key_values"][0][0].shape[-2] + else: + size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device) model_kwargs["attention_mask"] = torch.cat( @@ -1178,10 +1179,6 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1) model_kwargs["input_ids"] = None - if _uses_transformers_4_26: - # TODO: why? - model_kwargs = self.base_model_prepare_inputs_for_generation(**model_kwargs) - return model_kwargs From 634e75ea36c01e9269a1853c7ec7811e7c82107a Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 12 Dec 2023 13:28:57 +0100 Subject: [PATCH 05/10] Apply suggestions from code review Co-authored-by: Joao Gante --- src/peft/peft_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index aad57cf89b..9ea0272b77 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1141,13 +1141,13 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format # for some architectures which requires a special fix for prompt tuning etc. - uses_transformers_4_26 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0") + uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0") transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"] uses_cache = self.base_model.config.model_type in transformers_new_cache_archs if peft_config.is_prompt_learning: if model_kwargs.get("attention_mask", None) is not None: - if uses_transformers_4_26 and uses_cache and (model_kwargs["past_key_values"] is not None): + if uses_transformers_4_36 and uses_cache and (model_kwargs["past_key_values"] is not None): # TODO figure out why this workaround is necessary, see #1252 for context size = model_kwargs["input_ids"].shape[0], model_kwargs["past_key_values"][0][0].shape[-2] else: From 599b4603aa51ae04a1607d05735f6b0005972f24 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 12 Dec 2023 13:42:41 +0100 Subject: [PATCH 06/10] Add comment about next transformers version --- src/peft/peft_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 9ea0272b77..360a5c5fb8 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1141,6 +1141,7 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format # for some architectures which requires a special fix for prompt tuning etc. + # TODO: starting with transformers 4.37, all architectures should support caching. uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0") transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"] uses_cache = self.base_model.config.model_type in transformers_new_cache_archs From 4e3e87d615b61ce5052f76c188f92fb9f459b902 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 12 Dec 2023 14:20:37 +0100 Subject: [PATCH 07/10] Provisions for transformers 4.37 All model architectures should now support cache. --- src/peft/peft_model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 360a5c5fb8..d6282baf2e 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1142,13 +1142,16 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format # for some architectures which requires a special fix for prompt tuning etc. # TODO: starting with transformers 4.37, all architectures should support caching. + uses_transformers_4_37 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.37.0") uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0") transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"] - uses_cache = self.base_model.config.model_type in transformers_new_cache_archs + uses_cache = uses_transformers_4_37 or ( + uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs + ) if peft_config.is_prompt_learning: if model_kwargs.get("attention_mask", None) is not None: - if uses_transformers_4_36 and uses_cache and (model_kwargs["past_key_values"] is not None): + if uses_cache and (model_kwargs["past_key_values"] is not None): # TODO figure out why this workaround is necessary, see #1252 for context size = model_kwargs["input_ids"].shape[0], model_kwargs["past_key_values"][0][0].shape[-2] else: From 0bbf4e7a0a467fd60f505fa8fd99160b3fa60e23 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 12 Dec 2023 14:21:57 +0100 Subject: [PATCH 08/10] Skip flaky test on MacOS --- tests/test_adaption_prompt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index 117c43a427..00fa15bb23 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -15,6 +15,7 @@ import importlib import os +import platform import tempfile import unittest from unittest import TestCase @@ -387,6 +388,10 @@ def test_add_and_set_while_disabled(self): def test_use_cache(self) -> None: """Test that AdaptionPrompt works when Llama config use_cache=True.""" + if platform.system() == "Darwin": + # TODO: check why this is, may have started with transformers 4.36.0 + self.skipTest("This test is flaky on macOS.") + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) original = LlamaForCausalLM( LlamaConfig( From 75d07fa1f6652ad25f6677d53e8c4755bf07ab8d Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 12 Dec 2023 14:34:49 +0100 Subject: [PATCH 09/10] Skip another flaky test on MacOS --- tests/test_multitask_prompt_tuning.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index 9aa6b8d7d9..9cbc3f1dc1 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -15,6 +15,7 @@ import importlib import os +import platform import tempfile from unittest import TestCase @@ -220,6 +221,10 @@ def test_generate(self) -> None: def test_use_cache(self) -> None: """Test that MultiTaskPromptTuning works when Llama config use_cache=True.""" + if platform.system() == "Darwin": + # TODO: check why this is, may have started with transformers 4.36.0 + self.skipTest("This test is flaky on macOS.") + input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) task_ids = torch.LongTensor([1, 2]).to(self.torch_device) From d3c3be8c8d857773a91b54466904de5a6ba77636 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 12 Dec 2023 14:49:05 +0100 Subject: [PATCH 10/10] Check if seeding + eval() removes flakiness --- tests/test_adaption_prompt.py | 8 ++------ tests/test_multitask_prompt_tuning.py | 8 ++------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index 00fa15bb23..92bbd72017 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -15,7 +15,6 @@ import importlib import os -import platform import tempfile import unittest from unittest import TestCase @@ -388,10 +387,7 @@ def test_add_and_set_while_disabled(self): def test_use_cache(self) -> None: """Test that AdaptionPrompt works when Llama config use_cache=True.""" - if platform.system() == "Darwin": - # TODO: check why this is, may have started with transformers 4.36.0 - self.skipTest("This test is flaky on macOS.") - + torch.manual_seed(0) input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) original = LlamaForCausalLM( LlamaConfig( @@ -402,7 +398,7 @@ def test_use_cache(self) -> None: num_attention_heads=4, use_cache=False, ) - ) + ).eval() adapted = get_peft_model( original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") ) diff --git a/tests/test_multitask_prompt_tuning.py b/tests/test_multitask_prompt_tuning.py index 9cbc3f1dc1..be548aaa3e 100644 --- a/tests/test_multitask_prompt_tuning.py +++ b/tests/test_multitask_prompt_tuning.py @@ -15,7 +15,6 @@ import importlib import os -import platform import tempfile from unittest import TestCase @@ -221,14 +220,11 @@ def test_generate(self) -> None: def test_use_cache(self) -> None: """Test that MultiTaskPromptTuning works when Llama config use_cache=True.""" - if platform.system() == "Darwin": - # TODO: check why this is, may have started with transformers 4.36.0 - self.skipTest("This test is flaky on macOS.") - + torch.manual_seed(0) input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) task_ids = torch.LongTensor([1, 2]).to(self.torch_device) - original = LlamaForCausalLM(self._create_test_llama_config()) + original = LlamaForCausalLM(self._create_test_llama_config()).eval() mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config()) mpt = mpt.to(self.torch_device)