Skip to content

Commit

Permalink
bugfix: cover disabled request timeout scenario for collate_requests (
Browse files Browse the repository at this point in the history
#167)

* fix timeout

* fix batch case

* apply feedback

* fix

* update test

* reduce batch timeout

* update

* fix timeout
  • Loading branch information
aniketmaurya authored Jul 11, 2024
1 parent 825d200 commit 0c70c20
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
12 changes: 7 additions & 5 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def collate_requests(
timed_out_uids = []
entered_at = time.monotonic()
end_time = entered_at + batch_timeout
apply_timeout = lit_api.request_timeout not in (-1, False)

while time.monotonic() < end_time and len(payloads) < max_batch_size:
remaining_time = end_time - time.monotonic()
Expand All @@ -85,10 +86,11 @@ def collate_requests(

try:
uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001))
if time.monotonic() - timestamp <= lit_api.request_timeout:
payloads.append((uid, timestamp, x_enc))
else:
if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout:
timed_out_uids.append(uid)
else:
payloads.append((uid, x_enc))

except Empty:
continue

Expand Down Expand Up @@ -169,7 +171,7 @@ def run_batched_loop(
if not batches:
continue
logger.debug(f"{len(batches)} batched requests received")
uids, _, inputs = zip(*batches)
uids, inputs = zip(*batches)
try:
contexts = [{}] * len(inputs)
if hasattr(lit_spec, "populate_context"):
Expand Down Expand Up @@ -278,7 +280,7 @@ def run_batched_streaming_loop(

if not batches:
continue
uids, _, inputs = zip(*batches)
uids, inputs = zip(*batches)
try:
contexts = [{}] * len(inputs)
if hasattr(lit_spec, "populate_context"):
Expand Down
47 changes: 42 additions & 5 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import pytest
from asgi_lifespan import LifespanManager
from fastapi import Request, Response
Expand Down Expand Up @@ -81,20 +82,28 @@ def decode_request(self, request: Request):
return request["input"]

def predict(self, x):
time.sleep(1)
time.sleep(2)
return self.model(x)

def encode_response(self, output) -> Response:
return {"output": output}


class SlowBatchAPI(SlowLitAPI):
def batch(self, inputs):
return np.asarray(inputs)

def unbatch(self, output):
return list(output)


@pytest.mark.asyncio()
async def test_timeout():
api = SlowLitAPI() # takes 1 second for each prediction
server = LitServer(api, accelerator="cpu", devices=1, timeout=0.9) # windows CI need more time to process queue
api = SlowLitAPI() # takes 2 second for each prediction
server = LitServer(api, accelerator="cpu", devices=1, timeout=1.9)

async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
await asyncio.sleep(1) # Give time to start inference workers
await asyncio.sleep(2) # Give time to start inference workers
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)
Expand All @@ -103,15 +112,43 @@ async def test_timeout():
assert response1.status_code == 200, "First request should complete since it's popped from the request queue."
assert response2.status_code == 504, "Server takes longer than specified timeout and request should timeout"

# Batched Server
server = LitServer(SlowBatchAPI(), accelerator="cpu", timeout=1.9, max_batch_size=2, batch_timeout=0.01)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
await asyncio.sleep(2) # Give time to start inference workers
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response3 = ac.post("/predict", json={"input": 6.0})
response1, response2, response3 = await asyncio.gather(response1, response2, response3)
assert (
response1.status_code == 200
), "Batch: First request should complete since it's popped from the request queue."
assert (
response2.status_code == 200
), "Batch: Second request should complete since it's popped from the request queue."

assert response3.status_code == 504, "Batch: Third request was delayed and should fail"

server1 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=-1)
server2 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=False)
server3 = LitServer(SlowBatchAPI(), accelerator="cpu", devices=1, timeout=False, max_batch_size=2, batch_timeout=2)
server4 = LitServer(SlowBatchAPI(), accelerator="cpu", devices=1, timeout=-1, max_batch_size=2, batch_timeout=2)

with TestClient(server1.app) as client1, TestClient(server2.app) as client2:
with TestClient(server1.app) as client1, TestClient(server2.app) as client2, TestClient(
server3.app
) as client3, TestClient(server4.app) as client4:
response1 = client1.post("/predict", json={"input": 4.0})
assert response1.status_code == 200, "Expected slow server to respond since timeout was disabled"

response2 = client2.post("/predict", json={"input": 4.0})
assert response2.status_code == 200, "Expected slow server to respond since timeout was disabled"

response3 = client3.post("/predict", json={"input": 4.0})
assert response3.status_code == 200, "Expected slow batch server to respond since timeout was disabled"

response4 = client4.post("/predict", json={"input": 4.0})
assert response4.status_code == 200, "Expected slow batch server to respond since timeout was disabled"


def test_concurrent_requests():
n_requests = 100
Expand Down

0 comments on commit 0c70c20

Please sign in to comment.