Skip to content

Commit

Permalink
Update JetStream grpc proto to support I/O with text and token ids (#78)
Browse files Browse the repository at this point in the history
* Update JetStream grpc proto to support I/O with text and token ids

* Update orchestrator and token utils to support text and token I/O

* Add and update unit tests

* Fix prometheus duplicate metrics issue

* add shortuuid dep

* Update docstring

* Add client tokenization mode

* Update client side I/O handling

* latest pylint fix
  • Loading branch information
JoeZijunZhou authored May 14, 2024
1 parent 2f8924d commit 01c5a03
Show file tree
Hide file tree
Showing 17 changed files with 632 additions and 245 deletions.
11 changes: 8 additions & 3 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ async def grpc_async_request(
token_list = []
request_start_time = time.perf_counter()
response = stub.Decode(request)
async for sample_list in response:
async for resp in response:
if ttft == 0:
ttft = time.perf_counter() - request_start_time
token_list.extend(sample_list.response[0].token_ids)
token_list.extend(resp.stream_content.samples[0].token_ids)
latency = time.perf_counter() - request_start_time
return token_list, ttft, latency

Expand All @@ -405,9 +405,13 @@ async def send_request(
priority: int,
) -> RequestFuncOutput:
"""Send the request to JetStream server."""
# Tokenization on client side following MLPerf standard.
token_ids = tokenizer.encode(input_request.prompt)
request = jetstream_pb2.DecodeRequest(
session_cache=session_cache,
additional_text=input_request.prompt,
token_content=jetstream_pb2.DecodeRequest.TokenContent(
token_ids=token_ids
),
priority=priority,
max_tokens=input_request.output_len,
)
Expand Down Expand Up @@ -551,6 +555,7 @@ def main(args: argparse.Namespace):
args.total_mock_requests
) # e.g. [("AB", 2, "AB", 3)]
else:
dataset = []
if args.dataset == "openorca":
dataset = load_openorca_dataset_pkl()
elif args.dataset == "sharegpt":
Expand Down
181 changes: 146 additions & 35 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,18 @@
import threading
import time
import traceback
from typing import Any, AsyncIterator, Optional, Union
from typing import Any, AsyncIterator, Optional, Tuple, Union, cast

import grpc
import jax
from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
from jetstream.core.utils import async_multifuture
from jetstream.engine import engine_api

from jetstream.core.utils.return_sample import ReturnSample
from jetstream.engine import engine_api, tokenizer_api, token_utils
import numpy as np
import prometheus_client
import shortuuid

root = logging.getLogger()
root.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -127,27 +128,28 @@ class ActiveRequest:
# We keep prefill and decode information together in the same object so that
# there is less indirection about where this return channel is.
# The return channel returns a list of strings, one per sample for that query.
return_channel: async_multifuture.AsyncMultifuture[list[list[int]]]
return_channel: async_multifuture.AsyncMultifuture[list[ReturnSample]]
# [num_samples,] which corresponds to whether each sample is complete for the
# requests.
complete: Optional[np.ndarray] = None
prefill_result: Any = None
#################### Information relevant for prefill ########################
history_path: Optional[str] = None
prefill_text: Optional[str] = None
prefill_content: Optional[str | list[int]] = None
################## Information relevant for detokenization ###################
# Which generate step this was added at.
generate_timestep_added: Optional[int] = None
is_client_side_tokenization: Optional[bool] = False

def enqueue_tokens(self, generated_tokens: list[list[int]]):
"""Records information about the step.
def enqueue_samples(self, generated_samples: list[ReturnSample]):
"""Adds the generated sample(s) to return channel for current step.
Args:
generated_tokens: One token to put into the return channel
generated_samples: The generated sample(s) for current step.
This should be called only from within the Drivers background thread.
"""
self.return_channel.add_result(generated_tokens)
self.return_channel.add_result(generated_samples)


class JetThread(threading.Thread):
Expand Down Expand Up @@ -247,7 +249,8 @@ def __init__(
# At first, a request is placed here in order to get prefilled.
self._prefill_backlog = queue.Queue()
self._prefill_backlog_size_metric = prometheus_client.Gauge(
"jetstream_prefill_backlog_size", "Size of prefill queue"
f"jetstream_prefill_backlog_size_{shortuuid.uuid()}",
"Size of prefill queue",
)

# Stage 2
Expand Down Expand Up @@ -438,6 +441,33 @@ def _load_cache_history(self, path: str) -> Union[None, Any]:
else:
return None

def _process_prefill_content(
self,
request: ActiveRequest,
tokenizer: tokenizer_api.Tokenizer,
is_bos: bool,
max_prefill_length: int,
) -> Tuple[jax.Array | np.ndarray, int]:
content = request.prefill_content
if isinstance(content, str):
# If it's text input, tokenize and pad the input.
return tokenizer.encode(
content,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
jax_padding=self._jax_padding,
)
else:
# If it's token input, pad the input.
return token_utils.pad_tokens(
content,
tokenizer.bos_id,
tokenizer.pad_id,
is_bos=is_bos,
max_prefill_length=max_prefill_length,
jax_padding=self._jax_padding,
)

def _prefill_thread(self, idx: int):
"""Thread which runs in the background performing prefills."""
logging.info("---------Spinning up prefill thread %d.---------", idx)
Expand All @@ -455,7 +485,6 @@ def _prefill_thread(self, idx: int):

if request is None:
break
# Tokenize, and introduce a leading dimension
is_bos = not bool(request.history_path)
logging.info(
"Prefilling on prefill engine %d : prefill queue size, %d,"
Expand All @@ -465,13 +494,11 @@ def _prefill_thread(self, idx: int):
is_bos,
request.history_path,
)
padded_tokens, true_length = tokenizer.encode(
request.prefill_text,
is_bos=is_bos,
max_prefill_length=prefill_engine.max_prefill_length,
jax_padding=self._jax_padding,
# Tokenize and padding the text or token input.
padded_tokens, true_length = self._process_prefill_content(
request, tokenizer, is_bos, prefill_engine.max_prefill_length
)
# Compute new kv cache for the prefill_text.
# Compute new kv cache for the prefill_content.
prefill_result = prefill_engine.prefill(
params=prefill_params,
padded_tokens=padded_tokens,
Expand All @@ -497,6 +524,8 @@ def _transfer_thread(self, idx: int):
while self.live:
# The transfer thread can just sleep until it has work to do.
new_request = transfer_backlog.get(block=True)
if new_request is None:
break
target_idx = min(
self._generate_backlogs.items(), key=lambda q: q[1].qsize()
)[0]
Expand Down Expand Up @@ -665,15 +694,17 @@ def _detokenize_thread(self, idx: int):

for slot, request in my_live_requests.items():
if request is not None:
results, complete = tokenizer.decode(
results, complete = token_utils.process_result_tokens(
tokenizer=tokenizer,
slot=slot,
slot_max_length=request.max_tokens,
result_tokens=result_tokens,
is_client_side_tokenization=request.is_client_side_tokenization,
complete=request.complete,
)
request.complete = complete
# Return some tokens.
request.enqueue_tokens(results)
# Return some output samples.
request.enqueue_samples(results)
if request.complete.all():
request.return_channel.close()
# Place the slot back on the free queue.
Expand All @@ -698,6 +729,21 @@ class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer):
def __init__(self, driver: Driver):
self._driver = driver

def _get_prefill_content(
self, request: jetstream_pb2.DecodeRequest
) -> Tuple[str | list[int], bool]:
which_content = request.WhichOneof("content")
content = getattr(request, which_content)
if which_content == "text_content":
return cast(jetstream_pb2.DecodeRequest.TextContent, content).text, False
else:
return (
list(
cast(jetstream_pb2.DecodeRequest.TokenContent, content).token_ids
),
True,
)

async def Decode( # pylint: disable=invalid-overridden-method
self,
request: jetstream_pb2.DecodeRequest,
Expand All @@ -709,14 +755,19 @@ async def Decode( # pylint: disable=invalid-overridden-method
"LLM orchestrator is being used in offline test mode, and will not"
" respond to gRPC queries - only direct function calls."
)
is_client_side_tokenization = False
return_channel = async_multifuture.AsyncMultifuture()
if context:
context.add_done_callback(return_channel.cancel)
prefill_content, is_client_side_tokenization = self._get_prefill_content(
request
)
# Wrap request as an ActiveRequest.
active_request = ActiveRequest(
max_tokens=request.max_tokens,
history_path=request.session_cache,
prefill_text=request.additional_text,
prefill_content=prefill_content,
is_client_side_tokenization=is_client_side_tokenization,
return_channel=return_channel,
)
# The first stage is being prefilled, all other stages are handled
Expand All @@ -736,18 +787,78 @@ async def Decode( # pylint: disable=invalid-overridden-method
logging.info(
"Placed request on the prefill queue.",
)
async for response in active_request.return_channel:
# When an active request is created a queue is instantiated. New tokens
# are placed there during the decoding loop, we pop from that queue by
# using the .next method on the active request.
# Yielding allows for the response to be a streaming grpc call - which
# can be called via iterating over a for loop on the other side.
# The DecodeResponse stream should consume all generated tokens in
# return_channel when complete signal is received. It should check if
# return_channel is empty to decide if it should exit the while loop.
repeated_token_ids = []
for token_ids in response:
repeated_token_ids.append(
jetstream_pb2.RepeatedTokenIds(token_ids=token_ids)
# When an active request is created a queue is instantiated. New tokens
# are placed there during the decoding loop, we pop from that queue by
# using the .next method on the active request.
# Yielding allows for the response to be a streaming grpc call - which
# can be called via iterating over a for loop on the client side.
# The DecodeResponse stream should consume all generated tokens in
# return_channel when complete signal is received (AsyncMultifuture
# promises this).
if is_client_side_tokenization:
# If is_client_side_tokenization, the client should request with token
# ids, and the JetStream server will return token ids as response.
# The client should take care of tokenization and detokenization.
async for response in active_request.return_channel:
response = cast(list[ReturnSample], response)
samples = []
for sample in response:
samples.append(
jetstream_pb2.DecodeResponse.StreamContent.Sample(
token_ids=sample.token_ids,
)
)
yield jetstream_pb2.DecodeResponse(
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
samples=samples
)
)
else:
# Buffer response mechanism is used to handle streaming
# detokenization with special character (For some edge cases with
# SentencePiece tokenizer, it requires to decode a complete sequence
# instead of a single token).
buffered_response_list = []
async for response in active_request.return_channel:
response = cast(list[ReturnSample], response)
buffered = False
for item in response:
if item.text and token_utils.is_byte_token(item.text[-1]):
# If any sample ends in bytes, this means we might still need to
# decode more bytes to compose the string.
buffered_response_list.append(response)
buffered = True
break
if buffered:
continue
# Flush the buffered responses to each sample of current response.
current_response_with_flushed_buffer = list(
zip(*buffered_response_list, response)
)
# Empty buffer: [[s0_cur], [s1_cur], ...]
# Has buffer:
# [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...]
current_response_with_flushed_buffer = cast(
list[list[ReturnSample]], current_response_with_flushed_buffer
)
# Reset buffer after flushed.
buffered_response_list = []
# Form correct sample(s) and return as StreamContent for this iteration.
samples = []
for sample in current_response_with_flushed_buffer:
text = []
token_ids = []
for resp in sample:
text.extend(resp.text)
token_ids.extend(resp.token_ids)
samples.append(
jetstream_pb2.DecodeResponse.StreamContent.Sample(
text=token_utils.text_tokens_to_str(text),
token_ids=token_ids,
)
)
yield jetstream_pb2.DecodeResponse(
stream_content=jetstream_pb2.DecodeResponse.StreamContent(
samples=samples
)
)
yield jetstream_pb2.DecodeResponse(response=repeated_token_ids)
49 changes: 40 additions & 9 deletions jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ package jetstream_proto;
// TODO: Merge this with main JetStream core once we settle on an API.

service Orchestrator {
// Generate the next model tokens.
// Query LLM to generate text or tokens.
rpc Decode(DecodeRequest) returns (stream DecodeResponse) {}
}

message DecodeRequest {
// Where to load any pre-existing kv cache from.
string session_cache = 1;
// New text from a user or tool.
string additional_text = 2;
int32 priority = 3;
// The maximum output length of a sequence. It's used in JetStream to control
// the output/decode length of a sequence. It would not be used in the engine.
Expand All @@ -35,12 +34,44 @@ message DecodeRequest {
// sequence; max_prefill_predict_length is the maximum length of the
// input/prefill of a sequence.
int32 max_tokens = 4;

message TextContent {
string text = 1;
}
message TokenContent {
repeated int32 token_ids = 1;
}

// The client can pass the inputs either as a string, in which case the server will
// tokenize it, or as tokens, in which case it's the client's responsibility to
// ensure they tokenize its input strings with the correct tokenizer.
oneof content {
TextContent text_content = 5;
TokenContent token_content = 6;
}
reserved 2;
// Next ID: 7
}

message DecodeResponse {
// List of responses, one per sample. The list size depends on text generation strategy the engine used.
repeated RepeatedTokenIds response = 1;
}
message RepeatedTokenIds {
// List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1.
repeated int32 token_ids = 1;
// InitialContent supports returning initial one-off response data from the
// stream. It's a placeholder for future features such as history cache.
message InitialContent {}
message StreamContent {
message Sample {
// The text string decoded from token id(s).
string text = 1;
// List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1.
repeated int32 token_ids = 2;
}
// Supports multiple samples in the StreamContent. The Sample list size depends on text generation strategy the engine used.
repeated Sample samples = 1;
}

oneof content {
InitialContent initial_content = 2;
StreamContent stream_content = 3;
}
reserved 1;
// Next ID: 4
}
Loading

0 comments on commit 01c5a03

Please sign in to comment.