Skip to content

Commit

Permalink
allow get_user to be async
Browse files Browse the repository at this point in the history
careful to deprecate overridden get_current_user without ignoring auth
  • Loading branch information
minrk committed Apr 22, 2022
1 parent 282b4f1 commit a99947a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 14 deletions.
2 changes: 2 additions & 0 deletions jupyter_server/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class IdentityProvider(LoggingConfigurable):
"""
Interface for providing identity
_may_ be a coroutine.
Two principle methods:
- :meth:`~.IdentityProvider.get_user` returns a :class:`~.User` object
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def is_token_authenticated(cls, handler):
"""
if getattr(handler, "_user_id", None) is None:
# ensure get_user has been called, so we know if we're token-authenticated
handler.get_current_user()
handler.current_user
return getattr(handler, "_token_authenticated", False)

@classmethod
Expand Down
55 changes: 44 additions & 11 deletions jupyter_server/base/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Distributed under the terms of the Modified BSD License.
import datetime
import functools
import inspect
import ipaddress
import json
import mimetypes
Expand Down Expand Up @@ -134,7 +135,21 @@ def clear_login_cookie(self):
self.force_clear_cookie(self.cookie_name)

def get_current_user(self):
return self.identity_provider.get_user(self)
clsname = self.__class__.__name__
msg = (
f"Calling `{clsname}.get_current_user()` directly is deprecated in jupyter-server 2.0."
" Use `self.current_user` instead (works in all versions)."
)
if hasattr(self, "_current_jupyter_user"):
# backward-compat: return _current_jupyter_user
warnings.warn(
msg,
DeprecationWarning,
stacklevel=2,
)
return self._current_jupyter_user
# haven't called get_user in prepare, raise
raise RuntimeError(msg)

def skip_check_origin(self):
"""Ask my login_handler if I should skip the origin_check
Expand Down Expand Up @@ -164,7 +179,7 @@ def cookie_name(self):
@property
def logged_in(self):
"""Is a user currently logged in?"""
user = self.get_current_user()
user = self.current_user
return user and not user == "anonymous"

@property
Expand Down Expand Up @@ -543,9 +558,35 @@ def check_host(self):
)
return allow

def prepare(self):
async def prepare(self):
if not self.check_host():
raise web.HTTPError(403)

from jupyter_server.auth import IdentityProvider

if (
type(self.identity_provider) is IdentityProvider
and inspect.getmodule(self.get_current_user).name != "__name__"
):
# check for overridden get_current_user + default IdentityProvider
# deprecated way to override auth (e.g. JupyterHub < 3.0)
# allow deprecated, overridden get_current_user
warnings.warn(
"Overriding JupyterHandler.get_current_user is deprecated in jupyter-server 2.0."
" Use an IdentityProvider class.",
DeprecationWarning,
# stacklevel not useful here
)
user = self.get_current_user()
else:
user = self.identity_provider.get_user(self)
if inspect.isawaitable(user):
# IdentityProvider.get_user _may_ be async
user = await user

# self.current_user for tornado's @web.authenticated
# self._jupyter_user for backward-compat in deprecated get_current_user calls
self.current_user = self._jupyter_user = user
return super().prepare()

# ---------------------------------------------------------------
Expand Down Expand Up @@ -663,14 +704,6 @@ def write_error(self, status_code, **kwargs):
self.log.warning(reply["message"])
self.finish(json.dumps(reply))

def get_current_user(self):
"""Raise 403 on API handlers instead of redirecting to human login page"""
# preserve _user_cache so we don't raise more than once
if hasattr(self, "_user_cache"):
return self._user_cache
self._user_cache = user = super().get_current_user()
return user

def get_login_url(self):
# if get_login_url is invoked in an API handler,
# that means @web.authenticated is trying to trigger a redirect.
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/base/zmqhandlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def pre_get(self):
the websocket finishes completing.
"""
# authenticate the request before opening the websocket
user = self.get_current_user()
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/gateway/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def authenticate(self):
the websocket finishes completing.
"""
# authenticate the request before opening the websocket
if self.get_current_user() is None:
if self.current_user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)

Expand Down

0 comments on commit a99947a

Please sign in to comment.