Skip to content

Commit

Permalink
Update benchmark script to easily test llama-3 (#83)
Browse files Browse the repository at this point in the history
* Update benchmark script to easily test llama-3

* fix lint

* Update benchmarks/README.md
  • Loading branch information
bhavya01 authored May 17, 2024
1 parent 01c5a03 commit e4952fb
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
13 changes: 13 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ python benchmark_serving.py \
```

### Run Benchmark for Llama 3

```
python benchmark_serving.py \
--tokenizer <llama3 tokenizer path> \
--num-prompts 10 \
--dataset sharegpt \
--dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--model llama-3
```

### Save request outputs in Benchmark

Please use `--save-request-outputs` flag to save predictions to a file.
Expand Down
27 changes: 14 additions & 13 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
from jetstream.engine.token_utils import load_vocab
from jetstream.third_party.llama3 import llama3_tokenizer
import numpy as np
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
import pandas
Expand Down Expand Up @@ -130,10 +131,13 @@ def to_dict(self):
}


def get_tokenizer(tokenizer_name: str) -> Any:
def get_tokenizer(model_id: str, tokenizer_name: str) -> Any:
"""Return a tokenizer or a tokenizer placholder."""
if tokenizer_name == "test":
return "test"
elif model_id == "llama-3":
# Llama 3 uses a tiktoken tokenizer.
return llama3_tokenizer.Tokenizer(tokenizer_name)
else:
# Use JetStream tokenizer util. It's using the sentencepiece wrapper in
# seqio library.
Expand Down Expand Up @@ -195,18 +199,14 @@ def tokenize_dataset(
prompts = []
outputs = []
indices = []

prompt_token_ids = []
outputs_token_ids = []
for prompt, output, idx in dataset:
prompts.append(prompt)
outputs.append(output)
indices.append(idx)

prompt_token_ids = tokenizer.encode(
prompts
) # adjust this code based on tokenizer method
outputs_token_ids = tokenizer.encode(
outputs
) # adjust this code based on tokenizer method
prompt_token_ids.append(tokenizer.encode(prompt))
outputs_token_ids.append(tokenizer.encode(output))

tokenized_dataset = []
for i in range(n):
Expand Down Expand Up @@ -549,7 +549,7 @@ def main(args: argparse.Namespace):

api_url = f"{args.server}:{args.port}"

tokenizer = get_tokenizer(tokenizer_id)
tokenizer = get_tokenizer(model_id, tokenizer_id)
if tokenizer == "test" or args.dataset == "test":
input_requests = mock_requests(
args.total_mock_requests
Expand Down Expand Up @@ -680,9 +680,10 @@ def main(args: argparse.Namespace):
type=str,
default="no_model",
help=(
"Name of the model. (it's just used to label the benchmark, the model"
" config is defined in config_lib, and passed as the server config"
" flag when we run the JetStream server)"
"Name of the model like llama-2, llama-3, gemma. (it's just used to"
" label the benchmark, pick the tokenizer, the model config is"
" defined in config_lib, and passed as the server config flag when"
" we run the JetStream server)"
),
)
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions jetstream/third_party/llama3/llama3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
bos: bool = False,
eos: bool = False,
allowed_special: Union[Literal["all"], AbstractSet[str]] | None = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
Expand Down

0 comments on commit e4952fb

Please sign in to comment.