From 0ba40711528f8f15b0f79d6fb23e4f0e4ce22888 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 16 May 2024 22:05:19 +0100 Subject: [PATCH 1/2] add custom endpoint path --- README.md | 10 ++++++++++ src/litserve/server.py | 16 ++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 25469099..e3e91908 100644 --- a/README.md +++ b/README.md @@ -523,6 +523,16 @@ if __name__ == "__main__": +
+ Customize endpoint path + +  +By default, LitServe exposes `/predict` endpoint for serving the model, but you can customize the API endpoint +path by providing `LitServer(..., endpoint_path="/api/CUSTOM_PATH")`. + +
+ + # Contribute LitServe is a community project accepting contributions. Let's make the world's most advanced AI inference engine. diff --git a/src/litserve/server.py b/src/litserve/server.py index 1a1c5eba..f2fc9f69 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -338,6 +338,7 @@ def __init__( timeout: Union[float, bool] = 30, max_batch_size: int = 1, batch_timeout: float = 0.0, + endpoint_path: str = "/predict", stream: bool = False, ): if batch_timeout > timeout and timeout not in (False, -1): @@ -345,6 +346,10 @@ def __init__( if max_batch_size <= 0: raise ValueError("max_batch_size must be greater than 0") + if not endpoint_path.startswith("/"): + raise ValueError("endpoint_path must start with '/', like '/predict' or '/api/predict'.") + + self.endpoint_path = endpoint_path lit_api.stream = stream lit_api.sanitize(max_batch_size) self.app = FastAPI(lifespan=lifespan) @@ -404,7 +409,6 @@ def setup_server(self): async def index(request: Request) -> Response: return Response(content="litserve running") - @self.app.post("/predict", dependencies=[Depends(setup_auth())]) async def predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: uid = uuid.uuid4() logger.debug(f"Received request uid={uid}") @@ -452,7 +456,6 @@ async def data_reader(): load_and_raise(response) return response - @self.app.post("/stream-predict", dependencies=[Depends(setup_auth())]) async def stream_predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type: uid = uuid.uuid4() logger.debug(f"Received request uid={uid}") @@ -510,6 +513,15 @@ async def data_streamer(): return StreamingResponse(data_streamer()) + stream = self.app.lit_api.stream + methods = ["POST"] + self.app.add_api_route( + self.endpoint_path, + stream_predict if stream else predict, + methods=methods, + dependencies=[Depends(setup_auth())], + ) + def generate_client_file(self): src_path = os.path.join(os.path.dirname(__file__), "python_client.py") dest_path = os.path.join(os.getcwd(), "client.py") From 18f74f5706d28253f233a1490260efa5fd3be532 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Thu, 16 May 2024 22:17:13 +0100 Subject: [PATCH 2/2] fix tests --- tests/e2e/test_e2e.py | 2 +- tests/test_lit_server.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 9fccdf0b..748cedf1 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -56,7 +56,7 @@ def test_e2e_batched_streaming(): ) time.sleep(5) - resp = requests.post("http://127.0.0.1:8000/stream-predict", json={"input": 4.0}, headers=None, stream=True) + resp = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}, headers=None, stream=True) assert resp.status_code == 200, f"Expected response to be 200 but got {resp.status_code}" outputs = [] diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 09d681bf..0c4ca0d6 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -143,8 +143,8 @@ async def test_stream(simple_stream_api): expected_output2 = "prompt=World generated_output=LitServe is streaming output".lower().replace(" ", "") async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp1 = ac.post("/stream-predict", json={"prompt": "Hello"}, timeout=10) - resp2 = ac.post("/stream-predict", json={"prompt": "World"}, timeout=10) + resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10) + resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10) resp1, resp2 = await asyncio.gather(resp1, resp2) assert resp1.status_code == 200, "Check if server is running and the request format is valid." assert resp1.text == expected_output1, "Server returns input prompt and generated output which didn't match." @@ -154,13 +154,20 @@ async def test_stream(simple_stream_api): @pytest.mark.asyncio() async def test_batched_stream_server(simple_batched_stream_api): - server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30) + server = LitServer( + simple_batched_stream_api, + stream=True, + max_batch_size=4, + batch_timeout=2, + timeout=30, + endpoint_path="/v1/stream/predict", + ) expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "") expected_output2 = "World LitServe is streaming output".lower().replace(" ", "") async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: - resp1 = ac.post("/stream-predict", json={"prompt": "Hello"}, timeout=10) - resp2 = ac.post("/stream-predict", json={"prompt": "World"}, timeout=10) + resp1 = ac.post("/v1/stream/predict", json={"prompt": "Hello"}, timeout=10) + resp2 = ac.post("/v1/stream/predict", json={"prompt": "World"}, timeout=10) resp1, resp2 = await asyncio.gather(resp1, resp2) assert resp1.status_code == 200, "Check if server is running and the request format is valid." assert resp2.status_code == 200, "Check if server is running and the request format is valid."