Skip to content

Commit

Permalink
Merge pull request #62 from davidbrochart/rework_noauth
Browse files Browse the repository at this point in the history
Rework noauth
  • Loading branch information
davidbrochart authored Sep 21, 2021
2 parents c846f48 + 9431b94 commit 6ea4e14
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 116 deletions.
110 changes: 110 additions & 0 deletions plugins/auth/fps_auth/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Optional

import httpx
from fastapi import Depends, status
from fastapi.security.base import SecurityBase
from fastapi_users.authentication import BaseAuthentication, CookieAuthentication # type: ignore
from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore
from starlette.requests import Request

from .models import (
User,
UserCreate,
UserUpdate,
UserDB,
)
from .config import get_auth_config
from .db import secret, get_user_db


NOAUTH_EMAIL = "[email protected]"
NOAUTH_USER = UserDB(
id="d4ded46b-a4df-4b51-8d83-ae19010272a7",
email=NOAUTH_EMAIL,
hashed_password="",
)


class NoAuth(SecurityBase):
def __call__(self):
return "noauth"


class NoAuthAuthentication(BaseAuthentication):
def __init__(self, user: UserDB, name: str = "noauth"):
super().__init__(name, logout=False)
self.user = user
self.scheme = NoAuth()

async def __call__(self, credentials, user_manager):
noauth_user = await user_manager.user_db.get_by_email(NOAUTH_EMAIL)
return noauth_user or self.user


noauth_authentication = NoAuthAuthentication(NOAUTH_USER)


class UserManager(BaseUserManager[UserCreate, UserDB]):
user_db_model = UserDB

async def on_after_register(self, user: UserDB, request: Optional[Request] = None):
user.initialized = True
for oauth_account in user.oauth_accounts:
if oauth_account.oauth_name == "github":
async with httpx.AsyncClient() as client:
r = (
await client.get(
f"https://api.github.com/user/{oauth_account.account_id}"
)
).json()
user.anonymous = False
user.username = r["login"]
user.name = r["name"]
user.color = None
user.avatar = r["avatar_url"]
await self.user_db.update(user)


class LoginCookieAuthentication(CookieAuthentication):
async def get_login_response(self, user, response, user_manager):
await super().get_login_response(user, response, user_manager)
# set user as logged in
user.logged_in = True
await user_manager.user_db.update(user)
# auto redirect
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = "/lab"

async def get_logout_response(self, user, response, user_manager):
await super().get_logout_response(user, response, user_manager)
# set user as logged out
user.logged_in = False
await user_manager.user_db.update(user)


cookie_authentication = LoginCookieAuthentication(
cookie_secure=get_auth_config().cookie_secure, secret=secret
)


def get_user_manager(user_db=Depends(get_user_db)):
yield UserManager(user_db)


