Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: new Cache format in decoder-only models #31421

Merged

Conversation

zucchini-nlp
Copy link
Member

What does this PR do?

This PR adds to all decoder-only models (except for XLNet) support for dynamic cache. I am working on general support for cache class in all models, but I found that encoder-decoder models are trickier and we might want to make a tuple of DynamicCache for those. That's why I am splitting the task into multiple PRs for easier review :)

All models were testes with GenerationTesterMixin tests and passed on my end. For special Bloom case, I made a tiny change that cache will hold same tensor-shapes in key/values (we need to be able to concat at shape[-2]). Another special case Git seems to work from the tests.

For encoder-decoder models I will try to add first for one model and make a PR. If the format is okay, I'll open another PR with all the generative models left

@zucchini-nlp zucchini-nlp requested a review from gante June 14, 2024 15:04
@@ -622,6 +622,7 @@ def _flash_attention_forward(
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kinds of changes are from fix copies, and are not related at all to the PR. But let's leave it here as it's anyway related to code-consistency in the library

Comment on lines -172 to -176
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2)
value_layer = torch.cat(
[value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For GIT I think we can make assumption that we always get only new text tokens if past-kv is not None, and that was verified by generation tests. In that case, we can use DynamicCache by cropping image related tokens from keys/values as I did below

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Member

gante commented Jun 14, 2024

For encoder-decoder models I will try to add first for one model and make a PR.

@zucchini-nlp regarding encoder-decoder models: let's first finish the PR for Whisper (#31166), and then copy the design all over the library! :D Your input in that PR would be appreciated :)

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two high-level comments to clear out before diving into details :)

src/transformers/generation/candidate_generator.py Outdated Show resolved Hide resolved
src/transformers/generation/candidate_generator.py Outdated Show resolved Hide resolved
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for working through these models 💪

request: let's activate slow tests! (I don't trust my review to capture all nuances from all architectures 😉 )

src/transformers/models/codegen/modeling_codegen.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/modeling_gpt_neo.py Outdated Show resolved Hide resolved
@gante gante requested a review from ArthurZucker June 17, 2024 16:19
@zucchini-nlp
Copy link
Member Author

All slow tests are passing on my end

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short review, I want to better understand why we are shipping get_seq_length while we want to deprecate this API / it will make our life harder for any other cache class we want to support for these models (PS I plan to support more 😉 )

Let's use cache positions!

Otherwise great work isolating the changes everywhere, might be good to even add some copied from here and there!

Comment on lines 479 to 484
past_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_length = past_key_values.get_seq_length()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super ugly IMO.
I would really love for us to just no add this code that ended up being copy pasted everywhere because we were lazy to update when porting new models.

Let's use cache positions. And let's already deprecate the casting and legacy cache etc, to make sure we only do this for one revision!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I wanted to add cache_position as part of StaticCache support but it can be added in this PR also. Wondering about deprecation cycle for Cache class, afaik there were no deprecation warning about that before so we would still have to keep the ugly hack and add a warning message?

I'm pro of totally getting rid of old cache, if it doesn't break BC

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see that Llama on main had its preparation modified already assuming that the inputs is always Cache object. Oke, then it makes sense to get rid of legacy_cache in forward also

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it was not done before but this PR IMO is a good time to do it! So keep this code, but add a warning saying, we are automatically converting your cache from tuple to dynamic cache class (SHould only be triggered outside generate because generate should already pass a cache class!)

Yep it makes sense but let's not be too brutal in case some people still use it! We give them one release until we totally remove it!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oke, added cache position and deprecation warnings to all models from this PR. I'll add the same deprecation warning to all models that already support cache class in another PR. This one is ready for review, tests are passing on my end!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

  • let's not warn when we are training ( though I suspect a lot of people generate with a model that has model.training.... )
  • let's add integration tests with compile
  • let's add to the compile doc, the list of supported models + performance boost
    Maybe for another PR or for @ydshieh :
  • let's add these models to the list of model we run benchmark for

