Skip to content

Commit

Permalink
clean up the classes
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Jun 18, 2021
1 parent ad005d6 commit 4a3ab81
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 91 deletions.
90 changes: 51 additions & 39 deletions distributed/protocol/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
containing them. It is automatically de-serialized by the Worker before execution.
- `Computation` - A computation that the Worker can execute. The Scheduler sees
this as a black box.
this as a black box. A computation **cannot** contain pickled objects but it may
contain `Serialize` and/or `Serialized` objects, which will be de-serialize when
arriving on the Worker automatically.
- `Data(Computation)` - De-serialized data that does **not** contain any tasks or
`PickledObject`.
- `Task(Computation)` - A de-serialized task ready for execution, possible nested.
- `PickledTask(Task)` - A task serialized using `protocol.pickle` and can contain
`PickledObject`.
- `PickledComputation` - A computation that are serialized using `protocol.pickle`.
The class is derived from `Computation` but can contain pickled objects. Itself
and contained pickled objects will be de-serialize by the Worker before execution.
Notable Functions
-----------------
Expand Down Expand Up @@ -80,9 +79,15 @@ def __init__(self, value: bytes):
def __reduce__(self):
return (type(self), (self._value,))

@property
def value(self):
return self._value
@classmethod
def msgpack_decode(cls, state: Mapping):
return cls(state["value"])

def msgpack_encode(self) -> dict:
return {
f"__{type(self).__name__}__": True,
"value": self._value,
}

@classmethod
def serialize(cls, obj) -> "PickledObject":
Expand Down Expand Up @@ -141,38 +146,39 @@ def __call__(self, *args, **kwargs):


class Computation:
def __init__(self, value):
def __init__(self, value, is_a_task: bool):
self._value = value
self._is_a_task = is_a_task

@property
def value(self):
return self._value

def get_computation(self) -> Tuple[Callable, Iterable, Mapping]:
assert False, self._value


class Data(Computation):
def get_computation(self) -> Tuple[Callable, Iterable, Mapping]:
return (identity, (self._value,), {})

@classmethod
def msgpack_decode(cls, state: Mapping):
return cls(state["value"], state["is_a_task"])

def msgpack_encode(self) -> dict:
return {
f"__{type(self).__name__}__": True,
"value": self._value,
"is_a_task": self._is_a_task,
}

def get_func_and_args(self) -> Tuple[Callable, Iterable, Mapping]:
if self._is_a_task:
return (execute_task, (self._value,), {})
else:
return (identity, (self._value,), {})

class Task(Computation):
def get_computation(self) -> Tuple[Callable, Iterable, Mapping]:
return (execute_task, (self._value,), {})
def get_computation(self) -> "Computation":
return self


class PickledTask(Task):
class PickledComputation(Computation):
_size_warning_triggered: bool = False
_size_warning_limit: int = 1_000_000

def __init__(self, serialized_data: bytes):
self._value: bytes = serialized_data

@classmethod
def serialize(cls, value) -> "PickledTask":
def serialize(cls, value, is_a_task: bool):
data = pickle.dumps(value)
ret = cls(data)
ret = cls(data, is_a_task)
if not cls._size_warning_triggered and len(data) > cls._size_warning_limit:
cls._size_warning_triggered = True
s = str(value)
Expand Down Expand Up @@ -204,11 +210,11 @@ def inner_deserialize(obj):

return inner_deserialize(pickle.loads(self._value))

def get_task(self):
return Task(self.deserialize())
def get_computation(self) -> Computation:
return Computation(self.deserialize(), self._is_a_task)

def get_computation(self) -> Tuple[Callable, Iterable, Mapping]:
return self.get_task().get_computation()
def get_func_and_args(self) -> Tuple[Callable, Iterable, Mapping]:
return self.get_computation().get_func_and_args()


def typeset_computation(computation) -> Computation:
Expand All @@ -217,6 +223,7 @@ def typeset_computation(computation) -> Computation:
if isinstance(computation, Computation):
return computation # Already a computation

contain_pickled = [False]
contain_tasks = [False]

def serialize_callables(obj):
Expand All @@ -232,15 +239,20 @@ def serialize_callables(obj):
return (PickledCallable.serialize(obj[0]),) + tuple(
map(serialize_callables, obj[1:])
)
elif isinstance(obj, PickledObject):
contain_pickled[0] = True
return obj
else:
assert not isinstance(obj, (Serialize, Serialized))
assert not isinstance(obj, (Serialize, Serialized)), obj
return obj

computation = serialize_callables(computation)
if contain_tasks[0]:
return PickledTask.serialize(computation)
return PickledComputation.serialize(computation, is_a_task=True)
elif contain_pickled[0]:
return PickledComputation.serialize(computation, is_a_task=False)
else:
return Data(Serialize(computation))
return Computation(Serialize(computation), is_a_task=False)


def typeset_dask_graph(dsk: Mapping[str, Any]) -> Dict[str, Computation]:
Expand Down
68 changes: 24 additions & 44 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..utils import ensure_bytes, has_keyword, typename
from . import pickle
from .compression import decompress, maybe_compress
from .computation import Data, PickledCallable, PickledObject, PickledTask
from .computation import Computation, PickledCallable, PickledComputation, PickledObject
from .utils import frame_split_size, msgpack_opts, pack_frames_prelude, unpack_frames

lazy_registrations = {}
Expand Down Expand Up @@ -119,17 +119,17 @@ def msgpack_decode_default(obj):
if "__Set__" in obj:
return set(obj["as-list"])

if "__PickledTask__" in obj:
return PickledTask(obj["value"])

if "__PickledCallable__" in obj:
return PickledCallable(obj["value"])
return PickledCallable.msgpack_decode(obj)

if "__PickledObject__" in obj:
return PickledObject(obj["value"])
return PickledObject.msgpack_decode(obj)

if "__PickledComputation__" in obj:
return PickledComputation.msgpack_decode(obj)

if "__Data__" in obj:
return Data(obj["value"])
if "__Computation__" in obj:
return Computation.msgpack_decode(obj)

return obj

Expand All @@ -153,17 +153,8 @@ def msgpack_encode_default(obj):
if isinstance(obj, set):
return {"__Set__": True, "as-list": list(obj)}

if isinstance(obj, PickledTask):
return {"__PickledTask__": True, "value": obj.value}

if isinstance(obj, PickledCallable):
return {"__PickledCallable__": True, "value": obj.value}

if isinstance(obj, PickledObject):
return {"__PickledObject__": True, "value": obj.value}

if isinstance(obj, Data):
return {"__Data__": True, "value": obj.value}
if isinstance(obj, (PickledObject, Computation)):
return obj.msgpack_encode()

return obj

Expand Down Expand Up @@ -556,32 +547,21 @@ def nested_deserialize(x):
{'op': 'update', 'data': 123}
"""

def replace_inner(x):
if type(x) is dict:
x = x.copy()
for k, v in x.items():
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)

elif type(x) is list:
x = list(x)
for k, v in enumerate(x):
typ = type(v)
if typ is dict or typ is list:
x[k] = replace_inner(v)
elif typ is Serialize:
x[k] = v.data
elif typ is Serialized:
x[k] = deserialize(v.header, v.frames)

typ = type(x)
if typ is dict:
return {k: nested_deserialize(v) for k, v in x.items()}
elif typ is list or typ is tuple:
return typ(nested_deserialize(o) for o in x)
elif typ is Serialize:
return x.data
elif typ is Serialized:
return deserialize(x.header, x.frames)
elif isinstance(x, Computation):
x = x.get_computation()
x._value = nested_deserialize(x._value)
return x
else:
return x

return replace_inner(x)


def serialize_bytelist(x, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dask.utils import stringify

from distributed.protocol.computation import PickledTask
from distributed.protocol.computation import PickledComputation

from .client import futures_of, wait
from .utils import sync
Expand Down Expand Up @@ -84,8 +84,8 @@ async def _get_raw_components_from_future(self, future):
key = future.key
spec = await self.scheduler.get_runspec(key=key)
deps, task = spec["deps"], spec["task"]
assert isinstance(task, PickledTask)
function, args, kwargs = task.get_computation()
assert isinstance(task, PickledComputation)
function, args, kwargs = task.get_func_and_args()
return (function, args, kwargs, deps)

async def _prepare_raw_components(self, raw_components):
Expand Down
11 changes: 6 additions & 5 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from .node import ServerNode
from .proctitle import setproctitle
from .protocol import pickle
from .protocol.computation import Computation, PickledTask, Task
from .protocol.computation import Computation, PickledComputation
from .protocol.serialize import to_serialize
from .pubsub import PubSubWorkerExtension
from .security import Security
Expand Down Expand Up @@ -2863,7 +2863,7 @@ def meets_resource_constraints(self, key):
return True

async def _maybe_deserialize_task(self, ts: TaskState) -> Computation:
if isinstance(ts.runspec, PickledTask):
if isinstance(ts.runspec, PickledComputation):
try:
start = time()
# Offload deserializing large tasks
Expand All @@ -2887,6 +2887,7 @@ async def _maybe_deserialize_task(self, ts: TaskState) -> Computation:
self.log.append((ts.key, "deserialize-error"))
raise
else:
assert isinstance(ts.runspec, Computation), ts.runspec
return ts.runspec

async def ensure_computing(self):
Expand Down Expand Up @@ -2955,7 +2956,7 @@ async def execute(self, key, report=False):
assert not ts.waiting_for_data
assert ts.state == "executing"

function, args, kwargs = ts.runspec.get_computation()
function, args, kwargs = ts.runspec.get_func_and_args()

start = time()
data = {}
Expand Down Expand Up @@ -3765,9 +3766,9 @@ def loads_function(bytes_object):
return pickle.loads(bytes_object)


def _deserialize(runspec: PickledTask) -> Task:
def _deserialize(runspec: PickledComputation) -> Computation:
"""Deserialize computation"""
return runspec.get_task()
return runspec.get_computation()


def execute_task(task):
Expand Down

0 comments on commit 4a3ab81

Please sign in to comment.