diff --git a/.github/workflows/ci-minimal-dependency-check.yml b/.github/workflows/ci-minimal-dependency-check.yml index d6e75471..5d8b531c 100644 --- a/.github/workflows/ci-minimal-dependency-check.yml +++ b/.github/workflows/ci-minimal-dependency-check.yml @@ -21,7 +21,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: "3.10" - name: Install LitServe run: | diff --git a/_requirements/test.txt b/_requirements/test.txt index c51c52e4..81031b3b 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -11,3 +11,4 @@ lightning >2.0.0 torch >2.0.0 transformers openai>=1.12.0 +numpy <2.0 diff --git a/src/litserve/server.py b/src/litserve/server.py index 3d5b7e49..9179d899 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -20,6 +20,7 @@ import pickle import shutil import sys +import threading import time import uuid from concurrent.futures import ThreadPoolExecutor @@ -43,6 +44,16 @@ from litserve.utils import LitAPIStatus, load_and_raise from collections import deque +mp.allow_connection_pickling() + +try: + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +except ImportError: + print("uvloop is not installed. Falling back to the default asyncio event loop.") + logger = logging.getLogger(__name__) # if defined, it will require clients to auth with X-API-Key in the header @@ -88,11 +99,11 @@ def collate_requests( break try: - uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) + response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: - timed_out_uids.append(uid) + timed_out_uids.append((response_queue_id, uid)) else: - payloads.append((uid, x_enc)) + payloads.append((response_queue_id, uid, x_enc)) except Empty: continue @@ -100,10 +111,10 @@ def collate_requests( return payloads, timed_out_uids -def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queue: Queue): +def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): while True: try: - uid, timestamp, x_enc = request_queue.get(timeout=1.0) + response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) except (Empty, ValueError): continue @@ -115,7 +126,7 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queue.put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) continue try: context = {} @@ -136,7 +147,7 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re lit_api.encode_response, y, ) - response_queue.put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: logger.exception( "LitAPI ran into an error while processing the request uid=%s.\n" @@ -144,14 +155,14 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re uid, ) err_pkl = pickle.dumps(e) - response_queue.put((uid, (err_pkl, LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) def run_batched_loop( lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, - response_queue: Queue, + response_queues: List[Queue], max_batch_size: int, batch_timeout: float, ): @@ -163,18 +174,18 @@ def run_batched_loop( batch_timeout, ) - for uid in timed_out_uids: + for response_queue_id, uid in timed_out_uids: logger.error( f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and " "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queue.put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) if not batches: continue logger.debug(f"{len(batches)} batched requests received") - uids, inputs = zip(*batches) + response_queue_ids, uids, inputs = zip(*batches) try: contexts = [{}] * len(inputs) if hasattr(lit_spec, "populate_context"): @@ -192,10 +203,10 @@ def run_batched_loop( x = lit_api.batch(x) y = _inject_context(contexts, lit_api.predict, x) outputs = lit_api.unbatch(y) - for y, uid, context in zip(outputs, uids, contexts): + for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts): y_enc = _inject_context(context, lit_api.encode_response, y) - response_queue.put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: logger.exception( @@ -203,14 +214,14 @@ def run_batched_loop( "Please check the error trace for more details." ) err_pkl = pickle.dumps(e) - for uid in uids: - response_queue.put((uid, (err_pkl, LitAPIStatus.ERROR))) + for response_queue_id, uid in zip(response_queue_ids, uids): + response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) -def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queue: Queue): +def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): while True: try: - uid, timestamp, x_enc = request_queue.get(timeout=1.0) + response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) logger.debug("uid=%s", uid) except (Empty, ValueError): continue @@ -223,7 +234,7 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queue.put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) continue try: @@ -247,22 +258,22 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, ) for y_enc in y_enc_gen: y_enc = lit_api.format_encoded_response(y_enc) - response_queue.put((uid, (y_enc, LitAPIStatus.OK))) - response_queue.put((uid, ("", LitAPIStatus.FINISH_STREAMING))) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) except Exception as e: logger.exception( "LitAPI ran into an error while processing the streaming request uid=%s.\n" "Please check the error trace for more details.", uid, ) - response_queue.put((uid, (pickle.dumps(e), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (pickle.dumps(e), LitAPIStatus.ERROR))) def run_batched_streaming_loop( lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, - response_queue: Queue, + response_queues: List[Queue], max_batch_size: int, batch_timeout: float, ): @@ -273,17 +284,17 @@ def run_batched_streaming_loop( max_batch_size, batch_timeout, ) - for uid in timed_out_uids: + for response_queue_id, uid in timed_out_uids: logger.error( f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and " "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queue.put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) if not batches: continue - uids, inputs = zip(*batches) + response_queue_ids, uids, inputs = zip(*batches) try: contexts = [{}] * len(inputs) if hasattr(lit_spec, "populate_context"): @@ -305,12 +316,12 @@ def run_batched_streaming_loop( # y_enc_iter -> [[response-1, response-2], [response-1, response-2]] for y_batch in y_enc_iter: - for y_enc, uid in zip(y_batch, uids): + for response_queue_id, y_enc, uid in zip(response_queue_ids, y_batch, uids): y_enc = lit_api.format_encoded_response(y_enc) - response_queue.put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) - for uid in uids: - response_queue.put((uid, ("", LitAPIStatus.FINISH_STREAMING))) + for response_queue_id, uid in zip(response_queue_ids, uids): + response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) except Exception as e: logger.exception( @@ -318,7 +329,7 @@ def run_batched_streaming_loop( "Please check the error trace for more details." ) err_pkl = pickle.dumps(e) - response_queue.put((uid, (err_pkl, LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) def inference_worker( @@ -327,7 +338,7 @@ def inference_worker( device: str, worker_id: int, request_queue: Queue, - response_queue: Queue, + response_queues: List[Queue], max_batch_size: int, batch_timeout: float, stream: bool, @@ -335,28 +346,29 @@ def inference_worker( ): lit_api.setup(device) lit_api.device = device + + print(f"Setup complete for worker {worker_id}.") + if workers_setup_status: workers_setup_status[worker_id] = True - message = f"Setup complete for worker {worker_id}." - print(message) - logger.info(message) + if lit_spec: logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") if stream: if max_batch_size > 1: - run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queue, max_batch_size, batch_timeout) + run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) else: - run_streaming_loop(lit_api, lit_spec, request_queue, response_queue) + run_streaming_loop(lit_api, lit_spec, request_queue, response_queues) return if max_batch_size > 1: - run_batched_loop(lit_api, lit_spec, request_queue, response_queue, max_batch_size, batch_timeout) + run_batched_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) else: run_single_loop( lit_api, lit_spec, request_queue, - response_queue, + response_queues, ) @@ -391,11 +403,7 @@ async def response_queue_to_buffer( else: while True: - try: - uid, payload = await loop.run_in_executor(response_executor, response_queue.get) - except Empty: - await asyncio.sleep(0.0001) - continue + uid, payload = await loop.run_in_executor(response_executor, response_queue.get) event = buffer.pop(uid) buffer[uid] = payload event.set() @@ -434,6 +442,9 @@ def __init__( lit_api.request_timeout = timeout lit_api.sanitize(max_batch_size, spec=spec) self.app = FastAPI(lifespan=self.lifespan) + self.app.response_queue_id = None + self.response_queue_id = None + self.response_buffer = {} # gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448 if not stream: self.app.add_middleware(GZipMiddleware, minimum_size=1000) @@ -472,51 +483,35 @@ def __init__( device_list = range(devices) self.devices = [self.device_identifiers(accelerator, device) for device in device_list] + self.workers = self.devices * self.workers_per_device self.setup_server() - @asynccontextmanager - async def lifespan(self, app: FastAPI): + def launch_inference_worker(self, num_uvicorn_servers: int): manager = mp.Manager() - self.request_queue = manager.Queue() - self.response_buffer = {} self.workers_setup_status = manager.dict() + self.request_queue = manager.Queue() - response_queues = [] - tasks: List[asyncio.Task] = [] - - def close_tasks(): - for task in tasks: - task.cancel() - response_executor.shutdown(wait=False, cancel_futures=True) - - try: - pickle.dumps(self.lit_api) - pickle.dumps(self.lit_spec) - - except (pickle.PickleError, AttributeError) as e: - logging.error( - "The LitAPI instance provided to LitServer cannot be moved to a worker because " - "it cannot be pickled. Please ensure all heavy-weight operations, like model " - "creation, are defined in LitAPI's setup method." - ) - raise e + self.response_queues = [] + for _ in range(num_uvicorn_servers): + response_queue = manager.Queue() + self.response_queues.append(response_queue) - loop = asyncio.get_running_loop() - response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) + for spec in self._specs: + # Objects of Server class are referenced (not copied) + logging.debug(f"shallow copy for Server is created for for spec {spec}") + server_copy = copy.copy(self) + del server_copy.app + try: + spec.setup(server_copy) + except Exception as e: + raise e process_list = [] - # NOTE: device: str | List[str], the latter in the case a model needs more than one device to run for worker_id, device in enumerate(self.devices * self.workers_per_device): if len(device) == 1: device = device[0] self.workers_setup_status[worker_id] = False - response_queue = manager.Queue() - response_queues.append(response_queue) - - future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) - task = loop.create_task(future, name=f"Response-reader-{worker_id}") - tasks.append(task) ctx = mp.get_context("spawn") process = ctx.Process( @@ -527,35 +522,30 @@ def close_tasks(): device, worker_id, self.request_queue, - response_queue, + self.response_queues, self.max_batch_size, self.batch_timeout, self.stream, self.workers_setup_status, ), - daemon=True, ) process.start() - process_list.append((process, worker_id)) + process_list.append(process) + return manager, process_list - for spec in self._specs: - # Objects of Server class are referenced (not copied) - logging.debug(f"shallow copy for Server is created for for spec {spec}") - server_copy = copy.copy(self) - del server_copy.app - try: - spec.setup(server_copy) - except Exception as e: - close_tasks() - raise e + @asynccontextmanager + async def lifespan(self, app: FastAPI): + loop = asyncio.get_running_loop() + + response_queue = self.response_queues[app.response_queue_id] + response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) + future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) + task = loop.create_task(future) yield - close_tasks() - for process, worker_id in process_list: - logging.info(f"terminating worker worker_id={worker_id}") - process.terminate() - manager.shutdown() + task.cancel() + logger.debug("Shutting down response queue to buffer task") def device_identifiers(self, accelerator, device): if isinstance(device, Sequence): @@ -603,10 +593,11 @@ async def health(request: Request) -> Response: return Response(content="not ready", status_code=503) async def predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: + response_queue_id = self.app.response_queue_id uid = uuid.uuid4() event = asyncio.Event() self.response_buffer[uid] = event - logger.debug(f"Received request uid={uid}") + logger.info(f"Received request uid={uid}") payload = request if self.request_type == Request: @@ -617,7 +608,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) else: payload = await request.json() - self.request_queue.put((uid, time.monotonic(), payload)) + self.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), payload)) await event.wait() response, status = self.response_buffer.pop(uid) @@ -627,6 +618,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) return response async def stream_predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: + response_queue_id = self.app.response_queue_id uid = uuid.uuid4() event = asyncio.Event() q = deque() @@ -636,7 +628,7 @@ async def stream_predict(request: self.request_type, background_tasks: Backgroun payload = request if self.request_type == Request: payload = await request.json() - self.request_queue.put((uid, time.monotonic(), payload)) + self.request_queue.put((response_queue_id, uid, time.monotonic(), payload)) return StreamingResponse(self.data_streamer(q, data_available=event)) @@ -675,7 +667,15 @@ def generate_client_file(self): except Exception as e: print(f"Error copying file: {e}") - def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): + def run( + self, + port: Union[str, int] = 8000, + num_uvicorn_servers: Optional[int] = None, + log_level: str = "info", + generate_client_file: bool = True, + uvicorn_worker_type: Optional[str] = None, + **kwargs, + ): if generate_client_file: self.generate_client_file() @@ -688,7 +688,49 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl if not (1024 <= port <= 65535): raise ValueError(port_msg) - uvicorn.run(host="0.0.0.0", port=port, app=self.app, log_level=log_level, **kwargs) + config = uvicorn.Config(app=self.app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) + sockets = [config.bind_socket()] + + if num_uvicorn_servers is None: + num_uvicorn_servers = len(self.workers) + + manager, litserve_workers = self.launch_inference_worker(num_uvicorn_servers) + + if sys.platform == "win32": + uvicorn_worker_type = "thread" + elif uvicorn_worker_type is None: + uvicorn_worker_type = "process" + + servers = self._start_server(port, num_uvicorn_servers, log_level, sockets, uvicorn_worker_type) + + for s in servers: + s.join() + + for w in litserve_workers: + w.terminate() + w.join() + manager.shutdown() + + def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_worker_type): + servers = [] + for response_queue_id in range(num_uvicorn_servers): + self.app.response_queue_id = response_queue_id + if self.lit_spec: + self.lit_spec.response_queue_id = response_queue_id + app = copy.copy(self.app) + + config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level) + server = uvicorn.Server(config=config) + if uvicorn_worker_type == "process": + ctx = mp.get_context("fork") + w = ctx.Process(target=server.run, args=(sockets,)) + elif uvicorn_worker_type == "thread": + w = threading.Thread(target=server.run, args=(sockets,)) + else: + raise ValueError("Invalid value for uvicorn_worker_type. Must be 'process' or 'thread'") + w.start() + servers.append(w) + return servers def setup_auth(self): if hasattr(self.lit_api, "authorize") and callable(self.lit_api.authorize): diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index daa7a98e..77dc79a9 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -317,8 +317,8 @@ async def options_chat_completions(self, request: Request): return Response(status_code=200) async def chat_completion(self, request: ChatCompletionRequest, background_tasks: BackgroundTasks): + response_queue_id = self.response_queue_id logger.debug("Received chat completion request %s", request) - uids = [uuid.uuid4() for _ in range(request.n)] self.queues = [] self.events = [] @@ -328,7 +328,7 @@ async def chat_completion(self, request: ChatCompletionRequest, background_tasks q = deque() event = asyncio.Event() self._server.response_buffer[uid] = (q, event) - self._server.request_queue.put((uid, time.monotonic(), request_el)) + self._server.request_queue.put((response_queue_id, uid, time.monotonic(), request_el)) self.queues.append(q) self.events.append(event) diff --git a/tests/conftest.py b/tests/conftest.py index ced27706..85b3a082 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager import time import psutil from typing import Generator @@ -97,9 +98,23 @@ def simple_batched_stream_api(): return SimpleBatchedStreamAPI() +@contextmanager +def wrap_litserve_start(server: LitServer): + server.app.response_queue_id = 0 + if server.lit_spec: + server.lit_spec.response_queue_id = 0 + manager, processes = server.launch_inference_worker(num_uvicorn_servers=1) + yield server + for p in processes: + p.terminate() + manager.shutdown() + + @pytest.fixture() def lit_server(simple_litapi): - return LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10) + server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10) + with wrap_litserve_start(server) as s: + yield s @pytest.fixture() diff --git a/tests/test_auth.py b/tests/test_auth.py index 1f8bbc6d..1955c4fa 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -18,6 +18,7 @@ from litserve import LitAPI, LitServer import litserve.server +from tests.conftest import wrap_litserve_start class SimpleAuthedLitAPI(LitAPI): @@ -40,8 +41,7 @@ def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())): def test_authorized_custom(): server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: input = {"input": 4.0} response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input) assert response.status_code == 200 @@ -49,8 +49,7 @@ def test_authorized_custom(): def test_not_authorized_custom(): server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: input = {"input": 4.0} response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input) assert response.status_code == 401 @@ -74,7 +73,7 @@ def test_authorized_api_key(): litserve.server.LIT_SERVER_API_KEY = "abcd" server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: input = {"input": 4.0} response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input) assert response.status_code == 200 @@ -86,7 +85,7 @@ def test_not_authorized_api_key(): litserve.server.LIT_SERVER_API_KEY = "abcd" server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: input = {"input": 4.0} response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input) assert response.status_code == 401 diff --git a/tests/test_batch.py b/tests/test_batch.py index d25b3c9a..8956f55a 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -23,7 +23,7 @@ from httpx import AsyncClient from litserve import LitAPI, LitServer - +from tests.conftest import wrap_litserve_start import torch import torch.nn as nn @@ -86,10 +86,11 @@ async def test_batched(): api = SimpleLitAPI() server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=2, batch_timeout=4) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response1 = ac.post("/predict", json={"input": 4.0}) - response2 = ac.post("/predict", json={"input": 5.0}) - response1, response2 = await asyncio.gather(response1, response2) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response1 = ac.post("/predict", json={"input": 4.0}) + response2 = ac.post("/predict", json={"input": 5.0}) + response1, response2 = await asyncio.gather(response1, response2) assert response1.json() == {"output": 9.0} assert response2.json() == {"output": 11.0} @@ -99,11 +100,11 @@ async def test_batched(): async def test_unbatched(): api = SimpleLitAPI2() server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=1) - - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response1 = ac.post("/predict", json={"input": 4.0}) - response2 = ac.post("/predict", json={"input": 5.0}) - response1, response2 = await asyncio.gather(response1, response2) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response1 = ac.post("/predict", json={"input": 4.0}) + response2 = ac.post("/predict", json={"input": 5.0}) + response1, response2 = await asyncio.gather(response1, response2) assert response1.json() == {"output": 9.0} assert response2.json() == {"output": 11.0} @@ -127,8 +128,9 @@ def put(self, *args): def test_batched_loop(): requests_queue = Queue() - requests_queue.put(("uuid-1234", time.monotonic(), {"input": 4.0})) - requests_queue.put(("uuid-1235", time.monotonic(), {"input": 5.0})) + response_queue_id = 0 + requests_queue.put((response_queue_id, "uuid-1234", time.monotonic(), {"input": 4.0})) + requests_queue.put((response_queue_id, "uuid-1235", time.monotonic(), {"input": 5.0})) lit_api_mock = MagicMock() lit_api_mock.request_timeout = 2 diff --git a/tests/test_compression.py b/tests/test_compression.py index aa6cb948..0fb162fb 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -14,7 +14,7 @@ from fastapi import Request, Response from fastapi.testclient import TestClient - +from tests.conftest import wrap_litserve_start from litserve import LitAPI, LitServer # trivially compressible content @@ -38,8 +38,8 @@ def encode_response(self, output) -> Response: def test_compression(): server = LitServer(LargeOutputLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - # compressed - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + # compressed response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={}) assert response.status_code == 200 assert response.headers["Content-Encoding"] == "gzip" @@ -47,8 +47,7 @@ def test_compression(): assert 0 < content_length < 100000 assert response.json() == test_output - # uncompressed - with TestClient(server.app) as client: + # uncompressed response = client.post("/predict", headers={"Accept-Encoding": ""}, json={}) assert response.status_code == 200 assert "Content-Encoding" not in response.headers diff --git a/tests/test_examples.py b/tests/test_examples.py index 6f2ded49..3b14793c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,7 +1,7 @@ import pytest from asgi_lifespan import LifespanManager from httpx import AsyncClient - +from tests.conftest import wrap_litserve_start import litserve as ls @@ -9,24 +9,27 @@ async def test_simple_pytorch_api(): api = ls.examples.SimpleTorchAPI() server = ls.LitServer(api, accelerator="cpu") - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response = await ac.post("/predict", json={"input": 4.0}) - assert response.json() == {"output": 9.0} + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response = await ac.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 9.0} @pytest.mark.asyncio() async def test_simple_batched_api(): api = ls.examples.SimpleBatchedAPI() server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.1) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response = await ac.post("/predict", json={"input": 4.0}) - assert response.json() == {"output": 16.0} + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response = await ac.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 16.0} @pytest.mark.asyncio() async def test_simple_api(): api = ls.examples.SimpleLitAPI() server = ls.LitServer(api) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response = await ac.post("/predict", json={"input": 4.0}) - assert response.json() == {"output": 16.0} + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response = await ac.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 16.0} diff --git a/tests/test_form.py b/tests/test_form.py index 56c49baf..e5026089 100644 --- a/tests/test_form.py +++ b/tests/test_form.py @@ -14,6 +14,7 @@ from fastapi import Request, Response from fastapi.testclient import TestClient +from tests.conftest import wrap_litserve_start from litserve import LitAPI, LitServer @@ -39,7 +40,7 @@ def test_multipart_form_data(tmp_path): SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length * 2) ) - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: file_path = f"{tmp_path}/big_file.txt" with open(file_path, "wb") as f: f.write(bytearray([1] * file_length)) @@ -56,7 +57,7 @@ def test_file_too_big(tmp_path): SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length / 2) ) - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: file_path = f"{tmp_path}/big_file.txt" with open(file_path, "wb") as f: f.write(bytearray([1] * file_length)) @@ -86,8 +87,7 @@ def encode_response(self, output) -> Response: def test_urlencoded_form_data(): server = LitServer(SimpleFormLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: file = {"input": "4.0"} response = client.post("/predict", data=file) assert response.json() == {"output": 16.0} diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index d119649a..d0f4719c 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -23,6 +23,7 @@ import torch.nn as nn from queue import Queue from httpx import AsyncClient +from tests.conftest import wrap_litserve_start from unittest.mock import patch, MagicMock import pytest @@ -75,8 +76,8 @@ def test_inference_worker(mock_single_loop, mock_batched_loop): @pytest.fixture() def loop_args(): requests_queue = Queue() - requests_queue.put(("uuid-123", time.monotonic(), 1)) # uid, timestamp, x_enc - requests_queue.put(("uuid-234", time.monotonic(), 2)) + requests_queue.put((0, "uuid-123", time.monotonic(), 1)) # response_queue_id, uid, timestamp, x_enc + requests_queue.put((1, "uuid-234", time.monotonic(), 2)) lit_api_mock = MagicMock() lit_api_mock.request_timeout = 1 @@ -92,10 +93,10 @@ def put(self, item): def test_single_loop(loop_args): lit_api_mock, requests_queue = loop_args lit_api_mock.unbatch.side_effect = None - response_queue = FakeResponseQueue() + response_queues = [FakeResponseQueue()] with pytest.raises(StopIteration, match="exit loop"): - run_single_loop(lit_api_mock, None, requests_queue, response_queue) + run_single_loop(lit_api_mock, None, requests_queue, response_queues) @pytest.mark.asyncio() @@ -104,14 +105,19 @@ async def test_stream(simple_stream_api): expected_output1 = "prompt=Hello generated_output=LitServe is streaming output".lower().replace(" ", "") expected_output2 = "prompt=World generated_output=LitServe is streaming output".lower().replace(" ", "") - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) - resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) - resp1, resp2 = await asyncio.gather(resp1, resp2) - assert resp1.status_code == 200, "Check if server is running and the request format is valid." - assert resp1.text == expected_output1, "Server returns input prompt and generated output which didn't match." - assert resp2.status_code == 200, "Check if server is running and the request format is valid." - assert resp2.text == expected_output2, "Server returns input prompt and generated output which didn't match." + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) + resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) + resp1, resp2 = await asyncio.gather(resp1, resp2) + assert resp1.status_code == 200, "Check if server is running and the request format is valid." + assert ( + resp1.text == expected_output1 + ), "Server returns input prompt and generated output which didn't match." + assert resp2.status_code == 200, "Check if server is running and the request format is valid." + assert ( + resp2.text == expected_output2 + ), "Server returns input prompt and generated output which didn't match." @pytest.mark.asyncio() @@ -120,14 +126,19 @@ async def test_batched_stream_server(simple_batched_stream_api): expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "") expected_output2 = "World LitServe is streaming output".lower().replace(" ", "") - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) - resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) - resp1, resp2 = await asyncio.gather(resp1, resp2) - assert resp1.status_code == 200, "Check if server is running and the request format is valid." - assert resp2.status_code == 200, "Check if server is running and the request format is valid." - assert resp1.text == expected_output1, "Server returns input prompt and generated output which didn't match." - assert resp2.text == expected_output2, "Server returns input prompt and generated output which didn't match." + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) + resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) + resp1, resp2 = await asyncio.gather(resp1, resp2) + assert resp1.status_code == 200, "Check if server is running and the request format is valid." + assert resp2.status_code == 200, "Check if server is running and the request format is valid." + assert ( + resp1.text == expected_output1 + ), "Server returns input prompt and generated output which didn't match." + assert ( + resp2.text == expected_output2 + ), "Server returns input prompt and generated output which didn't match." class FakeStreamResponseQueue: @@ -144,7 +155,7 @@ def put(self, item): self.count += 1 -def test_streaming_loop(loop_args): +def test_streaming_loop(): num_streamed_outputs = 10 def fake_predict(inputs: str): @@ -164,11 +175,11 @@ def fake_encode(output): fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x) requests_queue = Queue() - requests_queue.put(("UUID-1234", time.monotonic(), {"prompt": "Hello"})) - response_queue = FakeStreamResponseQueue(num_streamed_outputs) + requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"})) + response_queues = [FakeStreamResponseQueue(num_streamed_outputs)] with pytest.raises(StopIteration, match="exit loop"): - run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queue) + run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues) fake_stream_api.predict.assert_called_once_with("Hello") fake_stream_api.encode_response.assert_called_once() @@ -220,13 +231,13 @@ def fake_encode(output_iter): fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x) requests_queue = Queue() - requests_queue.put(("UUID-001", time.monotonic(), {"prompt": "Hello"})) - requests_queue.put(("UUID-002", time.monotonic(), {"prompt": "World"})) - response_queue = FakeBatchStreamResponseQueue(num_streamed_outputs) + requests_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"})) + requests_queue.put((0, "UUID-002", time.monotonic(), {"prompt": "World"})) + response_queues = [FakeBatchStreamResponseQueue(num_streamed_outputs)] with pytest.raises(StopIteration, match="finish streaming"): run_batched_streaming_loop( - fake_stream_api, fake_stream_api, requests_queue, response_queue, max_batch_size=2, batch_timeout=2 + fake_stream_api, fake_stream_api, requests_queue, response_queues, max_batch_size=2, batch_timeout=2 ) fake_stream_api.predict.assert_called_once_with(["Hello", "World"]) fake_stream_api.encode_response.assert_called_once() @@ -324,10 +335,10 @@ def test_server_run(mock_uvicorn): server.run(port=65536) server.run(port=8000) - mock_uvicorn.run.assert_called() + mock_uvicorn.Config.assert_called() mock_uvicorn.reset_mock() server.run(port="8001") - mock_uvicorn.run.assert_called() + mock_uvicorn.Config.assert_called() class IndentityAPI(ls.examples.SimpleLitAPI): @@ -383,26 +394,29 @@ def dummy_load_and_raise(resp): # Test context injection with single loop api = IndentityAPI() server = LitServer(api) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) assert resp.json()["output"] == 5.0, "output from Identity server must be same as input" # Test context injection with batched loop server = LitServer(IndentityBatchedAPI(), max_batch_size=2, batch_timeout=0.01) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) assert resp.json()["output"] == 5.0, "output from Identity server must be same as input" # Test context injection with batched streaming loop server = LitServer(IndentityBatchedStreamingAPI(), max_batch_size=2, batch_timeout=0.01, stream=True) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) assert resp.json()["output"] == 5.0, "output from Identity server must be same as input" server = LitServer(PredictErrorAPI()) - with pytest.raises(TypeError, match=re.escape("predict() missing 1 required positional argument: 'y'")), TestClient( - server.app - ) as client: + with wrap_litserve_start(server) as server, pytest.raises( + TypeError, match=re.escape("predict() missing 1 required positional argument: 'y'") + ), TestClient(server.app) as client: client.post("/predict", json={"input": 5.0}, timeout=10) @@ -412,7 +426,7 @@ def test_custom_api_path(): server = LitServer(ls.examples.SimpleLitAPI(), api_path="/v1/custom_predict") url = server.api_path - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post(url, json={"input": 4.0}) assert response.status_code == 200, "Server response should be 200 (OK)" @@ -424,7 +438,7 @@ def decode_request(self, request): def test_http_exception(): server = LitServer(TestHTTPExceptionAPI()) - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.status_code == 501, "Server raises 501 error" assert response.text == '{"detail":"decode request is bad"}', "decode request is bad" diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index d017347c..5d1335eb 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -14,7 +14,7 @@ from fastapi.testclient import TestClient from pydantic import BaseModel - +from tests.conftest import wrap_litserve_start from litserve import LitAPI, LitServer @@ -42,7 +42,6 @@ def encode_response(self, output: float) -> PredictResponse: def test_pydantic(): server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=5) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 16.0} diff --git a/tests/test_simple.py b/tests/test_simple.py index a632b95e..6232042f 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -25,6 +25,8 @@ from litserve import LitAPI, LitServer +from tests.conftest import wrap_litserve_start + class SimpleLitAPI(LitAPI): def setup(self, device): @@ -40,10 +42,8 @@ def encode_response(self, output) -> Response: return {"output": output} -def test_simple(): - server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=5) - - with TestClient(server.app) as client: +def test_simple(lit_server): + with TestClient(lit_server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 16.0} @@ -57,7 +57,7 @@ def setup(self, device): def test_workers_health(): server = LitServer(SlowSetupLitAPI(), accelerator="cpu", devices=1, timeout=5, workers_per_device=2) - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.get("/health") assert response.status_code == 503 assert response.text == "not ready" @@ -80,16 +80,13 @@ def make_load_request(server, outputs): outputs.append(response.json()) -def test_load(): - server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=25) - +def test_load(lit_server): from threading import Thread threads = [] - for _ in range(1): outputs = [] - t = Thread(target=make_load_request, args=(server, outputs)) + t = Thread(target=make_load_request, args=(lit_server, outputs)) t.start() threads.append((t, outputs)) @@ -127,46 +124,70 @@ async def test_timeout(): api = SlowLitAPI() # takes 2 second for each prediction server = LitServer(api, accelerator="cpu", devices=1, timeout=1.5) - # case 1: first request completes, second request times out in queue - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - await asyncio.sleep(2) # Give time to start inference workers - response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0})) - await asyncio.sleep(0.0001) - response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0})) - await asyncio.wait([response1, response2]) - assert ( - response1.result().status_code == 200 - ), "First request should complete since it's popped from the request queue." - assert ( - response2.result().status_code == 504 - ), "Server takes longer than specified timeout and request should timeout" + with wrap_litserve_start(server) as server: + # case 1: first request completes, second request times out in queue + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + await asyncio.sleep(2) # Give time to start inference workers + response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0})) + await asyncio.sleep(0.0001) + response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0})) + await asyncio.wait([response1, response2]) + assert ( + response1.result().status_code == 200 + ), "First request should complete since it's popped from the request queue." + assert ( + response2.result().status_code == 504 + ), "Server takes longer than specified timeout and request should timeout" # Case 2: first 2 requests finish as a batch and third request times out in queue - server = LitServer(SlowBatchAPI(), accelerator="cpu", timeout=1.5, max_batch_size=2, batch_timeout=0.01) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - await asyncio.sleep(2) # Give time to start inference workers - response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0})) - response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0})) - await asyncio.sleep(0.0001) - response3 = asyncio.create_task(ac.post("/predict", json={"input": 6.0})) - await asyncio.wait([response1, response2, response3]) - assert ( - response1.result().status_code == 200 - ), "Batch: First request should complete since it's popped from the request queue." - assert ( - response2.result().status_code == 200 - ), "Batch: Second request should complete since it's popped from the request queue." - - assert response3.result().status_code == 504, "Batch: Third request was delayed and should fail" + server = LitServer( + SlowBatchAPI(), + accelerator="cpu", + timeout=1.5, + max_batch_size=2, + batch_timeout=0.01, + ) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + await asyncio.sleep(2) # Give time to start inference workers + response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0})) + response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0})) + await asyncio.sleep(0.0001) + response3 = asyncio.create_task(ac.post("/predict", json={"input": 6.0})) + await asyncio.wait([response1, response2, response3]) + assert ( + response1.result().status_code == 200 + ), "Batch: First request should complete since it's popped from the request queue." + assert ( + response2.result().status_code == 200 + ), "Batch: Second request should complete since it's popped from the request queue." + + assert response3.result().status_code == 504, "Batch: Third request was delayed and should fail" server1 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=-1) server2 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=False) - server3 = LitServer(SlowBatchAPI(), accelerator="cpu", devices=1, timeout=False, max_batch_size=2, batch_timeout=2) - server4 = LitServer(SlowBatchAPI(), accelerator="cpu", devices=1, timeout=-1, max_batch_size=2, batch_timeout=2) - - with TestClient(server1.app) as client1, TestClient(server2.app) as client2, TestClient( - server3.app - ) as client3, TestClient(server4.app) as client4: + server3 = LitServer( + SlowBatchAPI(), + accelerator="cpu", + devices=1, + timeout=False, + max_batch_size=2, + batch_timeout=2, + ) + server4 = LitServer( + SlowBatchAPI(), + accelerator="cpu", + devices=1, + timeout=-1, + max_batch_size=2, + batch_timeout=2, + ) + + with wrap_litserve_start(server1) as server1, wrap_litserve_start(server2) as server2, wrap_litserve_start( + server3 + ) as server3, wrap_litserve_start(server4) as server4, TestClient(server1.app) as client1, TestClient( + server2.app + ) as client2, TestClient(server3.app) as client3, TestClient(server4.app) as client4: response1 = client1.post("/predict", json={"input": 4.0}) assert response1.status_code == 200, "Expected slow server to respond since timeout was disabled" @@ -180,10 +201,9 @@ async def test_timeout(): assert response4.status_code == 200, "Expected slow batch server to respond since timeout was disabled" -def test_concurrent_requests(): +def test_concurrent_requests(lit_server): n_requests = 100 - server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with TestClient(server.app) as client, ThreadPoolExecutor(n_requests // 4 + 1) as executor: + with TestClient(lit_server.app) as client, ThreadPoolExecutor(n_requests // 4 + 1) as executor: responses = list(executor.map(lambda i: client.post("/predict", json={"input": i}), range(n_requests))) count = 0 diff --git a/tests/test_specs.py b/tests/test_specs.py index 5cfe3660..2e7281b5 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -24,6 +24,7 @@ OpenAIBatchingWithUsage, OpenAIWithUsageEncodeResponse, ) +from tests.conftest import wrap_litserve_start from litserve.specs.openai import OpenAISpec, ChatMessage import litserve as ls @@ -32,13 +33,14 @@ async def test_openai_spec(openai_request_data): spec = OpenAISpec() server = ls.LitServer(TestAPI(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 200, "Status code should be 200" + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" - assert ( - resp.json()["choices"][0]["message"]["content"] == "This is a generated output" - ), "LitAPI predict response should match with the generated output" + assert ( + resp.json()["choices"][0]["message"]["content"] == "This is a generated output" + ), "LitAPI predict response should match with the generated output" # OpenAIWithUsage @@ -53,64 +55,66 @@ async def test_openai_spec(openai_request_data): ) async def test_openai_token_usage(api, batch_size, openai_request_data, openai_response_data): server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=batch_size, batch_timeout=0.01) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 200, "Status code should be 200" - result = resp.json() - content = result["choices"][0]["message"]["content"] - assert content == "10 + 6 is equal to 16.", "LitAPI predict response should match with the generated output" - assert result["usage"] == openai_response_data["usage"] - - # with streaming - openai_request_data["stream"] = True - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 200, "Status code should be 200" - assert result["usage"] == openai_response_data["usage"] + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + result = resp.json() + content = result["choices"][0]["message"]["content"] + assert content == "10 + 6 is equal to 16.", "LitAPI predict response should match with the generated output" + assert result["usage"] == openai_response_data["usage"] + + # with streaming + openai_request_data["stream"] = True + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert result["usage"] == openai_response_data["usage"] @pytest.mark.asyncio() async def test_openai_spec_with_image(openai_request_data_with_image): - spec = OpenAISpec() - server = ls.LitServer(TestAPI(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_image, timeout=10) - assert resp.status_code == 200, "Status code should be 200" + server = ls.LitServer(TestAPI(), spec=OpenAISpec()) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_image, timeout=10) + assert resp.status_code == 200, "Status code should be 200" - assert ( - resp.json()["choices"][0]["message"]["content"] == "This is a generated output" - ), "LitAPI predict response should match with the generated output" + assert ( + resp.json()["choices"][0]["message"]["content"] == "This is a generated output" + ), "LitAPI predict response should match with the generated output" @pytest.mark.asyncio() async def test_override_encode(openai_request_data): - spec = OpenAISpec() - server = ls.LitServer(TestAPIWithCustomEncode(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 200, "Status code should be 200" + server = ls.LitServer(TestAPIWithCustomEncode(), spec=OpenAISpec()) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" - assert ( - resp.json()["choices"][0]["message"]["content"] == "This is a custom encoded output" - ), "LitAPI predict response should match with the generated output" + assert ( + resp.json()["choices"][0]["message"]["content"] == "This is a custom encoded output" + ), "LitAPI predict response should match with the generated output" @pytest.mark.asyncio() async def test_openai_spec_with_tools(openai_request_data_with_tools): spec = OpenAISpec() server = ls.LitServer(TestAPIWithToolCalls(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_tools, timeout=10) - assert resp.status_code == 200, "Status code should be 200" - assert ( - resp.json()["choices"][0]["message"]["content"] == "" - ), "LitAPI predict response should match with the generated output" - assert resp.json()["choices"][0]["message"]["tool_calls"] == [ - { - "id": "call_1", - "type": "function", - "function": {"name": "function_1", "arguments": '{"arg_1": "arg_1_value"}'}, - } - ], "LitAPI predict response should match with the generated output" + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_tools, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert ( + resp.json()["choices"][0]["message"]["content"] == "" + ), "LitAPI predict response should match with the generated output" + assert resp.json()["choices"][0]["message"]["tool_calls"] == [ + { + "id": "call_1", + "type": "function", + "function": {"name": "function_1", "arguments": '{"arg_1": "arg_1_value"}'}, + } + ], "LitAPI predict response should match with the generated output" class IncorrectAPI1(ls.LitAPI): @@ -131,14 +135,13 @@ def encode_response(self, output): @pytest.mark.asyncio() async def test_openai_spec_validation(openai_request_data): - spec = OpenAISpec() - server = ls.LitServer(IncorrectAPI1(), spec=spec) - with pytest.raises(ValueError, match="predict is not a generator"): + server = ls.LitServer(IncorrectAPI1(), spec=OpenAISpec()) + with pytest.raises(ValueError, match="predict is not a generator"), wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager: await manager.shutdown() - server = ls.LitServer(IncorrectAPI2(), spec=spec) - with pytest.raises(ValueError, match="encode_response is not a generator"): + server = ls.LitServer(IncorrectAPI2(), spec=OpenAISpec()) + with pytest.raises(ValueError, match="encode_response is not a generator"), wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager: await manager.shutdown() @@ -159,11 +162,12 @@ async def test_oai_prepopulated_context(openai_request_data): openai_request_data["max_tokens"] = 3 spec = OpenAISpec() server = ls.LitServer(PrePopulatedAPI(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert ( - resp.json()["choices"][0]["message"]["content"] == "This is a" - ), "OpenAISpec must return only 3 tokens as specified using `max_tokens` parameter" + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert ( + resp.json()["choices"][0]["message"]["content"] == "This is a" + ), "OpenAISpec must return only 3 tokens as specified using `max_tokens` parameter" class WrongLitAPI(ls.LitAPI): @@ -178,7 +182,8 @@ def predict(self, prompt): @pytest.mark.asyncio() async def test_fail_http(openai_request_data): server = ls.LitServer(WrongLitAPI(), spec=ls.OpenAISpec()) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - res = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert res.status_code == 501, "Server raises 501 error" - assert res.text == '{"detail":"test LitAPI.predict error"}' + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + res = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert res.status_code == 501, "Server raises 501 error" + assert res.text == '{"detail":"test LitAPI.predict error"}' diff --git a/tests/test_torch.py b/tests/test_torch.py index a32c9bef..7b1f3788 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -15,6 +15,7 @@ from fastapi.testclient import TestClient from litserve import LitAPI, LitServer +from tests.conftest import wrap_litserve_start import torch import torch.nn as nn @@ -51,8 +52,7 @@ def encode_response(self, output) -> Response: def test_torch(): server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=10) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 9.0} @@ -60,7 +60,6 @@ def test_torch(): @pytest.mark.skipif(torch.cuda.device_count() == 0, reason="requires CUDA to be available") def test_torch_gpu(): server = LitServer(SimpleLitAPI(), accelerator="cuda", devices=1, timeout=10) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 9.0}