Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add different token sampling algorithms to decoder. #123

Merged
merged 12 commits into from
Jun 14, 2024

Conversation

bvrockwell
Copy link
Collaborator

@bvrockwell bvrockwell commented Jun 12, 2024

  • Added sampling methods to engine.py decoder (greedy, weighted, nucleus, topk).
  • Added configurations to different launch methods.
  • rolled up Jetstream submodule to main

@bvrockwell bvrockwell force-pushed the add-decoder-temperature branch 3 times, most recently from 134f77a to b692f77 Compare June 12, 2024 05:40
@bvrockwell bvrockwell marked this pull request as ready for review June 12, 2024 19:00
Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding different token sampling! Can you start the pytorch engine server and share the inference result with different sampling algorithms?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need create a duplicated class if sampling_utils py is same as jetstream's one. Jetstream is one of dependencies of pytorch engine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, wasn't sure if I should just keep it to the previous pinned JetStream commit or roll up to most recent main. 13 files changed since (including sampling_utils.py addition, so I changed the import to point to jetstream.engine instead). I'll run the tests with the different sampling algorithms too, thanks!

@@ -220,7 +222,14 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
if len(logits.shape) == 2:
logits = jnp.expand_dims(logits, 0)
return (
jnp.argmax(logits[:, -1], axis=-1)
sampling_utils.sampling(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit for the sampling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests to test_engine.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding it, looks good to me!

@@ -220,7 +222,14 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
if len(logits.shape) == 2:
logits = jnp.expand_dims(logits, 0)
return (
jnp.argmax(logits[:, -1], axis=-1)
sampling_utils.sampling(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding it, looks good to me!

@bvrockwell
Copy link
Collaborator Author

bvrockwell commented Jun 13, 2024

Thanks for adding different token sampling! Can you start the pytorch engine server and share the inference result with different sampling algorithms?

Here are the inference results between main and proposed changes:

python run_server.py --size=7b --batch_size=1 --max_cache_length=16 \ --quantize_weights=true --quantize_kv_cache=true --checkpoint_path=".../model/llama2" \ --tokenizer_path="deps/JetStream/jetstream/tests/engine/third_party/llama2/tokenizer.model" \ --model_name=llama-2 --sharding_config="default_shardings/llama.yaml" &> server.log &

call to benchmark

python deps/JetStream/benchmarks/benchmark_serving.py \ --tokenizer="deps/JetStream/jetstream/tests/engine/third_party/llama2/tokenizer.model" \ --num-prompts 100 \ --warmup-mode sampled

main: v5e-1 llama2-7b "greedy" quantized "int8_per_channel" batch=1 max_cache_length=16

Successful requests: 130
Benchmark duration: 23.104631 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 5.63 requests/s
Input token throughput: 382.26 tokens/s
Output token throughput: 28.57 tokens/s
Mean TTFT: 11684.67 ms
Median TTFT: 11878.46 ms
P99 TTFT: 22641.39 ms
Mean TPOT: 3531.82 ms
Median TPOT: 2214.95 ms
P99 TPOT: 18868.32 ms

add-decoder-temperature: "greedy" same settings

Successful requests: 130
Benchmark duration: 23.105529 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 5.63 requests/s
Input token throughput: 382.25 tokens/s
Output token throughput: 28.56 tokens/s
Mean TTFT: 11684.65 ms
Median TTFT: 11878.57 ms
P99 TTFT: 22642.09 ms
Mean TPOT: 3531.79 ms
Median TPOT: 2214.98 ms
P99 TPOT: 18869.06 ms

add-decoder-temperature: "weighted" temperature=10 same settings

Successful requests: 130
Benchmark duration: 23.113795 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 5.62 requests/s
Input token throughput: 382.11 tokens/s
Output token throughput: 28.55 tokens/s
Mean TTFT: 11689.17 ms
Median TTFT: 11883.04 ms
P99 TTFT: 22650.20 ms
Mean TPOT: 3533.19 ms
Median TPOT: 2215.81 ms
P99 TPOT: 18875.58 ms

add-decoder-temperature "topk" topk=10 temperature=0.5 same settings

Successful requests: 130
Benchmark duration: 26.119014 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 4.98 requests/s
Input token throughput: 338.14 tokens/s
Output token throughput: 25.27 tokens/s
Mean TTFT: 11691.81 ms
Median TTFT: 11885.63 ms
P99 TTFT: 22655.60 ms
Mean TPOT: 3533.97 ms
Median TPOT: 2216.27 ms
P99 TPOT: 18880.23 ms

add-decoder-temperature "nucleus" nucleus_topp=0.8 temperature=0.5 same settings

Successful requests: 130
Benchmark duration: 23.645459 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 5.50 requests/s
Input token throughput: 373.52 tokens/s
Output token throughput: 27.91 tokens/s
Mean TTFT: 11957.90 ms
Median TTFT: 12156.01 ms
P99 TTFT: 23170.30 ms
Mean TPOT: 3614.40 ms
Median TPOT: 2266.83 ms
P99 TPOT: 19309.17 ms

v5e-4

main "greedy"

Successful requests: 130
Benchmark duration: 7.190380 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 18.08 requests/s
Input token throughput: 1228.31 tokens/s
Output token throughput: 91.79 tokens/s
Mean TTFT: 3626.83 ms
Median TTFT: 3692.30 ms
P99 TTFT: 7034.99 ms
Mean TPOT: 1085.40 ms
Median TPOT: 680.39 ms
P99 TPOT: 5863.28 ms

add-decoder-temperature "greedy" same settings

Successful requests: 130
Benchmark duration: 9.167303 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 14.18 requests/s
Input token throughput: 963.42 tokens/s
Output token throughput: 72.00 tokens/s
Mean TTFT: 3612.90 ms
Median TTFT: 3649.38 ms
P99 TTFT: 7026.81 ms
Mean TPOT: 1092.39 ms
Median TPOT: 686.06 ms
P99 TPOT: 5855.20 ms

add-decoder-temperature "weighted" temperature=10 same settings

Successful requests: 130
Benchmark duration: 7.201671 s
Total input tokens: 8832
Total generated tokens: 660
Request throughput: 18.05 requests/s
Input token throughput: 1226.38 tokens/s
Output token throughput: 91.65 tokens/s
Mean TTFT: 3633.50 ms
Median TTFT: 3697.00 ms
P99 TTFT: 7045.76 ms
Mean TPOT: 1093.84 ms
Median TPOT: 689.30 ms
P99 TPOT: 5809.58 ms

@qihqi qihqi merged commit 97aaeae into main Jun 14, 2024
4 checks passed
@bvrockwell bvrockwell deleted the add-decoder-temperature branch June 14, 2024 00:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants