diff --git a/jupyverse_api/jupyverse_api/app/__init__.py b/jupyverse_api/jupyverse_api/app/__init__.py index 618bb677..75e01c84 100644 --- a/jupyverse_api/jupyverse_api/app/__init__.py +++ b/jupyverse_api/jupyverse_api/app/__init__.py @@ -11,6 +11,8 @@ class App: + """A wrapper around FastAPI that checks for endpoint path conflicts.""" + _app: FastAPI _router_paths: Dict[str, List[str]] diff --git a/jupyverse_api/jupyverse_api/auth/__init__.py b/jupyverse_api/jupyverse_api/auth/__init__.py index a4476d47..f8cd3d7b 100644 --- a/jupyverse_api/jupyverse_api/auth/__init__.py +++ b/jupyverse_api/jupyverse_api/auth/__init__.py @@ -1,32 +1,26 @@ +from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Tuple from jupyverse_api import Config -from pydantic import BaseModel +from .models import User # noqa -class User(BaseModel): - username: str = "" - name: str = "" - display_name: str = "" - initials: Optional[str] = None - color: Optional[str] = None - avatar_url: Optional[str] = None - workspace: str = "{}" - settings: str = "{}" - -class Auth: +class Auth(ABC): + @abstractmethod def current_user(self, permissions: Optional[Dict[str, List[str]]] = None) -> Callable: - raise RuntimeError("Not implemented") + ... + @abstractmethod async def update_user(self) -> Callable: - raise RuntimeError("Not implemented") + ... + @abstractmethod def websocket_auth( self, permissions: Optional[Dict[str, List[str]]] = None, ) -> Callable[[], Tuple[Any, Dict[str, List[str]]]]: - raise RuntimeError("Not implemented") + ... class AuthConfig(Config): diff --git a/plugins/noauth/fps_noauth/models.py b/jupyverse_api/jupyverse_api/auth/models.py similarity index 91% rename from plugins/noauth/fps_noauth/models.py rename to jupyverse_api/jupyverse_api/auth/models.py index 47f1b6e2..c6337013 100644 --- a/plugins/noauth/fps_noauth/models.py +++ b/jupyverse_api/jupyverse_api/auth/models.py @@ -1,10 +1,8 @@ from typing import Optional - from pydantic import BaseModel class User(BaseModel): - anonymous: bool = True username: str = "" name: str = "" display_name: str = "" diff --git a/jupyverse_api/jupyverse_api/contents/__init__.py b/jupyverse_api/jupyverse_api/contents/__init__.py index f8f84f99..d5256d4e 100644 --- a/jupyverse_api/jupyverse_api/contents/__init__.py +++ b/jupyverse_api/jupyverse_api/contents/__init__.py @@ -1,4 +1,5 @@ import asyncio +from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, Union @@ -7,26 +8,31 @@ from .models import Content, SaveContent -class FileIdManager: +class FileIdManager(ABC): stop_watching_files: asyncio.Event stopped_watching_files: asyncio.Event + @abstractmethod async def get_path(self, file_id: str) -> str: - raise RuntimeError("Not implemented") + ... + @abstractmethod async def get_id(self, file_path: str) -> str: - raise RuntimeError("Not implemented") + ... -class Contents(Router): +class Contents(Router, ABC): @property + @abstractmethod def file_id_manager(self) -> FileIdManager: - raise RuntimeError("Not implemented") + ... + @abstractmethod async def read_content( self, path: Union[str, Path], get_content: bool, as_json: bool = False ) -> Content: - raise RuntimeError("Not implemented") + ... + @abstractmethod async def write_content(self, content: Union[SaveContent, Dict]) -> None: - raise RuntimeError("Not implemented") + ... diff --git a/jupyverse_api/jupyverse_api/kernels/__init__.py b/jupyverse_api/jupyverse_api/kernels/__init__.py index c63b2f1c..83dc050c 100644 --- a/jupyverse_api/jupyverse_api/kernels/__init__.py +++ b/jupyverse_api/jupyverse_api/kernels/__init__.py @@ -1,12 +1,14 @@ -from typing import Optional +from abc import ABC, abstractmethod from pathlib import Path +from typing import Optional from jupyverse_api import Router, Config -class Kernels(Router): +class Kernels(Router, ABC): + @abstractmethod async def watch_connection_files(self, path: Path) -> None: - raise RuntimeError("Not implemented") + ... class KernelsConfig(Config): diff --git a/jupyverse_api/jupyverse_api/lab/__init__.py b/jupyverse_api/jupyverse_api/lab/__init__.py index aed5bcc6..151dea2e 100644 --- a/jupyverse_api/jupyverse_api/lab/__init__.py +++ b/jupyverse_api/jupyverse_api/lab/__init__.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Tuple @@ -5,11 +6,13 @@ from jupyverse_api import Router -class Lab(Router): +class Lab(Router, ABC): + @abstractmethod def init_router( self, router: APIRouter, redirect_after_root: str ) -> Tuple[Path, List[Dict[str, Any]]]: - raise RuntimeError("Not implemented") + ... + @abstractmethod def get_federated_extensions(self, extensions_dir: Path) -> Tuple[List, List]: - raise RuntimeError("Not implemented") + ... diff --git a/plugins/auth/fps_auth/backends.py b/plugins/auth/fps_auth/backends.py index 7b58bf0b..548357b5 100644 --- a/plugins/auth/fps_auth/backends.py +++ b/plugins/auth/fps_auth/backends.py @@ -1,5 +1,6 @@ import logging import uuid +from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, Tuple import httpx @@ -20,118 +21,137 @@ from fastapi_users.db import SQLAlchemyUserDatabase from httpx_oauth.clients.github import GitHubOAuth2 from jupyverse_api.exceptions import RedirectException +from jupyverse_api.frontend import FrontendConfig from starlette.requests import Request -from .db import Db, User +from .config import _AuthConfig +from .db import User from .models import UserCreate, UserRead logger = logging.getLogger("auth") -class Backend: - def __init__(self, auth_config, frontend_config): - self.auth_config = auth_config - self.frontend_config = frontend_config - self.db = db = Db(auth_config) - - class NoAuthStrategy(Strategy, Generic[models.UP, models.ID]): - async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] - ) -> Optional[models.UP]: - active_user = await user_manager.user_db.get_by_email(auth_config.global_email) - return active_user - - async def write_token(self, user: models.UP): - pass - - async def destroy_token(self, token: str, user: models.UP): - pass - - def get_noauth_strategy() -> NoAuthStrategy: - return NoAuthStrategy() - - self.noauth_authentication = AuthenticationBackend( - name="noauth", - transport=NoAuthTransport(), - get_strategy=get_noauth_strategy, - ) - self.cookie_authentication = AuthenticationBackend( - name="cookie", - transport=CookieTransport(cookie_secure=auth_config.cookie_secure), - get_strategy=self._get_jwt_strategy, - ) - self.github_cookie_authentication = AuthenticationBackend( - name="github", - transport=GitHubTransport(), - get_strategy=self._get_jwt_strategy, - ) - self.github_authentication = GitHubOAuth2(auth_config.client_id, auth_config.client_secret) - - class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): - async def on_after_register(self, user: User, request: Optional[Request] = None): - for oauth_account in user.oauth_accounts: - if oauth_account.oauth_name == "github": - async with httpx.AsyncClient() as client: - r = ( - await client.get( - f"https://api.github.com/user/{oauth_account.account_id}" - ) - ).json() - - await self.user_db.update( - user, - dict( - anonymous=False, - username=r["login"], - color=None, - avatar_url=r["avatar_url"], - is_active=True, - ), - ) - - self.UserManager = UserManager - - async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(db.get_user_db)): - yield UserManager(user_db) - - self.get_user_manager = get_user_manager - - self.fapi_users = FastAPIUsers[User, uuid.UUID]( - get_user_manager, - [ - self.noauth_authentication, - self.cookie_authentication, - self.github_cookie_authentication, - ], - ) - - async def update_user( - user: UserRead = Depends(self.current_user()), - user_db: SQLAlchemyUserDatabase = Depends(db.get_user_db), - ): - async def _(data: Dict[str, Any]) -> UserRead: - await user_db.update(user, data) - return user - - return _ - - self.update_user = update_user - - def _get_jwt_strategy(self) -> JWTStrategy: - return JWTStrategy(secret=self.db.secret, lifetime_seconds=None) - - def _get_enabled_backends(self): - if self.auth_config.mode == "noauth" and not self.frontend_config.collaborative: - res = [self.noauth_authentication, self.github_cookie_authentication] +@dataclass +class Res: + cookie_authentication: Any + current_user: Any + update_user: Any + fapi_users: Any + get_user_manager: Any + github_authentication: Any + github_cookie_authentication: Any + websocket_auth: Any + + +def get_backend(auth_config: _AuthConfig, frontend_config: FrontendConfig, db) -> Res: + class NoAuthTransport(Transport): + scheme = None # type: ignore + + async def get_login_response(self, token: str, response: Response): + pass + + async def get_logout_response(self, response: Response): + pass + + @staticmethod + def get_openapi_login_responses_success(): + pass + + @staticmethod + def get_openapi_logout_responses_success(): + pass + + class NoAuthStrategy(Strategy, Generic[models.UP, models.ID]): + async def read_token( + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] + ) -> Optional[models.UP]: + active_user = await user_manager.user_db.get_by_email(auth_config.global_email) + return active_user + + async def write_token(self, user: models.UP): + pass + + async def destroy_token(self, token: str, user: models.UP): + pass + + class GitHubTransport(CookieTransport): + async def get_login_response(self, token: str, response: Response): + await super().get_login_response(token, response) + response.status_code = status.HTTP_302_FOUND + response.headers["Location"] = "/lab" + + def get_noauth_strategy() -> NoAuthStrategy: + return NoAuthStrategy() + + def get_jwt_strategy() -> JWTStrategy: + return JWTStrategy(secret=db.secret, lifetime_seconds=None) + + noauth_authentication = AuthenticationBackend( + name="noauth", + transport=NoAuthTransport(), + get_strategy=get_noauth_strategy, + ) + + cookie_authentication = AuthenticationBackend( + name="cookie", + transport=CookieTransport(cookie_secure=auth_config.cookie_secure), + get_strategy=get_jwt_strategy, + ) + + github_cookie_authentication = AuthenticationBackend( + name="github", + transport=GitHubTransport(), + get_strategy=get_jwt_strategy, + ) + + github_authentication = GitHubOAuth2(auth_config.client_id, auth_config.client_secret) + + class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): + async def on_after_register(self, user: User, request: Optional[Request] = None): + for oauth_account in user.oauth_accounts: + if oauth_account.oauth_name == "github": + async with httpx.AsyncClient() as client: + r = ( + await client.get( + f"https://api.github.com/user/{oauth_account.account_id}" + ) + ).json() + + await self.user_db.update( + user, + dict( + anonymous=False, + username=r["login"], + color=None, + avatar_url=r["avatar_url"], + is_active=True, + ), + ) + + async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(db.get_user_db)): + yield UserManager(user_db) + + def get_enabled_backends(): + if auth_config.mode == "noauth" and not frontend_config.collaborative: + res = [noauth_authentication, github_cookie_authentication] else: - res = [self.cookie_authentication, self.github_cookie_authentication] + res = [cookie_authentication, github_cookie_authentication] return res - async def _create_guest(self, user_manager): + fapi_users = FastAPIUsers[User, uuid.UUID]( + get_user_manager, + [ + noauth_authentication, + cookie_authentication, + github_cookie_authentication, + ], + ) + + async def create_guest(user_manager): # workspace and settings are copied from global user # but this is a new user - global_user = await user_manager.get_by_email(self.auth_config.global_email) + global_user = await user_manager.get_by_email(auth_config.global_email) user_id = str(uuid.uuid4()) guest = dict( anonymous=True, @@ -144,18 +164,16 @@ async def _create_guest(self, user_manager): ) return await user_manager.create(UserCreate(**guest)) - def current_user(self, permissions: Optional[Dict[str, List[str]]] = None): + def current_user(permissions: Optional[Dict[str, List[str]]] = None): async def _( response: Response, token: Optional[str] = None, user: Optional[User] = Depends( - self.fapi_users.current_user( - optional=True, get_enabled_backends=self._get_enabled_backends - ) + fapi_users.current_user(optional=True, get_enabled_backends=get_enabled_backends) ), - user_manager: BaseUserManager[User, models.ID] = Depends(self.get_user_manager), + user_manager: BaseUserManager[User, models.ID] = Depends(get_user_manager), ): - if self.auth_config.mode == "user": + if auth_config.mode == "user": # "user" authentication: check authorization if user and permissions: for resource, actions in permissions.items(): @@ -165,41 +183,35 @@ async def _( break else: # "noauth" or "token" authentication - if self.frontend_config.collaborative: - if not user and self.auth_config.mode == "noauth": - user = await self._create_guest(user_manager) - await self.cookie_authentication.login( - self._get_jwt_strategy(), user, response - ) - - elif not user and self.auth_config.mode == "token": - global_user = await user_manager.get_by_email(self.auth_config.global_email) + if frontend_config.collaborative: + if not user and auth_config.mode == "noauth": + user = await create_guest(user_manager) + await cookie_authentication.login(get_jwt_strategy(), user, response) + + elif not user and auth_config.mode == "token": + global_user = await user_manager.get_by_email(auth_config.global_email) if global_user and global_user.username == token: - user = await self._create_guest(user_manager) - await self.cookie_authentication.login( - self._get_jwt_strategy(), user, response - ) + user = await create_guest(user_manager) + await cookie_authentication.login(get_jwt_strategy(), user, response) else: - if self.auth_config.mode == "token": - global_user = await user_manager.get_by_email(self.auth_config.global_email) + if auth_config.mode == "token": + global_user = await user_manager.get_by_email(auth_config.global_email) if global_user and global_user.username == token: user = global_user - await self.cookie_authentication.login( - self._get_jwt_strategy(), user, response - ) + await cookie_authentication.login(get_jwt_strategy(), user, response) if user: return user - elif self.auth_config.login_url: - raise RedirectException(self.auth_config.login_url) + elif auth_config.login_url: + raise RedirectException(auth_config.login_url) else: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return _ - def websocket_auth(self, permissions: Optional[Dict[str, List[str]]] = None): + def websocket_auth(permissions: Optional[Dict[str, List[str]]] = None): """ A function returning a dependency for the WebSocket connection. @@ -212,17 +224,17 @@ def websocket_auth(self, permissions: Optional[Dict[str, List[str]]] = None): async def _( websocket: WebSocket, - user_manager: BaseUserManager[models.UP, models.ID] = Depends(self.get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ) -> Optional[Tuple[WebSocket, Optional[Dict[str, List[str]]]]]: accept_websocket = False checked_permissions: Optional[Dict[str, List[str]]] = None - if self.auth_config.mode == "noauth": + if auth_config.mode == "noauth": accept_websocket = True elif "fastapiusersauth" in websocket._cookies: token = websocket._cookies["fastapiusersauth"] - user = await self._get_jwt_strategy().read_token(token, user_manager) + user = await get_jwt_strategy().read_token(token, user_manager) if user: - if self.auth_config.mode == "user": + if auth_config.mode == "user": # "user" authentication: check authorization if permissions is None: accept_websocket = True @@ -247,31 +259,23 @@ async def _( return _ - @property - def User(self): - return UserRead - - -class NoAuthTransport(Transport): - scheme = None # type: ignore - - async def get_login_response(self, token: str, response: Response): - pass - - async def get_logout_response(self, response: Response): - pass - - @staticmethod - def get_openapi_login_responses_success(): - pass - - @staticmethod - def get_openapi_logout_responses_success(): - pass + async def update_user( + user: UserRead = Depends(current_user()), + user_db: SQLAlchemyUserDatabase = Depends(db.get_user_db), + ): + async def _(data: Dict[str, Any]) -> UserRead: + await user_db.update(user, data) + return user + return _ -class GitHubTransport(CookieTransport): - async def get_login_response(self, token: str, response: Response): - await super().get_login_response(token, response) - response.status_code = status.HTTP_302_FOUND - response.headers["Location"] = "/lab" + return Res( + cookie_authentication=cookie_authentication, + current_user=current_user, + update_user=update_user, + fapi_users=fapi_users, + get_user_manager=get_user_manager, + github_authentication=github_authentication, + github_cookie_authentication=github_cookie_authentication, + websocket_auth=websocket_auth, + ) diff --git a/plugins/auth/fps_auth/db.py b/plugins/auth/fps_auth/db.py index 44189eb8..1e40c181 100644 --- a/plugins/auth/fps_auth/db.py +++ b/plugins/auth/fps_auth/db.py @@ -1,7 +1,8 @@ import logging import secrets +from dataclasses import dataclass from pathlib import Path -from typing import AsyncGenerator, List +from typing import Any, AsyncGenerator, List from fastapi import Depends from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID @@ -14,53 +15,10 @@ from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base # type: ignore from sqlalchemy.orm import relationship, sessionmaker # type: ignore +from .config import _AuthConfig -logger = logging.getLogger("auth") - - -class Db: - def __init__(self, auth_config): - jupyter_dir = Path.home() / ".local" / "share" / "jupyter" - jupyter_dir.mkdir(parents=True, exist_ok=True) - - name = "jupyverse" - if auth_config.test: - name += "_test" - - secret_path = jupyter_dir / f"{name}_secret" - userdb_path = jupyter_dir / f"{name}_users.db" - - if auth_config.clear_users: - if userdb_path.is_file(): - userdb_path.unlink() - if secret_path.is_file(): - secret_path.unlink() - - if not secret_path.is_file(): - secret_path.write_text(secrets.token_hex(32)) - - self.secret = secret_path.read_text() - - database_url = f"sqlite+aiosqlite:///{userdb_path}" - - engine = create_async_engine(database_url) - async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) - - async def get_async_session() -> AsyncGenerator[AsyncSession, None]: - async with async_session_maker() as session: - yield session - - async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(session, User, OAuthAccount) - - async def create_db_and_tables(): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - self.get_async_session = get_async_session - self.get_user_db = get_user_db - self.create_db_and_tables = create_db_and_tables +logger = logging.getLogger("auth") Base: DeclarativeMeta = declarative_base() @@ -82,3 +40,59 @@ class User(SQLAlchemyBaseUserTableUUID, Base): settings = Column(Text(), default="{}", nullable=False) permissions = Column(JSON, default={}, nullable=False) oauth_accounts: List[OAuthAccount] = relationship("OAuthAccount", lazy="joined") + + +@dataclass +class Res: + User: Any + async_session_maker: Any + create_db_and_tables: Any + get_async_session: Any + get_user_db: Any + secret: Any + + +def get_db(auth_config: _AuthConfig) -> Res: + jupyter_dir = Path.home() / ".local" / "share" / "jupyter" + jupyter_dir.mkdir(parents=True, exist_ok=True) + name = "jupyverse" + if auth_config.test: + name += "_test" + secret_path = jupyter_dir / f"{name}_secret" + userdb_path = jupyter_dir / f"{name}_users.db" + + if auth_config.clear_users: + if userdb_path.is_file(): + userdb_path.unlink() + if secret_path.is_file(): + secret_path.unlink() + + if not secret_path.is_file(): + secret_path.write_text(secrets.token_hex(32)) + + secret = secret_path.read_text() + + database_url = f"sqlite+aiosqlite:///{userdb_path}" + + engine = create_async_engine(database_url) + async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async def create_db_and_tables(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + async with async_session_maker() as session: + yield session + + async def get_user_db(session: AsyncSession = Depends(get_async_session)): + yield SQLAlchemyUserDatabase(session, User, OAuthAccount) + + return Res( + User=User, + async_session_maker=async_session_maker, + create_db_and_tables=create_db_and_tables, + get_async_session=get_async_session, + get_user_db=get_user_db, + secret=secret, + ) diff --git a/plugins/auth/fps_auth/main.py b/plugins/auth/fps_auth/main.py index c651235a..6c42a956 100644 --- a/plugins/auth/fps_auth/main.py +++ b/plugins/auth/fps_auth/main.py @@ -7,7 +7,7 @@ from jupyverse_api.app import App from .config import _AuthConfig -from .routes import _Auth +from .routes import auth_factory logger = logging.getLogger("auth") @@ -26,7 +26,7 @@ async def start( app = await ctx.request_resource(App) frontend_config = await ctx.request_resource(FrontendConfig) - auth = _Auth(app, self.auth_config, frontend_config) + auth = auth_factory(app, self.auth_config, frontend_config) ctx.add_resource(auth, types=Auth) await auth.db.create_db_and_tables() diff --git a/plugins/auth/fps_auth/models.py b/plugins/auth/fps_auth/models.py index 4354d008..92295dcf 100644 --- a/plugins/auth/fps_auth/models.py +++ b/plugins/auth/fps_auth/models.py @@ -1,24 +1,13 @@ import uuid -from typing import Dict, List, Optional +from typing import Dict, List from fastapi_users import schemas -from pydantic import BaseModel +from jupyverse_api.auth import User -class Permissions(BaseModel): - permissions: Dict[str, List[str]] - - -class JupyterUser(Permissions): +class JupyterUser(User): anonymous: bool = True - username: str = "" - name: str = "" - display_name: str = "" - initials: Optional[str] = None - color: Optional[str] = None - avatar_url: Optional[str] = None - workspace: str = "{}" - settings: str = "{}" + permissions: Dict[str, List[str]] class UserRead(schemas.BaseUser[uuid.UUID], JupyterUser): diff --git a/plugins/auth/fps_auth/routes.py b/plugins/auth/fps_auth/routes.py index 869c336d..fb1bd8ca 100644 --- a/plugins/auth/fps_auth/routes.py +++ b/plugins/auth/fps_auth/routes.py @@ -1,7 +1,7 @@ import contextlib import json import logging -from typing import Dict, List +from typing import Any, Callable, Dict, List, Optional, Tuple from fastapi import APIRouter, Depends, Request from jupyverse_api import Router @@ -10,9 +10,9 @@ from jupyverse_api.frontend import FrontendConfig from sqlalchemy import select # type: ignore -from .backends import Backend +from .backends import get_backend from .config import _AuthConfig -from .db import User +from .db import get_db from .models import UserCreate, UserRead, UserUpdate @@ -20,116 +20,131 @@ logger = logging.getLogger("auth") -class _Auth(Backend, Auth, Router): - def __init__( - self, - app: App, - auth_config: _AuthConfig, - frontend_config: FrontendConfig, - ) -> None: - Router.__init__(self, app) - Backend.__init__(self, auth_config, frontend_config) - - self.auth_config = auth_config - self.backend = backend = Backend(auth_config, frontend_config) - - db = self.db - - router = APIRouter() - - get_async_session_context = contextlib.asynccontextmanager(db.get_async_session) - get_user_db_context = contextlib.asynccontextmanager(db.get_user_db) - get_user_manager_context = contextlib.asynccontextmanager(backend.get_user_manager) - - @contextlib.asynccontextmanager - async def _get_user_manager(): - async with get_async_session_context() as session: - async with get_user_db_context(session) as user_db: - async with get_user_manager_context(user_db) as user_manager: - yield user_manager - - async def create_user(**kwargs): - async with _get_user_manager() as user_manager: - await user_manager.create(UserCreate(**kwargs)) - - self.create_user = create_user - - async def _update_user(user, **kwargs): - async with _get_user_manager() as user_manager: - await user_manager.update(UserUpdate(**kwargs), user) - - self._update_user = _update_user - - async def get_user_by_email(user_email): - async with _get_user_manager() as user_manager: - return await user_manager.get_by_email(user_email) - - self.get_user_by_email = get_user_by_email - - @router.get("/auth/users") - async def get_users( - user: UserRead = Depends(self.current_user(permissions={"admin": ["read"]})), - ): - async with db.async_session_maker() as session: - statement = select(User) - users = (await session.execute(statement)).unique().all() - return [usr.User for usr in users if usr.User.is_active] - - @router.get("/api/me") - async def get_api_me( - request: Request, - user: UserRead = Depends(self.current_user()), - ): - checked_permissions: Dict[str, List[str]] = {} - permissions = json.loads( - dict(request.query_params).get("permissions", "{}").replace("'", '"') +def auth_factory( + app: App, + auth_config: _AuthConfig, + frontend_config: FrontendConfig, +): + db = get_db(auth_config) + backend = get_backend(auth_config, frontend_config, db) + + get_async_session_context = contextlib.asynccontextmanager(db.get_async_session) + get_user_db_context = contextlib.asynccontextmanager(db.get_user_db) + get_user_manager_context = contextlib.asynccontextmanager(backend.get_user_manager) + + @contextlib.asynccontextmanager + async def _get_user_manager(): + async with get_async_session_context() as session: + async with get_user_db_context(session) as user_db: + async with get_user_manager_context(user_db) as user_manager: + yield user_manager + + async def create_user(**kwargs): + async with _get_user_manager() as user_manager: + await user_manager.create(UserCreate(**kwargs)) + + async def update_user(user, **kwargs): + async with _get_user_manager() as user_manager: + await user_manager.update(UserUpdate(**kwargs), user) + + async def get_user_by_email(user_email): + async with _get_user_manager() as user_manager: + return await user_manager.get_by_email(user_email) + + class _Auth(Auth, Router): + def __init__(self) -> None: + super().__init__(app) + + self.db = db + + router = APIRouter() + + @router.get("/auth/users") + async def get_users( + user: UserRead = Depends(backend.current_user(permissions={"admin": ["read"]})), + ): + async with db.async_session_maker() as session: + statement = select(db.User) + users = (await session.execute(statement)).unique().all() + return [usr.User for usr in users if usr.User.is_active] + + @router.get("/api/me") + async def get_api_me( + request: Request, + user: UserRead = Depends(backend.current_user()), + ): + checked_permissions: Dict[str, List[str]] = {} + permissions = json.loads( + dict(request.query_params).get("permissions", "{}").replace("'", '"') + ) + if permissions: + user_permissions = user.permissions + for resource, actions in permissions.items(): + user_resource_permissions = user_permissions.get(resource) + if user_resource_permissions is None: + continue + allowed = checked_permissions[resource] = [] + for action in actions: + if action in user_resource_permissions: + allowed.append(action) + + keys = ["username", "name", "display_name", "initials", "avatar_url", "color"] + identity = {k: getattr(user, k) for k in keys} + return { + "identity": identity, + "permissions": checked_permissions, + } + + # redefine GET /me because we want our current_user dependency + # it is first defined in users_router and so it wins over the one in + # fapi_users.get_users_router + users_router = APIRouter() + + @users_router.get("/me") + async def get_me( + user: UserRead = Depends(backend.current_user(permissions={"admin": ["read"]})), + ): + return user + + users_router.include_router(backend.fapi_users.get_users_router(UserRead, UserUpdate)) + + # Cookie based auth login and logout + self.include_router( + backend.fapi_users.get_auth_router(backend.cookie_authentication), prefix="/auth" ) - if permissions: - user_permissions = user.permissions - for resource, actions in permissions.items(): - user_resource_permissions = user_permissions.get(resource) - if user_resource_permissions is None: - continue - allowed = checked_permissions[resource] = [] - for action in actions: - if action in user_resource_permissions: - allowed.append(action) - - keys = ["username", "name", "display_name", "initials", "avatar_url", "color"] - identity = {k: getattr(user, k) for k in keys} - return { - "identity": identity, - "permissions": checked_permissions, - } - - # redefine GET /me because we want our current_user dependency - # it is first defined in users_router and so it wins over the one in - # fapi_users.get_users_router - users_router = APIRouter() - - @users_router.get("/me") - async def get_me( - user: UserRead = Depends(self.current_user(permissions={"admin": ["read"]})), - ): - return user - - users_router.include_router(self.fapi_users.get_users_router(UserRead, UserUpdate)) - - # Cookie based auth login and logout - self.include_router( - self.fapi_users.get_auth_router(self.cookie_authentication), prefix="/auth" - ) - self.include_router( - self.fapi_users.get_register_router(UserRead, UserCreate), - prefix="/auth", - dependencies=[Depends(self.current_user(permissions={"admin": ["write"]}))], - ) - self.include_router(users_router, prefix="/auth/user") - # GitHub OAuth register router - self.include_router( - self.fapi_users.get_oauth_router( - self.github_authentication, self.github_cookie_authentication, db.secret - ), - prefix="/auth/github", - ) - self.include_router(router) + self.include_router( + backend.fapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + dependencies=[Depends(backend.current_user(permissions={"admin": ["write"]}))], + ) + self.include_router(users_router, prefix="/auth/user") + + # GitHub OAuth register router + self.include_router( + backend.fapi_users.get_oauth_router( + backend.github_authentication, backend.github_cookie_authentication, db.secret + ), + prefix="/auth/github", + ) + self.include_router(router) + + self.create_user = create_user + self.__update_user = update_user + self.get_user_by_email = get_user_by_email + + async def _update_user(self, user, **kwargs): + return await self.__update_user(user, **kwargs) + + def current_user(self, permissions: Optional[Dict[str, List[str]]] = None) -> Callable: + return backend.current_user(permissions) + + async def update_user(self, update_user=Depends(backend.update_user)) -> Callable: + return update_user + + def websocket_auth( + self, + permissions: Optional[Dict[str, List[str]]] = None, + ) -> Callable[[], Tuple[Any, Dict[str, List[str]]]]: + return backend.websocket_auth(permissions) + + return _Auth() diff --git a/plugins/auth_fief/fps_auth_fief/backend.py b/plugins/auth_fief/fps_auth_fief/backend.py index 1338d623..05c09871 100644 --- a/plugins/auth_fief/fps_auth_fief/backend.py +++ b/plugins/auth_fief/fps_auth_fief/backend.py @@ -4,9 +4,9 @@ from fastapi.security import APIKeyCookie from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo from fief_client.integrations.fastapi import FiefAuth +from jupyverse_api.auth import User from .config import _AuthFiefConfig -from .models import UserRead class Backend: @@ -87,14 +87,10 @@ def current_user(permissions=None): async def _( user: FiefUserInfo = Depends(self.auth.current_user(permissions=permissions)), ): - return UserRead(**user["fields"]) + return User(**user["fields"]) return _ self.current_user = current_user self.update_user = update_user self.websocket_auth = websocket_auth - - @property - def User(self): - return UserRead diff --git a/plugins/auth_fief/fps_auth_fief/models.py b/plugins/auth_fief/fps_auth_fief/models.py deleted file mode 100644 index 89786c57..00000000 --- a/plugins/auth_fief/fps_auth_fief/models.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Dict, List, Optional - -from pydantic import BaseModel - - -class Permissions(BaseModel): - permissions: Dict[str, List[str]] - - -class UserRead(BaseModel): - username: str = "" - name: str = "" - display_name: str = "" - initials: Optional[str] = None - color: Optional[str] = None - avatar_url: Optional[str] = None - workspace: str = "{}" - settings: str = "{}" diff --git a/plugins/auth_fief/fps_auth_fief/routes.py b/plugins/auth_fief/fps_auth_fief/routes.py index 327cd7ed..d388c49d 100644 --- a/plugins/auth_fief/fps_auth_fief/routes.py +++ b/plugins/auth_fief/fps_auth_fief/routes.py @@ -6,11 +6,10 @@ from fief_client import FiefAccessTokenInfo from jupyverse_api import Router from jupyverse_api.app import App -from jupyverse_api.auth import Auth +from jupyverse_api.auth import Auth, User from .backend import Backend from .config import _AuthFiefConfig -from .models import UserRead class _AuthFief(Backend, Auth, Router): @@ -43,7 +42,7 @@ async def auth_callback(request: Request, response: Response, code: str = Query( @router.get("/api/me") async def get_api_me( request: Request, - user: UserRead = Depends(self.current_user()), + user: User = Depends(self.current_user()), access_token_info: FiefAccessTokenInfo = Depends(self.auth.authenticated()), ): checked_permissions: Dict[str, List[str]] = {} diff --git a/plugins/contents/fps_contents/main.py b/plugins/contents/fps_contents/main.py index c5b6163d..36dc5cfe 100644 --- a/plugins/contents/fps_contents/main.py +++ b/plugins/contents/fps_contents/main.py @@ -12,7 +12,7 @@ async def start( ctx: Context, ) -> None: app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) + auth = await ctx.request_resource(Auth) # type: ignore contents = _Contents(app, auth) ctx.add_resource(contents, types=Contents) diff --git a/plugins/jupyterlab/fps_jupyterlab/main.py b/plugins/jupyterlab/fps_jupyterlab/main.py index ef7bc7ab..ec2fd86f 100644 --- a/plugins/jupyterlab/fps_jupyterlab/main.py +++ b/plugins/jupyterlab/fps_jupyterlab/main.py @@ -19,9 +19,9 @@ async def start( ctx.add_resource(self.jupyterlab_config, types=JupyterLabConfig) app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) + auth = await ctx.request_resource(Auth) # type: ignore frontend_config = await ctx.request_resource(FrontendConfig) - lab = await ctx.request_resource(Lab) + lab = await ctx.request_resource(Lab) # type: ignore jupyterlab = _JupyterLab(app, self.jupyterlab_config, auth, frontend_config, lab) ctx.add_resource(jupyterlab, types=JupyterLab) diff --git a/plugins/kernels/fps_kernels/main.py b/plugins/kernels/fps_kernels/main.py index 7f0614fa..28aef7d7 100644 --- a/plugins/kernels/fps_kernels/main.py +++ b/plugins/kernels/fps_kernels/main.py @@ -3,6 +3,7 @@ import logging from collections.abc import AsyncGenerator from pathlib import Path +from typing import Optional from asphalt.core import Component, Context, context_teardown @@ -26,11 +27,11 @@ def __init__(self, **kwargs): async def start( self, ctx: Context, - ) -> AsyncGenerator[None, BaseException | None]: + ) -> AsyncGenerator[None, Optional[BaseException]]: ctx.add_resource(self.kernels_config, types=KernelsConfig) app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) + auth = await ctx.request_resource(Auth) # type: ignore frontend_config = await ctx.request_resource(FrontendConfig) yjs = await ctx.request_resource(Yjs) diff --git a/plugins/lab/fps_lab/main.py b/plugins/lab/fps_lab/main.py index 4e77aca1..c03a42e6 100644 --- a/plugins/lab/fps_lab/main.py +++ b/plugins/lab/fps_lab/main.py @@ -14,7 +14,7 @@ async def start( ctx: Context, ) -> None: app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) + auth = await ctx.request_resource(Auth) # type: ignore frontend_config = await ctx.request_resource(FrontendConfig) jupyterlab_config = ctx.get_resource(JupyterLabConfig) diff --git a/plugins/login/fps_login/routes.py b/plugins/login/fps_login/routes.py index f361e766..ba117254 100644 --- a/plugins/login/fps_login/routes.py +++ b/plugins/login/fps_login/routes.py @@ -10,7 +10,7 @@ class _AuthConfig(AuthConfig): - login_url: Optional[str] = None + login_url: Optional[str] class _Login(Login): diff --git a/plugins/nbconvert/fps_nbconvert/main.py b/plugins/nbconvert/fps_nbconvert/main.py index 650863eb..b5c8d98f 100644 --- a/plugins/nbconvert/fps_nbconvert/main.py +++ b/plugins/nbconvert/fps_nbconvert/main.py @@ -12,7 +12,7 @@ async def start( ctx: Context, ) -> None: app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) + auth = await ctx.request_resource(Auth) # type: ignore nbconvert = _Nbconvert(app, auth) ctx.add_resource(nbconvert, types=Nbconvert) diff --git a/plugins/noauth/fps_noauth/backends.py b/plugins/noauth/fps_noauth/backends.py index b522b9ce..d042eb57 100644 --- a/plugins/noauth/fps_noauth/backends.py +++ b/plugins/noauth/fps_noauth/backends.py @@ -1,18 +1,12 @@ from typing import Any, Dict, List, Optional, Tuple from fastapi import WebSocket -from jupyverse_api.auth import Auth - -from .models import User +from jupyverse_api.auth import Auth, User USER = User() class _NoAuth(Auth): - @property - def User(self): - return User - def current_user(self, *args, **kwargs): async def _(): return USER diff --git a/plugins/resource_usage/fps_resource_usage/main.py b/plugins/resource_usage/fps_resource_usage/main.py index 3c77b0ba..8eafa1bd 100644 --- a/plugins/resource_usage/fps_resource_usage/main.py +++ b/plugins/resource_usage/fps_resource_usage/main.py @@ -15,7 +15,7 @@ async def start( ctx: Context, ) -> None: app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) + auth = await ctx.request_resource(Auth) # type: ignore resource_usage = _ResourceUsage(app, auth, self.resource_usage_config) ctx.add_resource(resource_usage, types=ResourceUsage) diff --git a/plugins/retrolab/fps_retrolab/main.py b/plugins/retrolab/fps_retrolab/main.py index dbbc07f1..661751b1 100644 --- a/plugins/retrolab/fps_retrolab/main.py +++ b/plugins/retrolab/fps_retrolab/main.py @@ -14,9 +14,9 @@ async def start( ctx: Context, ) -> None: app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) + auth = await ctx.request_resource(Auth) # type: ignore frontend_config = await ctx.request_resource(FrontendConfig) - lab = await ctx.request_resource(Lab) + lab = await ctx.request_resource(Lab) # type: ignore retrolab = _RetroLab(app, auth, frontend_config, lab) ctx.add_resource(retrolab, types=RetroLab) diff --git a/plugins/terminals/fps_terminals/main.py b/plugins/terminals/fps_terminals/main.py index ae047739..1ec474aa 100644 --- a/plugins/terminals/fps_terminals/main.py +++ b/plugins/terminals/fps_terminals/main.py @@ -21,7 +21,7 @@ async def start( ctx: Context, ) -> None: app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) + auth = await ctx.request_resource(Auth) # type: ignore terminals = _Terminals(app, auth, _TerminalServer) ctx.add_resource(terminals, types=Terminals) diff --git a/plugins/yjs/fps_yjs/main.py b/plugins/yjs/fps_yjs/main.py index cfd9aa70..7cbd9685 100644 --- a/plugins/yjs/fps_yjs/main.py +++ b/plugins/yjs/fps_yjs/main.py @@ -1,5 +1,6 @@ from __future__ import annotations from collections.abc import AsyncGenerator +from typing import Optional from asphalt.core import Component, Context, context_teardown from jupyverse_api.app import App @@ -15,10 +16,10 @@ class YjsComponent(Component): async def start( self, ctx: Context, - ) -> AsyncGenerator[None, BaseException | None]: + ) -> AsyncGenerator[None, Optional[BaseException]]: app = await ctx.request_resource(App) - auth = await ctx.request_resource(Auth) - contents = await ctx.request_resource(Contents) + auth = await ctx.request_resource(Auth) # type: ignore + contents = await ctx.request_resource(Contents) # type: ignore yjs = _Yjs(app, auth, contents) ctx.add_resource(yjs, types=Yjs)