diff --git a/daskernetes/core.py b/daskernetes/core.py index 19c2dfd40..8ed93bbd1 100644 --- a/daskernetes/core.py +++ b/daskernetes/core.py @@ -2,10 +2,15 @@ import logging import os import socket +import copy from urllib.parse import urlparse import uuid from weakref import finalize, ref +try: + import yaml +except ImportError: + yaml = False from tornado import gen from tornado.ioloop import IOLoop @@ -13,9 +18,12 @@ from distributed.deploy import LocalCluster from kubernetes import client, config +from daskernetes.objects import make_pod_from_dict + logger = logging.getLogger(__name__) + class KubeCluster(object): """ Launch a Dask cluster on Kubernetes @@ -30,10 +38,6 @@ class KubeCluster(object): namespace: str Namespace in which to launch the workers. Defaults to current namespace if available or "default" - worker_image: str - Docker image and tag - worker_labels: dict - Additional labels to add to pod n_workers: int Number of workers on initial launch. Use ``scale_up`` in the future threads_per_worker: int @@ -55,27 +59,29 @@ class KubeCluster(object): """ def __init__( self, + worker_pod_template, name=None, namespace=None, - worker_image='daskdev/dask:latest', - worker_labels=None, n_workers=0, - threads_per_worker=1, host='0.0.0.0', - port=8786, - env={}, + port=0, **kwargs, ): self.cluster = LocalCluster(ip=host or socket.gethostname(), scheduler_port=port, n_workers=0, **kwargs) + if name is None: + name = 'dask-%s-%s' % (getpass.getuser(), str(uuid.uuid4())[:10]) + + self.name = name + try: config.load_incluster_config() except config.ConfigException: config.load_kube_config() - self.api = client.CoreV1Api() + self.core_api = client.CoreV1Api() if namespace is None: namespace = _namespace_default() @@ -85,23 +91,31 @@ def __init__( self.namespace = namespace self.name = name - self.worker_image = worker_image - self.worker_labels = (worker_labels or {}).copy() - self.threads_per_worker = threads_per_worker - self.env = dict(env) + self.worker_pod_template = copy.deepcopy(worker_pod_template) # Default labels that can't be overwritten - self.worker_labels['org.pydata.dask/cluster-name'] = name - self.worker_labels['app'] = 'dask' - self.worker_labels['component'] = 'dask-worker' + self.worker_pod_template.metadata.labels['dask.pydata.org/cluster-name'] = name + self.worker_pod_template.metadata.labels['app'] = 'dask' + self.worker_pod_template.metadata.labels['component'] = 'dask-worker' - finalize(self, cleanup_pods, self.namespace, self.worker_labels) + finalize(self, cleanup_pods, self.namespace, worker_pod_template.metadata.labels) self._cached_widget = None if n_workers: self.scale_up(n_workers) + @classmethod + def from_dict(cls, pod_spec, **kwargs): + return cls(make_pod_from_dict(pod_spec), **kwargs) + + @classmethod + def from_yaml(cls, yaml_path, **kwargs): + if not yaml: + raise ImportError("PyYaml is required to use yaml functionality, please install it!") + with open(yaml_path) as f: + return cls.from_dict(yaml.safe_load(f)) + def _widget(self): """ Create IPython widget for display within a notebook """ if self._cached_widget: @@ -142,38 +156,25 @@ def scheduler_address(self): return self.scheduler.address def _make_pod(self): - return client.V1Pod( - metadata=client.V1ObjectMeta( - generate_name=self.name + '-', - labels=self.worker_labels - ), - spec=client.V1PodSpec( - restart_policy='Never', - containers=[ - client.V1Container( - name='dask-worker', - image=self.worker_image, - args=[ - 'dask-worker', - self.scheduler_address, - '--nthreads', str(self.threads_per_worker), - ], - env=[client.V1EnvVar(name=k, value=v) - for k, v in self.env.items()], - ) - ] - ) + pod = copy.deepcopy(self.worker_pod_template) + if pod.spec.containers[0].env is None: + pod.spec.containers[0].env = [] + pod.spec.containers[0].env.append( + client.V1EnvVar(name='DASK_SCHEDULER_ADDRESS', value=self.scheduler_address) ) + pod.metadata.generate_name = self.name + return pod + def pods(self): - return self.api.list_namespaced_pod( + return self.core_api.list_namespaced_pod( self.namespace, - label_selector=format_labels(self.worker_labels) + label_selector=format_labels(self.worker_pod_template.metadata.labels) ).items def logs(self, pod): - return self.api.read_namespaced_pod_log(pod.metadata.name, - pod.metadata.namespace) + return self.core_api.read_namespaced_pod_log(pod.metadata.name, + pod.metadata.namespace) def scale_up(self, n, **kwargs): """ @@ -181,8 +182,10 @@ def scale_up(self, n, **kwargs): """ pods = self.pods() - out = [self.api.create_namespaced_pod(self.namespace, self._make_pod()) - for _ in range(n - len(pods))] + out = [ + self.core_api.create_namespaced_pod(self.namespace, self._make_pod()) + for _ in range(n - len(pods)) + ] return out # fixme: wait for this to be ready before returning! @@ -208,7 +211,7 @@ def scale_down(self, workers): return for pod in to_delete: try: - self.api.delete_namespaced_pod( + self.core_api.delete_namespaced_pod( pod.metadata.name, self.namespace, client.V1DeleteOptions() @@ -227,7 +230,7 @@ def close(self): self.cluster.close() def __exit__(self, type, value, traceback): - cleanup_pods(self.namespace, self.worker_labels) + cleanup_pods(self.namespace, self.worker_pod_template.metadata.labels) self.cluster.__exit__(type, value, traceback) def __del__(self): @@ -242,9 +245,9 @@ def adapt(self): return Adaptive(self.scheduler, self) -def cleanup_pods(namespace, worker_labels): +def cleanup_pods(namespace, labels): api = client.CoreV1Api() - pods = api.list_namespaced_pod(namespace, label_selector=format_labels(worker_labels)) + pods = api.list_namespaced_pod(namespace, label_selector=format_labels(labels)) for pod in pods.items: try: api.delete_namespaced_pod(pod.metadata.name, namespace, diff --git a/daskernetes/objects.py b/daskernetes/objects.py new file mode 100644 index 000000000..62bed69d2 --- /dev/null +++ b/daskernetes/objects.py @@ -0,0 +1,164 @@ +""" +Convenience functions for creating pod templates. +""" +from kubernetes import client +from collections import namedtuple +import json + +try: + import yaml +except ImportError: + yaml = False + +# FIXME: ApiClient provides us serialize / deserialize methods, +# but unfortunately also starts a threadpool for no reason! This +# takes up resources, so we try to not make too many. +SERIALIZATION_API_CLIENT = client.ApiClient() + +def _set_k8s_attribute(obj, attribute, value): + """ + Set a specific value on a kubernetes object's attribute + + obj + an object from Kubernetes Python API client + attribute + Should be a Kubernetes API style attribute (with camelCase) + value + Can be anything (string, list, dict, k8s objects) that can be + accepted by the k8s python client + """ + current_value = None + attribute_name = None + # All k8s python client objects have an 'attribute_map' property + # which has as keys python style attribute names (api_client) + # and as values the kubernetes JSON API style attribute names + # (apiClient). We want to allow users to use the JSON API style attribute + # names only. + for python_attribute, json_attribute in obj.attribute_map.items(): + if json_attribute == attribute: + attribute_name = python_attribute + break + else: + raise ValueError('Attribute must be one of {}'.format(obj.attribute_map.values())) + + if hasattr(obj, attribute_name): + current_value = getattr(obj, attribute_name) + + if current_value is not None: + # This will ensure that current_value is something JSONable, + # so a dict, list, or scalar + current_value = SERIALIZATION_API_CLIENT.sanitize_for_serialization( + current_value + ) + + if isinstance(current_value, dict): + # Deep merge our dictionaries! + setattr(obj, attribute_name, merge_dictionaries(current_value, value)) + elif isinstance(current_value, list): + # Just append lists + setattr(obj, attribute_name, current_value + value) + else: + # Replace everything else + setattr(obj, attribute_name, value) + +def merge_dictionaries(a, b, path=None, update=True): + """ + Merge two dictionaries recursively. + + From https://stackoverflow.com/a/25270947 + """ + if path is None: path = [] + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + merge_dictionaries(a[key], b[key], path + [str(key)]) + elif a[key] == b[key]: + pass # same leaf value + elif isinstance(a[key], list) and isinstance(b[key], list): + for idx, val in enumerate(b[key]): + a[key][idx] = merge_dictionaries(a[key][idx], b[key][idx], path + [str(key), str(idx)], update=update) + elif update: + a[key] = b[key] + else: + raise Exception('Conflict at %s' % '.'.join(path + [str(key)])) + else: + a[key] = b[key] + return a + +def make_pod_spec( + image, + labels={}, + threads_per_worker=1, + env={}, + extra_container_config={}, + extra_pod_config={}, + memory_limit=None, + memory_request=None, + cpu_limit=None, + cpu_request=None, +): + """ + Create a pod template from various parameters passed in. + """ + pod = client.V1Pod( + metadata=client.V1ObjectMeta( + labels=labels + ), + spec=client.V1PodSpec( + restart_policy='Never', + containers=[ + client.V1Container( + name='dask-worker', + image=image, + args=[ + 'dask-worker', + '$(DASK_SCHEDULER_ADDRESS)', + '--nthreads', str(threads_per_worker), + ], + env=[client.V1EnvVar(name=k, value=v) + for k, v in env.items()], + ) + ] + ) + ) + + + resources = client.V1ResourceRequirements(limits={}, requests={}) + + if cpu_request: + resources.requests['cpu'] = cpu_request + if memory_request: + resources.requests['memory'] = memory_request + + if cpu_limit: + resources.limits['cpu'] = cpu_limit + if memory_limit: + resources.limits['memory'] = memory_limit + + pod.spec.containers[0].resources = resources + + for key, value in extra_container_config.items(): + _set_k8s_attribute( + pod.spec.containers[0], + key, + value + ) + + for key, value in extra_pod_config.items(): + _set_k8s_attribute( + pod.spec, + key, + value + ) + return pod + + +_FakeResponse = namedtuple('_FakeResponse', ['data']) + +def make_pod_from_dict(dict_): + # FIXME: We can't use the 'deserialize' function since + # that expects a response object! + return SERIALIZATION_API_CLIENT.deserialize( + _FakeResponse(data=json.dumps(dict_)), + client.V1Pod + ) diff --git a/daskernetes/tests/conftest.py b/daskernetes/tests/conftest.py new file mode 100644 index 000000000..0df968954 --- /dev/null +++ b/daskernetes/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest + +def pytest_addoption(parser): + parser.addoption( + "--worker-image", + help="Worker image to use for testing" + ) + +@pytest.fixture +def image_name(request): + worker_image = request.config.getoption('--worker-image') + if not worker_image: + pytest.fail("Need to pass --worker-image. Image must have same python setup as host!") + return + return worker_image diff --git a/daskernetes/tests/test_core.py b/daskernetes/tests/test_core.py index 31108c1b1..29e0b2b57 100644 --- a/daskernetes/tests/test_core.py +++ b/daskernetes/tests/test_core.py @@ -1,15 +1,18 @@ import getpass import os from time import sleep, time +import yaml +import tempfile import pytest from daskernetes import KubeCluster +from daskernetes.objects import make_pod_spec from dask.distributed import Client from distributed.utils_test import loop, inc -def test_basic(loop): - with KubeCluster(loop=loop) as cluster: +def test_basic(image_name, loop): + with KubeCluster(make_pod_spec(image_name), loop=loop) as cluster: cluster.scale_up(2) with Client(cluster) as client: future = client.submit(inc, 10) @@ -26,8 +29,8 @@ def test_basic(loop): assert all(client.has_what().values()) -def test_logs(loop): - with KubeCluster(loop=loop) as cluster: +def test_logs(image_name, loop): + with KubeCluster(make_pod_spec(image_name), loop=loop) as cluster: cluster.scale_up(2) start = time() @@ -40,9 +43,9 @@ def test_logs(loop): assert 'distributed.worker' in logs -def test_ipython_display(loop): +def test_ipython_display(image_name, loop): ipywidgets = pytest.importorskip('ipywidgets') - with KubeCluster(loop=loop) as cluster: + with KubeCluster(make_pod_spec(image_name), loop=loop) as cluster: cluster.scale_up(1) cluster._ipython_display_() box = cluster._cached_widget @@ -58,16 +61,16 @@ def test_ipython_display(loop): sleep(0.5) -def test_namespace(loop): - with KubeCluster(loop=loop) as cluster: +def test_namespace(image_name, loop): + with KubeCluster(make_pod_spec(image_name), loop=loop) as cluster: assert 'dask' in cluster.name assert getpass.getuser() in cluster.name - with KubeCluster(loop=loop, port=0) as cluster2: + with KubeCluster(make_pod_spec(image_name), loop=loop) as cluster2: assert cluster.name != cluster2.name -def test_adapt(loop): - with KubeCluster(loop=loop) as cluster: +def test_adapt(image_name, loop): + with KubeCluster(make_pod_spec(image_name), loop=loop) as cluster: cluster.adapt() with Client(cluster) as client: future = client.submit(inc, 10) @@ -80,11 +83,60 @@ def test_adapt(loop): assert time() < start + 10 -def test_env(loop): - with KubeCluster(loop=loop, env={'ABC': 'DEF'}) as cluster: +def test_env(image_name, loop): + with KubeCluster(make_pod_spec(image_name, env={'ABC': 'DEF'}), loop=loop) as cluster: cluster.scale_up(1) with Client(cluster) as client: while not cluster.scheduler.workers: sleep(0.1) env = client.run(lambda: dict(os.environ)) assert all(v['ABC'] == 'DEF' for v in env.values()) + +def test_pod_from_yaml(image_name, loop): + test_yaml = { + "kind": "Pod", + "metadata": { + "labels": { + "app": "dask", + "component": "dask-worker" + } + }, + "spec": { + "containers": [ + { + "args": [ + "dask-worker", + "$(DASK_SCHEDULER_ADDRESS)", + "--nthreads", + "1" + ], + "image": image_name, + "name": "dask-worker" + } + ] + } + } + + with tempfile.NamedTemporaryFile('w') as f: + yaml.safe_dump(test_yaml, f) + f.flush() + cluster = KubeCluster.from_yaml( + f.name, + loop=loop, + n_workers=0, + ) + + cluster.scale_up(2) + with Client(cluster) as client: + future = client.submit(inc, 10) + result = future.result() + assert result == 11 + + while len(cluster.scheduler.workers) < 2: + sleep(0.1) + + # Ensure that inter-worker communication works well + futures = client.map(inc, range(10)) + total = client.submit(sum, futures) + assert total.result() == sum(map(inc, range(10))) + assert all(client.has_what().values()) diff --git a/daskernetes/tests/test_objects.py b/daskernetes/tests/test_objects.py new file mode 100644 index 000000000..742a5a256 --- /dev/null +++ b/daskernetes/tests/test_objects.py @@ -0,0 +1,164 @@ +import pytest +from time import sleep +import yaml +import tempfile +from daskernetes import KubeCluster +from daskernetes.objects import make_pod_spec, make_pod_from_dict +from distributed.utils_test import loop, inc +from dask.distributed import Client + + +def test_extra_pod_config(image_name, loop): + """ + Test that our pod config merging process works fine + """ + cluster = KubeCluster( + make_pod_spec( + image_name, + extra_pod_config={ + 'automountServiceAccountToken': False + } + ), + loop=loop, + n_workers=0, + ) + + pod = cluster._make_pod() + + assert pod.spec.automount_service_account_token == False + +def test_extra_container_config(image_name, loop): + """ + Test that our container config merging process works fine + """ + cluster = KubeCluster( + make_pod_spec( + image_name, + extra_container_config={ + 'imagePullPolicy': 'IfNotReady', + 'securityContext': { + 'runAsUser': 0 + } + } + ), + loop=loop, + n_workers=0, + ) + + pod = cluster._make_pod() + + assert pod.spec.containers[0].image_pull_policy == 'IfNotReady' + assert pod.spec.containers[0].security_context == { + 'runAsUser': 0 + } + +def test_container_resources_config(image_name, loop): + """ + Test container resource requests / limits being set properly + """ + cluster = KubeCluster( + make_pod_spec( + image_name, + memory_request="1G", + memory_limit="2G", + cpu_limit="2" + ), + loop=loop, + n_workers=0, + ) + + pod = cluster._make_pod() + + assert pod.spec.containers[0].resources.requests['memory'] == '1G' + assert pod.spec.containers[0].resources.limits['memory'] == '2G' + assert pod.spec.containers[0].resources.limits['cpu'] == '2' + assert "cpu" not in pod.spec.containers[0].resources.requests + +def test_extra_container_config_merge(image_name, loop): + """ + Test that our container config merging process works recursively fine + """ + cluster = KubeCluster( + make_pod_spec( + image_name, + extra_container_config={ + "env": [ {"name": "BOO", "value": "FOO" } ], + "args": ["last-item"] + } + ), + loop=loop, + n_workers=0, + env={"TEST": "HI"}, + ) + + pod = cluster._make_pod() + + assert pod.spec.containers[0].env == [ + { "name": "TEST", "value": "HI"}, + { "name": "BOO", "value": "FOO"} + ] + + assert pod.spec.containers[0].args[-1] == "last-item" + + +def test_extra_container_config_merge(image_name, loop): + """ + Test that our container config merging process works recursively fine + """ + cluster = KubeCluster( + make_pod_spec( + image_name, + env={"TEST": "HI"}, + extra_container_config={ + "env": [ {"name": "BOO", "value": "FOO" } ], + "args": ["last-item"] + } + ), + loop=loop, + n_workers=0, + ) + + pod = cluster._make_pod() + + for e in [ + { "name": "TEST", "value": "HI"}, + { "name": "BOO", "value": "FOO"} + ]: + assert e in pod.spec.containers[0].env + + assert pod.spec.containers[0].args[-1] == "last-item" + + +def test_make_pod_from_dict(): + d = { + "kind": "Pod", + "metadata": { + "labels": { + "app": "dask", + "component": "dask-worker" + } + }, + "spec": { + "containers": [ + { + "args": [ + "dask-worker", + "$(DASK_SCHEDULER_ADDRESS)", + "--nthreads", + "1" + ], + "image": "image-name", + "name": "dask-worker", + "securityContext": {"capabilities": {"add": ["SYS_ADMIN"]}, + "privileged": True}, + } + ], + "restartPolicy": "Never", + } + } + + pod = make_pod_from_dict(d) + + assert pod.spec.restart_policy == 'Never' + assert pod.spec.containers[0].security_context.privileged + assert pod.spec.containers[0].security_context.capabilities.add == ['SYS_ADMIN']