Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: adds health check endpoint #182

Merged
merged 13 commits into from
Jul 29, 2024
19 changes: 19 additions & 0 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,12 @@ def inference_worker(
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[str, bool] = None,
):
lit_api.setup(device)
lit_api.device = device
if workers_setup_status:
workers_setup_status[worker_id] = True
message = f"Setup complete for worker {worker_id}."
print(message)
logger.info(message)
Expand Down Expand Up @@ -476,6 +479,7 @@ async def lifespan(self, app: FastAPI):
manager = mp.Manager()
self.request_queue = manager.Queue()
self.response_buffer = {}
self.workers_setup_status = manager.dict()

response_queues = []
tasks: List[asyncio.Task] = []
Expand Down Expand Up @@ -506,6 +510,7 @@ def close_tasks():
if len(device) == 1:
device = device[0]

self.workers_setup_status[worker_id] = False
response_queue = manager.Queue()
response_queues.append(response_queue)

Expand All @@ -526,6 +531,7 @@ def close_tasks():
self.max_batch_size,
self.batch_timeout,
self.stream,
self.workers_setup_status,
),
daemon=True,
)
Expand Down Expand Up @@ -579,10 +585,23 @@ async def data_streamer(self, q: deque, data_available: asyncio.Event, send_stat
data_available.clear()

def setup_server(self):
workers_ready = False

@self.app.get("/", dependencies=[Depends(self.setup_auth())])
async def index(request: Request) -> Response:
return Response(content="litserve running")

@self.app.get("/health", dependencies=[Depends(self.setup_auth())])
async def health(request: Request) -> Response:
nonlocal workers_ready
if not workers_ready:
workers_ready = all(self.workers_setup_status.values())

if workers_ready:
return Response(content="ok", status_code=200)

return Response(content="not ready", status_code=503)

async def predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type:
uid = uuid.uuid4()
event = asyncio.Event()
Expand Down
25 changes: 25 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ def test_simple():
assert response.json() == {"output": 16.0}


class SlowSetupLitAPI(SimpleLitAPI):
def setup(self, device):
self.model = lambda x: x**2
time.sleep(2)


def test_workers_health():
server = LitServer(SlowSetupLitAPI(), accelerator="cpu", devices=1, timeout=5, workers_per_device=2)

with TestClient(server.app) as client:
response = client.get("/health")
assert response.status_code == 503
assert response.text == "not ready"

time.sleep(1)
response = client.get("/health")
assert response.status_code == 503
assert response.text == "not ready"

time.sleep(3)
response = client.get("/health")
assert response.status_code == 200
assert response.text == "ok"


def make_load_request(server, outputs):
with TestClient(server.app) as client:
for _ in range(100):
Expand Down
Loading