Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with transformers 4.36 #1252

Merged

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Dec 11, 2023

It seems that transformers==4.36 breaks some of our tests. Locally, they pass with 4.35 but fail with 4.36.

One offending line is this:

seq_len += past_key_value[0].shape[-2]

It seems that past_key_values used to be a tuple of tensors, now it is a DynamicCache, whose __getitem__ returns a tuple of tensors. How can we rewrite the PEFT code in a backwards compatible fashion (i.e. it should also work with older transformers versions)?

Another error is this:

        if causal_4d_mask is not None:
>           expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
E           RuntimeError: The size of tensor a (14) must match the size of tensor b (17) at non-singleton dimension 3

If there is no easy fix, LMK and we'll have to pin the transformers version for now.

ping @tomaarsen @younesbelkada

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tomaarsen
Copy link
Member

cc: @gante as well. I don't have the bandwidth to look into this today

@BenjaminBossan BenjaminBossan changed the title Empty commit to check CI Issue with transformers 4.36 Dec 11, 2023
@BenjaminBossan
Copy link
Member Author

LMK if there is no bandwidth for this issue, then we should pin the transformers version for now, otherwise CI will remain broken.

@gante
Copy link
Member

gante commented Dec 11, 2023

@BenjaminBossan

uhmmm... the catch is that the function is missing one input, the layer index. If layer_idx held the layer index, the following replacement would work:

if past_key_value is not None:
    if isinstance(past_key_value, tuple):
        seq_len += past_key_value[0].shape[-2]
    else:
        seq_len += past_key_value.get_seq_length(layer_idx)

@gante
Copy link
Member

gante commented Dec 11, 2023

For context, the layer index is now held in the decoder layer and in the attention layer themselves, e.g. here

@BenjaminBossan
Copy link
Member Author

Thanks @gante! I made the following change:

    if past_key_value is not None:
        if isinstance(past_key_value, tuple):
            seq_len += past_key_value[0].shape[-2]
        else:
            # since transformers 4.36, this is a DynamicCache instance
            seq_len += past_key_value.get_seq_length(model.layer_idx)

This makes the first batch of tests pass 🎉

Unfortunately, the size mismatch in to_4d remains. For example:

        if causal_4d_mask is not None:
>           expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
E           RuntimeError: The size of tensor a (14) must match the size of tensor b (17) at non-singleton dimension 3

../../../anaconda3/envs/peft/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:136: RuntimeError

@gante
Copy link
Member

gante commented Dec 11, 2023

@BenjaminBossan the second one is tricker :D would you be able to share a snippet so I can reproduce on my end?

@BenjaminBossan
Copy link
Member Author

BenjaminBossan commented Dec 11, 2023

When checking out peft, this test fails:

pytest tests/ -k test_disable_adapter_45_test_HuggingFaceM4_tiny_random_LlamaForCausalLM_prompt_encoder

Possible this PEFT code is the issue, but I'm not sure:

peft/src/peft/peft_model.py

Lines 1162 to 1167 in e73967e

if model_kwargs["past_key_values"] is None:
inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0], task_ids=task_ids)
prompts = prompts.to(inputs_embeds.dtype)
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
model_kwargs["input_ids"] = None

edit: Maybe not, I couldn't spot a difference here between transformers 4.35 and 4.36

edit2: When I comment out this line:

self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation

the error disappears (but a later assertion that PEFT changes the result fails, as expected). However, when I jump into the output of self.prepare_inputs_for_generation, I can't spot any difference between 4.35 and 4.36.

@gante
Copy link
Member

gante commented Dec 11, 2023

@BenjaminBossan I think the attention mask in peft_model.py L1140 is being incorrectly constructed 🤔

The exception pops up because the input has a length of 17 (13 from the cache, 4 from the new input_ids) and the attention mask has a length of 14 (10 from peft_config.num_virtual_tokens + 4 from input_ids). In the previous generation round, we have an input of length 13 (10 from self.get_prompt() + 3 from input_ids converted into input_embeds), so all shapes seem to be correct except for the ad hoc attention mask.

