Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Assistant Generation] Improve Encoder Decoder #26701

Merged
merged 13 commits into from
Oct 11, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Oct 9, 2023

What does this PR do?

This PR speeds up assistant generation / speculative decoding for encoder-decoder models such as Distill-Whisper by ~20-30%.

Improvements:

  • If assistant and model share same encoder, let's allow the user to pass assistant_encoder_outputs so that the inputs are not encoded twice (gives ~20% speed-up)
  • In the small loop I don't think we have to allocate tensors for the attention mask all the time. This is done automatically by the model if necessary (gives ~3,4% speed-up)
  • The heuristic to increase / decrease the number of "look-ahead" tokens doesn't work well for whisper, can we maybe allow the user to somehow disable it? Maybe via a config attribute?

@patrickvonplaten patrickvonplaten marked this pull request as draft October 9, 2023 17:53
assistant_model.max_assistant_tokens += 2.0
else:
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0)
# if n_matches == int(assistant_model.max_assistant_tokens):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The heuristic to increase / decrease the number of "look-ahead" tokens doesn't work well for whisper, can we maybe allow the user to somehow disable it? Maybe via a config attribute?

@@ -4391,19 +4395,16 @@ def assisted_decoding(
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
assist_attn = torch.ones_like(candidate_input_ids)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this @gante ? Allocating new memory here every time leads to some slow downs that are not insignificant for Distil Whisper

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this makes sense to remove, since it is the default attention mask! 👍

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 9, 2023

The documentation is not available anymore as the PR was closed or merged.

@@ -4484,18 +4485,18 @@ def assisted_decoding(
# 2.2. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
for i in range(candidate_length + 1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a bug previously. We forgot to apply the logits processors to the last logit here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 👀

selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary here as new_logits is already sliced

@patrickvonplaten patrickvonplaten marked this pull request as ready for review October 10, 2023 12:20
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, thank you for the upgrades @patrickvonplaten 🔥

Only added two minor, optional nits.

@@ -227,6 +227,20 @@ class GenerationConfig(PushToHubMixin):
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.

> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192)

max_assistant_tokens (`int`, *optional*, defaults to 5):
Copy link
Member

@gante gante Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can take the chance to give a better name to this variable: assistant_tokens or similar. max_assistant_tokens implies that the assistant will never cross this limit but, as we can see in max_assistant_tokens_schedule (which should also be renamed accordingly), that is not true :)

Poor original naming choice by me :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls
if hasattr(assistant_model, "max_assistant_tokens"):
warnings.warn(
"Setting `max_assistant_tokens` via `assistant_model.max_assistant_tokens` is deprecated and will be removed in v5. Make sure to set `max_assistant_tokens` via the generation_config instead.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can deprecate this earlier (like in v4.37)?

I haven't seen users fiddling with this internal variable :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me!

@@ -4391,19 +4395,16 @@ def assisted_decoding(
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len = candidate_input_ids.shape[1] - prev_seq_len
assist_inputs = candidate_input_ids[:, -new_token_len:]
assist_attn = torch.ones_like(candidate_input_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this makes sense to remove, since it is the default attention mask! 👍

@@ -4484,18 +4485,18 @@ def assisted_decoding(
# 2.2. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
if len(logits_processor) > 0:
for i in range(candidate_length):
for i in range(candidate_length + 1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch 👀

@patrickvonplaten patrickvonplaten changed the title [Assistant Generation] Improve enc dec [Assistant Generation] Improve Encoder Decoder Oct 11, 2023
@patrickvonplaten
Copy link
Contributor Author

The failing Hub test seems to be flaky.

This PR is ready for a final review.

@@ -544,7 +544,11 @@ def forward(
inputs_embeds = self.embed_tokens(input) * self.embed_scale

if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug previously. The attention_mask should be equal to input_embeds + past_key_values length.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean thanks for improving the performances!

Comment on lines +240 to +241
- `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else
reduce by 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if the schedule parameters should be hard coded but fine for me

@patrickvonplaten patrickvonplaten merged commit da69de1 into main Oct 11, 2023
19 of 21 checks passed
@patrickvonplaten patrickvonplaten deleted the improve_assistant_generation_enc_dec branch October 11, 2023 13:52
helboukkouri pushed a commit to helboukkouri/transformers that referenced this pull request Oct 16, 2023
* [Assistant Generation] Improve enc dec

* save more

* Fix logit processor checks

* Clean

* make style

* fix deprecation

* fix generation test

* Apply suggestions from code review

* fix biogpt

* make style
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
* [Assistant Generation] Improve enc dec

* save more

* Fix logit processor checks

* Clean

* make style

* fix deprecation

* fix generation test

* Apply suggestions from code review

* fix biogpt

* make style
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
* [Assistant Generation] Improve enc dec

* save more

* Fix logit processor checks

* Clean

* make style

* fix deprecation

* fix generation test

* Apply suggestions from code review

* fix biogpt

* make style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants