Skip to content

Commit

Permalink
Merge pull request #60 from davidbrochart/get_user_db
Browse files Browse the repository at this point in the history
Use get_user_db dependency
  • Loading branch information
davidbrochart committed Sep 21, 2021
2 parents dd1a176 + 988bd04 commit c846f48
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 15 deletions.
13 changes: 7 additions & 6 deletions plugins/auth/fps_auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,20 @@ def get_user_manager(user_db=Depends(get_user_db)):


class LoginCookieAuthentication(CookieAuthentication):
async def get_login_response(self, user, response):
await super().get_login_response(user, response)
async def get_login_response(self, user, response, user_manager):
await super().get_login_response(user, response, user_manager)
# set user as logged in
user.logged_in = True
await user_db.update(user)
await user_manager.user_db.update(user)
# auto redirect
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = "/lab"

async def get_logout_response(self, user, response):
await super().get_logout_response(user, response)
async def get_logout_response(self, user, response, user_manager):
await super().get_logout_response(user, response, user_manager)
# set user as logged out
user.logged_in = False
await user_db.update(user)
await user_manager.user_db.update(user)


cookie_authentication = LoginCookieAuthentication(
Expand Down Expand Up @@ -153,6 +153,7 @@ def current_user(optional: bool = False):
async def _(
auth_config=Depends(get_auth_config),
user: User = Depends(users.current_user(optional=True)),
user_db=Depends(get_user_db),
):
if auth_config.mode == "noauth":
return await user_db.get_by_email(noauth_email)
Expand Down
2 changes: 1 addition & 1 deletion plugins/auth/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
packages=find_packages(),
install_requires=[
"fps",
"fastapi-users[sqlalchemy,oauth]==8",
"fastapi-users[sqlalchemy,oauth]>=8.1.0",
],
entry_points={
"fps_router": ["fps-auth = fps_auth.routes"],
Expand Down
9 changes: 7 additions & 2 deletions plugins/jupyterlab/fps_jupyterlab/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from starlette.requests import Request # type: ignore
from fps.hooks import register_router # type: ignore

from fps_auth.db import get_user_db # type: ignore
from fps_auth.routes import ( # type: ignore
current_user,
user_db,
cookie_authentication,
LoginCookieAuthentication,
get_user_manager,
)
from fps_auth.models import User # type: ignore
from fps_auth.config import get_auth_config # type: ignore
Expand Down Expand Up @@ -74,13 +75,15 @@ async def get_root(
token: Optional[UUID4] = None,
auth_config=Depends(get_auth_config),
jlab_config=Depends(get_jlab_config),
user_db=Depends(get_user_db),
user_manager=Depends(get_user_manager),
):
if token and auth_config.mode == "token":
user = await user_db.get(token)
if user:
await super(
LoginCookieAuthentication, cookie_authentication
).get_login_response(user, response)
).get_login_response(user, response, user_manager)
# auto redirect
response.status_code = status.HTTP_302_FOUND
response.headers["Location"] = jlab_config.base_url + "lab"
Expand Down Expand Up @@ -138,6 +141,7 @@ async def get_workspace_data(user: User = Depends(current_user(optional=True))):
async def set_workspace(
request: Request,
user: User = Depends(current_user()),
user_db=Depends(get_user_db),
):
user.workspace = await request.body()
await user_db.update(user)
Expand Down Expand Up @@ -245,6 +249,7 @@ async def change_setting(
name0,
name1,
user: User = Depends(current_user()),
user_db=Depends(get_user_db),
):
settings = json.loads(user.settings)
settings[f"{name0}:{name1}"] = await request.json()
Expand Down
8 changes: 6 additions & 2 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from fps_auth.routes import cookie_authentication, current_user # type: ignore
from fps_auth.models import User # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.db import get_user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

from .kernel_server.server import KernelServer # type: ignore
Expand Down Expand Up @@ -167,7 +167,11 @@ async def restart_kernel(

@router.websocket("/api/kernels/{kernel_id}/channels")
async def kernel_channels(
websocket: WebSocket, kernel_id, session_id, auth_config=Depends(get_auth_config)
websocket: WebSocket,
kernel_id,
session_id,
auth_config=Depends(get_auth_config),
user_db=Depends(get_user_db),
):
accept_websocket = False
if auth_config.mode == "noauth":
Expand Down
7 changes: 5 additions & 2 deletions plugins/terminals/fps_terminals/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fps_auth.routes import cookie_authentication, current_user # type: ignore
from fps_auth.models import User # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.db import get_user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

from .models import Terminal
Expand Down Expand Up @@ -51,7 +51,10 @@ async def delete_terminal(

@router.websocket("/terminals/websocket/{name}")
async def terminal_websocket(
websocket: WebSocket, name, auth_config=Depends(get_auth_config)
websocket: WebSocket,
name,
auth_config=Depends(get_auth_config),
user_db=Depends(get_user_db),
):
accept_websocket = False
if auth_config.mode == "noauth":
Expand Down
8 changes: 6 additions & 2 deletions plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import fastapi

from fps_auth.routes import cookie_authentication # type: ignore
from fps_auth.db import user_db # type: ignore
from fps_auth.db import get_user_db # type: ignore
from fps_auth.config import get_auth_config # type: ignore

router = APIRouter()
Expand All @@ -25,7 +25,11 @@ def get_path_param_names(path: str) -> Set[str]:

@router.websocket("/api/yjs/{type}:{path:path}")
async def websocket_endpoint(
websocket: WebSocket, type, path, auth_config=Depends(get_auth_config)
websocket: WebSocket,
type,
path,
auth_config=Depends(get_auth_config),
user_db=Depends(get_user_db),
):
accept_websocket = False
if auth_config.mode == "noauth":
Expand Down

0 comments on commit c846f48

Please sign in to comment.