Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP co-assign related root-ish tasks #4899

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 157 additions & 19 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,9 @@ class TaskGroup:
_start: double
_stop: double
_all_durations: object
_last_worker: WorkerState
_last_worker_tasks_left: int # TODO Py_ssize_t?
_last_worker_priority: tuple # TODO remove (debugging only)

def __init__(self, name: str):
self._name = name
Expand All @@ -964,6 +967,9 @@ def __init__(self, name: str):
self._start = 0.0
self._stop = 0.0
self._all_durations = defaultdict(float)
self._last_worker = None
self._last_worker_tasks_left = 0
self._last_worker_priority = ()

@property
def name(self):
Expand Down Expand Up @@ -1009,6 +1015,26 @@ def start(self):
def stop(self):
return self._stop

@property
def last_worker(self):
return self._last_worker

@property
def last_worker_tasks_left(self):
return self._last_worker_tasks_left

@last_worker_tasks_left.setter
def last_worker_tasks_left(self, n: int):
self._last_worker_tasks_left = n

@property
def last_worker_priority(self):
return self._last_worker_priority

@last_worker_priority.setter
def last_worker_priority(self, x: tuple):
self._last_worker_priority = x

@ccall
def add(self, o):
ts: TaskState = o
Expand Down Expand Up @@ -2337,14 +2363,20 @@ def decide_worker(self, ts: TaskState) -> WorkerState:
ts.state = "no-worker"
return ws

if ts._dependencies or valid_workers is not None:
if (
ts._dependencies
or valid_workers is not None
or ts._group._last_worker is not None
):
ws = decide_worker(
ts,
self._workers_dv.values(),
valid_workers,
partial(self.worker_objective, ts),
self._total_nthreads,
)
else:
# Fastpath when there are no related tasks or restrictions
worker_pool = self._idle or self._workers
worker_pool_dv = cast(dict, worker_pool)
wp_vals = worker_pool.values()
Expand All @@ -2366,6 +2398,15 @@ def decide_worker(self, ts: TaskState) -> WorkerState:
else: # dumb but fast in large case
ws = wp_vals[self._n_tasks % n_workers]

ts._group._last_worker = ws
group_tasks_per_thread = (
len(ts._group) / self._total_nthreads if self._total_nthreads > 0 else 0
)
ts._group._last_worker_tasks_left = (
math.floor(group_tasks_per_thread * ws._nthreads) - 1
)
ts._group._last_worker_priority = ts._priority

if self._validate:
assert ws is None or isinstance(ws, WorkerState), (
type(ws),
Expand Down Expand Up @@ -4671,6 +4712,9 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True):
recommendations[ts._key] = "released"
else: # pure data
recommendations[ts._key] = "forgotten"
if ts._group._last_worker is ws:
ts._group._last_worker = None
ts._group._last_worker_tasks_left = 0
ws._has_what.clear()

