Skip to content

Commit

Permalink
Initially impl. of typed computations
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Jun 17, 2021
1 parent 42d631d commit 4c1cd29
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 163 deletions.
247 changes: 247 additions & 0 deletions distributed/protocol/computation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
"""
This module implements graph computations based on the specification in Dask[1]:
> A computation may be one of the following:
> - Any key present in the Dask graph like `'x'`
> - Any other value like `1`, to be interpreted literally
> - A task like `(inc, 'x')`
> - A list of computations, like `[1, 'x', (inc, 'x')]`
In order to support efficient and flexible task serialization, this module introduces
classes for computations, tasks, data, functions, etc.
Notable Classes
---------------
- `PickledObject` - An object that are serialized using `protocol.pickle`.
This object isn't a computation by itself instead users can build computations
containing them. It is automatically de-serialized by the Worker before execution.
- `Computation` - A computation that the Worker can execute. The Scheduler sees
this as a black box.
- `Data(Computation)` - De-serialized data that does **not** contain any tasks or
`PickledObject`.
- `Task(Computation)` - A de-serialized task ready for execution, possible nested.
- `PickledTask(Task)` - A task serialized using `protocol.pickle` and can contain
`PickledObject`.
Notable Functions
-----------------
- `typeset_dask_graph()` - Use to typeset a Dask graph, which wrap computations in
either the `Data` or `Task` class. This should be done before communicating the graph.
Note, this replaces the old `tlz.valmap(dumps_task, dsk)` operation.
[1] <https://docs.dask.org/en/latest/spec.html>
"""

import threading
import warnings
from typing import Any, Callable, Dict, Iterable, Mapping, MutableMapping, Tuple

import tlz

from dask.core import istask
from dask.utils import apply, format_bytes

from ..utils import LRU
from . import pickle


def identity(x, *args_ignored):
return x


def execute_task(task, *args_ignored):
"""Evaluate a nested task
>>> inc = lambda x: x + 1
>>> execute_task((inc, 1))
2
>>> execute_task((sum, [1, 2, (inc, 3)]))
7
"""
if istask(task):
func, args = task[0], task[1:]
return func(*map(execute_task, args))
elif isinstance(task, list):
return list(map(execute_task, task))
else:
return task


class PickledObject:
_value: bytes

def __init__(self, value: bytes):
self._value = value

def __reduce__(self):
return (type(self), (self._value,))

@property
def value(self):
return self._value

@classmethod
def serialize(cls, obj) -> "PickledObject":
return cls(pickle.dumps(obj))

def deserialize(self):
return pickle.loads(self._value)


class PickledCallable(PickledObject):
cache_dumps: MutableMapping[int, bytes] = LRU(maxsize=100)
cache_loads: MutableMapping[int, Callable] = LRU(maxsize=100)
cache_max_sized_obj = 1_000_000
cache_dumps_lock = threading.Lock()

@classmethod
def dumps_function(cls, func: Callable) -> bytes:
"""Dump a function to bytes, cache functions"""

try:
with cls.cache_dumps_lock:
ret = cls.cache_dumps[func]
except KeyError:
ret = pickle.dumps(func)
if len(ret) <= cls.cache_max_sized_obj:
with cls.cache_dumps_lock:
cls.cache_dumps[func] = ret
except TypeError: # Unhashable function
ret = pickle.dumps(func)
return ret

@classmethod
def loads_function(cls, dumped_func: bytes):
"""Load a function from bytes, cache bytes"""
if len(dumped_func) > cls.cache_max_sized_obj:
return pickle.loads(dumped_func)

try:
ret = cls.cache_loads[dumped_func]
except KeyError:
cls.cache_loads[dumped_func] = ret = pickle.loads(dumped_func)
return ret

@classmethod
def serialize(cls, func: Callable) -> "PickledCallable":
if isinstance(func, cls):
return func
else:
return cls(cls.dumps_function(func))

def deserialize(self) -> Callable:
return self.loads_function(self._value)

def __call__(self, *args, **kwargs):
return self.deserialize()(*args, **kwargs)


class Computation:
def __init__(self, value):
self._value = value

@property
def value(self):
return self._value

def get_computation(self) -> Tuple[Callable, Iterable, Mapping]:
assert False, self._value


class Data(Computation):
def get_computation(self) -> Tuple[Callable, Iterable, Mapping]:
return (identity, (self._value,), {})


class Task(Computation):
def get_computation(self) -> Tuple[Callable, Iterable, Mapping]:
return (execute_task, (self._value,), {})


class PickledTask(Task):
_size_warning_triggered: bool = False
_size_warning_limit: int = 1_000_000

def __init__(self, serialized_data: bytes):
self._value: bytes = serialized_data

@classmethod
def serialize(cls, value) -> "PickledTask":
data = pickle.dumps(value)
ret = cls(data)
if not cls._size_warning_triggered and len(data) > cls._size_warning_limit:
cls._size_warning_triggered = True
s = str(value)
if len(s) > 70:
s = s[:50] + " ... " + s[-15:]
warnings.warn(
"Large object of size %s detected in task graph: \n"
" %s\n"
"Consider scattering large objects ahead of time\n"
"with client.scatter to reduce scheduler burden and \n"
"keep data on workers\n\n"
" future = client.submit(func, big_data) # bad\n\n"
" big_future = client.scatter(big_data) # good\n"
" future = client.submit(func, big_future) # good"
% (format_bytes(len(data)), s)
)
return ret

def deserialize(self):
def inner_deserialize(obj):
if isinstance(obj, list):
return [inner_deserialize(o) for o in obj]
elif istask(obj):
return tuple(inner_deserialize(o) for o in obj)
elif isinstance(obj, PickledObject):
return obj.deserialize()
else:
return obj

