Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) #25242

Conversation

sinking-point
Copy link
Contributor

What does this PR do?

Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call.

The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed.

Fixes #25020

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

@gante
Copy link
Member

gante commented Aug 3, 2023

@sinking-point the PR has "WIP" in the title -- is it still under development, or is it ready to review?

@sinking-point
Copy link
Contributor Author

Not ready yet. Still have to fix more models and see what's breaking the other test. I've deprioritised this somewhat as it's quite time consuming, but I'll keep chipping away at it whenever I can.

If you need this done quickly, you're welcome to help - lmk and I'll add you as a collaborator on my branch.

@gante
Copy link
Member

gante commented Aug 3, 2023

Not urgent -- simply double-checking whether it was in need of a review or not :)

@sinking-point sinking-point force-pushed the sinking-point/assisted-decoding-model-kwargs-fix-all-models-25020 branch 9 times, most recently from bffb27b to bcad9c7 Compare September 6, 2023 10:49
@sinking-point sinking-point changed the title WIP In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) Sep 6, 2023
@sinking-point
Copy link
Contributor Author

@gante This should be ready for review now. Thanks in advance.

@sinking-point sinking-point force-pushed the sinking-point/assisted-decoding-model-kwargs-fix-all-models-25020 branch from bcad9c7 to a41bf7c Compare September 6, 2023 11:22
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sinking-point this is a great piece of work, thank you so much for working on this 🙏

For me, it's a green light on the logic of the PR 🟢 I've added a few comments to improve on the readability of the changes (and one performance-related comment), so that we can fulfill our role as maintainers on top of your work at super-human levels 🤗

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Show resolved Hide resolved
src/transformers/models/bark/modeling_bark.py Show resolved Hide resolved
@gante
Copy link
Member

gante commented Sep 13, 2023

@ArthurZucker @LysandreJik This PR is a big one and touches in a core piece of logic for all generative models, so I'm tagging 2 core maintainers.

Context

Advanced generation techniques (like assisted generation or medusa) may generate more than one token per model forward pass. The original implementation of assisted generation had a lot of custom code, as it breaks one of the assumptions in the models' prepare_inputs_for_generation -- that only one token is generated per forward pass.

Solution

@sinking-point has kindly put forward a proposal to get rid of the custom code in assisted generation. After iterating with me, the plan was to remove the assumption of one token per forward in prepare_inputs_for_generation -- the slicing therein should be done based on how many tokens do not have corresponding past KV values, and not simply taking the last set of inputs. This is the change this PR implements, as well as the removal of some custom assisted generation code. Needless to say, it is fully backwards compatible 😉

Postface

To reiterate: this PR gets the green light from me in terms of logic 🟢, and it is a big contribution by @sinking-point. This PR is also important to future-proof our generative techniques -- we will be ready for new types of multiple-token-per-forward-pass strategies as a result of this PR.

I'll be off the next few weeks, but I'm sure this PR will get a quick resolution 🤗

@sinking-point
Copy link
Contributor Author

Thanks @gante , I'll take a look at your comments tomorrow 👍

@Rocketknight1
Copy link
Member

Hi @sinking-point! Sorry for the delay - I'm taking over this PR from @gante because he's out on a well-deserved rest right now. Is everything ready for review, or are there any other issues you want to discuss with the team before we take a final look at it?

@sinking-point
Copy link
Contributor Author

No worries @Rocketknight1 . Thanks for taking this on.

There's one discussion gante opened that I haven't resolved. Could you give your input on this? #25242 (comment)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@Rocketknight1
Copy link
Member

@sinking-point Replied to the last remaining issue up above!

@sinking-point
Copy link
Contributor Author

Thanks @Rocketknight1 , I'll take a look on Monday

@sinking-point sinking-point force-pushed the sinking-point/assisted-decoding-model-kwargs-fix-all-models-25020 branch from 72a343d to 8d1fbc2 Compare October 2, 2023 10:39
@sinking-point
Copy link
Contributor Author

Some random tests started failing so I rebased onto main where they're fixed, but it looks like I have some more work to do now.

@Rocketknight1
Copy link
Member

Ick, yeah. I'm not sure what's causing those test failures, but if you can't figure it out, let me know and I'll dive in!

@sinking-point
Copy link
Contributor Author

Should be ready to merge if you're happy with it. Thanks!

@sinking-point
Copy link
Contributor Author

Looks like doc tests passed @Rocketknight1 , so as you said let's make this a priority before any more models are added.

@Rocketknight1
Copy link
Member

Understood! It's quite a big PR since it touches so many models, but I'll try to get an internal review in the next few days.

@sinking-point
Copy link
Contributor Author

Could you block new generative model merges until this PR is merged?

Alternatively, could you require that new generative models' prepare_inputs_for_generation method follows this PR? That is, instead assuming that if past_key_values is provided it covers all but the last position, you should calculate how many positions are remaining after past_key_values and keep those.

@sinking-point
Copy link
Contributor Author

Hi @Rocketknight1 , any update on this?

@gante
Copy link
Member

gante commented Oct 10, 2023

Hey @sinking-point 👋

I'm back from holidays and I'll be doing a quick final check now. Assuming the check comes out positive, we'll tag a core maintainer to greenlight the merge.

Our apologies for the slow process, it should be quick now 🤗

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for iterating 🙏

@gante
Copy link
Member

gante commented Oct 10, 2023

@sinking-point regarding the failing test: rebasing the PR should fix it, the bug was fixed last week :)

@gante
Copy link
Member

gante commented Oct 10, 2023

ping @LysandreJik -- this PR should be ready to be merged after it is rebased. Please read this comment for context :)

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.
@sinking-point sinking-point force-pushed the sinking-point/assisted-decoding-model-kwargs-fix-all-models-25020 branch from cbf75a3 to 8ce040d Compare October 11, 2023 09:41
@sinking-point
Copy link
Contributor Author

Thanks @gante :)

@LysandreJik
Copy link
Member

This seems ok to me but I'd like to ask @patrickvonplaten for his opinion and eventual approval given the experience maintaining this part of the code

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Code is cleaned up and made more extendable - very much in favor of this change

@patrickvonplaten patrickvonplaten merged commit dcc49d8 into huggingface:main Oct 11, 2023
21 checks passed
@sinking-point sinking-point deleted the sinking-point/assisted-decoding-model-kwargs-fix-all-models-25020 branch October 11, 2023 13:11
@sinking-point
Copy link
Contributor Author

Amazing, thank you @LysandreJik and @patrickvonplaten

@gante
Copy link
Member

gante commented Oct 11, 2023

And thank you @sinking-point for this big contribution 🔥 💛

helboukkouri pushed a commit to helboukkouri/transformers that referenced this pull request Oct 16, 2023
…prepare_input_for_generation in all models) (huggingface#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
…prepare_input_for_generation in all models) (huggingface#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
…prepare_input_for_generation in all models) (huggingface#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GenerationMixin: model_kwargs not passed to model in assisted decoding
6 participants