Skip to content

Commit

Permalink
Speed up test_balance (#7008)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Sep 7, 2022
1 parent 3655f13 commit c7d5ba7
Showing 1 changed file with 52 additions and 54 deletions.
106 changes: 52 additions & 54 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
from time import sleep

import pytest
from tlz import concat, sliding_window
from tlz import sliding_window

import dask
from dask.utils import key_split

from distributed import Event, Lock, Nanny, Worker, profile, wait, worker_client
from distributed.compatibility import LINUX
from distributed.config import config
from distributed.core import Status
from distributed.metrics import time
from distributed.system import MEMORY_LIMIT
Expand Down Expand Up @@ -46,8 +45,7 @@
teardown_module = nodebug_teardown_module


@pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost")
@gen_cluster(client=True, nthreads=[("127.0.0.1", 2), ("127.0.0.2", 2)])
@gen_cluster(client=True, nthreads=[("", 2), ("", 2)])
async def test_work_stealing(c, s, a, b):
[x] = await c._scatter([1], workers=a.address)
futures = c.map(slowadd, range(50), [x] * 50)
Expand Down Expand Up @@ -625,35 +623,37 @@ def slow2(x):
assert any(future.key in w.state.tasks for w in rest)


def func(x):
sleep(1)


async def assert_balanced(inp, expected, c, s, *workers):
async def assert_balanced(inp, expected, recompute_saturation, c, s, *workers):
steal = s.extensions["stealing"]
await steal.stop()
ev = Event()

def block(*args, event, **kwargs):
event.wait()

counter = itertools.count()
tasks = list(concat(inp))
data_seq = itertools.count()

class Sizeof:
def __init__(self, nbytes):
self._nbytes = nbytes - 16

def __sizeof__(self) -> int:
return self._nbytes

futures = []
for w, ts in zip(workers, inp):
for t in sorted(ts, reverse=True):
if t:
[dat] = await c.scatter([next(data_seq)], workers=w.address)
ts = s.tasks[dat.key]
# Ensure scheduler state stays consistent
old_nbytes = ts.nbytes
ts.nbytes = int(s.bandwidth * t)
for ws in ts.who_has:
ws.nbytes += ts.nbytes - old_nbytes
[dat] = await c.scatter(
[Sizeof(int(t * s.bandwidth))], workers=w.address
)
else:
dat = 123
i = next(counter)
f = c.submit(
func,
block,
dat,
event=ev,
key="%d-%d" % (int(t), i),
workers=w.address,
allow_other_workers=True,
Expand All @@ -664,36 +664,36 @@ async def assert_balanced(inp, expected, c, s, *workers):

while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures):
await asyncio.sleep(0.001)
if recompute_saturation:
for ws in s.workers.values():
s._reevaluate_occupancy_worker(ws)
try:
for _ in range(10):
steal.balance()

while steal.in_flight:
await asyncio.sleep(0.001)

result = [
sorted(
(int(key_split(ts.key)) for ts in s.workers[w.address].processing),
reverse=True,
)
for w in workers
]

for _ in range(10):
steal.balance()

while steal.in_flight:
await asyncio.sleep(0.001)

result = [
sorted(
(int(key_split(ts.key)) for ts in s.workers[w.address].processing),
reverse=True,
)
for w in workers
]

result2 = sorted(result, reverse=True)
expected2 = sorted(expected, reverse=True)

if config.get("pdb-on-err"):
if result2 != expected2:
import pdb

pdb.set_trace()
result2 = sorted(result, reverse=True)
expected2 = sorted(expected, reverse=True)

if result2 == expected2:
return
if result2 == expected2:
# Release the threadpools
return
finally:
await ev.set()
raise Exception(f"Expected: {expected2}; got: {result2}")


@pytest.mark.slow
@pytest.mark.parametrize("recompute_saturation", [True, False])
@pytest.mark.parametrize(
"inp,expected",
[
Expand All @@ -712,23 +712,21 @@ async def assert_balanced(inp, expected, c, s, *workers):
[[0, 0], [0, 0], [0, 0], []], # no one clearly saturated
[[0, 0], [0, 0], [0], [0]],
),
# NOTE: There is a timing issue that workers may already start executing
# tasks before we call balance, i.e. the workers will reject the
# stealing request and we end up with a different end result.
# Particularly tests with many input tasks are more likely to fail since
# the test setup takes longer and allows the workers more time to
# schedule a task on the threadpool
(
[[4, 2, 2, 2, 2, 1, 1], [4, 2, 1, 1], [], [], []],
[[4, 2, 2, 2, 2], [4, 2, 1], [1], [1], [1]],
),
pytest.param(
[[1, 1, 1, 1, 1, 1, 1], [1, 1], [1, 1], [1, 1], []],
[[1, 1, 1, 1, 1], [1, 1], [1, 1], [1, 1], [1, 1]],
# Can't mark as flaky as when it fails it does so every time for some reason
marks=pytest.mark.xfail(
reason="Some uncertainty based on executing stolen task"
),
),
],
)
def test_balance(inp, expected):
def test_balance(inp, expected, recompute_saturation):
async def test_balance_(*args, **kwargs):
await assert_balanced(inp, expected, *args, **kwargs)
await assert_balanced(inp, expected, recompute_saturation, *args, **kwargs)

config = {
"distributed.scheduler.default-task-durations": {str(i): 1 for i in range(10)}
Expand Down

0 comments on commit c7d5ba7

Please sign in to comment.