Skip to content

Commit

Permalink
format encoded response (#85)
Browse files Browse the repository at this point in the history
* format encoded response

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
aniketmaurya and pre-commit-ci[bot] authored May 14, 2024
1 parent 7c658e2 commit 56b691c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
from abc import ABC, abstractmethod

from pydantic import BaseModel


def no_batch_unbatch_message_no_stream(obj, data):
return f"""
Expand Down Expand Up @@ -113,6 +116,13 @@ def encode_response(self, output):
"""
pass

def format_encoded_response(self, data):
if isinstance(data, dict):
return json.dumps(data)
if isinstance(data, BaseModel):
return data.model_dump_json()
return data

@property
def stream(self):
return self._stream
Expand Down
4 changes: 3 additions & 1 deletion src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def run_single_loop(lit_api, request_queue: Queue, request_buffer):
pipe_s.send((pickle.dumps(e), LitAPIStatus.ERROR))


def run_streaming_loop(lit_api, request_queue: Queue, request_buffer):
def run_streaming_loop(lit_api: LitAPI, request_queue: Queue, request_buffer):
while True:
try:
uid = request_queue.get(timeout=1.0)
Expand All @@ -171,6 +171,7 @@ def run_streaming_loop(lit_api, request_queue: Queue, request_buffer):
y_enc_gen = lit_api.encode_response(y_gen)
for y_enc in y_enc_gen:
with contextlib.suppress(BrokenPipeError):
y_enc = lit_api.format_encoded_response(y_enc)
pipe_s.send((y_enc, LitAPIStatus.OK))
with contextlib.suppress(BrokenPipeError):
pipe_s.send(("", LitAPIStatus.FINISH_STREAMING))
Expand Down Expand Up @@ -205,6 +206,7 @@ def run_batched_streaming_loop(lit_api, request_queue: Queue, request_buffer, ma
for y_batch in y_enc_iter:
for y_enc, pipe_s in zip(y_batch, pipes):
with contextlib.suppress(BrokenPipeError):
y_enc = lit_api.format_encoded_response(y_enc)
pipe_s.send((y_enc, LitAPIStatus.OK))

for pipe_s in pipes:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def fake_encode(output):
fake_stream_api.decode_request = MagicMock(side_effect=lambda x: x["prompt"])
fake_stream_api.predict = MagicMock(side_effect=fake_predict)
fake_stream_api.encode_response = MagicMock(side_effect=fake_encode)
fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x)

_, requests_queue, request_buffer = loop_args
request_buffer = Manager().dict()
Expand Down Expand Up @@ -249,6 +250,7 @@ def fake_encode(output_iter):
fake_stream_api.predict = MagicMock(side_effect=fake_predict)
fake_stream_api.encode_response = MagicMock(side_effect=fake_encode)
fake_stream_api.unbatch = MagicMock(side_effect=lambda inputs: inputs)
fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x)

_, requests_queue, request_buffer = loop_args
request_buffer = Manager().dict()
Expand Down

0 comments on commit 56b691c

Please sign in to comment.