Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GatewayKernelManager to derive from AsyncMappingKernelManager #5966

Merged
merged 1 commit into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 40 additions & 53 deletions notebook/gateway/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import json

from socket import gaierror
from tornado import gen, web
from tornado import web
from tornado.escape import json_encode, json_decode, url_escape
from tornado.httpclient import HTTPClient, AsyncHTTPClient, HTTPError

from ..services.kernels.kernelmanager import MappingKernelManager
from ..services.kernels.kernelmanager import AsyncMappingKernelManager
from ..services.sessions.sessionmanager import SessionManager

from jupyter_client.kernelspec import KernelSpecManager
Expand Down Expand Up @@ -303,13 +303,12 @@ def load_connection_args(self, **kwargs):
return kwargs


@gen.coroutine
def gateway_request(endpoint, **kwargs):
async def gateway_request(endpoint, **kwargs):
"""Make an async request to kernel gateway endpoint, returns a response """
client = AsyncHTTPClient()
kwargs = GatewayClient.instance().load_connection_args(**kwargs)
try:
response = yield client.fetch(endpoint, **kwargs)
response = await client.fetch(endpoint, **kwargs)
# Trap a set of common exceptions so that we can inform the user that their Gateway url is incorrect
# or the server is not running.
# NOTE: We do this here since this handler is called during the Notebook's startup and subsequent refreshes
Expand All @@ -332,10 +331,10 @@ def gateway_request(endpoint, **kwargs):
"url is valid and the Gateway instance is running.".format(GatewayClient.instance().url)
) from e

raise gen.Return(response)
return response


class GatewayKernelManager(MappingKernelManager):
class GatewayKernelManager(AsyncMappingKernelManager):
"""Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway."""

# We'll maintain our own set of kernel ids
Expand Down Expand Up @@ -367,8 +366,7 @@ def _get_kernel_endpoint_url(self, kernel_id=None):

return self.base_endpoint

@gen.coroutine
def start_kernel(self, kernel_id=None, path=None, **kwargs):
async def start_kernel(self, kernel_id=None, path=None, **kwargs):
"""Start a kernel for a session and return its kernel_id.

Parameters
Expand Down Expand Up @@ -403,21 +401,20 @@ def start_kernel(self, kernel_id=None, path=None, **kwargs):

json_body = json_encode({'name': kernel_name, 'env': kernel_env})

response = yield gateway_request(kernel_url, method='POST', body=json_body)
response = await gateway_request(kernel_url, method='POST', body=json_body)
kernel = json_decode(response.body)
kernel_id = kernel['id']
self.log.info("Kernel started: %s" % kernel_id)
self.log.debug("Kernel args: %r" % kwargs)
else:
kernel = yield self.get_kernel(kernel_id)
kernel = await self.get_kernel(kernel_id)
kernel_id = kernel['id']
self.log.info("Using existing kernel: %s" % kernel_id)

self._kernels[kernel_id] = kernel
raise gen.Return(kernel_id)
return kernel_id

@gen.coroutine
def get_kernel(self, kernel_id=None, **kwargs):
async def get_kernel(self, kernel_id=None, **kwargs):
"""Get kernel for kernel_id.

Parameters
Expand All @@ -428,7 +425,7 @@ def get_kernel(self, kernel_id=None, **kwargs):
kernel_url = self._get_kernel_endpoint_url(kernel_id)
self.log.debug("Request kernel at: %s" % kernel_url)
try:
response = yield gateway_request(kernel_url, method='GET')
response = await gateway_request(kernel_url, method='GET')
except web.HTTPError as error:
if error.status_code == 404:
self.log.warn("Kernel not found at: %s" % kernel_url)
Expand All @@ -440,10 +437,9 @@ def get_kernel(self, kernel_id=None, **kwargs):
kernel = json_decode(response.body)
self._kernels[kernel_id] = kernel
self.log.debug("Kernel retrieved: %s" % kernel)
raise gen.Return(kernel)
return kernel

@gen.coroutine
def kernel_model(self, kernel_id):
async def kernel_model(self, kernel_id):
"""Return a dictionary of kernel information described in the
JSON standard model.

Expand All @@ -453,21 +449,19 @@ def kernel_model(self, kernel_id):
The uuid of the kernel.
"""
self.log.debug("RemoteKernelManager.kernel_model: %s", kernel_id)
model = yield self.get_kernel(kernel_id)
raise gen.Return(model)
model = await self.get_kernel(kernel_id)
return model

@gen.coroutine
def list_kernels(self, **kwargs):
async def list_kernels(self, **kwargs):
"""Get a list of kernels."""
kernel_url = self._get_kernel_endpoint_url()
self.log.debug("Request list kernels: %s", kernel_url)
response = yield gateway_request(kernel_url, method='GET')
response = await gateway_request(kernel_url, method='GET')
kernels = json_decode(response.body)
self._kernels = {x['id']: x for x in kernels}
raise gen.Return(kernels)
return kernels

