diff --git a/src/litserve/__init__.py b/src/litserve/__init__.py index b3cc773e..41f3efea 100644 --- a/src/litserve/__init__.py +++ b/src/litserve/__init__.py @@ -16,5 +16,6 @@ from litserve.server import LitServer, Request, Response from litserve import test_examples from litserve.specs.openai import OpenAISpec +from litserve.callbacks import Callback -__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec"] +__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec", "Callback"] diff --git a/src/litserve/callbacks/__init__.py b/src/litserve/callbacks/__init__.py new file mode 100644 index 00000000..6c67a5de --- /dev/null +++ b/src/litserve/callbacks/__init__.py @@ -0,0 +1,3 @@ +from .base import Callback, CallbackRunner, EventTypes, NoopCallback + +__all__ = ["Callback", "CallbackRunner", "EventTypes", "NoopCallback"] diff --git a/src/litserve/callbacks/base.py b/src/litserve/callbacks/base.py new file mode 100644 index 00000000..df3fb368 --- /dev/null +++ b/src/litserve/callbacks/base.py @@ -0,0 +1,80 @@ +import dataclasses +import logging +from abc import ABC +from typing import List, Union + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class EventTypes: + BEFORE_SETUP = "on_before_setup" + AFTER_SETUP = "on_after_setup" + BEFORE_DECODE_REQUEST = "on_before_decode_request" + AFTER_DECODE_REQUEST = "on_after_decode_request" + BEFORE_ENCODE_RESPONSE = "on_before_encode_response" + AFTER_ENCODE_RESPONSE = "on_after_encode_response" + BEFORE_PREDICT = "on_before_predict" + AFTER_PREDICT = "on_after_predict" + ON_SERVER_START = "on_server_start" + ON_SERVER_END = "on_serve_end" + + +class Callback(ABC): + def on_before_setup(self, *args, **kwargs): + """Called before setup is started.""" + + def on_after_setup(self, *args, **kwargs): + """Called after setup is completed.""" + + def on_before_decode_request(self, *args, **kwargs): + """Called before request decoding is started.""" + + def on_after_decode_request(self, *args, **kwargs): + """Called after request decoding is completed.""" + + def on_before_encode_response(self, *args, **kwargs): + """Called before response encoding is started.""" + + def on_after_encode_response(self, *args, **kwargs): + """Called after response encoding is completed.""" + + def on_before_predict(self, *args, **kwargs): + """Called before prediction is started.""" + + def on_after_predict(self, *args, **kwargs): + """Called after prediction is completed.""" + + def on_server_start(self, *args, **kwargs): + """Called before server starts.""" + + def on_server_end(self, *args, **kwargs): + """Called when server terminates.""" + + +class CallbackRunner: + def __init__(self, callbacks: Union[Callback, List[Callback]] = None): + self._callbacks = [] + if callbacks: + self._add_callbacks(callbacks) + + def _add_callbacks(self, callbacks: Union[Callback, List[Callback]]): + if not isinstance(callbacks, list): + callbacks = [callbacks] + for callback in callbacks: + if not isinstance(callback, Callback): + raise ValueError(f"Invalid callback type: {callback}") + self._callbacks.extend(callbacks) + + def trigger_event(self, event_name, *args, **kwargs): + """Triggers an event, invoking all registered callbacks for that event.""" + for callback in self._callbacks: + try: + getattr(callback, event_name)(*args, **kwargs) + except Exception: + # Handle exceptions to prevent one callback from disrupting others + logger.exception(f"Error in callback '{callback}' during event '{event_name}'") + + +class NoopCallback(Callback): + """This callback does nothing.""" diff --git a/src/litserve/callbacks/defaults/__init__.py b/src/litserve/callbacks/defaults/__init__.py new file mode 100644 index 00000000..f72f34d0 --- /dev/null +++ b/src/litserve/callbacks/defaults/__init__.py @@ -0,0 +1,3 @@ +from litserve.callbacks.defaults.metric_callback import PredictionTimeLogger + +__all__ = ["PredictionTimeLogger"] diff --git a/src/litserve/callbacks/defaults/metric_callback.py b/src/litserve/callbacks/defaults/metric_callback.py new file mode 100644 index 00000000..c7831d93 --- /dev/null +++ b/src/litserve/callbacks/defaults/metric_callback.py @@ -0,0 +1,21 @@ +import time +import typing +from logging import getLogger + +from ..base import Callback + +if typing.TYPE_CHECKING: + from litserve import LitAPI + +logger = getLogger(__name__) + + +class PredictionTimeLogger(Callback): + def on_before_predict(self, lit_api: "LitAPI"): + t0 = time.perf_counter() + self._start_time = t0 + + def on_after_predict(self, lit_api: "LitAPI"): + t1 = time.perf_counter() + elapsed = t1 - self._start_time + print(f"Prediction took {elapsed:.2f} seconds", flush=True) diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 1c73b276..f6537cd8 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -25,6 +25,7 @@ from starlette.formparsers import MultiPartParser from litserve import LitAPI +from litserve.callbacks import CallbackRunner, EventTypes from litserve.specs.base import LitSpec from litserve.utils import LitAPIStatus @@ -93,7 +94,13 @@ def collate_requests( return payloads, timed_out_uids -def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): +def run_single_loop( + lit_api: LitAPI, + lit_spec: LitSpec, + request_queue: Queue, + response_queues: List[Queue], + callback_runner: CallbackRunner, +): while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -114,21 +121,31 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re context = {} if hasattr(lit_spec, "populate_context"): lit_spec.populate_context(context, x_enc) + + callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api) x = _inject_context( context, lit_api.decode_request, x_enc, ) + callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api) + + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) y = _inject_context( context, lit_api.predict, x, ) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) y_enc = _inject_context( context, lit_api.encode_response, y, ) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: logger.exception( @@ -147,6 +164,7 @@ def run_batched_loop( response_queues: List[Queue], max_batch_size: int, batch_timeout: float, + callback_runner: CallbackRunner, ): while True: batches, timed_out_uids = collate_requests( @@ -174,6 +192,7 @@ def run_batched_loop( for input, context in zip(inputs, contexts): lit_spec.populate_context(context, input) + callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api) x = [ _inject_context( context, @@ -182,12 +201,24 @@ def run_batched_loop( ) for input, context in zip(inputs, contexts) ] + callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api) + x = lit_api.batch(x) + + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) y = _inject_context(contexts, lit_api.predict, x) + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) + outputs = lit_api.unbatch(y) + + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) + y_enc_list = [] for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts): y_enc = _inject_context(context, lit_api.encode_response, y) + y_enc_list.append((response_queue_id, uid, y_enc)) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) + for response_queue_id, uid, y_enc in y_enc_list: response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: @@ -200,7 +231,13 @@ def run_batched_loop( response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) -def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): +def run_streaming_loop( + lit_api: LitAPI, + lit_spec: LitSpec, + request_queue: Queue, + response_queues: List[Queue], + callback_runner: CallbackRunner, +): while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -228,11 +265,15 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, lit_api.decode_request, x_enc, ) + + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) y_gen = _inject_context( context, lit_api.predict, x, ) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + y_enc_gen = _inject_context( context, lit_api.encode_response, @@ -258,6 +299,7 @@ def run_batched_streaming_loop( response_queues: List[Queue], max_batch_size: int, batch_timeout: float, + callback_runner: CallbackRunner, ): while True: batches, timed_out_uids = collate_requests( @@ -283,6 +325,7 @@ def run_batched_streaming_loop( for input, context in zip(inputs, contexts): lit_spec.populate_context(context, input) + callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api) x = [ _inject_context( context, @@ -291,10 +334,19 @@ def run_batched_streaming_loop( ) for input, context in zip(inputs, contexts) ] + callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api) + x = lit_api.batch(x) + + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) y_iter = _inject_context(contexts, lit_api.predict, x) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + unbatched_iter = lit_api.unbatch(y_iter) + + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) y_enc_iter = _inject_context(contexts, lit_api.encode_response, unbatched_iter) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) # y_enc_iter -> [[response-1, response-2], [response-1, response-2]] for y_batch in y_enc_iter: @@ -324,10 +376,13 @@ def inference_worker( max_batch_size: int, batch_timeout: float, stream: bool, - workers_setup_status: Dict[str, bool] = None, + workers_setup_status: Dict[str, bool], + callback_runner: CallbackRunner, ): + callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api) lit_api.setup(device) lit_api.device = device + callback_runner.trigger_event(EventTypes.AFTER_SETUP, lit_api=lit_api) print(f"Setup complete for worker {worker_id}.") @@ -338,17 +393,16 @@ def inference_worker( logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") if stream: if max_batch_size > 1: - run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) + run_batched_streaming_loop( + lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner + ) else: - run_streaming_loop(lit_api, lit_spec, request_queue, response_queues) + run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner) return if max_batch_size > 1: - run_batched_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) - else: - run_single_loop( - lit_api, - lit_spec, - request_queue, - response_queues, + run_batched_loop( + lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner ) + else: + run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner) diff --git a/src/litserve/server.py b/src/litserve/server.py index e5200dd6..289b720a 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -27,7 +27,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from queue import Empty -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union, List import uvicorn from fastapi import Depends, FastAPI, HTTPException, Request, Response @@ -37,6 +37,7 @@ from starlette.middleware.gzip import GZipMiddleware from litserve import LitAPI +from litserve.callbacks.base import CallbackRunner, Callback, EventTypes from litserve.connector import _Connector from litserve.loops import inference_worker from litserve.specs import OpenAISpec @@ -113,6 +114,7 @@ def __init__( stream: bool = False, spec: Optional[LitSpec] = None, max_payload_size=None, + callbacks: Optional[Union[List[Callback], Callback]] = None, middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None, ): if batch_timeout > timeout and timeout not in (False, -1): @@ -171,6 +173,7 @@ def __init__( self.stream = stream self.max_payload_size = max_payload_size self._connector = _Connector(accelerator=accelerator, devices=devices) + self._callback_runner = CallbackRunner(callbacks) specs = spec if spec is not None else [] self._specs = specs if isinstance(specs, Sequence) else [specs] @@ -234,6 +237,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int): self.batch_timeout, self.stream, self.workers_setup_status, + self._callback_runner, ), ) process.start() @@ -258,6 +262,7 @@ async def lifespan(self, app: FastAPI): yield + self._callback_runner.trigger_event(EventTypes.ON_SERVER_END, litserver=self) task.cancel() logger.debug("Shutting down response queue to buffer task") @@ -290,6 +295,7 @@ async def data_streamer(self, q: deque, data_available: asyncio.Event, send_stat def register_endpoints(self): """Register endpoint routes for the FastAPI app and setup middlewares.""" + self._callback_runner.trigger_event(EventTypes.ON_SERVER_START, litserver=self) workers_ready = False @self.app.get("/", dependencies=[Depends(self.setup_auth())]) diff --git a/tests/test_batch.py b/tests/test_batch.py index 1e9706db..5a1d1a1f 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -24,10 +24,13 @@ from httpx import AsyncClient from litserve import LitAPI, LitServer +from litserve.callbacks import CallbackRunner from litserve.loops import run_batched_loop, collate_requests from litserve.utils import wrap_litserve_start import litserve as ls +NOOP_CB_RUNNER = CallbackRunner() + class Linear(nn.Module): def __init__(self): @@ -166,7 +169,13 @@ def test_batched_loop(): with patch("pickle.dumps", side_effect=StopIteration("exit loop")), pytest.raises(StopIteration, match="exit loop"): run_batched_loop( - lit_api_mock, lit_api_mock, requests_queue, FakeResponseQueue(), max_batch_size=2, batch_timeout=4 + lit_api_mock, + lit_api_mock, + requests_queue, + FakeResponseQueue(), + max_batch_size=2, + batch_timeout=4, + callback_runner=NOOP_CB_RUNNER, ) lit_api_mock.batch.assert_called_once() diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 00000000..d6a12a02 --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,43 @@ +import re + +import litserve as ls +from fastapi.testclient import TestClient + +from litserve.callbacks import CallbackRunner, EventTypes +from litserve.callbacks.defaults import PredictionTimeLogger +from litserve.utils import wrap_litserve_start + + +def test_callback_runner(): + cb_runner = CallbackRunner() + assert cb_runner._callbacks == [], "Callbacks list must be empty" + + cb = PredictionTimeLogger() + cb_runner._add_callbacks(cb) + assert cb_runner._callbacks == [cb], "Callback not added to runner" + + +def test_callback(capfd): + lit_api = ls.test_examples.SimpleLitAPI() + server = ls.LitServer(lit_api, callbacks=[PredictionTimeLogger()]) + + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 16.0} + + captured = capfd.readouterr() + pattern = r"Prediction took \d+\.\d{2} seconds" + assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}" + + +def test_metric_logger(capfd): + cb = PredictionTimeLogger() + cb_runner = CallbackRunner() + cb_runner._add_callbacks(cb) + assert cb_runner._callbacks == [cb], "Callback not added to runner" + cb_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=None) + cb_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=None) + + captured = capfd.readouterr() + pattern = r"Prediction took \d+\.\d{2} seconds" + assert re.search(pattern, captured.out), f"Expected pattern not found in output: {captured.out}" diff --git a/tests/test_loops.py b/tests/test_loops.py index 2f032fb7..7c34c19a 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -22,6 +22,7 @@ import pytest from fastapi import HTTPException +from litserve.callbacks import CallbackRunner from litserve.loops import ( run_single_loop, run_streaming_loop, @@ -33,6 +34,8 @@ from litserve.utils import LitAPIStatus import litserve as ls +NOOP_CB_RUNNER = CallbackRunner() + @pytest.fixture def loop_args(): @@ -57,7 +60,7 @@ def test_single_loop(loop_args): response_queues = [FakeResponseQueue()] with pytest.raises(StopIteration, match="exit loop"): - run_single_loop(lit_api_mock, None, requests_queue, response_queues) + run_single_loop(lit_api_mock, None, requests_queue, response_queues, callback_runner=NOOP_CB_RUNNER) class FakeStreamResponseQueue: @@ -98,7 +101,9 @@ def fake_encode(output): response_queues = [FakeStreamResponseQueue(num_streamed_outputs)] with pytest.raises(StopIteration, match="exit loop"): - run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues) + run_streaming_loop( + fake_stream_api, fake_stream_api, requests_queue, response_queues, callback_runner=NOOP_CB_RUNNER + ) fake_stream_api.predict.assert_called_once_with("Hello") fake_stream_api.encode_response.assert_called_once() @@ -156,7 +161,13 @@ def fake_encode(output_iter): with pytest.raises(StopIteration, match="finish streaming"): run_batched_streaming_loop( - fake_stream_api, fake_stream_api, requests_queue, response_queues, max_batch_size=2, batch_timeout=2 + fake_stream_api, + fake_stream_api, + requests_queue, + response_queues, + max_batch_size=2, + batch_timeout=2, + callback_runner=NOOP_CB_RUNNER, ) fake_stream_api.predict.assert_called_once_with(["Hello", "World"]) fake_stream_api.encode_response.assert_called_once() @@ -165,10 +176,24 @@ def fake_encode(output_iter): @patch("litserve.loops.run_batched_loop") @patch("litserve.loops.run_single_loop") def test_inference_worker(mock_single_loop, mock_batched_loop): - inference_worker(*[MagicMock()] * 6, max_batch_size=2, batch_timeout=0, stream=False) + inference_worker( + *[MagicMock()] * 6, + max_batch_size=2, + batch_timeout=0, + stream=False, + workers_setup_status={}, + callback_runner=NOOP_CB_RUNNER, + ) mock_batched_loop.assert_called_once() - inference_worker(*[MagicMock()] * 6, max_batch_size=1, batch_timeout=0, stream=False) + inference_worker( + *[MagicMock()] * 6, + max_batch_size=1, + batch_timeout=0, + stream=False, + workers_setup_status={}, + callback_runner=NOOP_CB_RUNNER, + ) mock_single_loop.assert_called_once() @@ -182,7 +207,9 @@ def test_run_single_loop(): response_queues = [Queue()] # Run the loop in a separate thread to allow it to be stopped - loop_thread = threading.Thread(target=run_single_loop, args=(lit_api, None, request_queue, response_queues)) + loop_thread = threading.Thread( + target=run_single_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER) + ) loop_thread.start() # Allow some time for the loop to process @@ -208,7 +235,9 @@ def test_run_single_loop_timeout(caplog): response_queues = [Queue()] # Run the loop in a separate thread to allow it to be stopped - loop_thread = threading.Thread(target=run_single_loop, args=(lit_api, None, request_queue, response_queues)) + loop_thread = threading.Thread( + target=run_single_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER) + ) loop_thread.start() request_queue.put((None, None, None, None)) @@ -231,7 +260,9 @@ def test_run_batched_loop(): response_queues = [Queue()] # Run the loop in a separate thread to allow it to be stopped - loop_thread = threading.Thread(target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 1)) + loop_thread = threading.Thread( + target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 1, NOOP_CB_RUNNER) + ) loop_thread.start() # Allow some time for the loop to process @@ -265,7 +296,7 @@ def test_run_batched_loop_timeout(caplog): # Run the loop in a separate thread to allow it to be stopped loop_thread = threading.Thread( - target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 0.001) + target=run_batched_loop, args=(lit_api, None, request_queue, response_queues, 2, 0.001, NOOP_CB_RUNNER) ) loop_thread.start() @@ -293,7 +324,9 @@ def test_run_streaming_loop(): response_queues = [Queue()] # Run the loop in a separate thread to allow it to be stopped - loop_thread = threading.Thread(target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues)) + loop_thread = threading.Thread( + target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER) + ) loop_thread.start() # Allow some time for the loop to process @@ -319,7 +352,9 @@ def test_run_streaming_loop_timeout(caplog): response_queues = [Queue()] # Run the loop in a separate thread to allow it to be stopped - loop_thread = threading.Thread(target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues)) + loop_thread = threading.Thread( + target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER) + ) loop_thread.start() # Allow some time for the loop to process @@ -352,7 +387,7 @@ def off_test_run_batched_streaming_loop(openai_request_data): # Run the loop in a separate thread to allow it to be stopped loop_thread = threading.Thread( - target=run_batched_streaming_loop, args=(lit_api, spec, request_queue, response_queues, 2, 0.1) + target=run_batched_streaming_loop, args=(lit_api, spec, request_queue, response_queues, 2, 0.1, NOOP_CB_RUNNER) ) loop_thread.start()