-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add different token sampling algorithms to decoder. #123
Conversation
bvrockwell
commented
Jun 12, 2024
•
edited
Loading
edited
- Added sampling methods to engine.py decoder (greedy, weighted, nucleus, topk).
- Added configurations to different launch methods.
- rolled up Jetstream submodule to main
134f77a
to
b692f77
Compare
b692f77
to
8d672f4
Compare
32609da
to
409d118
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding different token sampling! Can you start the pytorch engine server and share the inference result with different sampling algorithms?
jetstream_pt/sampling_utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need create a duplicated class if sampling_utils py is same as jetstream's one. Jetstream is one of dependencies of pytorch engine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok, wasn't sure if I should just keep it to the previous pinned JetStream commit or roll up to most recent main. 13 files changed since (including sampling_utils.py addition, so I changed the import to point to jetstream.engine instead). I'll run the tests with the different sampling algorithms too, thanks!
@@ -220,7 +222,14 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: | |||
if len(logits.shape) == 2: | |||
logits = jnp.expand_dims(logits, 0) | |||
return ( | |||
jnp.argmax(logits[:, -1], axis=-1) | |||
sampling_utils.sampling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unit for the sampling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added tests to test_engine.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding it, looks good to me!
@@ -220,7 +222,14 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: | |||
if len(logits.shape) == 2: | |||
logits = jnp.expand_dims(logits, 0) | |||
return ( | |||
jnp.argmax(logits[:, -1], axis=-1) | |||
sampling_utils.sampling( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding it, looks good to me!
Here are the inference results between main and proposed changes:
call to benchmark
main: v5e-1 llama2-7b "greedy" quantized "int8_per_channel" batch=1 max_cache_length=16Successful requests: 130 add-decoder-temperature: "greedy" same settingsSuccessful requests: 130 add-decoder-temperature: "weighted" temperature=10 same settingsSuccessful requests: 130 add-decoder-temperature "topk" topk=10 temperature=0.5 same settingsSuccessful requests: 130 add-decoder-temperature "nucleus" nucleus_topp=0.8 temperature=0.5 same settingsSuccessful requests: 130 v5e-4main "greedy"Successful requests: 130 add-decoder-temperature "greedy" same settingsSuccessful requests: 130 add-decoder-temperature "weighted" temperature=10 same settingsSuccessful requests: 130 |