Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ability to customize authorization method #151

Merged
merged 1 commit into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ LIT_SERVER_API_KEY=supersecretkey python main.py

Clients are expected to auth with the same API key set in the `X-API-Key` HTTP header.

Alternatively, implement a method named `authorize` in the LitAPI subclass to provide custom authentication:

```python
from fastapi import HTTPException, Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
raise HTTPException(status_code=401, detail="Not authorized")
```

</details>

<details>
Expand Down
24 changes: 15 additions & 9 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,6 @@ def api_key_auth(x_api_key: str = Depends(APIKeyHeader(name="X-API-Key"))):
)


def setup_auth():
if LIT_SERVER_API_KEY:
return api_key_auth
return no_auth


def cleanup(request_buffer, uid):
logger.debug("Cleaning up request uid=%s", uid)
with contextlib.suppress(KeyError):
Expand Down Expand Up @@ -562,7 +556,7 @@ def cleanup_request(self, request_buffer, uid):
request_buffer.pop(uid)

def setup_server(self):
@self.app.get("/", dependencies=[Depends(setup_auth())])
@self.app.get("/", dependencies=[Depends(self.setup_auth())])
async def index(request: Request) -> Response:
return Response(content="litserve running")

Expand Down Expand Up @@ -624,14 +618,19 @@ async def stream_predict(request: self.request_type, background_tasks: Backgroun
endpoint = self.api_path
methods = ["POST"]
self.app.add_api_route(
endpoint, stream_predict if stream else predict, methods=methods, dependencies=[Depends(setup_auth())]
endpoint,
stream_predict if stream else predict,
methods=methods,
dependencies=[Depends(self.setup_auth())],
)

for spec in self._specs:
spec: LitSpec
# TODO check that path is not clashing
for path, endpoint, methods in spec.endpoints:
self.app.add_api_route(path, endpoint=endpoint, methods=methods, dependencies=[Depends(setup_auth())])
self.app.add_api_route(
path, endpoint=endpoint, methods=methods, dependencies=[Depends(self.setup_auth())]
)

def generate_client_file(self):
src_path = os.path.join(os.path.dirname(__file__), "python_client.py")
Expand Down Expand Up @@ -661,3 +660,10 @@ def run(self, port: Union[str, int] = 8000, log_level: str = "info", generate_cl
raise ValueError(port_msg)

uvicorn.run(host="0.0.0.0", port=port, app=self.app, log_level=log_level, **kwargs)

def setup_auth(self):
if hasattr(self.lit_api, "authorize") and callable(self.lit_api.authorize):
return self.lit_api.authorize
if LIT_SERVER_API_KEY:
return api_key_auth
return no_auth
94 changes: 94 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fastapi import Request, Response, HTTPException, Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.testclient import TestClient

from litserve import LitAPI, LitServer
import litserve.server


class SimpleAuthedLitAPI(LitAPI):
def setup(self, device):
self.model = lambda x: x**2

def decode_request(self, request: Request):
return request["input"]

def predict(self, x):
return self.model(x)

def encode_response(self, output) -> Response:
return {"output": output}

def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if auth.scheme != "Bearer" or auth.credentials != "1234":
raise HTTPException(status_code=401, detail="Bad token")


def test_authorized_custom():
server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input)
assert response.status_code == 200


def test_not_authorized_custom():
server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input)
assert response.status_code == 401


class SimpleLitAPI(LitAPI):
def setup(self, device):
self.model = lambda x: x**2

def decode_request(self, request: Request):
return request["input"]

def predict(self, x):
return self.model(x)

def encode_response(self, output) -> Response:
return {"output": output}


def test_authorized_api_key():
litserve.server.LIT_SERVER_API_KEY = "abcd"
server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input)
assert response.status_code == 200

litserve.server.LIT_SERVER_API_KEY = None


def test_not_authorized_api_key():
litserve.server.LIT_SERVER_API_KEY = "abcd"
server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input)
assert response.status_code == 401

litserve.server.LIT_SERVER_API_KEY = None
Loading