Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework noauth #62

Merged
merged 4 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions plugins/auth/fps_auth/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Optional

import httpx
from fastapi import Depends, status
from fastapi.security.base import SecurityBase
from fastapi_users.authentication import BaseAuthentication, CookieAuthentication # type: ignore
from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore
from starlette.requests import Request

from .models import (
User,
UserCreate,
UserUpdate,
UserDB,
)
from .config import get_auth_config
from .db import secret, get_user_db


NOAUTH_EMAIL = "[email protected]"
NOAUTH_USER = UserDB(
id="d4ded46b-a4df-4b51-8d83-ae19010272a7",
email=NOAUTH_EMAIL,
hashed_password="",
)


class NoAuth(SecurityBase):
def __call__(self):
return "noauth"


class NoAuthAuthentication(BaseAuthentication):
def __init__(self, user: UserDB, name: str = "noauth"):
super().__init__(name, logout=False)
self.user = user
self.scheme = NoAuth()

async def __call__(self, credentials, user_manager):
noauth_user = await user_manager.user_db.get_by_email(NOAUTH_EMAIL)
return noauth_user or self.user


noauth_authentication = NoAuthAuthentication(NOAUTH_USER)


class UserManager(BaseUserManager[UserCreate, UserDB]):
user_db_model = UserDB

async def on_after_register(self, user: UserDB, request: Optional[Request] = None):
user.initialized = True
for oauth_account in user.oauth_accounts:
if oauth_account.oauth_name == "github":
async with httpx.AsyncClient() as client:
r = (
await client.get(
f"https://api.github.com/user/{oauth_account.account_id}"
)
).json()
user.anonymous = False
user.username = r["login"]
user.name = r["name"]
user.color = None
user.avatar = r["avatar_url"]
await self.user_db.update(user)


class LoginCookieAuthentication(CookieAuthentication):
async def get_login_response(self, user, response, user_manager):
await super().get_login_response(user, response, user_manager)
# set user as logged in
user.logged_in = True
await user_manager.user_db.update(user)
# auto redirect
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = "/lab"

async def get_logout_response(self, user, response, user_manager):
await super().get_logout_response(user, response, user_manager)
# set user as logged out
user.logged_in = False
await user_manager.user_db.update(user)


cookie_authentication = LoginCookieAuthentication(
cookie_secure=get_auth_config().cookie_secure, secret=secret
)


def get_user_manager(user_db=Depends(get_user_db)):
yield UserManager(user_db)