If we replace

if model_kwargs.get("attention_mask", None) is not None:
  prefix_attention_mask = torch.ones(
      model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
  ).to(model_kwargs["input_ids"].device)
  model_kwargs["attention_mask"] = torch.cat(
      (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
  )

by

if model_kwargs.get("attention_mask", None) is not None:
  if model_kwargs["past_key_values"] is None:
      prefix_attention_mask = torch.ones(
          model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
      ).to(model_kwargs["input_ids"].device)
  else:
      prefix_attention_mask = torch.ones(
          model_kwargs["input_ids"].shape[0], model_kwargs["past_key_values"][0][0].shape[-2]
      ).to(model_kwargs["input_ids"].device)
  model_kwargs["attention_mask"] = torch.cat(
      (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
  )

then test pass because the attention mask is now of the expected shape. However, I'm not fully qualified to tell whether it is the right fix for PEFT :)

@BenjaminBossan
Copy link
Member Author

Thanks for digging into this. Your fix indeed makes the test pass -- alas, only for transformers 4.36 and not for 4.35, even though the shapes are the same for both. I'll try to dig deeper tomorrow.

However, I'm not fully qualified to tell whether it is the right fix for PEFT :)

That part of the code is also not very familiar to me :D

@Jaykumaran
Copy link

ttributeError Traceback (most recent call last)
Cell In[9], line 6
4 from datasets import load_dataset
5 # from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training
----> 6 from transformers import (AutoModelForCausalLM, AutoTokenizer,
7 BitsAndBytesConfig, HfArgumentParser,
8 TrainingArguments,GenerationConfig, logging, pipeline)
9 from trl import SFTTrainer

File /opt/conda/lib/python3.10/site-packages/transformers/init.py:26
23 from typing import TYPE_CHECKING
25 # Check the dependencies satisfy the minimal versions required.
---> 26 from . import dependency_versions_check
27 from .utils import (
28 OptionalDependencyNotAvailable,
29 _LazyModule,
(...)
46 logging,
47 )
50 logger = logging.get_logger(name) # pylint: disable=invalid-name

File /opt/conda/lib/python3.10/site-packages/transformers/dependency_versions_check.py:16
1 # Copyright 2020 The HuggingFace Team. All rights reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
15 from .dependency_versions_table import deps
---> 16 from .utils.versions import require_version, require_version_core
19 # define which module versions we always want to check at run time
20 # (usually the ones defined in install_requires in setup.py)
21 #
22 # order specific notes:
23 # - tqdm must be checked before tokenizers
25 pkgs_to_check_at_runtime = [
26 "python",
27 "tqdm",
(...)
37 "pyyaml",
38 ]

File /opt/conda/lib/python3.10/site-packages/transformers/utils/init.py:61
23 from .doc import (
24 add_code_sample_docstrings,
25 add_end_docstrings,
(...)
29 replace_return_docstrings,
30 )
31 from .generic import (
32 ContextManagers,
33 ExplicitEnum,
(...)
59 working_or_temp_dir,
60 )
---> 61 from .hub import (
62 CLOUDFRONT_DISTRIB_PREFIX,
63 DISABLE_TELEMETRY,
64 HF_MODULES_CACHE,
65 HUGGINGFACE_CO_PREFIX,
66 HUGGINGFACE_CO_RESOLVE_ENDPOINT,
67 PYTORCH_PRETRAINED_BERT_CACHE,
68 PYTORCH_TRANSFORMERS_CACHE,
69 S3_BUCKET_PREFIX,
70 TRANSFORMERS_CACHE,
71 TRANSFORMERS_DYNAMIC_MODULE_NAME,
72 EntryNotFoundError,
73 PushInProgress,
74 PushToHubMixin,
75 RepositoryNotFoundError,
76 RevisionNotFoundError,
77 cached_file,
78 default_cache_path,
79 define_sagemaker_information,
80 download_url,
81 extract_commit_hash,
82 get_cached_models,
83 get_file_from_repo,
84 has_file,
85 http_user_agent,
86 is_offline_mode,
87 is_remote_url,
88 move_cache,
89 send_example_telemetry,
90 try_to_load_from_cache,
91 )
92 from .import_utils import (
93 ENV_VARS_TRUE_AND_AUTO_VALUES,
94 ENV_VARS_TRUE_VALUES,
(...)
197 torch_required,
198 )
199 from .peft_utils import (
200 ADAPTER_CONFIG_NAME,
201 ADAPTER_SAFE_WEIGHTS_NAME,
(...)
204 find_adapter_config_file,
205 )

File /opt/conda/lib/python3.10/site-packages/transformers/utils/hub.py:94
84 old_default_cache_path = os.path.join(torch_cache_home, "transformers")
86 # Determine default cache directory. Lots of legacy environment variables to ensure backward compatibility.
87 # The best way to set the cache path is with the environment variable HF_HOME. For more details, checkout this
88 # documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
(...)
92 #
93 # TODO: clean this for v5?
---> 94 PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE)
95 PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
96 TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)

