-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache: new Cache format in decoder-only models #31421
Cache: new Cache format in decoder-only models #31421
Conversation
@@ -622,6 +622,7 @@ def _flash_attention_forward( | |||
""" | |||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | |||
first unpad the input, then computes the attention scores and pad the final attention scores. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These kinds of changes are from fix copies, and are not related at all to the PR. But let's leave it here as it's anyway related to code-consistency in the library
key_layer = self.transpose_for_scores(self.key(hidden_states)) | ||
value_layer = self.transpose_for_scores(self.value(hidden_states)) | ||
key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2) | ||
value_layer = torch.cat( | ||
[value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For GIT I think we can make assumption that we always get only new text tokens if past-kv is not None
, and that was verified by generation tests. In that case, we can use DynamicCache by cropping image related tokens from keys/values as I did below
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@zucchini-nlp regarding encoder-decoder models: let's first finish the PR for Whisper (#31166), and then copy the design all over the library! :D Your input in that PR would be appreciated :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two high-level comments to clear out before diving into details :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for working through these models 💪
request: let's activate slow tests! (I don't trust my review to capture all nuances from all architectures 😉 )
All slow tests are passing on my end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Short review, I want to better understand why we are shipping get_seq_length
while we want to deprecate this API / it will make our life harder for any other cache class we want to support for these models (PS I plan to support more 😉 )
Let's use cache positions!
Otherwise great work isolating the changes everywhere, might be good to even add some copied from here and there!
past_length = 0 | ||
if use_cache: | ||
use_legacy_cache = not isinstance(past_key_values, Cache) | ||
if use_legacy_cache: | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_length = past_key_values.get_seq_length() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is super ugly IMO.
I would really love for us to just no add this code that ended up being copy pasted everywhere because we were lazy to update when porting new models.
Let's use cache positions. And let's already deprecate the casting and legacy cache etc, to make sure we only do this for one revision!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I wanted to add cache_position
as part of StaticCache support but it can be added in this PR also. Wondering about deprecation cycle for Cache class, afaik there were no deprecation warning about that before so we would still have to keep the ugly hack and add a warning message?
I'm pro of totally getting rid of old cache, if it doesn't break BC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see that Llama on main had its preparation modified already assuming that the inputs is always Cache object. Oke, then it makes sense to get rid of legacy_cache in forward also
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it was not done before but this PR IMO is a good time to do it! So keep this code, but add a warning saying, we are automatically converting your cache from tuple to dynamic cache class (SHould only be triggered outside generate because generate should already pass a cache class!)
Yep it makes sense but let's not be too brutal in case some people still use it! We give them one release until we totally remove it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oke, added cache position and deprecation warnings to all models from this PR. I'll add the same deprecation warning to all models that already support cache class in another PR. This one is ready for review, tests are passing on my end!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
- let's not warn when we are training ( though I suspect a lot of people generate with a model that has model.training.... )
- let's add integration tests with compile
- let's add to the compile doc, the list of supported models + performance boost
Maybe for another PR or for @ydshieh : - let's add these models to the list of model we run benchmark for
if attention_mask is not None: | ||
attention_mask = attention_mask[:, :, :, : key_layer.shape[-2]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be compile compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be, but my main idea is to deprecate old cache in this PR and check if these changes enabled compile-compatibility for them. So tests for compile and other things can go for a second PR
You mean these list (#28981) right? OR we have it somewhere in the docs also?
@@ -64,6 +65,9 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): | |||
_no_split_modules = ["GPTNeoXLayer"] | |||
_skip_keys_device_placement = "past_key_values" | |||
_supports_flash_attn_2 = True | |||
_supports_cache_class = True | |||
_supports_quantized_cache = True | |||
_supports_static_cache = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, this is automatically tested, but manually testing would be nice to make sure we have correct outputs!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean add slow integration tests for each model and cache type, or simply run some generations on my end to see that it makes sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to you! I think we can trust our tests / potentially make sure that they were run for all these mdel (we have one that tests cache and compile!
Discovered just today that GPT-Neo and a few models are just like Gemma2, i.e. have hybrid cache format. It will take me some time to figure out how to make things work and still support old format |
@ArthurZucker teh only model that is like Gemma2 is GPT-Neo. I started working on adding a HybridCache support on it but realized it will bloat up the code with if/elses to check on which cache type we're using now. So, I propose to review and merge this PR to deprecate cache as tuple. For gpt-neo currently I just toggled static-cache support to |
@zucchini-nlp ping me when this gets merged -- with this PR merged, I can uniformize RoPE on even more models 🙌 |
Cool will do one last pass then! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀 huge work and refactoring!
Mostly a few comments on the codegen
valid for other model:
- in generate cache positions only should be used
- not creating past is breaking, but we can break TBH
- For another PR maybe, but let's move the supports_quantized_cache to
cache_utils
with a set!
@@ -303,6 +313,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): | |||
supports_gradient_checkpointing = True | |||
_no_split_modules = ["CodeGenBlock"] | |||
_skip_keys_device_placement = "past_key_values" | |||
_supports_cache_class = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW we could also just add the class to the set _support_quantized_cache={""}
in cache_utils, we don't pollute this here, and we can directly get all classes that support quantized / static etc.
-> better to auto build the doc, better in general to serparate cache stuff from modeling specific
else: | ||
past_length = past_key_values[0][0].size(-2) | ||
use_legacy_cache = False | ||
if use_cache and not isinstance(past_key_values, Cache) and not self.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, my "bad" here, we are kinda doing a breaking change, as before past_key_values would always be created if None!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the warning should be the only thing that check not self.training!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, can be done. Though I'm not sure how often users want to cache when training
EDIT: i just got what you meant, will raise warning message only if not training but create cache from legacy in all cases!
Sounds good, but imo it should be another PR because this one is finalizing deprecation for a few models left. I'll make a PR for auto-mapping-cache to see how it goes |
yep another pr for sure! |
Ran the tests and fixed in couple places. Will merge later today/tomorrow Btw @gante seems like compile isn't working on current |
@zucchini-nlp will have a look, thank you for flagging 👍 |
Transformer4.45 has the Dynamic cache updates removing self.bias causing the failure. Until we investigate further and update the code based on the new transformer, we are putting bias back to GaudiGPTJAttention.init() huggingface/transformers#31421
What does this PR do?
This PR adds to all decoder-only models (except for XLNet) support for dynamic cache. I am working on general support for cache class in all models, but I found that encoder-decoder models are trickier and we might want to make a tuple of DynamicCache for those. That's why I am splitting the task into multiple PRs for easier review :)
All models were testes with GenerationTesterMixin tests and passed on my end. For special Bloom case, I made a tiny change that cache will hold same tensor-shapes in key/values (we need to be able to concat at shape[-2]). Another special case Git seems to work from the tests.
For encoder-decoder models I will try to add first for one model and make a PR. If the format is okay, I'll open another PR with all the generative models left