Skip to content

Commit

Permalink
Align Tokenizer in JetStream (#40)
Browse files Browse the repository at this point in the history
* Align Tokenizer in JetStream

* Update requirements with pytest dep

* Remove mix_decode unit test
  • Loading branch information
JoeZijunZhou authored Apr 24, 2024
1 parent f6f9b06 commit a0df320
Show file tree
Hide file tree
Showing 16 changed files with 136 additions and 143 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ jobs:
pip install pylint
pip install pyink
pip install -r requirements.txt
pip install -r benchmarks/requirements.in
- name: Typecheck the code with pytype
run: |
pytype --jobs auto --disable import-error --disable module-attr jetstream/
pytype --jobs auto --disable import-error --disable module-attr jetstream/ benchmarks/
- name: Analysing the code with pylint
run: |
pylint jetstream/ benchmarks/
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ python -m jetstream.tools.load_tester
### Test core modules
```
# Test JetStream core orchestrator
python -m jetstream.core.orchestrator_test
python -m jetstream.tests.core.test_orchestrator
# Test JetStream core server library
python -m jetstream.core.server_test
python -m jetstream.tests.core.test_server
# Test mock JetStream engine implementation
python -m jetstream.engine.mock_engine_test
python -m jetstream.tests.engine.test_mock_engine
# Test mock JetStream token utils
python -m jetstream.engine.utils_test
python -m jetstream.tests.engine.test_utils
```
65 changes: 33 additions & 32 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@
import json
import random
import time
from typing import Any, AsyncGenerator, List, Optional
from typing import Any, AsyncGenerator, Optional

import grpc
from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
from jetstream.engine.token_utils import load_vocab
import numpy as np
import tensorflow as tf
import tensorflow_text as tftxt
from tqdm.asyncio import tqdm
from tqdm.asyncio import tqdm # pytype: disable=pyi-error
from eval_accuracy import eval_accuracy


Expand Down Expand Up @@ -106,9 +105,9 @@ class InputRequest:

@dataclass
class RequestFuncOutput:
input_request: InputRequest = None
generated_token_list: list[str] = None
generated_text: str = None
input_request: Optional[InputRequest] = None
generated_token_list: list[str] = []
generated_text: str = ""
success: bool = False
latency: float = 0
ttft: float = 0
Expand All @@ -132,18 +131,16 @@ def get_tokenizer(tokenizer_name: str) -> Any:
if tokenizer_name == "test":
return "test"
else:
with tf.io.gfile.GFile(tokenizer_name, "rb") as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
model=sp_model, add_bos=True, add_eos=False, reverse=False
)
return sp_tokenizer
# Use JetStream tokenizer util. It's using the sentencepiece wrapper in
# seqio library.
vocab = load_vocab(tokenizer_name)
return vocab.tokenizer


def load_sharegpt_dataset(
dataset_path: str,
conversation_starter: str,
) -> List[tuple[str]]:
) -> list[tuple[Any, Any]]:
# Load the dataset.
with open(dataset_path, "r", encoding="utf-8") as f:
dataset = json.load(f)
Expand All @@ -166,7 +163,7 @@ def load_sharegpt_dataset(
return dataset


def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:
def load_openorca_dataset(dataset_path: str) -> list[tuple[Any, Any]]:
# Load the dataset.
with open(dataset_path, "r", encoding="utf-8") as f:
dataset = json.load(f)
Expand All @@ -179,9 +176,9 @@ def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:


def tokenize_dataset(
dataset: List[tuple[str]],
dataset: list[tuple[Any, Any, Any]],
tokenizer: Any,
) -> List[tuple[Any]]:
) -> list[tuple[str, Any, str, int, int, int]]:

n = len(dataset)

Expand All @@ -194,10 +191,10 @@ def tokenize_dataset(
outputs.append(output)
indices.append(idx)

prompt_token_ids = tokenizer.tokenize(
prompt_token_ids = tokenizer.encode(
prompts
) # adjust this code based on tokenizer method
outputs_token_ids = tokenizer.tokenize(
outputs_token_ids = tokenizer.encode(
outputs
) # adjust this code based on tokenizer method

Expand All @@ -218,8 +215,9 @@ def tokenize_dataset(


def filter_dataset(
tokenized_dataset: List[tuple[Any]], max_output_length: Optional[int] = None
) -> List[InputRequest]:
tokenized_dataset: list[tuple[str, Any, str, int, int, int]],
max_output_length: Optional[int] = None,
) -> list[InputRequest]:
if max_output_length is None:
print("In InputRequest, pass in actual output_length for each sample")
else:
Expand All @@ -229,7 +227,7 @@ def filter_dataset(
)

# Filter out too long sequences.
filtered_dataset: List[InputRequest] = []
filtered_dataset: list[InputRequest] = []
for (
prompt,
_,
Expand Down Expand Up @@ -258,12 +256,12 @@ def filter_dataset(


def sample_requests(
dataset: List[tuple[str]],
dataset: list[tuple[Any, Any]],
tokenizer: Any,
num_requests: int,
max_output_length: Optional[int] = None,
oversample_multiplier: float = 1.2,
) -> List[InputRequest]:
) -> list[InputRequest]:

# Original dataset size
n = len(dataset)
Expand Down Expand Up @@ -304,7 +302,7 @@ def sample_requests(


async def get_request(
input_requests: List[InputRequest],
input_requests: list[InputRequest],
request_rate: float,
) -> AsyncGenerator[InputRequest, None]:
input_requests = iter(input_requests)
Expand All @@ -321,8 +319,8 @@ async def get_request(


def calculate_metrics(
input_requests: List[InputRequest],
outputs: List[RequestFuncOutput],
input_requests: list[InputRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer: Any,
) -> BenchmarkMetrics:
Expand Down Expand Up @@ -374,16 +372,17 @@ async def grpc_async_request(
token_list = []
request_start_time = time.perf_counter()
response = stub.Decode(request)
async for token in response:
async for sample_list in response:
if ttft == 0:
ttft = time.perf_counter() - request_start_time
token_list.append(token.response[0])
token_list.extend(sample_list.response[0].token_ids)
latency = time.perf_counter() - request_start_time
return token_list, ttft, latency


async def send_request(
api_url: str,
tokenizer: Any,
input_request: InputRequest,
pbar: tqdm,
session_cache: str,
Expand All @@ -405,7 +404,8 @@ async def send_request(
output.ttft = ttft
output.latency = latency
output.generated_token_list = generated_token_list
output.generated_text = "".join(generated_token_list)
# generated_token_list is a list of token ids, decode it to generated_text.
output.generated_text = tokenizer.decode(generated_token_list)
output.success = True
if pbar:
pbar.update(1)
Expand All @@ -415,7 +415,7 @@ async def send_request(
async def benchmark(
api_url: str,
tokenizer: Any,
input_requests: List[InputRequest],
input_requests: list[InputRequest],
request_rate: float,
disable_tqdm: bool,
session_cache: str,
Expand All @@ -433,6 +433,7 @@ async def benchmark(
asyncio.create_task(
send_request(
api_url=api_url,
tokenizer=tokenizer,
input_request=request,
pbar=pbar,
session_cache=session_cache,
Expand All @@ -442,7 +443,7 @@ async def benchmark(
)
outputs = await asyncio.gather(*tasks)

if not disable_tqdm:
if not disable_tqdm and pbar:
pbar.close()

benchmark_duration = time.perf_counter() - benchmark_start_time
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
nltk
evaluate
rouge-score
rouge-score
tqdm
11 changes: 8 additions & 3 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ActiveRequest:
# We keep prefill and decode information together in the same object so that
# there is less indirection about where this return channel is.
# The return channel returns a list of strings, one per sample for that query.
return_channel: async_multifuture.AsyncMultifuture[list[str]]
return_channel: async_multifuture.AsyncMultifuture[list[list[int]]]
# [num_samples,] which corresponds to whether each sample is complete for the
# requests.
complete: Optional[np.ndarray] = None
Expand All @@ -139,7 +139,7 @@ class ActiveRequest:
# Which generate step this was added at.
generate_timestep_added: Optional[int] = None

def enqueue_tokens(self, generated_tokens: list[str]):
def enqueue_tokens(self, generated_tokens: list[list[int]]):
"""Records information about the step.
Args:
Expand Down Expand Up @@ -662,4 +662,9 @@ async def Decode( # pylint: disable=invalid-overridden-method
# The DecodeResponse stream should consume all generated tokens in
# return_channel when complete signal is received. It should check if
# return_channel is empty to decide if it should exit the while loop.
yield jetstream_pb2.DecodeResponse(response=response)
repeated_token_ids = []
for token_ids in response:
repeated_token_ids.append(
jetstream_pb2.RepeatedTokenIds(token_ids=token_ids)
)
yield jetstream_pb2.DecodeResponse(response=repeated_token_ids)
8 changes: 6 additions & 2 deletions jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ message DecodeRequest {
int32 max_tokens = 4;
}
message DecodeResponse {
// List of responses, one per sample.
repeated string response = 1;
// List of responses, one per sample. The list size depends on text generation strategy the engine used.
repeated RepeatedTokenIds response = 1;
}
message RepeatedTokenIds {
// List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1.
repeated int32 token_ids = 1;
}
10 changes: 6 additions & 4 deletions jetstream/core/proto/jetstream_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"e\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x17\n\x0f\x61\x64\x64itional_text\x18\x02 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05""\n\x0e\x44\x65\x63odeResponse\x12\x10\n\x08response\x18\x01 \x03(\t2]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"e\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x17\n\x0f\x61\x64\x64itional_text\x18\x02 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05"E\n\x0e\x44\x65\x63odeResponse\x12\x33\n\x08response\x18\x01 \x03(\x0b\x32!.jetstream_proto.RepeatedTokenIds"%\n\x10RepeatedTokenIds\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x32]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
)

_globals = globals()
Expand All @@ -41,7 +41,9 @@
_globals["_DECODEREQUEST"]._serialized_start = 57
_globals["_DECODEREQUEST"]._serialized_end = 158
_globals["_DECODERESPONSE"]._serialized_start = 160
_globals["_DECODERESPONSE"]._serialized_end = 194
_globals["_ORCHESTRATOR"]._serialized_start = 196
_globals["_ORCHESTRATOR"]._serialized_end = 289
_globals["_DECODERESPONSE"]._serialized_end = 229
_globals["_REPEATEDTOKENIDS"]._serialized_start = 231
_globals["_REPEATEDTOKENIDS"]._serialized_end = 268
_globals["_ORCHESTRATOR"]._serialized_start = 270
_globals["_ORCHESTRATOR"]._serialized_end = 363
# @@protoc_insertion_point(module_scope)
8 changes: 5 additions & 3 deletions jetstream/engine/mock_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def _encode(self, s: str) -> Sequence[int]:

def _decode(self, ids: np.ndarray):
"""Converts a numpy array into a string."""
# 'We use array methods, not python iterables so we don't
# implement this method in the mock vocab.
raise NotImplementedError
return "".join([chr(r) for r in list(ids)])

def _encode_tf(self, s: str) -> np.ndarray:
"""Converts a string into a numpy array."""
Expand All @@ -78,6 +76,10 @@ def _decode_tf(self, ids: np.ndarray) -> List[str]:
results = np.split(ids, ids.shape[0])
return ["".join([chr(r) for r in list(line[0])]) for line in results]

def decode(self, ids: np.ndarray):
"""Converts a numpy array into a string."""
return self._decode(ids)

def encode_tf(self, s: str) -> np.ndarray:
"""Converts a string into a numpy array."""
return self._encode_tf(s)
Expand Down
34 changes: 5 additions & 29 deletions jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@
from jetstream.engine import mock_utils


def mix_decode(vocab: Vocabulary, tok_id: int):
"""
The IdToPiece and decode results differ for 344 tokens in Llama2.
Use the decode function to generate the correct strings for these 344 tokens.
If IdToPiece returns a hex string (e.g., '<0x0A>') for a token within these
344, utilize IdToPiece to convert it into a string, likely with a space
placeholder (' ') for the corresponding tokens.
"""
p_token = vocab.tokenizer.IdToPiece(tok_id)
# SentencePiece escapes the whitespace with a meta symbol "▁" (U+2581)
p_token = p_token.replace("▁", " ")
d_token = vocab.tokenizer.decode([tok_id])
return p_token if p_token.lstrip() == d_token else d_token


def take_nearest_length(lengths: list[int], length: int) -> int:
"""Gets the nearest length to the right in a set of lengths."""
pos = bisect_left(lengths, length)
Expand Down Expand Up @@ -131,7 +116,7 @@ def process_result_tokens(
vocab: Vocabulary,
complete: np.ndarray,
debug: bool = False,
) -> Tuple[List[str], np.ndarray]:
) -> Tuple[List[List[int]], np.ndarray]:
"""Processes a result tokens into a list of strings, handling multiple
samples.
Expand All @@ -145,7 +130,7 @@ def process_result_tokens(
debug: Whether to log step by step detokenisation.
Returns:
sample_return: List of strings, one per sample.
sample_return: List of tok_id list, one list per sample.
complete: Updated complete.
"""
# tokens: [samples, speculations]
Expand All @@ -166,7 +151,7 @@ def process_result_tokens(
)
sample_return = []
for idx in range(samples):
string_so_far = ""
tok_id_so_far = []
if not complete[idx].item():
for spec_idx in range(speculations):
tok_id = slot_tokens[idx, spec_idx].item()
Expand All @@ -182,17 +167,8 @@ def process_result_tokens(
complete[idx] = True
break
else:
try:
# pytype: disable=attribute-error
token = mix_decode(vocab, tok_id)
except ValueError:
# This error only occurs when using tests where the vocab range is
# computed via addition and int->char is computed using chr(). Real
# models have vocab logits which are at max the size of the vocab.
logging.warning("%d exceeded vocab range", tok_id)
token = "<sampled_outside_vocab>"
string_so_far += token
sample_return.append(string_so_far)
tok_id_so_far.append(tok_id)
sample_return.append(tok_id_so_far)
if debug:
logging.info("Sampled return %s", str(sample_return))
return sample_return, complete
Expand Down
Loading

0 comments on commit a0df320

Please sign in to comment.