-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: assisted_decoding
now accepts arbitrary candidate generators
#27750
Conversation
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) | ||
|
||
|
||
def _crop_past_key_values(model, past_key_values, maximum_length): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: these functions were moved here to avoid circular imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be good to use the new cache format no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soon! (we still need to maintain retrocompatibility)
@@ -888,6 +895,29 @@ def _reorder_cache(self, past_key_values, beam_idx): | |||
f" enable beam search for {self.__class__}" | |||
) | |||
|
|||
def _get_candidate_generator( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function will be expanded as we add more CandidateGenerator
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so will the logic here be something like
- check params in
generation_config
(someif else
condition) - based on params, set
candidate_generator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apoorvumang exactly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, so I'll write a PromptLookupCandidateGenerator
that implements CandidateGenerator
, and then wire it up in this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the plan is to have similar checks to the ones we have for the supported logits processor I guess?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker precisely, we will have flags to control which candidate generation strategies we have in place. I suspect that, because some candidate generation strategies are so cheap (like the one proposed in #27722), assisted generation may become mainstream!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
One more place needs to change I think -
how do u suggest this should change to support prompt lookup decoding? @gante |
@apoorvumang I'd add an |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥 would maybe rename the file with candiate_generators
like we have logits_processor
but otherwise great!
@@ -888,6 +895,29 @@ def _reorder_cache(self, past_key_values, beam_idx): | |||
f" enable beam search for {self.__class__}" | |||
) | |||
|
|||
def _get_candidate_generator( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the plan is to have similar checks to the ones we have for the supported logits processor I guess?
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) | ||
|
||
|
||
def _crop_past_key_values(model, past_key_values, maximum_length): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be good to use the new cache format no?
77a8b67
to
a57367b
Compare
…ors (huggingface#27750) Co-authored-by: Arthur <[email protected]>
…ors (huggingface#27750) Co-authored-by: Arthur <[email protected]>
What does this PR do?
A common trend is starting to pop up: people are experimenting with new strategies to generate candidate sequences, to then run an assisted-generation-like strategy. A key example is the new technique in #27722, which is equal to
assisted_decoding
except for the candidate generation part. This technique in particular achieves nice speedups in some settings, and doesn't need an assistant model -- a model-free speedup!To facilitate experimentation and the addition of new candidate generation techniques, this PR abstracts the candidate generation part in
assisted_decoding
to a new class with a stable API. This was inspired in classes likeLogitsProcessor
orStoppingCriteria
-- components ofgenerate
that can easily be replaced. All these changes are backwards compatible! 🤗Suggested review order:
utils.py
, to see the shape ofassisted_decoding
under the abstracted APIcandidate.py
, to see the structure of the new base class (and the specific case of the original assisted generation)The following tests are passing:
RUN_SLOW=1 py.test tests/models/whisper/ -k speculative
py.test tests/ -k test_assisted
(which catches mixin and integration tests associated with assisted generation)Happy to add more tests if needed :)