Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add custom endpoint path #91

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,16 @@ if __name__ == "__main__":
</details>


<details>
<summary>Customize endpoint path</summary>

&nbsp;
By default, LitServe exposes `/predict` endpoint for serving the model, but you can customize the API endpoint
path by providing `LitServer(..., endpoint_path="/api/CUSTOM_PATH")`.

</details>


# Contribute
LitServe is a community project accepting contributions. Let's make the world's most advanced AI inference engine.

Expand Down
16 changes: 14 additions & 2 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,18 @@ def __init__(
timeout: Union[float, bool] = 30,
max_batch_size: int = 1,
batch_timeout: float = 0.0,
endpoint_path: str = "/predict",
stream: bool = False,
):
if batch_timeout > timeout and timeout not in (False, -1):
raise ValueError("batch_timeout must be less than timeout")
if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")

if not endpoint_path.startswith("/"):
raise ValueError("endpoint_path must start with '/', like '/predict' or '/api/predict'.")

self.endpoint_path = endpoint_path
lit_api.stream = stream
lit_api.sanitize(max_batch_size)
self.app = FastAPI(lifespan=lifespan)
Expand Down Expand Up @@ -404,7 +409,6 @@ def setup_server(self):
async def index(request: Request) -> Response:
return Response(content="litserve running")

@self.app.post("/predict", dependencies=[Depends(setup_auth())])
async def predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type:
uid = uuid.uuid4()
logger.debug(f"Received request uid={uid}")
Expand Down Expand Up @@ -452,7 +456,6 @@ async def data_reader():
load_and_raise(response)
return response

@self.app.post("/stream-predict", dependencies=[Depends(setup_auth())])
async def stream_predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type:
uid = uuid.uuid4()
logger.debug(f"Received request uid={uid}")
Expand Down Expand Up @@ -510,6 +513,15 @@ async def data_streamer():

return StreamingResponse(data_streamer())

stream = self.app.lit_api.stream
methods = ["POST"]
self.app.add_api_route(
self.endpoint_path,
stream_predict if stream else predict,
methods=methods,
dependencies=[Depends(setup_auth())],
)

def generate_client_file(self):
src_path = os.path.join(os.path.dirname(__file__), "python_client.py")
dest_path = os.path.join(os.getcwd(), "client.py")
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_e2e_batched_streaming():
)

time.sleep(5)
resp = requests.post("http://127.0.0.1:8000/stream-predict", json={"input": 4.0}, headers=None, stream=True)
resp = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}, headers=None, stream=True)
assert resp.status_code == 200, f"Expected response to be 200 but got {resp.status_code}"

outputs = []
Expand Down
17 changes: 12 additions & 5 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ async def test_stream(simple_stream_api):
expected_output2 = "prompt=World generated_output=LitServe is streaming output".lower().replace(" ", "")

async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp1 = ac.post("/stream-predict", json={"prompt": "Hello"}, timeout=10)
resp2 = ac.post("/stream-predict", json={"prompt": "World"}, timeout=10)
resp1 = ac.post("/predict", json={"prompt": "Hello"}, timeout=10)
resp2 = ac.post("/predict", json={"prompt": "World"}, timeout=10)
resp1, resp2 = await asyncio.gather(resp1, resp2)
assert resp1.status_code == 200, "Check if server is running and the request format is valid."
assert resp1.text == expected_output1, "Server returns input prompt and generated output which didn't match."
Expand All @@ -154,13 +154,20 @@ async def test_stream(simple_stream_api):

@pytest.mark.asyncio()
async def test_batched_stream_server(simple_batched_stream_api):
server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30)
server = LitServer(
simple_batched_stream_api,
stream=True,
max_batch_size=4,
batch_timeout=2,
timeout=30,
endpoint_path="/v1/stream/predict",
)
expected_output1 = "Hello LitServe is streaming output".lower().replace(" ", "")
expected_output2 = "World LitServe is streaming output".lower().replace(" ", "")

async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp1 = ac.post("/stream-predict", json={"prompt": "Hello"}, timeout=10)
resp2 = ac.post("/stream-predict", json={"prompt": "World"}, timeout=10)
resp1 = ac.post("/v1/stream/predict", json={"prompt": "Hello"}, timeout=10)
resp2 = ac.post("/v1/stream/predict", json={"prompt": "World"}, timeout=10)
resp1, resp2 = await asyncio.gather(resp1, resp2)
assert resp1.status_code == 200, "Check if server is running and the request format is valid."
assert resp2.status_code == 200, "Check if server is running and the request format is valid."
Expand Down
Loading