-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) #25242
Conversation
@sinking-point the PR has "WIP" in the title -- is it still under development, or is it ready to review? |
Not ready yet. Still have to fix more models and see what's breaking the other test. I've deprioritised this somewhat as it's quite time consuming, but I'll keep chipping away at it whenever I can. If you need this done quickly, you're welcome to help - lmk and I'll add you as a collaborator on my branch. |
Not urgent -- simply double-checking whether it was in need of a review or not :) |
bffb27b
to
bcad9c7
Compare
@gante This should be ready for review now. Thanks in advance. |
bcad9c7
to
a41bf7c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sinking-point this is a great piece of work, thank you so much for working on this 🙏
For me, it's a green light on the logic of the PR 🟢 I've added a few comments to improve on the readability of the changes (and one performance-related comment), so that we can fulfill our role as maintainers on top of your work at super-human levels 🤗
@ArthurZucker @LysandreJik This PR is a big one and touches in a core piece of logic for all generative models, so I'm tagging 2 core maintainers. ContextAdvanced generation techniques (like assisted generation or medusa) may generate more than one token per model forward pass. The original implementation of assisted generation had a lot of custom code, as it breaks one of the assumptions in the models' Solution@sinking-point has kindly put forward a proposal to get rid of the custom code in assisted generation. After iterating with me, the plan was to remove the assumption of one token per PostfaceTo reiterate: this PR gets the green light from me in terms of logic 🟢, and it is a big contribution by @sinking-point. This PR is also important to future-proof our generative techniques -- we will be ready for new types of multiple-token-per-forward-pass strategies as a result of this PR. I'll be off the next few weeks, but I'm sure this PR will get a quick resolution 🤗 |
Thanks @gante , I'll take a look at your comments tomorrow 👍 |
Hi @sinking-point! Sorry for the delay - I'm taking over this PR from @gante because he's out on a well-deserved rest right now. Is everything ready for review, or are there any other issues you want to discuss with the team before we take a final look at it? |
No worries @Rocketknight1 . Thanks for taking this on. There's one discussion gante opened that I haven't resolved. Could you give your input on this? #25242 (comment) |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@sinking-point Replied to the last remaining issue up above! |
Thanks @Rocketknight1 , I'll take a look on Monday |
72a343d
to
8d1fbc2
Compare
Some random tests started failing so I rebased onto main where they're fixed, but it looks like I have some more work to do now. |
Ick, yeah. I'm not sure what's causing those test failures, but if you can't figure it out, let me know and I'll dive in! |
Should be ready to merge if you're happy with it. Thanks! |
Looks like doc tests passed @Rocketknight1 , so as you said let's make this a priority before any more models are added. |
Understood! It's quite a big PR since it touches so many models, but I'll try to get an internal review in the next few days. |
Alternatively, could you require that new generative models' |
Hi @Rocketknight1 , any update on this? |
Hey @sinking-point 👋 I'm back from holidays and I'll be doing a quick final check now. Assuming the check comes out positive, we'll tag a core maintainer to greenlight the merge. Our apologies for the slow process, it should be quick now 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for iterating 🙏
@sinking-point regarding the failing test: rebasing the PR should fix it, the bug was fixed last week :) |
ping @LysandreJik -- this PR should be ready to be merged after it is rebased. Please read this comment for context :) |
Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call. The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed.
…to support assisted generation
cbf75a3
to
8ce040d
Compare
Thanks @gante :) |
This seems ok to me but I'd like to ask @patrickvonplaten for his opinion and eventual approval given the experience maintaining this part of the code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Code is cleaned up and made more extendable - very much in favor of this change
Amazing, thank you @LysandreJik and @patrickvonplaten |
And thank you @sinking-point for this big contribution 🔥 💛 |
…prepare_input_for_generation in all models) (huggingface#25242) * In assisted decoding, pass model_kwargs to model's forward call Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call. The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed. * Improve variable names in _extend_attention_mask * Refactor extending token_type_ids into a function * Replace deepcopy with copy to optimize performance * Update new persimmon model with llama changes for assisted generation * Update new mistral model for assisted generation with prepare_inputs_for_generation * Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
…prepare_input_for_generation in all models) (huggingface#25242) * In assisted decoding, pass model_kwargs to model's forward call Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call. The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed. * Improve variable names in _extend_attention_mask * Refactor extending token_type_ids into a function * Replace deepcopy with copy to optimize performance * Update new persimmon model with llama changes for assisted generation * Update new mistral model for assisted generation with prepare_inputs_for_generation * Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
…prepare_input_for_generation in all models) (huggingface#25242) * In assisted decoding, pass model_kwargs to model's forward call Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call. The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed. * Improve variable names in _extend_attention_mask * Refactor extending token_type_ids into a function * Replace deepcopy with copy to optimize performance * Update new persimmon model with llama changes for assisted generation * Update new mistral model for assisted generation with prepare_inputs_for_generation * Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
What does this PR do?
Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call.
The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed.
Fixes #25020
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante