Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use annotations machinery for workers/priority/... #4347

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 19 additions & 119 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from functools import partial
import html
import inspect
import itertools
import json
import logging
from numbers import Number, Integral
from numbers import Number
import os
import sys
import uuid
Expand Down Expand Up @@ -2543,8 +2542,8 @@ def _graph_to_futures(
self,
dsk,
keys,
restrictions=None,
loose_restrictions=None,
workers=None,
allow_other_workers=None,
priority=None,
user_priority=0,
resources=None,
Expand All @@ -2553,27 +2552,9 @@ def _graph_to_futures(
actors=None,
):
with self._refcount_lock:
if resources:
resources = self._expand_resources(
resources, all_keys=itertools.chain(dsk, keys)
)
resources = {stringify(k): v for k, v in resources.items()}

if retries:
retries = self._expand_retries(
retries, all_keys=itertools.chain(dsk, keys)
)

if actors is not None and actors is not True and actors is not False:
actors = list(self._expand_key(actors))

if restrictions:
restrictions = keymap(stringify, restrictions)
restrictions = valmap(list, restrictions)

if loose_restrictions is not None:
loose_restrictions = list(map(stringify, loose_restrictions))

keyset = set(keys)

# Make sure `dsk` is a high level graph
Expand All @@ -2582,8 +2563,19 @@ def _graph_to_futures(

dsk = highlevelgraph_pack(dsk, self, keyset)

if isinstance(retries, Number) and retries > 0:
retries = {k: retries for k in dsk}
annotations = {}
if user_priority:
annotations["priority"] = user_priority
if workers:
if not isinstance(workers, (list, tuple, set)):
workers = [workers]
annotations["workers"] = workers
if retries:
annotations["retries"] = retries
if allow_other_workers:
annotations["allow_other_workers"] = allow_other_workers
if resources:
annotations["resources"] = resources

# Create futures before sending graph (helps avoid contention)
futures = {key: Future(key, self, inform=False) for key in keyset}
Expand All @@ -2592,14 +2584,10 @@ def _graph_to_futures(
"op": "update-graph-hlg",
"hlg": dsk,
"keys": list(map(stringify, keys)),
"restrictions": restrictions or {},
"loose_restrictions": loose_restrictions,
"priority": priority,
"user_priority": user_priority,
"resources": resources,
"submitting_task": getattr(thread_state, "key", None),
"retries": retries,
"fifo_timeout": fifo_timeout,
"annotations": annotations,
"actors": actors,
}
)
Expand Down Expand Up @@ -2849,10 +2837,6 @@ def compute(
else:
dsk2[name] = (func, keys) + extra_args

restrictions, loose_restrictions = self.get_restrictions(
collections, workers, allow_other_workers
)

if not isinstance(priority, Number):
priority = {k: p for c, p in priority.items() for k in self._expand_key(c)}

Expand All @@ -2870,8 +2854,8 @@ def compute(
futures_dict = self._graph_to_futures(
dsk,
names,
restrictions,
loose_restrictions,
workers=workers,
allow_other_workers=allow_other_workers,
resources=resources,
retries=retries,
user_priority=priority,
Expand Down Expand Up @@ -3873,90 +3857,6 @@ def _expand_key(cls, k):
else:
yield stringify(kk)

@classmethod
def _expand_retries(cls, retries, all_keys):
"""
Expand the user-provided "retries" specification
to a {task key: Integral} dictionary.
"""
if retries and isinstance(retries, dict):
result = {
name: value
for key, value in retries.items()
for name in cls._expand_key(key)
}
elif isinstance(retries, Integral):
# Each task unit may potentially fail, allow retrying all of them
result = {name: retries for name in all_keys}
else:
raise TypeError(
"`retries` should be an integer or dict, got %r" % (type(retries))
)
return keymap(stringify, result)

def _expand_resources(cls, resources, all_keys):
"""
Expand the user-provided "resources" specification
to a {task key: {resource name: Number}} dictionary.
"""
# Resources can either be a single dict such as {'GPU': 2},
# indicating a requirement for all keys, or a nested dict
# such as {'x': {'GPU': 1}, 'y': {'SSD': 4}} indicating
# per-key requirements
if not isinstance(resources, dict):
raise TypeError("`resources` should be a dict, got %r" % (type(resources)))

per_key_reqs = {}
global_reqs = {}
all_keys = list(all_keys)
for k, v in resources.items():
if isinstance(v, dict):
# It's a per-key requirement
per_key_reqs.update((kk, v) for kk in cls._expand_key(k))
else:
# It's a global requirement
global_reqs.update((kk, {k: v}) for kk in all_keys)

if global_reqs and per_key_reqs:
raise ValueError(
"cannot have both per-key and all-key requirements "
"in resources dict %r" % (resources,)
)
return global_reqs or per_key_reqs

@classmethod
def get_restrictions(cls, collections, workers, allow_other_workers):
""" Get restrictions from inputs to compute/persist """
if isinstance(workers, (str, tuple, list)):
workers = {tuple(collections): workers}
if isinstance(workers, dict):
restrictions = {}
for colls, ws in workers.items():
if isinstance(ws, str):
ws = [ws]
if dask.is_dask_collection(colls):
keys = flatten(colls.__dask_keys__())
elif isinstance(colls, str):
keys = [colls]
else:
keys = list(
{k for c in flatten(colls) for k in flatten(c.__dask_keys__())}
)
restrictions.update({k: ws for k in keys})
else:
restrictions = {}

if allow_other_workers is True:
loose_restrictions = list(restrictions)
elif allow_other_workers:
loose_restrictions = list(
{k for c in flatten(allow_other_workers) for k in c.__dask_keys__()}
)
else:
loose_restrictions = []

return restrictions, loose_restrictions

@staticmethod
def collections_to_dsk(collections, *args, **kwargs):
return collections_to_dsk(collections, *args, **kwargs)
Expand Down
12 changes: 8 additions & 4 deletions distributed/protocol/highlevelgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,25 @@ def _materialized_layer_unpack(state, dsk, dependencies, annotations):
)


def highlevelgraph_unpack(dumped_hlg):
def highlevelgraph_unpack(dumped_hlg, annotations: dict):
# Notice, we set `use_list=False`, which makes msgpack convert lists to tuples
hlg = msgpack.loads(
dumped_hlg, object_hook=msgpack_decode_default, use_list=False, **msgpack_opts
)

dsk = {}
deps = {}
annotations = {}
out_annotations = {}
for layer in hlg["layers"]:
if annotations:
if layer["state"]["annotations"] is None:
layer["state"]["annotations"] = {}
layer["state"]["annotations"].update(annotations)
if layer["__module__"] is None: # Default implementation
unpack_func = _materialized_layer_unpack
else:
mod = import_allowed_module(layer["__module__"])
unpack_func = getattr(mod, layer["__name__"]).__dask_distributed_unpack__
unpack_func(layer["state"], dsk, deps, annotations)
unpack_func(layer["state"], dsk, deps, out_annotations)

return dsk, deps, annotations
return dsk, deps, out_annotations
3 changes: 2 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,9 +2279,10 @@ def update_graph_hlg(
user_priority=0,
actors=None,
fifo_timeout=0,
annotations=None,
):

dsk, dependencies, annotations = highlevelgraph_unpack(hlg)
dsk, dependencies, annotations = highlevelgraph_unpack(hlg, annotations)

# Remove any self-dependencies (happens on test_publish_bag() and others)
for k, v in dependencies.items():
Expand Down
9 changes: 9 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6397,3 +6397,12 @@ async def test_annotations_loose_restrictions(c, s, a, b):
for ts in s.tasks.values()
]
)


@gen_cluster(client=True)
async def test_workers_collection_restriction(c, s, a, b):
da = pytest.importorskip("dask.array")

future = c.compute(da.arange(10), workers=a.address)
await future
assert a.data and not b.data