Skip to content

Commit

Permalink
Reconcile connection information (#879)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin-bates authored Nov 15, 2022
1 parent 9f1c379 commit a21dd92
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 25 deletions.
86 changes: 63 additions & 23 deletions jupyter_client/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,9 @@ def write_connection_file(
cfg["signature_scheme"] = signature_scheme
cfg["kernel_name"] = kernel_name

# Prevent over-writing a file that has already been written with the same
# info. This is to prevent a race condition where the process has
# already been launched but has not yet read the connection file.
if os.path.exists(fname):
with open(fname) as f:
try:
data = json.load(f)
if data == cfg:
return fname, cfg
except Exception:
pass

# Only ever write this file as user read/writeable
# This would otherwise introduce a vulnerability as a file has secrets
# which would let others execute arbitrarily code as you
# which would let others execute arbitrary code as you
with secure_write(fname) as f:
f.write(json.dumps(cfg, indent=2))

Expand Down Expand Up @@ -580,18 +568,70 @@ def load_connection_info(self, info: KernelConnectionInfo) -> None:
if "signature_scheme" in info:
self.session.signature_scheme = info["signature_scheme"]

def _force_connection_info(self, info: KernelConnectionInfo) -> None:
"""Unconditionally loads connection info from a dict containing connection info.
def _reconcile_connection_info(self, info: KernelConnectionInfo) -> None:
"""Reconciles the connection information returned from the Provisioner.
Overwrites connection info-based attributes, regardless of their current values
and writes this information to the connection file.
Because some provisioners (like derivations of LocalProvisioner) may have already
written the connection file, this method needs to ensure that, if the connection
file exists, its contents match that of what was returned by the provisioner. If
the file does exist and its contents do not match, a ValueError is raised.
If the file does not exist, the connection information in 'info' is loaded into the
KernelManager and written to the file.
"""
# Reset current ports to 0 and indicate file has not been written to enable override
self._connection_file_written = False
for name in port_names:
setattr(self, name, 0)
self.load_connection_info(info)
self.write_connection_file()
# Prevent over-writing a file that has already been written with the same
# info. This is to prevent a race condition where the process has
# already been launched but has not yet read the connection file - as is
# the case with LocalProvisioners.
file_exists: bool = False
if os.path.exists(self.connection_file):
with open(self.connection_file) as f:
file_info = json.load(f)
# Prior to the following comparison, we need to adjust the value of "key" to
# be bytes, otherwise the comparison below will fail.
file_info["key"] = file_info["key"].encode()
if not self._equal_connections(info, file_info):
raise ValueError(
"Connection file already exists and does not match "
"the expected values returned from provisioner!"
)
file_exists = True

if not file_exists:
# Load the connection info and write out file. Note, this does not necessarily
# overwrite non-zero port values, so we'll validate afterward.
self.load_connection_info(info)
self.write_connection_file()

# Ensure what is in KernelManager is what we expect. This will catch issues if the file
# already existed, yet it's contents differed from the KernelManager's (and provisioner).
km_info = self.get_connection_info()
if not self._equal_connections(info, km_info):
raise ValueError(
"KernelManager's connection information already exists and does not match "
"the expected values returned from provisioner!"
)

@staticmethod
def _equal_connections(conn1: KernelConnectionInfo, conn2: KernelConnectionInfo) -> bool:
"""Compares pertinent keys of connection info data. Returns True if equivalent, False otherwise."""

pertinent_keys = [
"key",
"ip",
"stdin_port",
"iopub_port",
"shell_port",
"control_port",
"hb_port",
"transport",
"signature_scheme",
]

for key in pertinent_keys:
if conn1.get(key) != conn2.get(key):
return False
return True

# --------------------------------------------------------------------------
# Creating connected sockets
Expand Down
5 changes: 3 additions & 2 deletions jupyter_client/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,9 @@ async def _async_launch_kernel(self, kernel_cmd: t.List[str], **kw: t.Any) -> No
assert self.provisioner is not None
connection_info = await self.provisioner.launch_kernel(kernel_cmd, **kw)
assert self.provisioner.has_process
# Provisioner provides the connection information. Load into kernel manager and write file.
self._force_connection_info(connection_info)
# Provisioner provides the connection information. Load into kernel manager
# and write the connection file, if not already done.
self._reconcile_connection_info(connection_info)

_launch_kernel = run_sync(_async_launch_kernel)

Expand Down
66 changes: 66 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import os
from tempfile import TemporaryDirectory

import pytest
from jupyter_core.application import JupyterApp
from jupyter_core.paths import jupyter_runtime_dir

from jupyter_client import connect
from jupyter_client import KernelClient
from jupyter_client import KernelManager
from jupyter_client.consoleapp import JupyterConsoleApp
from jupyter_client.session import Session

Expand Down Expand Up @@ -235,3 +237,67 @@ def test_mixin_cleanup_random_ports():
assert not os.path.exists(filename)
for name in dc._random_port_names:
assert getattr(dc, name) == 0


param_values = [
(True, True, None),
(True, False, ValueError),
(False, True, None),
(False, False, ValueError),
]


@pytest.mark.parametrize("file_exists, km_matches, expected_exception", param_values)
def test_reconcile_connection_info(file_exists, km_matches, expected_exception):

expected_info = sample_info
mismatched_info = sample_info.copy()
mismatched_info["key"] = b"def456"
mismatched_info["shell_port"] = expected_info["shell_port"] + 42
mismatched_info["control_port"] = expected_info["control_port"] + 42

with TemporaryDirectory() as connection_dir:

cf = os.path.join(connection_dir, "kernel.json")
km = KernelManager()
km.connection_file = cf

if file_exists:
_, info = connect.write_connection_file(cf, **expected_info)
info["key"] = info["key"].encode() # set 'key' back to bytes

if km_matches:
# Let this be the case where the connection file exists, and the KM has matching
# values prior to reconciliation. This is the LocalProvisioner case.
provisioner_info = info
km.load_connection_info(provisioner_info)
else:
# Let this be the case where the connection file exists, the KM has no values
# prior to reconciliation, but the provisioner has returned different values
# and a ValueError is expected.
provisioner_info = mismatched_info
else: # connection file does not exist
if km_matches:
# Let this be the case where the connection file does not exist, NOR does the KM
# have any values of its own and reconciliation sets those values. This is the
# non-LocalProvisioner case.
provisioner_info = expected_info
else:
# Let this be the case where the connection file does not exist, yet the KM
# has values that do not match those returned from the provisioner and a
# ValueError is expected.
km.load_connection_info(expected_info)
provisioner_info = mismatched_info

if expected_exception is None:
km._reconcile_connection_info(provisioner_info)
km_info = km.get_connection_info()
assert km._equal_connections(km_info, provisioner_info)
else:
with pytest.raises(expected_exception) as ex:
km._reconcile_connection_info(provisioner_info)
if file_exists:
assert "Connection file already exists" in str(ex.value)
else:
assert "KernelManager's connection information already exists" in str(ex.value)
assert km._equal_connections(km.get_connection_info(), provisioner_info) is False

0 comments on commit a21dd92

Please sign in to comment.