-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor generation sampling parameters (e.g. top k, temperature) into "Sampling" classes #5420
Conversation
…o Sampling classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a huge improvement from a code readability and extensibility perspective! My only concern is performance.
This CI failure suggests that generation is slowed down.
The failing test is checking (very indirectly) how long a very small bart variant took to .generate
on small batches.
From an accuracy perspective, we have some slow integration tests to make sure generation quality doesn't regress.
(These can be prefixed by USE_CUDA=1 if you are on GPU/want them to run faster.)
You should do one run of all the @slow tests using
RUN_SLOW=1 pytest tests/
The ones most likely to break are
RUN_SLOW=1 pytest tests/test_modeling_bart.py
RUN_SLOW=1 pytest tests/test_modeling_t5.py
RUN_SLOW=1 pytest tests/test_modeling_marian.py
@sshleifer thanks for taking a look. The run against the tests you mentioned (bart/t5/marian) passed when I gave them a kick. When you say performance, this approach should have the same amount of compute (each enabled Sampler runs once per generation loop) since it is just moving code around unless I missed something. Let me do a rebase and see if that CI failure goes away -- let me know if you have any other concerns! |
batch_size=batch_size, | ||
num_beams=num_beams, | ||
) | ||
if sampler: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here we will not be able to keep backwards probability with beam_search
+ sampling
because top_k_top_p_filtering
is applied after the beam scores are added. I think from a logical point of view it does make more sense to apply top_k_top_p_filtering
after adding the beam scores. On the other hand beam search sampling is not used that much and definitely an edge case....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC the proposal would be, get the raw logits, normalize, add beam scores and then perform sampling using the transformed distribution? That makes sense to me; it seems like a design decision as to how to make these probability shifts interact with beam search. Is it covered in any literature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I would think the distribution should be transformed after the beam scores have been added. I don't know any literature here though. I'm not too concerned about beam search + sampling, but I'm not sure if we also restrict "greedy" beam search this way for future use cases. @yjernite @srush - do you have more insight here maybe?
@turtlesoupy - thanks a lot for the PR! Cool design choice! The I guess a method that adapts the |
@patrickvonplaten I'm un-opinionated since my use cases weren't using beam search; the goal of this PR was so that I could introduce a my own sampler that enforced rules without having to fork the generate function. For beam search, one approach could be to apply the warp to ( |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
#4164 has a full description of the intention here. Basically, to avoid exploding generate(...) with more arguments, I've added one generic Sampler parameter that allows for arbitrary transformations of the generation probability distribution conditioned on the past. This allows users to specify custom ways of sampling (e.g. insert a specific token after a previous one, etc.)
In the process, I've added some basic tests around these samplers; existing tests pass otherwise.