diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index d6282baf2e..a1ce19a3dc 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1141,11 +1141,11 @@ def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, ** # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format # for some architectures which requires a special fix for prompt tuning etc. - # TODO: starting with transformers 4.37, all architectures should support caching. - uses_transformers_4_37 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.37.0") + # TODO: starting with transformers 4.38, all architectures should support caching. + uses_transformers_4_38 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.38.0") uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0") transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"] - uses_cache = uses_transformers_4_37 or ( + uses_cache = uses_transformers_4_38 or ( uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs )