Skip to content

Commit

Permalink
Enable batch-unbatch by default (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya authored Aug 26, 2024
1 parent 144476b commit a65fadf
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 14 deletions.
17 changes: 3 additions & 14 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,19 @@ def batch(self, inputs):

return numpy.stack(inputs)

if self.stream:
message = no_batch_unbatch_message_stream(self, inputs)
else:
message = no_batch_unbatch_message_no_stream(self, inputs)
raise NotImplementedError(message)
return inputs

@abstractmethod
def predict(self, x, **kwargs):
"""Run the model on the input and return or yield the output."""
pass

def _unbatch_no_stream(self, output):
if hasattr(output, "__torch_function__") or output.__class__.__name__ == "ndarray":
return list(output)
message = no_batch_unbatch_message_no_stream(self, output)
raise NotImplementedError(message)
return list(output)

def _unbatch_stream(self, output_stream):
for output in output_stream:
if hasattr(output, "__torch_function__") or output.__class__.__name__ == "ndarray":
yield list(output)
else:
message = no_batch_unbatch_message_no_stream(self, output)
raise NotImplementedError(message)
yield list(output)

def unbatch(self, output):
"""Convert a batched output to a list of outputs."""
Expand Down
81 changes: 81 additions & 0 deletions tests/test_litapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import numpy as np

import litserve as ls


class TestDefaultBatchedAPI(ls.LitAPI):
def setup(self, device) -> None:
self.model = lambda x: len(x)

def decode_request(self, request):
return request["input"]

def predict(self, x):
return self.model(x)

def encode_response(self, output):
return {"output": output}


class TestCustomBatchedAPI(TestDefaultBatchedAPI):
def batch(self, inputs):
return np.stack(inputs)

def unbatch(self, output):
return list(output)


class TestStreamAPI(ls.LitAPI):
def setup(self, device) -> None:
self.model = None

def decode_request(self, request):
return request["input"]

def predict(self, x):
# x is a list of integers
for i in range(4):
yield np.asarray(x) * i

def encode_response(self, output_stream):
for output in output_stream:
output = list(output)
yield [{"output": o} for o in output]


def test_default_batch_unbatch():
api = TestDefaultBatchedAPI()
api._sanitize(max_batch_size=4, spec=None)
inputs = [1, 2, 3, 4]
output = api.batch(inputs)
assert output == inputs, "Default batch should not change input"
assert api.unbatch(output) == inputs, "Default unbatch should not change input"


def test_custom_batch_unbatch():
api = TestCustomBatchedAPI()
api._sanitize(max_batch_size=4, spec=None)
inputs = [1, 2, 3, 4]
output = api.batch(inputs)
assert np.all(output == np.array(inputs)), "Custom batch stacks input as numpy array"
assert api.unbatch(output) == inputs, "Custom unbatch should unstack input as list"


def test_batch_unbatch_stream():
api = TestStreamAPI()
api._sanitize(max_batch_size=4, spec=None)
inputs = [1, 2, 3, 4]
output = api.batch(inputs)
output = api.predict(output)
output = api.unbatch(output)
output = api.encode_response(output)
first_resp = [o["output"] for o in next(output)]
expected_outputs = [[0, 0, 0, 0], [1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]]
assert first_resp == expected_outputs[0], "First response should be 0s"
count = 1
for out, expected_output in zip(output, expected_outputs[1:]):
resp = [o["output"] for o in out]
assert resp == expected_output
count += 1

assert count == 4, "Should have 4 responses"

0 comments on commit a65fadf

Please sign in to comment.