Skip to content

Commit

Permalink
Callback system v0 (#278)
Browse files Browse the repository at this point in the history
* init

* connector

* update

* update

* fix tests

* cleanup

* update

* fix tests

* apply callbacks to all the loops

* update

* fix tests

* add test

* fix test

* rename callback

* fix docstrings

* refactor hook names

* handle None

* no root logger

* apply feedback
  • Loading branch information
aniketmaurya committed Sep 18, 2024
1 parent cf78c85 commit 4d2ac63
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 27 deletions.
3 changes: 2 additions & 1 deletion src/litserve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
from litserve.server import LitServer, Request, Response
from litserve import test_examples
from litserve.specs.openai import OpenAISpec
from litserve.callbacks import Callback

__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec"]
__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec", "Callback"]
3 changes: 3 additions & 0 deletions src/litserve/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import Callback, CallbackRunner, EventTypes, NoopCallback

__all__ = ["Callback", "CallbackRunner", "EventTypes", "NoopCallback"]
80 changes: 80 additions & 0 deletions src/litserve/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import dataclasses
import logging
from abc import ABC
from typing import List, Union

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class EventTypes:
BEFORE_SETUP = "on_before_setup"
AFTER_SETUP = "on_after_setup"
BEFORE_DECODE_REQUEST = "on_before_decode_request"
AFTER_DECODE_REQUEST = "on_after_decode_request"
BEFORE_ENCODE_RESPONSE = "on_before_encode_response"
AFTER_ENCODE_RESPONSE = "on_after_encode_response"
BEFORE_PREDICT = "on_before_predict"
AFTER_PREDICT = "on_after_predict"
ON_SERVER_START = "on_server_start"
ON_SERVER_END = "on_serve_end"


class Callback(ABC):
def on_before_setup(self, *args, **kwargs):
"""Called before setup is started."""

def on_after_setup(self, *args, **kwargs):
"""Called after setup is completed."""

def on_before_decode_request(self, *args, **kwargs):
"""Called before request decoding is started."""

def on_after_decode_request(self, *args, **kwargs):
"""Called after request decoding is completed."""

def on_before_encode_response(self, *args, **kwargs):
"""Called before response encoding is started."""

def on_after_encode_response(self, *args, **kwargs):
"""Called after response encoding is completed."""

def on_before_predict(self, *args, **kwargs):
"""Called before prediction is started."""

def on_after_predict(self, *args, **kwargs):
"""Called after prediction is completed."""

def on_server_start(self, *args, **kwargs):
"""Called before server starts."""

def on_server_end(self, *args, **kwargs):
"""Called when server terminates."""


class CallbackRunner:
def __init__(self, callbacks: Union[Callback, List[Callback]] = None):
self._callbacks = []
if callbacks:
self._add_callbacks(callbacks)

def _add_callbacks(self, callbacks: Union[Callback, List[Callback]]):
if not isinstance(callbacks, list):
callbacks = [callbacks]
for callback in callbacks:
if not isinstance(callback, Callback):
raise ValueError(f"Invalid callback type: {callback}")
self._callbacks.extend(callbacks)

def trigger_event(self, event_name, *args, **kwargs):
"""Triggers an event, invoking all registered callbacks for that event."""
for callback in self._callbacks:
try:
getattr(callback, event_name)(*args, **kwargs)
except Exception:
# Handle exceptions to prevent one callback from disrupting others
logger.exception(f"Error in callback '{callback}' during event '{event_name}'")


class NoopCallback(Callback):
"""This callback does nothing."""
3 changes: 3 additions & 0 deletions src/litserve/callbacks/defaults/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from litserve.callbacks.defaults.metric_callback import PredictionTimeLogger

__all__ = ["PredictionTimeLogger"]
21 changes: 21 additions & 0 deletions src/litserve/callbacks/defaults/metric_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import time
import typing
from logging import getLogger

from ..base import Callback

if typing.TYPE_CHECKING:
from litserve import LitAPI

logger = getLogger(__name__)


class PredictionTimeLogger(Callback):
def on_before_predict(self, lit_api: "LitAPI"):
t0 = time.perf_counter()
self._start_time = t0

def on_after_predict(self, lit_api: "LitAPI"):
t1 = time.perf_counter()
elapsed = t1 - self._start_time
print(f"Prediction took {elapsed:.2f} seconds", flush=True)
78 changes: 66 additions & 12 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from starlette.formparsers import MultiPartParser

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus

Expand Down Expand Up @@ -93,7 +94,13 @@ def collate_requests(
return payloads, timed_out_uids


def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_single_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
callback_runner: CallbackRunner,
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand All @@ -114,21 +121,31 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re
context = {}
if hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(context, x_enc)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
x = _inject_context(
context,
lit_api.decode_request,
x_enc,
)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
y = _inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
y_enc = _inject_context(
context,
lit_api.encode_response,
y,
)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
except Exception as e:
logger.exception(
Expand All @@ -147,6 +164,7 @@ def run_batched_loop(
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
):
while True:
batches, timed_out_uids = collate_requests(
Expand Down Expand Up @@ -174,6 +192,7 @@ def run_batched_loop(
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
x = [
_inject_context(
context,
Expand All @@ -182,12 +201,24 @@ def run_batched_loop(
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)

x = lit_api.batch(x)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
y = _inject_context(contexts, lit_api.predict, x)
callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)

outputs = lit_api.unbatch(y)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
y_enc_list = []
for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts):
y_enc = _inject_context(context, lit_api.encode_response, y)
y_enc_list.append((response_queue_id, uid, y_enc))
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

for response_queue_id, uid, y_enc in y_enc_list:
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

except Exception as e:
Expand All @@ -200,7 +231,13 @@ def run_batched_loop(
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))


def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_streaming_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
callback_runner: CallbackRunner,
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand Down Expand Up @@ -228,11 +265,15 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue,
lit_api.decode_request,
x_enc,
)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
y_gen = _inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)

y_enc_gen = _inject_context(
context,
lit_api.encode_response,
Expand All @@ -258,6 +299,7 @@ def run_batched_streaming_loop(
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
):
while True:
batches, timed_out_uids = collate_requests(
Expand All @@ -283,6 +325,7 @@ def run_batched_streaming_loop(
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
x = [
_inject_context(
context,
Expand All @@ -291,10 +334,19 @@ def run_batched_streaming_loop(
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)

x = lit_api.batch(x)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
y_iter = _inject_context(contexts, lit_api.predict, x)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)

unbatched_iter = lit_api.unbatch(y_iter)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
y_enc_iter = _inject_context(contexts, lit_api.encode_response, unbatched_iter)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

# y_enc_iter -> [[response-1, response-2], [response-1, response-2]]
for y_batch in y_enc_iter:
Expand Down Expand Up @@ -324,10 +376,13 @@ def inference_worker(
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[str, bool] = None,
workers_setup_status: Dict[str, bool],
callback_runner: CallbackRunner,
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
lit_api.setup(device)
lit_api.device = device
callback_runner.trigger_event(EventTypes.AFTER_SETUP, lit_api=lit_api)

print(f"Setup complete for worker {worker_id}.")

Expand All @@ -338,17 +393,16 @@ def inference_worker(
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")
if stream:
if max_batch_size > 1:
run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
run_batched_streaming_loop(
lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner
)
else:
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues)
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
return

if max_batch_size > 1:
run_batched_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
else:
run_single_loop(
lit_api,
lit_spec,
request_queue,
response_queues,
run_batched_loop(
lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner
)
else:
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
8 changes: 7 additions & 1 deletion src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from queue import Empty
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, Optional, Sequence, Tuple, Union, List

import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, Response
Expand All @@ -37,6 +37,7 @@
from starlette.middleware.gzip import GZipMiddleware

from litserve import LitAPI
from litserve.callbacks.base import CallbackRunner, Callback, EventTypes
from litserve.connector import _Connector
from litserve.loops import inference_worker
from litserve.specs import OpenAISpec
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
stream: bool = False,
spec: Optional[LitSpec] = None,
max_payload_size=None,
callbacks: Optional[Union[List[Callback], Callback]] = None,
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
):
if batch_timeout > timeout and timeout not in (False, -1):
Expand Down Expand Up @@ -171,6 +173,7 @@ def __init__(
self.stream = stream
self.max_payload_size = max_payload_size
self._connector = _Connector(accelerator=accelerator, devices=devices)
self._callback_runner = CallbackRunner(callbacks)

specs = spec if spec is not None else []
self._specs = specs if isinstance(specs, Sequence) else [specs]
Expand Down Expand Up @@ -234,6 +237,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
self.batch_timeout,
self.stream,
self.workers_setup_status,
self._callback_runner,
),
)
process.start()
Expand All @@ -258,6 +262,7 @@ async def lifespan(self, app: FastAPI):

yield

self._callback_runner.trigger_event(EventTypes.ON_SERVER_END, litserver=self)
task.cancel()
logger.debug("Shutting down response queue to buffer task")

Expand Down Expand Up @@ -290,6 +295,7 @@ async def data_streamer(self, q: deque, data_available: asyncio.Event, send_stat

def register_endpoints(self):
"""Register endpoint routes for the FastAPI app and setup middlewares."""
self._callback_runner.trigger_event(EventTypes.ON_SERVER_START, litserver=self)
workers_ready = False

@self.app.get("/", dependencies=[Depends(self.setup_auth())])
Expand Down
Loading

0 comments on commit 4d2ac63

Please sign in to comment.