Skip to content

Commit

Permalink
Merge pull request #58 from davidbrochart/fastapiusers_8
Browse files Browse the repository at this point in the history
Use FastAPI-Users 8.0
  • Loading branch information
davidbrochart authored Sep 20, 2021
2 parents 112fa00 + 6dddc4a commit 2191e9b
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 97 deletions.
74 changes: 74 additions & 0 deletions plugins/auth/fps_auth/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import secrets
from pathlib import Path

from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase # type: ignore
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTable # type: ignore
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base # type: ignore
from sqlalchemy import Boolean, String, Text, Column # type: ignore
import sqlalchemy # type: ignore
import databases # type: ignore
from fps.config import get_config # type: ignore

from .config import AuthConfig
from .models import (
UserDB,
)

auth_config = get_config(AuthConfig)

jupyter_dir = Path.home() / ".local" / "share" / "jupyter"
jupyter_dir.mkdir(parents=True, exist_ok=True)
secret_path = jupyter_dir / "jupyverse_secret"
userdb_path = jupyter_dir / "jupyverse_users.db"

if auth_config.clear_users:
if userdb_path.is_file():
userdb_path.unlink()
if secret_path.is_file():
secret_path.unlink()

if not secret_path.is_file():
with open(secret_path, "w") as f:
f.write(secrets.token_hex(32))

with open(secret_path) as f:
secret = f.read()


DATABASE_URL = f"sqlite:///{userdb_path}"

database = databases.Database(DATABASE_URL)

Base: DeclarativeMeta = declarative_base()


class UserTable(Base, SQLAlchemyBaseUserTable):
initialized = Column(Boolean, default=False, nullable=False)
anonymous = Column(Boolean, default=False, nullable=False)
name = Column(String(length=32), nullable=True)
username = Column(String(length=32), nullable=True)
color = Column(String(length=32), nullable=True)
avatar = Column(String(length=32), nullable=True)
logged_in = Column(Boolean, default=False, nullable=False)
workspace = Column(Text(), nullable=False)
settings = Column(Text(), nullable=False)


class OAuthAccount(SQLAlchemyBaseOAuthAccountTable, Base):
pass


engine = sqlalchemy.create_engine(
DATABASE_URL, connect_args={"check_same_thread": False}
)

Base.metadata.create_all(engine)

users = UserTable.__table__
oauth_accounts = OAuthAccount.__table__

user_db = SQLAlchemyUserDatabase(UserDB, database, users, oauth_accounts)


def get_user_db():
yield user_db
65 changes: 0 additions & 65 deletions plugins/auth/fps_auth/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
import secrets
from pathlib import Path
from typing import Optional

from pydantic import BaseModel
import databases # type: ignore
import sqlalchemy # type: ignore
from fastapi_users import models # type: ignore
from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase # type: ignore
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTable # type: ignore
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base # type: ignore
from sqlalchemy import Boolean, String, Text, Column
from fps.config import Config # type: ignore

from .config import AuthConfig

auth_config = Config(AuthConfig)


class JupyterUser(BaseModel):
Expand Down Expand Up @@ -45,55 +32,3 @@ class UserUpdate(models.BaseUserUpdate, JupyterUser):

class UserDB(User, models.BaseUserDB):
pass


jupyter_dir = Path.home() / ".local" / "share" / "jupyter"
jupyter_dir.mkdir(parents=True, exist_ok=True)
secret_path = jupyter_dir / "jupyverse_secret"
userdb_path = jupyter_dir / "jupyverse_users.db"

if auth_config.clear_users:
if userdb_path.is_file():
userdb_path.unlink()
if secret_path.is_file():
secret_path.unlink()

if not secret_path.is_file():
with open(secret_path, "w") as f:
f.write(secrets.token_hex(32))

with open(secret_path) as f:
secret = f.read()

DATABASE_URL = f"sqlite:///{userdb_path}"

database = databases.Database(DATABASE_URL)

Base: DeclarativeMeta = declarative_base()


class UserTable(Base, SQLAlchemyBaseUserTable):
initialized = Column(Boolean, default=False, nullable=False)
anonymous = Column(Boolean, default=False, nullable=False)
name = Column(String(length=32), nullable=True)
username = Column(String(length=32), nullable=True)
color = Column(String(length=32), nullable=True)
avatar = Column(String(length=32), nullable=True)
logged_in = Column(Boolean, default=False, nullable=False)
workspace = Column(Text(), nullable=False)
settings = Column(Text(), nullable=False)


class OAuthAccount(SQLAlchemyBaseOAuthAccountTable, Base):
pass


engine = sqlalchemy.create_engine(
DATABASE_URL, connect_args={"check_same_thread": False}
)

Base.metadata.create_all(engine)

users = UserTable.__table__
oauth_accounts = OAuthAccount.__table__
user_db = SQLAlchemyUserDatabase(UserDB, database, users, oauth_accounts)
56 changes: 29 additions & 27 deletions plugins/auth/fps_auth/routes.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
from uuid import uuid4
from typing import Optional

import httpx # type: ignore
from httpx_oauth.clients.github import GitHubOAuth2 # type: ignore
from fps.hooks import register_router # type: ignore
from fps.config import get_config, FPSConfig # type: ignore
from fastapi_users.authentication import CookieAuthentication # type: ignore
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi_users import FastAPIUsers # type: ignore
from fastapi_users import FastAPIUsers, BaseUserManager # type: ignore
from starlette.requests import Request
from sqlalchemy.orm import sessionmaker # type: ignore

from .config import get_auth_config
from .db import get_user_db, user_db, secret, database, engine, UserTable
from .models import (
user_db,
engine,
UserTable,
User,
UserCreate,
UserUpdate,
UserDB,
database,
secret,
)


Expand All @@ -31,6 +28,29 @@
auth_config = get_auth_config()


class UserManager(BaseUserManager[UserCreate, UserDB]):
user_db_model = UserDB

async def on_after_register(self, user: UserDB, request: Optional[Request] = None):
user.initialized = True
for oauth_account in user.oauth_accounts:
print(oauth_account)
if oauth_account.oauth_name == "github":
r = httpx.get(
f"https://api.github.com/user/{oauth_account.account_id}"
).json()
user.anonymous = False
user.username = r["login"]
user.name = r["name"]
user.color = None
user.avatar = r["avatar_url"]
await self.user_db.update(user)


def get_user_manager(user_db=Depends(get_user_db)):
yield UserManager(user_db)


class LoginCookieAuthentication(CookieAuthentication):
async def get_login_response(self, user, response):
await super().get_login_response(user, response)
Expand All @@ -55,7 +75,7 @@ async def get_logout_response(self, user, response):
auth_backends = [cookie_authentication]

users = FastAPIUsers(
user_db,
get_user_manager,
auth_backends,
User,
UserCreate,
Expand All @@ -68,29 +88,11 @@ async def get_logout_response(self, user, response):
)


async def on_after_register(user: UserDB, request):
user.initialized = True
await user_db.update(user)


async def on_after_github_register(user: UserDB, request: Request):
r = httpx.get(
f"https://api.github.com/user/{user.oauth_accounts[0].account_id}"
).json()
user.initialized = True
user.anonymous = False
user.username = r["login"]
user.name = r["name"]
user.color = None
user.avatar = r["avatar_url"]
await user_db.update(user)


github_oauth_router = users.get_oauth_router(
github_oauth_client, secret, after_register=on_after_github_register # type: ignore
github_oauth_client, secret # type: ignore
)
auth_router = users.get_auth_router(cookie_authentication)
user_register_router = users.get_register_router(on_after_register) # type: ignore
user_register_router = users.get_register_router() # type: ignore
users_router = users.get_users_router()

router = APIRouter()
Expand Down
3 changes: 1 addition & 2 deletions plugins/auth/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
packages=find_packages(),
install_requires=[
"fps",
"fastapi-users[sqlalchemy]>=7.0.0",
"httpx-oauth",
"fastapi-users[sqlalchemy,oauth]==8",
"aiosqlite",
],
entry_points={
Expand Down
3 changes: 2 additions & 1 deletion plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from starlette.requests import Request # type: ignore

from fps_auth.routes import cookie_authentication, current_user # type: ignore
from fps_auth.models import User, user_db # type: ignore
from fps_auth.models import User # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

from .kernel_server.server import KernelServer # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion plugins/terminals/fps_terminals/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from fastapi import APIRouter, WebSocket, Response, Depends, status

from fps_auth.routes import cookie_authentication, current_user # type: ignore
from fps_auth.models import User, user_db # type: ignore
from fps_auth.models import User # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

from .models import Terminal
Expand Down
2 changes: 1 addition & 1 deletion plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import fastapi

from fps_auth.routes import cookie_authentication # type: ignore
from fps_auth.models import user_db # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

router = APIRouter()
Expand Down

0 comments on commit 2191e9b

Please sign in to comment.