Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: Bart and related architectures support Cache objects #28065

Closed
wants to merge 11 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Dec 15, 2023

What does this PR do?

This PR applies the changes to Bart so it supports the new Cache objects. In other works, it is akin to #26681 but for encoder-decoder models.

⚠️ This is a giant PR that can't be separated due to our copy mechanism (🙃), but the review process doesn't need to be daunting. Here's my suggested review order and high-level rationale:

  1. Changes in cache_utils.py. I've introduced DynamicCacheWithCrossAttention, which expands DynamicCache [cache object equivalent to the previous past_key_values input/output] with the ability to hold a cross-attention cache. This design was intentional: most LLMs (and now even multimodel models) tend to be decoder-only, so this separation will keep the cache class for decoder-only models simpler. It also enable us to be more strict -- I've caught an unintended cache deletion in Whisper thanks to the increased specificity!
  2. Changes in modeling_bart.py. These changes are the equivalent of the modeling changes in Generate: New Cache abstraction and Attention Sinks support #26681, but for encoder-decoder models.
  3. Other changes, which can be reviewed more lightly. They are either related documentation fixes, minor corrections, propagation of bart's changes through make fix-copies (plus a few manual changes like adding imports or updating docstrings), or test upgrades for the new DynamicCacheWithCrossAttention.

The following tests were run locally - includes FA2 and some pretty challenging tests to ensure nothing was broken in the process:

  • RUN_SLOW=1 py.test tests/models/bart/test_modeling_bart.py -vv
  • RUN_SLOW=1 py.test tests/models/mbart/test_modeling_mbart.py -vv
  • RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv

👉 In any case, we should run the slow CI before merging!

Note on Whisper: same failures as in `main`, i.e. (open me)

Screenshot 2023-12-15 at 14 59 22

@gante gante changed the title Cache: Bart (and related architectures) support Cache objects Cache: Bart supports Cache objects Dec 15, 2023
@gante gante changed the title Cache: Bart supports Cache objects Cache: Bart supports Cache objects Dec 15, 2023
@gante gante changed the title Cache: Bart supports Cache objects Cache: Bart + related architectures support Cache objects Dec 15, 2023
@gante gante changed the title Cache: Bart + related architectures support Cache objects Cache: Bart and related architectures support Cache objects Dec 15, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Member Author

gante commented Dec 15, 2023

@amyeroberts this PR is not finalized, but I'd love to get an early review -- the failing tests are fixed by propagating the changes to models with the #Copied from statement. However, it's not a copy/paste job, so if you were to request changes, they could be painful to propagate to the remaining models 😬

The key parts to review now are labeled as 1 and 2 in the PR header 🤗

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive piece of work 🔥

I've just paid attention to the addition in cache_utils and changes in BART. Just some nits and questions on my side for understanding but overall structure I think looks great! Would be good to get a second set of eyes from someone with more cache experience on this too.

`past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or
`config.use_cache=True`.

Two formats are allowed:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is passing in inputs in the legacy_format discouraged? If both are allowed, then we should update the the type hint to have both; if the legacy format is deprecated, I'd reword this as we don't want to encourage passing in the old format.

Comment on lines +200 to +202
is_cross_attention = key_value_states is not None

if is_cross_attention:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - this variable is only used once an on the immediate next line - comment provides enough context

Suggested change
is_cross_attention = key_value_states is not None
if is_cross_attention:
if key_value_states is not None:


# Keep only the unprocessed tokens:
# 1 - If the length of the decoder_attention_mask exceeds the length of decoder_input_ids, then we are in a
# setting where some of the inputs are exclusivelly passed as part of the cache (e.g. when passing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ultranit

Suggested change
# setting where some of the inputs are exclusivelly passed as part of the cache (e.g. when passing
# setting where some of the inputs are exclusively passed as part of the cache (e.g. when passing

# Keep only the unprocessed tokens:
# 1 - If the length of the decoder_attention_mask exceeds the length of decoder_input_ids, then we are in a
# setting where some of the inputs are exclusivelly passed as part of the cache (e.g. when passing
# input_embeds as input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding, am I right in thinking the reason they're exclusively part of the cache if I pass input_embeds is because any input_ids must have been generated?

# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# decoder_input_ids based on the past_length.
elif past_length < decoder_input_ids.shape[1]:
decoder_input_ids = decoder_input_ids[:, past_length:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And in this case - we're removing tokens that have already been seen i.e. have been processed and part of the cache?

Comment on lines +2282 to +2303
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By eye, this looks equivalent to the logic above, just with input_ids instead of decoder_ids -> can we abstract out the common logic here?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gante
Copy link
Member Author

gante commented Jan 16, 2024

Mr bot, this is not stale (on hold while the static cache is being worked on, as they will likely have overlapping changes and the static cache is more important)

@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jan 16, 2024
@gante
Copy link
Member Author

gante commented Apr 18, 2024

Closing this PR, at this point it's easier to start from scratch

@gante gante closed this Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants