Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for prompt lookup decoding (variant of assisted generation) #27722

Closed
apoorvumang opened this issue Nov 27, 2023 · 24 comments
Closed

Comments

@apoorvumang
Copy link
Contributor

Feature request

Recently proposed method prompt lookup decoding, which replaces the draft model with string matching in prompt

Code: https://github.com/apoorvumang/prompt-lookup-decoding

Motivation

  • The method gives significant speedups in input grounded tasks (2x-4x)
  • Applicable to all decoder models, supports sampling
  • Easy to implement - we can just modify assisted generation to also support a function for assistant model (rather than a LLM)

Your contribution

I have a not-so-well written implementation here (python notebook). I can contribute in making it better, but will need help since its my first time

@apoorvumang
Copy link
Contributor Author

tagging @gante since you have recently worked on lookahead decoding

@gante
Copy link
Member

gante commented Nov 27, 2023

Hi @apoorvumang 👋

First of all, thank you for creating this clever strategy and for sharing it openly! It's simple and elegant, which makes it really great.

I've been thinking about it from a software point of view. The core functionality (find_candidate_pred_tokens) is simple, and it reuses the core of assisted_generation. I'm also seeing more techniques that speed up LLMs through the generation of candidate sequences. As such, here's my proposal:

  1. I'll open a PR today to refactor the contents of assisted_generation into a generalist decoding technique that accepts an arbitrary function to generate candidates. assisted_generation would be a variant of this function, as is your technique.
  2. In parallel, you can work to add your technique on top of the generalist decoding technique with candidates:
    a. You'll have to define the controlling parameters of your technique in the GenerationConfig class, defaulting to None
    b. When the parameters above are non-None, your technique would get triggered in generate, using the same pattern
    c. After the code is in a nearly ready state, we'll do some benchmarks over different tasks and share in social media

Does it sound good to you? 🤗

(LMK if you'd like further pointers)

@apoorvumang
Copy link
Contributor Author

Sounds good! I will try to read up and implement it

Agree on the assisted_generation refactoring as well - maybe we could even have user provided assistant_function? (but that's a software decision I'm not qualified to make)

@gante
Copy link
Member

gante commented Nov 28, 2023

@apoorvumang #27750 has the new assisted_decoding structure. It is still subject to review, but you can now have a concrete idea of what I had in mind :)

Adding your technique on top of it should be straightforward!

@apoorvumang
Copy link
Contributor Author

Thanks @gante ! Will look into the the refactored code now. I think I should be able to get something running by tonight (IST)

@apoorvumang
Copy link
Contributor Author

I have made a working implementation here, based off of #27750 : https://github.com/apoorvumang/transformers/tree/prompt_lookup_decoding . Should I start a PR with it?

@apoorvumang
Copy link
Contributor Author

Also, if you suggest any benchmarks/benchmarking code, I can help with that. I have access to A100 40GB GPU and M1 Max 32GB @gante

@gante
Copy link
Member

gante commented Nov 30, 2023

@apoorvumang Yes, open a PR! I can add a few suggestions even before #27750 is merged :)

My advice for benchmarks would be the following: users love it when a certain method works well with little to no hyperparameters. At the moment, I see two hyperparameters -- prompt_lookup_num_tokens and prompt_lookup_max_matching_ngram. I'd run a few benchmarks over a few datasets changing these hyperparameters to find whether we can:
a) get away with only one hyperparameter OR
b) set an update heuristic that gets the best hyperparameters for the input at hand (through the update_candidate_strategy method)

If you find a way to make a) or b) work, the technique would become more user-friendly, and thus with a higher chance of being correctly used. For us, transformers maintainers, having fewer flags is also great!

After we settle on a final implementation, I can validate the benchmarks on different devices (e.g. a T4, a 3090, ...). Given the simplicity of the technique, I suspect the results will be mostly hardware agnostic on GPU :)

@apoorvumang
Copy link
Contributor Author

Started PR here: https://github.com/huggingface/transformers/pull/27775/commits . Please do leave suggestions @gante

I will start some benchmarking on my side to find optimal hyperparameters (or update schedules). Maybe both of these can be best tuned using just a default value + update schedule, and if user wants to really change default value they can go instantiate and provide a PromptLookupCandidateGenerator with new params.

Will get back once I start some tests. I will be trying on some standard summarization, QA and maybe look for a code editing sort of dataset.

@apoorvumang
Copy link
Contributor Author

image

There is significant difference between greedy and sampling when summarizing, but there are still gains. Proper analysis of the phenomenon would be a paper-worthy effort probably.

I will try to run a similar thing for code editing as well. If you think there's something I could try pls let me know.

One question @gante : Is the most popular method greedy or sampling (I would assume greedy since its the default, but I know sampling is better for quality)? If I could optimize for only one of these, which one should be the 'default'?

@0xdevalias
Copy link

If I could optimize for only one of these, which one should be the 'default'?

Naive question/input here.. but assuming you can figure the optimisations, and they don't apply equally to both, would it be possible to have 2 settings for it? One when used with greedy and one when used with sampling? Even if that's handled automagically under the hood (or even presumably if it's exposed to users, it would be simpler than having to know the exact hyperparameters to tune?)

@apoorvumang
Copy link
Contributor Author

Thanks! Yes it can ofc - _get_candidate_generator has access to generation_config, which can be passed on here to check for stuff like this.

Any other thoughts/ideas @0xdevalias ?

@0xdevalias
Copy link

0xdevalias commented Dec 1, 2023

Thanks! Yes it can ofc

@apoorvumang Awesome :)

Any other thoughts/ideas?

@apoorvumang None at this stage; was more of a 'drive by random brain spark' type moment :)

@gante
Copy link
Member

gante commented Dec 1, 2023

@apoorvumang @0xdevalias the preliminary results seem to point out that there is no obvious parameterization 🤔 Let's wait to see the results for coding!

Regarding sampling vs greedy: greedy is the default for legacy reasons, sampling is by far the most popular with chat LLMs :) tasks like summarization, translation, or automatic speech recognition tend to use greedy decoding or beam search, though.

Finally, regarding default values: we'll have to default the values to None, so we can detect whether the user wants to use it or not. We have a few default values for legacy reasons, but the defaults should be set at a model level (with the generation_config.json). This does not prevent us, however, from suggesting values in the parameters' docstring 🤗

@apoorvumang
Copy link
Contributor Author

Here's using mt-bench, only 2nd turn code
image

@apoorvumang
Copy link
Contributor Author

All 80 samples from mt-bench, 2nd turn only.
image

@keyboardAnt
Copy link

keyboardAnt commented Dec 7, 2023

All 80 samples from mt-bench, 2nd turn only. image

Hi @apoorvumang – Thanks for sharing your great work!

Two quick questions:

  1. What temperature did you use in "Sampling baseline" and "Sampling PLD"?
  2. How should we interpret the black-colored lines that go below 0? (What is their minimal tokens per second rate?)

@gante
Copy link
Member

gante commented Dec 8, 2023

@keyboardAnt the error bars are usually the standard deviation of the measurement, which is a centered (and symmetric) moment -- it does not denote the minimum/maximum of a measurement, nor a range between percentiles.

As such, I'm reading it as a long-tailed distribution. Some speedups are huge (e.g. 5x), while most are moderate (e.g. 1.5x)

@apoorvumang
Copy link
Contributor Author

Hi @keyboardAnt , thank you!

  1. Default temperature, so probably 1.0
  2. As @gante said, the black coloured lines are standard deviation, not min or max. I didn't save the exact data for these so can't share that. But for places where it seems to be less than 0, its probably because of very high variance in speedups (1x to 10x).

Here's an example of this phenomenon, courtesy ChatGPT
image

PS: Sorry for the delay in working on this PR - I will try to work on it this weekend

@keyboardAnt
Copy link

@gante, @apoorvumang, yes. Because of the high variance, we better consider the minimal tokens/sec rate. This could ensure the long tail is one-sided. Otherwise, it might suggest a slowdown.

@apoorvumang
Copy link
Contributor Author

@keyboardAnt Could you please expand on what you mean? Like we should look for configs with a good lower bound for tokens/sec rather than a good average?

@keyboardAnt
Copy link

keyboardAnt commented Dec 10, 2023

@apoorvumang, my suggestion is to measure speedup. That is

speedup := (The ratio of tokens per second with PLD) / (The ratio of tokens per second without PLD)

where with-PLD and without-PLD share the same variables (e.g., prompt, target model, GPU device). We want to show that speedup >> 1 in most cases, and to rule out the possibility that speedup < 1 (i.e., a slowdown). The visualizations you shared do not rule out the possibility that speedup < 1.

We must measure speedup in varied configurations so we can better understand it. Each configuration has a unique prompt, target model, or (max_matching_ngram, num_token_output) hyperparameter. Visualizing the distribution of speedup and calculating its harmonic mean can help.

Copy link

github-actions bot commented Jan 4, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@keyboardAnt
Copy link

keyboardAnt commented Jun 12, 2024

@apoorvumang, my suggestion is to measure speedup. That is

speedup := (The ratio of tokens per second with PLD) / (The ratio of tokens per second without PLD)

where with-PLD and without-PLD share the same variables (e.g., prompt, target model, GPU device). We want to show that speedup >> 1 in most cases, and to rule out the possibility that speedup < 1 (i.e., a slowdown). The visualizations you shared do not rule out the possibility that speedup < 1.

We must measure speedup in varied configurations so we can better understand it.

We recently released this preprint that covers (also) the question of slowdowns: https://arxiv.org/pdf/2405.14105

Our experiments show that slowdowns exist in practice (for example, if PLD is too slow or inaccurate). We also propose a novel algorithm for running PLD (or any other drafters) in parallel so that there are no slowdowns.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants