Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback system v0 #278

Merged
merged 22 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/litserve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
from litserve.server import LitServer, Request, Response
from litserve import test_examples
from litserve.specs.openai import OpenAISpec
from litserve.callbacks import Callback

__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec"]
__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec", "Callback"]
3 changes: 3 additions & 0 deletions src/litserve/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import Callback, CallbackRunner, EventTypes, NoopCallback

__all__ = ["Callback", "CallbackRunner", "EventTypes", "NoopCallback"]
76 changes: 76 additions & 0 deletions src/litserve/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import dataclasses
import logging
from abc import ABC
from typing import List, Union


@dataclasses.dataclass
class EventTypes:
LITAPI_SETUP_START = "on_litapi_setup_start"
LITAPI_SETUP_END = "on_litapi_setup_end"
LITAPI_DECODE_REQUEST_START = "on_litapi_decode_request_start"
LITAPI_DECODE_REQUEST_END = "on_litapi_decode_request_end"
LITAPI_ENCODE_RESPONSE_START = "on_litapi_encode_response_start"
LITAPI_ENCODE_RESPONSE_END = "on_litapi_encode_response_end"
LITAPI_PREDICT_START = "on_litapi_predict_start"
LITAPI_PREDICT_END = "on_litapi_predict_end"
SERVER_SETUP_START = "on_server_setup_start"
SERVER_SETUP_END = "on_server_setup_end"


class Callback(ABC):
def on_litapi_predict_start(self, *args, **kwargs):
"""Called before LitAPI.predict() is called."""

def on_litapi_predict_end(self, *args, **kwargs):
"""Called after LitAPI.predict() is called."""

def on_litapi_decode_request_start(self, *args, **kwargs):
"""Called before LitAPI.decode_request() is called."""

def on_litapi_decode_request_end(self, *args, **kwargs):
"""Called after LitAPI.decode_request() is called."""

def on_litapi_encode_response_start(self, *args, **kwargs):
"""Called before LitAPI.encode_response() is called."""

def on_litapi_encode_response_end(self, *args, **kwargs):
"""Called after LitAPI.encode_response() is called."""

def on_litapi_setup_start(self, *args, **kwargs):
"""Called before LitAPI.setup() is called."""

def on_litapi_setup_end(self, *args, **kwargs):
"""Called after LitAPI.setup() is called."""

def on_server_setup_start(self, *args, **kwargs):
"""Called before LitServer.setup_server() is called."""

def on_server_setup_end(self, *args, **kwargs):
"""Called after LitServer.setup_server() is called."""


class CallbackRunner:
def __init__(self):
self._callbacks = []

def add_callbacks(self, callbacks: Union[Callback, List[Callback]]):
if isinstance(callbacks, list):
self._callbacks.extend(callbacks)
else:
self._callbacks.append(callbacks)

def trigger_event(self, event_name, *args, **kwargs):
"""Triggers an event, invoking all registered callbacks for that event."""
if not self._callbacks:
return
for callback in self._callbacks:
try:
getattr(callback, event_name)(*args, **kwargs)
except Exception:
# Handle exceptions to prevent one callback from disrupting others
logging.exception(f"Error in callback '{callback}' during event '{event_name}'")


class NoopCallback(Callback):
"""This callback does nothing."""
3 changes: 3 additions & 0 deletions src/litserve/callbacks/defaults/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from litserve.callbacks.defaults.metric_callback import PredictionTimeLogger

__all__ = ["PredictionTimeLogger"]
21 changes: 21 additions & 0 deletions src/litserve/callbacks/defaults/metric_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import time
import typing
from logging import getLogger

from ..base import Callback

if typing.TYPE_CHECKING:
from litserve import LitAPI

logger = getLogger(__name__)


class PredictionTimeLogger(Callback):
def on_litapi_predict_start(self, lit_api: "LitAPI"):
t0 = time.perf_counter()
self._start_time = t0

def on_litapi_predict_end(self, lit_api: "LitAPI"):
t1 = time.perf_counter()
elapsed = t1 - self._start_time
print(f"Prediction took {elapsed:.2f} seconds", flush=True)
78 changes: 66 additions & 12 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from starlette.formparsers import MultiPartParser

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus

Expand Down Expand Up @@ -93,7 +94,13 @@ def collate_requests(
return payloads, timed_out_uids


def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_single_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
callback_runner: CallbackRunner,
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand All @@ -114,21 +121,31 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re
context = {}
if hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(context, x_enc)

callback_runner.trigger_event(EventTypes.LITAPI_DECODE_REQUEST_START, lit_api=lit_api)
x = _inject_context(
context,
lit_api.decode_request,
x_enc,
)
callback_runner.trigger_event(EventTypes.LITAPI_DECODE_REQUEST_END, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.LITAPI_PREDICT_START, lit_api=lit_api)
y = _inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.LITAPI_PREDICT_END, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.LITAPI_ENCODE_RESPONSE_START, lit_api=lit_api)
y_enc = _inject_context(
context,
lit_api.encode_response,
y,
)
callback_runner.trigger_event(EventTypes.LITAPI_ENCODE_RESPONSE_END, lit_api=lit_api)