Comment on lines +441 to +442
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key_layer.shape[-2]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be compile compatible?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be, but my main idea is to deprecate old cache in this PR and check if these changes enabled compile-compatibility for them. So tests for compile and other things can go for a second PR

You mean these list (#28981) right? OR we have it somewhere in the docs also?

src/transformers/models/git/modeling_git.py Show resolved Hide resolved
src/transformers/models/gpt_neo/modeling_gpt_neo.py Outdated Show resolved Hide resolved
@@ -64,6 +65,9 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, this is automatically tested, but manually testing would be nice to make sure we have correct outputs!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean add slow integration tests for each model and cache type, or simply run some generations on my end to see that it makes sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you! I think we can trust our tests / potentially make sure that they were run for all these mdel (we have one that tests cache and compile!

src/transformers/models/gptj/modeling_gptj.py Outdated Show resolved Hide resolved
src/transformers/models/idefics/modeling_idefics.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

Discovered just today that GPT-Neo and a few models are just like Gemma2, i.e. have hybrid cache format. It will take me some time to figure out how to make things work and still support old format

@zucchini-nlp
Copy link
Member Author

@ArthurZucker teh only model that is like Gemma2 is GPT-Neo. I started working on adding a HybridCache support on it but realized it will bloat up the code with if/elses to check on which cache type we're using now. So, I propose to review and merge this PR to deprecate cache as tuple.

For gpt-neo currently I just toggled static-cache support to False, and added a TODO for us. Might be quite breaking and imo should be part of compile compatibility, not deprecation PR.

@gante
Copy link
Member

gante commented Aug 2, 2024

@zucchini-nlp ping me when this gets merged -- with this PR merged, I can uniformize RoPE on even more models 🙌

@ArthurZucker
Copy link
Collaborator

Cool will do one last pass then!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 huge work and refactoring!

Mostly a few comments on the codegen valid for other model:

  • in generate cache positions only should be used
  • not creating past is breaking, but we can break TBH
  • For another PR maybe, but let's move the supports_quantized_cache to cache_utils with a set!

@@ -303,6 +313,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW we could also just add the class to the set _support_quantized_cache={""} in cache_utils, we don't pollute this here, and we can directly get all classes that support quantized / static etc.
-> better to auto build the doc, better in general to serparate cache stuff from modeling specific

else:
past_length = past_key_values[0][0].size(-2)
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, my "bad" here, we are kinda doing a breaking change, as before past_key_values would always be created if None!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the warning should be the only thing that check not self.training!

Copy link
Member Author

@zucchini-nlp zucchini-nlp Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, can be done. Though I'm not sure how often users want to cache when training

EDIT: i just got what you meant, will raise warning message only if not training but create cache from legacy in all cases!

src/transformers/models/codegen/modeling_codegen.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

BTW we could also just add the class to the set _support_quantized_cache={""} in cache_utils, we don't pollute this here, and we can directly get all classes that support quantized / static etc.

Sounds good, but imo it should be another PR because this one is finalizing deprecation for a few models left. I'll make a PR for auto-mapping-cache to see how it goes

@ArthurZucker
Copy link
Collaborator

yep another pr for sure!

@zucchini-nlp
Copy link
Member Author

Ran the tests and fixed in couple places. Will merge later today/tomorrow

Btw @gante seems like compile isn't working on current main after making cache classes nn.Module. Fails to register_buffer for some reason. For ex: tests/models/llama/test_modeling_llama.py::LlamaModelTest::test_generate_compile_fullgraph fails

@gante
Copy link
Member

gante commented Aug 6, 2024

@zucchini-nlp will have a look, thank you for flagging 👍

@zucchini-nlp zucchini-nlp merged commit a30c865 into huggingface:main Aug 7, 2024
23 checks passed
jiminha added a commit to huggingface/optimum-habana that referenced this pull request Sep 25, 2024
Transformer4.45 has the Dynamic cache updates removing
self.bias causing the failure. Until we investigate further
and update the code based on the new transformer, we are putting bias back to
GaudiGPTJAttention.init()

huggingface/transformers#31421
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants