From d348e7978c0d518ed09c8278c70cd5c4059a2265 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 14:09:54 +0000 Subject: [PATCH 01/45] threaded uvicorn run --- src/litserve/server.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 3d5b7e49..18219fa0 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -20,6 +20,7 @@ import pickle import shutil import sys +import threading import time import uuid from concurrent.futures import ThreadPoolExecutor @@ -472,6 +473,7 @@ def __init__( device_list = range(devices) self.devices = [self.device_identifiers(accelerator, device) for device in device_list] + self.workers = self.devices * self.workers_per_device self.setup_server() @asynccontextmanager @@ -688,7 +690,40 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl if not (1024 <= port <= 65535): raise ValueError(port_msg) - uvicorn.run(host="0.0.0.0", port=port, app=self.app, log_level=log_level, **kwargs) + uvicorn.run(host="0.0.0.0", port=port, app=self.app, workers=4, log_level=log_level, **kwargs) + + def runv2(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): + if generate_client_file: + self.generate_client_file() + + port_msg = f"port must be a value from 1024 to 65535 but got {port}" + try: + port = int(port) + except ValueError: + raise ValueError(port_msg) + + if not (1024 <= port <= 65535): + raise ValueError(port_msg) + + config = uvicorn.Config(app=self.app, port=8000) + + sockets = [config.bind_socket()] + def run(server, config, sockets: list | None = None) -> None: + # self.config.setup_event_loop() + uvicorn.loops.uvloop.uvloop_setup(use_subprocess=False) + return asyncio.run(server.serve(sockets=sockets)) + + threads = [] + for i in range(4): + server = uvicorn.Server(config=config) + th = threading.Thread(target=server.run, args=(sockets,)) + th.start() + threads.append(th) + + for th in threads: + th.join() + + def setup_auth(self): if hasattr(self.lit_api, "authorize") and callable(self.lit_api.authorize): From ba63390d88ea55df6feff83129562e2a53c361d8 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 14:59:35 +0000 Subject: [PATCH 02/45] update --- src/litserve/server.py | 77 +++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 42 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 18219fa0..3647ffd5 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -43,6 +43,9 @@ from litserve.specs.base import LitSpec from litserve.utils import LitAPIStatus, load_and_raise from collections import deque +import uvloop + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) logger = logging.getLogger(__name__) @@ -333,6 +336,7 @@ def inference_worker( batch_timeout: float, stream: bool, workers_setup_status: Dict[str, bool] = None, + server_callback=None ): lit_api.setup(device) lit_api.device = device @@ -378,6 +382,7 @@ async def response_queue_to_buffer( stream: bool, response_executor: ThreadPoolExecutor, ): + print("running response consumer") loop = asyncio.get_running_loop() if stream: while True: @@ -476,50 +481,27 @@ def __init__( self.workers = self.devices * self.workers_per_device self.setup_server() - @asynccontextmanager - async def lifespan(self, app: FastAPI): + + + async def launch_inference_worker(self): + app = self.app manager = mp.Manager() + app.manager = manager self.request_queue = manager.Queue() - self.response_buffer = {} self.workers_setup_status = manager.dict() + self.response_buffer = {} + loop = asyncio.get_running_loop() + self.response_queues = [] - response_queues = [] - tasks: List[asyncio.Task] = [] - - def close_tasks(): - for task in tasks: - task.cancel() - response_executor.shutdown(wait=False, cancel_futures=True) - - try: - pickle.dumps(self.lit_api) - pickle.dumps(self.lit_spec) - - except (pickle.PickleError, AttributeError) as e: - logging.error( - "The LitAPI instance provided to LitServer cannot be moved to a worker because " - "it cannot be pickled. Please ensure all heavy-weight operations, like model " - "creation, are defined in LitAPI's setup method." - ) - raise e - - loop = asyncio.get_running_loop() - response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) - - process_list = [] - # NOTE: device: str | List[str], the latter in the case a model needs more than one device to run for worker_id, device in enumerate(self.devices * self.workers_per_device): if len(device) == 1: device = device[0] self.workers_setup_status[worker_id] = False response_queue = manager.Queue() - response_queues.append(response_queue) - - future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) - task = loop.create_task(future, name=f"Response-reader-{worker_id}") - tasks.append(task) + self.response_queues.append(response_queue) + process_list = [] ctx = mp.get_context("spawn") process = ctx.Process( target=inference_worker, @@ -550,14 +532,22 @@ def close_tasks(): except Exception as e: close_tasks() raise e - + + @asynccontextmanager + async def lifespan(self, app: FastAPI): + loop = asyncio.get_running_loop() + tasks = [] + response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) + for i, response_queue in enumerate(self.response_queues): + future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) + task = loop.create_task(future, name=f"Response-reader-{i}") + tasks.append(task) yield - close_tasks() - for process, worker_id in process_list: - logging.info(f"terminating worker worker_id={worker_id}") - process.terminate() - manager.shutdown() + for task in tasks: + task.cancel() + + def device_identifiers(self, accelerator, device): if isinstance(device, Sequence): @@ -619,7 +609,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) else: payload = await request.json() - self.request_queue.put((uid, time.monotonic(), payload)) + self.request_queue.put_nowait((uid, time.monotonic(), payload)) await event.wait() response, status = self.response_buffer.pop(uid) @@ -690,7 +680,10 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl if not (1024 <= port <= 65535): raise ValueError(port_msg) - uvicorn.run(host="0.0.0.0", port=port, app=self.app, workers=4, log_level=log_level, **kwargs) + loop = asyncio.new_event_loop() + loop.run_until_complete(self.launch_inference_worker()) + + uvicorn.run(host="0.0.0.0", port=port, app=self.app, workers=1, log_level=log_level, **kwargs) def runv2(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): if generate_client_file: @@ -705,7 +698,7 @@ def runv2(self, port: Union[str, int] = 8000, log_level: str = "info", generate_ if not (1024 <= port <= 65535): raise ValueError(port_msg) - config = uvicorn.Config(app=self.app, port=8000) + config = uvicorn.Config(app=self.app, port=port, log_level=log_level) sockets = [config.bind_socket()] def run(server, config, sockets: list | None = None) -> None: From 5bbd1abce467ba09d462435e8133991b6e43368c Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 16:32:02 +0000 Subject: [PATCH 03/45] update --- src/litserve/server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 3647ffd5..036af531 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -683,7 +683,10 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl loop = asyncio.new_event_loop() loop.run_until_complete(self.launch_inference_worker()) - uvicorn.run(host="0.0.0.0", port=port, app=self.app, workers=1, log_level=log_level, **kwargs) + config = uvicorn.Config(app=self.app, port=port, log_level=log_level) + sockets = [config.bind_socket()] + server = uvicorn.Server(config=config) + server.run(sockets=sockets) def runv2(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): if generate_client_file: From 4592ecfb6e2b3f9438f8904cb698185d78b5a43c Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 16:36:17 +0000 Subject: [PATCH 04/45] update --- src/litserve/server.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 036af531..a6534aa6 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -45,6 +45,7 @@ from collections import deque import uvloop +mp.allow_connection_pickling() asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) logger = logging.getLogger(__name__) @@ -336,7 +337,8 @@ def inference_worker( batch_timeout: float, stream: bool, workers_setup_status: Dict[str, bool] = None, - server_callback=None + server_run=None, + sockets=None ): lit_api.setup(device) lit_api.device = device @@ -345,6 +347,9 @@ def inference_worker( message = f"Setup complete for worker {worker_id}." print(message) logger.info(message) + if server_run: + server_run(sockets=sockets) + if lit_spec: logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") if stream: @@ -483,7 +488,7 @@ def __init__( - async def launch_inference_worker(self): + async def launch_inference_worker(self, config, sockets): app = self.app manager = mp.Manager() app.manager = manager @@ -501,6 +506,8 @@ async def launch_inference_worker(self): response_queue = manager.Queue() self.response_queues.append(response_queue) + server = uvicorn.Server(config=config) + process_list = [] ctx = mp.get_context("spawn") process = ctx.Process( @@ -516,6 +523,8 @@ async def launch_inference_worker(self): self.batch_timeout, self.stream, self.workers_setup_status, + server.run, + sockets ), daemon=True, ) @@ -681,11 +690,11 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl raise ValueError(port_msg) loop = asyncio.new_event_loop() - loop.run_until_complete(self.launch_inference_worker()) - + config = uvicorn.Config(app=self.app, port=port, log_level=log_level) sockets = [config.bind_socket()] server = uvicorn.Server(config=config) + loop.run_until_complete(self.launch_inference_worker(config, sockets)) server.run(sockets=sockets) def runv2(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): From 2c2177bafe536ba6e9a8c5f06f9e787044aa4aec Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 16:40:30 +0000 Subject: [PATCH 05/45] update --- src/litserve/server.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index a6534aa6..0840387f 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -497,6 +497,7 @@ async def launch_inference_worker(self, config, sockets): self.response_buffer = {} loop = asyncio.get_running_loop() self.response_queues = [] + self.process_list = [] for worker_id, device in enumerate(self.devices * self.workers_per_device): if len(device) == 1: @@ -508,8 +509,8 @@ async def launch_inference_worker(self, config, sockets): server = uvicorn.Server(config=config) - process_list = [] - ctx = mp.get_context("spawn") + + ctx = mp.get_context("fork") process = ctx.Process( target=inference_worker, args=( @@ -529,7 +530,7 @@ async def launch_inference_worker(self, config, sockets): daemon=True, ) process.start() - process_list.append((process, worker_id)) + self.process_list.append((process, worker_id)) for spec in self._specs: # Objects of Server class are referenced (not copied) @@ -607,7 +608,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) uid = uuid.uuid4() event = asyncio.Event() self.response_buffer[uid] = event - logger.debug(f"Received request uid={uid}") + logger.info(f"Received request uid={uid}") payload = request if self.request_type == Request: @@ -695,7 +696,8 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl sockets = [config.bind_socket()] server = uvicorn.Server(config=config) loop.run_until_complete(self.launch_inference_worker(config, sockets)) - server.run(sockets=sockets) + for p, worker_id in self.process_list: + p.join() def runv2(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): if generate_client_file: From dec04bd6b60548ecd831d4a79a894face6a495dc Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 16:49:07 +0000 Subject: [PATCH 06/45] update --- src/litserve/server.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 0840387f..9d3b227c 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -402,6 +402,7 @@ async def response_queue_to_buffer( else: while True: + print("running response consumer") try: uid, payload = await loop.run_in_executor(response_executor, response_queue.get) except Empty: @@ -494,8 +495,7 @@ async def launch_inference_worker(self, config, sockets): app.manager = manager self.request_queue = manager.Queue() self.workers_setup_status = manager.dict() - self.response_buffer = {} - loop = asyncio.get_running_loop() + self.response_buffer = {} self.response_queues = [] self.process_list = [] @@ -540,7 +540,6 @@ async def launch_inference_worker(self, config, sockets): try: spec.setup(server_copy) except Exception as e: - close_tasks() raise e @asynccontextmanager @@ -552,6 +551,8 @@ async def lifespan(self, app: FastAPI): future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) task = loop.create_task(future, name=f"Response-reader-{i}") tasks.append(task) + + print("All tasks started!") yield for task in tasks: From 16c04448ffbbe46f33b162256e233433e315bf9d Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 17:19:07 +0000 Subject: [PATCH 07/45] update --- src/litserve/server.py | 54 +++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 9d3b227c..b9a7ac21 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -337,8 +337,6 @@ def inference_worker( batch_timeout: float, stream: bool, workers_setup_status: Dict[str, bool] = None, - server_run=None, - sockets=None ): lit_api.setup(device) lit_api.device = device @@ -347,9 +345,14 @@ def inference_worker( message = f"Setup complete for worker {worker_id}." print(message) logger.info(message) - if server_run: - server_run(sockets=sockets) - + + config = workers_setup_status["config"] + sockets = workers_setup_status["sockets"] + lit_server = create_server(lit_api, lit_spec, config, sockets) + lit_server.response_queue = response_queue + lit_server.request_queue = request_queue + lit_server.workers_setup_status = workers_setup_status + if lit_spec: logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") if stream: @@ -412,6 +415,13 @@ async def response_queue_to_buffer( buffer[uid] = payload event.set() +def create_server(lit_api, lit_spec, config, sockets): + lit_server = LitServer(lit_api=lit_api, spec=lit_spec) + config.app = lit_server.app + server = uvicorn.Server(config=config) + th = threading.Thread(target=server.run, args=(sockets,), daemon=True) + th.start() + return lit_server class LitServer: def __init__( @@ -428,6 +438,7 @@ def __init__( spec: Optional[LitSpec] = None, max_payload_size=None, ): + self.litserve_locals = locals() if batch_timeout > timeout and timeout not in (False, -1): raise ValueError("batch_timeout must be less than timeout") if max_batch_size <= 0: @@ -488,17 +499,17 @@ def __init__( self.setup_server() - - async def launch_inference_worker(self, config, sockets): - app = self.app + def launch_inference_worker(self, config, sockets): manager = mp.Manager() - app.manager = manager self.request_queue = manager.Queue() self.workers_setup_status = manager.dict() - self.response_buffer = {} self.response_queues = [] self.process_list = [] + config = uvicorn.Config(app=None, port=8000, log_level="info") + self.workers_setup_status["config"] = config + self.workers_setup_status["sockets"] = sockets + for worker_id, device in enumerate(self.devices * self.workers_per_device): if len(device) == 1: device = device[0] @@ -507,10 +518,7 @@ async def launch_inference_worker(self, config, sockets): response_queue = manager.Queue() self.response_queues.append(response_queue) - server = uvicorn.Server(config=config) - - - ctx = mp.get_context("fork") + ctx = mp.get_context("spawn") process = ctx.Process( target=inference_worker, args=( @@ -524,8 +532,6 @@ async def launch_inference_worker(self, config, sockets): self.batch_timeout, self.stream, self.workers_setup_status, - server.run, - sockets ), daemon=True, ) @@ -544,13 +550,14 @@ async def launch_inference_worker(self, config, sockets): @asynccontextmanager async def lifespan(self, app: FastAPI): + self.response_buffer = {} loop = asyncio.get_running_loop() tasks = [] response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) - for i, response_queue in enumerate(self.response_queues): - future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) - task = loop.create_task(future, name=f"Response-reader-{i}") - tasks.append(task) + + future = response_queue_to_buffer(self.response_queue, self.response_buffer, self.stream, response_executor) + task = loop.create_task(future, name=f"Response-reader") + tasks.append(task) print("All tasks started!") yield @@ -559,7 +566,6 @@ async def lifespan(self, app: FastAPI): task.cancel() - def device_identifiers(self, accelerator, device): if isinstance(device, Sequence): return [f"{accelerator}:{el}" for el in device] @@ -691,12 +697,10 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl if not (1024 <= port <= 65535): raise ValueError(port_msg) - loop = asyncio.new_event_loop() - config = uvicorn.Config(app=self.app, port=port, log_level=log_level) sockets = [config.bind_socket()] - server = uvicorn.Server(config=config) - loop.run_until_complete(self.launch_inference_worker(config, sockets)) + + self.launch_inference_worker(config, sockets) for p, worker_id in self.process_list: p.join() From 7254b8852f41723f64c9bbe67444801a2ca9d7d7 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 18:36:23 +0000 Subject: [PATCH 08/45] update --- src/litserve/server.py | 50 ++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index b9a7ac21..483c6ee2 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -345,14 +345,16 @@ def inference_worker( message = f"Setup complete for worker {worker_id}." print(message) logger.info(message) - + config = workers_setup_status["config"] sockets = workers_setup_status["sockets"] - lit_server = create_server(lit_api, lit_spec, config, sockets) + lit_server, th = create_server(lit_api, lit_spec, config, sockets, ) lit_server.response_queue = response_queue lit_server.request_queue = request_queue lit_server.workers_setup_status = workers_setup_status - + lit_server.response_buffer = dict() + th.start() + if lit_spec: logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") if stream: @@ -390,7 +392,6 @@ async def response_queue_to_buffer( stream: bool, response_executor: ThreadPoolExecutor, ): - print("running response consumer") loop = asyncio.get_running_loop() if stream: while True: @@ -405,23 +406,17 @@ async def response_queue_to_buffer( else: while True: - print("running response consumer") - try: - uid, payload = await loop.run_in_executor(response_executor, response_queue.get) - except Empty: - await asyncio.sleep(0.0001) - continue + uid, payload = await loop.run_in_executor(response_executor, response_queue.get) event = buffer.pop(uid) buffer[uid] = payload event.set() -def create_server(lit_api, lit_spec, config, sockets): - lit_server = LitServer(lit_api=lit_api, spec=lit_spec) +def create_server(lit_api, lit_spec, config, sockets, **kwargs): + lit_server = LitServer(lit_api=lit_api, spec=lit_spec, max_batch_size=8, batch_timeout=0.01) config.app = lit_server.app server = uvicorn.Server(config=config) th = threading.Thread(target=server.run, args=(sockets,), daemon=True) - th.start() - return lit_server + return lit_server, th class LitServer: def __init__( @@ -499,12 +494,12 @@ def __init__( self.setup_server() - def launch_inference_worker(self, config, sockets): - manager = mp.Manager() - self.request_queue = manager.Queue() + def launch_inference_worker(self, manager, config, sockets): + self.workers_setup_status = manager.dict() self.response_queues = [] self.process_list = [] + self.request_queues = [] config = uvicorn.Config(app=None, port=8000, log_level="info") self.workers_setup_status["config"] = config @@ -514,6 +509,9 @@ def launch_inference_worker(self, config, sockets): if len(device) == 1: device = device[0] + request_queue = manager.Queue() + self.request_queues.append(request_queue) + self.workers_setup_status[worker_id] = False response_queue = manager.Queue() self.response_queues.append(response_queue) @@ -526,7 +524,7 @@ def launch_inference_worker(self, config, sockets): self.lit_spec, device, worker_id, - self.request_queue, + request_queue, response_queue, self.max_batch_size, self.batch_timeout, @@ -550,20 +548,17 @@ def launch_inference_worker(self, config, sockets): @asynccontextmanager async def lifespan(self, app: FastAPI): - self.response_buffer = {} loop = asyncio.get_running_loop() - tasks = [] + # self.response_buffer = dict() + response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) - future = response_queue_to_buffer(self.response_queue, self.response_buffer, self.stream, response_executor) - task = loop.create_task(future, name=f"Response-reader") - tasks.append(task) + task = loop.create_task(future) print("All tasks started!") yield - for task in tasks: - task.cancel() + task.cancel() def device_identifiers(self, accelerator, device): @@ -688,6 +683,8 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl if generate_client_file: self.generate_client_file() + manager = mp.Manager() + port_msg = f"port must be a value from 1024 to 65535 but got {port}" try: port = int(port) @@ -700,10 +697,11 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl config = uvicorn.Config(app=self.app, port=port, log_level=log_level) sockets = [config.bind_socket()] - self.launch_inference_worker(config, sockets) + self.launch_inference_worker(manager, config, sockets) for p, worker_id in self.process_list: p.join() + def runv2(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): if generate_client_file: self.generate_client_file() From b4c653eaf01b4df3ad2c2bb42462883d75b41833 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 20:39:26 +0000 Subject: [PATCH 09/45] clean --- src/litserve/server.py | 50 +++++++----------------------------------- 1 file changed, 8 insertions(+), 42 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 483c6ee2..3562163b 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -340,21 +340,23 @@ def inference_worker( ): lit_api.setup(device) lit_api.device = device - if workers_setup_status: - workers_setup_status[worker_id] = True + message = f"Setup complete for worker {worker_id}." print(message) logger.info(message) config = workers_setup_status["config"] sockets = workers_setup_status["sockets"] - lit_server, th = create_server(lit_api, lit_spec, config, sockets, ) + lit_server, th = create_server(lit_api, lit_spec, config, sockets, ) # inits a new FastAPI instance for uvicorn lit_server.response_queue = response_queue lit_server.request_queue = request_queue lit_server.workers_setup_status = workers_setup_status lit_server.response_buffer = dict() th.start() + if workers_setup_status: + workers_setup_status[worker_id] = True + if lit_spec: logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") if stream: @@ -412,7 +414,7 @@ async def response_queue_to_buffer( event.set() def create_server(lit_api, lit_spec, config, sockets, **kwargs): - lit_server = LitServer(lit_api=lit_api, spec=lit_spec, max_batch_size=8, batch_timeout=0.01) + lit_server = LitServer(lit_api=lit_api, spec=lit_spec) config.app = lit_server.app server = uvicorn.Server(config=config) th = threading.Thread(target=server.run, args=(sockets,), daemon=True) @@ -495,7 +497,6 @@ def __init__( def launch_inference_worker(self, manager, config, sockets): - self.workers_setup_status = manager.dict() self.response_queues = [] self.process_list = [] @@ -509,10 +510,9 @@ def launch_inference_worker(self, manager, config, sockets): if len(device) == 1: device = device[0] + self.workers_setup_status[worker_id] = False request_queue = manager.Queue() self.request_queues.append(request_queue) - - self.workers_setup_status[worker_id] = False response_queue = manager.Queue() self.response_queues.append(response_queue) @@ -694,47 +694,13 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl if not (1024 <= port <= 65535): raise ValueError(port_msg) - config = uvicorn.Config(app=self.app, port=port, log_level=log_level) + config = uvicorn.Config(app=self.app, port=port, log_level=log_level, loop="uvloop") sockets = [config.bind_socket()] self.launch_inference_worker(manager, config, sockets) for p, worker_id in self.process_list: p.join() - - def runv2(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): - if generate_client_file: - self.generate_client_file() - - port_msg = f"port must be a value from 1024 to 65535 but got {port}" - try: - port = int(port) - except ValueError: - raise ValueError(port_msg) - - if not (1024 <= port <= 65535): - raise ValueError(port_msg) - - config = uvicorn.Config(app=self.app, port=port, log_level=log_level) - - sockets = [config.bind_socket()] - def run(server, config, sockets: list | None = None) -> None: - # self.config.setup_event_loop() - uvicorn.loops.uvloop.uvloop_setup(use_subprocess=False) - return asyncio.run(server.serve(sockets=sockets)) - - threads = [] - for i in range(4): - server = uvicorn.Server(config=config) - th = threading.Thread(target=server.run, args=(sockets,)) - th.start() - threads.append(th) - - for th in threads: - th.join() - - - def setup_auth(self): if hasattr(self.lit_api, "authorize") and callable(self.lit_api.authorize): return self.lit_api.authorize From 57c0cb1ab00e039d5c1300060caee93dcf08dd53 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 20:41:38 +0000 Subject: [PATCH 10/45] update --- src/litserve/server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 3562163b..19e560b2 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -351,7 +351,6 @@ def inference_worker( lit_server.response_queue = response_queue lit_server.request_queue = request_queue lit_server.workers_setup_status = workers_setup_status - lit_server.response_buffer = dict() th.start() if workers_setup_status: @@ -549,7 +548,7 @@ def launch_inference_worker(self, manager, config, sockets): @asynccontextmanager async def lifespan(self, app: FastAPI): loop = asyncio.get_running_loop() - # self.response_buffer = dict() + self.response_buffer = dict() response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) future = response_queue_to_buffer(self.response_queue, self.response_buffer, self.stream, response_executor) From 48a1cf042218ac73634dbbb0a8e22a6b1d07ff92 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 20:47:30 +0000 Subject: [PATCH 11/45] update --- src/litserve/server.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 19e560b2..71ae7ed7 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -338,12 +338,11 @@ def inference_worker( stream: bool, workers_setup_status: Dict[str, bool] = None, ): + lit_api.setup(device) lit_api.device = device - message = f"Setup complete for worker {worker_id}." - print(message) - logger.info(message) + print(f"Setup complete for worker {worker_id}.") config = workers_setup_status["config"] sockets = workers_setup_status["sockets"] @@ -554,7 +553,6 @@ async def lifespan(self, app: FastAPI): future = response_queue_to_buffer(self.response_queue, self.response_buffer, self.stream, response_executor) task = loop.create_task(future) - print("All tasks started!") yield task.cancel() From 058f4a7661ccccd74bde2b41b8a1adcb4fa248c9 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 20:48:00 +0000 Subject: [PATCH 12/45] update --- src/litserve/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litserve/server.py b/src/litserve/server.py index 71ae7ed7..6a2c3bf6 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -556,6 +556,7 @@ async def lifespan(self, app: FastAPI): yield task.cancel() + logger.debug("Shutting down response queue to buffer task") def device_identifiers(self, accelerator, device): From 593c14f4a3725f453d6c8893dc9714fb2b7ade42 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 20:48:49 +0000 Subject: [PATCH 13/45] fixes --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 6a2c3bf6..b177d90f 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -350,6 +350,7 @@ def inference_worker( lit_server.response_queue = response_queue lit_server.request_queue = request_queue lit_server.workers_setup_status = workers_setup_status + lit_server.response_buffer = dict() th.start() if workers_setup_status: @@ -547,7 +548,6 @@ def launch_inference_worker(self, manager, config, sockets): @asynccontextmanager async def lifespan(self, app: FastAPI): loop = asyncio.get_running_loop() - self.response_buffer = dict() response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) future = response_queue_to_buffer(self.response_queue, self.response_buffer, self.stream, response_executor) From c9a938c2c7993f59794c7e6779c90ba0c70f8117 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 22:31:53 +0000 Subject: [PATCH 14/45] update --- src/litserve/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index b177d90f..906e8a38 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -416,6 +416,8 @@ def create_server(lit_api, lit_spec, config, sockets, **kwargs): lit_server = LitServer(lit_api=lit_api, spec=lit_spec) config.app = lit_server.app server = uvicorn.Server(config=config) + # ctx = mp.get_context("fork") + # p = ctx.Process(target=server.run, args=(sockets,), daemon=True) th = threading.Thread(target=server.run, args=(sockets,), daemon=True) return lit_server, th @@ -530,7 +532,6 @@ def launch_inference_worker(self, manager, config, sockets): self.stream, self.workers_setup_status, ), - daemon=True, ) process.start() self.process_list.append((process, worker_id)) From 43361f8c3002b49eddd75bd28036dd5e0ba90d08 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 22:56:02 +0000 Subject: [PATCH 15/45] MP --- src/litserve/server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 906e8a38..abea1cab 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -416,10 +416,10 @@ def create_server(lit_api, lit_spec, config, sockets, **kwargs): lit_server = LitServer(lit_api=lit_api, spec=lit_spec) config.app = lit_server.app server = uvicorn.Server(config=config) - # ctx = mp.get_context("fork") - # p = ctx.Process(target=server.run, args=(sockets,), daemon=True) - th = threading.Thread(target=server.run, args=(sockets,), daemon=True) - return lit_server, th + ctx = mp.get_context("fork") + w = ctx.Process(target=server.run, args=(sockets,), daemon=True) + # w = threading.Thread(target=server.run, args=(sockets,), daemon=True) + return lit_server, w class LitServer: def __init__( From 7a7a9572ea85102a8a5fb462697199a2609e5330 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 31 Jul 2024 23:36:51 +0000 Subject: [PATCH 16/45] update --- src/litserve/server.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index abea1cab..93c251da 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -81,17 +81,22 @@ def get_batch_from_uid(uids, lit_api, request_buffer): def collate_requests( lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float ) -> Tuple[List, List]: + apply_timeout = lit_api.request_timeout not in (-1, False) payloads = [] timed_out_uids = [] + + uid, timestamp, x_enc = request_queue.get(block=True) entered_at = time.monotonic() end_time = entered_at + batch_timeout - apply_timeout = lit_api.request_timeout not in (-1, False) + if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: + timed_out_uids.append(uid) + else: + payloads.append((uid, x_enc)) while time.monotonic() < end_time and len(payloads) < max_batch_size: remaining_time = end_time - time.monotonic() if remaining_time <= 0: break - try: uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: @@ -416,9 +421,9 @@ def create_server(lit_api, lit_spec, config, sockets, **kwargs): lit_server = LitServer(lit_api=lit_api, spec=lit_spec) config.app = lit_server.app server = uvicorn.Server(config=config) - ctx = mp.get_context("fork") - w = ctx.Process(target=server.run, args=(sockets,), daemon=True) - # w = threading.Thread(target=server.run, args=(sockets,), daemon=True) + # ctx = mp.get_context("fork") + # w = ctx.Process(target=server.run, args=(sockets,), daemon=True) + w = threading.Thread(target=server.run, args=(sockets,), daemon=True) return lit_server, w class LitServer: From e0fb5c9e2a74f624acfe85b9de50abae7edc1203 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 00:10:23 +0000 Subject: [PATCH 17/45] update --- src/litserve/server.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 93c251da..d7a111ca 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -78,25 +78,21 @@ def get_batch_from_uid(uids, lit_api, request_buffer): return batches + def collate_requests( lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float ) -> Tuple[List, List]: - apply_timeout = lit_api.request_timeout not in (-1, False) payloads = [] timed_out_uids = [] - - uid, timestamp, x_enc = request_queue.get(block=True) entered_at = time.monotonic() end_time = entered_at + batch_timeout - if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: - timed_out_uids.append(uid) - else: - payloads.append((uid, x_enc)) + apply_timeout = lit_api.request_timeout not in (-1, False) while time.monotonic() < end_time and len(payloads) < max_batch_size: remaining_time = end_time - time.monotonic() if remaining_time <= 0: break + try: uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: @@ -421,9 +417,9 @@ def create_server(lit_api, lit_spec, config, sockets, **kwargs): lit_server = LitServer(lit_api=lit_api, spec=lit_spec) config.app = lit_server.app server = uvicorn.Server(config=config) - # ctx = mp.get_context("fork") - # w = ctx.Process(target=server.run, args=(sockets,), daemon=True) - w = threading.Thread(target=server.run, args=(sockets,), daemon=True) + ctx = mp.get_context("fork") + w = ctx.Process(target=server.run, args=(sockets,), daemon=True) + # w = threading.Thread(target=server.run, args=(sockets,), daemon=True) return lit_server, w class LitServer: From e4b5c98f75b2626ba702f836296d1b65f2fc7189 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 10:26:10 +0000 Subject: [PATCH 18/45] format --- src/litserve/server.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index d7a111ca..0b2689a2 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -20,7 +20,6 @@ import pickle import shutil import sys -import threading import time import uuid from concurrent.futures import ThreadPoolExecutor @@ -78,7 +77,6 @@ def get_batch_from_uid(uids, lit_api, request_buffer): return batches - def collate_requests( lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float ) -> Tuple[List, List]: @@ -339,7 +337,6 @@ def inference_worker( stream: bool, workers_setup_status: Dict[str, bool] = None, ): - lit_api.setup(device) lit_api.device = device @@ -347,11 +344,16 @@ def inference_worker( config = workers_setup_status["config"] sockets = workers_setup_status["sockets"] - lit_server, th = create_server(lit_api, lit_spec, config, sockets, ) # inits a new FastAPI instance for uvicorn + lit_server, th = create_server( + lit_api, + lit_spec, + config, + sockets, + ) # inits a new FastAPI instance for uvicorn lit_server.response_queue = response_queue lit_server.request_queue = request_queue lit_server.workers_setup_status = workers_setup_status - lit_server.response_buffer = dict() + lit_server.response_buffer = {} th.start() if workers_setup_status: @@ -413,6 +415,7 @@ async def response_queue_to_buffer( buffer[uid] = payload event.set() + def create_server(lit_api, lit_spec, config, sockets, **kwargs): lit_server = LitServer(lit_api=lit_api, spec=lit_spec) config.app = lit_server.app @@ -422,6 +425,7 @@ def create_server(lit_api, lit_spec, config, sockets, **kwargs): # w = threading.Thread(target=server.run, args=(sockets,), daemon=True) return lit_server, w + class LitServer: def __init__( self, @@ -497,7 +501,6 @@ def __init__( self.workers = self.devices * self.workers_per_device self.setup_server() - def launch_inference_worker(self, manager, config, sockets): self.workers_setup_status = manager.dict() self.response_queues = [] @@ -546,7 +549,7 @@ def launch_inference_worker(self, manager, config, sockets): spec.setup(server_copy) except Exception as e: raise e - + @asynccontextmanager async def lifespan(self, app: FastAPI): loop = asyncio.get_running_loop() @@ -560,7 +563,6 @@ async def lifespan(self, app: FastAPI): task.cancel() logger.debug("Shutting down response queue to buffer task") - def device_identifiers(self, accelerator, device): if isinstance(device, Sequence): return [f"{accelerator}:{el}" for el in device] From cc5720adeb8bcc86f6cdee577a003932ffe8f27c Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 12:56:48 +0000 Subject: [PATCH 19/45] update --- src/litserve/server.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 0b2689a2..525ac47e 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -511,6 +511,16 @@ def launch_inference_worker(self, manager, config, sockets): self.workers_setup_status["config"] = config self.workers_setup_status["sockets"] = sockets + for spec in self._specs: + # Objects of Server class are referenced (not copied) + logging.debug(f"shallow copy for Server is created for for spec {spec}") + server_copy = copy.copy(self) + del server_copy.app + try: + spec.setup(server_copy) + except Exception as e: + raise e + for worker_id, device in enumerate(self.devices * self.workers_per_device): if len(device) == 1: device = device[0] @@ -540,16 +550,6 @@ def launch_inference_worker(self, manager, config, sockets): process.start() self.process_list.append((process, worker_id)) - for spec in self._specs: - # Objects of Server class are referenced (not copied) - logging.debug(f"shallow copy for Server is created for for spec {spec}") - server_copy = copy.copy(self) - del server_copy.app - try: - spec.setup(server_copy) - except Exception as e: - raise e - @asynccontextmanager async def lifespan(self, app: FastAPI): loop = asyncio.get_running_loop() From e38f300a6313c6a13d1c47dce3d4a52ca897792d Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 13:57:42 +0000 Subject: [PATCH 20/45] single queue --- src/litserve/server.py | 83 ++++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 525ac47e..71b767ec 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -92,11 +92,11 @@ def collate_requests( break try: - uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) + worker_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: - timed_out_uids.append(uid) + timed_out_uids.append((worker_id, uid)) else: - payloads.append((uid, x_enc)) + payloads.append((worker_id, uid, x_enc)) except Empty: continue @@ -155,7 +155,7 @@ def run_batched_loop( lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, - response_queue: Queue, + response_queues: List[Queue], max_batch_size: int, batch_timeout: float, ): @@ -167,18 +167,18 @@ def run_batched_loop( batch_timeout, ) - for uid in timed_out_uids: + for worker_id, uid in timed_out_uids: logger.error( f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and " "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queue.put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[worker_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) if not batches: continue logger.debug(f"{len(batches)} batched requests received") - uids, inputs = zip(*batches) + worker_ids, uids, inputs = zip(*batches) try: contexts = [{}] * len(inputs) if hasattr(lit_spec, "populate_context"): @@ -196,10 +196,10 @@ def run_batched_loop( x = lit_api.batch(x) y = _inject_context(contexts, lit_api.predict, x) outputs = lit_api.unbatch(y) - for y, uid, context in zip(outputs, uids, contexts): + for worker_id, y, uid, context in zip(worker_ids, outputs, uids, contexts): y_enc = _inject_context(context, lit_api.encode_response, y) - response_queue.put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[worker_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: logger.exception( @@ -207,8 +207,8 @@ def run_batched_loop( "Please check the error trace for more details." ) err_pkl = pickle.dumps(e) - for uid in uids: - response_queue.put((uid, (err_pkl, LitAPIStatus.ERROR))) + for worker_id, uid in zip(worker_ids, uids): + response_queues[worker_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queue: Queue): @@ -344,17 +344,19 @@ def inference_worker( config = workers_setup_status["config"] sockets = workers_setup_status["sockets"] - lit_server, th = create_server( - lit_api, - lit_spec, - config, - sockets, - ) # inits a new FastAPI instance for uvicorn - lit_server.response_queue = response_queue - lit_server.request_queue = request_queue - lit_server.workers_setup_status = workers_setup_status - lit_server.response_buffer = {} - th.start() + # lit_server, th = create_server( + # lit_api, + # lit_spec, + # config, + # sockets, + # ) # inits a new FastAPI instance for uvicorn + # lit_server.response_queue = response_queue + # lit_server.request_queue = request_queue + # lit_server.workers_setup_status = workers_setup_status + # lit_server.response_buffer = {} + # th.start() + + lit_api.worker_id = worker_id if workers_setup_status: workers_setup_status[worker_id] = True @@ -460,6 +462,7 @@ def __init__( lit_api.request_timeout = timeout lit_api.sanitize(max_batch_size, spec=spec) self.app = FastAPI(lifespan=self.lifespan) + self.app.worker_id = None # gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448 if not stream: self.app.add_middleware(GZipMiddleware, minimum_size=1000) @@ -505,7 +508,7 @@ def launch_inference_worker(self, manager, config, sockets): self.workers_setup_status = manager.dict() self.response_queues = [] self.process_list = [] - self.request_queues = [] + self.request_queue = manager.Queue() config = uvicorn.Config(app=None, port=8000, log_level="info") self.workers_setup_status["config"] = config @@ -521,15 +524,15 @@ def launch_inference_worker(self, manager, config, sockets): except Exception as e: raise e + for worker_id, device in enumerate(self.devices * self.workers_per_device): + response_queue = manager.Queue() + self.response_queues.append(response_queue) + for worker_id, device in enumerate(self.devices * self.workers_per_device): if len(device) == 1: device = device[0] self.workers_setup_status[worker_id] = False - request_queue = manager.Queue() - self.request_queues.append(request_queue) - response_queue = manager.Queue() - self.response_queues.append(response_queue) ctx = mp.get_context("spawn") process = ctx.Process( @@ -539,8 +542,8 @@ def launch_inference_worker(self, manager, config, sockets): self.lit_spec, device, worker_id, - request_queue, - response_queue, + self.request_queue, + self.response_queues, self.max_batch_size, self.batch_timeout, self.stream, @@ -553,9 +556,11 @@ def launch_inference_worker(self, manager, config, sockets): @asynccontextmanager async def lifespan(self, app: FastAPI): loop = asyncio.get_running_loop() + self.response_buffer = {} + response_queue = self.response_queues[app.worker_id] response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) - future = response_queue_to_buffer(self.response_queue, self.response_buffer, self.stream, response_executor) + future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) task = loop.create_task(future) yield @@ -609,6 +614,7 @@ async def health(request: Request) -> Response: return Response(content="not ready", status_code=503) async def predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: + worker_id = self.app.worker_id uid = uuid.uuid4() event = asyncio.Event() self.response_buffer[uid] = event @@ -623,7 +629,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) else: payload = await request.json() - self.request_queue.put_nowait((uid, time.monotonic(), payload)) + self.request_queue.put_nowait((worker_id, uid, time.monotonic(), payload)) await event.wait() response, status = self.response_buffer.pop(uid) @@ -700,6 +706,21 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl sockets = [config.bind_socket()] self.launch_inference_worker(manager, config, sockets) + + servers = [] + for worker_id, device in enumerate(self.devices * self.workers_per_device): + self.app.worker_id = worker_id + app = copy.copy(self.app) + config = uvicorn.Config(app=app, port=port, log_level=log_level, loop="uvloop") + server = uvicorn.Server(config=config) + ctx = mp.get_context("fork") + w = ctx.Process(target=server.run, args=(sockets,)) + w.start() + servers.append(w) + + for s in servers: + s.join() + for p, worker_id in self.process_list: p.join() From 0ee56d8a77f4eb314aeece63e215ed6f7af5daab Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 15:25:00 +0000 Subject: [PATCH 21/45] update --- src/litserve/server.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 71b767ec..2a37ecda 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -462,7 +462,7 @@ def __init__( lit_api.request_timeout = timeout lit_api.sanitize(max_batch_size, spec=spec) self.app = FastAPI(lifespan=self.lifespan) - self.app.worker_id = None + self.app.response_queue_id = None # gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448 if not stream: self.app.add_middleware(GZipMiddleware, minimum_size=1000) @@ -506,7 +506,6 @@ def __init__( def launch_inference_worker(self, manager, config, sockets): self.workers_setup_status = manager.dict() - self.response_queues = [] self.process_list = [] self.request_queue = manager.Queue() @@ -524,10 +523,6 @@ def launch_inference_worker(self, manager, config, sockets): except Exception as e: raise e - for worker_id, device in enumerate(self.devices * self.workers_per_device): - response_queue = manager.Queue() - self.response_queues.append(response_queue) - for worker_id, device in enumerate(self.devices * self.workers_per_device): if len(device) == 1: device = device[0] @@ -558,7 +553,7 @@ async def lifespan(self, app: FastAPI): loop = asyncio.get_running_loop() self.response_buffer = {} - response_queue = self.response_queues[app.worker_id] + response_queue = self.response_queues[app.response_queue_id] response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) task = loop.create_task(future) @@ -614,7 +609,7 @@ async def health(request: Request) -> Response: return Response(content="not ready", status_code=503) async def predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: - worker_id = self.app.worker_id + response_queue_id = self.app.response_queue_id uid = uuid.uuid4() event = asyncio.Event() self.response_buffer[uid] = event @@ -629,7 +624,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) else: payload = await request.json() - self.request_queue.put_nowait((worker_id, uid, time.monotonic(), payload)) + self.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), payload)) await event.wait() response, status = self.response_buffer.pop(uid) @@ -705,11 +700,19 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl config = uvicorn.Config(app=self.app, port=port, log_level=log_level, loop="uvloop") sockets = [config.bind_socket()] + devices = [self.devices * self.workers_per_device] + num_uvicorn_servers = 8 + + self.response_queues = [] + for response_queue_id in range(num_uvicorn_servers): + response_queue = manager.Queue() + self.response_queues.append(response_queue) + self.launch_inference_worker(manager, config, sockets) servers = [] - for worker_id, device in enumerate(self.devices * self.workers_per_device): - self.app.worker_id = worker_id + for response_queue_id in range(num_uvicorn_servers): + self.app.response_queue_id = response_queue_id app = copy.copy(self.app) config = uvicorn.Config(app=app, port=port, log_level=log_level, loop="uvloop") server = uvicorn.Server(config=config) From 9743ac8f1b273b1bf369a4433e80b3f6187aea9f Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 15:35:24 +0000 Subject: [PATCH 22/45] update --- src/litserve/server.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 2a37ecda..69e2a988 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -682,7 +682,14 @@ def generate_client_file(self): except Exception as e: print(f"Error copying file: {e}") - def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_client_file: bool = True, **kwargs): + def run( + self, + port: Union[str, int] = 8000, + num_uvicorn_servers: Optional[int] = None, + log_level: str = "info", + generate_client_file: bool = True, + **kwargs, + ): if generate_client_file: self.generate_client_file() @@ -700,8 +707,8 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl config = uvicorn.Config(app=self.app, port=port, log_level=log_level, loop="uvloop") sockets = [config.bind_socket()] - devices = [self.devices * self.workers_per_device] - num_uvicorn_servers = 8 + if num_uvicorn_servers is None: + num_uvicorn_servers = len(self.devices * self.workers_per_device) self.response_queues = [] for response_queue_id in range(num_uvicorn_servers): From a80523fdf785e09d37075b7eb372bbdf24599b8b Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 15:37:42 +0000 Subject: [PATCH 23/45] clean up --- src/litserve/server.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 69e2a988..4b8c61b6 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -342,20 +342,6 @@ def inference_worker( print(f"Setup complete for worker {worker_id}.") - config = workers_setup_status["config"] - sockets = workers_setup_status["sockets"] - # lit_server, th = create_server( - # lit_api, - # lit_spec, - # config, - # sockets, - # ) # inits a new FastAPI instance for uvicorn - # lit_server.response_queue = response_queue - # lit_server.request_queue = request_queue - # lit_server.workers_setup_status = workers_setup_status - # lit_server.response_buffer = {} - # th.start() - lit_api.worker_id = worker_id if workers_setup_status: @@ -418,16 +404,6 @@ async def response_queue_to_buffer( event.set() -def create_server(lit_api, lit_spec, config, sockets, **kwargs): - lit_server = LitServer(lit_api=lit_api, spec=lit_spec) - config.app = lit_server.app - server = uvicorn.Server(config=config) - ctx = mp.get_context("fork") - w = ctx.Process(target=server.run, args=(sockets,), daemon=True) - # w = threading.Thread(target=server.run, args=(sockets,), daemon=True) - return lit_server, w - - class LitServer: def __init__( self, From 2215fc16fc83bccf46585769132ece0ecc482397 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 15:47:13 +0000 Subject: [PATCH 24/45] update --- src/litserve/server.py | 71 +++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 4b8c61b6..cb3cb704 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -92,11 +92,11 @@ def collate_requests( break try: - worker_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) + response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=min(remaining_time, 0.001)) if apply_timeout and time.monotonic() - timestamp > lit_api.request_timeout: - timed_out_uids.append((worker_id, uid)) + timed_out_uids.append((response_queue_id, uid)) else: - payloads.append((worker_id, uid, x_enc)) + payloads.append((response_queue_id, uid, x_enc)) except Empty: continue @@ -104,10 +104,10 @@ def collate_requests( return payloads, timed_out_uids -def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queue: Queue): +def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): while True: try: - uid, timestamp, x_enc = request_queue.get(timeout=1.0) + response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) except (Empty, ValueError): continue @@ -119,7 +119,7 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queue.put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) continue try: context = {} @@ -140,7 +140,7 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re lit_api.encode_response, y, ) - response_queue.put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: logger.exception( "LitAPI ran into an error while processing the request uid=%s.\n" @@ -148,7 +148,7 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re uid, ) err_pkl = pickle.dumps(e) - response_queue.put((uid, (err_pkl, LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) def run_batched_loop( @@ -167,18 +167,18 @@ def run_batched_loop( batch_timeout, ) - for worker_id, uid in timed_out_uids: + for response_queue_id, uid in timed_out_uids: logger.error( f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and " "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queues[worker_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) if not batches: continue logger.debug(f"{len(batches)} batched requests received") - worker_ids, uids, inputs = zip(*batches) + response_queue_ids, uids, inputs = zip(*batches) try: contexts = [{}] * len(inputs) if hasattr(lit_spec, "populate_context"): @@ -196,10 +196,10 @@ def run_batched_loop( x = lit_api.batch(x) y = _inject_context(contexts, lit_api.predict, x) outputs = lit_api.unbatch(y) - for worker_id, y, uid, context in zip(worker_ids, outputs, uids, contexts): + for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts): y_enc = _inject_context(context, lit_api.encode_response, y) - response_queues[worker_id].put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: logger.exception( @@ -207,14 +207,14 @@ def run_batched_loop( "Please check the error trace for more details." ) err_pkl = pickle.dumps(e) - for worker_id, uid in zip(worker_ids, uids): - response_queues[worker_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) + for response_queue_id, uid in zip(response_queue_ids, uids): + response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) -def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queue: Queue): +def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): while True: try: - uid, timestamp, x_enc = request_queue.get(timeout=1.0) + response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) logger.debug("uid=%s", uid) except (Empty, ValueError): continue @@ -227,7 +227,7 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queue.put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) continue try: @@ -251,22 +251,22 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, ) for y_enc in y_enc_gen: y_enc = lit_api.format_encoded_response(y_enc) - response_queue.put((uid, (y_enc, LitAPIStatus.OK))) - response_queue.put((uid, ("", LitAPIStatus.FINISH_STREAMING))) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) except Exception as e: logger.exception( "LitAPI ran into an error while processing the streaming request uid=%s.\n" "Please check the error trace for more details.", uid, ) - response_queue.put((uid, (pickle.dumps(e), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (pickle.dumps(e), LitAPIStatus.ERROR))) def run_batched_streaming_loop( lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, - response_queue: Queue, + response_queues: List[Queue], max_batch_size: int, batch_timeout: float, ): @@ -277,13 +277,13 @@ def run_batched_streaming_loop( max_batch_size, batch_timeout, ) - for uid in timed_out_uids: + for response_queue_id, uid in timed_out_uids: logger.error( f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and " "has been timed out. " "You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)." ) - response_queue.put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) if not batches: continue @@ -311,10 +311,10 @@ def run_batched_streaming_loop( for y_batch in y_enc_iter: for y_enc, uid in zip(y_batch, uids): y_enc = lit_api.format_encoded_response(y_enc) - response_queue.put((uid, (y_enc, LitAPIStatus.OK))) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) for uid in uids: - response_queue.put((uid, ("", LitAPIStatus.FINISH_STREAMING))) + response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) except Exception as e: logger.exception( @@ -322,7 +322,7 @@ def run_batched_streaming_loop( "Please check the error trace for more details." ) err_pkl = pickle.dumps(e) - response_queue.put((uid, (err_pkl, LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) def inference_worker( @@ -331,7 +331,7 @@ def inference_worker( device: str, worker_id: int, request_queue: Queue, - response_queue: Queue, + response_queues: List[Queue], max_batch_size: int, batch_timeout: float, stream: bool, @@ -342,7 +342,7 @@ def inference_worker( print(f"Setup complete for worker {worker_id}.") - lit_api.worker_id = worker_id + # lit_api.worker_id = worker_id if workers_setup_status: workers_setup_status[worker_id] = True @@ -351,19 +351,19 @@ def inference_worker( logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") if stream: if max_batch_size > 1: - run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queue, max_batch_size, batch_timeout) + run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) else: - run_streaming_loop(lit_api, lit_spec, request_queue, response_queue) + run_streaming_loop(lit_api, lit_spec, request_queue, response_queues) return if max_batch_size > 1: - run_batched_loop(lit_api, lit_spec, request_queue, response_queue, max_batch_size, batch_timeout) + run_batched_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) else: run_single_loop( lit_api, lit_spec, request_queue, - response_queue, + response_queues, ) @@ -419,7 +419,6 @@ def __init__( spec: Optional[LitSpec] = None, max_payload_size=None, ): - self.litserve_locals = locals() if batch_timeout > timeout and timeout not in (False, -1): raise ValueError("batch_timeout must be less than timeout") if max_batch_size <= 0: @@ -684,7 +683,7 @@ def run( sockets = [config.bind_socket()] if num_uvicorn_servers is None: - num_uvicorn_servers = len(self.devices * self.workers_per_device) + num_uvicorn_servers = len(self.workers) self.response_queues = [] for response_queue_id in range(num_uvicorn_servers): @@ -707,7 +706,7 @@ def run( for s in servers: s.join() - for p, worker_id in self.process_list: + for p, _ in self.process_list: p.join() def setup_auth(self): From 68dcc06a8513ee1b1bd2c595d9ef25a718b8d63f Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 15:55:14 +0000 Subject: [PATCH 25/45] update --- src/litserve/server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index cb3cb704..0e6d8b8d 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -609,6 +609,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) return response async def stream_predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: + response_queue_id = self.app.response_queue_id uid = uuid.uuid4() event = asyncio.Event() q = deque() @@ -618,7 +619,7 @@ async def stream_predict(request: self.request_type, background_tasks: Backgroun payload = request if self.request_type == Request: payload = await request.json() - self.request_queue.put((uid, time.monotonic(), payload)) + self.request_queue.put((response_queue_id, uid, time.monotonic(), payload)) return StreamingResponse(self.data_streamer(q, data_available=event)) From 8a9e790633f1e3cab4600551a477317d2b624fe8 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 16:22:07 +0000 Subject: [PATCH 26/45] update --- src/litserve/server.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 0e6d8b8d..1d3db9d6 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -479,15 +479,10 @@ def __init__( self.workers = self.devices * self.workers_per_device self.setup_server() - def launch_inference_worker(self, manager, config, sockets): + def launch_inference_worker(self, manager): self.workers_setup_status = manager.dict() - self.process_list = [] self.request_queue = manager.Queue() - config = uvicorn.Config(app=None, port=8000, log_level="info") - self.workers_setup_status["config"] = config - self.workers_setup_status["sockets"] = sockets - for spec in self._specs: # Objects of Server class are referenced (not copied) logging.debug(f"shallow copy for Server is created for for spec {spec}") @@ -498,6 +493,7 @@ def launch_inference_worker(self, manager, config, sockets): except Exception as e: raise e + process_list = [] for worker_id, device in enumerate(self.devices * self.workers_per_device): if len(device) == 1: device = device[0] @@ -521,7 +517,8 @@ def launch_inference_worker(self, manager, config, sockets): ), ) process.start() - self.process_list.append((process, worker_id)) + process_list.append(process) + return process_list @asynccontextmanager async def lifespan(self, app: FastAPI): @@ -691,8 +688,7 @@ def run( response_queue = manager.Queue() self.response_queues.append(response_queue) - self.launch_inference_worker(manager, config, sockets) - + litserve_workers = self.launch_inference_worker(manager) servers = [] for response_queue_id in range(num_uvicorn_servers): self.app.response_queue_id = response_queue_id @@ -707,8 +703,8 @@ def run( for s in servers: s.join() - for p, _ in self.process_list: - p.join() + for w in litserve_workers: + w.join() def setup_auth(self): if hasattr(self.lit_api, "authorize") and callable(self.lit_api.authorize): From cf6a4784cdaaddb44bb32920f3f31acbdab507d2 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 17:06:41 +0000 Subject: [PATCH 27/45] fix tests --- src/litserve/server.py | 36 +++++++------ tests/conftest.py | 15 +++++- tests/test_auth.py | 39 ++++++++------- tests/test_batch.py | 26 +++++----- tests/test_compression.py | 37 +++++++------- tests/test_examples.py | 23 +++++---- tests/test_form.py | 51 ++++++++++--------- tests/test_lit_server.py | 83 +++++++++++++++++------------- tests/test_simple.py | 103 +++++++++++++++++++------------------- 9 files changed, 230 insertions(+), 183 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 1d3db9d6..f0bcf27c 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -479,10 +479,16 @@ def __init__( self.workers = self.devices * self.workers_per_device self.setup_server() - def launch_inference_worker(self, manager): + def launch_inference_worker(self, num_uvicorn_servers: int): + manager = mp.Manager() self.workers_setup_status = manager.dict() self.request_queue = manager.Queue() + self.response_queues = [] + for _ in range(num_uvicorn_servers): + response_queue = manager.Queue() + self.response_queues.append(response_queue) + for spec in self._specs: # Objects of Server class are referenced (not copied) logging.debug(f"shallow copy for Server is created for for spec {spec}") @@ -518,7 +524,7 @@ def launch_inference_worker(self, manager): ) process.start() process_list.append(process) - return process_list + return manager, process_list @asynccontextmanager async def lifespan(self, app: FastAPI): @@ -666,8 +672,6 @@ def run( if generate_client_file: self.generate_client_file() - manager = mp.Manager() - port_msg = f"port must be a value from 1024 to 65535 but got {port}" try: port = int(port) @@ -683,12 +687,19 @@ def run( if num_uvicorn_servers is None: num_uvicorn_servers = len(self.workers) - self.response_queues = [] - for response_queue_id in range(num_uvicorn_servers): - response_queue = manager.Queue() - self.response_queues.append(response_queue) + manager, litserve_workers = self.launch_inference_worker(num_uvicorn_servers) + + servers = self._start_server(port, num_uvicorn_servers, log_level, sockets) - litserve_workers = self.launch_inference_worker(manager) + for s in servers: + s.join() + + for w in litserve_workers: + w.terminate() + w.join() + manager.shutdown() + + def _start_server(self, port, num_uvicorn_servers, log_level, sockets): servers = [] for response_queue_id in range(num_uvicorn_servers): self.app.response_queue_id = response_queue_id @@ -699,12 +710,7 @@ def run( w = ctx.Process(target=server.run, args=(sockets,)) w.start() servers.append(w) - - for s in servers: - s.join() - - for w in litserve_workers: - w.join() + return servers def setup_auth(self): if hasattr(self.lit_api, "authorize") and callable(self.lit_api.authorize): diff --git a/tests/conftest.py b/tests/conftest.py index ced27706..dd985796 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager import time import psutil from typing import Generator @@ -97,9 +98,21 @@ def simple_batched_stream_api(): return SimpleBatchedStreamAPI() +@contextmanager +def wrap_litserve_start(server): + server.app.response_queue_id = 0 + manager, processes = server.launch_inference_worker(1) + yield server + for p in processes: + p.terminate() + manager.shutdown() + + @pytest.fixture() def lit_server(simple_litapi): - return LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10) + server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10) + with wrap_litserve_start(server) as s: + yield s @pytest.fixture() diff --git a/tests/test_auth.py b/tests/test_auth.py index 1f8bbc6d..f6a71e1d 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -18,6 +18,7 @@ from litserve import LitAPI, LitServer import litserve.server +from tests.conftest import wrap_litserve_start class SimpleAuthedLitAPI(LitAPI): @@ -40,20 +41,20 @@ def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())): def test_authorized_custom(): server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - - with TestClient(server.app) as client: - input = {"input": 4.0} - response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input) - assert response.status_code == 200 + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + input = {"input": 4.0} + response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input) + assert response.status_code == 200 def test_not_authorized_custom(): server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - - with TestClient(server.app) as client: - input = {"input": 4.0} - response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input) - assert response.status_code == 401 + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + input = {"input": 4.0} + response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input) + assert response.status_code == 401 class SimpleLitAPI(LitAPI): @@ -74,10 +75,11 @@ def test_authorized_api_key(): litserve.server.LIT_SERVER_API_KEY = "abcd" server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with TestClient(server.app) as client: - input = {"input": 4.0} - response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input) - assert response.status_code == 200 + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + input = {"input": 4.0} + response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input) + assert response.status_code == 200 litserve.server.LIT_SERVER_API_KEY = None @@ -86,9 +88,10 @@ def test_not_authorized_api_key(): litserve.server.LIT_SERVER_API_KEY = "abcd" server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with TestClient(server.app) as client: - input = {"input": 4.0} - response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input) - assert response.status_code == 401 + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + input = {"input": 4.0} + response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input) + assert response.status_code == 401 litserve.server.LIT_SERVER_API_KEY = None diff --git a/tests/test_batch.py b/tests/test_batch.py index d25b3c9a..8956f55a 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -23,7 +23,7 @@ from httpx import AsyncClient from litserve import LitAPI, LitServer - +from tests.conftest import wrap_litserve_start import torch import torch.nn as nn @@ -86,10 +86,11 @@ async def test_batched(): api = SimpleLitAPI() server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=2, batch_timeout=4) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response1 = ac.post("/predict", json={"input": 4.0}) - response2 = ac.post("/predict", json={"input": 5.0}) - response1, response2 = await asyncio.gather(response1, response2) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response1 = ac.post("/predict", json={"input": 4.0}) + response2 = ac.post("/predict", json={"input": 5.0}) + response1, response2 = await asyncio.gather(response1, response2) assert response1.json() == {"output": 9.0} assert response2.json() == {"output": 11.0} @@ -99,11 +100,11 @@ async def test_batched(): async def test_unbatched(): api = SimpleLitAPI2() server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=1) - - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response1 = ac.post("/predict", json={"input": 4.0}) - response2 = ac.post("/predict", json={"input": 5.0}) - response1, response2 = await asyncio.gather(response1, response2) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response1 = ac.post("/predict", json={"input": 4.0}) + response2 = ac.post("/predict", json={"input": 5.0}) + response1, response2 = await asyncio.gather(response1, response2) assert response1.json() == {"output": 9.0} assert response2.json() == {"output": 11.0} @@ -127,8 +128,9 @@ def put(self, *args): def test_batched_loop(): requests_queue = Queue() - requests_queue.put(("uuid-1234", time.monotonic(), {"input": 4.0})) - requests_queue.put(("uuid-1235", time.monotonic(), {"input": 5.0})) + response_queue_id = 0 + requests_queue.put((response_queue_id, "uuid-1234", time.monotonic(), {"input": 4.0})) + requests_queue.put((response_queue_id, "uuid-1235", time.monotonic(), {"input": 5.0})) lit_api_mock = MagicMock() lit_api_mock.request_timeout = 2 diff --git a/tests/test_compression.py b/tests/test_compression.py index aa6cb948..3979d29d 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -14,7 +14,7 @@ from fastapi import Request, Response from fastapi.testclient import TestClient - +from tests.conftest import wrap_litserve_start from litserve import LitAPI, LitServer # trivially compressible content @@ -38,20 +38,21 @@ def encode_response(self, output) -> Response: def test_compression(): server = LitServer(LargeOutputLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - # compressed - with TestClient(server.app) as client: - response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={}) - assert response.status_code == 200 - assert response.headers["Content-Encoding"] == "gzip" - content_length = int(response.headers["Content-Length"]) - assert 0 < content_length < 100000 - assert response.json() == test_output - - # uncompressed - with TestClient(server.app) as client: - response = client.post("/predict", headers={"Accept-Encoding": ""}, json={}) - assert response.status_code == 200 - assert "Content-Encoding" not in response.headers - content_length = int(response.headers["Content-Length"]) - assert content_length > 100000 - assert response.json() == test_output + with wrap_litserve_start(server) as server: + # compressed + with TestClient(server.app) as client: + response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={}) + assert response.status_code == 200 + assert response.headers["Content-Encoding"] == "gzip" + content_length = int(response.headers["Content-Length"]) + assert 0 < content_length < 100000 + assert response.json() == test_output + + # uncompressed + with TestClient(server.app) as client: + response = client.post("/predict", headers={"Accept-Encoding": ""}, json={}) + assert response.status_code == 200 + assert "Content-Encoding" not in response.headers + content_length = int(response.headers["Content-Length"]) + assert content_length > 100000 + assert response.json() == test_output diff --git a/tests/test_examples.py b/tests/test_examples.py index 6f2ded49..3b14793c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,7 +1,7 @@ import pytest from asgi_lifespan import LifespanManager from httpx import AsyncClient - +from tests.conftest import wrap_litserve_start import litserve as ls @@ -9,24 +9,27 @@ async def test_simple_pytorch_api(): api = ls.examples.SimpleTorchAPI() server = ls.LitServer(api, accelerator="cpu") - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response = await ac.post("/predict", json={"input": 4.0}) - assert response.json() == {"output": 9.0} + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response = await ac.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 9.0} @pytest.mark.asyncio() async def test_simple_batched_api(): api = ls.examples.SimpleBatchedAPI() server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.1) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response = await ac.post("/predict", json={"input": 4.0}) - assert response.json() == {"output": 16.0} + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response = await ac.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 16.0} @pytest.mark.asyncio() async def test_simple_api(): api = ls.examples.SimpleLitAPI() server = ls.LitServer(api) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - response = await ac.post("/predict", json={"input": 4.0}) - assert response.json() == {"output": 16.0} + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + response = await ac.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 16.0} diff --git a/tests/test_form.py b/tests/test_form.py index 56c49baf..9eac31b1 100644 --- a/tests/test_form.py +++ b/tests/test_form.py @@ -14,6 +14,7 @@ from fastapi import Request, Response from fastapi.testclient import TestClient +from tests.conftest import wrap_litserve_start from litserve import LitAPI, LitServer @@ -39,14 +40,15 @@ def test_multipart_form_data(tmp_path): SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length * 2) ) - with TestClient(server.app) as client: - file_path = f"{tmp_path}/big_file.txt" - with open(file_path, "wb") as f: - f.write(bytearray([1] * file_length)) - with open(file_path, "rb") as f: - file = {"input": f} - response = client.post("/predict", files=file) - assert response.json() == {"output": file_length**2} + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + file_path = f"{tmp_path}/big_file.txt" + with open(file_path, "wb") as f: + f.write(bytearray([1] * file_length)) + with open(file_path, "rb") as f: + file = {"input": f} + response = client.post("/predict", files=file) + assert response.json() == {"output": file_length**2} def test_file_too_big(tmp_path): @@ -56,18 +58,19 @@ def test_file_too_big(tmp_path): SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length / 2) ) - with TestClient(server.app) as client: - file_path = f"{tmp_path}/big_file.txt" - with open(file_path, "wb") as f: - f.write(bytearray([1] * file_length)) - with open(file_path, "rb") as f: - file = {"input": f} - response = client.post("/predict", files=file) - assert response.status_code == 413 + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + file_path = f"{tmp_path}/big_file.txt" + with open(file_path, "wb") as f: + f.write(bytearray([1] * file_length)) + with open(file_path, "rb") as f: + file = {"input": f} + response = client.post("/predict", files=file) + assert response.status_code == 413 - # spoof content-length size - response = client.post("/predict", files=file, headers={"content-length": "1024"}) - assert response.status_code == 413 + # spoof content-length size + response = client.post("/predict", files=file, headers={"content-length": "1024"}) + assert response.status_code == 413 class SimpleFormLitAPI(LitAPI): @@ -86,8 +89,8 @@ def encode_response(self, output) -> Response: def test_urlencoded_form_data(): server = LitServer(SimpleFormLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - - with TestClient(server.app) as client: - file = {"input": "4.0"} - response = client.post("/predict", data=file) - assert response.json() == {"output": 16.0} + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + file = {"input": "4.0"} + response = client.post("/predict", data=file) + assert response.json() == {"output": 16.0} diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index d119649a..dd019c6d 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -23,6 +23,7 @@ import torch.nn as nn from queue import Queue from httpx import AsyncClient +from tests.conftest import wrap_litserve_start from unittest.mock import patch, MagicMock import pytest @@ -75,8 +76,8 @@ def test_inference_worker(mock_single_loop, mock_batched_loop): @pytest.fixture() def loop_args(): requests_queue = Queue() - requests_queue.put(("uuid-123", time.monotonic(), 1)) # uid, timestamp, x_enc - requests_queue.put(("uuid-234", time.monotonic(), 2)) + requests_queue.put((0, "uuid-123", time.monotonic(), 1)) # response_queue_id, uid, timestamp, x_enc + requests_queue.put((1, "uuid-234", time.monotonic(), 2)) lit_api_mock = MagicMock() lit_api_mock.request_timeout = 1 @@ -103,15 +104,19 @@ async def test_stream(simple_stream_api): server = LitServer(simple_stream_api, stream=True, timeout=10) expected_output1 = "prompt=Hello generated_output=LitServe is streaming output".lower().replace(" ", "") expected_output2 = "prompt=World generated_output=LitServe is streaming output".lower().replace(" ", "") - - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) - resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) - resp1, resp2 = await asyncio.gather(resp1, resp2) - assert resp1.status_code == 200, "Check if server is running and the request format is valid." - assert resp1.text == expected_output1, "Server returns input prompt and generated output which didn't match." - assert resp2.status_code == 200, "Check if server is running and the request format is valid." - assert resp2.text == expected_output2, "Server returns input prompt and generated output which didn't match." + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) + resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) + resp1, resp2 = await asyncio.gather(resp1, resp2) + assert resp1.status_code == 200, "Check if server is running and the request format is valid." + assert ( + resp1.text == expected_output1 + ), "Server returns input prompt and generated output which didn't match." + assert resp2.status_code == 200, "Check if server is running and the request format is valid." + assert ( + resp2.text == expected_output2 + ), "Server returns input prompt and generated output which didn't match." @pytest.mark.asyncio() @@ -120,14 +125,19 @@ async def test_batched_stream_server(simple_batched_stream_api): expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "") expected_output2 = "World LitServe is streaming output".lower().replace(" ", "") - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) - resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) - resp1, resp2 = await asyncio.gather(resp1, resp2) - assert resp1.status_code == 200, "Check if server is running and the request format is valid." - assert resp2.status_code == 200, "Check if server is running and the request format is valid." - assert resp1.text == expected_output1, "Server returns input prompt and generated output which didn't match." - assert resp2.text == expected_output2, "Server returns input prompt and generated output which didn't match." + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) + resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) + resp1, resp2 = await asyncio.gather(resp1, resp2) + assert resp1.status_code == 200, "Check if server is running and the request format is valid." + assert resp2.status_code == 200, "Check if server is running and the request format is valid." + assert ( + resp1.text == expected_output1 + ), "Server returns input prompt and generated output which didn't match." + assert ( + resp2.text == expected_output2 + ), "Server returns input prompt and generated output which didn't match." class FakeStreamResponseQueue: @@ -324,10 +334,10 @@ def test_server_run(mock_uvicorn): server.run(port=65536) server.run(port=8000) - mock_uvicorn.run.assert_called() + mock_uvicorn.Config.assert_called() mock_uvicorn.reset_mock() server.run(port="8001") - mock_uvicorn.run.assert_called() + mock_uvicorn.Config.assert_called() class IndentityAPI(ls.examples.SimpleLitAPI): @@ -383,20 +393,23 @@ def dummy_load_and_raise(resp): # Test context injection with single loop api = IndentityAPI() server = LitServer(api) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) assert resp.json()["output"] == 5.0, "output from Identity server must be same as input" # Test context injection with batched loop server = LitServer(IndentityBatchedAPI(), max_batch_size=2, batch_timeout=0.01) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) assert resp.json()["output"] == 5.0, "output from Identity server must be same as input" # Test context injection with batched streaming loop server = LitServer(IndentityBatchedStreamingAPI(), max_batch_size=2, batch_timeout=0.01, stream=True) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/predict", json={"input": 5.0}, timeout=10) assert resp.json()["output"] == 5.0, "output from Identity server must be same as input" server = LitServer(PredictErrorAPI()) @@ -412,9 +425,10 @@ def test_custom_api_path(): server = LitServer(ls.examples.SimpleLitAPI(), api_path="/v1/custom_predict") url = server.api_path - with TestClient(server.app) as client: - response = client.post(url, json={"input": 4.0}) - assert response.status_code == 200, "Server response should be 200 (OK)" + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + response = client.post(url, json={"input": 4.0}) + assert response.status_code == 200, "Server response should be 200 (OK)" class TestHTTPExceptionAPI(ls.examples.SimpleLitAPI): @@ -424,7 +438,8 @@ def decode_request(self, request): def test_http_exception(): server = LitServer(TestHTTPExceptionAPI()) - with TestClient(server.app) as client: - response = client.post("/predict", json={"input": 4.0}) - assert response.status_code == 501, "Server raises 501 error" - assert response.text == '{"detail":"decode request is bad"}', "decode request is bad" + with wrap_litserve_start(server) as server: + with TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.status_code == 501, "Server raises 501 error" + assert response.text == '{"detail":"decode request is bad"}', "decode request is bad" diff --git a/tests/test_simple.py b/tests/test_simple.py index a632b95e..49415c33 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -25,6 +25,8 @@ from litserve import LitAPI, LitServer +from tests.conftest import wrap_litserve_start + class SimpleLitAPI(LitAPI): def setup(self, device): @@ -40,10 +42,8 @@ def encode_response(self, output) -> Response: return {"output": output} -def test_simple(): - server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=5) - - with TestClient(server.app) as client: +def test_simple(lit_server): + with TestClient(lit_server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 16.0} @@ -80,16 +80,13 @@ def make_load_request(server, outputs): outputs.append(response.json()) -def test_load(): - server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=25) - +def test_load(lit_server): from threading import Thread threads = [] - for _ in range(1): outputs = [] - t = Thread(target=make_load_request, args=(server, outputs)) + t = Thread(target=make_load_request, args=(lit_server, outputs)) t.start() threads.append((t, outputs)) @@ -127,63 +124,67 @@ async def test_timeout(): api = SlowLitAPI() # takes 2 second for each prediction server = LitServer(api, accelerator="cpu", devices=1, timeout=1.5) - # case 1: first request completes, second request times out in queue - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - await asyncio.sleep(2) # Give time to start inference workers - response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0})) - await asyncio.sleep(0.0001) - response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0})) - await asyncio.wait([response1, response2]) - assert ( - response1.result().status_code == 200 - ), "First request should complete since it's popped from the request queue." - assert ( - response2.result().status_code == 504 - ), "Server takes longer than specified timeout and request should timeout" + with wrap_litserve_start(server) as server: + # case 1: first request completes, second request times out in queue + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + await asyncio.sleep(2) # Give time to start inference workers + response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0})) + await asyncio.sleep(0.0001) + response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0})) + await asyncio.wait([response1, response2]) + assert ( + response1.result().status_code == 200 + ), "First request should complete since it's popped from the request queue." + assert ( + response2.result().status_code == 504 + ), "Server takes longer than specified timeout and request should timeout" # Case 2: first 2 requests finish as a batch and third request times out in queue server = LitServer(SlowBatchAPI(), accelerator="cpu", timeout=1.5, max_batch_size=2, batch_timeout=0.01) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - await asyncio.sleep(2) # Give time to start inference workers - response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0})) - response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0})) - await asyncio.sleep(0.0001) - response3 = asyncio.create_task(ac.post("/predict", json={"input": 6.0})) - await asyncio.wait([response1, response2, response3]) - assert ( - response1.result().status_code == 200 - ), "Batch: First request should complete since it's popped from the request queue." - assert ( - response2.result().status_code == 200 - ), "Batch: Second request should complete since it's popped from the request queue." - - assert response3.result().status_code == 504, "Batch: Third request was delayed and should fail" + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + await asyncio.sleep(2) # Give time to start inference workers + response1 = asyncio.create_task(ac.post("/predict", json={"input": 4.0})) + response2 = asyncio.create_task(ac.post("/predict", json={"input": 5.0})) + await asyncio.sleep(0.0001) + response3 = asyncio.create_task(ac.post("/predict", json={"input": 6.0})) + await asyncio.wait([response1, response2, response3]) + assert ( + response1.result().status_code == 200 + ), "Batch: First request should complete since it's popped from the request queue." + assert ( + response2.result().status_code == 200 + ), "Batch: Second request should complete since it's popped from the request queue." + + assert response3.result().status_code == 504, "Batch: Third request was delayed and should fail" server1 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=-1) server2 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=False) server3 = LitServer(SlowBatchAPI(), accelerator="cpu", devices=1, timeout=False, max_batch_size=2, batch_timeout=2) server4 = LitServer(SlowBatchAPI(), accelerator="cpu", devices=1, timeout=-1, max_batch_size=2, batch_timeout=2) - with TestClient(server1.app) as client1, TestClient(server2.app) as client2, TestClient( - server3.app - ) as client3, TestClient(server4.app) as client4: - response1 = client1.post("/predict", json={"input": 4.0}) - assert response1.status_code == 200, "Expected slow server to respond since timeout was disabled" + with wrap_litserve_start(server1) as server1, wrap_litserve_start(server2) as server2, wrap_litserve_start( + server3 + ) as server3, wrap_litserve_start(server4) as server4: + with TestClient(server1.app) as client1, TestClient(server2.app) as client2, TestClient( + server3.app + ) as client3, TestClient(server4.app) as client4: + response1 = client1.post("/predict", json={"input": 4.0}) + assert response1.status_code == 200, "Expected slow server to respond since timeout was disabled" - response2 = client2.post("/predict", json={"input": 4.0}) - assert response2.status_code == 200, "Expected slow server to respond since timeout was disabled" + response2 = client2.post("/predict", json={"input": 4.0}) + assert response2.status_code == 200, "Expected slow server to respond since timeout was disabled" - response3 = client3.post("/predict", json={"input": 4.0}) - assert response3.status_code == 200, "Expected slow batch server to respond since timeout was disabled" + response3 = client3.post("/predict", json={"input": 4.0}) + assert response3.status_code == 200, "Expected slow batch server to respond since timeout was disabled" - response4 = client4.post("/predict", json={"input": 4.0}) - assert response4.status_code == 200, "Expected slow batch server to respond since timeout was disabled" + response4 = client4.post("/predict", json={"input": 4.0}) + assert response4.status_code == 200, "Expected slow batch server to respond since timeout was disabled" -def test_concurrent_requests(): +def test_concurrent_requests(lit_server): n_requests = 100 - server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with TestClient(server.app) as client, ThreadPoolExecutor(n_requests // 4 + 1) as executor: + with TestClient(lit_server.app) as client, ThreadPoolExecutor(n_requests // 4 + 1) as executor: responses = list(executor.map(lambda i: client.post("/predict", json={"input": i}), range(n_requests))) count = 0 From c30815bc78cde1293229c22736a85f69031f67f8 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 19:03:27 +0000 Subject: [PATCH 28/45] fix tests --- src/litserve/server.py | 12 ++++++---- src/litserve/specs/openai.py | 8 ++++++- tests/test_lit_server.py | 43 ++++++++++++++++++------------------ 3 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index f0bcf27c..6634c7c7 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -287,7 +287,7 @@ def run_batched_streaming_loop( if not batches: continue - uids, inputs = zip(*batches) + response_queue_ids, uids, inputs = zip(*batches) try: contexts = [{}] * len(inputs) if hasattr(lit_spec, "populate_context"): @@ -309,11 +309,11 @@ def run_batched_streaming_loop( # y_enc_iter -> [[response-1, response-2], [response-1, response-2]] for y_batch in y_enc_iter: - for y_enc, uid in zip(y_batch, uids): + for response_queue_id, y_enc, uid in zip(response_queue_ids, y_batch, uids): y_enc = lit_api.format_encoded_response(y_enc) response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) - for uid in uids: + for response_queue_id, uid in zip(response_queue_ids, uids): response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) except Exception as e: @@ -438,6 +438,8 @@ def __init__( lit_api.sanitize(max_batch_size, spec=spec) self.app = FastAPI(lifespan=self.lifespan) self.app.response_queue_id = None + self.response_queue_id = None + self.response_buffer = {} # gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448 if not stream: self.app.add_middleware(GZipMiddleware, minimum_size=1000) @@ -529,7 +531,6 @@ def launch_inference_worker(self, num_uvicorn_servers: int): @asynccontextmanager async def lifespan(self, app: FastAPI): loop = asyncio.get_running_loop() - self.response_buffer = {} response_queue = self.response_queues[app.response_queue_id] response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) @@ -703,7 +704,10 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets): servers = [] for response_queue_id in range(num_uvicorn_servers): self.app.response_queue_id = response_queue_id + if self.lit_spec: + self.lit_spec.response_queue_id = response_queue_id app = copy.copy(self.app) + config = uvicorn.Config(app=app, port=port, log_level=log_level, loop="uvloop") server = uvicorn.Server(config=config) ctx = mp.get_context("fork") diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index daa7a98e..bc14d01e 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -317,6 +317,12 @@ async def options_chat_completions(self, request: Request): return Response(status_code=200) async def chat_completion(self, request: ChatCompletionRequest, background_tasks: BackgroundTasks): + try: + response_queue_id = self.response_queue_id + except Exception as e: + print(e) + response_queue_id = None + print("response_queue_id", response_queue_id) logger.debug("Received chat completion request %s", request) uids = [uuid.uuid4() for _ in range(request.n)] @@ -328,7 +334,7 @@ async def chat_completion(self, request: ChatCompletionRequest, background_tasks q = deque() event = asyncio.Event() self._server.response_buffer[uid] = (q, event) - self._server.request_queue.put((uid, time.monotonic(), request_el)) + self._server.request_queue.put((response_queue_id, uid, time.monotonic(), request_el)) self.queues.append(q) self.events.append(event) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index dd019c6d..d0f4719c 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -93,10 +93,10 @@ def put(self, item): def test_single_loop(loop_args): lit_api_mock, requests_queue = loop_args lit_api_mock.unbatch.side_effect = None - response_queue = FakeResponseQueue() + response_queues = [FakeResponseQueue()] with pytest.raises(StopIteration, match="exit loop"): - run_single_loop(lit_api_mock, None, requests_queue, response_queue) + run_single_loop(lit_api_mock, None, requests_queue, response_queues) @pytest.mark.asyncio() @@ -104,6 +104,7 @@ async def test_stream(simple_stream_api): server = LitServer(simple_stream_api, stream=True, timeout=10) expected_output1 = "prompt=Hello generated_output=LitServe is streaming output".lower().replace(" ", "") expected_output2 = "prompt=World generated_output=LitServe is streaming output".lower().replace(" ", "") + with wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) @@ -154,7 +155,7 @@ def put(self, item): self.count += 1 -def test_streaming_loop(loop_args): +def test_streaming_loop(): num_streamed_outputs = 10 def fake_predict(inputs: str): @@ -174,11 +175,11 @@ def fake_encode(output): fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x) requests_queue = Queue() - requests_queue.put(("UUID-1234", time.monotonic(), {"prompt": "Hello"})) - response_queue = FakeStreamResponseQueue(num_streamed_outputs) + requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"})) + response_queues = [FakeStreamResponseQueue(num_streamed_outputs)] with pytest.raises(StopIteration, match="exit loop"): - run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queue) + run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues) fake_stream_api.predict.assert_called_once_with("Hello") fake_stream_api.encode_response.assert_called_once() @@ -230,13 +231,13 @@ def fake_encode(output_iter): fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x) requests_queue = Queue() - requests_queue.put(("UUID-001", time.monotonic(), {"prompt": "Hello"})) - requests_queue.put(("UUID-002", time.monotonic(), {"prompt": "World"})) - response_queue = FakeBatchStreamResponseQueue(num_streamed_outputs) + requests_queue.put((0, "UUID-001", time.monotonic(), {"prompt": "Hello"})) + requests_queue.put((0, "UUID-002", time.monotonic(), {"prompt": "World"})) + response_queues = [FakeBatchStreamResponseQueue(num_streamed_outputs)] with pytest.raises(StopIteration, match="finish streaming"): run_batched_streaming_loop( - fake_stream_api, fake_stream_api, requests_queue, response_queue, max_batch_size=2, batch_timeout=2 + fake_stream_api, fake_stream_api, requests_queue, response_queues, max_batch_size=2, batch_timeout=2 ) fake_stream_api.predict.assert_called_once_with(["Hello", "World"]) fake_stream_api.encode_response.assert_called_once() @@ -413,9 +414,9 @@ def dummy_load_and_raise(resp): assert resp.json()["output"] == 5.0, "output from Identity server must be same as input" server = LitServer(PredictErrorAPI()) - with pytest.raises(TypeError, match=re.escape("predict() missing 1 required positional argument: 'y'")), TestClient( - server.app - ) as client: + with wrap_litserve_start(server) as server, pytest.raises( + TypeError, match=re.escape("predict() missing 1 required positional argument: 'y'") + ), TestClient(server.app) as client: client.post("/predict", json={"input": 5.0}, timeout=10) @@ -425,10 +426,9 @@ def test_custom_api_path(): server = LitServer(ls.examples.SimpleLitAPI(), api_path="/v1/custom_predict") url = server.api_path - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - response = client.post(url, json={"input": 4.0}) - assert response.status_code == 200, "Server response should be 200 (OK)" + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post(url, json={"input": 4.0}) + assert response.status_code == 200, "Server response should be 200 (OK)" class TestHTTPExceptionAPI(ls.examples.SimpleLitAPI): @@ -438,8 +438,7 @@ def decode_request(self, request): def test_http_exception(): server = LitServer(TestHTTPExceptionAPI()) - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - response = client.post("/predict", json={"input": 4.0}) - assert response.status_code == 501, "Server raises 501 error" - assert response.text == '{"detail":"decode request is bad"}', "decode request is bad" + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.status_code == 501, "Server raises 501 error" + assert response.text == '{"detail":"decode request is bad"}', "decode request is bad" From 06257afb82d281f40c5dbecb94d7b3b84a354ec6 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 19:08:29 +0000 Subject: [PATCH 29/45] fix --- tests/test_compression.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tests/test_compression.py b/tests/test_compression.py index 3979d29d..0fb162fb 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -38,21 +38,19 @@ def encode_response(self, output) -> Response: def test_compression(): server = LitServer(LargeOutputLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with wrap_litserve_start(server) as server: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: # compressed - with TestClient(server.app) as client: - response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={}) - assert response.status_code == 200 - assert response.headers["Content-Encoding"] == "gzip" - content_length = int(response.headers["Content-Length"]) - assert 0 < content_length < 100000 - assert response.json() == test_output + response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={}) + assert response.status_code == 200 + assert response.headers["Content-Encoding"] == "gzip" + content_length = int(response.headers["Content-Length"]) + assert 0 < content_length < 100000 + assert response.json() == test_output # uncompressed - with TestClient(server.app) as client: - response = client.post("/predict", headers={"Accept-Encoding": ""}, json={}) - assert response.status_code == 200 - assert "Content-Encoding" not in response.headers - content_length = int(response.headers["Content-Length"]) - assert content_length > 100000 - assert response.json() == test_output + response = client.post("/predict", headers={"Accept-Encoding": ""}, json={}) + assert response.status_code == 200 + assert "Content-Encoding" not in response.headers + content_length = int(response.headers["Content-Length"]) + assert content_length > 100000 + assert response.json() == test_output From 78de09e89fda67537ae62fb7f61b1cc461c39187 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 19:09:53 +0000 Subject: [PATCH 30/45] remove uvloop import --- src/litserve/server.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 6634c7c7..2ece54ba 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -42,10 +42,16 @@ from litserve.specs.base import LitSpec from litserve.utils import LitAPIStatus, load_and_raise from collections import deque -import uvloop mp.allow_connection_pickling() -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +try: + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +except ImportError: + print("uvloop is not installed. Falling back to the default asyncio event loop.") logger = logging.getLogger(__name__) From ded829b999e5a348f26039f498b9de3268eb79a1 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 19:41:26 +0000 Subject: [PATCH 31/45] auto loop --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 2ece54ba..88ba56b1 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -688,7 +688,7 @@ def run( if not (1024 <= port <= 65535): raise ValueError(port_msg) - config = uvicorn.Config(app=self.app, port=port, log_level=log_level, loop="uvloop") + config = uvicorn.Config(app=self.app, port=port, log_level=log_level) sockets = [config.bind_socket()] if num_uvicorn_servers is None: From ea9f2b446aef0b6c2d79260de18c1394c0de3ec2 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 19:55:57 +0000 Subject: [PATCH 32/45] format --- _requirements/test.txt | 1 + tests/test_auth.py | 36 ++++++------- tests/test_form.py | 49 ++++++++--------- tests/test_pydantic.py | 5 +- tests/test_specs.py | 120 +++++++++++++++++++++-------------------- 5 files changed, 105 insertions(+), 106 deletions(-) diff --git a/_requirements/test.txt b/_requirements/test.txt index b834241c..bade4d6e 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -11,3 +11,4 @@ lightning >2.0.0 torch >2.0.0 transformers openai>=1.12.0 +uvloop diff --git a/tests/test_auth.py b/tests/test_auth.py index f6a71e1d..1955c4fa 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -41,20 +41,18 @@ def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())): def test_authorized_custom(): server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - input = {"input": 4.0} - response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input) - assert response.status_code == 200 + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + input = {"input": 4.0} + response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input) + assert response.status_code == 200 def test_not_authorized_custom(): server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - input = {"input": 4.0} - response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input) - assert response.status_code == 401 + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + input = {"input": 4.0} + response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input) + assert response.status_code == 401 class SimpleLitAPI(LitAPI): @@ -75,11 +73,10 @@ def test_authorized_api_key(): litserve.server.LIT_SERVER_API_KEY = "abcd" server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - input = {"input": 4.0} - response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input) - assert response.status_code == 200 + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + input = {"input": 4.0} + response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input) + assert response.status_code == 200 litserve.server.LIT_SERVER_API_KEY = None @@ -88,10 +85,9 @@ def test_not_authorized_api_key(): litserve.server.LIT_SERVER_API_KEY = "abcd" server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - input = {"input": 4.0} - response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input) - assert response.status_code == 401 + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + input = {"input": 4.0} + response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input) + assert response.status_code == 401 litserve.server.LIT_SERVER_API_KEY = None diff --git a/tests/test_form.py b/tests/test_form.py index 9eac31b1..e5026089 100644 --- a/tests/test_form.py +++ b/tests/test_form.py @@ -40,15 +40,14 @@ def test_multipart_form_data(tmp_path): SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length * 2) ) - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - file_path = f"{tmp_path}/big_file.txt" - with open(file_path, "wb") as f: - f.write(bytearray([1] * file_length)) - with open(file_path, "rb") as f: - file = {"input": f} - response = client.post("/predict", files=file) - assert response.json() == {"output": file_length**2} + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + file_path = f"{tmp_path}/big_file.txt" + with open(file_path, "wb") as f: + f.write(bytearray([1] * file_length)) + with open(file_path, "rb") as f: + file = {"input": f} + response = client.post("/predict", files=file) + assert response.json() == {"output": file_length**2} def test_file_too_big(tmp_path): @@ -58,19 +57,18 @@ def test_file_too_big(tmp_path): SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length / 2) ) - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - file_path = f"{tmp_path}/big_file.txt" - with open(file_path, "wb") as f: - f.write(bytearray([1] * file_length)) - with open(file_path, "rb") as f: - file = {"input": f} - response = client.post("/predict", files=file) - assert response.status_code == 413 + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + file_path = f"{tmp_path}/big_file.txt" + with open(file_path, "wb") as f: + f.write(bytearray([1] * file_length)) + with open(file_path, "rb") as f: + file = {"input": f} + response = client.post("/predict", files=file) + assert response.status_code == 413 - # spoof content-length size - response = client.post("/predict", files=file, headers={"content-length": "1024"}) - assert response.status_code == 413 + # spoof content-length size + response = client.post("/predict", files=file, headers={"content-length": "1024"}) + assert response.status_code == 413 class SimpleFormLitAPI(LitAPI): @@ -89,8 +87,7 @@ def encode_response(self, output) -> Response: def test_urlencoded_form_data(): server = LitServer(SimpleFormLitAPI(), accelerator="cpu", devices=1, workers_per_device=1) - with wrap_litserve_start(server) as server: - with TestClient(server.app) as client: - file = {"input": "4.0"} - response = client.post("/predict", data=file) - assert response.json() == {"output": 16.0} + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + file = {"input": "4.0"} + response = client.post("/predict", data=file) + assert response.json() == {"output": 16.0} diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index d017347c..5d1335eb 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -14,7 +14,7 @@ from fastapi.testclient import TestClient from pydantic import BaseModel - +from tests.conftest import wrap_litserve_start from litserve import LitAPI, LitServer @@ -42,7 +42,6 @@ def encode_response(self, output: float) -> PredictResponse: def test_pydantic(): server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=5) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 16.0} diff --git a/tests/test_specs.py b/tests/test_specs.py index 5cfe3660..3fc49d2b 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -24,6 +24,7 @@ OpenAIBatchingWithUsage, OpenAIWithUsageEncodeResponse, ) +from tests.conftest import wrap_litserve_start from litserve.specs.openai import OpenAISpec, ChatMessage import litserve as ls @@ -32,13 +33,14 @@ async def test_openai_spec(openai_request_data): spec = OpenAISpec() server = ls.LitServer(TestAPI(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 200, "Status code should be 200" + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" - assert ( - resp.json()["choices"][0]["message"]["content"] == "This is a generated output" - ), "LitAPI predict response should match with the generated output" + assert ( + resp.json()["choices"][0]["message"]["content"] == "This is a generated output" + ), "LitAPI predict response should match with the generated output" # OpenAIWithUsage @@ -53,64 +55,66 @@ async def test_openai_spec(openai_request_data): ) async def test_openai_token_usage(api, batch_size, openai_request_data, openai_response_data): server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=batch_size, batch_timeout=0.01) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 200, "Status code should be 200" - result = resp.json() - content = result["choices"][0]["message"]["content"] - assert content == "10 + 6 is equal to 16.", "LitAPI predict response should match with the generated output" - assert result["usage"] == openai_response_data["usage"] - - # with streaming - openai_request_data["stream"] = True - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 200, "Status code should be 200" - assert result["usage"] == openai_response_data["usage"] + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + result = resp.json() + content = result["choices"][0]["message"]["content"] + assert content == "10 + 6 is equal to 16.", "LitAPI predict response should match with the generated output" + assert result["usage"] == openai_response_data["usage"] + + # with streaming + openai_request_data["stream"] = True + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert result["usage"] == openai_response_data["usage"] @pytest.mark.asyncio() async def test_openai_spec_with_image(openai_request_data_with_image): - spec = OpenAISpec() - server = ls.LitServer(TestAPI(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_image, timeout=10) - assert resp.status_code == 200, "Status code should be 200" + server = ls.LitServer(TestAPI(), spec=OpenAISpec()) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_image, timeout=10) + assert resp.status_code == 200, "Status code should be 200" - assert ( - resp.json()["choices"][0]["message"]["content"] == "This is a generated output" - ), "LitAPI predict response should match with the generated output" + assert ( + resp.json()["choices"][0]["message"]["content"] == "This is a generated output" + ), "LitAPI predict response should match with the generated output" @pytest.mark.asyncio() async def test_override_encode(openai_request_data): - spec = OpenAISpec() - server = ls.LitServer(TestAPIWithCustomEncode(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert resp.status_code == 200, "Status code should be 200" + server = ls.LitServer(TestAPIWithCustomEncode(), spec=OpenAISpec()) + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 200, "Status code should be 200" - assert ( - resp.json()["choices"][0]["message"]["content"] == "This is a custom encoded output" - ), "LitAPI predict response should match with the generated output" + assert ( + resp.json()["choices"][0]["message"]["content"] == "This is a custom encoded output" + ), "LitAPI predict response should match with the generated output" @pytest.mark.asyncio() async def test_openai_spec_with_tools(openai_request_data_with_tools): spec = OpenAISpec() server = ls.LitServer(TestAPIWithToolCalls(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_tools, timeout=10) - assert resp.status_code == 200, "Status code should be 200" - assert ( - resp.json()["choices"][0]["message"]["content"] == "" - ), "LitAPI predict response should match with the generated output" - assert resp.json()["choices"][0]["message"]["tool_calls"] == [ - { - "id": "call_1", - "type": "function", - "function": {"name": "function_1", "arguments": '{"arg_1": "arg_1_value"}'}, - } - ], "LitAPI predict response should match with the generated output" + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data_with_tools, timeout=10) + assert resp.status_code == 200, "Status code should be 200" + assert ( + resp.json()["choices"][0]["message"]["content"] == "" + ), "LitAPI predict response should match with the generated output" + assert resp.json()["choices"][0]["message"]["tool_calls"] == [ + { + "id": "call_1", + "type": "function", + "function": {"name": "function_1", "arguments": '{"arg_1": "arg_1_value"}'}, + } + ], "LitAPI predict response should match with the generated output" class IncorrectAPI1(ls.LitAPI): @@ -159,11 +163,12 @@ async def test_oai_prepopulated_context(openai_request_data): openai_request_data["max_tokens"] = 3 spec = OpenAISpec() server = ls.LitServer(PrePopulatedAPI(), spec=spec) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert ( - resp.json()["choices"][0]["message"]["content"] == "This is a" - ), "OpenAISpec must return only 3 tokens as specified using `max_tokens` parameter" + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert ( + resp.json()["choices"][0]["message"]["content"] == "This is a" + ), "OpenAISpec must return only 3 tokens as specified using `max_tokens` parameter" class WrongLitAPI(ls.LitAPI): @@ -178,7 +183,8 @@ def predict(self, prompt): @pytest.mark.asyncio() async def test_fail_http(openai_request_data): server = ls.LitServer(WrongLitAPI(), spec=ls.OpenAISpec()) - async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - res = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) - assert res.status_code == 501, "Server raises 501 error" - assert res.text == '{"detail":"test LitAPI.predict error"}' + with wrap_litserve_start(server) as server: + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + res = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert res.status_code == 501, "Server raises 501 error" + assert res.text == '{"detail":"test LitAPI.predict error"}' From 4eb8cbae89c902686a48bd535c0e72fa0e2c3b51 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 19:58:32 +0000 Subject: [PATCH 33/45] formatting --- tests/test_simple.py | 49 ++++++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/tests/test_simple.py b/tests/test_simple.py index 49415c33..290329e0 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -140,7 +140,13 @@ async def test_timeout(): ), "Server takes longer than specified timeout and request should timeout" # Case 2: first 2 requests finish as a batch and third request times out in queue - server = LitServer(SlowBatchAPI(), accelerator="cpu", timeout=1.5, max_batch_size=2, batch_timeout=0.01) + server = LitServer( + SlowBatchAPI(), + accelerator="cpu", + timeout=1.5, + max_batch_size=2, + batch_timeout=0.01, + ) with wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: await asyncio.sleep(2) # Give time to start inference workers @@ -160,26 +166,39 @@ async def test_timeout(): server1 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=-1) server2 = LitServer(SlowLitAPI(), accelerator="cpu", devices=1, timeout=False) - server3 = LitServer(SlowBatchAPI(), accelerator="cpu", devices=1, timeout=False, max_batch_size=2, batch_timeout=2) - server4 = LitServer(SlowBatchAPI(), accelerator="cpu", devices=1, timeout=-1, max_batch_size=2, batch_timeout=2) + server3 = LitServer( + SlowBatchAPI(), + accelerator="cpu", + devices=1, + timeout=False, + max_batch_size=2, + batch_timeout=2, + ) + server4 = LitServer( + SlowBatchAPI(), + accelerator="cpu", + devices=1, + timeout=-1, + max_batch_size=2, + batch_timeout=2, + ) with wrap_litserve_start(server1) as server1, wrap_litserve_start(server2) as server2, wrap_litserve_start( server3 - ) as server3, wrap_litserve_start(server4) as server4: - with TestClient(server1.app) as client1, TestClient(server2.app) as client2, TestClient( - server3.app - ) as client3, TestClient(server4.app) as client4: - response1 = client1.post("/predict", json={"input": 4.0}) - assert response1.status_code == 200, "Expected slow server to respond since timeout was disabled" + ) as server3, wrap_litserve_start(server4) as server4, TestClient(server1.app) as client1, TestClient( + server2.app + ) as client2, TestClient(server3.app) as client3, TestClient(server4.app) as client4: + response1 = client1.post("/predict", json={"input": 4.0}) + assert response1.status_code == 200, "Expected slow server to respond since timeout was disabled" - response2 = client2.post("/predict", json={"input": 4.0}) - assert response2.status_code == 200, "Expected slow server to respond since timeout was disabled" + response2 = client2.post("/predict", json={"input": 4.0}) + assert response2.status_code == 200, "Expected slow server to respond since timeout was disabled" - response3 = client3.post("/predict", json={"input": 4.0}) - assert response3.status_code == 200, "Expected slow batch server to respond since timeout was disabled" + response3 = client3.post("/predict", json={"input": 4.0}) + assert response3.status_code == 200, "Expected slow batch server to respond since timeout was disabled" - response4 = client4.post("/predict", json={"input": 4.0}) - assert response4.status_code == 200, "Expected slow batch server to respond since timeout was disabled" + response4 = client4.post("/predict", json={"input": 4.0}) + assert response4.status_code == 200, "Expected slow batch server to respond since timeout was disabled" def test_concurrent_requests(lit_server): From 27f7bbaf096fc98c839371e40336244f1612485d Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 19:59:33 +0000 Subject: [PATCH 34/45] remove uvloop --- _requirements/test.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/_requirements/test.txt b/_requirements/test.txt index bade4d6e..b834241c 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -11,4 +11,3 @@ lightning >2.0.0 torch >2.0.0 transformers openai>=1.12.0 -uvloop From 41cb99b68c133892c25669cc46088fb8a27efab1 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 20:09:06 +0000 Subject: [PATCH 35/45] fix test --- src/litserve/specs/openai.py | 8 +------- tests/conftest.py | 6 ++++-- tests/test_specs.py | 9 ++++----- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index bc14d01e..77dc79a9 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -317,14 +317,8 @@ async def options_chat_completions(self, request: Request): return Response(status_code=200) async def chat_completion(self, request: ChatCompletionRequest, background_tasks: BackgroundTasks): - try: - response_queue_id = self.response_queue_id - except Exception as e: - print(e) - response_queue_id = None - print("response_queue_id", response_queue_id) + response_queue_id = self.response_queue_id logger.debug("Received chat completion request %s", request) - uids = [uuid.uuid4() for _ in range(request.n)] self.queues = [] self.events = [] diff --git a/tests/conftest.py b/tests/conftest.py index dd985796..85b3a082 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,9 +99,11 @@ def simple_batched_stream_api(): @contextmanager -def wrap_litserve_start(server): +def wrap_litserve_start(server: LitServer): server.app.response_queue_id = 0 - manager, processes = server.launch_inference_worker(1) + if server.lit_spec: + server.lit_spec.response_queue_id = 0 + manager, processes = server.launch_inference_worker(num_uvicorn_servers=1) yield server for p in processes: p.terminate() diff --git a/tests/test_specs.py b/tests/test_specs.py index 3fc49d2b..2e7281b5 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -135,14 +135,13 @@ def encode_response(self, output): @pytest.mark.asyncio() async def test_openai_spec_validation(openai_request_data): - spec = OpenAISpec() - server = ls.LitServer(IncorrectAPI1(), spec=spec) - with pytest.raises(ValueError, match="predict is not a generator"): + server = ls.LitServer(IncorrectAPI1(), spec=OpenAISpec()) + with pytest.raises(ValueError, match="predict is not a generator"), wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager: await manager.shutdown() - server = ls.LitServer(IncorrectAPI2(), spec=spec) - with pytest.raises(ValueError, match="encode_response is not a generator"): + server = ls.LitServer(IncorrectAPI2(), spec=OpenAISpec()) + with pytest.raises(ValueError, match="encode_response is not a generator"), wrap_litserve_start(server) as server: async with LifespanManager(server.app) as manager: await manager.shutdown() From b86afa10fb2a64bf91f2c4712feb49ca8dc5d12e Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 20:14:12 +0000 Subject: [PATCH 36/45] wrap ls start --- tests/test_simple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simple.py b/tests/test_simple.py index 290329e0..6232042f 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -57,7 +57,7 @@ def setup(self, device): def test_workers_health(): server = LitServer(SlowSetupLitAPI(), accelerator="cpu", devices=1, timeout=5, workers_per_device=2) - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.get("/health") assert response.status_code == 503 assert response.text == "not ready" From 1e2307e256daf27eb656783837aaecf2214b2eec Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 20:19:05 +0000 Subject: [PATCH 37/45] update --- tests/test_torch.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_torch.py b/tests/test_torch.py index a32c9bef..7b1f3788 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -15,6 +15,7 @@ from fastapi.testclient import TestClient from litserve import LitAPI, LitServer +from tests.conftest import wrap_litserve_start import torch import torch.nn as nn @@ -51,8 +52,7 @@ def encode_response(self, output) -> Response: def test_torch(): server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=10) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 9.0} @@ -60,7 +60,6 @@ def test_torch(): @pytest.mark.skipif(torch.cuda.device_count() == 0, reason="requires CUDA to be available") def test_torch_gpu(): server = LitServer(SimpleLitAPI(), accelerator="cuda", devices=1, timeout=10) - - with TestClient(server.app) as client: + with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 9.0} From 9f48b5ab984c2814ff6791a134d1f4bd55322b4b Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 20:28:36 +0000 Subject: [PATCH 38/45] fixes --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 88ba56b1..9322cf57 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -714,7 +714,7 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets): self.lit_spec.response_queue_id = response_queue_id app = copy.copy(self.app) - config = uvicorn.Config(app=app, port=port, log_level=log_level, loop="uvloop") + config = uvicorn.Config(app=app, port=port, log_level=log_level) server = uvicorn.Server(config=config) ctx = mp.get_context("fork") w = ctx.Process(target=server.run, args=(sockets,)) From 950f4578c3f3b7b28196f9b3c11ff747713c71bd Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 20:31:24 +0000 Subject: [PATCH 39/45] fix host --- src/litserve/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 9322cf57..60ea8575 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -688,7 +688,7 @@ def run( if not (1024 <= port <= 65535): raise ValueError(port_msg) - config = uvicorn.Config(app=self.app, port=port, log_level=log_level) + config = uvicorn.Config(app=self.app, host="0.0.0.0", port=port, log_level=log_level) sockets = [config.bind_socket()] if num_uvicorn_servers is None: @@ -714,7 +714,7 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets): self.lit_spec.response_queue_id = response_queue_id app = copy.copy(self.app) - config = uvicorn.Config(app=app, port=port, log_level=log_level) + config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level) server = uvicorn.Server(config=config) ctx = mp.get_context("fork") w = ctx.Process(target=server.run, args=(sockets,)) From c462eced1f148536113b8d120a9d60d684676406 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 20:42:31 +0000 Subject: [PATCH 40/45] downgrade numpy --- _requirements/test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/_requirements/test.txt b/_requirements/test.txt index b834241c..e75bcda0 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -11,3 +11,4 @@ lightning >2.0.0 torch >2.0.0 transformers openai>=1.12.0 +numpy==1.26.4 From 3827ccc934255ecdd411c9e6cbb9b2343d426560 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 20:56:14 +0000 Subject: [PATCH 41/45] fix windows --- src/litserve/server.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 60ea8575..e5a948ca 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -20,6 +20,7 @@ import pickle import shutil import sys +import threading import time import uuid from concurrent.futures import ThreadPoolExecutor @@ -674,6 +675,7 @@ def run( num_uvicorn_servers: Optional[int] = None, log_level: str = "info", generate_client_file: bool = True, + uvicorn_worker_type: Optional[str] = None, **kwargs, ): if generate_client_file: @@ -688,7 +690,7 @@ def run( if not (1024 <= port <= 65535): raise ValueError(port_msg) - config = uvicorn.Config(app=self.app, host="0.0.0.0", port=port, log_level=log_level) + config = uvicorn.Config(app=self.app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) sockets = [config.bind_socket()] if num_uvicorn_servers is None: @@ -696,7 +698,12 @@ def run( manager, litserve_workers = self.launch_inference_worker(num_uvicorn_servers) - servers = self._start_server(port, num_uvicorn_servers, log_level, sockets) + if sys.platform == "win32": + uvicorn_worker_type = "thread" + elif uvicorn_worker_type is None: + uvicorn_worker_type = "process" + + servers = self._start_server(port, num_uvicorn_servers, log_level, sockets, uvicorn_worker_type) for s in servers: s.join() @@ -706,7 +713,7 @@ def run( w.join() manager.shutdown() - def _start_server(self, port, num_uvicorn_servers, log_level, sockets): + def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_worker_type): servers = [] for response_queue_id in range(num_uvicorn_servers): self.app.response_queue_id = response_queue_id @@ -716,8 +723,13 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets): config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level) server = uvicorn.Server(config=config) - ctx = mp.get_context("fork") - w = ctx.Process(target=server.run, args=(sockets,)) + if uvicorn_worker_type == "process": + ctx = mp.get_context("fork") + w = ctx.Process(target=server.run, args=(sockets,)) + elif uvicorn_worker_type == "thread": + w = threading.Thread(target=server.run, args=(sockets,)) + else: + raise ValueError("Invalid value for uvicorn_worker_type. Must be 'process' or 'thread'") w.start() servers.append(w) return servers From e6e04ff307aef8a0f8f5e54673ee45adb95c6be0 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 1 Aug 2024 22:08:38 +0100 Subject: [PATCH 42/45] Update test.txt Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- _requirements/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_requirements/test.txt b/_requirements/test.txt index e75bcda0..1cc8e437 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -11,4 +11,4 @@ lightning >2.0.0 torch >2.0.0 transformers openai>=1.12.0 -numpy==1.26.4 +numpy <2.0 From 264486f051fbeb9f985c1d6bfc45d861570ad783 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 2 Aug 2024 12:26:15 +0200 Subject: [PATCH 43/45] Update src/litserve/server.py --- src/litserve/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index e5a948ca..d72c8caa 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -349,7 +349,6 @@ def inference_worker( print(f"Setup complete for worker {worker_id}.") - # lit_api.worker_id = worker_id if workers_setup_status: workers_setup_status[worker_id] = True From 2941949d54631de3293f9c1d72921ba1f6df5e1b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 10:26:21 +0000 Subject: [PATCH 44/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litserve/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index d72c8caa..9179d899 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -349,7 +349,6 @@ def inference_worker( print(f"Setup complete for worker {worker_id}.") - if workers_setup_status: workers_setup_status[worker_id] = True From 7597b160932023c206f45dc86cc871922870851a Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Fri, 2 Aug 2024 10:32:57 +0000 Subject: [PATCH 45/45] upgrade Python --- .github/workflows/ci-minimal-dependency-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-minimal-dependency-check.yml b/.github/workflows/ci-minimal-dependency-check.yml index d6e75471..5d8b531c 100644 --- a/.github/workflows/ci-minimal-dependency-check.yml +++ b/.github/workflows/ci-minimal-dependency-check.yml @@ -21,7 +21,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: "3.10" - name: Install LitServe run: |