return inner_deserialize(pickle.loads(self._value))

def get_task(self):
return Task(self.deserialize())

def get_computation(self) -> Tuple[Callable, Iterable, Mapping]:
return self.get_task().get_computation()


def typeset_computation(computation) -> Computation:
from .serialize import Serialize, Serialized

if isinstance(computation, Computation):
return computation # Already a computation

contain_tasks = [False]

def serialize_callables(obj):
if isinstance(obj, list):
return [serialize_callables(o) for o in obj]
elif istask(obj):
contain_tasks[0] = True
if obj[0] is apply:
return (apply, PickledCallable.serialize(obj[1])) + tuple(
map(serialize_callables, obj[2:])
)
else:
return (PickledCallable.serialize(obj[0]),) + tuple(
map(serialize_callables, obj[1:])
)
else:
assert not isinstance(obj, (Serialize, Serialized))
return obj

computation = serialize_callables(computation)
if contain_tasks[0]:
return PickledTask.serialize(computation)
else:
return Data(Serialize(computation))


def typeset_dask_graph(dsk: Mapping[str, Any]) -> Dict[str, Computation]:
return tlz.valmap(typeset_computation, dsk)
31 changes: 24 additions & 7 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..utils import ensure_bytes, has_keyword, typename
from . import pickle
from .compression import decompress, maybe_compress
from .computation import Data, PickledCallable, PickledObject, PickledTask
from .utils import frame_split_size, msgpack_opts, pack_frames_prelude, unpack_frames

lazy_registrations = {}
Expand Down Expand Up @@ -118,13 +119,17 @@ def msgpack_decode_default(obj):
if "__Set__" in obj:
return set(obj["as-list"])

if "__Serialized__" in obj:
# Notice, the data here is marked a Serialized rather than deserialized. This
# is because deserialization requires Pickle which the Scheduler cannot run
# because of security reasons.
# By marking it Serialized, the data is passed through to the workers that
# eventually will deserialize it.
return Serialized(*obj["data"])
if "__PickledTask__" in obj:
return PickledTask(obj["value"])

if "__PickledCallable__" in obj:
return PickledCallable(obj["value"])

if "__PickledObject__" in obj:
return PickledObject(obj["value"])

if "__Data__" in obj:
return Data(obj["value"])

return obj

Expand All @@ -148,6 +153,18 @@ def msgpack_encode_default(obj):
if isinstance(obj, set):
return {"__Set__": True, "as-list": list(obj)}

if isinstance(obj, PickledTask):
return {"__PickledTask__": True, "value": obj.value}

if isinstance(obj, PickledCallable):
return {"__PickledCallable__": True, "value": obj.value}

if isinstance(obj, PickledObject):
return {"__PickledObject__": True, "value": obj.value}

if isinstance(obj, Data):
return {"__Data__": True, "value": obj.value}

return obj


Expand Down
12 changes: 5 additions & 7 deletions distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from dask.utils import stringify

from distributed.protocol.computation import PickledTask

from .client import futures_of, wait
from .utils import sync
from .utils_comm import pack_data
from .worker import _deserialize

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,12 +84,9 @@ async def _get_raw_components_from_future(self, future):
key = future.key
spec = await self.scheduler.get_runspec(key=key)
deps, task = spec["deps"], spec["task"]
if isinstance(task, dict):
function, args, kwargs = _deserialize(**task)
return (function, args, kwargs, deps)
else:
function, args, kwargs = _deserialize(task=task)
return (function, args, kwargs, deps)
assert isinstance(task, PickledTask)
function, args, kwargs = task.get_computation()
return (function, args, kwargs, deps)

async def _prepare_raw_components(self, raw_components):
"""
Expand Down
15 changes: 7 additions & 8 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .multi_lock import MultiLockExtension
from .node import ServerNode
from .proctitle import setproctitle
from .protocol.computation import Computation
from .publish import PublishExtension
from .pubsub import PubSubSchedulerExtension
from .queues import QueueExtension
Expand Down Expand Up @@ -393,7 +394,7 @@ class WorkerState:
.. attribute:: processing: {TaskState: cost}
A dictionary of tasks that have been submitted to this worker.
Each task state is asssociated with the expected cost in seconds
Each task state is associated with the expected cost in seconds
of running that task, summing both the task's expected computation
time and the expected communication time of its result.
Expand All @@ -408,7 +409,7 @@ class WorkerState:
.. attribute:: executing: {TaskState: duration}
A dictionary of tasks that are currently being run on this worker.
Each task state is asssociated with the duration in seconds which
Each task state is associated with the duration in seconds which
the task has been running.
.. attribute:: has_what: {TaskState}
Expand Down Expand Up @@ -7379,11 +7380,15 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
if duration < 0:
duration = state.get_task_duration(ts)

run_spec = ts._run_spec
assert run_spec is None or isinstance(run_spec, Computation)

msg: dict = {
"op": "compute-task",
"key": ts._key,
"priority": ts._priority,
"duration": duration,
"runspec": run_spec,
}
if ts._resource_restrictions:
msg["resource_restrictions"] = ts._resource_restrictions
Expand All @@ -7400,12 +7405,6 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
if state._validate:
assert all(msg["who_has"].values())

task = ts._run_spec
if type(task) is dict:
msg.update(task)
else:
msg["task"] = task

if ts._annotations:
msg["annotations"] = ts._annotations
return msg
Expand Down
Loading

0 comments on commit 4c1cd29

Please sign in to comment.