users = FastAPIUsers(
get_user_manager,
[noauth_authentication, cookie_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)


async def get_enabled_backends(auth_config=Depends(get_auth_config)):
if auth_config.mode == "noauth":
return [noauth_authentication]
return [cookie_authentication]


current_user = users.current_user(get_enabled_backends=get_enabled_backends)
104 changes: 13 additions & 91 deletions plugins/auth/fps_auth/routes.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from uuid import uuid4
from typing import Optional

import httpx # type: ignore
from httpx_oauth.clients.github import GitHubOAuth2 # type: ignore
from fps.hooks import register_router # type: ignore
from fps.config import get_config, FPSConfig # type: ignore
from fastapi_users.authentication import CookieAuthentication # type: ignore
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore
from starlette.requests import Request
from fastapi import APIRouter, Depends
from sqlalchemy.orm import sessionmaker # type: ignore

from .config import get_auth_config
from .db import get_user_db, user_db, secret, database, engine, UserTable
from .db import user_db, secret, database, engine, UserTable
from .backends import (
users,
current_user,
cookie_authentication,
NOAUTH_USER,
NOAUTH_EMAIL,
)
from .models import (
User,
UserCreate,
UserUpdate,
UserDB,
)

Expand All @@ -28,61 +28,6 @@
auth_config = get_auth_config()


class UserManager(BaseUserManager[UserCreate, UserDB]):
user_db_model = UserDB

async def on_after_register(self, user: UserDB, request: Optional[Request] = None):
user.initialized = True
for oauth_account in user.oauth_accounts:
print(oauth_account)
if oauth_account.oauth_name == "github":
r = httpx.get(
f"https://api.github.com/user/{oauth_account.account_id}"
).json()
user.anonymous = False
user.username = r["login"]
user.name = r["name"]
user.color = None
user.avatar = r["avatar_url"]
await self.user_db.update(user)


def get_user_manager(user_db=Depends(get_user_db)):
yield UserManager(user_db)


class LoginCookieAuthentication(CookieAuthentication):
async def get_login_response(self, user, response, user_manager):
await super().get_login_response(user, response, user_manager)
# set user as logged in
user.logged_in = True
await user_manager.user_db.update(user)
# auto redirect
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = "/lab"

async def get_logout_response(self, user, response, user_manager):
await super().get_logout_response(user, response, user_manager)
# set user as logged out
user.logged_in = False
await user_manager.user_db.update(user)


cookie_authentication = LoginCookieAuthentication(
cookie_secure=auth_config.cookie_secure, secret=secret
)

auth_backends = [cookie_authentication]

users = FastAPIUsers(
get_user_manager,
auth_backends,
User,
UserCreate,
UserUpdate,
UserDB,
)

github_oauth_client = GitHubOAuth2(
auth_config.client_id, auth_config.client_secret.get_secret_value()
)
Expand All @@ -99,7 +44,6 @@ async def get_logout_response(self, user, response, user_manager):

TOKEN_USER = None
USER_TOKEN = None
noauth_email = "[email protected]"


def set_user_token(user_token):
Expand Down Expand Up @@ -130,14 +74,9 @@ async def shutdown():


async def create_noauth_user():
user = await user_db.get_by_email(noauth_email)
if user is None:
user = UserDB(
id="d4ded46b-a4df-4b51-8d83-ae19010272a7",
email=noauth_email,
hashed_password="",
)
await user_db.create(user)
noauth_user = await user_db.get_by_email(NOAUTH_EMAIL)
if noauth_user is None:
await user_db.create(NOAUTH_USER)


async def create_token_user():
Expand All @@ -149,25 +88,8 @@ async def create_token_user():
await user_db.create(TOKEN_USER)


def current_user(optional: bool = False):
async def _(
auth_config=Depends(get_auth_config),
user: User = Depends(users.current_user(optional=True)),
user_db=Depends(get_user_db),
):
if auth_config.mode == "noauth":
return await user_db.get_by_email(noauth_email)
elif user is None and not optional:
# FIXME: could be 403
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
else:
return user

return _


@router.get("/auth/users")
async def get_users(user: User = Depends(current_user())):
async def get_users(user: User = Depends(current_user)):
users = session.query(UserTable).all()
return [user for user in users if user.logged_in]

Expand Down
14 changes: 7 additions & 7 deletions plugins/contents/fps_contents/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import APIRouter, Depends
from starlette.requests import Request # type: ignore

from fps_auth.routes import current_user # type: ignore
from fps_auth.backends import current_user # type: ignore
from fps_auth.models import User # type: ignore

from .models import Checkpoint, Content, SaveContent
Expand All @@ -23,7 +23,7 @@
)
async def create_content(
request: Request,
user: User = Depends(current_user()),
user: User = Depends(current_user),
):
create_content = await request.json()
path = Path(create_content["path"])
Expand Down Expand Up @@ -53,13 +53,13 @@ async def create_content(
@router.get("/api/contents")
async def get_root_content(
content: int,
user: User = Depends(current_user()),
user: User = Depends(current_user),
):
return Content(**get_path_content(Path(""), bool(content)))


@router.get("/api/contents/{path:path}/checkpoints")
async def get_checkpoint(path, user: User = Depends(current_user())):
async def get_checkpoint(path, user: User = Depends(current_user)):
src_path = Path(path)
dst_path = (
Path(".ipynb_checkpoints") / f"{src_path.stem}-checkpoint{src_path.suffix}"
Expand All @@ -74,15 +74,15 @@ async def get_checkpoint(path, user: User = Depends(current_user())):
async def get_content(
path: str,
content: int,
user: User = Depends(current_user()),
user: User = Depends(current_user),
):
return Content(**get_path_content(Path(path), bool(content)))


@router.put("/api/contents/{path:path}")
async def save_content(
request: Request,
user: User = Depends(current_user()),
user: User = Depends(current_user),
):
save_content = SaveContent(**(await request.json()))
try:
Expand Down Expand Up @@ -110,7 +110,7 @@ async def save_content(
"/api/contents/{path:path}/checkpoints",
status_code=201,
)
async def create_checkpoint(path, user: User = Depends(current_user())):
async def create_checkpoint(path, user: User = Depends(current_user)):
src_path = Path(path)
dst_path = (
Path(".ipynb_checkpoints") / f"{src_path.stem}-checkpoint{src_path.suffix}"
Expand Down
16 changes: 8 additions & 8 deletions plugins/jupyterlab/fps_jupyterlab/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fps.hooks import register_router # type: ignore

from fps_auth.db import get_user_db # type: ignore
from fps_auth.routes import ( # type: ignore
from fps_auth.backends import ( # type: ignore
current_user,
cookie_authentication,
LoginCookieAuthentication,
Expand Down Expand Up @@ -91,7 +91,7 @@ async def get_root(

@router.get("/lab")
async def get_lab(
user: User = Depends(current_user()), jlab_config=Depends(get_jlab_config)
user: User = Depends(current_user), jlab_config=Depends(get_jlab_config)
):
return HTMLResponse(
get_index("default", jlab_config.collaborative, jlab_config.base_url)
Expand Down Expand Up @@ -128,7 +128,7 @@ async def get_listings():


@router.get("/lab/api/workspaces/{name}")
async def get_workspace_data(user: User = Depends(current_user(optional=True))):
async def get_workspace_data(user: User = Depends(current_user)):
if user:
return json.loads(user.workspace)
return {}
Expand All @@ -140,7 +140,7 @@ async def get_workspace_data(user: User = Depends(current_user(optional=True))):
)
async def set_workspace(
request: Request,
user: User = Depends(current_user()),
user: User = Depends(current_user),
user_db=Depends(get_user_db),
):
user.workspace = await request.body()
Expand All @@ -150,7 +150,7 @@ async def set_workspace(

@router.get("/lab/workspaces/{name}", response_class=HTMLResponse)
async def get_workspace(
name, user: User = Depends(current_user()), jlab_config=Depends(get_jlab_config)
name, user: User = Depends(current_user), jlab_config=Depends(get_jlab_config)
):
return get_index(name, jlab_config.collaborative, jlab_config.base_url)

Expand Down Expand Up @@ -202,7 +202,7 @@ async def get_setting(
name0,
name1,
name2,
user: User = Depends(current_user(optional=True)),
user: User = Depends(current_user),
):
with open(
prefix_dir / "share" / "jupyter" / "lab" / "static" / "package.json"
Expand Down Expand Up @@ -248,7 +248,7 @@ async def change_setting(
request: Request,
name0,
name1,
user: User = Depends(current_user()),
user: User = Depends(current_user),
user_db=Depends(get_user_db),
):
settings = json.loads(user.settings)
Expand All @@ -259,7 +259,7 @@ async def change_setting(


@router.get("/lab/api/settings")
async def get_settings(user: User = Depends(current_user(optional=True))):
async def get_settings(user: User = Depends(current_user)):
with open(
prefix_dir / "share" / "jupyter" / "lab" / "static" / "package.json"
) as f:
Expand Down
Loading

0 comments on commit 6ea4e14

Please sign in to comment.