Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
trungleduc committed Mar 20, 2023
1 parent 1ba842f commit 709b44a
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 29 deletions.
154 changes: 125 additions & 29 deletions voila/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import threading
import webbrowser

from .voila_identity_provider import VoilaLoginHandler

try:
from urllib.parse import urljoin
from urllib.request import pathname2url
Expand All @@ -35,12 +37,29 @@
from jupyter_core.paths import jupyter_config_path, jupyter_path
from jupyter_server.base.handlers import FileFindHandler, path_regex
from jupyter_server.config_manager import recursive_update
from jupyter_server.services.config import ConfigManager
from jupyter_server.services.config.manager import ConfigManager
from jupyter_server.services.contents.largefilemanager import LargeFileManager
from jupyter_server.services.kernels.handlers import KernelHandler, ZMQChannelsHandler
from jupyter_server.utils import run_sync, url_path_join
from jupyter_server.services.kernels.handlers import KernelHandler
from jupyter_server.services.kernels.websocket import KernelWebsocketHandler
from jupyter_server.auth.authorizer import AllowAllAuthorizer, Authorizer
from jupyter_server.auth.identity import PasswordIdentityProvider
from jupyter_server import DEFAULT_TEMPLATE_PATH_LIST, DEFAULT_STATIC_FILES_PATH
from jupyter_server.services.kernels.connection.base import (
BaseKernelWebsocketConnection,
)
from jupyter_server.services.kernels.connection.channels import (
ZMQChannelsWebsocketConnection,
)
from jupyter_server.auth.identity import (
IdentityProvider,
)
from jupyter_server.utils import url_path_join
from jupyter_core.utils import run_sync

from jupyterlab_server.themes_handler import ThemesHandler
from traitlets import Bool, Callable, Dict, Integer, List, Unicode, default


from traitlets import Bool, Callable, Dict, Integer, List, Unicode, default, Type, Bytes
from traitlets.config.application import Application
from traitlets.config.loader import Config

Expand Down Expand Up @@ -272,6 +291,43 @@ def hook(req: tornado.web.RequestHandler,
),
)

cookie_secret = Bytes(
b"",
config=True,
help="""The random bytes used to secure cookies.
By default this is a new random number every time you start the server.
Set it to a value in a config file to enable logins to persist across server sessions.
Note: Cookie secrets should be kept private, do not share config files with
cookie_secret stored in plaintext (you can read the value from a file).
""",
)

@default("cookie_secret")
def _default_cookie_secret(self):
return os.urandom(32)

authorizer_class = Type(
default_value=AllowAllAuthorizer,
klass=Authorizer,
config=True,
help=_("The authorizer class to use."),
)

identity_provider_class = Type(
default_value=PasswordIdentityProvider,
klass=IdentityProvider,
config=True,
help=_("The identity provider class to use."),
)

kernel_websocket_connection_class = Type(
default_value=ZMQChannelsWebsocketConnection,
klass=BaseKernelWebsocketConnection,
config=True,
help=_("The kernel websocket connection class to use."),
)

@property
def display_url(self):
if self.custom_display_url:
Expand All @@ -282,13 +338,17 @@ def display_url(self):
ip = "%s" % socket.gethostname() if self.ip in ("", "0.0.0.0") else self.ip
url = self._url(ip)
# TODO: do we want to have the token?
# if self.token:
# # Don't log full token if it came from config
# token = self.token if self._token_generated else '...'
# url = (url_concat(url, {'token': token})
# + '\n or '
# + url_concat(self._url('127.0.0.1'), {'token': token}))
return url
if self.identity_provider.token:
# Don't log full token if it came from config
token = (
self.identity_provider.token
if self.identity_provider.token_generated
else "..."
)
query = f"?token={token}"
else:
query = ""
return f"{url}{query}"

@property
def connection_url(self):
Expand Down Expand Up @@ -405,6 +465,7 @@ def setup_template_dirs(self):
self.static_paths = collect_static_paths(
["voila", "nbconvert"], template_name
)
self.static_paths.append(DEFAULT_STATIC_FILES_PATH)
conf_paths = [os.path.join(d, "conf.json") for d in self.template_paths]
for p in conf_paths:
# see if config file exists
Expand All @@ -428,17 +489,8 @@ def setup_template_dirs(self):
if self.notebook_path and not os.path.exists(self.notebook_path):
raise ValueError("Notebook not found: %s" % self.notebook_path)

def _handle_signal_stop(self, sig, frame):
self.log.info("Handle signal %s." % sig)
self.ioloop.add_callback_from_signal(self.ioloop.stop)

def start(self):
self.connection_dir = tempfile.mkdtemp(
prefix="voila_", dir=self.connection_dir_root
)
self.log.info("Storing connection files in %s." % self.connection_dir)
self.log.info("Serving static files from %s." % self.static_root)

def init_settings(self) -> Dict:
"""Initialize settings for Voila application."""
# default server_url to base_url
self.server_url = self.server_url or self.base_url

Expand Down Expand Up @@ -486,28 +538,55 @@ def start(self):
extensions=["jinja2.ext.i18n"],
**jenv_opt,
)
server_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(DEFAULT_TEMPLATE_PATH_LIST),
extensions=["jinja2.ext.i18n"],
**jenv_opt,
)

nbui = gettext.translation(
"nbui", localedir=os.path.join(ROOT, "i18n"), fallback=True
)
env.install_gettext_translations(nbui, newstyle=False)
server_env.install_gettext_translations(nbui, newstyle=False)

identity_provider_kwargs = {
"parent": self,
"log": self.log,
"login_handler_class": VoilaLoginHandler,
}
self.identity_provider = self.identity_provider_class(
**identity_provider_kwargs
)

self.app = tornado.web.Application(
self.authorizer = self.authorizer_class(
parent=self, log=self.log, identity_provider=self.identity_provider
)

settings = dict(
base_url=self.base_url,
server_url=self.server_url or self.base_url,
kernel_manager=self.kernel_manager,
kernel_spec_manager=self.kernel_spec_manager,
allow_remote_access=True,
autoreload=self.autoreload,
voila_jinja2_env=env,
jinja2_env=env,
jinja2_env=server_env,
static_path="/",
server_root_dir="/",
contents_manager=self.contents_manager,
config_manager=self.config_manager,
cookie_secret=self.cookie_secret,
authorizer=self.authorizer,
identity_provider=self.identity_provider,
kernel_websocket_connection_class=self.kernel_websocket_connection_class,
login_url=url_path_join(self.base_url, "/login"),
)

self.app.settings.update(self.tornado_settings)
return settings

def init_handlers(self) -> List:
"""Initialize handlers for Voila application."""
handlers = []

handlers.extend(
Expand All @@ -522,7 +601,7 @@ def start(self):
url_path_join(
self.server_url, r"/api/kernels/%s/channels" % _kernel_id_regex
),
ZMQChannelsHandler,
KernelWebsocketHandler,
),
(
url_path_join(self.server_url, r"/voila/templates/(.*)"),
Expand All @@ -549,8 +628,8 @@ def start(self):
),
]
)

if preheat_kernel:
handlers.extend(self.identity_provider.get_handlers())
if self.voila_configuration.preheat_kernel:
handlers.append(
(
url_path_join(
Expand Down Expand Up @@ -621,13 +700,30 @@ def start(self):
),
]
)
return handlers

def start(self):
self.connection_dir = tempfile.mkdtemp(
prefix="voila_", dir=self.connection_dir_root
)
self.log.info("Storing connection files in %s." % self.connection_dir)
self.log.info("Serving static files from %s." % self.static_root)

settings = self.init_settings()

self.app = tornado.web.Application(**settings)
self.app.settings.update(self.tornado_settings)
handlers = self.init_handlers()
self.app.add_handlers(".*$", handlers)
self.listen()

def _handle_signal_stop(self, sig, frame):
self.log.info("Handle signal %s." % sig)
self.ioloop.add_callback_from_signal(self.ioloop.stop)

def stop(self):
shutil.rmtree(self.connection_dir)
run_sync(self.kernel_manager.shutdown_all())
run_sync(self.kernel_manager.shutdown_all)()

def random_ports(self, port, n):
"""Generate a list of n random ports near the given port.
Expand Down
21 changes: 21 additions & 0 deletions voila/voila_identity_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any, Optional
from jupyter_server.auth.identity import IdentityProvider
from jupyter_server.auth.login import LoginFormHandler


class VoilaLoginHandler(LoginFormHandler):
def static_url(
self, path: str, include_host: Optional[bool] = None, **kwargs: Any
) -> str:
settings = {
"static_url_prefix": "voila/static/",
"static_path": None,
}
return settings.get("static_url_prefix", "/static/") + path


class VoilaIdentityProvider(IdentityProvider):
@property
def auth_enabled(self) -> bool:
"""Return whether any auth is enabled"""
return bool(self.token)

0 comments on commit 709b44a

Please sign in to comment.