-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Assistant Generation] Improve Encoder Decoder #26701
Conversation
src/transformers/generation/utils.py
Outdated
assistant_model.max_assistant_tokens += 2.0 | ||
else: | ||
assistant_model.max_assistant_tokens = max(1.0, assistant_model.max_assistant_tokens - 1.0) | ||
# if n_matches == int(assistant_model.max_assistant_tokens): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The heuristic to increase / decrease the number of "look-ahead" tokens doesn't work well for whisper, can we maybe allow the user to somehow disable it? Maybe via a config attribute?
@@ -4391,19 +4395,16 @@ def assisted_decoding( | |||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) | |||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len | |||
assist_inputs = candidate_input_ids[:, -new_token_len:] | |||
assist_attn = torch.ones_like(candidate_input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this @gante ? Allocating new memory here every time leads to some slow downs that are not insignificant for Distil Whisper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this makes sense to remove, since it is the default attention mask! 👍
The documentation is not available anymore as the PR was closed or merged. |
@@ -4484,18 +4485,18 @@ def assisted_decoding( | |||
# 2.2. Process the new logits | |||
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present | |||
if len(logits_processor) > 0: | |||
for i in range(candidate_length): | |||
for i in range(candidate_length + 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was a bug previously. We forgot to apply the logits processors to the last logit here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch 👀
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | ||
else: | ||
selected_tokens = new_logits[:, -candidate_length - 1 :, :].argmax(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unnecessary here as new_logits
is already sliced
…to improve_assistant_generation_enc_dec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic, thank you for the upgrades @patrickvonplaten 🔥
Only added two minor, optional nits.
@@ -227,6 +227,20 @@ class GenerationConfig(PushToHubMixin): | |||
decoder_start_token_id (`int`, *optional*): | |||
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. | |||
|
|||
> Generation parameters exclusive to [assistant generation](https://arxiv.org/abs/2211.17192) | |||
|
|||
max_assistant_tokens (`int`, *optional*, defaults to 5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can take the chance to give a better name to this variable: assistant_tokens
or similar. max_assistant_tokens
implies that the assistant will never cross this limit but, as we can see in max_assistant_tokens_schedule
(which should also be renamed accordingly), that is not true :)
Poor original naming choice by me :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
src/transformers/generation/utils.py
Outdated
assistant_model.max_assistant_tokens = 5 # this value, which will be updated, persists across calls | ||
if hasattr(assistant_model, "max_assistant_tokens"): | ||
warnings.warn( | ||
"Setting `max_assistant_tokens` via `assistant_model.max_assistant_tokens` is deprecated and will be removed in v5. Make sure to set `max_assistant_tokens` via the generation_config instead.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can deprecate this earlier (like in v4.37)?
I haven't seen users fiddling with this internal variable :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for me!
@@ -4391,19 +4395,16 @@ def assisted_decoding( | |||
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model) | |||
new_token_len = candidate_input_ids.shape[1] - prev_seq_len | |||
assist_inputs = candidate_input_ids[:, -new_token_len:] | |||
assist_attn = torch.ones_like(candidate_input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this makes sense to remove, since it is the default attention mask! 👍
@@ -4484,18 +4485,18 @@ def assisted_decoding( | |||
# 2.2. Process the new logits | |||
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present | |||
if len(logits_processor) > 0: | |||
for i in range(candidate_length): | |||
for i in range(candidate_length + 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch 👀
…to improve_assistant_generation_enc_dec
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
…into improve_assistant_generation_enc_dec
The failing Hub test seems to be flaky. This PR is ready for a final review. |
@@ -544,7 +544,11 @@ def forward( | |||
inputs_embeds = self.embed_tokens(input) * self.embed_scale | |||
|
|||
if attention_mask is None: | |||
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a bug previously. The attention_mask should be equal to input_embeds + past_key_values length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean thanks for improving the performances!
- `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else | ||
reduce by 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if the schedule parameters should be hard coded but fine for me
* [Assistant Generation] Improve enc dec * save more * Fix logit processor checks * Clean * make style * fix deprecation * fix generation test * Apply suggestions from code review * fix biogpt * make style
* [Assistant Generation] Improve enc dec * save more * Fix logit processor checks * Clean * make style * fix deprecation * fix generation test * Apply suggestions from code review * fix biogpt * make style
* [Assistant Generation] Improve enc dec * save more * Fix logit processor checks * Clean * make style * fix deprecation * fix generation test * Apply suggestions from code review * fix biogpt * make style
What does this PR do?
This PR speeds up assistant generation / speculative decoding for encoder-decoder models such as Distill-Whisper by ~20-30%.
Improvements:
assistant_encoder_outputs
so that the inputs are not encoded twice (gives ~20% speed-up)