-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #62 from davidbrochart/rework_noauth
Rework noauth
- Loading branch information
Showing
7 changed files
with
148 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
from typing import Optional | ||
|
||
import httpx | ||
from fastapi import Depends, status | ||
from fastapi.security.base import SecurityBase | ||
from fastapi_users.authentication import BaseAuthentication, CookieAuthentication # type: ignore | ||
from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore | ||
from starlette.requests import Request | ||
|
||
from .models import ( | ||
User, | ||
UserCreate, | ||
UserUpdate, | ||
UserDB, | ||
) | ||
from .config import get_auth_config | ||
from .db import secret, get_user_db | ||
|
||
|
||
NOAUTH_EMAIL = "[email protected]" | ||
NOAUTH_USER = UserDB( | ||
id="d4ded46b-a4df-4b51-8d83-ae19010272a7", | ||
email=NOAUTH_EMAIL, | ||
hashed_password="", | ||
) | ||
|
||
|
||
class NoAuth(SecurityBase): | ||
def __call__(self): | ||
return "noauth" | ||
|
||
|
||
class NoAuthAuthentication(BaseAuthentication): | ||
def __init__(self, user: UserDB, name: str = "noauth"): | ||
super().__init__(name, logout=False) | ||
self.user = user | ||
self.scheme = NoAuth() | ||
|
||
async def __call__(self, credentials, user_manager): | ||
noauth_user = await user_manager.user_db.get_by_email(NOAUTH_EMAIL) | ||
return noauth_user or self.user | ||
|
||
|
||
noauth_authentication = NoAuthAuthentication(NOAUTH_USER) | ||
|
||
|
||
class UserManager(BaseUserManager[UserCreate, UserDB]): | ||
user_db_model = UserDB | ||
|
||
async def on_after_register(self, user: UserDB, request: Optional[Request] = None): | ||
user.initialized = True | ||
for oauth_account in user.oauth_accounts: | ||
if oauth_account.oauth_name == "github": | ||
async with httpx.AsyncClient() as client: | ||
r = ( | ||
await client.get( | ||
f"https://api.github.com/user/{oauth_account.account_id}" | ||
) | ||
).json() | ||
user.anonymous = False | ||
user.username = r["login"] | ||
user.name = r["name"] | ||
user.color = None | ||
user.avatar = r["avatar_url"] | ||
await self.user_db.update(user) | ||
|
||
|
||
class LoginCookieAuthentication(CookieAuthentication): | ||
async def get_login_response(self, user, response, user_manager): | ||
await super().get_login_response(user, response, user_manager) | ||
# set user as logged in | ||
user.logged_in = True | ||
await user_manager.user_db.update(user) | ||
# auto redirect | ||
response.status_code = status.HTTP_302_FOUND | ||
response.headers["Location"] = "/lab" | ||
|
||
async def get_logout_response(self, user, response, user_manager): | ||
await super().get_logout_response(user, response, user_manager) | ||
# set user as logged out | ||
user.logged_in = False | ||
await user_manager.user_db.update(user) | ||
|
||
|
||
cookie_authentication = LoginCookieAuthentication( | ||
cookie_secure=get_auth_config().cookie_secure, secret=secret | ||
) | ||
|
||
|
||
def get_user_manager(user_db=Depends(get_user_db)): | ||
yield UserManager(user_db) | ||
|
||
|
||
users = FastAPIUsers( | ||
get_user_manager, | ||
[noauth_authentication, cookie_authentication], | ||
User, | ||
UserCreate, | ||
UserUpdate, | ||
UserDB, | ||
) | ||
|
||
|
||
async def get_enabled_backends(auth_config=Depends(get_auth_config)): | ||
if auth_config.mode == "noauth": | ||
return [noauth_authentication] | ||
return [cookie_authentication] | ||
|
||
|
||
current_user = users.current_user(get_enabled_backends=get_enabled_backends) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,22 @@ | ||
from uuid import uuid4 | ||
from typing import Optional | ||
|
||
import httpx # type: ignore | ||
from httpx_oauth.clients.github import GitHubOAuth2 # type: ignore | ||
from fps.hooks import register_router # type: ignore | ||
from fps.config import get_config, FPSConfig # type: ignore | ||
from fastapi_users.authentication import CookieAuthentication # type: ignore | ||
from fastapi import APIRouter, Depends, HTTPException, status | ||
from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore | ||
from starlette.requests import Request | ||
from fastapi import APIRouter, Depends | ||
from sqlalchemy.orm import sessionmaker # type: ignore | ||
|
||
from .config import get_auth_config | ||
from .db import get_user_db, user_db, secret, database, engine, UserTable | ||
from .db import user_db, secret, database, engine, UserTable | ||
from .backends import ( | ||
users, | ||
current_user, | ||
cookie_authentication, | ||
NOAUTH_USER, | ||
NOAUTH_EMAIL, | ||
) | ||
from .models import ( | ||
User, | ||
UserCreate, | ||
UserUpdate, | ||
UserDB, | ||
) | ||
|
||
|
@@ -28,61 +28,6 @@ | |
auth_config = get_auth_config() | ||
|
||
|
||
class UserManager(BaseUserManager[UserCreate, UserDB]): | ||
user_db_model = UserDB | ||
|
||
async def on_after_register(self, user: UserDB, request: Optional[Request] = None): | ||
user.initialized = True | ||
for oauth_account in user.oauth_accounts: | ||
print(oauth_account) | ||
if oauth_account.oauth_name == "github": | ||
r = httpx.get( | ||
f"https://api.github.com/user/{oauth_account.account_id}" | ||
).json() | ||
user.anonymous = False | ||
user.username = r["login"] | ||
user.name = r["name"] | ||
user.color = None | ||
user.avatar = r["avatar_url"] | ||
await self.user_db.update(user) | ||
|
||
|
||
def get_user_manager(user_db=Depends(get_user_db)): | ||
yield UserManager(user_db) | ||
|
||
|
||
class LoginCookieAuthentication(CookieAuthentication): | ||
async def get_login_response(self, user, response, user_manager): | ||
await super().get_login_response(user, response, user_manager) | ||
# set user as logged in | ||
user.logged_in = True | ||
await user_manager.user_db.update(user) | ||
# auto redirect | ||
response.status_code = status.HTTP_302_FOUND | ||
response.headers["Location"] = "/lab" | ||
|
||
async def get_logout_response(self, user, response, user_manager): | ||
await super().get_logout_response(user, response, user_manager) | ||
# set user as logged out | ||
user.logged_in = False | ||
await user_manager.user_db.update(user) | ||
|
||
|
||
cookie_authentication = LoginCookieAuthentication( | ||
cookie_secure=auth_config.cookie_secure, secret=secret | ||
) | ||
|
||
auth_backends = [cookie_authentication] | ||
|
||
users = FastAPIUsers( | ||
get_user_manager, | ||
auth_backends, | ||
User, | ||
UserCreate, | ||
UserUpdate, | ||
UserDB, | ||
) | ||
|
||
github_oauth_client = GitHubOAuth2( | ||
auth_config.client_id, auth_config.client_secret.get_secret_value() | ||
) | ||
|
@@ -99,7 +44,6 @@ async def get_logout_response(self, user, response, user_manager): | |
|
||
TOKEN_USER = None | ||
USER_TOKEN = None | ||
noauth_email = "[email protected]" | ||
|
||
|
||
def set_user_token(user_token): | ||
|
@@ -130,14 +74,9 @@ async def shutdown(): | |
|
||
|
||
async def create_noauth_user(): | ||
user = await user_db.get_by_email(noauth_email) | ||
if user is None: | ||
user = UserDB( | ||
id="d4ded46b-a4df-4b51-8d83-ae19010272a7", | ||
email=noauth_email, | ||
hashed_password="", | ||
) | ||
await user_db.create(user) | ||
noauth_user = await user_db.get_by_email(NOAUTH_EMAIL) | ||
if noauth_user is None: | ||
await user_db.create(NOAUTH_USER) | ||
|
||
|
||
async def create_token_user(): | ||
|
@@ -149,25 +88,8 @@ async def create_token_user(): | |
await user_db.create(TOKEN_USER) | ||
|
||
|
||
def current_user(optional: bool = False): | ||
async def _( | ||
auth_config=Depends(get_auth_config), | ||
user: User = Depends(users.current_user(optional=True)), | ||
user_db=Depends(get_user_db), | ||
): | ||
if auth_config.mode == "noauth": | ||
return await user_db.get_by_email(noauth_email) | ||
elif user is None and not optional: | ||
# FIXME: could be 403 | ||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) | ||
else: | ||
return user | ||
|
||
return _ | ||
|
||
|
||
@router.get("/auth/users") | ||
async def get_users(user: User = Depends(current_user())): | ||
async def get_users(user: User = Depends(current_user)): | ||
users = session.query(UserTable).all() | ||
return [user for user in users if user.logged_in] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.