Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add excluded_paths to JWTMiddleware #226

Closed
wants to merge 10 commits into from
12 changes: 9 additions & 3 deletions piccolo_api/jwt_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,15 @@ async def __call__(self, scope, receive, send):
"""
allow_unauthenticated = self.allow_unauthenticated

if scope["path"] in self.excluded_paths:
await self.asgi(scope, receive, send)
return
for excluded_path in self.excluded_paths:
if excluded_path.endswith("*"):
if excluded_path.startswith(excluded_path.rstrip("*")):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite right - because it's comparing excluded_paths with excluded_paths.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But

if scope['path'].startswith(excluded_path.rstrip('*')):

does nothing. If we have /foo/* we need check to root_path like this

if scope["root_path"].startswith(excluded_path.rstrip("/*")): 

to take some effect. Sorry if I don't understand well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if scope['path'].startswith(excluded_path.rstrip('*')):

work only if we specified excluded_paths=["*"] not excluded_paths=["/foo/*"]. Did you think so?

Copy link
Member

@dantownsend dantownsend Apr 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the confusion is to do with how FastAPI splits the path between path and root_path.

For example, if I have the an APIRouter mounted at /private, and an endpoint mounted to that router at /blog, then root_path is /private and path is /blog.

I think we need to combine them for it to work properly. So rather than just checking against path it's something like urllib.parse.urljoin(request.scope['root_path'], request.scope['path']).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This discussion gives some more context django/asgiref#229

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raw_path did the trick. If we use raw_path like this

if scope["raw_path"].decode("utf-8").startswith(excluded_path.rstrip("*")):

both excluded_paths=["*"] and excluded_paths=["/foo/*"] works.

await self.asgi(scope, receive, send)
return
else:
if scope["path"] == excluded_path:
await self.asgi(scope, receive, send)
return

headers = dict(scope["headers"])
token = self.get_token(headers)
Expand Down
28 changes: 22 additions & 6 deletions tests/jwt_auth/test_jwt_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,21 @@ def get(self, request: Request):
)


FASTAPI_APP = FastAPI(title="Test visible paths")
fastapi_app = FastAPI(title="Test excluded paths")


@fastapi_app.get("/")
def wildcard_route():
return "Wildcard route test"


ECHO_APP = Router([Route("/", EchoEndpoint)])
APP = JWTMiddleware(asgi=ECHO_APP, secret="SECRET")
APP_UNAUTH = JWTMiddleware(
asgi=ECHO_APP, secret="SECRET", allow_unauthenticated=True
)
APP_VISIBLE_PATHS = JWTMiddleware(
asgi=FASTAPI_APP, secret="SECRET", excluded_paths=["/docs"]
APP_EXCLUDED_PATHS = JWTMiddleware(
asgi=fastapi_app, secret="SECRET", excluded_paths=["/docs", "/*"]
)


Expand Down Expand Up @@ -206,12 +212,22 @@ def test_token_without_user_id(self):
{"user_id": None, "jwt_error": JWTError.user_not_found.value},
)

def test_visible_paths(self):
client = TestClient(APP_VISIBLE_PATHS)
def test_excluded_paths(self):
client = TestClient(APP_EXCLUDED_PATHS)

response = client.get("/docs")
self.assertEqual(response.status_code, 200)
self.assertIn(
b"<title>Test visible paths - Swagger UI</title>",
b"<title>Test excluded paths - Swagger UI</title>",
response.content,
)

def test_excluded_paths_wildcards(self):
client = TestClient(APP_EXCLUDED_PATHS)

response = client.get("/")
self.assertEqual(response.status_code, 200)
self.assertIn(
b"Wildcard route test",
response.content,
)