AttributeError: module 'huggingface_hub.constants' has no attribute 'HF_HUB_CACHE'

import os

import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, HfArgumentParser,
TrainingArguments,GenerationConfig, logging, pipeline)
from trl import SFTTrainer

!pip install trl transformers accelerate git+https://github.com/huggingface/peft.git -Uqqq
!pip install datasets bitsandbytes einops wandb -Uqqq

@tomaarsen
Copy link
Member

@Jaykumaran That sounds like your huggingface_hub version is too low.

@BenjaminBossan BenjaminBossan marked this pull request as ready for review December 12, 2023 12:13
@BenjaminBossan
Copy link
Member Author

@younesbelkada @tomaarsen @gante Please check if the fix/workaround is good. Ideally, I'd like not to hard-code the supported architectures, not sure if transformers provides a way to check that instead.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits, otherwise LGTM 👍

src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/peft_model.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome investigation!

@BenjaminBossan BenjaminBossan merged commit ee6f6dc into huggingface:main Dec 12, 2023
14 checks passed
@tomaarsen
Copy link
Member

Thanks @gante, @younesbelkada & @BenjaminBossan for looking into this!

@BenjaminBossan BenjaminBossan mentioned this pull request Dec 12, 2023
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jan 11, 2024
See huggingface#1252 for more context.

The initial idea was for transformers 4.37 to add the new caching to all
architectures, but this was postponed to 4.38. The code needs to be
adapted for prompt tuning not to break when transformers 4.37 is
released.
@pacman100
Copy link
Contributor

After spending 2 hours in the trenches of generation.utils.GenerationMixin, cache_utils.DynamicCache, modeling_llama.LlamaForCausalLM and peft.PeftModel, the fix of this PR is incorrect and the change in the logic of the prepare_inputs_for_generation of models like LlamaForCausalLM to make DynamicCache work is the cause for it.

@pacman100
Copy link
Contributor

pacman100 commented Jan 12, 2024

The highlighted part below is the cause of the issue here as the new logic removed it while not maintaining backward compatibility of using only the last input id for the next generation when past_key_value_length>=input_ids_length. The assumption 3 in the new changes should be reconsidered I think.
Screenshot 2024-01-12 at 1 49 43 PM

@gante
Copy link
Member

gante commented Jan 12, 2024

After spending 2 hours in the trenches of generation.utils.GenerationMixin, cache_utils.DynamicCache, modeling_llama.LlamaForCausalLM and peft.PeftModel

@pacman100 I feel you 😢

I'm sorry for breaking the old default behavior, which was used here -- it was the only solution I could find to ensure all new generation methods worked correctly. The transformers codebase didn't have a case like this one (past length > inputs length AND we only want to use the latest token in the fwd pass), so I didn't consider it at all.

BenjaminBossan added a commit that referenced this pull request Jan 12, 2024
See #1252 and #1352 for more context.

The initial idea was for transformers 4.37 to add the new caching to all
architectures, but this was postponed to 4.38. The code needs to be
adapted for prompt tuning not to break when transformers 4.37 is
released.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants