From 4c1cd29c3525b2f6883bfbc52a03fd4982f45eee Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 17 Jun 2021 14:56:46 +0200 Subject: [PATCH] Initially impl. of typed computations --- distributed/protocol/computation.py | 247 ++++++++++++++++++++++++++++ distributed/protocol/serialize.py | 31 +++- distributed/recreate_tasks.py | 12 +- distributed/scheduler.py | 15 +- distributed/tests/test_client.py | 29 ++-- distributed/tests/test_scheduler.py | 17 -- distributed/worker.py | 154 +++++------------ 7 files changed, 342 insertions(+), 163 deletions(-) create mode 100644 distributed/protocol/computation.py diff --git a/distributed/protocol/computation.py b/distributed/protocol/computation.py new file mode 100644 index 00000000000..ea3b7c566fd --- /dev/null +++ b/distributed/protocol/computation.py @@ -0,0 +1,247 @@ +""" +This module implements graph computations based on the specification in Dask[1]: +> A computation may be one of the following: +> - Any key present in the Dask graph like `'x'` +> - Any other value like `1`, to be interpreted literally +> - A task like `(inc, 'x')` +> - A list of computations, like `[1, 'x', (inc, 'x')]` + +In order to support efficient and flexible task serialization, this module introduces +classes for computations, tasks, data, functions, etc. + +Notable Classes +--------------- + +- `PickledObject` - An object that are serialized using `protocol.pickle`. + This object isn't a computation by itself instead users can build computations + containing them. It is automatically de-serialized by the Worker before execution. + +- `Computation` - A computation that the Worker can execute. The Scheduler sees + this as a black box. + + - `Data(Computation)` - De-serialized data that does **not** contain any tasks or + `PickledObject`. + + - `Task(Computation)` - A de-serialized task ready for execution, possible nested. + - `PickledTask(Task)` - A task serialized using `protocol.pickle` and can contain + `PickledObject`. + +Notable Functions +----------------- + +- `typeset_dask_graph()` - Use to typeset a Dask graph, which wrap computations in + either the `Data` or `Task` class. This should be done before communicating the graph. + Note, this replaces the old `tlz.valmap(dumps_task, dsk)` operation. + +[1] +""" + +import threading +import warnings +from typing import Any, Callable, Dict, Iterable, Mapping, MutableMapping, Tuple + +import tlz + +from dask.core import istask +from dask.utils import apply, format_bytes + +from ..utils import LRU +from . import pickle + + +def identity(x, *args_ignored): + return x + + +def execute_task(task, *args_ignored): + """Evaluate a nested task + + >>> inc = lambda x: x + 1 + >>> execute_task((inc, 1)) + 2 + >>> execute_task((sum, [1, 2, (inc, 3)])) + 7 + """ + if istask(task): + func, args = task[0], task[1:] + return func(*map(execute_task, args)) + elif isinstance(task, list): + return list(map(execute_task, task)) + else: + return task + + +class PickledObject: + _value: bytes + + def __init__(self, value: bytes): + self._value = value + + def __reduce__(self): + return (type(self), (self._value,)) + + @property + def value(self): + return self._value + + @classmethod + def serialize(cls, obj) -> "PickledObject": + return cls(pickle.dumps(obj)) + + def deserialize(self): + return pickle.loads(self._value) + + +class PickledCallable(PickledObject): + cache_dumps: MutableMapping[int, bytes] = LRU(maxsize=100) + cache_loads: MutableMapping[int, Callable] = LRU(maxsize=100) + cache_max_sized_obj = 1_000_000 + cache_dumps_lock = threading.Lock() + + @classmethod + def dumps_function(cls, func: Callable) -> bytes: + """Dump a function to bytes, cache functions""" + + try: + with cls.cache_dumps_lock: + ret = cls.cache_dumps[func] + except KeyError: + ret = pickle.dumps(func) + if len(ret) <= cls.cache_max_sized_obj: + with cls.cache_dumps_lock: + cls.cache_dumps[func] = ret + except TypeError: # Unhashable function + ret = pickle.dumps(func) + return ret + + @classmethod + def loads_function(cls, dumped_func: bytes): + """Load a function from bytes, cache bytes""" + if len(dumped_func) > cls.cache_max_sized_obj: + return pickle.loads(dumped_func) + + try: + ret = cls.cache_loads[dumped_func] + except KeyError: + cls.cache_loads[dumped_func] = ret = pickle.loads(dumped_func) + return ret + + @classmethod + def serialize(cls, func: Callable) -> "PickledCallable": + if isinstance(func, cls): + return func + else: + return cls(cls.dumps_function(func)) + + def deserialize(self) -> Callable: + return self.loads_function(self._value) + + def __call__(self, *args, **kwargs): + return self.deserialize()(*args, **kwargs) + + +class Computation: + def __init__(self, value): + self._value = value + + @property + def value(self): + return self._value + + def get_computation(self) -> Tuple[Callable, Iterable, Mapping]: + assert False, self._value + + +class Data(Computation): + def get_computation(self) -> Tuple[Callable, Iterable, Mapping]: + return (identity, (self._value,), {}) + + +class Task(Computation): + def get_computation(self) -> Tuple[Callable, Iterable, Mapping]: + return (execute_task, (self._value,), {}) + + +class PickledTask(Task): + _size_warning_triggered: bool = False + _size_warning_limit: int = 1_000_000 + + def __init__(self, serialized_data: bytes): + self._value: bytes = serialized_data + + @classmethod + def serialize(cls, value) -> "PickledTask": + data = pickle.dumps(value) + ret = cls(data) + if not cls._size_warning_triggered and len(data) > cls._size_warning_limit: + cls._size_warning_triggered = True + s = str(value) + if len(s) > 70: + s = s[:50] + " ... " + s[-15:] + warnings.warn( + "Large object of size %s detected in task graph: \n" + " %s\n" + "Consider scattering large objects ahead of time\n" + "with client.scatter to reduce scheduler burden and \n" + "keep data on workers\n\n" + " future = client.submit(func, big_data) # bad\n\n" + " big_future = client.scatter(big_data) # good\n" + " future = client.submit(func, big_future) # good" + % (format_bytes(len(data)), s) + ) + return ret + + def deserialize(self): + def inner_deserialize(obj): + if isinstance(obj, list): + return [inner_deserialize(o) for o in obj] + elif istask(obj): + return tuple(inner_deserialize(o) for o in obj) + elif isinstance(obj, PickledObject): + return obj.deserialize() + else: + return obj + + return inner_deserialize(pickle.loads(self._value)) + + def get_task(self): + return Task(self.deserialize()) + + def get_computation(self) -> Tuple[Callable, Iterable, Mapping]: + return self.get_task().get_computation() + + +def typeset_computation(computation) -> Computation: + from .serialize import Serialize, Serialized + + if isinstance(computation, Computation): + return computation # Already a computation + + contain_tasks = [False] + + def serialize_callables(obj): + if isinstance(obj, list): + return [serialize_callables(o) for o in obj] + elif istask(obj): + contain_tasks[0] = True + if obj[0] is apply: + return (apply, PickledCallable.serialize(obj[1])) + tuple( + map(serialize_callables, obj[2:]) + ) + else: + return (PickledCallable.serialize(obj[0]),) + tuple( + map(serialize_callables, obj[1:]) + ) + else: + assert not isinstance(obj, (Serialize, Serialized)) + return obj + + computation = serialize_callables(computation) + if contain_tasks[0]: + return PickledTask.serialize(computation) + else: + return Data(Serialize(computation)) + + +def typeset_dask_graph(dsk: Mapping[str, Any]) -> Dict[str, Computation]: + return tlz.valmap(typeset_computation, dsk) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index c88d36a8995..f33ed4c5e06 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -12,6 +12,7 @@ from ..utils import ensure_bytes, has_keyword, typename from . import pickle from .compression import decompress, maybe_compress +from .computation import Data, PickledCallable, PickledObject, PickledTask from .utils import frame_split_size, msgpack_opts, pack_frames_prelude, unpack_frames lazy_registrations = {} @@ -118,13 +119,17 @@ def msgpack_decode_default(obj): if "__Set__" in obj: return set(obj["as-list"]) - if "__Serialized__" in obj: - # Notice, the data here is marked a Serialized rather than deserialized. This - # is because deserialization requires Pickle which the Scheduler cannot run - # because of security reasons. - # By marking it Serialized, the data is passed through to the workers that - # eventually will deserialize it. - return Serialized(*obj["data"]) + if "__PickledTask__" in obj: + return PickledTask(obj["value"]) + + if "__PickledCallable__" in obj: + return PickledCallable(obj["value"]) + + if "__PickledObject__" in obj: + return PickledObject(obj["value"]) + + if "__Data__" in obj: + return Data(obj["value"]) return obj @@ -148,6 +153,18 @@ def msgpack_encode_default(obj): if isinstance(obj, set): return {"__Set__": True, "as-list": list(obj)} + if isinstance(obj, PickledTask): + return {"__PickledTask__": True, "value": obj.value} + + if isinstance(obj, PickledCallable): + return {"__PickledCallable__": True, "value": obj.value} + + if isinstance(obj, PickledObject): + return {"__PickledObject__": True, "value": obj.value} + + if isinstance(obj, Data): + return {"__Data__": True, "value": obj.value} + return obj diff --git a/distributed/recreate_tasks.py b/distributed/recreate_tasks.py index ec596bc4614..d1e8758151e 100644 --- a/distributed/recreate_tasks.py +++ b/distributed/recreate_tasks.py @@ -2,10 +2,11 @@ from dask.utils import stringify +from distributed.protocol.computation import PickledTask + from .client import futures_of, wait from .utils import sync from .utils_comm import pack_data -from .worker import _deserialize logger = logging.getLogger(__name__) @@ -83,12 +84,9 @@ async def _get_raw_components_from_future(self, future): key = future.key spec = await self.scheduler.get_runspec(key=key) deps, task = spec["deps"], spec["task"] - if isinstance(task, dict): - function, args, kwargs = _deserialize(**task) - return (function, args, kwargs, deps) - else: - function, args, kwargs = _deserialize(task=task) - return (function, args, kwargs, deps) + assert isinstance(task, PickledTask) + function, args, kwargs = task.get_computation() + return (function, args, kwargs, deps) async def _prepare_raw_components(self, raw_components): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a8571e8b627..f3e8a44ee27 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -58,6 +58,7 @@ from .multi_lock import MultiLockExtension from .node import ServerNode from .proctitle import setproctitle +from .protocol.computation import Computation from .publish import PublishExtension from .pubsub import PubSubSchedulerExtension from .queues import QueueExtension @@ -393,7 +394,7 @@ class WorkerState: .. attribute:: processing: {TaskState: cost} A dictionary of tasks that have been submitted to this worker. - Each task state is asssociated with the expected cost in seconds + Each task state is associated with the expected cost in seconds of running that task, summing both the task's expected computation time and the expected communication time of its result. @@ -408,7 +409,7 @@ class WorkerState: .. attribute:: executing: {TaskState: duration} A dictionary of tasks that are currently being run on this worker. - Each task state is asssociated with the duration in seconds which + Each task state is associated with the duration in seconds which the task has been running. .. attribute:: has_what: {TaskState} @@ -7379,11 +7380,15 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> if duration < 0: duration = state.get_task_duration(ts) + run_spec = ts._run_spec + assert run_spec is None or isinstance(run_spec, Computation) + msg: dict = { "op": "compute-task", "key": ts._key, "priority": ts._priority, "duration": duration, + "runspec": run_spec, } if ts._resource_restrictions: msg["resource_restrictions"] = ts._resource_restrictions @@ -7400,12 +7405,6 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> if state._validate: assert all(msg["who_has"].values()) - task = ts._run_spec - if type(task) is dict: - msg.update(task) - else: - msg["task"] = task - if ts._annotations: msg["annotations"] = ts._annotations return msg diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index ec47479e157..f908aba2701 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -66,7 +66,14 @@ Scheduler, ) from distributed.sizeof import sizeof -from distributed.utils import is_valid_xml, mp_context, sync, tmp_text, tmpfile +from distributed.utils import ( + import_term, + is_valid_xml, + mp_context, + sync, + tmp_text, + tmpfile, +) from distributed.utils_test import ( TaskStateMetadataPlugin, async_wait_for, @@ -1560,12 +1567,12 @@ def g(): @gen_cluster(client=True) async def test_upload_file_refresh_delayed(c, s, a, b): with save_sys_modules(): - for value in [123, 456]: - with tmp_text("myfile.py", "def f():\n return {}".format(value)) as fn: + for i, value in enumerate([123, 456]): + with tmp_text(f"myfile{i}.py", f"def f():\n return {value}") as fn: await c.upload_file(fn) sys.path.append(os.path.dirname(fn)) - from myfile import f + f = import_term(f"myfile{i}.f") b = delayed(f)() bb = c.compute(b, sync=False) @@ -4693,8 +4700,7 @@ async def test_recreate_error_delayed(c, s, a, b): error_f = await c._get_errored_future(f) function, args, kwargs = await c._get_components_from_future(error_f) assert f.status == "error" - assert function.__name__ == "div" - assert args == (1, 0) + assert args == ((div, 1, 0),) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4713,8 +4719,7 @@ async def test_recreate_error_futures(c, s, a, b): error_f = await c._get_errored_future(f) function, args, kwargs = await c._get_components_from_future(error_f) assert f.status == "error" - assert function.__name__ == "div" - assert args == (1, 0) + assert args == ((div, 1, 0),) with pytest.raises(ZeroDivisionError): function(*args, **kwargs) @@ -4801,8 +4806,7 @@ async def test_recreate_task_delayed(c, s, a, b): function, args, kwargs = await c._get_components_from_future(f) assert f.status == "finished" - assert function.__name__ == "sum" - assert args == ([1, 1],) + assert args == ((sum, [1, 1]),) assert function(*args, **kwargs) == 2 @@ -4819,8 +4823,7 @@ async def test_recreate_task_futures(c, s, a, b): function, args, kwargs = await c._get_components_from_future(f) assert f.status == "finished" - assert function.__name__ == "sum" - assert args == ([1, 1],) + assert args == ((sum, [1, 1]),) assert function(*args, **kwargs) == 2 @@ -5619,7 +5622,7 @@ async def test_warn_when_submitting_large_values(c, s, a, b): assert "2.00 MB" in text or "1.91 MiB" in text assert "large" in text assert "..." in text - assert "'000" in text + assert "... 000" in text assert "000'" in text assert len(text) < 2000 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f1aeef606d2..083fac7c8f0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -17,7 +17,6 @@ import dask from dask import delayed -from dask.utils import apply from distributed import Client, Nanny, Worker, fire_and_forget, wait from distributed.comm import Comm @@ -477,22 +476,6 @@ def test_dumps_function(): assert a != c -def test_dumps_task(): - d = dumps_task((inc, 1)) - assert set(d) == {"function", "args"} - - f = lambda x, y=2: x + y - d = dumps_task((apply, f, (1,), {"y": 10})) - assert cloudpickle.loads(d["function"])(1, 2) == 3 - assert cloudpickle.loads(d["args"]) == (1,) - assert cloudpickle.loads(d["kwargs"]) == {"y": 10} - - d = dumps_task((apply, f, (1,))) - assert cloudpickle.loads(d["function"])(1, 2) == 3 - assert cloudpickle.loads(d["args"]) == (1,) - assert set(d) == {"function", "args"} - - @gen_cluster() async def test_ready_remove_worker(s, a, b): s.update_graph( diff --git a/distributed/worker.py b/distributed/worker.py index 9cbb7d42979..2826c2a9e42 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -10,7 +10,7 @@ import threading import warnings import weakref -from collections import defaultdict, deque, namedtuple +from collections import defaultdict, deque from collections.abc import MutableMapping from contextlib import suppress from datetime import timedelta @@ -25,7 +25,7 @@ import dask from dask.core import istask from dask.system import CPU_COUNT -from dask.utils import apply, format_bytes, funcname +from dask.utils import format_bytes, funcname from . import comm, preloading, profile, system, utils from .batched import BatchedSend @@ -47,7 +47,9 @@ from .metrics import time from .node import ServerNode from .proctitle import setproctitle -from .protocol import pickle, to_serialize +from .protocol import pickle +from .protocol.computation import Computation, PickledTask, Task +from .protocol.serialize import to_serialize from .pubsub import PubSubWorkerExtension from .security import Security from .sizeof import safe_sizeof as sizeof @@ -56,7 +58,6 @@ from .utils import ( LRU, TimeoutError, - _maybe_complex, deprecated, get_ip, has_arg, @@ -100,8 +101,6 @@ dask.config.get("distributed.scheduler.default-data-size") ) -SerializedTask = namedtuple("SerializedTask", ["function", "args", "kwargs", "task"]) - class TaskState: """Holds volatile state relating to an individual Dask task @@ -155,15 +154,14 @@ class TaskState: Parameters ---------- key: str - runspec: SerializedTask - A named tuple containing the ``function``, ``args``, ``kwargs`` and - ``task`` associated with this `TaskState` instance. This defaults to + runspec: Computation + This defaults to ``None`` and can remain empty if it is a dependency that this worker will receive from another worker. """ - def __init__(self, key, runspec=None): + def __init__(self, key, runspec: Computation = None): assert key is not None self.key = key self.runspec = runspec @@ -416,7 +414,7 @@ def __init__( lifetime_restart=None, **kwargs, ): - self.tasks = dict() + self.tasks: Dict[str, TaskState] = dict() self.waiting_for_data_count = 0 self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) @@ -1555,10 +1553,7 @@ async def set_resources(self, **resources): def add_task( self, key, - function=None, - args=None, - kwargs=None, - task=no_value, + runspec: Computation = None, who_has=None, nbytes=None, priority=None, @@ -1569,7 +1564,6 @@ def add_task( **kwargs2, ): try: - runspec = SerializedTask(function, args, kwargs, task) if key in self.tasks: ts = self.tasks[key] ts.scheduler_holds_ref = True @@ -1592,9 +1586,7 @@ def add_task( self.transition(ts, "waiting", runspec=runspec) else: self.log.append((key, "new")) - self.tasks[key] = ts = TaskState( - key=key, runspec=SerializedTask(function, args, kwargs, task) - ) + self.tasks[key] = ts = TaskState(key=key, runspec=runspec) self.transition(ts, "waiting") # TODO: move transition of `ts` to end of `add_task` # This will require a chained recommendation transition system like @@ -2870,31 +2862,32 @@ def meets_resource_constraints(self, key): return True - async def _maybe_deserialize_task(self, ts): - if not isinstance(ts.runspec, SerializedTask): - return ts.runspec - try: - start = time() - # Offload deserializing large tasks - if sizeof(ts.runspec) > OFFLOAD_THRESHOLD: - function, args, kwargs = await offload(_deserialize, *ts.runspec) - else: - function, args, kwargs = _deserialize(*ts.runspec) - stop = time() + async def _maybe_deserialize_task(self, ts: TaskState) -> Computation: + if isinstance(ts.runspec, PickledTask): + try: + start = time() + # Offload deserializing large tasks + if sizeof(ts.runspec) > OFFLOAD_THRESHOLD: + runspec = await offload(_deserialize, ts.runspec) + else: + runspec = _deserialize(ts.runspec) + stop = time() - if stop - start > 0.010: - ts.startstops.append( - {"action": "deserialize", "start": start, "stop": stop} - ) - return function, args, kwargs - except Exception as e: - logger.warning("Could not deserialize task", exc_info=True) - emsg = error_message(e) - emsg["key"] = ts.key - emsg["op"] = "task-erred" - self.batched_stream.send(emsg) - self.log.append((ts.key, "deserialize-error")) - raise + if stop - start > 0.010: + ts.startstops.append( + {"action": "deserialize", "start": start, "stop": stop} + ) + return runspec + except Exception as e: + logger.warning("Could not deserialize task", exc_info=True) + emsg = error_message(e) + emsg["key"] = ts.key + emsg["op"] = "task-erred" + self.batched_stream.send(emsg) + self.log.append((ts.key, "deserialize-error")) + raise + else: + return ts.runspec async def ensure_computing(self): if self.paused: @@ -2962,7 +2955,7 @@ async def execute(self, key, report=False): assert not ts.waiting_for_data assert ts.state == "executing" - function, args, kwargs = ts.runspec + function, args, kwargs = ts.runspec.get_computation() start = time() data = {} @@ -3772,21 +3765,9 @@ def loads_function(bytes_object): return pickle.loads(bytes_object) -def _deserialize(function=None, args=None, kwargs=None, task=no_value): - """Deserialize task inputs and regularize to func, args, kwargs""" - if function is not None: - function = loads_function(function) - if args and isinstance(args, bytes): - args = pickle.loads(args) - if kwargs and isinstance(kwargs, bytes): - kwargs = pickle.loads(kwargs) - - if task is not no_value: - assert not function and not args and not kwargs - function = execute_task - args = (task,) - - return function, args or (), kwargs or {} +def _deserialize(runspec: PickledTask) -> Task: + """Deserialize computation""" + return runspec.get_task() def execute_task(task): @@ -3828,59 +3809,10 @@ def dumps_function(func): def dumps_task(task): - """Serialize a dask task + from .protocol.computation import typeset_computation - Returns a dict of bytestrings that can each be loaded with ``loads`` - - Examples - -------- - Either returns a task as a function, args, kwargs dict - - >>> from operator import add - >>> dumps_task((add, 1)) # doctest: +SKIP - {'function': b'\x80\x04\x95\x00\x8c\t_operator\x94\x8c\x03add\x94\x93\x94.' - 'args': b'\x80\x04\x95\x07\x00\x00\x00K\x01K\x02\x86\x94.'} - - Or as a single task blob if it can't easily decompose the result. This - happens either if the task is highly nested, or if it isn't a task at all - - >>> dumps_task(1) # doctest: +SKIP - {'task': b'\x80\x04\x95\x03\x00\x00\x00\x00\x00\x00\x00K\x01.'} - """ - if istask(task): - if task[0] is apply and not any(map(_maybe_complex, task[2:])): - d = {"function": dumps_function(task[1]), "args": warn_dumps(task[2])} - if len(task) == 4: - d["kwargs"] = warn_dumps(task[3]) - return d - elif not any(map(_maybe_complex, task[1:])): - return {"function": dumps_function(task[0]), "args": warn_dumps(task[1:])} - return to_serialize(task) - - -_warn_dumps_warned = [False] - - -def warn_dumps(obj, dumps=pickle.dumps, limit=1e6): - """Dump an object to bytes, warn if those bytes are large""" - b = dumps(obj, protocol=4) - if not _warn_dumps_warned[0] and len(b) > limit: - _warn_dumps_warned[0] = True - s = str(obj) - if len(s) > 70: - s = s[:50] + " ... " + s[-15:] - warnings.warn( - "Large object of size %s detected in task graph: \n" - " %s\n" - "Consider scattering large objects ahead of time\n" - "with client.scatter to reduce scheduler burden and \n" - "keep data on workers\n\n" - " future = client.submit(func, big_data) # bad\n\n" - " big_future = client.scatter(big_data) # good\n" - " future = client.submit(func, big_future) # good" - % (format_bytes(len(b)), s) - ) - return b + # TODO: replace all calls to dumps_task() with calls to typeset_computation() + return typeset_computation(task) def apply_function(