-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init * connector * update * update * fix tests * cleanup * update * fix tests * apply callbacks to all the loops * update * fix tests * add test * fix test * rename callback * fix docstrings * refactor hook names * handle None * no root logger * apply feedback
- Loading branch information
1 parent
cf78c85
commit 4d2ac63
Showing
10 changed files
with
282 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import Callback, CallbackRunner, EventTypes, NoopCallback | ||
|
||
__all__ = ["Callback", "CallbackRunner", "EventTypes", "NoopCallback"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import dataclasses | ||
import logging | ||
from abc import ABC | ||
from typing import List, Union | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclasses.dataclass | ||
class EventTypes: | ||
BEFORE_SETUP = "on_before_setup" | ||
AFTER_SETUP = "on_after_setup" | ||
BEFORE_DECODE_REQUEST = "on_before_decode_request" | ||
AFTER_DECODE_REQUEST = "on_after_decode_request" | ||
BEFORE_ENCODE_RESPONSE = "on_before_encode_response" | ||
AFTER_ENCODE_RESPONSE = "on_after_encode_response" | ||
BEFORE_PREDICT = "on_before_predict" | ||
AFTER_PREDICT = "on_after_predict" | ||
ON_SERVER_START = "on_server_start" | ||
ON_SERVER_END = "on_serve_end" | ||
|
||
|
||
class Callback(ABC): | ||
def on_before_setup(self, *args, **kwargs): | ||
"""Called before setup is started.""" | ||
|
||
def on_after_setup(self, *args, **kwargs): | ||
"""Called after setup is completed.""" | ||
|
||
def on_before_decode_request(self, *args, **kwargs): | ||
"""Called before request decoding is started.""" | ||
|
||
def on_after_decode_request(self, *args, **kwargs): | ||
"""Called after request decoding is completed.""" | ||
|
||
def on_before_encode_response(self, *args, **kwargs): | ||
"""Called before response encoding is started.""" | ||
|
||
def on_after_encode_response(self, *args, **kwargs): | ||
"""Called after response encoding is completed.""" | ||
|
||
def on_before_predict(self, *args, **kwargs): | ||
"""Called before prediction is started.""" | ||
|
||
def on_after_predict(self, *args, **kwargs): | ||
"""Called after prediction is completed.""" | ||
|
||
def on_server_start(self, *args, **kwargs): | ||
"""Called before server starts.""" | ||
|
||
def on_server_end(self, *args, **kwargs): | ||
"""Called when server terminates.""" | ||
|
||
|
||
class CallbackRunner: | ||
def __init__(self, callbacks: Union[Callback, List[Callback]] = None): | ||
self._callbacks = [] | ||
if callbacks: | ||
self._add_callbacks(callbacks) | ||
|
||
def _add_callbacks(self, callbacks: Union[Callback, List[Callback]]): | ||
if not isinstance(callbacks, list): | ||
callbacks = [callbacks] | ||
for callback in callbacks: | ||
if not isinstance(callback, Callback): | ||
raise ValueError(f"Invalid callback type: {callback}") | ||
self._callbacks.extend(callbacks) | ||
|
||
def trigger_event(self, event_name, *args, **kwargs): | ||
"""Triggers an event, invoking all registered callbacks for that event.""" | ||
for callback in self._callbacks: | ||
try: | ||
getattr(callback, event_name)(*args, **kwargs) | ||
except Exception: | ||
# Handle exceptions to prevent one callback from disrupting others | ||
logger.exception(f"Error in callback '{callback}' during event '{event_name}'") | ||
|
||
|
||
class NoopCallback(Callback): | ||
"""This callback does nothing.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from litserve.callbacks.defaults.metric_callback import PredictionTimeLogger | ||
|
||
__all__ = ["PredictionTimeLogger"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import time | ||
import typing | ||
from logging import getLogger | ||
|
||
from ..base import Callback | ||
|
||
if typing.TYPE_CHECKING: | ||
from litserve import LitAPI | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class PredictionTimeLogger(Callback): | ||
def on_before_predict(self, lit_api: "LitAPI"): | ||
t0 = time.perf_counter() | ||
self._start_time = t0 | ||
|
||
def on_after_predict(self, lit_api: "LitAPI"): | ||
t1 = time.perf_counter() | ||
elapsed = t1 - self._start_time | ||
print(f"Prediction took {elapsed:.2f} seconds", flush=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.