response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
except Exception as e:
logger.exception(
Expand All @@ -147,6 +164,7 @@ def run_batched_loop(
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
):
while True:
batches, timed_out_uids = collate_requests(
Expand Down Expand Up @@ -174,6 +192,7 @@ def run_batched_loop(
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.LITAPI_DECODE_REQUEST_START, lit_api=lit_api)
x = [
_inject_context(
context,
Expand All @@ -182,12 +201,24 @@ def run_batched_loop(
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.LITAPI_DECODE_REQUEST_END, lit_api=lit_api)

x = lit_api.batch(x)

callback_runner.trigger_event(EventTypes.LITAPI_PREDICT_START, lit_api=lit_api)
y = _inject_context(contexts, lit_api.predict, x)
callback_runner.trigger_event(EventTypes.LITAPI_PREDICT_START, lit_api=lit_api)

outputs = lit_api.unbatch(y)

callback_runner.trigger_event(EventTypes.LITAPI_ENCODE_RESPONSE_START, lit_api=lit_api)
y_enc_list = []
for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts):
y_enc = _inject_context(context, lit_api.encode_response, y)
y_enc_list.append((response_queue_id, uid, y_enc))
callback_runner.trigger_event(EventTypes.LITAPI_ENCODE_RESPONSE_END, lit_api=lit_api)

for response_queue_id, uid, y_enc in y_enc_list:
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

except Exception as e:
Expand All @@ -200,7 +231,13 @@ def run_batched_loop(
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))


def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_streaming_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
callback_runner: CallbackRunner,
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand Down Expand Up @@ -228,11 +265,15 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue,
lit_api.decode_request,
x_enc,
)

callback_runner.trigger_event(EventTypes.LITAPI_PREDICT_START, lit_api=lit_api)
y_gen = _inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.LITAPI_PREDICT_END, lit_api=lit_api)

y_enc_gen = _inject_context(
context,
lit_api.encode_response,
Expand All @@ -258,6 +299,7 @@ def run_batched_streaming_loop(
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
):
while True:
batches, timed_out_uids = collate_requests(
Expand All @@ -283,6 +325,7 @@ def run_batched_streaming_loop(
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.LITAPI_DECODE_REQUEST_START, lit_api=lit_api)
x = [
_inject_context(
context,
Expand All @@ -291,10 +334,19 @@ def run_batched_streaming_loop(
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.LITAPI_DECODE_REQUEST_END, lit_api=lit_api)

x = lit_api.batch(x)

callback_runner.trigger_event(EventTypes.LITAPI_PREDICT_START, lit_api=lit_api)
y_iter = _inject_context(contexts, lit_api.predict, x)
callback_runner.trigger_event(EventTypes.LITAPI_PREDICT_END, lit_api=lit_api)

unbatched_iter = lit_api.unbatch(y_iter)

callback_runner.trigger_event(EventTypes.LITAPI_ENCODE_RESPONSE_START, lit_api=lit_api)
y_enc_iter = _inject_context(contexts, lit_api.encode_response, unbatched_iter)
callback_runner.trigger_event(EventTypes.LITAPI_ENCODE_RESPONSE_END, lit_api=lit_api)

# y_enc_iter -> [[response-1, response-2], [response-1, response-2]]
for y_batch in y_enc_iter:
Expand Down Expand Up @@ -324,10 +376,13 @@ def inference_worker(
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[str, bool] = None,
workers_setup_status: Dict[str, bool],
callback_runner: CallbackRunner,
):
callback_runner.trigger_event(EventTypes.LITAPI_SETUP_START, lit_api=lit_api)
lit_api.setup(device)
lit_api.device = device
callback_runner.trigger_event(EventTypes.LITAPI_SETUP_END, lit_api=lit_api)

print(f"Setup complete for worker {worker_id}.")

Expand All @@ -338,17 +393,16 @@ def inference_worker(
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")
if stream:
if max_batch_size > 1:
run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
run_batched_streaming_loop(
lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner
)
else:
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues)
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
return

if max_batch_size > 1:
run_batched_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
else:
run_single_loop(
lit_api,
lit_spec,
request_queue,
response_queues,
run_batched_loop(
lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner
)
else:
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
10 changes: 9 additions & 1 deletion src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from queue import Empty
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, Optional, Sequence, Tuple, Union, List

import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, Response
Expand All @@ -37,6 +37,7 @@
from starlette.middleware.gzip import GZipMiddleware

from litserve import LitAPI
from litserve.callbacks.base import CallbackRunner, Callback, EventTypes
from litserve.connector import _Connector
from litserve.loops import inference_worker
from litserve.specs import OpenAISpec
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
stream: bool = False,
spec: Optional[LitSpec] = None,
max_payload_size=None,
callbacks: Optional[Union[List[Callback], Callback]] = None,
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
):
if batch_timeout > timeout and timeout not in (False, -1):
Expand Down Expand Up @@ -171,6 +173,8 @@ def __init__(
self.stream = stream
self.max_payload_size = max_payload_size
self._connector = _Connector(accelerator=accelerator, devices=devices)
self._callback_runner = CallbackRunner()
self._callback_runner.add_callbacks(callbacks)

specs = spec if spec is not None else []
self._specs = specs if isinstance(specs, Sequence) else [specs]
Expand Down Expand Up @@ -240,6 +244,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
self.batch_timeout,
self.stream,
self.workers_setup_status,
self._callback_runner,
),
)
process.start()
Expand Down Expand Up @@ -295,6 +300,7 @@ async def data_streamer(self, q: deque, data_available: asyncio.Event, send_stat
data_available.clear()

def setup_server(self):
self._callback_runner.trigger_event(EventTypes.SERVER_SETUP_START, litserver=self)
workers_ready = False

@self.app.get("/", dependencies=[Depends(self.setup_auth())])
Expand Down Expand Up @@ -380,6 +386,8 @@ async def stream_predict(request: self.request_type) -> self.response_type:
elif callable(middleware):
self.app.add_middleware(middleware)

self._callback_runner.trigger_event(EventTypes.SERVER_SETUP_END, litserver=self)

@staticmethod
def generate_client_file():
src_path = os.path.join(os.path.dirname(__file__), "python_client.py")
Expand Down
Loading
Loading