Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched streaming #55

Merged
merged 21 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ LitServe supports multiple advanced state-of-the-art features.
| Automatic schema validation | ✅ |
| Handle timeouts | ✅ |
| Handle disconnects | ✅ |
| Streaming | in progress... |
| Streaming | |

> [!NOTE]
> Our goal is not to jump on every hype train, but instead support features that scale
Expand Down
104 changes: 100 additions & 4 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from abc import ABC, abstractmethod


def no_batch_unbatch_message(obj, data):
def no_batch_unbatch_message_no_stream(obj, data):
return f"""
You set `max_batch_size > 1`, but the default implementation for batch() and unbatch() only supports
PyTorch tensors or NumPy ndarrays, while we found {type(data)}.
Expand All @@ -30,7 +31,26 @@ def unbatch(self, output):
"""


def no_batch_unbatch_message_stream(obj, data):
return f"""
You set `max_batch_size > 1`, but the default implementation for batch() and unbatch() only supports
PyTorch tensors or NumPy ndarrays, while we found {type(data)}.
Please implement these two methods in {obj.__class__.__name__}.

Example:

def batch(self, inputs):
return np.stack(inputs)

def unbatch(self, output):
for out in output:
yield list(out)
"""


class LitAPI(ABC):
_stream: bool = False

@abstractmethod
def setup(self, devices):
"""Setup the model so it can be called in `predict`."""
Expand All @@ -53,7 +73,12 @@ def batch(self, inputs):
import numpy

return numpy.stack(inputs)
raise NotImplementedError(no_batch_unbatch_message(self, inputs))

if self.stream:
message = no_batch_unbatch_message_stream(self, inputs)
else:
message = no_batch_unbatch_message_no_stream(self, inputs)
raise NotImplementedError(message)

@abstractmethod
def predict(self, x):
Expand All @@ -63,8 +88,16 @@ def predict(self, x):
def unbatch(self, output):
"""Convert a batched output to a list of outputs."""
if hasattr(output, "__torch_function__") or output.__class__.__name__ == "ndarray":
return list(output)
raise NotImplementedError(no_batch_unbatch_message(self, output))
if self._stream:
yield from list(output)
else:
return list(output)

if self.stream:
message = no_batch_unbatch_message_stream(self, output)
else:
message = no_batch_unbatch_message_no_stream(self, output)
raise NotImplementedError(message)

@abstractmethod
def encode_response(self, output):
Expand All @@ -74,3 +107,66 @@ def encode_response(self, output):

"""
pass

@property
def stream(self):
return self._stream

@stream.setter
def stream(self, value):
self._stream = value

def sanitize(self, max_batch_size: int):
if (
self.stream
and max_batch_size > 1
and not all([
inspect.isgeneratorfunction(self.predict),
inspect.isgeneratorfunction(self.encode_response),
inspect.isgeneratorfunction(self.unbatch),
])
):
raise ValueError(
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
`lit_api.unbatch` must generate values using `yield`.

Example:

def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction

def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output

def unbatch(self, outputs):
for output in outputs:
unbatched_output = ...
yield unbatched_output
"""
)

if self.stream and not all([
inspect.isgeneratorfunction(self.predict),
inspect.isgeneratorfunction(self.encode_response),
]):
raise ValueError(
"""When `stream=True` both `lit_api.predict` and
`lit_api.encode_response` must generate values using `yield`.

Example:

def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction

def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
"""
)
70 changes: 43 additions & 27 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,50 @@ def run_streaming_loop(lit_api, request_queue: Queue, request_buffer):
pipe_s.send((pickle.dumps(e), LitAPIStatus.ERROR))


def run_batched_streaming_loop(lit_api, request_queue: Queue, request_buffer, max_batch_size, batch_timeout):
lantiga marked this conversation as resolved.
Show resolved Hide resolved
while True:
batches = collate_requests(
lit_api,
request_queue,
request_buffer,
max_batch_size,
batch_timeout,
)
if not batches:
continue

inputs, pipes = zip(*batches)

try:
x = lit_api.batch(inputs)
y_iter = lit_api.predict(x)
unbatched_iter = lit_api.unbatch(y_iter)
y_enc_iter = lit_api.encode_response(unbatched_iter)

# y_enc_iter -> [[response-1, response-2], [response-1, response-2]]
for y_batch in y_enc_iter:
for y_enc, pipe_s in zip(y_batch, pipes):
with contextlib.suppress(BrokenPipeError):
pipe_s.send((y_enc, LitAPIStatus.OK))

for pipe_s in pipes:
pipe_s.send(("", LitAPIStatus.FINISH_STREAMING))
except Exception as e:
logging.exception(e)
err = pickle.dumps(e)
for pipe_s in pipes:
pipe_s.send((err, LitAPIStatus.ERROR))


def inference_worker(lit_api, device, worker_id, request_queue, request_buffer, max_batch_size, batch_timeout, stream):
lit_api.setup(device=device)
if stream:
run_streaming_loop(lit_api, request_queue, request_buffer)
if max_batch_size > 1:
run_batched_streaming_loop(lit_api, request_queue, request_buffer, max_batch_size, batch_timeout)
else:
run_streaming_loop(lit_api, request_queue, request_buffer)
return

if max_batch_size > 1:
run_batched_loop(lit_api, request_queue, request_buffer, max_batch_size, batch_timeout)
else:
Expand Down Expand Up @@ -227,7 +266,7 @@ async def lifespan(app: FastAPI):


class LitServer:
# TODO: add support for accelerator="auto", devices="auto"
# TODO: add support for devices="auto"
def __init__(
self,
lit_api: LitAPI,
Expand All @@ -244,31 +283,8 @@ def __init__(
if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")

if stream and max_batch_size > 1:
raise ValueError("streaming is not supported with automatic batching at this time.")

if stream and not all([
inspect.isgeneratorfunction(lit_api.predict),
inspect.isgeneratorfunction(lit_api.encode_response),
]):
raise ValueError(
"""When `stream=True` both `lit_api.predict` and
`lit_api.encode_response` must generate values using `yield`.

Example:

def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction

def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
"""
)

lit_api.stream = stream
lit_api.sanitize(max_batch_size)
self.app = FastAPI(lifespan=lifespan)
self.app.lit_api = lit_api
self.app.workers_per_device = workers_per_device
Expand Down
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ def encode_response(self, output: Generator) -> Generator:
yield out.lower()


class SimpleBatchedStreamAPI(LitAPI):
def setup(self, device) -> None:
self.sentence = "LitServe is streaming output"

def decode_request(self, request: Request) -> str:
return request["prompt"]

def batch(self, inputs):
return inputs

def predict(self, x) -> Generator:
n = len(x)
output = self.sentence.split()
responses = [x]
for out in output:
responses.append([out] * n)
yield from responses

def encode_response(self, output: Generator) -> Generator:
delay = 0.01 # delay for testing timeouts
for out in output:
time.sleep(delay)
yield [e.lower() for e in out]

def unbatch(self, output):
yield from output


@pytest.fixture()
def simple_litapi():
return SimpleLitAPI()
Expand All @@ -63,6 +91,11 @@ def simple_stream_api():
return SimpleStreamAPI()


@pytest.fixture()
def simple_batched_stream_api():
return SimpleBatchedStreamAPI()


@pytest.fixture()
def lit_server(simple_litapi):
return LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
Expand Down
76 changes: 75 additions & 1 deletion tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from unittest.mock import patch, MagicMock

from litserve.connector import _Connector
from litserve.server import inference_worker, run_single_loop, run_streaming_loop
from litserve.server import (
inference_worker,
run_single_loop,
run_streaming_loop,
LitAPIStatus,
run_batched_streaming_loop,
)
from litserve.server import LitServer

import pytest
Expand Down Expand Up @@ -146,6 +152,22 @@ async def test_stream(simple_stream_api):
assert resp2.text == expected_output2, "Server returns input prompt and generated output which didn't match."


@pytest.mark.asyncio()
async def test_batched_stream_server(simple_batched_stream_api):
server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30)
expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "")
expected_output2 = "World LitServe is streaming output".lower().replace(" ", "")

async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp1 = ac.post("/stream-predict", json={"prompt": "Hello"}, timeout=10)
resp2 = ac.post("/stream-predict", json={"prompt": "World"}, timeout=10)
resp1, resp2 = await asyncio.gather(resp1, resp2)
assert resp1.status_code == 200, "Check if server is running and the request format is valid."
assert resp2.status_code == 200, "Check if server is running and the request format is valid."
assert resp1.text == expected_output1, "Server returns input prompt and generated output which didn't match."
assert resp2.text == expected_output2, "Server returns input prompt and generated output which didn't match."


class FakeStreamPipe:
def __init__(self, num_streamed_outputs):
self.num_streamed_outputs = num_streamed_outputs
Expand Down Expand Up @@ -187,6 +209,58 @@ def fake_encode(output):
fake_stream_api.encode_response.assert_called_once()


class FakeBatchedStreamPipe:
def __init__(self, num_streamed_outputs):
self.num_streamed_outputs = num_streamed_outputs
self.count = 0

def send(self, args):
response, status = args
if status == LitAPIStatus.FINISH_STREAMING:
raise StopIteration("interrupt iteration")
if status == LitAPIStatus.ERROR and b"interrupt iteration" in response:
assert self.count == self.num_streamed_outputs, (
f"Loop count must have incremented for " f"{self.num_streamed_outputs} times."
)
raise StopIteration("finish streaming")

assert (
response == f"{self.count}"
), f"streaming loop generates number from 0 to 9 which is sent via Pipe. {args}"
self.count += 1


def test_batched_streaming_loop(loop_args):
num_streamed_outputs = 10

def fake_predict(inputs: list):
n = len(inputs)
for i in range(num_streamed_outputs):
yield [{"output": f"{i}"}] * n

def fake_encode(output_iter):
assert inspect.isgenerator(output_iter), "predict function must be a generator when `stream=True`"
for outputs in output_iter:
yield [output["output"] for output in outputs]

fake_stream_api = MagicMock()
fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
fake_stream_api.batch = MagicMock(side_effect=lambda inputs: inputs)
fake_stream_api.predict = MagicMock(side_effect=fake_predict)
fake_stream_api.encode_response = MagicMock(side_effect=fake_encode)
fake_stream_api.unbatch = MagicMock(side_effect=lambda inputs: inputs)

_, requests_queue, request_buffer = loop_args
request_buffer = Manager().dict()
request_buffer[1] = {"prompt": "Hello"}, FakeBatchedStreamPipe(num_streamed_outputs)
request_buffer[2] = {"prompt": "World"}, FakeBatchedStreamPipe(num_streamed_outputs)

with pytest.raises(StopIteration, match="finish streaming"):
run_batched_streaming_loop(fake_stream_api, requests_queue, request_buffer, max_batch_size=2, batch_timeout=2)
fake_stream_api.predict.assert_called_once_with(("Hello", "World"))
fake_stream_api.encode_response.assert_called_once()


def test_litapi_with_stream(simple_litapi):
with pytest.raises(
ValueError,
Expand Down
Loading