@gen.coroutine
def shutdown_kernel(self, kernel_id, now=False, restart=False):
async def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""Shutdown a kernel by its kernel uuid.

Parameters
Expand All @@ -481,12 +475,11 @@ def shutdown_kernel(self, kernel_id, now=False, restart=False):
"""
kernel_url = self._get_kernel_endpoint_url(kernel_id)
self.log.debug("Request shutdown kernel at: %s", kernel_url)
response = yield gateway_request(kernel_url, method='DELETE')
response = await gateway_request(kernel_url, method='DELETE')
self.log.debug("Shutdown kernel response: %d %s", response.code, response.reason)
self.remove_kernel(kernel_id)

@gen.coroutine
def restart_kernel(self, kernel_id, now=False, **kwargs):
async def restart_kernel(self, kernel_id, now=False, **kwargs):
"""Restart a kernel by its kernel uuid.

Parameters
Expand All @@ -496,11 +489,10 @@ def restart_kernel(self, kernel_id, now=False, **kwargs):
"""
kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/restart'
self.log.debug("Request restart kernel at: %s", kernel_url)
response = yield gateway_request(kernel_url, method='POST', body=json_encode({}))
response = await gateway_request(kernel_url, method='POST', body=json_encode({}))
self.log.debug("Restart kernel response: %d %s", response.code, response.reason)

@gen.coroutine
def interrupt_kernel(self, kernel_id, **kwargs):
async def interrupt_kernel(self, kernel_id, **kwargs):
"""Interrupt a kernel by its kernel uuid.

Parameters
Expand All @@ -510,7 +502,7 @@ def interrupt_kernel(self, kernel_id, **kwargs):
"""
kernel_url = self._get_kernel_endpoint_url(kernel_id) + '/interrupt'
self.log.debug("Request interrupt kernel at: %s", kernel_url)
response = yield gateway_request(kernel_url, method='POST', body=json_encode({}))
response = await gateway_request(kernel_url, method='POST', body=json_encode({}))
self.log.debug("Interrupt kernel response: %d %s", response.code, response.reason)

def shutdown_all(self, now=False):
Expand Down Expand Up @@ -565,9 +557,8 @@ def _get_kernelspecs_endpoint_url(self, kernel_name=None):

return self.base_endpoint

@gen.coroutine
def get_all_specs(self):
fetched_kspecs = yield self.list_kernel_specs()
async def get_all_specs(self):
fetched_kspecs = await self.list_kernel_specs()

# get the default kernel name and compare to that of this server.
# If different log a warning and reset the default. However, the
Expand All @@ -583,19 +574,17 @@ def get_all_specs(self):
km.default_kernel_name = remote_default_kernel_name

remote_kspecs = fetched_kspecs.get('kernelspecs')
raise gen.Return(remote_kspecs)
return remote_kspecs

@gen.coroutine
def list_kernel_specs(self):
async def list_kernel_specs(self):
"""Get a list of kernel specs."""
kernel_spec_url = self._get_kernelspecs_endpoint_url()
self.log.debug("Request list kernel specs at: %s", kernel_spec_url)
response = yield gateway_request(kernel_spec_url, method='GET')
response = await gateway_request(kernel_spec_url, method='GET')
kernel_specs = json_decode(response.body)
raise gen.Return(kernel_specs)
return kernel_specs

@gen.coroutine
def get_kernel_spec(self, kernel_name, **kwargs):
async def get_kernel_spec(self, kernel_name, **kwargs):
"""Get kernel spec for kernel_name.

Parameters
Expand All @@ -606,7 +595,7 @@ def get_kernel_spec(self, kernel_name, **kwargs):
kernel_spec_url = self._get_kernelspecs_endpoint_url(kernel_name=str(kernel_name))
self.log.debug("Request kernel spec at: %s" % kernel_spec_url)
try:
response = yield gateway_request(kernel_spec_url, method='GET')
response = await gateway_request(kernel_spec_url, method='GET')
except web.HTTPError as error:
if error.status_code == 404:
# Convert not found to KeyError since that's what the Notebook handler expects
Expand All @@ -620,10 +609,9 @@ def get_kernel_spec(self, kernel_name, **kwargs):
else:
kernel_spec = json_decode(response.body)

raise gen.Return(kernel_spec)
return kernel_spec

@gen.coroutine
def get_kernel_spec_resource(self, kernel_name, path):
async def get_kernel_spec_resource(self, kernel_name, path):
"""Get kernel spec for kernel_name.

Parameters
Expand All @@ -636,22 +624,21 @@ def get_kernel_spec_resource(self, kernel_name, path):
kernel_spec_resource_url = url_path_join(self.base_resource_endpoint, str(kernel_name), str(path))
self.log.debug("Request kernel spec resource '{}' at: {}".format(path, kernel_spec_resource_url))
try:
response = yield gateway_request(kernel_spec_resource_url, method='GET')
response = await gateway_request(kernel_spec_resource_url, method='GET')
except web.HTTPError as error:
if error.status_code == 404:
kernel_spec_resource = None
else:
raise
else:
kernel_spec_resource = response.body
raise gen.Return(kernel_spec_resource)
return kernel_spec_resource


class GatewaySessionManager(SessionManager):
kernel_manager = Instance('notebook.gateway.managers.GatewayKernelManager')

@gen.coroutine
def kernel_culled(self, kernel_id):
async def kernel_culled(self, kernel_id):
"""Checks if the kernel is still considered alive and returns true if its not found. """
kernel = yield self.kernel_manager.get_kernel(kernel_id)
raise gen.Return(kernel is None)
kernel = await self.kernel_manager.get_kernel(kernel_id)
return kernel is None
35 changes: 17 additions & 18 deletions notebook/tests/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def generate_model(name):
return model


@gen.coroutine
def mock_gateway_request(url, **kwargs):
async def mock_gateway_request(url, **kwargs):
method = 'GET'
if kwargs['method']:
method = kwargs['method']
Expand All @@ -51,17 +50,17 @@ def mock_gateway_request(url, **kwargs):
# Fetch all kernelspecs
if endpoint.endswith('/api/kernelspecs') and method == 'GET':
response_buf = StringIO(json.dumps(kernelspecs))
response = yield maybe_future(HTTPResponse(request, 200, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 200, buffer=response_buf))
return response

# Fetch named kernelspec
if endpoint.rfind('/api/kernelspecs/') >= 0 and method == 'GET':
requested_kernelspec = endpoint.rpartition('/')[2]
kspecs = kernelspecs.get('kernelspecs')
if requested_kernelspec in kspecs:
response_buf = StringIO(json.dumps(kspecs.get(requested_kernelspec)))
response = yield maybe_future(HTTPResponse(request, 200, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 200, buffer=response_buf))
return response
else:
raise HTTPError(404, message='Kernelspec does not exist: %s' % requested_kernelspec)

Expand All @@ -75,8 +74,8 @@ def mock_gateway_request(url, **kwargs):
model = generate_model(name)
running_kernels[model.get('id')] = model # Register model as a running kernel
response_buf = StringIO(json.dumps(model))
response = yield maybe_future(HTTPResponse(request, 201, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 201, buffer=response_buf))
return response

# Fetch list of running kernels
if endpoint.endswith('/api/kernels') and method == 'GET':
Expand All @@ -85,24 +84,24 @@ def mock_gateway_request(url, **kwargs):
model = running_kernels.get(kernel_id)
kernels.append(model)
response_buf = StringIO(json.dumps(kernels))
response = yield maybe_future(HTTPResponse(request, 200, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 200, buffer=response_buf))
return response

# Interrupt or restart existing kernel
if endpoint.rfind('/api/kernels/') >= 0 and method == 'POST':
requested_kernel_id, sep, action = endpoint.rpartition('/api/kernels/')[2].rpartition('/')

if action == 'interrupt':
if requested_kernel_id in running_kernels:
response = yield maybe_future(HTTPResponse(request, 204))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 204))
return response
else:
raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)
elif action == 'restart':
if requested_kernel_id in running_kernels:
response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id)))
response = yield maybe_future(HTTPResponse(request, 204, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 204, buffer=response_buf))
return response
else:
raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)
else:
Expand All @@ -112,16 +111,16 @@ def mock_gateway_request(url, **kwargs):
if endpoint.rfind('/api/kernels/') >= 0 and method == 'DELETE':
requested_kernel_id = endpoint.rpartition('/')[2]
running_kernels.pop(requested_kernel_id) # Simulate shutdown by removing kernel from running set
response = yield maybe_future(HTTPResponse(request, 204))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 204))
return response

# Fetch existing kernel
if endpoint.rfind('/api/kernels/') >= 0 and method == 'GET':
requested_kernel_id = endpoint.rpartition('/')[2]
if requested_kernel_id in running_kernels:
response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id)))
response = yield maybe_future(HTTPResponse(request, 200, buffer=response_buf))
raise gen.Return(response)
response = await maybe_future(HTTPResponse(request, 200, buffer=response_buf))
return response
else:
raise HTTPError(404, message='Kernel does not exist: %s' % requested_kernel_id)

Expand Down