Skip to content

Commit

Permalink
hack: allow oversaturation
Browse files Browse the repository at this point in the history
this probably isn't quite the right way to do it; only works if root tasks really are in priority order (which I guess they are??)?
  • Loading branch information
gjoseph92 committed Dec 6, 2022
1 parent 7e73f12 commit 796eaf5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
19 changes: 8 additions & 11 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,10 +2131,11 @@ def decide_worker_rootish_queuing_enabled(

if self.validate:
assert ws is not None
assert not _worker_full(ws, self.WORKER_SATURATION), (
ws,
_task_slots_available(ws, self.WORKER_SATURATION),
)
# Oversaturation is ok when scheduling by cogroup
# assert not _worker_full(ws, self.WORKER_SATURATION), (
# ws,
# _task_slots_available(ws, self.WORKER_SATURATION),
# )
assert ws in self.running, (ws, self.running)
assert self.workers.get(ws.address) is ws

Expand Down Expand Up @@ -2205,16 +2206,12 @@ def _update_candidates_for_cogroup(
for ts in cogroup:
# Family member in memory
for ws in ts.who_has:
if ws.status == Status.running and not _worker_full(
ws, self.WORKER_SATURATION
):
if ws.status == Status.running:
candidates[ws] += ts.get_nbytes()
# Family member processing
if (
(ws := ts.processing_on) # NOTE: exclusive with `ts.who_has`
and ws.status == Status.running
and not _worker_full(ws, self.WORKER_SATURATION)
):
ws := ts.processing_on
) and ws.status == Status.running: # NOTE: exclusive with `ts.who_has`
# NOTE: siblings processing on different workers is a rare case
tg = ts.group
nbytes_estimate = (
Expand Down
15 changes: 15 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c):
assert x.key in a.data or x.key in b.data


@gen_cluster(
client=True,
nthreads=[("", 1), ("", 1)],
)
async def test_decide_worker_coschedule_order_binary_op(c, s, a, b):
xs = [delayed(i, name=f"x-{i}") for i in range(8)]
ys = [delayed(i, name=f"y-{i}") for i in range(8)]
zs = [x + y for x, y in zip(xs, ys)]

await c.gather(c.compute(zs))

assert not a.transfer_incoming_log, [l["keys"] for l in a.transfer_incoming_log]
assert not b.transfer_incoming_log, [l["keys"] for l in b.transfer_incoming_log]


@pytest.mark.parametrize("ndeps", [0, 1, 4])
@pytest.mark.parametrize(
"nthreads",
Expand Down

0 comments on commit 796eaf5

Please sign in to comment.