users = FastAPIUsers(
get_user_manager,
[noauth_authentication, cookie_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)


async def get_enabled_backends(auth_config=Depends(get_auth_config)):
if auth_config.mode == "noauth":
return [noauth_authentication]
return [cookie_authentication]


current_user = users.current_user(get_enabled_backends=get_enabled_backends)
104 changes: 13 additions & 91 deletions plugins/auth/fps_auth/routes.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from uuid import uuid4
from typing import Optional

import httpx # type: ignore
from httpx_oauth.clients.github import GitHubOAuth2 # type: ignore
from fps.hooks import register_router # type: ignore
from fps.config import get_config, FPSConfig # type: ignore
from fastapi_users.authentication import CookieAuthentication # type: ignore
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore
from starlette.requests import Request
from fastapi import APIRouter, Depends
from sqlalchemy.orm import sessionmaker # type: ignore

from .config import get_auth_config
from .db import get_user_db, user_db, secret, database, engine, UserTable
from .db import user_db, secret, database, engine, UserTable
from .backends import (
users,
current_user,
cookie_authentication,
NOAUTH_USER,
NOAUTH_EMAIL,
)
from .models import (
User,
UserCreate,
UserUpdate,
UserDB,
)

Expand All @@ -28,61 +28,6 @@
auth_config = get_auth_config()


class UserManager(BaseUserManager[UserCreate, UserDB]):
user_db_model = UserDB

async def on_after_register(self, user: UserDB, request: Optional[Request] = None):
user.initialized = True
for oauth_account in user.oauth_accounts:
print(oauth_account)
if oauth_account.oauth_name == "github":
r = httpx.get(
f"https://api.github.com/user/{oauth_account.account_id}"
).json()
user.anonymous = False
user.username = r["login"]
user.name = r["name"]
user.color = None
user.avatar = r["avatar_url"]
await self.user_db.update(user)


def get_user_manager(user_db=Depends(get_user_db)):
yield UserManager(user_db)


class LoginCookieAuthentication(CookieAuthentication):
async def get_login_response(self, user, response, user_manager):
await super().get_login_response(user, response, user_manager)
# set user as logged in
user.logged_in = True
await user_manager.user_db.update(user)
# auto redirect
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = "/lab"

async def get_logout_response(self, user, response, user_manager):
await super().get_logout_response(user, response, user_manager)
# set user as logged out
user.logged_in = False
await user_manager.user_db.update(user)


cookie_authentication = LoginCookieAuthentication(
cookie_secure=auth_config.cookie_secure, secret=secret
)

auth_backends = [cookie_authentication]

users = FastAPIUsers(
get_user_manager,
auth_backends,
User,
UserCreate,
UserUpdate,
UserDB,
)

github_oauth_client = GitHubOAuth2(
auth_config.client_id, auth_config.client_secret.get_secret_value()
)
Expand All @@ -99,7 +44,6 @@ async def get_logout_response(self, user, response, user_manager):

TOKEN_USER = None
USER_TOKEN = None
noauth_email = "[email protected]"


def set_user_token(user_token):
Expand Down Expand Up @@ -130,14 +74,9 @@ async def shutdown():


async def create_noauth_user():
user = await user_db.get_by_email(noauth_email)
if user is None:
user = UserDB(
id="d4ded46b-a4df-4b51-8d83-ae19010272a7",
email=noauth_email,
hashed_password="",
)
await user_db.create(user)
noauth_user = await user_db.get_by_email(NOAUTH_EMAIL)
if noauth_user is None:
await user_db.create(NOAUTH_USER)


async def create_token_user():
Expand All @@ -149,25 +88,8 @@ async def create_token_user():
await user_db.create(TOKEN_USER)


def current_user(optional: bool = False):
async def _(
auth_config=Depends(get_auth_config),
user: User = Depends(users.current_user(optional=True)),
user_db=Depends(get_user_db),
):
if auth_config.mode == "noauth":
return await user_db.get_by_email(noauth_email)
elif user is None and not optional:
# FIXME: could be 403
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
else:
return user

return _


@router.get("/auth/users")
async def get_users(user: User = Depends(current_user())):
async def get_users(user: User = Depends(current_user)):
users = session.query(UserTable).all()
return [user for user in users if user.logged_in]

Expand Down
14 changes: 7 additions & 7 deletions plugins/contents/fps_contents/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import APIRouter, Depends
from starlette.requests import Request # type: ignore

from fps_auth.routes import current_user # type: ignore
from fps_auth.backends import current_user # type: ignore
from fps_auth.models import User # type: ignore

from .models import Checkpoint, Content, SaveContent
Expand All @@ -23,7 +23,7 @@
)
async def create_content(
request: Request,
user: User = Depends(current_user()),
user: User = Depends(current_user),
):
create_content = await request.json()
path = Path(create_content["path"])
Expand Down Expand Up @@ -53,13 +53,13 @@ async def create_content(
@router.get("/api/contents")
async def get_root_content(
content: int,
user: User = Depends(current_user()),
user: User = Depends(current_user),
):
return Content(**get_path_content(Path(""), bool(content)))


@router.get("/api/contents/{path:path}/checkpoints")
async def get_checkpoint(path, user: User = Depends(current_user())):
async def get_checkpoint(path, user: User = Depends(current_user)):
src_path = Path(path)
dst_path = (
Path(".ipynb_checkpoints") / f"{src_path.stem}-checkpoint{src_path.suffix}"
Expand All @@ -74,15 +74,15 @@ async def get_checkpoint(path, user: User = Depends(current_user())):
async def get_content(
path: str,
content: int,
user: User = Depends(current_user()),
user: User = Depends(current_user),
):
return Content(**get_path_content(Path(path), bool(content)))


@router.put("/api/contents/{path:path}")
async def save_content(
request: Request,
user: User = Depends(current_user()),
user: User = Depends(current_user),
):
save_content = SaveContent(**(await request.json()))
try:
Expand Down Expand Up @@ -110,7 +110,7 @@ async def save_content(
"/api/contents/{path:path}/checkpoints",
status_code=201,
)
async def create_checkpoint(path, user: User = Depends(current_user())):
async def create_checkpoint(path, user: User = Depends(current_user)):
src_path = Path(path)
dst_path = (
Path(".ipynb_checkpoints") / f"{src_path.stem}-checkpoint{src_path.suffix}"
Expand Down
16 changes: 8 additions & 8 deletions plugins/jupyterlab/fps_jupyterlab/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fps.hooks import register_router # type: ignore

from fps_auth.db import get_user_db # type: ignore
from fps_auth.routes import ( # type: ignore
from fps_auth.backends import ( # type: ignore
current_user,
cookie_authentication,
LoginCookieAuthentication,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you are implementing a kind of login endpoint for token authentication mode. IMO, this should also be a custom authentication backend that would check for a static token in query params and return a static user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the mechanism for token authentication is that when you access the root endpoint with the token as a query parameter, the server will set a cookie so that you don't have to pass the token again in next accesses (this is the default behavior for jupyter-server). I'm not sure how we could implement that differently. Maybe you have ideas?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think I see what you mean, we should check for this token in query parameters for all endpoints, and thus have a special authentication back-end for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, if the goal is to set a cookie, then what you did is indeed a good solution.

I thought the token was required to be passed on each request (like a REST API), in which case, a custom auth backend would have been more appropriate.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but still, we should allow the token to be passed on any endpoint, not just the root endpoint. I tried playing with a TokenAuthentication class, but then I was faced with the problem of getting the token as a query parameter, and basically didn't know where to get it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key is to use the security schemes of FastAPI: there is one for most common cases. In this case, APIKeyQuery seems a good candidate: name argument is the name of your query parameter to read the token from.

It's also important to set auto_error to False to prevent the scheme from raising a 403 when the value is not present: FastAPI Users needs this to try several backends before raising an error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll keep that for a next PR, thanks.

Expand Down Expand Up @@ -91,7 +91,7 @@ async def get_root(

@router.get("/lab")
async def get_lab(
user: User = Depends(current_user()), jlab_config=Depends(get_jlab_config)
user: User = Depends(current_user), jlab_config=Depends(get_jlab_config)
):
return HTMLResponse(
get_index("default", jlab_config.collaborative, jlab_config.base_url)
Expand Down Expand Up @@ -128,7 +128,7 @@ async def get_listings():


@router.get("/lab/api/workspaces/{name}")
async def get_workspace_data(user: User = Depends(current_user(optional=True))):
async def get_workspace_data(user: User = Depends(current_user)):
if user:
return json.loads(user.workspace)
return {}
Expand All @@ -140,7 +140,7 @@ async def get_workspace_data(user: User = Depends(current_user(optional=True))):
)
async def set_workspace(
request: Request,
user: User = Depends(current_user()),
user: User = Depends(current_user),
user_db=Depends(get_user_db),
):
user.workspace = await request.body()
Expand All @@ -150,7 +150,7 @@ async def set_workspace(

@router.get("/lab/workspaces/{name}", response_class=HTMLResponse)
async def get_workspace(
name, user: User = Depends(current_user()), jlab_config=Depends(get_jlab_config)
name, user: User = Depends(current_user), jlab_config=Depends(get_jlab_config)
):
return get_index(name, jlab_config.collaborative, jlab_config.base_url)

Expand Down Expand Up @@ -202,7 +202,7 @@ async def get_setting(
name0,
name1,
name2,
user: User = Depends(current_user(optional=True)),
user: User = Depends(current_user),
):
with open(
prefix_dir / "share" / "jupyter" / "lab" / "static" / "package.json"
Expand Down Expand Up @@ -248,7 +248,7 @@ async def change_setting(
request: Request,
name0,
name1,
user: User = Depends(current_user()),
user: User = Depends(current_user),
user_db=Depends(get_user_db),
):
settings = json.loads(user.settings)
Expand All @@ -259,7 +259,7 @@ async def change_setting(


@router.get("/lab/api/settings")
async def get_settings(user: User = Depends(current_user(optional=True))):
async def get_settings(user: User = Depends(current_user)):
with open(
prefix_dir / "share" / "jupyter" / "lab" / "static" / "package.json"
) as f:
Expand Down
Loading