Skip to content

Commit

Permalink
sketch of a decide_worker - doesn't work
Browse files Browse the repository at this point in the history
obviously profoundly slow; want to figure out what logic works before we think about what to pre-compute.

note that I doubt this works at all without being able to oversaturate workers with a family. otherwise, with single-threaded workers, we'll just keep jumping along to a new worker for every task.
  • Loading branch information
gjoseph92 committed Dec 6, 2022
1 parent ba846ad commit 7e73f12
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 21 deletions.
138 changes: 121 additions & 17 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,7 +2085,9 @@ def decide_worker_rootish_queuing_disabled(

return ws

def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:
def decide_worker_rootish_queuing_enabled(
self, ts: TaskState
) -> WorkerState | None:
"""Pick a worker for a runnable root-ish task, if not all are busy.
Picks the least-busy worker out of the ``idle`` workers (idle workers have fewer
Expand All @@ -2110,36 +2112,118 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:
"""
if self.validate:
# We don't `assert self.is_rootish(ts)` here, because that check is
# dependent on cluster size. It's possible a task looked root-ish when it
# was queued, but the cluster has since scaled up and it no longer does when
# coming out of the queue. If `is_rootish` changes to a static definition,
# then add that assertion here (and actually pass in the task).
assert self.is_rootish(ts)
assert not math.isinf(self.WORKER_SATURATION)

if not self.idle_task_count:
# All workers busy? Task gets/stays queued.
return None

# Just pick the least busy worker.
# NOTE: this will lead to worst-case scheduling with regards to co-assignment.
ws = min(
self.idle_task_count,
key=lambda ws: len(ws.processing) / ws.nthreads,
)
ws: WorkerState | None = None
if ts.cogroup:
ws = self.decide_worker_from_cogroup(ts)

if not ws:
ws = min(
self.idle_task_count,
key=lambda ws: len(ws.processing) / ws.nthreads,
)

if self.validate:
assert ws is not None
assert not _worker_full(ws, self.WORKER_SATURATION), (
ws,
_task_slots_available(ws, self.WORKER_SATURATION),
)
assert ws in self.running, (ws, self.running)

if self.validate and ws is not None:
assert self.workers.get(ws.address) is ws
assert ws in self.running, (ws, self.running)

return ws

def decide_worker_from_cogroup(self, ts: TaskState) -> WorkerState | None:
assert (cogroup := ts.cogroup)

candidates = self._candidates_from_cogroup(cogroup)
if not candidates:
return None

ws, nbytes = max(candidates.items(), key=operator.itemgetter(1))
return ws

def _candidates_from_cogroup(
self, cogroup: list[TaskState]
) -> dict[WorkerState, int]:
assert cogroup
candidates: defaultdict[WorkerState, int] = defaultdict(lambda: 0)

# NOTE: because cogroups are disjoint, we _should_ avoid double-counting any tasks,
# even though we traverse over many co-groups.
if self.validate:
seen_ids: set[int] = {id(cogroup)}

self._update_candidates_for_cogroup(candidates, cogroup)
if candidates:
return candidates

# No tasks in immediate group are scheduled or in memory.
# Check dependents.
dependents = list(_dependent_cogroups(cogroup))
# TODO if the group contains a widely-shared task, `dependents` could be every cogroup.
# We should maybe only look at dependents of the apex task? Or skip dependents of the roots?
for dg in dependents:
if self.validate:
assert id(dg) not in seen_ids, dg
seen_ids.add(id(dg))
self._update_candidates_for_cogroup(candidates, dg)
if candidates:
return candidates

# No dependents scheduled, check siblings.
for dg in dependents:
for sg in _dependency_cogroups(dg):
if sg is not cogroup:
if self.validate:
assert id(sg) not in seen_ids, sg
seen_ids.add(id(sg))
self._update_candidates_for_cogroup(candidates, sg)
if candidates:
return candidates

# No siblings, check dependencies.
for dg in _dependency_cogroups(cogroup):
if self.validate:
assert id(dg) not in seen_ids, dg
seen_ids.add(id(dg))
self._update_candidates_for_cogroup(candidates, dg)

return candidates

def _update_candidates_for_cogroup(
self, candidates: defaultdict[WorkerState, int], cogroup: list[TaskState]
) -> None:
ws: WorkerState | None
for ts in cogroup:
# Family member in memory
for ws in ts.who_has:
if ws.status == Status.running and not _worker_full(
ws, self.WORKER_SATURATION
):
candidates[ws] += ts.get_nbytes()
# Family member processing
if (
(ws := ts.processing_on) # NOTE: exclusive with `ts.who_has`
and ws.status == Status.running
and not _worker_full(ws, self.WORKER_SATURATION)
):
# NOTE: siblings processing on different workers is a rare case
tg = ts.group
nbytes_estimate = (
round(tg.nbytes_total / nmem)
if (nmem := tg.states["memory"])
else DEFAULT_DATA_SIZE
)
candidates[ws] += nbytes_estimate

def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
"""Pick a worker for a runnable non-root task, considering dependencies and
restrictions.
Expand Down Expand Up @@ -2226,7 +2310,7 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
return {ts.key: "no-worker"}, {}, {}
else:
if not (ws := self.decide_worker_rootish_queuing_enabled()):
if not (ws := self.decide_worker_rootish_queuing_enabled(ts)):
return {ts.key: "queued"}, {}, {}
else:
if not (ws := self.decide_worker_non_rootish(ts)):
Expand Down Expand Up @@ -2704,7 +2788,7 @@ def transition_queued_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
assert not ts.actor, f"Actors can't be queued: {ts}"
assert ts in self.queued

if ws := self.decide_worker_rootish_queuing_enabled():
if ws := self.decide_worker_rootish_queuing_enabled(ts):
self.queued.discard(ts)
worker_msgs = self._add_to_processing(ts, ws)
# If no worker, task just stays `queued`
Expand Down Expand Up @@ -7870,6 +7954,26 @@ def _task_to_client_msgs(ts: TaskState) -> dict[str, list[dict[str, Any]]]:
return {}


def _dependent_cogroups(cogroup: list[TaskState]) -> Iterator[list[TaskState]]:
ids: set[int] = set()
for ts in cogroup:
# TODO don't check root-ish tasks
for dts in ts.dependents:
if (dcg := dts.cogroup) and dcg is not cogroup and id(dcg) not in ids:
ids.add(id(dcg))
yield dcg


def _dependency_cogroups(cogroup: list[TaskState]) -> Iterator[list[TaskState]]:
ids: set[int] = set()
for ts in cogroup:
# TODO don't check root-ish tasks
for dts in ts.dependencies:
if (dcg := dts.cogroup) and dcg is not cogroup and id(dcg) not in ids:
ids.add(id(dcg))
yield dcg


def decide_worker(
ts: TaskState,
all_workers: Iterable[WorkerState],
Expand Down
5 changes: 1 addition & 4 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,7 @@ def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads):
@gen_cluster(
client=True,
nthreads=nthreads,
config={
"distributed.scheduler.work-stealing": False,
"distributed.scheduler.worker-saturation": float("inf"),
},
config={"distributed.scheduler.work-stealing": False},
)
async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers):
r"""
Expand Down

0 comments on commit 7e73f12

Please sign in to comment.