self.transitions(recommendations)
Expand Down Expand Up @@ -6244,8 +6288,9 @@ async def retire_workers(
logger.info("Retire workers %s", workers)

# Keys orphaned by retiring those workers
keys = {k for w in workers for k in w.has_what}
keys = {ts._key for ts in keys if ts._who_has.issubset(workers)}
tasks = {ts for w in workers for ts in w.has_what}
keys = {ts._key for ts in tasks if ts._who_has.issubset(workers)}
groups = {ts._group for ts in tasks}

if keys:
other_workers = set(parent._workers_dv.values()) - workers
Expand All @@ -6260,6 +6305,11 @@ async def retire_workers(
lock=False,
)

for group in groups:
if group._last_worker in workers:
group._last_worker = None
group._last_worker_tasks_left = 0

worker_keys = {ws._address: ws.identity() for ws in workers}
if close_workers:
await asyncio.gather(
Expand Down Expand Up @@ -7471,11 +7521,52 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
@cfunc
@exceptval(check=False)
def decide_worker(
ts: TaskState, all_workers, valid_workers: set, objective
ts: TaskState,
all_workers,
valid_workers: set,
objective,
total_nthreads: Py_ssize_t,
) -> WorkerState:
"""
r"""
Decide which worker should take task *ts*.

There are two modes: root(ish) tasks, and normal tasks.

Root(ish) tasks
~~~~~~~~~~~~~~~

Root(ish) have no (or very very few) dependencies and fan out widely:
they belong to TaskGroups that contain more tasks than there are workers.
We want neighboring root tasks to run on the same worker, since there's a
good chance those neighbors will be combined in a downstream operation:

i j
/ \ / \
e f g h
| | | |
a b c d
\ \ / /
X

In the above case, we want ``a`` and ``b`` to run on the same worker,
and ``c`` and ``d`` to run on the same worker, reducing future
data transfer. We can also ignore the location of ``X``, because
as a common dependency, it will eventually get transferred everywhere.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3 the ascii art

Comment/question: Do we want to explain all of this here? Historically I haven't put the logic behind heuristics in the code. This is a subjective opinion, and far from universal, but I find that heavily commented/documented logic makes it harder to understand the code at a glance. I really like that the current decide_worker implementation fits in a terminal window. I think that single-line comments are cool, but that long multi-line comments would better be written as documentation.

Thoughts? If you are not in disagreement then I would encourage us to write up a small docpage or maybe a blogpost and then link to that external resource from the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also planning on updating https://distributed.dask.org/en/latest/scheduling-policies.html#choosing-workers, probably with this same ascii art. So just linking to that page in the docstring seems appropriate.


Calculaing this directly from the graph would be expensive, so instead
we use task priority as a proxy. We aim to send tasks close in priority
within a `TaskGroup` to the same worker. To do this efficiently, we rely
on the fact that `decide_worker` is generally called in priority order
for root tasks (because `Scheduler.update_graph` creates recommendations
in priority order), and track only the last worker used for a `TaskGroup`,
and how many more tasks can be assigned to it before picking a new one.

By colocating related root tasks, we ensure that placing thier downstream
normal tasks is set up for success.

Normal tasks
~~~~~~~~~~~~

We choose the worker that has the data on which *ts* depends.

If several workers have dependencies then we choose the less-busy worker.
Expand All @@ -7488,36 +7579,83 @@ def decide_worker(
of bytes sent between workers. This is determined by calling the
*objective* function.
"""
ws: WorkerState = None
wws: WorkerState
dts: TaskState

group: TaskGroup = ts._group
ws: WorkerState = group._last_worker

if valid_workers is not None:
total_nthreads = sum(wws._nthreads for wws in valid_workers)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This walks through all workers for all tasks. We may not be able to do this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below; I believe valid_workers is None is the common case? Agreed that this isn't ideal though. But if there are worker restrictions, ignoring them and just using self._total_nthreads could be wildly wrong (imagine 10 GPU workers and 100 CPU workers for a task group of 50 that needs to run on GPUs). Maybe there's a cheaper measurement?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, grand.


group_tasks_per_thread = (len(group) / total_nthreads) if total_nthreads > 0 else 0
ignore_deps_while_picking: bool = False

# Try to schedule sibling root-like tasks on the same workers.
if (
ws is not None
and group._last_worker_priority is not None
# ^ `decide_worker` hasn't previously been called out of priority order
and group_tasks_per_thread > 1
and sum(map(len, group._dependencies)) < 5 # TODO what number
):
if group._last_worker_tasks_left > 0:
group._last_worker_tasks_left -= 1
if group._last_worker_priority < ts.priority and (
valid_workers is None or ws in valid_workers
):
group._last_worker_priority = ts.priority
return ws

# `decide_worker` called out of priority order, or the last used worker is not valid for this task.
# This is probably not actually a root-ish task; disable root-ish mode in the future.
group._last_worker = None
group._last_worker_tasks_left = 0
group._last_worker_priority = None

# Previous worker is fully assigned, so pick a new worker.
ignore_deps_while_picking = True

deps: set = ts._dependencies
dts: TaskState
candidates: set
assert all([dts._who_has for dts in deps])
if ts._actor:
candidates = set(all_workers)
if ignore_deps_while_picking:
candidates = valid_workers if valid_workers is not None else set(all_workers)
else:
candidates = {wws for dts in deps for wws in dts._who_has}
if valid_workers is None:
if not candidates:
if ts._actor:
candidates = set(all_workers)
else:
candidates &= valid_workers
if not candidates:
candidates = valid_workers
else:
candidates = {wws for dts in deps for wws in dts._who_has}
if valid_workers is None:
if not candidates:
if ts._loose_restrictions:
ws = decide_worker(ts, all_workers, None, objective)
return ws
candidates = set(all_workers)
else:
candidates &= valid_workers
if not candidates:
candidates = valid_workers
if not candidates:
if ts._loose_restrictions:
ws = decide_worker(
ts, all_workers, None, objective, total_nthreads
)
return ws

ncandidates: Py_ssize_t = len(candidates)
if ncandidates == 0:
pass
elif ncandidates == 1:
# NOTE: this is the ideal case: all the deps are already on the same worker.
for ws in candidates:
break
else:
ws = min(candidates, key=objective)

if group._last_worker_priority is not None:
group._last_worker = ws
group._last_worker_tasks_left = (
math.floor(group_tasks_per_thread * ws._nthreads) - 1
)
group._last_worker_priority = ts.priority
return ws


Expand Down
110 changes: 109 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dask
from dask import delayed
from dask.utils import apply
from dask.utils import apply, stringify

from distributed import Client, Nanny, Worker, fire_and_forget, wait
from distributed.comm import Comm
Expand Down Expand Up @@ -126,6 +126,114 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c):
assert x.key in a.data or x.key in b.data


@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * 3,
config={"distributed.scheduler.work-stealing": False},
)
async def test_decide_worker_select_candidate_holding_no_deps(client, s, a, b, c):
await client.submit(slowinc, 10, delay=0.1) # learn that slowinc is slow
root = await client.scatter(1)
assert sum(root.key in worker.data for worker in [a, b, c]) == 1

start = time()
tasks = client.map(slowinc, [root] * 6, delay=0.1, pure=False)
await wait(tasks)
elapsed = time() - start

assert elapsed <= 4
assert all(root.key in worker.data for worker in [a, b, c]), [
list(worker.data.keys()) for worker in [a, b, c]
]


@pytest.mark.parametrize("ndeps", [0, 1, 4])
@pytest.mark.parametrize(
"nthreads",
[
[("127.0.0.1", 1)] * 5,
[("127.0.0.1", 3), ("127.0.0.1", 2), ("127.0.0.1", 1)],
],
)
def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads):
@gen_cluster(
client=True,
nthreads=nthreads,
config={"distributed.scheduler.work-stealing": False},
)
async def test(c, s, *workers):
"""Ensure that related tasks end up on the same node"""
da = pytest.importorskip("dask.array")
np = pytest.importorskip("numpy")

if ndeps == 0:
x = da.random.random((100, 100), chunks=(10, 10))
else:

def random(**kwargs):
assert len(kwargs) == ndeps
return np.random.random((10, 10))

trivial_deps = {f"k{i}": delayed(object()) for i in range(ndeps)}

# TODO is there a simpler (non-blockwise) way to make this sort of graph?
x = da.blockwise(
random,
"yx",
new_axes={"y": (10,) * 10, "x": (10,) * 10},
dtype=float,
**trivial_deps,
)

xx, xsum = dask.persist(x, x.sum(axis=1, split_every=20))
await xsum

# Check that each chunk-row of the array is (mostly) stored on the same worker
primary_worker_key_fractions = []
secondary_worker_key_fractions = []
for i, keys in enumerate(x.__dask_keys__()):
# Iterate along rows of the array.
keys = set(stringify(k) for k in keys)

# No more than 2 workers should have any keys
assert sum(any(k in w.data for k in keys) for w in workers) <= 2

# What fraction of the keys for this row does each worker hold?
key_fractions = [
len(set(w.data).intersection(keys)) / len(keys) for w in workers
]
key_fractions.sort()
# Primary worker: holds the highest percentage of keys
# Secondary worker: holds the second highest percentage of keys
primary_worker_key_fractions.append(key_fractions[-1])
secondary_worker_key_fractions.append(key_fractions[-2])

# There may be one or two rows that were poorly split across workers,
# but the vast majority of rows should only be on one worker.
assert np.mean(primary_worker_key_fractions) >= 0.9
assert np.median(primary_worker_key_fractions) == 1.0
assert np.mean(secondary_worker_key_fractions) <= 0.1
assert np.median(secondary_worker_key_fractions) == 0.0

# Check that there were few transfers
unexpected_transfers = []
for worker in workers:
for log in worker.incoming_transfer_log:
keys = log["keys"]
# The root-ish tasks should never be transferred
assert not any(k.startswith("random") for k in keys), keys
# `object-` keys (the trivial deps of the root random tasks) should be transferred
if any(not k.startswith("object") for k in keys):
# But not many other things should be
unexpected_transfers.append(list(keys))

# A transfer at the very end to move aggregated results is fine (necessary with unbalanced workers in fact),
# but generally there should be very very few transfers.
assert len(unexpected_transfers) <= 2, unexpected_transfers

test()


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_move_data_over_break_restrictions(client, s, a, b, c):
[x] = await client.scatter([1], workers=b.address)
Expand Down