From 1fcf794e40e4635ac506a9584499c5b3e7eed953 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 22 Jul 2021 19:31:17 +0200 Subject: [PATCH 1/6] Refactor worker state machine --- distributed/cfexecutor.py | 4 + distributed/diagnostics/plugin.py | 18 - .../diagnostics/tests/test_worker_plugin.py | 75 +- distributed/scheduler.py | 37 +- distributed/stealing.py | 12 +- distributed/tests/test_cancelled_state.py | 219 ++ distributed/tests/test_client.py | 16 +- distributed/tests/test_client_executor.py | 4 +- distributed/tests/test_failed_workers.py | 83 +- distributed/tests/test_nanny.py | 2 +- distributed/tests/test_steal.py | 32 +- distributed/tests/test_stress.py | 4 +- distributed/tests/test_worker.py | 329 ++- distributed/worker.py | 1920 +++++++++-------- distributed/worker_client.py | 9 +- setup.cfg | 2 - 16 files changed, 1707 insertions(+), 1059 deletions(-) create mode 100644 distributed/tests/test_cancelled_state.py diff --git a/distributed/cfexecutor.py b/distributed/cfexecutor.py index 8028a4bc7f..55d9ccc563 100644 --- a/distributed/cfexecutor.py +++ b/distributed/cfexecutor.py @@ -127,6 +127,10 @@ def map(self, fn, *iterables, **kwargs): raise TypeError("unexpected arguments to map(): %s" % sorted(kwargs)) fs = self._client.map(fn, *iterables, **self._kwargs) + if isinstance(fs, list): + # Below iterator relies on this being a generator to cancel + # remaining futures + fs = (val for val in fs) # Yield must be hidden in closure so that the tasks are submitted # before the first iterator value is required. diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index f9077afddd..a66f762809 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -157,24 +157,6 @@ def transition(self, key, start, finish, **kwargs): kwargs : More options passed when transitioning """ - def release_key(self, key, state, cause, reason, report): - """ - Called when the worker releases a task. - - Parameters - ---------- - key : string - state : string - State of the released task. - One of waiting, ready, executing, long-running, memory, error. - cause : string or None - Additional information on what triggered the release of the task. - reason : None - Not used. - report : bool - Whether the worker should report the released task to the scheduler. - """ - class NannyPlugin: """Interface to extend the Nanny diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 7ae01f0922..cdbeccdeab 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -34,9 +34,6 @@ def transition(self, key, start, finish, **kwargs): {"key": key, "start": start, "finish": finish} ) - def release_key(self, key, state, cause, reason, report): - self.observed_notifications.append({"key": key, "state": state}) - @gen_cluster(client=True, nthreads=[]) async def test_create_with_client(c, s): @@ -107,11 +104,12 @@ async def test_create_on_construction(c, s, a, b): @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) async def test_normal_task_transitions_called(c, s, w): expected_notifications = [ - {"key": "task", "start": "new", "finish": "waiting"}, + {"key": "task", "start": "released", "finish": "waiting"}, {"key": "task", "start": "waiting", "finish": "ready"}, {"key": "task", "start": "ready", "finish": "executing"}, {"key": "task", "start": "executing", "finish": "memory"}, - {"key": "task", "state": "memory"}, + {"key": "task", "start": "memory", "finish": "released"}, + {"key": "task", "start": "released", "finish": "forgotten"}, ] plugin = MyPlugin(1, expected_notifications=expected_notifications) @@ -127,11 +125,12 @@ def failing(x): raise Exception() expected_notifications = [ - {"key": "task", "start": "new", "finish": "waiting"}, + {"key": "task", "start": "released", "finish": "waiting"}, {"key": "task", "start": "waiting", "finish": "ready"}, {"key": "task", "start": "ready", "finish": "executing"}, {"key": "task", "start": "executing", "finish": "error"}, - {"key": "task", "state": "error"}, + {"key": "task", "start": "error", "finish": "released"}, + {"key": "task", "start": "released", "finish": "forgotten"}, ] plugin = MyPlugin(1, expected_notifications=expected_notifications) @@ -147,11 +146,12 @@ def failing(x): ) async def test_superseding_task_transitions_called(c, s, w): expected_notifications = [ - {"key": "task", "start": "new", "finish": "waiting"}, + {"key": "task", "start": "released", "finish": "waiting"}, {"key": "task", "start": "waiting", "finish": "constrained"}, {"key": "task", "start": "constrained", "finish": "executing"}, {"key": "task", "start": "executing", "finish": "memory"}, - {"key": "task", "state": "memory"}, + {"key": "task", "start": "memory", "finish": "released"}, + {"key": "task", "start": "released", "finish": "forgotten"}, ] plugin = MyPlugin(1, expected_notifications=expected_notifications) @@ -166,16 +166,18 @@ async def test_dependent_tasks(c, s, w): dsk = {"dep": 1, "task": (inc, "dep")} expected_notifications = [ - {"key": "dep", "start": "new", "finish": "waiting"}, + {"key": "dep", "start": "released", "finish": "waiting"}, {"key": "dep", "start": "waiting", "finish": "ready"}, {"key": "dep", "start": "ready", "finish": "executing"}, {"key": "dep", "start": "executing", "finish": "memory"}, - {"key": "task", "start": "new", "finish": "waiting"}, + {"key": "task", "start": "released", "finish": "waiting"}, {"key": "task", "start": "waiting", "finish": "ready"}, {"key": "task", "start": "ready", "finish": "executing"}, {"key": "task", "start": "executing", "finish": "memory"}, - {"key": "dep", "state": "memory"}, - {"key": "task", "state": "memory"}, + {"key": "dep", "start": "memory", "finish": "released"}, + {"key": "task", "start": "memory", "finish": "released"}, + {"key": "task", "start": "released", "finish": "forgotten"}, + {"key": "dep", "start": "released", "finish": "forgotten"}, ] plugin = MyPlugin(1, expected_notifications=expected_notifications) @@ -219,3 +221,50 @@ class MyCustomPlugin(WorkerPlugin): await c.register_worker_plugin(MyCustomPlugin()) assert len(w.plugins) == 1 assert next(iter(w.plugins)).startswith("MyCustomPlugin-") + + +def test_release_key_deprecated(): + class ReleaseKeyDeprecated(WorkerPlugin): + def __init__(self): + self._called = False + + def release_key(self, key, state, cause, reason, report): + # Ensure that the handler still works + self._called = True + assert state == "memory" + assert key == "task" + + def teardown(self, worker): + assert self._called + return super().teardown(worker) + + @gen_cluster(client=True, nthreads=[("", 1)]) + async def test(c, s, a): + + await c.register_worker_plugin(ReleaseKeyDeprecated()) + fut = await c.submit(inc, 1, key="task") + assert fut == 2 + + with pytest.deprecated_call( + match="The `WorkerPlugin.release_key` hook is depreacted" + ): + test() + + +def test_assert_no_warning_no_overload(): + """Assert we do not receive a deprecation warning if we do not overload any + methods + """ + + class Dummy(WorkerPlugin): + pass + + @gen_cluster(client=True, nthreads=[("", 1)]) + async def test(c, s, a): + + await c.register_worker_plugin(Dummy()) + fut = await c.submit(inc, 1, key="task") + assert fut == 2 + + with pytest.warns(None): + test() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8beafdbf88..667464dc3a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2229,7 +2229,7 @@ def _transition(self, key, finish: str, *args, **kwargs): self._transition_counter += 1 recommendations, client_msgs, worker_msgs = a elif "released" not in start_finish: - assert not args and not kwargs + assert not args and not kwargs, start_finish a_recs: dict a_cmsgs: dict a_wmsgs: dict @@ -3048,7 +3048,11 @@ def transition_processing_released(self, key): w: str = _remove_from_processing(self, ts) if w: worker_msgs[w] = [ - {"op": "free-keys", "keys": [key], "reason": "Processing->Released"} + { + "op": "free-keys", + "keys": [key], + "reason": f"processing-released-{time()}", + } ] ts.state = "released" @@ -5398,7 +5402,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): self.log.append(("missing", key, errant_worker)) ts: TaskState = parent._tasks.get(key) - if ts is None or not ts._who_has: + if ts is None: return ws: WorkerState = parent._workers_dv.get(errant_worker) if ws is not None and ws in ts._who_has: @@ -5411,15 +5415,15 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): else: self.transitions({key: "forgotten"}) - def release_worker_data(self, comm=None, keys=None, worker=None): + def release_worker_data(self, comm=None, key=None, worker=None): parent: SchedulerState = cast(SchedulerState, self) + if worker not in parent._workers_dv: + return ws: WorkerState = parent._workers_dv[worker] - tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks} - removed_tasks: set = tasks.intersection(ws._has_what) - ts: TaskState + ts = parent._tasks.get(key) recommendations: dict = {} - for ts in removed_tasks: + if ts and ts in ws._has_what: del ws._has_what[ts] ws._nbytes -= ts.get_nbytes() wh: set = ts._who_has @@ -6670,7 +6674,7 @@ def add_keys(self, comm=None, worker=None, keys=()): if worker not in parent._workers_dv: return "not found" ws: WorkerState = parent._workers_dv[worker] - superfluous_data = [] + redundant_replicas = [] for key in keys: ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state == "memory": @@ -6679,14 +6683,15 @@ def add_keys(self, comm=None, worker=None, keys=()): ws._has_what[ts] = None ts._who_has.add(ws) else: - superfluous_data.append(key) - if superfluous_data: + redundant_replicas.append(key) + + if redundant_replicas: self.worker_send( worker, { - "op": "superfluous-data", - "keys": superfluous_data, - "reason": f"Add keys which are not in-memory {superfluous_data}", + "op": "remove-replicas", + "keys": redundant_replicas, + "stimulus_id": f"redundant-replicas-{time()}", }, ) @@ -7794,6 +7799,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> "key": ts._key, "priority": ts._priority, "duration": duration, + "stimulus_id": f"compute-task-{time()}", + "who_has": {}, } if ts._resource_restrictions: msg["resource_restrictions"] = ts._resource_restrictions @@ -7818,6 +7825,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> if ts._annotations: msg["annotations"] = ts._annotations + + assert "stimulus_id" in msg return msg diff --git a/distributed/stealing.py b/distributed/stealing.py index e3398b4c9a..46d7a8df6d 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -230,7 +230,15 @@ async def move_task_confirm(self, key=None, worker=None, state=None): return # Victim had already started execution, reverse stealing - if state in ("memory", "executing", "long-running", None): + if state in ( + "memory", + "executing", + "long-running", + "released", + "cancelled", + "resumed", + None, + ): self.log(("already-computing", key, victim.address, thief.address)) self.scheduler.check_idle_saturated(thief) self.scheduler.check_idle_saturated(victim) @@ -256,7 +264,7 @@ async def move_task_confirm(self, key=None, worker=None, state=None): await self.scheduler.remove_worker(thief.address) self.log(("confirm", key, victim.address, thief.address)) else: - raise ValueError("Unexpected task state: %s" % state) + raise ValueError(f"Unexpected task state: {ts}") except Exception as e: logger.exception(e) if LOG_PDB: diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py new file mode 100644 index 0000000000..b0906fc983 --- /dev/null +++ b/distributed/tests/test_cancelled_state.py @@ -0,0 +1,219 @@ +import asyncio +from unittest import mock + +import pytest + +from distributed import Nanny, wait +from distributed.core import CommClosedError +from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc + +pytestmark = pytest.mark.ci1 + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_abort_execution_release(c, s, w): + fut = c.submit(slowinc, 1, delay=0.5, key="f1") + + async def wait_for_exec(dask_worker): + while ( + fut.key not in dask_worker.tasks + or dask_worker.tasks[fut.key].state != "executing" + ): + await asyncio.sleep(0.01) + + await c.run(wait_for_exec) + + fut.release() + fut2 = c.submit(inc, 1, key="f2") + + async def observe(dask_worker): + cancelled = False + while ( + fut.key in dask_worker.tasks + and dask_worker.tasks[fut.key].state != "released" + ): + if dask_worker.tasks[fut.key].state == "cancelled": + cancelled = True + await asyncio.sleep(0.005) + return cancelled + + assert await c.run(observe) + await fut2 + del fut2 + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_abort_execution_reschedule(c, s, w): + fut = c.submit(slowinc, 1, delay=1) + + async def wait_for_exec(dask_worker): + while ( + fut.key not in dask_worker.tasks + or dask_worker.tasks[fut.key].state != "executing" + ): + await asyncio.sleep(0.01) + + await c.run(wait_for_exec) + + fut.release() + + async def observe(dask_worker): + while ( + fut.key in dask_worker.tasks + and dask_worker.tasks[fut.key].state != "released" + ): + if dask_worker.tasks[fut.key].state == "cancelled": + return + await asyncio.sleep(0.005) + + assert await c.run(observe) + fut = c.submit(slowinc, 1, delay=1) + await fut + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_abort_execution_add_as_dependency(c, s, w): + fut = c.submit(slowinc, 1, delay=1) + + async def wait_for_exec(dask_worker): + while ( + fut.key not in dask_worker.tasks + or dask_worker.tasks[fut.key].state != "executing" + ): + await asyncio.sleep(0.01) + + await c.run(wait_for_exec) + + fut.release() + + async def observe(dask_worker): + while ( + fut.key in dask_worker.tasks + and dask_worker.tasks[fut.key].state != "released" + ): + if dask_worker.tasks[fut.key].state == "cancelled": + return + await asyncio.sleep(0.005) + + assert await c.run(observe) + fut = c.submit(slowinc, 1, delay=1) + fut = c.submit(slowinc, fut, delay=1) + await fut + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2, Worker=Nanny) +async def test_abort_execution_to_fetch(c, s, a, b): + fut = c.submit(slowinc, 1, delay=2, key="f1", workers=[a.worker_address]) + + async def wait_for_exec(dask_worker): + while ( + fut.key not in dask_worker.tasks + or dask_worker.tasks[fut.key].state != "executing" + ): + await asyncio.sleep(0.01) + + await c.run(wait_for_exec, workers=[a.worker_address]) + + fut.release() + + async def observe(dask_worker): + while ( + fut.key in dask_worker.tasks + and dask_worker.tasks[fut.key].state != "released" + ): + if dask_worker.tasks[fut.key].state == "cancelled": + return + await asyncio.sleep(0.005) + + assert await c.run(observe) + + # While the first worker is still trying to compute f1, we'll resubmit it to + # another worker with a smaller delay. The key is still the same + fut = c.submit(slowinc, 1, delay=0, key="f1", workers=[b.worker_address]) + # then, a must switch the execute to fetch. Instead of doing so, it will + # simply re-use the currently computing result. + fut = c.submit(slowinc, fut, delay=1, workers=[a.worker_address], key="f2") + await fut + + +@gen_cluster(client=True, nthreads=[("", 1)] * 2) +async def test_worker_find_missing(c, s, *workers): + + fut = c.submit(slowinc, 1, delay=0.5, workers=[workers[0].address]) + await wait(fut) + # We do not want to use proper API since the API usually ensures that the cluster is informed properly. We want to + del workers[0].data[fut.key] + del workers[0].tasks[fut.key] + + # Actually no worker has the data, the scheduler is supposed to reschedule + await c.submit(inc, fut, workers=[workers[1].address]) + + +@gen_cluster(client=True) +async def test_worker_stream_died_during_comm(c, s, a, b): + write_queue = asyncio.Queue() + write_event = asyncio.Event() + b.rpc = _LockedCommPool( + b.rpc, + write_queue=write_queue, + write_event=write_event, + ) + fut = c.submit(inc, 1, workers=[a.address], allow_other_workers=True) + await wait(fut) + # Actually no worker has the data, the scheduler is supposed to reschedule + res = c.submit(inc, fut, workers=[b.address]) + + await write_queue.get() + await a.close() + write_event.set() + + await res + assert any("receive-dep-failed" in msg for msg in b.log) + + +@gen_cluster(client=True) +async def test_flight_to_executing_via_cancelled_resumed(c, s, a, b): + import asyncio + + import distributed + + lock = asyncio.Lock() + await lock.acquire() + + async def wait_and_raise(*args, **kwargs): + async with lock: + raise CommClosedError() + + with mock.patch.object( + distributed.worker, + "get_data_from_worker", + side_effect=wait_and_raise, + ): + fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True) + + fut2 = c.submit(inc, fut1, workers=[b.address]) + + async def observe(dask_worker): + while ( + fut1.key not in dask_worker.tasks + or dask_worker.tasks[fut1.key].state != "flight" + ): + await asyncio.sleep(0) + + await c.run(observe, workers=[b.address]) + + # Close in scheduler to ensure we transition and reschedule task properly + await s.close_worker(worker=a.address) + + while fut1.key not in b.tasks or b.tasks[fut1.key].state != "resumed": + await asyncio.sleep(0) + + lock.release() + assert await fut2 == 3 + + b_story = b.story(fut1.key) + assert any("receive-dep-failed" in msg for msg in b_story) + assert any("missing-dep" in msg for msg in b_story) + if not any("cancelled" in msg for msg in b_story): + breakpoint() + assert any("resumed" in msg for msg in b_story) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index a65ab4494d..e3f0d6c953 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5149,7 +5149,8 @@ def f(): @gen_cluster(client=True) async def test_secede_balances(c, s, a, b): - count = threading.active_count() + """Ensure that tasks scheduled from a seceded thread can be scheduled + elsewhere""" def f(x): client = get_client() @@ -5158,15 +5159,16 @@ def f(x): total = client.submit(sum, futures).result() return total - futures = c.map(f, range(100)) + futures = c.map(f, range(10), workers=[a.address]) results = await c.gather(futures) + # We dispatch 10 tasks and every task generates 11 more tasks + # 10 * 11 + 10 + assert a.executed_count + b.executed_count == 120 + assert a.executed_count >= 10 + assert b.executed_count > 0 - assert a.executed_count + b.executed_count == 1100 - assert a.executed_count > 200 - assert b.executed_count > 200 - - assert results == [sum(map(inc, range(10)))] * 100 + assert results == [sum(map(inc, range(10)))] * 10 @gen_cluster(client=True) diff --git a/distributed/tests/test_client_executor.py b/distributed/tests/test_client_executor.py index 8a50f5f27d..fcfc9472ad 100644 --- a/distributed/tests/test_client_executor.py +++ b/distributed/tests/test_client_executor.py @@ -12,7 +12,6 @@ import pytest from tlz import take -from distributed.compatibility import MACOS from distributed.utils import CancelledError from distributed.utils_test import inc, slowadd, slowdec, slowinc, throws, varying @@ -127,7 +126,7 @@ def test_cancellation_as_completed(client): assert n_cancelled == 2 -@pytest.mark.flaky(condition=MACOS, reruns=10, reruns_delay=5) +@pytest.mark.slow() def test_map(client): with client.get_executor() as e: N = 10 @@ -156,6 +155,7 @@ def test_map(client): assert number_of_processing_tasks(client) > 0 # Garbage collect the iterator => remaining tasks are cancelled del it + time.sleep(0.1) assert number_of_processing_tasks(client) == 0 diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 83eb8ae8e5..8e5d01167d 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -407,9 +407,10 @@ def __sizeof__(self) -> int: return parse_bytes(dask.config.get("distributed.comm.offload")) + 1 -@pytest.mark.flaky(reruns=10, reruns_delay=5) @gen_cluster(client=True) async def test_worker_who_has_clears_after_failed_connection(c, s, a, b): + """This test is very sensitive to cluster state consistency. Timeouts often + indicate subtle deadlocks. Be mindful when marking flaky/repeat/etc.""" n = await Nanny(s.address, nthreads=2, loop=s.loop) while len(s.nthreads) < 3: @@ -526,82 +527,6 @@ async def test_worker_time_to_live(c, s, a, b): set(s.workers) == {b.address} -class SlowDeserialize: - def __init__(self, data, delay=0.1): - self.delay = delay - self.data = data - - def __getstate__(self): - return self.delay - - def __setstate__(self, state): - delay = state - import time - - time.sleep(delay) - return SlowDeserialize(delay) - - def __sizeof__(self) -> int: - # Ensure this is offloaded to avoid blocking loop - import dask - from dask.utils import parse_bytes - - return parse_bytes(dask.config.get("distributed.comm.offload")) + 1 - - -@gen_cluster(client=True) -async def test_handle_superfluous_data(c, s, a, b): - """ - See https://github.com/dask/distributed/pull/4784#discussion_r649210094 - """ - - def slow_deser(x, delay): - return SlowDeserialize(x, delay=delay) - - futA = c.submit( - slow_deser, 1, delay=1, workers=[a.address], key="A", allow_other_workers=True - ) - futB = c.submit(inc, 1, workers=[b.address], key="B") - await wait([futA, futB]) - - def reducer(*args): - return - - assert len(a.tasks) == 1 - assert futA.key in a.tasks - - assert len(b.tasks) == 1 - assert futB.key in b.tasks - - red = c.submit(reducer, [futA, futB], workers=[b.address], key="reducer") - - dep_key = futA.key - - # Wait for the connection to be established - while dep_key not in b.tasks or not b.tasks[dep_key].state == "flight": - await asyncio.sleep(0.001) - - # Wait for the connection to be returned to the pool. this signals that - # worker B is done with the communication and is about to deserialize the - # result - while a.address not in b.rpc.available and not b.rpc.available[a.address]: - await asyncio.sleep(0.001) - - assert b.tasks[dep_key].state == "flight" - # After the comm is finished and the deserialization starts, Worker B - # wouldn't notice that A dies. - await a.close() - # However, while B is busy deserializing a third worker might notice that A - # is dead and issues a handle-missing signal to the scheduler. Since at this - # point in time, A was the only worker with a verified replica, the - # scheduler reschedules the computation by transitioning it to released. The - # released transition has the side effect that it purges all data which is - # in memory which exposes us to a race condition on B if B also receives the - # signal to compute that task in the meantime. - s.handle_missing_data(key=dep_key, errant_worker=a.address) - await red - - @gen_cluster() async def test_forget_data_not_supposed_to_have(s, a, b): """ @@ -618,7 +543,9 @@ async def test_forget_data_not_supposed_to_have(s, a, b): ts = TaskState("key") ts.state = "flight" a.tasks["key"] = ts - a.transition_flight_memory(ts, value=123) + recommendations = {ts: ("memory", 123)} + a.transitions(recommendations, stimulus_id="test") + assert a.data while a.data: await asyncio.sleep(0.001) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index a5206e4f2e..16a16d8028 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -569,7 +569,7 @@ async def test_failure_during_worker_initialization(s): assert "Restarting worker" not in logs.getvalue() -@gen_cluster(client=True, Worker=Nanny, timeout=10000000) +@gen_cluster(client=True, Worker=Nanny) async def test_environ_plugin(c, s, a, b): from dask.distributed import Environ diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 00c6efb74f..3b4793523b 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -152,7 +152,7 @@ def do_nothing(x, y=None): xs = c.map(do_nothing, range(10), workers=workers[0].address) await wait(xs) - futures = c.map(do_nothing, range(1000), y=xs) + futures = c.map(do_nothing, range(100), y=xs) await wait(futures) @@ -362,10 +362,8 @@ async def test_steal_resource_restrictions(c, s, a): b = await Worker(s.address, loop=s.loop, nthreads=1, resources={"A": 4}) - start = time() while not b.tasks or len(a.tasks) == 101: await asyncio.sleep(0.01) - assert time() < start + 3 assert len(b.tasks) > 0 assert len(a.tasks) < 101 @@ -680,24 +678,30 @@ async def test_steal_twice(c, s, a, b): async def test_dont_steal_already_released(c, s, a, b): future = c.submit(slowinc, 1, delay=0.05, workers=a.address) key = future.key - await asyncio.sleep(0.05) - assert key in a.tasks + while key not in a.tasks: + await asyncio.sleep(0.05) + del future - await asyncio.sleep(0.05) + # In case the system is slow (e.g. network) ensure that nothing bad happens # if the key was already released - assert key not in a.tasks + while key in a.tasks and a.tasks[key].state != "released": + await asyncio.sleep(0.05) + a.steal_request(key) - assert a.batched_stream.buffer == [ - {"op": "steal-response", "key": key, "state": None} - ] + assert len(a.batched_stream.buffer) == 1 + msg = a.batched_stream.buffer[0] + assert msg["op"] == "steal-response" + assert msg["key"] == key + assert msg["state"] in [None, "released"] + with captured_logger( logging.getLogger("distributed.stealing"), level=logging.DEBUG ) as stealing_logs: - await asyncio.sleep(0.05) - - logs = stealing_logs.getvalue() - assert f"Key released between request and confirm: {key}" in logs + logs = stealing_logs.getvalue() + while f"Key released between request and confirm: {key}" not in logs: + await asyncio.sleep(0.05) + logs = stealing_logs.getvalue() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 3781e7c38e..c11353e1df 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -174,9 +174,11 @@ def vsum(*args): @pytest.mark.avoid_ci @pytest.mark.slow @pytest.mark.timeout(1100) # Override timeout from setup.cfg -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 80, timeout=1000) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 80) async def test_stress_communication(c, s, *workers): s.validate = False # very slow otherwise + for w in workers: + w.validate = False da = pytest.importorskip("dask.array") # Test consumes many file descriptors and can hang if the limit is too low resource = pytest.importorskip("resource") diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6cee3388f2..aa4670608c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -7,7 +7,6 @@ import traceback from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool -from contextlib import suppress from numbers import Number from operator import add from time import sleep @@ -1775,8 +1774,8 @@ async def test_story_with_deps(c, s, a, b): Assert that the structure of the story does not change unintentionally and expected subfields are actually filled """ - futures = c.map(inc, range(10), workers=[a.address]) - res = c.submit(sum, futures, workers=[b.address]) + dep = c.submit(inc, 1, workers=[a.address]) + res = c.submit(inc, dep, workers=[b.address]) await res key = res.key @@ -1784,32 +1783,71 @@ async def test_story_with_deps(c, s, a, b): assert story == [] story = b.story(key) + pruned_story = [] + stimulus_ids = set() + # Story now includes randomized stimulus_ids and timestamps. + for msg in story: + assert isinstance(msg, tuple), msg + assert isinstance(msg[-1], float), msg + assert msg[-1] > time() - 60, msg + pruned_msg = list(msg) + stimulus_ids.add(msg[-2]) + pruned_story.append(tuple(pruned_msg[:-2])) + + assert len(stimulus_ids) == 3 + stimulus_id = pruned_story[0][-1] + assert isinstance(stimulus_id, str) + assert stimulus_id.startswith("compute-task") + # This is a simple transition log expected_story = [ - (key, "new"), - (key, "new", "waiting"), - # First log is what needs to be fetched in total as determined in - # ensure_communicating + (key, "compute-task"), + (key, "released", "waiting", {}), + (key, "waiting", "ready", {}), + (key, "ready", "executing", {}), + (key, "put-in-memory"), + (key, "executing", "memory", {}), + ] + assert pruned_story == expected_story + + dep_story = dep.key + + story = b.story(dep_story) + pruned_story = [] + stimulus_ids = set() + for msg in story: + assert isinstance(msg, tuple), msg + assert isinstance(msg[-1], float), msg + assert msg[-1] > time() - 60, msg + pruned_msg = list(msg) + stimulus_ids.add(msg[-2]) + pruned_story.append(tuple(pruned_msg[:-2])) + + assert len(stimulus_ids) == 3 + stimulus_id = pruned_story[0][-1] + assert isinstance(stimulus_id, str) + expected_story = [ + (dep_story, "register-replica", "released"), + (dep_story, "released", "fetch", {}), ( "gather-dependencies", - key, - {fut.key for fut in futures}, + a.address, + {dep.key}, ), - # Second log may just be a subset of the above, see also - # Worker.select_keys_for_gather - # This case, it's all because Worker.target_message_size is sufficiently - # large + (dep_story, "fetch", "flight", {}), ( "request-dep", - key, a.address, - {fut.key for fut in futures}, + {dep.key}, ), - (key, "waiting", "ready"), - (key, "ready", "executing"), - (key, "executing", "memory"), - (key, "put-in-memory"), + ( + "receive-dep", + a.address, + {dep.key}, + ), + (dep_story, "put-in-memory"), + (dep_story, "flight", "memory", {res.key: "ready"}), ] - assert story == expected_story + assert pruned_story == expected_story @gen_cluster(client=True) @@ -2129,6 +2167,8 @@ def raise_exc(*args): await asyncio.sleep(0.01) expected_states = { + f.key: "released", + g.key: "released", res.key: "error", } @@ -2321,6 +2361,7 @@ def raise_exc(*args): assert_task_states_on_worker(expected_states_A, a) expected_states_B = { + f.key: "released", g.key: "memory", h.key: "memory", res.key: "error", @@ -2331,6 +2372,7 @@ def raise_exc(*args): g.release() expected_states_A = { + g.key: "released", h.key: "memory", } await asyncio.sleep(0.05) @@ -2338,6 +2380,7 @@ def raise_exc(*args): # B must not forget a task since all have a still valid dependent expected_states_B = { + f.key: "released", h.key: "memory", res.key: "error", } @@ -2348,6 +2391,8 @@ def raise_exc(*args): expected_states_A = {} assert_task_states_on_worker(expected_states_A, a) expected_states_B = { + f.key: "released", + h.key: "released", res.key: "error", } @@ -2382,7 +2427,7 @@ async def test_hold_on_to_replicas(c, s, *workers): assert s.tasks[f2.key].state == "released" await asyncio.sleep(0.01) - while len(workers[2].tasks) > 1: + while len(workers[2].data) > 1: await asyncio.sleep(0.01) @@ -2571,33 +2616,14 @@ def sink(a, b, *args): b.tasks[fut3.key].state = "fetch" event.set() - with captured_logger("distributed.worker", level=logging.DEBUG) as worker_logs: - - # FIXME: We currently have no reliable, safe way to release the task and - # its dependent without race conditions - - # Unfortunately res1 is deadlocking. IRL this is not always a problem - # since a commonly reported transition is Fetch->Memory, i.e. the task - # exists already in memory for whatever reason but a gather_dep was - # still runnign, e.g. the task was rescheduled on that worker and it was - # computed faster than fetched. - - with suppress(TimeoutError): - await asyncio.wait_for(res1, 0.1) - - assert await res2 == 5 + assert await res1 == 5 + assert await res2 == 5 - del res1, res2, fut, fut2 - fut3.release() + del res1, res2, fut, fut2 + fut3.release() - while a.tasks and b.tasks: - await asyncio.sleep(0.1) - - expected_msg = ( - "Exception occured while handling `gather_dep` response for " - ) - assert expected_msg in worker_logs.getvalue() - assert any("except-gather-dep-result" in msg for msg in b.story(fut3.key)) + while a.tasks and b.tasks: + await asyncio.sleep(0.1) @gen_cluster(client=True) @@ -2621,3 +2647,214 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b): s.handle_missing_data(key="f1", errant_worker=a.address) await fut2 + + +def _acquire_replica(scheduler, worker, future): + if not isinstance(future, list): + keys = [future.key] + else: + keys = [f.key for f in future] + + scheduler.stream_comms[worker.address].send( + { + "op": "acquire-replica", + "keys": keys, + "stimulus_id": time(), + "priorities": {key: scheduler.tasks[key].priority for key in keys}, + "who_has": { + key: {w.address for w in scheduler.tasks[key].who_has} for key in keys + }, + }, + ) + + +def _remove_replica(scheduler, worker, future): + if not isinstance(future, list): + keys = [future.key] + else: + keys = [f.key for f in future] + + scheduler.stream_comms[worker.address].send( + { + "op": "remove-replicas", + "keys": keys, + "stimulus_id": time(), + } + ) + + +@gen_cluster(client=True) +async def test_acquire_replica(c, s, a, b): + fut = c.submit(inc, 1, workers=[a.address]) + await fut + + _acquire_replica(s, b, fut) + + while not len(s.who_has[fut.key]) == 2: + await asyncio.sleep(0.005) + + for w in [a, b]: + assert fut.key in w.tasks + assert w.tasks[fut.key].state == "memory" + + fut.release() + + while b.tasks or a.tasks: + await asyncio.sleep(0.005) + + +@gen_cluster(client=True) +async def test_acquire_replica_same_channel(c, s, a, b): + fut = c.submit(inc, 1, workers=[a.address], key="f-replica") + futB = c.submit(inc, 2, workers=[a.address], key="f-B") + futC = c.submit(inc, futB, workers=[b.address], key="f-C") + await fut + + _acquire_replica(s, b, fut) + + await futC + while fut.key not in b.tasks: + await asyncio.sleep(0.005) + assert len(s.who_has[fut.key]) == 2 + + # Ensure that both the replica and an ordinary dependency pass through the + # same communication channel + + for f in [fut, futB]: + assert any(("request-dep" in msg for msg in b.story(f.key))) + assert any(("gather-dependencies" in msg for msg in b.story(f.key))) + assert any((f.key in msg["keys"] for msg in b.incoming_transfer_log)) + + +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) +async def test_acquire_replica_many(c, s, *workers): + futs = c.map(inc, range(10), workers=[workers[0].address]) + res = c.submit(sum, futs, workers=[workers[1].address]) + final = c.submit(slowinc, res, delay=0.5, workers=[workers[1].address]) + + await wait(futs) + + _acquire_replica(s, workers[2], futs) + + # Worker 2 should normally not even be involved if there was no replication + while not all( + f.key in workers[2].tasks and workers[2].tasks[f.key].state == "memory" + for f in futs + ): + await asyncio.sleep(0.01) + + assert all(ts.state == "memory" for ts in workers[2].tasks.values()) + + assert await final == sum(map(inc, range(10))) + 1 + # All workers have a replica + assert all(len(s.tasks[f.key].who_has) == 3 for f in futs) + del futs, res, final + + while any(w.tasks for w in workers): + await asyncio.sleep(0.001) + + +@gen_cluster(client=True) +async def test_remove_replica_simple(c, s, a, b): + futs = c.map(inc, range(10), workers=[a.address]) + await wait(futs) + _acquire_replica(s, b, futs) + + while not all(len(s.tasks[f.key].who_has) == 2 for f in futs): + await asyncio.sleep(0.01) + + _remove_replica(s, b, futs) + + while b.tasks: + await asyncio.sleep(0.01) + + # might take a moment for the reply to reach the scheduler + while not all(len(s.tasks[f.key].who_has) == 1 for f in futs): + await asyncio.sleep(0.01) + + +@gen_cluster(client=True) +async def test_remove_replica_while_computing(c, s, *workers): + futs = c.map(inc, range(10), workers=[workers[0].address]) + + # All interesting things will happen on that worker + w = workers[1] + intermediate = c.map(slowinc, futs, delay=0.1, workers=[w.address]) + + def reduce(*args, **kwargs): + import time + + time.sleep(0.5) + return + + final = c.submit(reduce, intermediate, workers=[w.address], key="final") + while final.key not in w.tasks: + await asyncio.sleep(0.001) + + while not all(fut.done() for fut in intermediate): + # The worker should reject all of these since they are required + _remove_replica(s, w, futs) + _remove_replica(s, w, intermediate) + await asyncio.sleep(0.001) + + await wait(intermediate) + + # Since intermediate is done, futs replicas may be removed. + # They might be already gone due to the above remove replica calls + _remove_replica(s, w, futs) + # the intermediate tasks should not be touched because they are still needed + # (the scheduler should not have made the above call but we should be safe + # regarless) + assert all(w.tasks[f.key].state == "memory" for f in intermediate) + + while any(w.tasks[f.key].state != "released" for f in futs if f.key in w.tasks): + await asyncio.sleep(0.001) + + # The scheduler actually gets notified about the removed replica + while not all(len(s.tasks[f.key].who_has) == 1 for f in futs): + await asyncio.sleep(0.001) + + await final + del final, intermediate, futs + + while any(w.tasks for w in workers): + await asyncio.sleep(0.001) + + +@gen_cluster(client=True, nthreads=[("", 1)] * 3) +async def test_who_has_consistent_remove_replica(c, s, *workers): + + a = workers[0] + other_workers = {w for w in workers if w != a} + f1 = c.submit(inc, 1, key="f1", workers=[w.address for w in other_workers]) + await wait(f1) + for w in other_workers: + _acquire_replica(s, w, f1) + + while not len(s.tasks[f1.key].who_has) == len(other_workers): + await asyncio.sleep(0) + + f2 = c.submit(inc, f1, workers=[a.address]) + + # Wait just until the moment the worker received the task and scheduled the + # task to be fetched, then remove the replica from the worker this one is + # trying to get the data from. Ensure this is handled gracefully and no + # suspicious counters are raised since this is expected behaviour when + # removing replicas + + while f1.key not in a.tasks or a.tasks[f1.key].state != "flight": + await asyncio.sleep(0) + + coming_from = None + for w in other_workers: + coming_from = w + if w.address == a.tasks[f1.key].coming_from: + break + + coming_from.handle_remove_replicas([f1.key], "test") + + await f2 + + assert ("missing-dep", f1.key) in a.story(f1.key) + assert a.tasks[f1.key].suspicious_count == 0 + assert s.tasks[f1.key].suspicious == 0 diff --git a/distributed/worker.py b/distributed/worker.py index 39e77a8bb9..8324df16e8 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -80,9 +80,15 @@ no_value = "--no-value-sentinel--" -IN_PLAY = ("waiting", "ready", "executing", "long-running") -PENDING = ("waiting", "ready", "constrained") -PROCESSING = ("waiting", "ready", "constrained", "executing", "long-running") +PROCESSING = ( + "waiting", + "ready", + "constrained", + "executing", + "long-running", + "cancelled", + "resumed", +) READY = ("ready", "constrained") @@ -99,6 +105,10 @@ SerializedTask = namedtuple("SerializedTask", ["function", "args", "kwargs", "task"]) +class InvalidTransition(Exception): + pass + + class TaskState: """Holds volatile state relating to an individual Dask task @@ -164,11 +174,11 @@ def __init__(self, key, runspec=None): self.dependents = set() self.duration = None self.priority = None - self.state = "new" + self.state = "released" self.who_has = set() self.coming_from = None self.waiting_for_data = set() - self.resource_restrictions = None + self.resource_restrictions = {} self.exception = None self.exception_text = "" self.traceback = None @@ -181,7 +191,8 @@ def __init__(self, key, runspec=None): self.metadata = {} self.nbytes = None self.annotations = None - self.scheduler_holds_ref = False + self.done = False + self._next = None def __repr__(self): return f"" @@ -190,6 +201,11 @@ def get_nbytes(self) -> int: nbytes = self.nbytes return nbytes if nbytes is not None else DEFAULT_DATA_SIZE + def is_protected(self) -> bool: + return self.state in PROCESSING or any( + dep_ts.state in PROCESSING for dep_ts in self.dependents + ) + class Worker(ServerNode): """Worker node in a Dask distributed cluster @@ -421,9 +437,8 @@ def __init__( self.nanny = nanny self._lock = threading.Lock() - self.data_needed = deque() # TODO: replace with heap? + self.data_needed = list() - self.in_flight_tasks = 0 self.in_flight_workers = dict() self.total_out_connections = dask.config.get( "distributed.worker.connections.outgoing" @@ -449,7 +464,8 @@ def __init__( self.ready = list() self.constrained = deque() - self.executing_count = 0 + self._executing = set() + self._in_flight_tasks = set() self.executed_count = 0 self.long_running = set() @@ -462,33 +478,56 @@ def __init__( if validate is None: validate = dask.config.get("distributed.scheduler.validate") self.validate = validate - - self._transitions = { - # Basic state transitions - ("new", "waiting"): self.transition_new_waiting, - ("new", "fetch"): self.transition_new_fetch, - ("waiting", "ready"): self.transition_waiting_ready, - ("fetch", "flight"): self.transition_fetch_flight, - ("ready", "executing"): self.transition_ready_executing, - ("executing", "memory"): self.transition_executing_done, - ("flight", "memory"): self.transition_flight_memory, - ("flight", "fetch"): self.transition_flight_fetch, - # Shouldn't be a valid transition but happens nonetheless - ("ready", "memory"): self.transition_ready_memory, - # Scheduler intercession (re-assignment) - ("fetch", "waiting"): self.transition_fetch_waiting, - ("flight", "waiting"): self.transition_flight_waiting, - # Errors, long-running, constrained - ("waiting", "error"): self.transition_waiting_done, + self._transitions_table = { + ("cancelled", "resumed"): self.transition_cancelled_resumed, + ("cancelled", "fetch"): self.transition_cancelled_fetch, + ("cancelled", "released"): self.transition_cancelled_released, + ("cancelled", "waiting"): self.transition_cancelled_waiting, + ("cancelled", "forgotten"): self.transition_cancelled_forgotten, + ("cancelled", "memory"): self.transition_cancelled_memory, + ("cancelled", "error"): self.transition_generic_error, + ("resumed", "memory"): self._transition_to_memory_generic, + ("resumed", "error"): self.transition_generic_error, + ("resumed", "released"): self.transition_released_generic, + ("resumed", "waiting"): self.transition_rescheduled_next, + ("resumed", "fetch"): self.transition_rescheduled_next, ("constrained", "executing"): self.transition_constrained_executing, - ("executing", "error"): self.transition_executing_done, - ("executing", "rescheduled"): self.transition_executing_done, + ("constrained", "released"): self.transition_constrained_released, + ("error", "released"): self.transition_released_generic, + ("executing", "error"): self.transition_executing_error, ("executing", "long-running"): self.transition_executing_long_running, - ("long-running", "error"): self.transition_executing_done, - ("long-running", "memory"): self.transition_executing_done, - ("long-running", "rescheduled"): self.transition_executing_done, + ("executing", "memory"): self.transition_executing_memory, + ("executing", "released"): self.transition_executing_released, + ("executing", "rescheduled"): self.transition_executing_rescheduled, + ("fetch", "flight"): self.transition_fetch_flight, + ("fetch", "missing"): self.transition_fetch_missing, + ("fetch", "released"): self.transition_generic_released, + ("flight", "error"): self.transition_flight_error, + ("flight", "fetch"): self.transition_flight_fetch, + ("flight", "memory"): self.transition_flight_memory, + ("flight", "released"): self.transition_flight_released, + ("long-running", "error"): self.transition_long_running_error, + ("long-running", "memory"): self.transition_long_running_memory, + ("long-running", "rescheduled"): self.transition_executing_rescheduled, + ("long-running", "released"): self.transition_executing_released, + ("memory", "released"): self.transition_memory_released, + ("missing", "fetch"): self.transition_missing_fetch, + ("missing", "released"): self.transition_missing_released, + ("missing", "error"): self.transition_generic_error, + ("ready", "error"): self.transition_generic_error, + ("ready", "executing"): self.transition_ready_executing, + ("ready", "released"): self.transition_released_generic, + ("released", "error"): self.transition_generic_error, + ("released", "fetch"): self.transition_released_fetch, + ("released", "forgotten"): self.transition_released_forgotten, + ("released", "memory"): self.transition_released_memory, + ("released", "waiting"): self.transition_released_waiting, + ("waiting", "constrained"): self.transition_waiting_constrained, + ("waiting", "ready"): self.transition_waiting_ready, + ("waiting", "released"): self.transition_generic_released, } + self._transition_counter = 0 self.incoming_transfer_log = deque(maxlen=100000) self.incoming_count = 0 self.outgoing_transfer_log = deque(maxlen=100000) @@ -707,10 +746,11 @@ def __init__( stream_handlers = { "close": self.close, - "compute-task": self.add_task, "cancel-compute": self.cancel_compute, + "acquire-replica": self.handle_acquire_replica, + "compute-task": self.compute_task, "free-keys": self.handle_free_keys, - "superfluous-data": self.handle_superfluous_data, + "remove-replicas": self.handle_remove_replicas, "steal-request": self.steal_request, } @@ -736,6 +776,9 @@ def __init__( ) self.periodic_callbacks["keep-alive"] = pc + pc = PeriodicCallback(self.find_missing, 1000) + self.periodic_callbacks["find-missing"] = pc + self._address = contact_address self.memory_monitor_interval = parse_timedelta( @@ -819,6 +862,14 @@ def log_event(self, topic, msg): } ) + @property + def executing_count(self) -> int: + return len(self._executing) + + @property + def in_flight_tasks(self) -> int: + return len(self._in_flight_tasks) + @property def worker_address(self): """For API compatibility with Nanny""" @@ -977,7 +1028,6 @@ async def heartbeat(self): if self.heartbeat_active: logger.debug("Heartbeat skipped: channel busy") return - self.heartbeat_active = True logger.debug("Heartbeat: %s", self.address) try: @@ -1475,26 +1525,38 @@ async def get_data( # Local Execution # ################### - def update_data(self, comm=None, data=None, report=True, serializers=None): + def update_data( + self, comm=None, data=None, report=True, serializers=None, stimulus_id=None + ): + if stimulus_id is None: + stimulus_id = "update-data" + recommendations = {} + scheduler_messages = [] for key, value in data.items(): ts = self.tasks.get(key) if getattr(ts, "state", None) is not None: - self.transition(ts, "memory", value=value) + recommendations[ts] = ("memory", value) else: self.tasks[key] = ts = TaskState(key) - self.put_key_in_memory(ts, value) + recommendations, s_msgs = self._put_key_in_memory( + ts, value, stimulus_id=stimulus_id + ) + scheduler_messages.extend(s_msgs) ts.priority = None ts.duration = None - ts.scheduler_holds_ref = True self.log.append((key, "receive-from-scatter")) if report: - - self.log.append( - ("Notifying scheduler about in-memory in update-data", list(data)) + scheduler_messages.append( + { + "op": "add-keys", + "keys": list(data), + } ) - self.batched_stream.send({"op": "add-keys", "keys": list(data)}) + self.transitions(recommendations, stimulus_id=stimulus_id) + for msg in scheduler_messages: + self.batched_stream.send(msg) info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} return info @@ -1510,13 +1572,18 @@ def handle_free_keys(self, comm=None, keys=None, reason=None): upstream dependency. """ self.log.append(("free-keys", keys, reason)) + recommendations = {} for key in keys: ts = self.tasks.get(key) if ts is not None: - ts.scheduler_holds_ref = False - self.release_key(key, report=False, reason=reason) + if not ts.dependents: + recommendations[ts] = "forgotten" + else: + recommendations[ts] = "released" - def handle_superfluous_data(self, keys=(), reason=None): + self.transitions(recommendations, stimulus_id=reason) + + def handle_remove_replicas(self, keys, stimulus_id): """Stream handler notifying the worker that it might be holding unreferenced, superfluous data. This should not actually happen during ordinary operations and is only @@ -1534,13 +1601,17 @@ def handle_superfluous_data(self, keys=(), reason=None): For stronger guarantees, see handler free_keys """ - self.log.append(("Handle superfluous data", keys, reason)) + self.log.append(("remove-replica", keys, stimulus_id)) + recommendations = {} for key in list(keys): ts = self.tasks.get(key) - if ts and not ts.scheduler_holds_ref: - self.release_key(key, reason=f"delete data: {reason}", report=False) + if ts and not ts.is_protected(): + if not ts.dependents: + recommendations[ts] = "forgotten" + else: + recommendations[ts] = "released" + self.transitions(recommendations=recommendations, stimulus_id=stimulus_id) - logger.debug("Worker %s -- Deleted %d keys", self.name, len(keys)) return "OK" async def set_resources(self, **resources): @@ -1575,11 +1646,100 @@ def cancel_compute(self, key, reason): # scheduler side and therefore should not be assigned to a worker, # yet. assert not ts.dependents - self.release_key(key, reason=reason, report=False) + self.transition(ts, "released", stimulus_id=reason) + + def handle_acquire_replica( + self, comm=None, keys=None, priorities=None, who_has=None, stimulus_id=None + ): + recommendations = {} + scheduler_msgs = [] + for k in keys: + recs, s_msgs = self.register_acquire_internal( + k, + stimulus_id=stimulus_id, + priority=priorities[k], + ) + scheduler_msgs.extend(s_msgs) + recommendations.update(recs) + + self.update_who_has(who_has, stimulus_id=stimulus_id) + + for msg in scheduler_msgs: + self.batched_stream.send(msg) + self.transitions(recommendations, stimulus_id=stimulus_id) + + def register_acquire_internal(self, key, priority, stimulus_id): + if key in self.tasks: + logger.debug( + "Data task already known %s", + {"task": self.tasks[key], "stimulus_id": stimulus_id}, + ) + ts = self.tasks[key] + else: + self.tasks[key] = ts = TaskState(key) - def add_task( + self.log.append((key, "register-replica", ts.state, stimulus_id, time())) + ts.priority = ts.priority or priority + recommendations = {} + scheduler_msgs = [] + + if ts.state in ("released", "cancelled", "error"): + recommendations[ts] = "fetch" + + return recommendations, scheduler_msgs + + def transition_table_to_dot(self, filename="worker-transitions", format=None): + import graphviz + + from dask.dot import graphviz_to_file + + g = graphviz.Digraph( + graph_attr={ + "concentrate": "True", + }, + # node_attr=node_attr, + # edge_attr=edge_attr + ) + all_states = set() + for edge in self._transitions_table.keys(): + all_states.update(set(edge)) + + seen = set() + with g.subgraph(name="cluster_0") as c: + c.attr(style="filled", color="lightgrey") + c.node_attr.update(style="filled", color="white") + c.attr(label="executable") + for state in [ + "waiting", + "ready", + "executing", + "constrained", + "long-running", + ]: + c.node(state, label=state) + seen.add(state) + + with g.subgraph(name="cluster_1") as c: + for state in ["fetch", "flight", "missing"]: + c.attr(label="dependency") + c.node(state, label=state) + seen.add(state) + + # c.node("released", label="released", ports="n") + # seen.add('released') + + for state in all_states: + continue + # if state in seen: + g.node(state, label=state) + g.edges(self._transitions_table.keys()) + return graphviz_to_file(g, filename=filename, format=format) + + def compute_task( self, + *, key, + # FIXME: This will break protocol function=None, args=None, kwargs=None, @@ -1591,480 +1751,581 @@ def add_task( resource_restrictions=None, actor=False, annotations=None, - **kwargs2, + stimulus_id=None, ): - try: - runspec = SerializedTask(function, args, kwargs, task) - if key in self.tasks: - ts = self.tasks[key] - ts.scheduler_holds_ref = True - if ts.state == "memory": - assert key in self.data or key in self.actors - logger.debug( - "Asked to compute pre-existing result: %s: %s", key, ts.state - ) - self.send_task_state_to_scheduler(ts) - return - if ts.state in IN_PLAY: - return - if ts.state == "error": - ts.exception = None - ts.exception_text = "" - ts.traceback = None - ts.traceback_text = "" - else: - # This is a scheduler re-assignment - # Either `fetch` -> `waiting` or `flight` -> `waiting` - self.log.append((ts.key, "re-adding key, new TaskState")) - self.transition(ts, "waiting", runspec=runspec) - else: - self.log.append((key, "new")) - self.tasks[key] = ts = TaskState( - key=key, runspec=SerializedTask(function, args, kwargs, task) - ) - self.transition(ts, "waiting") - # TODO: move transition of `ts` to end of `add_task` - # This will require a chained recommendation transition system like - # the scheduler - - if priority is not None: - priority = tuple(priority) + (self.generation,) - self.generation -= 1 - - if actor: - self.actors[ts.key] = None - - ts.scheduler_holds_ref = True - ts.runspec = runspec - ts.priority = priority - ts.duration = duration - if resource_restrictions: - ts.resource_restrictions = resource_restrictions - ts.annotations = annotations - - who_has = who_has or {} - - for dependency, workers in who_has.items(): - assert workers - if dependency not in self.tasks: - # initial state is "new" - # this dependency does not already exist on worker - self.tasks[dependency] = dep_ts = TaskState(key=dependency) - - # link up to child / parents - ts.dependencies.add(dep_ts) - dep_ts.dependents.add(ts) - - # check to ensure task wasn't already executed and partially released - # # TODO: make this less bad - state = "fetch" if dependency not in self.data else "memory" - - # transition from new -> fetch handles adding dependency - # to waiting_for_data - discarded_self = False - if self.address in workers and state == "fetch": - discarded_self = True - workers = set(workers) - workers.discard(self.address) - who_has[dependency] = tuple(workers) - - self.transition(dep_ts, state, who_has=workers) - - self.log.append( - ( - dependency, - "new-dep", - dep_ts.state, - f"requested by {ts.key}", - discarded_self, - ) - ) + self.log.append((key, "compute-task", stimulus_id, time())) + if key in self.tasks: + logger.debug( + "Asked to compute an already known task %s", + {"task": self.tasks[key], "stimulus_id": stimulus_id}, + ) + ts = self.tasks[key] + else: + self.tasks[key] = ts = TaskState(key) + + ts.runspec = SerializedTask(function, args, kwargs, task) + + if priority is not None: + priority = tuple(priority) + (self.generation,) + self.generation -= 1 + + if actor: + self.actors[ts.key] = None + + ts.exception = None + ts.traceback = None + ts.exception_text = "" + ts.traceback_text = "" + ts.priority = priority + ts.duration = duration + if resource_restrictions: + ts.resource_restrictions = resource_restrictions + ts.annotations = annotations + + recommendations = {} + scheduler_msgs = [] + for dependency, _ in who_has.items(): + recs, s_msgs = self.register_acquire_internal( + key=dependency, + stimulus_id=stimulus_id, + priority=priority, + ) + recommendations.update(recs) + scheduler_msgs.extend(s_msgs) + dep_ts = self.tasks[dependency] - else: - # task was already present on worker - dep_ts = self.tasks[dependency] + # link up to child / parents + ts.dependencies.add(dep_ts) + dep_ts.dependents.add(ts) - # link up to child / parents - ts.dependencies.add(dep_ts) - dep_ts.dependents.add(ts) + if ts.state in ("ready", "executing", "waiting"): + pass + elif ts.state == "memory": + recommendations[ts] = "memory" + scheduler_msgs.append(self.get_task_state_for_scheduler(ts)) + elif ts.state in ("released", "fetch", "flight", "missing"): + recommendations[ts] = "waiting" + elif ts.state == "cancelled": + recommendations[ts] = "waiting" + else: + raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") + + for msg in scheduler_msgs: + self.batched_stream.send(msg) + self.transitions(recommendations, stimulus_id=stimulus_id) + + # We received new info, that's great but not related to the compute-task + # instruction + self.update_who_has(who_has=who_has, stimulus_id=stimulus_id) + if nbytes is not None: + for key, value in nbytes.items(): + self.tasks[key].nbytes = value + + def transition_missing_fetch(self, ts, *, stimulus_id): + self._missing_dep_flight.discard(ts) + ts.state = "fetch" + heapq.heappush(self.data_needed, (ts.priority, ts.key)) + return {}, [] + + def transition_missing_released(self, ts, *, stimulus_id): + self._missing_dep_flight.discard(ts) + recommendations = self.release_key(ts.key, reason="missing->released") + assert ts.key in self.tasks + return recommendations, [] + + def transition_fetch_missing(self, ts, *, stimulus_id): + # handle_missing will append to self.data_needed if new workers + # are found + ts.state = "missing" + self._missing_dep_flight.add(ts) + return {}, [] + + def transition_released_fetch(self, ts, *, stimulus_id): + if self.validate: + assert ts.state == "released" - if dep_ts.state not in ("memory",): - ts.waiting_for_data.add(dep_ts.key) + for w in ts.who_has: + self.pending_data_per_worker[w].append(ts.key) + ts.state = "fetch" + heapq.heappush(self.data_needed, (ts.priority, ts.key)) + return {}, [] - self.update_who_has(who_has=who_has) - if nbytes is not None: - for key, value in nbytes.items(): - self.tasks[key].nbytes = value + def transition_released_generic(self, ts, *, stimulus_id): + recs = self.release_key(ts.key, reason=stimulus_id) + return recs, [] - if ts.waiting_for_data: - self.data_needed.append(ts.key) + def transition_released_waiting(self, ts, *, stimulus_id): + if self.validate: + assert ts.state == "released" + assert all(d.key in self.tasks for d in ts.dependencies) + + recommendations = {} + ts.waiting_for_data.clear() + for dep_ts in ts.dependencies: + if not dep_ts.state == "memory": + ts.waiting_for_data.add(dep_ts) + + if not ts.waiting_for_data: + if not ts.resource_restrictions: + recommendations[ts] = "ready" else: - self.transition(ts, "ready") - if self.validate: - for worker, keys in self.has_what.items(): - for k in keys: - assert worker in self.tasks[k].who_has - if who_has: - assert all(self.tasks[dep] in ts.dependencies for dep in who_has) - assert all(self.tasks[dep.key] for dep in ts.dependencies) - for dependency in ts.dependencies: - self.validate_task(dependency) - self.validate_task(ts) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise + recommendations[ts] = "constrained" + else: + self.waiting_for_data_count += 1 + ts.state = "waiting" + return recommendations, [] - def transition(self, ts, finish, **kwargs): - if ts is None: - return - start = ts.state - if start == finish: - return - func = self._transitions[start, finish] - self.log.append((ts.key, start, finish)) - state = func(ts, **kwargs) - if state and finish != state: - self.log.append((ts.key, start, finish, state)) - ts.state = state or finish + def transition_fetch_flight(self, ts, worker, *, stimulus_id): if self.validate: - self.validate_task(ts) - self._notify_plugins("transition", ts.key, start, state or finish, **kwargs) - - def transition_new_waiting(self, ts): - try: - if self.validate: - assert ts.state == "new" - assert ts.runspec is not None - assert not ts.who_has - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + assert ts.state == "fetch" + assert ts.who_has + assert ts.key not in self.data_needed - pdb.set_trace() - raise + ts.state = "flight" + ts.coming_from = worker + self._in_flight_tasks.add(ts) + return {}, [] - def transition_new_fetch(self, ts, who_has): - try: - if self.validate: - assert ts.state == "new" - assert ts.runspec is None - assert who_has + def transition_memory_released(self, ts, *, stimulus_id): + recs = self.release_key(ts.key, reason=stimulus_id) + s_msgs = [{"op": "release-worker-data", "key": ts.key}] + return recs, s_msgs - for dependent in ts.dependents: - dependent.waiting_for_data.add(ts.key) - - ts.who_has.update(who_has) - for w in who_has: - self.has_what[w].add(ts.key) - self.pending_data_per_worker[w].append(ts.key) + def transition_waiting_constrained(self, ts, *, stimulus_id): + if self.validate: + assert ts.state == "waiting" + assert not ts.waiting_for_data + assert all( + dep.key in self.data or dep.key in self.actors + for dep in ts.dependencies + ) + assert all(dep.state == "memory" for dep in ts.dependencies) + assert ts.key not in self.ready + ts.state = "constrained" + self.constrained.append(ts.key) + return {}, [] - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def transition_long_running_rescheduled(self, ts, *, stimulus_id): + msgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] + return {ts: "released"}, msgs - pdb.set_trace() - raise + def transition_executing_rescheduled(self, ts, *, stimulus_id): + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] += quantity + msgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] - def transition_fetch_waiting(self, ts, runspec): - """This is a rescheduling transition that occurs after a worker failure. - A task was available from another worker but that worker died and the - scheduler reassigned the task for computation here. - """ - try: - if self.validate: - assert ts.state == "fetch" - assert ts.runspec is None - assert runspec is not None + self._executing.discard(ts) + return {ts: "released"}, msgs - ts.runspec = runspec + def transition_waiting_ready(self, ts, *, stimulus_id): + if self.validate: + assert ts.state == "waiting" + assert not ts.waiting_for_data + assert all( + dep.key in self.data or dep.key in self.actors + for dep in ts.dependencies + ) + assert all(dep.state == "memory" for dep in ts.dependencies) + assert ts.key not in self.ready + ts.state = "ready" + heapq.heappush(self.ready, (ts.priority, ts.key)) - # remove any stale entries in `has_what` - for worker in self.has_what.keys(): - self.has_what[worker].discard(ts.key) + return {}, [] - # clear `who_has` of stale info - ts.who_has.clear() + def transition_generic_error( + self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id + ): + ts.exception = exception + ts.traceback = traceback + ts.exception_text = exception_text + ts.traceback_text = traceback_text + smsgs = [self.get_task_state_for_scheduler(ts)] + ts.state = "error" + return {}, smsgs + + def transition_long_running_error( + self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id + ): + return self.transition_generic_error( + ts, + exception, + traceback, + exception_text, + traceback_text, + stimulus_id=stimulus_id, + ) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def transition_executing_error( + self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id + ): + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] += quantity + self._executing.discard(ts) + return self.transition_generic_error( + ts, + exception, + traceback, + exception_text, + traceback_text, + stimulus_id=stimulus_id, + ) - pdb.set_trace() - raise + def transition_rescheduled_next(self, ts, *, stimulus_id): + next_state = ts._next + recs = self.release_key(ts.key, reason=stimulus_id) + if self.validate: + assert ts.state == "released" + recs[ts] = next_state + return recs, [] + + def transition_cancelled_fetch(self, ts, *, stimulus_id): + if ts.done: + return {ts: "released"}, [] + recommendations = {} + if ts._previous == "flight": + ts.state = ts._previous + else: + assert ts._previous == "executing" + recommendations[ts] = ("resumed", "fetch") + return recommendations, [] + + def transition_cancelled_resumed(self, ts, next, *, stimulus_id): + ts._next = next + ts.state = "resumed" + return {}, [] + + def transition_cancelled_waiting(self, ts, *, stimulus_id): + if ts.done: + return {ts: "released"}, [] + recommendations = {} + if ts._previous == "executing": + ts.state = ts._previous + else: + assert ts._previous == "flight" + recommendations[ts] = ("resumed", "waiting") + return recommendations, [] + + def transition_rescheduled_cancelled(self, ts): + ts.state = "cancelled" + return {}, [] + + def transition_cancelled_forgotten(self, ts, *, stimulus_id): + ts._next = "forgotten" + if not ts.done: + return {}, [] + return {ts: "released"}, [] + + def transition_cancelled_released(self, ts, *, stimulus_id): + if not ts.done: + ts._next = "released" + return {}, [] + next_state = ts._next + self._executing.discard(ts) + self._in_flight_tasks.discard(ts) - def transition_flight_waiting(self, ts, runspec): - """This is a rescheduling transition that occurs after - a worker failure. A task was in flight from another worker to this - worker when that worker died and the scheduler reassigned the task for - computation here. - """ - try: - if self.validate: - assert ts.state == "flight" - assert ts.runspec is None - assert runspec is not None + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] += quantity + recommendations = self.release_key(ts.key, reason=stimulus_id) + recommendations[ts] = next_state or "released" + return recommendations, [] + + def transition_executing_released(self, ts, *, stimulus_id): + ts._previous = ts.state + ts.state = "cancelled" + ts.done = False + return {}, [] + + def transition_long_running_memory(self, ts, value=no_value, *, stimulus_id): + self.executed_count += 1 + return self._transition_to_memory_generic( + ts, value=value, stimulus_id=stimulus_id + ) - ts.runspec = runspec + def _transition_to_memory_generic(self, ts, value=no_value, *, stimulus_id): - # remove any stale entries in `has_what` - for worker in self.has_what.keys(): - self.has_what[worker].discard(ts.key) + if value is no_value and ts.key not in self.data: + raise RuntimeError( + f"Tried to transition task {ts} to `memory` without data available" + ) - # clear `who_has` of stale info - ts.who_has.clear() + if ts.resource_restrictions is not None: + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] += quantity - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + self._executing.discard(ts) + self._in_flight_tasks.discard(ts) + ts.coming_from = None - pdb.set_trace() - raise + recommendations, s_msgs = self._put_key_in_memory( + ts, value, stimulus_id=stimulus_id + ) + s_msgs.append(self.get_task_state_for_scheduler(ts)) + return recommendations, s_msgs - def transition_fetch_flight(self, ts, worker=None): - try: - if self.validate: - assert ts.state == "fetch" - assert ts.dependents + def transition_executing_memory(self, ts, value=no_value, *, stimulus_id): + if self.validate: + assert ts.state == "executing" or ts.key in self.long_running + assert not ts.waiting_for_data + assert ts.key not in self.ready + + self._executing.discard(ts) + self.executed_count += 1 + return self._transition_to_memory_generic( + ts, value=value, stimulus_id=stimulus_id + ) - ts.coming_from = worker - self.in_flight_tasks += 1 - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def transition_constrained_released(self, ts, *, stimulus_id): + recs = self.release_key(ts.key, reason=stimulus_id) + return recs, [] - pdb.set_trace() - raise + def transition_constrained_executing(self, ts, *, stimulus_id): + if self.validate: + assert not ts.waiting_for_data + assert ts.key not in self.data + assert ts.state in READY + assert ts.key not in self.ready + assert all( + dep.key in self.data or dep.key in self.actors + for dep in ts.dependencies + ) + for resource, quantity in ts.resource_restrictions.items(): + self.available_resources[resource] -= quantity + ts.state = "executing" + self._executing.add(ts) + self.loop.add_callback(self.execute, ts.key, stimulus_id=stimulus_id) + return {}, [] - def transition_flight_fetch(self, ts): - try: - if self.validate: - assert ts.state == "flight" + def transition_ready_executing(self, ts, *, stimulus_id): + if self.validate: + assert not ts.waiting_for_data + assert ts.key not in self.data + assert ts.state in READY + assert ts.key not in self.ready + assert all( + dep.key in self.data or dep.key in self.actors + for dep in ts.dependencies + ) + ts.state = "executing" + self._executing.add(ts) + self.loop.add_callback(self.execute, ts.key, stimulus_id=stimulus_id) + return {}, [] - self.in_flight_tasks -= 1 - ts.coming_from = None - ts.runspec = None + def transition_flight_fetch(self, ts, *, stimulus_id): + if self.validate: + assert ts.state == "flight" - if not ts.who_has: - if ts.key not in self._missing_dep_flight: - self._missing_dep_flight.add(ts.key) - logger.info("Task %s does not know who has", ts) - self.loop.add_callback(self.handle_missing_dep, ts) - for w in ts.who_has: - self.pending_data_per_worker[w].append(ts.key) - for dependent in ts.dependents: - dependent.waiting_for_data.add(ts.key) - if dependent.state == "waiting": - self.data_needed.append(dependent.key) + self._in_flight_tasks.discard(ts) + ts.coming_from = None - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + for w in ts.who_has: + self.pending_data_per_worker[w].append(ts.key) + ts.state = "fetch" + heapq.heappush(self.data_needed, (ts.priority, ts.key)) - pdb.set_trace() - raise + return {}, [] - def transition_flight_memory(self, ts, value=None): - try: - if self.validate: - assert ts.state == "flight" + def transition_flight_error( + self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id + ): + self._in_flight_tasks.discard(ts) + ts.coming_from = None + return self.transition_generic_error( + ts, + exception, + traceback, + exception_text, + traceback_text, + stimulus_id=stimulus_id, + ) - self.in_flight_tasks -= 1 - ts.coming_from = None - self.put_key_in_memory(ts, value) - for dependent in ts.dependents: - try: - dependent.waiting_for_data.remove(ts.key) - self.waiting_for_data_count -= 1 - except KeyError: - pass + def transition_flight_released(self, ts, *, stimulus_id): + ts._previous = "flight" + ts.state = "cancelled" + return {}, [] - self.log.append(("Notifying scheduler about in-memory", ts.key)) - self.batched_stream.send({"op": "add-keys", "keys": [ts.key]}) + def transition_cancelled_memory(self, ts, value, *, stimulus_id): + return {ts: ts._next}, [] - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def transition_generic_released(self, ts, *, stimulus_id): + recs = self.release_key(ts.key, reason=stimulus_id) + return recs, [] - pdb.set_trace() - raise + def transition_executing_long_running(self, ts, compute_duration, *, stimulus_id): - def transition_waiting_ready(self, ts): - try: - if self.validate: - assert ts.state == "waiting" - assert not ts.waiting_for_data - assert all( - dep.key in self.data or dep.key in self.actors - for dep in ts.dependencies - ) - assert all(dep.state == "memory" for dep in ts.dependencies) - assert ts.key not in self.ready + if self.validate: + assert ts.state == "executing" + ts.state = "long-running" + self._executing.discard(ts) + self.long_running.add(ts.key) + scheduler_msgs = [ + { + "op": "long-running", + "key": ts.key, + "compute_duration": compute_duration, + } + ] - self.has_what[self.address].discard(ts.key) + self.io_loop.add_callback(self.ensure_computing) + return {}, scheduler_msgs - if ts.resource_restrictions is not None: - self.constrained.append(ts.key) - return "constrained" - else: - heapq.heappush(self.ready, (ts.priority, ts.key)) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def transition_released_memory(self, ts, value, *, stimulus_id): + recommendations, scheduler_msgs = self._put_key_in_memory( + ts, value, stimulus_id=stimulus_id + ) + scheduler_msgs.append( + { + "op": "add-keys", + "keys": [ts.key], + } + ) + return recommendations, scheduler_msgs - pdb.set_trace() - raise + def transition_flight_memory(self, ts, value, *, stimulus_id): + if self.validate: + assert ts.state == "flight" - def transition_waiting_done(self, ts, value=None): - try: - if self.validate: - assert ts.state == "waiting" - assert ts.key not in self.ready + self._in_flight_tasks.discard(ts) + ts.coming_from = None + recommendations, scheduler_msgs = self._put_key_in_memory( + ts, value, stimulus_id=stimulus_id + ) + scheduler_msgs.append( + { + "op": "add-keys", + "keys": [ts.key], + } + ) + return recommendations, scheduler_msgs + + def _transition(self, ts, finish, *args, stimulus_id, **kwargs): + recommendations = {} + scheduler_msgs = [] + finish_state = finish + if isinstance(finish, tuple): + # the concatenated transition path might need to access the tuple + finish_state, *args = finish + + if ts is None or ts.state == finish_state: + return recommendations, scheduler_msgs + start = ts.state + start_finish = (start, finish_state) + func = self._transitions_table.get(start_finish) - self.waiting_for_data_count -= len(ts.waiting_for_data) - ts.waiting_for_data.clear() - if value is not None: - self.put_key_in_memory(ts, value) - self.send_task_state_to_scheduler(ts) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + default_state = "released" + if func is not None: + a: tuple = func(ts, *args, stimulus_id=stimulus_id, **kwargs) + self._transition_counter += 1 + recommendations, scheduler_msgs = a + self._notify_plugins("transition", ts.key, start, finish_state, **kwargs) - pdb.set_trace() - raise - - def transition_ready_executing(self, ts): - try: - if self.validate: - assert not ts.waiting_for_data - assert ts.key not in self.data - assert ts.state in READY - assert ts.key not in self.ready - assert all( - dep.key in self.data or dep.key in self.actors - for dep in ts.dependencies + elif default_state not in start_finish: + try: + a: tuple = self._transition(ts, default_state, stimulus_id=stimulus_id) + a_recs, a_smsgs = a + + recommendations.update(a_recs) + scheduler_msgs.extend(a_smsgs) + v = a_recs.get(ts, finish) + v_args = [] + v_state = v + if isinstance(v, tuple): + v_state, *v_args = finish + b: tuple = self._transition( + ts, v_state, *v_args, stimulus_id=stimulus_id ) + b_recs, b_smsgs = b + recommendations.update(b_recs) + scheduler_msgs.extend(b_smsgs) + except (InvalidTransition, KeyError): + raise InvalidTransition( + "Impossible transition from %r to %r for %s" + % (*start_finish, ts.key) + ) from None - self.executing_count += 1 - self.loop.add_callback(self.execute, ts.key) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise - - def transition_ready_error(self, ts): - if self.validate: - assert ts.exception is not None - assert ts.traceback is not None - assert ts.exception_text - assert ts.traceback_text - self.send_task_state_to_scheduler(ts) - - def transition_ready_memory(self, ts, value=no_value): - if value is not no_value: - self.put_key_in_memory(ts, value=value) - self.send_task_state_to_scheduler(ts) - - def transition_constrained_executing(self, ts): - self.transition_ready_executing(ts) - for resource, quantity in ts.resource_restrictions.items(): - self.available_resources[resource] -= quantity - - if self.validate: - assert all(v >= 0 for v in self.available_resources.values()) + else: + raise InvalidTransition( + "Impossible transition from %r to %r for %s" % (*start_finish, ts.key) + ) - def transition_executing_done(self, ts, value=no_value, report=True): - try: - if self.validate: - assert ts.state == "executing" or ts.key in self.long_running - assert not ts.waiting_for_data - assert ts.key not in self.ready + self.log.append( + ( + ts.key, + start, + ts.state, + {ts.key: new for ts, new in recommendations.items()}, + stimulus_id, + time(), + ) + ) + return recommendations, scheduler_msgs - out = None - if ts.resource_restrictions is not None: - for resource, quantity in ts.resource_restrictions.items(): - self.available_resources[resource] += quantity + def _transitions(self, recommendations: dict, scheduler_msgs: list, stimulus_id): - if ts.state == "executing": - self.executing_count -= 1 - self.executed_count += 1 - elif ts.state == "long-running": - self.long_running.remove(ts.key) + recommendations = recommendations.copy() + tasks = set() + while recommendations: + ts, finish = recommendations.popitem() + tasks.add(ts) + new = self._transition(ts, finish, stimulus_id=stimulus_id) + new_recs, new_smsgs = new + scheduler_msgs.extend(new_smsgs) - if value is not no_value: - try: - self.put_key_in_memory(ts, value, transition=False) - except Exception as e: - logger.info("Failed to put key in memory", exc_info=True) - msg = error_message(e) - ts.exception = msg["exception"] - ts.exception_text = msg["exception_text"] - ts.traceback = msg["traceback"] - ts.traceback_text = msg["traceback_text"] - ts.state = "error" - out = "error" - for d in ts.dependents: - d.waiting_for_data.add(ts.key) - - if report and self.batched_stream and self.status == Status.running: - self.send_task_state_to_scheduler(ts) - else: - raise CommClosedError + recommendations.update(new_recs) - return out + if self.validate: + # Full state validatition is too expensive + for ts in tasks: + self.validate_task(ts) - except OSError: - logger.info("Comm closed") - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + def transition(self, ts, finish: str, *, stimulus_id, **kwargs): + """Transition a key from its current state to the finish state - pdb.set_trace() - raise + Examples + -------- + >>> self.transition('x', 'waiting') + {'x': 'processing'} - def transition_executing_long_running(self, ts, compute_duration=None): - try: - if self.validate: - assert ts.state == "executing" + Returns + ------- + Dictionary of recommendations for future transitions - self.executing_count -= 1 - self.long_running.add(ts.key) - self.batched_stream.send( - { - "op": "long-running", - "key": ts.key, - "compute_duration": compute_duration, - } + See Also + -------- + Scheduler.transitions: transitive version of this function + """ + recommendations: dict + a: tuple = self._transition(ts, finish, stimulus_id=stimulus_id, **kwargs) + recommendations, s_msgs = a + for msg in s_msgs: + self.batched_stream.send(msg) + self.transitions(recommendations, stimulus_id=stimulus_id) + + def transitions(self, recommendations: dict, stimulus_id): + """Process transitions until none are left + + This includes feedback from previous transitions and continues until we + reach a steady state + """ + s_msgs = [] + self._transitions(recommendations, s_msgs, stimulus_id) + if not self.batched_stream.closed(): + for msg in s_msgs: + self.batched_stream.send(msg) + else: + logger.debug( + "BatchedSend closed while transitioning tasks. %s tasks not sent.", + len(s_msgs), ) - self.ensure_computing() - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise - - def maybe_transition_long_running(self, ts, compute_duration=None): + def maybe_transition_long_running(self, ts, stimulus_id, compute_duration=None): if ts.state == "executing": - self.transition(ts, "long-running", compute_duration=compute_duration) + self.transition( + ts, + "long-running", + compute_duration=compute_duration, + stimulus_id=stimulus_id, + ) + assert ts.state == "long-running" def stateof(self, key): ts = self.tasks[key] @@ -2090,123 +2351,68 @@ def story(self, *keys): ] def ensure_communicating(self): - changed = True - try: - while ( - changed - and self.data_needed - and len(self.in_flight_workers) < self.total_out_connections - ): - changed = False - logger.debug( - "Ensure communicating. Pending: %d. Connections: %d/%d", - len(self.data_needed), - len(self.in_flight_workers), - self.total_out_connections, - ) - - key = self.data_needed[0] + stimulus_id = f"ensure-communicating-{time()}" + skipped_worker_in_flight = list() - if key not in self.tasks: - self.data_needed.popleft() - changed = True - continue - - ts = self.tasks[key] - if ts.state != "waiting": - self.log.append((key, "communication pass")) - self.data_needed.popleft() - changed = True - continue + while self.data_needed and ( + len(self.in_flight_workers) < self.total_out_connections + or self.comm_nbytes < self.comm_threshold_bytes + ): + logger.debug( + "Ensure communicating. Pending: %d. Connections: %d/%d", + len(self.data_needed), + len(self.in_flight_workers), + self.total_out_connections, + ) - dependencies = ts.dependencies - if self.validate: - assert all(dep.key in self.tasks for dep in dependencies) - - dependencies_fetch = set() - dependencies_missing = set() - for dependency_ts in dependencies: - if dependency_ts.state == "fetch": - if not dependency_ts.who_has: - dependencies_missing.add(dependency_ts) - else: - dependencies_fetch.add(dependency_ts) + key = heapq.heappop(self.data_needed)[1] - del dependencies, dependency_ts + if key not in self.tasks: + continue - if dependencies_missing: - missing_deps2 = { - dep - for dep in dependencies_missing - if dep.key not in self._missing_dep_flight - } - for dep in missing_deps2: - self._missing_dep_flight.add(dep.key) - if missing_deps2: - logger.info( - "Can't find dependencies %s for key %s", - missing_deps2.copy(), - key, - ) - self.loop.add_callback(self.handle_missing_dep, *missing_deps2) - dependencies_fetch -= dependencies_missing + ts = self.tasks[key] + if ts.state != "fetch": + continue - self.log.append( - ("gather-dependencies", key, {d.key for d in dependencies_fetch}) - ) + if not ts.who_has: + self.transition(ts, "missing", stimulus_id=stimulus_id) + continue - in_flight = False + workers = [w for w in ts.who_has if w not in self.in_flight_workers] + if not workers: + skipped_worker_in_flight.append((ts.priority, ts.key)) + continue - while dependencies_fetch and ( - len(self.in_flight_workers) < self.total_out_connections - or self.comm_nbytes < self.comm_threshold_bytes - ): - to_gather_ts = dependencies_fetch.pop() - - workers = [ - w - for w in to_gather_ts.who_has - if w not in self.in_flight_workers - ] - if not workers: - in_flight = True - continue - host = get_address_host(self.address) - local = [w for w in workers if get_address_host(w) == host] - if local: - worker = random.choice(local) - else: - worker = random.choice(list(workers)) - to_gather, total_nbytes = self.select_keys_for_gather( - worker, to_gather_ts.key - ) - self.comm_nbytes += total_nbytes - self.in_flight_workers[worker] = to_gather - for d in to_gather: - dependencies_fetch.discard(self.tasks.get(d)) - self.transition(self.tasks[d], "flight", worker=worker) - assert not worker == self.address - self.loop.add_callback( - self.gather_dep, - worker=worker, - to_gather=to_gather, - total_nbytes=total_nbytes, - cause=ts, - ) - changed = True + host = get_address_host(self.address) + local = [w for w in workers if get_address_host(w) == host] + if local: + worker = random.choice(local) + else: + worker = random.choice(list(workers)) - if not dependencies_fetch and not in_flight: - self.data_needed.popleft() + to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb + self.log.append( + ("gather-dependencies", worker, to_gather, "stimulus", time()) + ) - pdb.set_trace() - raise + self.comm_nbytes += total_nbytes + self.in_flight_workers[worker] = to_gather + recommendations = {self.tasks[d]: ("flight", worker) for d in to_gather} + self.transitions(recommendations=recommendations, stimulus_id=stimulus_id) + assert not worker == self.address + self.loop.add_callback( + self.gather_dep, + worker=worker, + to_gather=to_gather, + total_nbytes=total_nbytes, + stimulus_id=stimulus_id, + ) + else: + for el in skipped_worker_in_flight: + heapq.heappush(self.data_needed, el) - def send_task_state_to_scheduler(self, ts): + def get_task_state_for_scheduler(self, ts): if ts.key in self.data or self.actors.get(ts.key): typ = ts.type if ts.nbytes is None or typ is None: @@ -2247,44 +2453,49 @@ def send_task_state_to_scheduler(self, ts): else: logger.error("Key not ready to send to worker, %s: %s", ts.key, ts.state) return - if ts.startstops: d["startstops"] = ts.startstops - self.batched_stream.send(d) + return d - def put_key_in_memory(self, ts, value, transition=True): + def _put_key_in_memory(self, ts, value, stimulus_id): if ts.key in self.data: ts.state = "memory" - return - + return {}, [] + recommendations = {} + scheduler_messages = [] if ts.key in self.actors: self.actors[ts.key] = value else: start = time() - self.data[ts.key] = value - ts.state = "memory" + try: + self.data[ts.key] = value + except Exception as e: + msg = error_message(e) + ts.exception = msg["exception"] + ts.traceback = msg["traceback"] + recommendations[ts] = ("error", msg["exception"], msg["traceback"]) + return recommendations, [] stop = time() if stop - start > 0.020: ts.startstops.append( {"action": "disk-write", "start": start, "stop": stop} ) + ts.state = "memory" if ts.nbytes is None: ts.nbytes = sizeof(value) ts.type = type(value) for dep in ts.dependents: - try: - dep.waiting_for_data.remove(ts.key) + dep.waiting_for_data.discard(ts) + if not dep.waiting_for_data and dep.state == "waiting": self.waiting_for_data_count -= 1 - except KeyError: - pass - if not dep.waiting_for_data: - self.transition(dep, "ready") + recommendations[dep] = "ready" - self.log.append((ts.key, "put-in-memory")) + self.log.append((ts.key, "put-in-memory", stimulus_id, time())) + return recommendations, scheduler_messages def select_keys_for_gather(self, worker, dep): assert isinstance(dep, str) @@ -2319,7 +2530,7 @@ async def gather_dep( worker: str, to_gather: Iterable[str], total_nbytes: int, - cause: TaskState, + stimulus_id, ): """Gather dependencies for a task from a worker who has them @@ -2333,29 +2544,41 @@ async def gather_dep( as some dependencies may already be present on this worker. total_nbytes : int Total number of bytes for all the dependencies in to_gather combined - cause : TaskState - Task we want to gather dependencies for """ - - if self.validate: - self.validate_state() + cause = None if self.status != Status.running: return + with log_errors(): response = {} to_gather_keys = set() try: - if self.validate: - self.validate_state() + found_dependent_for_cause = False for dependency_key in to_gather: dependency_ts = self.tasks.get(dependency_key) if dependency_ts and dependency_ts.state == "flight": to_gather_keys.add(dependency_key) + if not found_dependent_for_cause: + cause = dependency_ts + # For diagnostics we want to attach the transfer to + # a single task. this task is typically the next to + # be executed but since we're fetching tasks for + # potentially many dependents, an exact match is not + # possible. If there are no dependents, this is a + # pure replica fetch + for dependent in dependency_ts.dependents: + cause = dependent + found_dependent_for_cause = True + break # Keep namespace clean since this func is long and has many # dep*, *ts* variables + + assert cause is not None del to_gather, dependency_key, dependency_ts - self.log.append(("request-dep", cause.key, worker, to_gather_keys)) + self.log.append( + ("request-dep", worker, to_gather_keys, stimulus_id, time()) + ) logger.debug( "Request %d keys for task %s from %s", len(to_gather_keys), @@ -2368,15 +2591,17 @@ async def gather_dep( self.rpc, to_gather_keys, worker, who=self.address ) stop = time() - if response["status"] == "busy": - self.log.append(("busy-gather", worker, to_gather_keys)) - for key in to_gather_keys: - ts = self.tasks.get(key) - if ts and ts.state == "flight": - self.transition(ts, "fetch") return + data = {k: v for k, v in response["data"].items() if k in self.tasks} + lost_keys = set(response["data"]) - set(data) + + if lost_keys: + self.log.append(("lost-during-gather", lost_keys, stimulus_id)) + + total_bytes = sum(self.tasks[key].get_nbytes() for key in data) + cause.startstops.append( { "action": "transfer", @@ -2385,12 +2610,6 @@ async def gather_dep( "source": worker, } ) - - total_bytes = sum( - self.tasks[key].get_nbytes() - for key in response["data"] - if key in self.tasks - ) duration = (stop - start) or 0.010 bandwidth = total_bytes / duration self.incoming_transfer_log.append( @@ -2399,11 +2618,7 @@ async def gather_dep( "stop": stop + self.scheduler_delay, "middle": (start + stop) / 2.0 + self.scheduler_delay, "duration": duration, - "keys": { - key: self.tasks[key].nbytes - for key in response["data"] - if key in self.tasks - }, + "keys": {key: self.tasks[key].nbytes for key in data}, "total": total_bytes, "bandwidth": bandwidth, "who": worker, @@ -2426,13 +2641,17 @@ async def gather_dep( self.counters["transfer-count"].add(len(response["data"])) self.incoming_count += 1 - self.log.append(("receive-dep", worker, list(response["data"]))) + self.log.append( + ("receive-dep", worker, set(response["data"]), stimulus_id, time()) + ) except OSError: logger.exception("Worker stream died during communication: %s", worker) has_what = self.has_what.pop(worker) self.pending_data_per_worker.pop(worker) - self.log.append(("receive-dep-failed", worker, has_what)) + self.log.append( + ("receive-dep-failed", worker, has_what, stimulus_id, time()) + ) for d in has_what: ts = self.tasks[d] ts.who_has.remove(worker) @@ -2449,176 +2668,112 @@ async def gather_dep( busy = response.get("status", "") == "busy" data = response.get("data", {}) - # FIXME: We should not handle keys which were skipped by this coro. to_gather_keys is only a subset - assert set(to_gather_keys).issubset( - set(self.in_flight_workers.get(worker)) - ) + recommendations = {} + + deps_to_iter = self.in_flight_workers.pop(worker) + + if busy: + self.log.append( + ("busy-gather", worker, to_gather_keys, stimulus_id, time()) + ) - for d in self.in_flight_workers.pop(worker): + for d in deps_to_iter: ts = self.tasks.get(d) - try: - if not busy and d in data: - self.transition(ts, "memory", value=data[d]) - elif ts is None or ts.state == "executing": - self.log.append(("already-executing", d)) - self.release_key(d, reason="already executing at gather") - elif ts.state == "flight" and not ts.dependents: - self.log.append(("flight no-dependents", d)) - self.release_key( - d, reason="In-flight task no longer has dependents." - ) - elif ( - not busy - and d not in data - and ts.dependents - and ts.state != "memory" - ): - ts.who_has.discard(worker) - self.has_what[worker].discard(ts.key) - self.log.append(("missing-dep", d)) - self.batched_stream.send( - { - "op": "missing-data", - "errant_worker": worker, - "key": d, - } - ) - self.transition(ts, "fetch") - elif ts.state not in ("ready", "memory"): - self.transition(ts, "fetch") - else: - logger.debug( - "Unexpected task state encountered for %r after gather_dep", - ts, - ) - except Exception as exc: - emsg = error_message(exc) - assert ts is not None, ts - self.log.append( - (ts.key, "except-gather-dep-result", emsg, time()) - ) - # FIXME: We currently cannot release this task and its - # dependent safely - logger.debug( - "Exception occured while handling `gather_dep` response for %r", - ts, - exc_info=True, + assert ts, (d, self.story(d)) + ts.done = True + if d in data: + recommendations[ts] = ("memory", data[d]) + elif not busy: + ts.who_has.discard(worker) + self.has_what[worker].discard(ts.key) + self.log.append(("missing-dep", d)) + self.batched_stream.send( + {"op": "missing-data", "errant_worker": worker, "key": d} ) - if self.validate: - self.validate_state() + if ts.state != "memory" and ts not in recommendations: + recommendations[ts] = "fetch" + del data, response + self.transitions( + recommendations=recommendations, stimulus_id=stimulus_id + ) self.ensure_computing() if not busy: self.repetitively_busy = 0 - self.ensure_communicating() else: # Exponential backoff to avoid hammering scheduler/worker self.repetitively_busy += 1 await asyncio.sleep(0.100 * 1.5 ** self.repetitively_busy) - await self.query_who_has(*to_gather_keys) - self.ensure_communicating() + await self.query_who_has(*to_gather_keys, stimulus_id=stimulus_id) - def bad_dep(self, dep): - exc = ValueError( - "Could not find dependent %s. Check worker logs" % str(dep.key) - ) - for ts in dep.dependents: - msg = error_message(exc) - ts.exception = msg["exception"] - ts.traceback = msg["traceback"] - ts.exception_text = msg["exception_text"] - ts.traceback_text = msg["traceback_text"] - self.transition(ts, "error") - self.release_key(dep.key, reason="bad dep") - - async def handle_missing_dep(self, *deps, **kwargs): - self.log.append(("handle-missing", deps)) - try: - deps = {dep for dep in deps if dep.dependents} - if not deps: - return + self.ensure_communicating() - for dep in deps: - if dep.suspicious_count > 5: - deps.remove(dep) - self.bad_dep(dep) - if not deps: - return + def transition_released_forgotten(self, ts, *, stimulus_id): + recommendations = {} + # Dependents _should_ be released by the scheduler before this + if self.validate: + assert not any(d.state != "forgotten" for d in ts.dependents) + for dep in ts.dependencies: + dep.dependents.discard(ts) + if dep.state == "released" and not dep.dependents: + recommendations[dep] = "forgotten" - for dep in deps: - logger.info( - "Dependent not found: %s %s . Asking scheduler", - dep.key, - dep.suspicious_count, - ) + # Mark state as forgotten in case it is still referenced anymore + ts.state = "forgotten" + self.tasks.pop(ts.key, None) + return recommendations, [] - who_has = await retry_operation( - self.scheduler.who_has, keys=list(dep.key for dep in deps) - ) - who_has = {k: v for k, v in who_has.items() if v} - self.update_who_has(who_has) - still_missing = set() - for dep in deps: - dep.suspicious_count += 1 + async def find_missing(self): + with log_errors(): + if not self._missing_dep_flight: + return + try: + if self.validate: + for ts in self._missing_dep_flight: + # If this was collected somewhere else we should've transitioned already, shouldn't we? maybe this is the place, let's see + assert not ts.who_has + + stimulus_id = f"find-missing-{time()}" + who_has = await retry_operation( + self.scheduler.who_has, + keys=[ts.key for ts in self._missing_dep_flight], + ) + who_has = {k: v for k, v in who_has.items() if v} + self.update_who_has(who_has, stimulus_id=stimulus_id) - if not who_has.get(dep.key): - logger.info( - "No workers found for %s", - dep.key, + if self._missing_dep_flight: + logger.debug( + "No new workers found for %s", self._missing_dep_flight ) - self.log.append((dep.key, "no workers found", dep.dependents)) - self.release_key(dep.key, reason="Handle missing no workers") - elif self.address in who_has and dep.state != "memory": - - still_missing.add(dep) - self.batched_stream.send( - { - "op": "release-worker-data", - "keys": [dep.key], - "worker": self.address, - } + recommendations = { + dep: "released" + for dep in self._missing_dep_flight + if dep.state == "missing" + } + self.transitions( + recommendations=recommendations, stimulus_id=stimulus_id ) - else: - logger.debug("New workers found for %s", dep.key) - self.log.append((dep.key, "new workers found")) - for dependent in dep.dependents: - if dep.key in dependent.waiting_for_data: - self.data_needed.append(dependent.key) - if still_missing: - logger.debug( - "Found self referencing who has response from scheduler for keys %s.\n" - "Trying again handle_missing", - deps, - ) - await self.handle_missing_dep(*deps) - except Exception: - logger.error("Handle missing dep failed, retrying", exc_info=True) - retries = kwargs.get("retries", 5) - self.log.append(("handle-missing-failed", retries, deps)) - if retries > 0: - await self.handle_missing_dep(*deps, retries=retries - 1) - else: - raise - finally: - try: - for dep in deps: - self._missing_dep_flight.remove(dep.key) - except KeyError: - pass - self.ensure_communicating() + finally: + # This is quite arbirary but the heartbeat has a scaling implemented + self.periodic_callbacks[ + "find-missing" + ].callback_time = self.periodic_callbacks["heartbeat"].callback_time + self.ensure_communicating() + self.ensure_computing() - async def query_who_has(self, *deps): + async def query_who_has(self, *deps, stimulus_id): with log_errors(): response = await retry_operation(self.scheduler.who_has, keys=deps) - self.update_who_has(response) + self.update_who_has(response, stimulus_id) return response - def update_who_has(self, who_has): + def update_who_has(self, who_has, stimulus_id): try: + recommendations = {} for dep, workers in who_has.items(): if not workers: continue @@ -2632,10 +2787,17 @@ def update_who_has(self, who_has): ) # Do not mutate the input dict. That's rude workers = set(workers) - {self.address} - self.tasks[dep].who_has.update(workers) + dep_ts = self.tasks[dep] + dep_ts.who_has.update(workers) + + if dep_ts.state == "missing": + recommendations[dep_ts] = "fetch" for worker in workers: self.has_what[worker].add(dep) + if dep_ts.state in ("fetch", "flight", "missing"): + self.pending_data_per_worker[worker].append(dep_ts.key) + self.transitions(recommendations=recommendations, stimulus_id=stimulus_id) except Exception as e: logger.exception(e) if LOG_PDB: @@ -2661,10 +2823,7 @@ def steal_request(self, key): # If task is marked as "constrained" we haven't yet assigned it an # `available_resources` to run on, that happens in # `transition_constrained_executing` - ts.scheduler_holds_ref = False - self.release_key(ts.key, reason="stolen") - if self.validate: - assert ts.key not in self.tasks + self.transition(ts, "forgotten", stimulus_id=f"steal-request-{time()}") def release_key( self, @@ -2673,16 +2832,15 @@ def release_key( reason: Optional[str] = None, report: bool = True, ): + recommendations = {} try: if self.validate: assert not isinstance(key, TaskState) - ts = self.tasks.get(key, None) - # If the scheduler holds a reference which is usually the - # case when it instructed the task to be computed here or if - # data was scattered we must not release it unless the - # scheduler allow us to. See also handle_delete_data and - if ts is None or ts.scheduler_holds_ref: - return + ts = self.tasks[key] + # needed for legacy notification support + state_before = ts.state + ts.state = "released" + logger.debug( "Release key %s", {"key": key, "cause": cause, "reason": reason} ) @@ -2705,39 +2863,30 @@ def release_key( if key in self.threads: del self.threads[key] - if ts.state == "executing": - self.executing_count -= 1 - if ts.resource_restrictions is not None: if ts.state == "executing": for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity for d in ts.dependencies: - d.dependents.discard(ts) + ts.waiting_for_data.discard(ts) + if not d.dependents and d.state in ("flight", "fetch", "missing"): + recommendations[d] = "released" - if not d.dependents and d.state in ("flight", "fetch"): - self.release_key(d.key, reason="Dependent released") - - if report: - # Inform the scheduler of keys which will have gone missing - # We are releasing them before they have completed - if ts.state in PROCESSING: - # This path is only hit with work stealing - msg = {"op": "release", "key": key, "cause": cause} - else: - # This path is only hit when calling release_key manually - msg = { - "op": "release-worker-data", - "keys": [key], - "worker": self.address, - } - self.batched_stream.send(msg) + ts.waiting_for_data.clear() + ts.nbytes = None + ts._previous = None + ts._next = None + ts.done = False - self._notify_plugins("release_key", key, ts.state, cause, reason, report) - del self.tasks[key] + self._executing.discard(ts) + self._in_flight_tasks.discard(ts) + self._notify_plugins( + "release_key", key, state_before, cause, reason, report + ) except CommClosedError: + # Batched stream send might raise if it was already closed pass except Exception as e: logger.exception(e) @@ -2747,6 +2896,8 @@ def release_key( pdb.set_trace() raise + return recommendations + ################ # Execute Task # ################ @@ -2853,7 +3004,7 @@ def meets_resource_constraints(self, key): return True - async def _maybe_deserialize_task(self, ts): + async def _maybe_deserialize_task(self, ts, stimulus_id): if not isinstance(ts.runspec, SerializedTask): return ts.runspec try: @@ -2870,15 +3021,24 @@ async def _maybe_deserialize_task(self, ts): {"action": "deserialize", "start": start, "stop": stop} ) return function, args, kwargs - except Exception: + except Exception as e: logger.error("Could not deserialize task", exc_info=True) self.log.append((ts.key, "deserialize-error")) + emsg = error_message(e) + emsg.pop("status") + self.transition( + ts, + "error", + **emsg, + stimulus_id=stimulus_id, + ) raise def ensure_computing(self): if self.paused: return try: + stimulus_id = f"ensure-computing-{time()}" while self.constrained and self.executing_count < self.nthreads: key = self.constrained[0] ts = self.tasks.get(key, None) @@ -2887,7 +3047,7 @@ def ensure_computing(self): continue if self.meets_resource_constraints(key): self.constrained.popleft() - self.transition(ts, "executing") + self.transition(ts, "executing", stimulus_id=stimulus_id) else: break while self.ready and self.executing_count < self.nthreads: @@ -2899,9 +3059,9 @@ def ensure_computing(self): # to release. If the task has "disappeared" just continue through the heap continue elif ts.key in self.data: - self.transition(ts, "memory") + self.transition(ts, "memory", stimulus_id=stimulus_id) elif ts.state in READY: - self.transition(ts, "executing") + self.transition(ts, "executing", stimulus_id=stimulus_id) except Exception as e: logger.exception(e) if LOG_PDB: @@ -2910,30 +3070,30 @@ def ensure_computing(self): pdb.set_trace() raise - async def execute(self, key): + async def execute(self, key, stimulus_id): if self.status in (Status.closing, Status.closed, Status.closing_gracefully): return - if key not in self.tasks: return - ts = self.tasks[key] - if ts.state != "executing": - # This might happen if keys are canceled - logger.debug( - "Trying to execute a task %s which is not in executing state anymore" - % ts - ) - return - try: + if ts.state == "cancelled": + # This might happen if keys are canceled + logger.debug( + "Trying to execute a task %s which is not in executing state anymore" + % ts + ) + ts.done = True + self.transition(ts, "released", stimulus_id=stimulus_id) + return + if self.validate: assert not ts.waiting_for_data assert ts.state == "executing" assert ts.runspec is not None - function, args, kwargs = await self._maybe_deserialize_task(ts) + function, args, kwargs = await self._maybe_deserialize_task(ts, stimulus_id) args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) @@ -2979,40 +3139,29 @@ async def execute(self, key): finally: self.active_keys.discard(ts.key) - # We'll need to check again for the task state since it may have - # changed since the execution was kicked off. In particular, it may - # have been canceled and released already in which case we'll have - # to drop the result immediately - - if ts.key not in self.tasks: - logger.debug( - "Dropping result for %s since task has already been released." - % ts.key - ) - return - + key = ts.key + # Key *must* be still in tasks. Releasing it direclty is forbidden + # without going through cancelled + ts = self.tasks.get(key) + assert ts, self.story(ts) + ts.done = True + result: dict result["key"] = ts.key value = result.pop("result", None) ts.startstops.append( {"action": "compute", "start": result["start"], "stop": result["stop"]} ) self.threads[ts.key] = result["thread"] - + recommendations = {} if result["op"] == "task-finished": ts.nbytes = result["nbytes"] ts.type = result["type"] - self.transition(ts, "memory", value=value) + recommendations[ts] = ("memory", value) if self.digests is not None: self.digests["task-duration"].add(result["stop"] - result["start"]) elif isinstance(result.pop("actual-exception"), Reschedule): - self.batched_stream.send({"op": "reschedule", "key": ts.key}) - self.transition(ts, "rescheduled", report=False) - self.release_key(ts.key, report=False, reason="Reschedule") + recommendations[ts] = "rescheduled" else: - ts.exception = result["exception"] - ts.traceback = result["traceback"] - ts.exception_text = result["exception_text"] - ts.traceback_text = result["traceback_text"] logger.warning( "Compute Failed\n" "Function: %s\n" @@ -3024,7 +3173,15 @@ async def execute(self, key): convert_kwargs_to_str(kwargs2, max_len=1000), result["exception"].data, ) - self.transition(ts, "error") + recommendations[ts] = ( + "error", + result["exception"], + result["traceback"], + result["exception_text"], + result["traceback_text"], + ) + + self.transitions(recommendations, stimulus_id=stimulus_id) logger.debug("Send compute response to scheduler: %s, %s", ts.key, result) @@ -3033,15 +3190,18 @@ async def execute(self, key): assert not ts.waiting_for_data except Exception as exc: + assert ts logger.error( "Exception during execution of task %s.", ts.key, exc_info=True ) emsg = error_message(exc) - ts.exception = emsg["exception"] - ts.traceback = emsg["traceback"] - ts.exception_text = emsg["exception_text"] - ts.traceback_text = emsg["traceback_text"] - self.transition(ts, "error") + emsg.pop("status") + self.transition( + ts, + "error", + **emsg, + stimulus_id=stimulus_id, + ) finally: self.ensure_computing() self.ensure_communicating() @@ -3289,6 +3449,15 @@ def get_call_stack(self, comm=None, keys=None): def _notify_plugins(self, method_name, *args, **kwargs): for name, plugin in self.plugins.items(): if hasattr(plugin, method_name): + if method_name == "release_key": + warnings.warn( + """ +The `WorkerPlugin.release_key` hook is depreacted and will be removed in a future version. +A similar event can now be caught by filtering for a `finish=='released'` event in the `WorkerPlugin.transition` hook. +""", + DeprecationWarning, + ) + try: getattr(plugin, method_name)(*args, **kwargs) except Exception: @@ -3312,6 +3481,9 @@ def validate_task_executing(self, ts): assert ts.runspec is not None assert ts.key not in self.data assert not ts.waiting_for_data + assert all(ts.state == "memory" for ts in ts.dependencies), [ + self.story(t) for t in ts.dependencies if ts.state != "memory" + ] assert all( dep.key in self.data or dep.key in self.actors for dep in ts.dependencies ) @@ -3333,41 +3505,62 @@ def validate_task_waiting(self, ts): def validate_task_flight(self, ts): assert ts.key not in self.data + assert ts in self._in_flight_tasks assert not any(dep.key in self.ready for dep in ts.dependents) assert ts.coming_from assert ts.coming_from in self.in_flight_workers assert ts.key in self.in_flight_workers[ts.coming_from] def validate_task_fetch(self, ts): - assert ts.runspec is None assert ts.key not in self.data - assert self.address not in ts.who_has #!!!!!!!! - # FIXME This is currently not an invariant since upon comm failure we - # remove the erroneous worker from all who_has and correct the state - # upon the next ensure_communicate - - # if not ts.who_has: - # # If we do not know who_has for a fetch task, it must be logged in - # # the missing dep. There should be a handle_missing_dep running for - # # all of these keys - - # assert ts.key in self._missing_dep_flight, ( - # ts.key, - # self.story(ts), - # self._missing_dep_flight.copy(), - # self.in_flight_workers.copy(), - # ) - assert ts.dependents - + assert self.address not in ts.who_has for w in ts.who_has: assert ts.key in self.has_what[w] + def validate_task_missing(self, ts): + assert ts.key not in self.data + assert not ts.who_has + assert not any(ts.key in has_what for has_what in self.has_what.values()) + assert ts.key in self._missing_dep_flight + + def validate_task_cancelled(self, ts): + assert ts.key not in self.data + assert ts._previous + + def validate_task_resumed(self, ts): + assert ts.key not in self.data + assert ts._next + assert ts._previous + + def validate_task_released(self, ts): + assert ts.key not in self.data + assert not ts._next + assert not ts._previous + assert ts not in self._executing + assert ts not in self._in_flight_tasks + assert ts not in self._missing_dep_flight + assert ts not in self._missing_dep_flight + assert not ts.who_has + assert not any(ts.key in has_what for has_what in self.has_what.values()) + assert not ts.waiting_for_data + assert not ts.done + assert not ts.exception + assert not ts.traceback + def validate_task(self, ts): try: + if ts.key in self.tasks: + assert self.tasks[ts.key] == ts if ts.state == "memory": self.validate_task_memory(ts) elif ts.state == "waiting": self.validate_task_waiting(ts) + elif ts.state == "missing": + self.validate_task_missing(ts) + elif ts.state == "cancelled": + self.validate_task_cancelled(ts) + elif ts.state == "resumed": + self.validate_task_resumed(ts) elif ts.state == "ready": self.validate_task_ready(ts) elif ts.state == "executing": @@ -3376,6 +3569,8 @@ def validate_task(self, ts): self.validate_task_flight(ts) elif ts.state == "fetch": self.validate_task_fetch(ts) + elif ts.state == "released": + self.validate_task_released(ts) except Exception as e: logger.exception(e) if LOG_PDB: @@ -3388,6 +3583,8 @@ def validate_state(self): if self.status != Status.running: return try: + assert self.executing_count >= 0 + waiting_for_data_count = 0 for ts in self.tasks.values(): assert ts.state is not None # check that worker has task @@ -3402,19 +3599,21 @@ def validate_state(self): # Might need better bookkeeping assert dep.state is not None assert ts in dep.dependents, ts - for key in ts.waiting_for_data: - ts_wait = self.tasks[key] + if ts.waiting_for_data: + waiting_for_data_count += 1 + for ts_wait in ts.waiting_for_data: + assert ts_wait.key in self.tasks assert ( - ts_wait.state == "flight" - or ts_wait.state == "fetch" + ts_wait.state + in ("ready", "executing", "flight", "fetch", "missing") or ts_wait.key in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) - ) + ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) if ts.state == "memory": assert isinstance(ts.nbytes, int) assert not ts.waiting_for_data assert ts.key in self.data or ts.key in self.actors - + assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): for k in keys: assert worker in self.tasks[k].who_has @@ -3650,6 +3849,7 @@ def secede(): worker.maybe_transition_long_running, worker.tasks[thread_state.key], compute_duration=duration, + stimulus_id=f"secede-{thread_state.key}-{time()}", ) diff --git a/distributed/worker_client.py b/distributed/worker_client.py index 989a3f8f0d..059d5dfaad 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -3,6 +3,8 @@ import dask +from distributed.metrics import time + from .threadpoolexecutor import rejoin, secede from .worker import get_client, get_worker, thread_state @@ -50,9 +52,14 @@ def worker_client(timeout=None, separate_thread=True): worker = get_worker() client = get_client(timeout=timeout) if separate_thread: + duration = time() - thread_state.start_time secede() # have this thread secede from the thread pool worker.loop.add_callback( - worker.transition, worker.tasks[thread_state.key], "long-running" + worker.transition, + worker.tasks[thread_state.key], + "long-running", + stimulus_id=f"worker-client-secede-{time()}", + compute_duration=duration, ) yield client diff --git a/setup.cfg b/setup.cfg index f77b02739c..3c07c5e90b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,6 +59,4 @@ markers = # the MacOS GitHub CI (although it's been reported to work on MacBooks). # The CI script modifies this config file on the fly on Linux. timeout_method = thread -# This should not be reduced; Windows CI has been observed to be occasionally -# exceptionally slow. timeout = 300 From 8f1847993c8c2781496420c1fa35b586f29d4b9b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 2 Sep 2021 11:51:55 +0100 Subject: [PATCH 2/6] Code review --- distributed/cfexecutor.py | 4 - distributed/scheduler.py | 9 +- distributed/tests/test_cancelled_state.py | 171 ++----- distributed/tests/test_steal.py | 4 +- distributed/tests/test_stress.py | 1 - distributed/tests/test_worker.py | 53 +-- distributed/worker.py | 556 +++++++++------------- setup.cfg | 2 + 8 files changed, 309 insertions(+), 491 deletions(-) diff --git a/distributed/cfexecutor.py b/distributed/cfexecutor.py index 55d9ccc563..8028a4bc7f 100644 --- a/distributed/cfexecutor.py +++ b/distributed/cfexecutor.py @@ -127,10 +127,6 @@ def map(self, fn, *iterables, **kwargs): raise TypeError("unexpected arguments to map(): %s" % sorted(kwargs)) fs = self._client.map(fn, *iterables, **self._kwargs) - if isinstance(fs, list): - # Below iterator relies on this being a generator to cancel - # remaining futures - fs = (val for val in fs) # Yield must be hidden in closure so that the tasks are submitted # before the first iterator value is required. diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 667464dc3a..630da815fb 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5417,13 +5417,12 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): def release_worker_data(self, comm=None, key=None, worker=None): parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers_dv: + ws: WorkerState = parent._workers_dv.get(worker) + ts: TaskState = parent._tasks.get(key) + if not ws or not ts: return - ws: WorkerState = parent._workers_dv[worker] - ts: TaskState - ts = parent._tasks.get(key) recommendations: dict = {} - if ts and ts in ws._has_what: + if ts in ws._has_what: del ws._has_what[ts] ws._nbytes -= ts.get_nbytes() wh: set = ts._who_has diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index b0906fc983..83cc7a8420 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -1,152 +1,82 @@ import asyncio from unittest import mock -import pytest - -from distributed import Nanny, wait +import distributed +from distributed import Nanny from distributed.core import CommClosedError from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc -pytestmark = pytest.mark.ci1 - - -@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) -async def test_abort_execution_release(c, s, w): - fut = c.submit(slowinc, 1, delay=0.5, key="f1") - - async def wait_for_exec(dask_worker): - while ( - fut.key not in dask_worker.tasks - or dask_worker.tasks[fut.key].state != "executing" - ): - await asyncio.sleep(0.01) - - await c.run(wait_for_exec) - fut.release() - fut2 = c.submit(inc, 1, key="f2") +async def wait_for_state(key, state, dask_worker): + while key not in dask_worker.tasks or dask_worker.tasks[key].state != state: + await asyncio.sleep(0.005) - async def observe(dask_worker): - cancelled = False - while ( - fut.key in dask_worker.tasks - and dask_worker.tasks[fut.key].state != "released" - ): - if dask_worker.tasks[fut.key].state == "cancelled": - cancelled = True - await asyncio.sleep(0.005) - return cancelled - assert await c.run(observe) - await fut2 - del fut2 +async def wait_for_cancelled(key, dask_worker): + while key in dask_worker.tasks: + if dask_worker.tasks[key].state == "cancelled": + return + await asyncio.sleep(0.005) + assert False @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) -async def test_abort_execution_reschedule(c, s, w): +async def test_abort_execution_release(c, s, a): fut = c.submit(slowinc, 1, delay=1) - - async def wait_for_exec(dask_worker): - while ( - fut.key not in dask_worker.tasks - or dask_worker.tasks[fut.key].state != "executing" - ): - await asyncio.sleep(0.01) - - await c.run(wait_for_exec) - + await c.run(wait_for_state, fut.key, "executing") fut.release() + await c.run(wait_for_cancelled, fut.key) - async def observe(dask_worker): - while ( - fut.key in dask_worker.tasks - and dask_worker.tasks[fut.key].state != "released" - ): - if dask_worker.tasks[fut.key].state == "cancelled": - return - await asyncio.sleep(0.005) - assert await c.run(observe) +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_abort_execution_reschedule(c, s, a): fut = c.submit(slowinc, 1, delay=1) + await c.run(wait_for_state, fut.key, "executing") + fut.release() + await c.run(wait_for_cancelled, fut.key) + fut = c.submit(slowinc, 1, delay=0.1) await fut @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) -async def test_abort_execution_add_as_dependency(c, s, w): +async def test_abort_execution_add_as_dependency(c, s, a): fut = c.submit(slowinc, 1, delay=1) - - async def wait_for_exec(dask_worker): - while ( - fut.key not in dask_worker.tasks - or dask_worker.tasks[fut.key].state != "executing" - ): - await asyncio.sleep(0.01) - - await c.run(wait_for_exec) - + await c.run(wait_for_state, fut.key, "executing") fut.release() + await c.run(wait_for_cancelled, fut.key) - async def observe(dask_worker): - while ( - fut.key in dask_worker.tasks - and dask_worker.tasks[fut.key].state != "released" - ): - if dask_worker.tasks[fut.key].state == "cancelled": - return - await asyncio.sleep(0.005) - - assert await c.run(observe) fut = c.submit(slowinc, 1, delay=1) fut = c.submit(slowinc, fut, delay=1) await fut -@gen_cluster(client=True, nthreads=[("", 1)] * 2, Worker=Nanny) +@gen_cluster(client=True, Worker=Nanny) async def test_abort_execution_to_fetch(c, s, a, b): fut = c.submit(slowinc, 1, delay=2, key="f1", workers=[a.worker_address]) - - async def wait_for_exec(dask_worker): - while ( - fut.key not in dask_worker.tasks - or dask_worker.tasks[fut.key].state != "executing" - ): - await asyncio.sleep(0.01) - - await c.run(wait_for_exec, workers=[a.worker_address]) - + await c.run(wait_for_state, fut.key, "executing", workers=[a.worker_address]) fut.release() - - async def observe(dask_worker): - while ( - fut.key in dask_worker.tasks - and dask_worker.tasks[fut.key].state != "released" - ): - if dask_worker.tasks[fut.key].state == "cancelled": - return - await asyncio.sleep(0.005) - - assert await c.run(observe) + await c.run(wait_for_cancelled, fut.key, workers=[a.worker_address]) # While the first worker is still trying to compute f1, we'll resubmit it to # another worker with a smaller delay. The key is still the same - fut = c.submit(slowinc, 1, delay=0, key="f1", workers=[b.worker_address]) + fut = c.submit(inc, 1, key="f1", workers=[b.worker_address]) # then, a must switch the execute to fetch. Instead of doing so, it will # simply re-use the currently computing result. - fut = c.submit(slowinc, fut, delay=1, workers=[a.worker_address], key="f2") + fut = c.submit(inc, fut, workers=[a.worker_address], key="f2") await fut -@gen_cluster(client=True, nthreads=[("", 1)] * 2) -async def test_worker_find_missing(c, s, *workers): - - fut = c.submit(slowinc, 1, delay=0.5, workers=[workers[0].address]) - await wait(fut) - # We do not want to use proper API since the API usually ensures that the cluster is informed properly. We want to - del workers[0].data[fut.key] - del workers[0].tasks[fut.key] +@gen_cluster(client=True) +async def test_worker_find_missing(c, s, a, b): + fut = c.submit(inc, 1, workers=[a.address]) + await fut + # We do not want to use proper API since it would ensure that the cluster is + # informed properly + del a.data[fut.key] + del a.tasks[fut.key] - # Actually no worker has the data, the scheduler is supposed to reschedule - await c.submit(inc, fut, workers=[workers[1].address]) + # Actually no worker has the data; the scheduler is supposed to reschedule + assert await c.submit(inc, fut, workers=[b.address]) == 3 @gen_cluster(client=True) @@ -159,8 +89,8 @@ async def test_worker_stream_died_during_comm(c, s, a, b): write_event=write_event, ) fut = c.submit(inc, 1, workers=[a.address], allow_other_workers=True) - await wait(fut) - # Actually no worker has the data, the scheduler is supposed to reschedule + await fut + # Actually no worker has the data; the scheduler is supposed to reschedule res = c.submit(inc, fut, workers=[b.address]) await write_queue.get() @@ -173,10 +103,6 @@ async def test_worker_stream_died_during_comm(c, s, a, b): @gen_cluster(client=True) async def test_flight_to_executing_via_cancelled_resumed(c, s, a, b): - import asyncio - - import distributed - lock = asyncio.Lock() await lock.acquire() @@ -190,23 +116,13 @@ async def wait_and_raise(*args, **kwargs): side_effect=wait_and_raise, ): fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True) - fut2 = c.submit(inc, fut1, workers=[b.address]) - async def observe(dask_worker): - while ( - fut1.key not in dask_worker.tasks - or dask_worker.tasks[fut1.key].state != "flight" - ): - await asyncio.sleep(0) - - await c.run(observe, workers=[b.address]) + await wait_for_state(fut1.key, "flight", b) # Close in scheduler to ensure we transition and reschedule task properly await s.close_worker(worker=a.address) - - while fut1.key not in b.tasks or b.tasks[fut1.key].state != "resumed": - await asyncio.sleep(0) + await wait_for_state(fut1.key, "resumed", b) lock.release() assert await fut2 == 3 @@ -214,6 +130,5 @@ async def observe(dask_worker): b_story = b.story(fut1.key) assert any("receive-dep-failed" in msg for msg in b_story) assert any("missing-dep" in msg for msg in b_story) - if not any("cancelled" in msg for msg in b_story): - breakpoint() + assert any("cancelled" in msg for msg in b_story) assert any("resumed" in msg for msg in b_story) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3b4793523b..06fd0e6187 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -683,12 +683,10 @@ async def test_dont_steal_already_released(c, s, a, b): del future - # In case the system is slow (e.g. network) ensure that nothing bad happens - # if the key was already released while key in a.tasks and a.tasks[key].state != "released": await asyncio.sleep(0.05) - a.steal_request(key) + a.handle_steal_request(key) assert len(a.batched_stream.buffer) == 1 msg = a.batched_stream.buffer[0] assert msg["op"] == "steal-response" diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index c11353e1df..063281f071 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -173,7 +173,6 @@ def vsum(*args): @pytest.mark.avoid_ci @pytest.mark.slow -@pytest.mark.timeout(1100) # Override timeout from setup.cfg @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 80) async def test_stress_communication(c, s, *workers): s.validate = False # very slow otherwise diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index aa4670608c..aa14e2cc95 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2562,7 +2562,7 @@ def __call__(self, *args, **kwargs): await asyncio.sleep(0) ts = s.tasks[fut.key] - a.steal_request(fut.key) + a.handle_steal_request(fut.key) stealing_ext.scheduler.send_task_to_worker(b.address, ts) fut2 = c.submit(inc, fut, workers=[a.address]) @@ -2649,17 +2649,14 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b): await fut2 -def _acquire_replica(scheduler, worker, future): - if not isinstance(future, list): - keys = [future.key] - else: - keys = [f.key for f in future] +def _acquire_replicas(scheduler, worker, *futures): + keys = [f.key for f in futures] scheduler.stream_comms[worker.address].send( { - "op": "acquire-replica", + "op": "acquire-replicas", "keys": keys, - "stimulus_id": time(), + "stimulus_id": f"acquire-replicas-{time()}", "priorities": {key: scheduler.tasks[key].priority for key in keys}, "who_has": { key: {w.address for w in scheduler.tasks[key].who_has} for key in keys @@ -2668,33 +2665,30 @@ def _acquire_replica(scheduler, worker, future): ) -def _remove_replica(scheduler, worker, future): - if not isinstance(future, list): - keys = [future.key] - else: - keys = [f.key for f in future] +def _remove_replicas(scheduler, worker, *futures): + keys = [f.key for f in futures] scheduler.stream_comms[worker.address].send( { "op": "remove-replicas", "keys": keys, - "stimulus_id": time(), + "stimulus_id": f"remove-replicas-{time()}", } ) @gen_cluster(client=True) -async def test_acquire_replica(c, s, a, b): +async def test_acquire_replicas(c, s, a, b): fut = c.submit(inc, 1, workers=[a.address]) await fut - _acquire_replica(s, b, fut) + _acquire_replicas(s, b, fut) - while not len(s.who_has[fut.key]) == 2: + while len(s.who_has[fut.key]) != 2: await asyncio.sleep(0.005) - for w in [a, b]: - assert fut.key in w.tasks + for w in (a, b): + assert w.data[fut.key] == 2 assert w.tasks[fut.key].state == "memory" fut.release() @@ -2704,13 +2698,13 @@ async def test_acquire_replica(c, s, a, b): @gen_cluster(client=True) -async def test_acquire_replica_same_channel(c, s, a, b): +async def test_acquire_replicas_same_channel(c, s, a, b): fut = c.submit(inc, 1, workers=[a.address], key="f-replica") futB = c.submit(inc, 2, workers=[a.address], key="f-B") futC = c.submit(inc, futB, workers=[b.address], key="f-C") await fut - _acquire_replica(s, b, fut) + _acquire_replicas(s, b, fut) await futC while fut.key not in b.tasks: @@ -2727,14 +2721,14 @@ async def test_acquire_replica_same_channel(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) -async def test_acquire_replica_many(c, s, *workers): +async def test_acquire_replicas_many(c, s, *workers): futs = c.map(inc, range(10), workers=[workers[0].address]) res = c.submit(sum, futs, workers=[workers[1].address]) final = c.submit(slowinc, res, delay=0.5, workers=[workers[1].address]) await wait(futs) - _acquire_replica(s, workers[2], futs) + _acquire_replicas(s, workers[2], *futs) # Worker 2 should normally not even be involved if there was no replication while not all( @@ -2758,12 +2752,12 @@ async def test_acquire_replica_many(c, s, *workers): async def test_remove_replica_simple(c, s, a, b): futs = c.map(inc, range(10), workers=[a.address]) await wait(futs) - _acquire_replica(s, b, futs) + _acquire_replicas(s, b, *futs) while not all(len(s.tasks[f.key].who_has) == 2 for f in futs): await asyncio.sleep(0.01) - _remove_replica(s, b, futs) + _remove_replicas(s, b, *futs) while b.tasks: await asyncio.sleep(0.01) @@ -2793,15 +2787,15 @@ def reduce(*args, **kwargs): while not all(fut.done() for fut in intermediate): # The worker should reject all of these since they are required - _remove_replica(s, w, futs) - _remove_replica(s, w, intermediate) + _remove_replicas(s, w, *futs) + _remove_replicas(s, w, *intermediate) await asyncio.sleep(0.001) await wait(intermediate) # Since intermediate is done, futs replicas may be removed. # They might be already gone due to the above remove replica calls - _remove_replica(s, w, futs) + _remove_replicas(s, w, *futs) # the intermediate tasks should not be touched because they are still needed # (the scheduler should not have made the above call but we should be safe # regarless) @@ -2823,13 +2817,12 @@ def reduce(*args, **kwargs): @gen_cluster(client=True, nthreads=[("", 1)] * 3) async def test_who_has_consistent_remove_replica(c, s, *workers): - a = workers[0] other_workers = {w for w in workers if w != a} f1 = c.submit(inc, 1, key="f1", workers=[w.address for w in other_workers]) await wait(f1) for w in other_workers: - _acquire_replica(s, w, f1) + _acquire_replicas(s, w, f1) while not len(s.tasks[f1.key].who_has) == len(other_workers): await asyncio.sleep(0) diff --git a/distributed/worker.py b/distributed/worker.py index 8324df16e8..2fe763f6f4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -80,7 +80,7 @@ no_value = "--no-value-sentinel--" -PROCESSING = ( +PROCESSING = { "waiting", "ready", "constrained", @@ -88,8 +88,8 @@ "long-running", "cancelled", "resumed", -) -READY = ("ready", "constrained") +} +READY = {"ready", "constrained"} DEFAULT_EXTENSIONS = [PubSubWorkerExtension] @@ -185,7 +185,7 @@ def __init__(self, key, runspec=None): self.traceback_text = "" self.type = None self.suspicious_count = 0 - self.startstops = list() + self.startstops = [] self.start_time = None self.stop_time = None self.metadata = {} @@ -430,16 +430,16 @@ def __init__( lifetime_restart=None, **kwargs, ): - self.tasks = dict() + self.tasks = {} self.waiting_for_data_count = 0 self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) self.nanny = nanny self._lock = threading.Lock() - self.data_needed = list() + self.data_needed = [] - self.in_flight_workers = dict() + self.in_flight_workers = {} self.total_out_connections = dask.config.get( "distributed.worker.connections.outgoing" ) @@ -450,10 +450,10 @@ def __init__( self.comm_nbytes = 0 self._missing_dep_flight = set() - self.threads = dict() + self.threads = {} self.active_threads_lock = threading.Lock() - self.active_threads = dict() + self.active_threads = {} self.active_keys = set() self.profile_keys = defaultdict(profile.create) self.profile_keys_history = deque(maxlen=3600) @@ -462,7 +462,7 @@ def __init__( self.generation = 0 - self.ready = list() + self.ready = [] self.constrained = deque() self._executing = set() self._in_flight_tasks = set() @@ -486,14 +486,14 @@ def __init__( ("cancelled", "forgotten"): self.transition_cancelled_forgotten, ("cancelled", "memory"): self.transition_cancelled_memory, ("cancelled", "error"): self.transition_generic_error, - ("resumed", "memory"): self._transition_to_memory_generic, + ("resumed", "memory"): self.transition_generic_memory, ("resumed", "error"): self.transition_generic_error, - ("resumed", "released"): self.transition_released_generic, + ("resumed", "released"): self.transition_generic_released, ("resumed", "waiting"): self.transition_rescheduled_next, ("resumed", "fetch"): self.transition_rescheduled_next, ("constrained", "executing"): self.transition_constrained_executing, ("constrained", "released"): self.transition_constrained_released, - ("error", "released"): self.transition_released_generic, + ("error", "released"): self.transition_generic_released, ("executing", "error"): self.transition_executing_error, ("executing", "long-running"): self.transition_executing_long_running, ("executing", "memory"): self.transition_executing_memory, @@ -506,7 +506,7 @@ def __init__( ("flight", "fetch"): self.transition_flight_fetch, ("flight", "memory"): self.transition_flight_memory, ("flight", "released"): self.transition_flight_released, - ("long-running", "error"): self.transition_long_running_error, + ("long-running", "error"): self.transition_generic_error, ("long-running", "memory"): self.transition_long_running_memory, ("long-running", "rescheduled"): self.transition_executing_rescheduled, ("long-running", "released"): self.transition_executing_released, @@ -516,7 +516,7 @@ def __init__( ("missing", "error"): self.transition_generic_error, ("ready", "error"): self.transition_generic_error, ("ready", "executing"): self.transition_ready_executing, - ("ready", "released"): self.transition_released_generic, + ("ready", "released"): self.transition_generic_released, ("released", "error"): self.transition_generic_error, ("released", "fetch"): self.transition_released_fetch, ("released", "forgotten"): self.transition_released_forgotten, @@ -615,7 +615,7 @@ def __init__( self.available_resources = (resources or {}).copy() self.death_timeout = parse_timedelta(death_timeout) - self.extensions = dict() + self.extensions = {} if silence_logs: silence_logging(level=silence_logs) @@ -668,7 +668,7 @@ def __init__( or sys.maxsize, ) else: - self.data = dict() + self.data = {} self.actors = {} self.loop = loop or IOLoop.current() @@ -699,7 +699,7 @@ def __init__( self.batched_stream = BatchedSend(interval="2ms", loop=self.loop) self.name = name self.scheduler_delay = 0 - self.stream_comms = dict() + self.stream_comms = {} self.heartbeat_active = False self._ipython_kernel = None @@ -746,12 +746,12 @@ def __init__( stream_handlers = { "close": self.close, - "cancel-compute": self.cancel_compute, - "acquire-replica": self.handle_acquire_replica, - "compute-task": self.compute_task, + "cancel-compute": self.handle_cancel_compute, + "acquire-replicas": self.handle_acquire_replicas, + "compute-task": self.handle_compute_task, "free-keys": self.handle_free_keys, "remove-replicas": self.handle_remove_replicas, - "steal-request": self.steal_request, + "steal-request": self.handle_steal_request, } super().__init__( @@ -1529,36 +1529,32 @@ def update_data( self, comm=None, data=None, report=True, serializers=None, stimulus_id=None ): if stimulus_id is None: - stimulus_id = "update-data" + stimulus_id = f"update-data-{time()}" recommendations = {} scheduler_messages = [] for key, value in data.items(): - ts = self.tasks.get(key) - if getattr(ts, "state", None) is not None: + try: + ts = self.tasks[key] recommendations[ts] = ("memory", value) - else: + except KeyError: self.tasks[key] = ts = TaskState(key) - recommendations, s_msgs = self._put_key_in_memory( + recs, smsgs = self._put_key_in_memory( ts, value, stimulus_id=stimulus_id ) - scheduler_messages.extend(s_msgs) + recommendations.update(recs) + scheduler_messages += smsgs ts.priority = None ts.duration = None self.log.append((key, "receive-from-scatter")) if report: - scheduler_messages.append( - { - "op": "add-keys", - "keys": list(data), - } - ) + scheduler_messages.append({"op": "add-keys", "keys": list(data)}) + self.transitions(recommendations, stimulus_id=stimulus_id) for msg in scheduler_messages: self.batched_stream.send(msg) - info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} - return info + return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} def handle_free_keys(self, comm=None, keys=None, reason=None): """ @@ -1575,23 +1571,21 @@ def handle_free_keys(self, comm=None, keys=None, reason=None): recommendations = {} for key in keys: ts = self.tasks.get(key) - if ts is not None: - if not ts.dependents: - recommendations[ts] = "forgotten" - else: - recommendations[ts] = "released" + if ts: + recommendations[ts] = "released" if ts.dependents else "forgotten" self.transitions(recommendations, stimulus_id=reason) def handle_remove_replicas(self, keys, stimulus_id): - """Stream handler notifying the worker that it might be holding unreferenced, superfluous data. + """Stream handler notifying the worker that it might be holding unreferenced, + superfluous data. - This should not actually happen during ordinary operations and is only - intended to correct any erroneous state. An example where this is - necessary is if a worker fetches data for a downstream task but that - task is released before the data arrives. - In this case, the scheduler will notify the worker that it may be - holding this unnecessary data, if the worker hasn't released the data itself, already. + This should not actually happen during ordinary operations and is only intended + to correct any erroneous state. An example where this is necessary is if a + worker fetches data for a downstream task but that task is released before the + data arrives. In this case, the scheduler will notify the worker that it may be + holding this unnecessary data, if the worker hasn't released the data itself, + already. This handler does not guarantee the task nor the data to be actually released but only asks the worker to release the data on a best effort @@ -1601,15 +1595,13 @@ def handle_remove_replicas(self, keys, stimulus_id): For stronger guarantees, see handler free_keys """ - self.log.append(("remove-replica", keys, stimulus_id)) + self.log.append(("remove-replicas", keys, stimulus_id)) recommendations = {} - for key in list(keys): + for key in keys: ts = self.tasks.get(key) if ts and not ts.is_protected(): - if not ts.dependents: - recommendations[ts] = "forgotten" - else: - recommendations[ts] = "released" + recommendations[ts] = "released" if ts.dependents else "forgotten" + self.transitions(recommendations=recommendations, stimulus_id=stimulus_id) return "OK" @@ -1632,7 +1624,7 @@ async def set_resources(self, **resources): # Task Management # ################### - def cancel_compute(self, key, reason): + def handle_cancel_compute(self, key, reason): """ Cancel a task on a best effort basis. This is only possible while a task is in state `waiting` or `ready`. @@ -1648,19 +1640,19 @@ def cancel_compute(self, key, reason): assert not ts.dependents self.transition(ts, "released", stimulus_id=reason) - def handle_acquire_replica( + def handle_acquire_replicas( self, comm=None, keys=None, priorities=None, who_has=None, stimulus_id=None ): recommendations = {} scheduler_msgs = [] for k in keys: - recs, s_msgs = self.register_acquire_internal( + recs, smsgs = self.register_acquire_internal( k, stimulus_id=stimulus_id, priority=priorities[k], ) - scheduler_msgs.extend(s_msgs) recommendations.update(recs) + scheduler_msgs += smsgs self.update_who_has(who_has, stimulus_id=stimulus_id) @@ -1669,20 +1661,19 @@ def handle_acquire_replica( self.transitions(recommendations, stimulus_id=stimulus_id) def register_acquire_internal(self, key, priority, stimulus_id): - if key in self.tasks: + try: + ts = self.tasks[key] logger.debug( - "Data task already known %s", - {"task": self.tasks[key], "stimulus_id": stimulus_id}, + "Data task already known %s", {"task": ts, "stimulus_id": stimulus_id} ) - ts = self.tasks[key] - else: + except KeyError: self.tasks[key] = ts = TaskState(key) self.log.append((key, "register-replica", ts.state, stimulus_id, time())) ts.priority = ts.priority or priority + recommendations = {} scheduler_msgs = [] - if ts.state in ("released", "cancelled", "error"): recommendations[ts] = "fetch" @@ -1709,13 +1700,13 @@ def transition_table_to_dot(self, filename="worker-transitions", format=None): c.attr(style="filled", color="lightgrey") c.node_attr.update(style="filled", color="white") c.attr(label="executable") - for state in [ + for state in ( "waiting", "ready", "executing", "constrained", "long-running", - ]: + ): c.node(state, label=state) seen.add(state) @@ -1725,17 +1716,10 @@ def transition_table_to_dot(self, filename="worker-transitions", format=None): c.node(state, label=state) seen.add(state) - # c.node("released", label="released", ports="n") - # seen.add('released') - - for state in all_states: - continue - # if state in seen: - g.node(state, label=state) g.edges(self._transitions_table.keys()) return graphviz_to_file(g, filename=filename, format=format) - def compute_task( + def handle_compute_task( self, *, key, @@ -1754,13 +1738,13 @@ def compute_task( stimulus_id=None, ): self.log.append((key, "compute-task", stimulus_id, time())) - if key in self.tasks: + try: + ts = self.tasks[key] logger.debug( "Asked to compute an already known task %s", - {"task": self.tasks[key], "stimulus_id": stimulus_id}, + {"task": ts, "stimulus_id": stimulus_id}, ) - ts = self.tasks[key] - else: + except KeyError: self.tasks[key] = ts = TaskState(key) ts.runspec = SerializedTask(function, args, kwargs, task) @@ -1784,30 +1768,29 @@ def compute_task( recommendations = {} scheduler_msgs = [] - for dependency, _ in who_has.items(): - recs, s_msgs = self.register_acquire_internal( + for dependency in who_has: + recs, smsgs = self.register_acquire_internal( key=dependency, stimulus_id=stimulus_id, priority=priority, ) recommendations.update(recs) - scheduler_msgs.extend(s_msgs) + scheduler_msgs += smsgs dep_ts = self.tasks[dependency] # link up to child / parents ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) - if ts.state in ("ready", "executing", "waiting"): + if ts.state in {"ready", "executing", "waiting"}: pass elif ts.state == "memory": recommendations[ts] = "memory" scheduler_msgs.append(self.get_task_state_for_scheduler(ts)) - elif ts.state in ("released", "fetch", "flight", "missing"): - recommendations[ts] = "waiting" - elif ts.state == "cancelled": + elif ts.state in {"released", "fetch", "flight", "missing", "cancelled"}: recommendations[ts] = "waiting" else: + # FIXME Either remove exception or handle resumed raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") for msg in scheduler_msgs: @@ -1816,7 +1799,7 @@ def compute_task( # We received new info, that's great but not related to the compute-task # instruction - self.update_who_has(who_has=who_has, stimulus_id=stimulus_id) + self.update_who_has(who_has, stimulus_id=stimulus_id) if nbytes is not None: for key, value in nbytes.items(): self.tasks[key].nbytes = value @@ -1834,23 +1817,19 @@ def transition_missing_released(self, ts, *, stimulus_id): return recommendations, [] def transition_fetch_missing(self, ts, *, stimulus_id): - # handle_missing will append to self.data_needed if new workers - # are found + # handle_missing will append to self.data_needed if new workers are found ts.state = "missing" self._missing_dep_flight.add(ts) return {}, [] def transition_released_fetch(self, ts, *, stimulus_id): - if self.validate: - assert ts.state == "released" - for w in ts.who_has: self.pending_data_per_worker[w].append(ts.key) ts.state = "fetch" heapq.heappush(self.data_needed, (ts.priority, ts.key)) return {}, [] - def transition_released_generic(self, ts, *, stimulus_id): + def transition_generic_released(self, ts, *, stimulus_id): recs = self.release_key(ts.key, reason=stimulus_id) return recs, [] @@ -1865,13 +1844,13 @@ def transition_released_waiting(self, ts, *, stimulus_id): if not dep_ts.state == "memory": ts.waiting_for_data.add(dep_ts) - if not ts.waiting_for_data: - if not ts.resource_restrictions: - recommendations[ts] = "ready" - else: - recommendations[ts] = "constrained" - else: + if ts.waiting_for_data: self.waiting_for_data_count += 1 + elif ts.resource_restrictions: + recommendations[ts] = "constrained" + else: + recommendations[ts] = "ready" + ts.state = "waiting" return recommendations, [] @@ -1888,8 +1867,8 @@ def transition_fetch_flight(self, ts, worker, *, stimulus_id): def transition_memory_released(self, ts, *, stimulus_id): recs = self.release_key(ts.key, reason=stimulus_id) - s_msgs = [{"op": "release-worker-data", "key": ts.key}] - return recs, s_msgs + smsgs = [{"op": "release-worker-data", "key": ts.key}] + return recs, smsgs def transition_waiting_constrained(self, ts, *, stimulus_id): if self.validate: @@ -1906,27 +1885,28 @@ def transition_waiting_constrained(self, ts, *, stimulus_id): return {}, [] def transition_long_running_rescheduled(self, ts, *, stimulus_id): - msgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] - return {ts: "released"}, msgs + recs = {ts: "released"} + smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] + return recs, smsgs def transition_executing_rescheduled(self, ts, *, stimulus_id): for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - msgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] - self._executing.discard(ts) - return {ts: "released"}, msgs + + recs = {ts: "released"} + smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] + return recs, smsgs def transition_waiting_ready(self, ts, *, stimulus_id): if self.validate: assert ts.state == "waiting" - assert not ts.waiting_for_data - assert all( - dep.key in self.data or dep.key in self.actors - for dep in ts.dependencies - ) - assert all(dep.state == "memory" for dep in ts.dependencies) assert ts.key not in self.ready + assert not ts.waiting_for_data + for dep in ts.dependencies: + assert dep.key in self.data or dep.key in self.actors + assert dep.state == "memory" + ts.state = "ready" heapq.heappush(self.ready, (ts.priority, ts.key)) @@ -1943,18 +1923,6 @@ def transition_generic_error( ts.state = "error" return {}, smsgs - def transition_long_running_error( - self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id - ): - return self.transition_generic_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, - ) - def transition_executing_error( self, ts, exception, traceback, exception_text, traceback_text, *, stimulus_id ): @@ -1973,21 +1941,18 @@ def transition_executing_error( def transition_rescheduled_next(self, ts, *, stimulus_id): next_state = ts._next recs = self.release_key(ts.key, reason=stimulus_id) - if self.validate: - assert ts.state == "released" recs[ts] = next_state return recs, [] def transition_cancelled_fetch(self, ts, *, stimulus_id): if ts.done: return {ts: "released"}, [] - recommendations = {} - if ts._previous == "flight": + elif ts._previous == "flight": ts.state = ts._previous + return {}, [] else: assert ts._previous == "executing" - recommendations[ts] = ("resumed", "fetch") - return recommendations, [] + return {ts: ("resumed", "fetch")}, [] def transition_cancelled_resumed(self, ts, next, *, stimulus_id): ts._next = next @@ -1997,17 +1962,12 @@ def transition_cancelled_resumed(self, ts, next, *, stimulus_id): def transition_cancelled_waiting(self, ts, *, stimulus_id): if ts.done: return {ts: "released"}, [] - recommendations = {} - if ts._previous == "executing": + elif ts._previous == "executing": ts.state = ts._previous + return {}, [] else: assert ts._previous == "flight" - recommendations[ts] = ("resumed", "waiting") - return recommendations, [] - - def transition_rescheduled_cancelled(self, ts): - ts.state = "cancelled" - return {}, [] + return {ts: ("resumed", "waiting")}, [] def transition_cancelled_forgotten(self, ts, *, stimulus_id): ts._next = "forgotten" @@ -2031,18 +1991,16 @@ def transition_cancelled_released(self, ts, *, stimulus_id): def transition_executing_released(self, ts, *, stimulus_id): ts._previous = ts.state + # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 ts.state = "cancelled" ts.done = False return {}, [] def transition_long_running_memory(self, ts, value=no_value, *, stimulus_id): self.executed_count += 1 - return self._transition_to_memory_generic( - ts, value=value, stimulus_id=stimulus_id - ) - - def _transition_to_memory_generic(self, ts, value=no_value, *, stimulus_id): + return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) + def transition_generic_memory(self, ts, value=no_value, *, stimulus_id): if value is no_value and ts.key not in self.data: raise RuntimeError( f"Tried to transition task {ts} to `memory` without data available" @@ -2056,11 +2014,9 @@ def _transition_to_memory_generic(self, ts, value=no_value, *, stimulus_id): self._in_flight_tasks.discard(ts) ts.coming_from = None - recommendations, s_msgs = self._put_key_in_memory( - ts, value, stimulus_id=stimulus_id - ) - s_msgs.append(self.get_task_state_for_scheduler(ts)) - return recommendations, s_msgs + recs, smsgs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + smsgs.append(self.get_task_state_for_scheduler(ts)) + return recs, smsgs def transition_executing_memory(self, ts, value=no_value, *, stimulus_id): if self.validate: @@ -2070,9 +2026,7 @@ def transition_executing_memory(self, ts, value=no_value, *, stimulus_id): self._executing.discard(ts) self.executed_count += 1 - return self._transition_to_memory_generic( - ts, value=value, stimulus_id=stimulus_id - ) + return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) def transition_constrained_released(self, ts, *, stimulus_id): recs = self.release_key(ts.key, reason=stimulus_id) @@ -2084,10 +2038,9 @@ def transition_constrained_executing(self, ts, *, stimulus_id): assert ts.key not in self.data assert ts.state in READY assert ts.key not in self.ready - assert all( - dep.key in self.data or dep.key in self.actors - for dep in ts.dependencies - ) + for dep in ts.dependencies: + assert dep.key in self.data or dep.key in self.actors + for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] -= quantity ts.state = "executing" @@ -2105,15 +2058,13 @@ def transition_ready_executing(self, ts, *, stimulus_id): dep.key in self.data or dep.key in self.actors for dep in ts.dependencies ) + ts.state = "executing" self._executing.add(ts) self.loop.add_callback(self.execute, ts.key, stimulus_id=stimulus_id) return {}, [] def transition_flight_fetch(self, ts, *, stimulus_id): - if self.validate: - assert ts.state == "flight" - self._in_flight_tasks.discard(ts) ts.coming_from = None @@ -2140,24 +2091,18 @@ def transition_flight_error( def transition_flight_released(self, ts, *, stimulus_id): ts._previous = "flight" + # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 ts.state = "cancelled" return {}, [] def transition_cancelled_memory(self, ts, value, *, stimulus_id): return {ts: ts._next}, [] - def transition_generic_released(self, ts, *, stimulus_id): - recs = self.release_key(ts.key, reason=stimulus_id) - return recs, [] - def transition_executing_long_running(self, ts, compute_duration, *, stimulus_id): - - if self.validate: - assert ts.state == "executing" ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) - scheduler_msgs = [ + smsgs = [ { "op": "long-running", "key": ts.key, @@ -2166,85 +2111,74 @@ def transition_executing_long_running(self, ts, compute_duration, *, stimulus_id ] self.io_loop.add_callback(self.ensure_computing) - return {}, scheduler_msgs + return {}, smsgs def transition_released_memory(self, ts, value, *, stimulus_id): - recommendations, scheduler_msgs = self._put_key_in_memory( - ts, value, stimulus_id=stimulus_id - ) - scheduler_msgs.append( - { - "op": "add-keys", - "keys": [ts.key], - } - ) - return recommendations, scheduler_msgs + recs, smsgs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + smsgs.append({"op": "add-keys", "keys": [ts.key]}) + return recs, smsgs def transition_flight_memory(self, ts, value, *, stimulus_id): - if self.validate: - assert ts.state == "flight" - self._in_flight_tasks.discard(ts) ts.coming_from = None - recommendations, scheduler_msgs = self._put_key_in_memory( - ts, value, stimulus_id=stimulus_id - ) - scheduler_msgs.append( - { - "op": "add-keys", - "keys": [ts.key], - } - ) - return recommendations, scheduler_msgs + recs, smsgs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) + smsgs.append({"op": "add-keys", "keys": [ts.key]}) + return recs, smsgs - def _transition(self, ts, finish, *args, stimulus_id, **kwargs): + def transition_released_forgotten(self, ts, *, stimulus_id): recommendations = {} - scheduler_msgs = [] - finish_state = finish + # Dependents _should_ be released by the scheduler before this + if self.validate: + assert not any(d.state != "forgotten" for d in ts.dependents) + for dep in ts.dependencies: + dep.dependents.discard(ts) + if dep.state == "released" and not dep.dependents: + recommendations[dep] = "forgotten" + + # Mark state as forgotten in case it is still referenced anymore + ts.state = "forgotten" + self.tasks.pop(ts.key, None) + return recommendations, [] + + def _transition(self, ts, finish, *args, stimulus_id, **kwargs): if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple - finish_state, *args = finish + assert not args + finish, *args = finish + + if ts is None or ts.state == finish: + return {}, [] - if ts is None or ts.state == finish_state: - return recommendations, scheduler_msgs start = ts.state - start_finish = (start, finish_state) - func = self._transitions_table.get(start_finish) + func = self._transitions_table.get((start, finish)) - default_state = "released" if func is not None: - a: tuple = func(ts, *args, stimulus_id=stimulus_id, **kwargs) self._transition_counter += 1 - recommendations, scheduler_msgs = a - self._notify_plugins("transition", ts.key, start, finish_state, **kwargs) + recs, smsgs = func(ts, *args, stimulus_id=stimulus_id, **kwargs) + self._notify_plugins("transition", ts.key, start, finish, **kwargs) - elif default_state not in start_finish: + elif "released" not in (start, finish): + # start -> "released" -> finish try: - a: tuple = self._transition(ts, default_state, stimulus_id=stimulus_id) - a_recs, a_smsgs = a - - recommendations.update(a_recs) - scheduler_msgs.extend(a_smsgs) - v = a_recs.get(ts, finish) - v_args = [] - v_state = v + recs, smsgs = self._transition(ts, "released", stimulus_id=stimulus_id) + v = recs.get(ts, (finish, *args)) if isinstance(v, tuple): - v_state, *v_args = finish - b: tuple = self._transition( + v_state, *v_args = v + else: + v_state, v_args = v, () + b_recs, b_smsgs = self._transition( ts, v_state, *v_args, stimulus_id=stimulus_id ) - b_recs, b_smsgs = b - recommendations.update(b_recs) - scheduler_msgs.extend(b_smsgs) - except (InvalidTransition, KeyError): + recs.update(b_recs) + smsgs += b_smsgs + except InvalidTransition: raise InvalidTransition( - "Impossible transition from %r to %r for %s" - % (*start_finish, ts.key) + f"Impossible transition from {start} to {finish} for {ts.key}" ) from None else: raise InvalidTransition( - "Impossible transition from %r to %r for %s" % (*start_finish, ts.key) + f"Impossible transition from {start} to {finish} for {ts.key}" ) self.log.append( @@ -2252,30 +2186,12 @@ def _transition(self, ts, finish, *args, stimulus_id, **kwargs): ts.key, start, ts.state, - {ts.key: new for ts, new in recommendations.items()}, + {ts.key: new for ts, new in recs.items()}, stimulus_id, time(), ) ) - return recommendations, scheduler_msgs - - def _transitions(self, recommendations: dict, scheduler_msgs: list, stimulus_id): - - recommendations = recommendations.copy() - tasks = set() - while recommendations: - ts, finish = recommendations.popitem() - tasks.add(ts) - new = self._transition(ts, finish, stimulus_id=stimulus_id) - new_recs, new_smsgs = new - scheduler_msgs.extend(new_smsgs) - - recommendations.update(new_recs) - - if self.validate: - # Full state validatition is too expensive - for ts in tasks: - self.validate_task(ts) + return recs, smsgs def transition(self, ts, finish: str, *, stimulus_id, **kwargs): """Transition a key from its current state to the finish state @@ -2293,31 +2209,43 @@ def transition(self, ts, finish: str, *, stimulus_id, **kwargs): -------- Scheduler.transitions: transitive version of this function """ - recommendations: dict - a: tuple = self._transition(ts, finish, stimulus_id=stimulus_id, **kwargs) - recommendations, s_msgs = a - for msg in s_msgs: + recs, smsgs = self._transition(ts, finish, stimulus_id=stimulus_id, **kwargs) + for msg in smsgs: self.batched_stream.send(msg) - self.transitions(recommendations, stimulus_id=stimulus_id) + self.transitions(recs, stimulus_id=stimulus_id) - def transitions(self, recommendations: dict, stimulus_id): + def transitions(self, recommendations: dict, *, stimulus_id): """Process transitions until none are left This includes feedback from previous transitions and continues until we reach a steady state """ - s_msgs = [] - self._transitions(recommendations, s_msgs, stimulus_id) + smsgs = [] + + remaining_recs = recommendations.copy() + tasks = set() + while remaining_recs: + ts, finish = remaining_recs.popitem() + tasks.add(ts) + a_recs, a_smsgs = self._transition(ts, finish, stimulus_id=stimulus_id) + remaining_recs.update(a_recs) + smsgs += a_smsgs + + if self.validate: + # Full state validation is very expensive + for ts in tasks: + self.validate_task(ts) + if not self.batched_stream.closed(): - for msg in s_msgs: + for msg in smsgs: self.batched_stream.send(msg) else: logger.debug( - "BatchedSend closed while transitioning tasks. %s tasks not sent.", - len(s_msgs), + "BatchedSend closed while transitioning tasks. %d tasks not sent.", + len(smsgs), ) - def maybe_transition_long_running(self, ts, stimulus_id, compute_duration=None): + def maybe_transition_long_running(self, ts, *, stimulus_id, compute_duration=None): if ts.state == "executing": self.transition( ts, @@ -2352,25 +2280,26 @@ def story(self, *keys): def ensure_communicating(self): stimulus_id = f"ensure-communicating-{time()}" - skipped_worker_in_flight = list() + skipped_worker_in_flight = [] while self.data_needed and ( len(self.in_flight_workers) < self.total_out_connections or self.comm_nbytes < self.comm_threshold_bytes ): logger.debug( - "Ensure communicating. Pending: %d. Connections: %d/%d", + "Ensure communicating. Pending: %d. Connections: %d/%d", len(self.data_needed), len(self.in_flight_workers), self.total_out_connections, ) - key = heapq.heappop(self.data_needed)[1] + _, key = heapq.heappop(self.data_needed) - if key not in self.tasks: + try: + ts = self.tasks[key] + except KeyError: continue - ts = self.tasks[key] if ts.state != "fetch": continue @@ -2389,6 +2318,7 @@ def ensure_communicating(self): worker = random.choice(local) else: worker = random.choice(list(workers)) + assert worker != self.address to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key) @@ -2400,7 +2330,7 @@ def ensure_communicating(self): self.in_flight_workers[worker] = to_gather recommendations = {self.tasks[d]: ("flight", worker) for d in to_gather} self.transitions(recommendations=recommendations, stimulus_id=stimulus_id) - assert not worker == self.address + self.loop.add_callback( self.gather_dep, worker=worker, @@ -2408,9 +2338,9 @@ def ensure_communicating(self): total_nbytes=total_nbytes, stimulus_id=stimulus_id, ) - else: - for el in skipped_worker_in_flight: - heapq.heappush(self.data_needed, el) + + for el in skipped_worker_in_flight: + heapq.heappush(self.data_needed, el) def get_task_state_for_scheduler(self, ts): if ts.key in self.data or self.actors.get(ts.key): @@ -2452,20 +2382,20 @@ def get_task_state_for_scheduler(self, ts): } else: logger.error("Key not ready to send to worker, %s: %s", ts.key, ts.state) - return + return None if ts.startstops: d["startstops"] = ts.startstops return d - def _put_key_in_memory(self, ts, value, stimulus_id): + def _put_key_in_memory(self, ts, value, *, stimulus_id): if ts.key in self.data: ts.state = "memory" return {}, [] + recommendations = {} scheduler_messages = [] if ts.key in self.actors: self.actors[ts.key] = value - else: start = time() try: @@ -2530,6 +2460,7 @@ async def gather_dep( worker: str, to_gather: Iterable[str], total_nbytes: int, + *, stimulus_id, ): """Gather dependencies for a task from a worker who has them @@ -2595,7 +2526,7 @@ async def gather_dep( return data = {k: v for k, v in response["data"].items() if k in self.tasks} - lost_keys = set(response["data"]) - set(data) + lost_keys = response["data"].keys() - data.keys() if lost_keys: self.log.append(("lost-during-gather", lost_keys, stimulus_id)) @@ -2668,15 +2599,14 @@ async def gather_dep( busy = response.get("status", "") == "busy" data = response.get("data", {}) - recommendations = {} - - deps_to_iter = self.in_flight_workers.pop(worker) - if busy: self.log.append( ("busy-gather", worker, to_gather_keys, stimulus_id, time()) ) + recommendations = {} + deps_to_iter = self.in_flight_workers.pop(worker) + for d in deps_to_iter: ts = self.tasks.get(d) assert ts, (d, self.story(d)) @@ -2711,21 +2641,6 @@ async def gather_dep( self.ensure_communicating() - def transition_released_forgotten(self, ts, *, stimulus_id): - recommendations = {} - # Dependents _should_ be released by the scheduler before this - if self.validate: - assert not any(d.state != "forgotten" for d in ts.dependents) - for dep in ts.dependencies: - dep.dependents.discard(ts) - if dep.state == "released" and not dep.dependents: - recommendations[dep] = "forgotten" - - # Mark state as forgotten in case it is still referenced anymore - ts.state = "forgotten" - self.tasks.pop(ts.key, None) - return recommendations, [] - async def find_missing(self): with log_errors(): if not self._missing_dep_flight: @@ -2733,7 +2648,8 @@ async def find_missing(self): try: if self.validate: for ts in self._missing_dep_flight: - # If this was collected somewhere else we should've transitioned already, shouldn't we? maybe this is the place, let's see + # If this was collected somewhere else we should've transitioned + # already, shouldn't we? maybe this is the place, let's see assert not ts.who_has stimulus_id = f"find-missing-{time()}" @@ -2758,7 +2674,7 @@ async def find_missing(self): ) finally: - # This is quite arbirary but the heartbeat has a scaling implemented + # This is quite arbitrary but the heartbeat has scaling implemented self.periodic_callbacks[ "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time @@ -2767,11 +2683,11 @@ async def find_missing(self): async def query_who_has(self, *deps, stimulus_id): with log_errors(): - response = await retry_operation(self.scheduler.who_has, keys=deps) - self.update_who_has(response, stimulus_id) - return response + who_has = await retry_operation(self.scheduler.who_has, keys=deps) + self.update_who_has(who_has, stimulus_id=stimulus_id) + return who_has - def update_who_has(self, who_has, stimulus_id): + def update_who_has(self, who_has, *, stimulus_id): try: recommendations = {} for dep, workers in who_has.items(): @@ -2797,6 +2713,7 @@ def update_who_has(self, who_has, stimulus_id): self.has_what[worker].add(dep) if dep_ts.state in ("fetch", "flight", "missing"): self.pending_data_per_worker[worker].append(dep_ts.key) + self.transitions(recommendations=recommendations, stimulus_id=stimulus_id) except Exception as e: logger.exception(e) @@ -2806,20 +2723,17 @@ def update_who_has(self, who_has, stimulus_id): pdb.set_trace() raise - def steal_request(self, key): + def handle_steal_request(self, key): # There may be a race condition between stealing and releasing a task. # In this case the self.tasks is already cleared. The `None` will be # registered as `already-computing` on the other end ts = self.tasks.get(key) - if key in self.tasks: - state = ts.state - else: - state = None + state = ts.state if ts is not None else None response = {"op": "steal-response", "key": key, "state": state} self.batched_stream.send(response) - if state in ("ready", "waiting", "constrained"): + if state in {"ready", "waiting", "constrained"}: # If task is marked as "constrained" we haven't yet assigned it an # `available_resources` to run on, that happens in # `transition_constrained_executing` @@ -2870,7 +2784,7 @@ def release_key( for d in ts.dependencies: ts.waiting_for_data.discard(ts) - if not d.dependents and d.state in ("flight", "fetch", "missing"): + if not d.dependents and d.state in {"flight", "fetch", "missing"}: recommendations[d] = "released" ts.waiting_for_data.clear() @@ -3004,7 +2918,7 @@ def meets_resource_constraints(self, key): return True - async def _maybe_deserialize_task(self, ts, stimulus_id): + async def _maybe_deserialize_task(self, ts, *, stimulus_id): if not isinstance(ts.runspec, SerializedTask): return ts.runspec try: @@ -3054,9 +2968,10 @@ def ensure_computing(self): priority, key = heapq.heappop(self.ready) ts = self.tasks.get(key) if ts is None: - # It is possible for tasks to be released while still remaining on `ready` - # The scheduler might have re-routed to a new worker and told this worker - # to release. If the task has "disappeared" just continue through the heap + # It is possible for tasks to be released while still remaining on + # `ready` The scheduler might have re-routed to a new worker and + # told this worker to release. If the task has "disappeared" just + # continue through the heap continue elif ts.key in self.data: self.transition(ts, "memory", stimulus_id=stimulus_id) @@ -3070,7 +2985,7 @@ def ensure_computing(self): pdb.set_trace() raise - async def execute(self, key, stimulus_id): + async def execute(self, key, *, stimulus_id): if self.status in (Status.closing, Status.closed, Status.closing_gracefully): return if key not in self.tasks: @@ -3081,8 +2996,8 @@ async def execute(self, key, stimulus_id): if ts.state == "cancelled": # This might happen if keys are canceled logger.debug( - "Trying to execute a task %s which is not in executing state anymore" - % ts + "Trying to execute task %s which is not in executing state anymore", + ts, ) ts.done = True self.transition(ts, "released", stimulus_id=stimulus_id) @@ -3093,7 +3008,9 @@ async def execute(self, key, stimulus_id): assert ts.state == "executing" assert ts.runspec is not None - function, args, kwargs = await self._maybe_deserialize_task(ts, stimulus_id) + function, args, kwargs = await self._maybe_deserialize_task( + ts, stimulus_id=stimulus_id + ) args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) @@ -3104,6 +3021,8 @@ async def execute(self, key, stimulus_id): assert executor in self.executors assert key == ts.key self.active_keys.add(ts.key) + + result: dict try: e = self.executors[executor] ts.start_time = time() @@ -3140,12 +3059,11 @@ async def execute(self, key, stimulus_id): self.active_keys.discard(ts.key) key = ts.key - # Key *must* be still in tasks. Releasing it direclty is forbidden + # key *must* be still in tasks. Releasing it direclty is forbidden # without going through cancelled ts = self.tasks.get(key) - assert ts, self.story(ts) + assert ts, self.story(key) ts.done = True - result: dict result["key"] = ts.key value = result.pop("result", None) ts.startstops.append( @@ -3451,10 +3369,10 @@ def _notify_plugins(self, method_name, *args, **kwargs): if hasattr(plugin, method_name): if method_name == "release_key": warnings.warn( - """ -The `WorkerPlugin.release_key` hook is depreacted and will be removed in a future version. -A similar event can now be caught by filtering for a `finish=='released'` event in the `WorkerPlugin.transition` hook. -""", + "The `WorkerPlugin.release_key` hook is depreacted and will be " + "removed in a future version. A similar event can now be " + "caught by filtering for a `finish=='released'` event in the " + "`WorkerPlugin.transition` hook.", DeprecationWarning, ) @@ -3462,7 +3380,7 @@ def _notify_plugins(self, method_name, *args, **kwargs): getattr(plugin, method_name)(*args, **kwargs) except Exception: logger.info( - "Plugin '%s' failed with exception" % name, exc_info=True + "Plugin '%s' failed with exception", name, exc_info=True ) ############## @@ -3481,12 +3399,9 @@ def validate_task_executing(self, ts): assert ts.runspec is not None assert ts.key not in self.data assert not ts.waiting_for_data - assert all(ts.state == "memory" for ts in ts.dependencies), [ - self.story(t) for t in ts.dependencies if ts.state != "memory" - ] - assert all( - dep.key in self.data or dep.key in self.actors for dep in ts.dependencies - ) + for dep in ts.dependencies: + assert dep.state == "memory", self.story(dep) + assert dep.key in self.data or dep.key in self.actors def validate_task_ready(self, ts): assert ts.key in pluck(1, self.ready) @@ -3590,7 +3505,8 @@ def validate_state(self): # check that worker has task for worker in ts.who_has: assert ts.key in self.has_what[worker] - # check that deps have a set state and that dependency<->dependent links are there + # check that deps have a set state and that dependency<->dependent links + # are there for dep in ts.dependencies: # self.tasks was just a dict of tasks # and this check was originally that the key was in `task_state` @@ -3605,7 +3521,7 @@ def validate_state(self): assert ts_wait.key in self.tasks assert ( ts_wait.state - in ("ready", "executing", "flight", "fetch", "missing") + in {"ready", "executing", "flight", "fetch", "missing"} or ts_wait.key in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) diff --git a/setup.cfg b/setup.cfg index 3c07c5e90b..f77b02739c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -59,4 +59,6 @@ markers = # the MacOS GitHub CI (although it's been reported to work on MacBooks). # The CI script modifies this config file on the fly on Linux. timeout_method = thread +# This should not be reduced; Windows CI has been observed to be occasionally +# exceptionally slow. timeout = 300 From c9b31b62c274609649ea0d84e21330917eb202c8 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 9 Sep 2021 13:12:35 +0200 Subject: [PATCH 3/6] More polishing --- distributed/tests/test_cancelled_state.py | 25 ++++++------ distributed/worker.py | 47 +---------------------- 2 files changed, 14 insertions(+), 58 deletions(-) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 83cc7a8420..152ce0a853 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -2,7 +2,6 @@ from unittest import mock import distributed -from distributed import Nanny from distributed.core import CommClosedError from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc @@ -20,42 +19,42 @@ async def wait_for_cancelled(key, dask_worker): assert False -@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +@gen_cluster(client=True, nthreads=[("", 1)]) async def test_abort_execution_release(c, s, a): fut = c.submit(slowinc, 1, delay=1) - await c.run(wait_for_state, fut.key, "executing") + await wait_for_state(fut.key, "executing", a) fut.release() - await c.run(wait_for_cancelled, fut.key) + await wait_for_cancelled(fut.key, a) -@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +@gen_cluster(client=True, nthreads=[("", 1)]) async def test_abort_execution_reschedule(c, s, a): fut = c.submit(slowinc, 1, delay=1) - await c.run(wait_for_state, fut.key, "executing") + await wait_for_state(fut.key, "executing", a) fut.release() - await c.run(wait_for_cancelled, fut.key) + await wait_for_cancelled(fut.key, a) fut = c.submit(slowinc, 1, delay=0.1) await fut -@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +@gen_cluster(client=True, nthreads=[("", 1)]) async def test_abort_execution_add_as_dependency(c, s, a): fut = c.submit(slowinc, 1, delay=1) - await c.run(wait_for_state, fut.key, "executing") + await wait_for_state(fut.key, "executing", a) fut.release() - await c.run(wait_for_cancelled, fut.key) + await wait_for_cancelled(fut.key, a) fut = c.submit(slowinc, 1, delay=1) fut = c.submit(slowinc, fut, delay=1) await fut -@gen_cluster(client=True, Worker=Nanny) +@gen_cluster(client=True) async def test_abort_execution_to_fetch(c, s, a, b): fut = c.submit(slowinc, 1, delay=2, key="f1", workers=[a.worker_address]) - await c.run(wait_for_state, fut.key, "executing", workers=[a.worker_address]) + await wait_for_state(fut.key, "executing", a) fut.release() - await c.run(wait_for_cancelled, fut.key, workers=[a.worker_address]) + await wait_for_cancelled(fut.key, a) # While the first worker is still trying to compute f1, we'll resubmit it to # another worker with a smaller delay. The key is still the same diff --git a/distributed/worker.py b/distributed/worker.py index 2fe763f6f4..3dcfce9ffb 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1679,46 +1679,6 @@ def register_acquire_internal(self, key, priority, stimulus_id): return recommendations, scheduler_msgs - def transition_table_to_dot(self, filename="worker-transitions", format=None): - import graphviz - - from dask.dot import graphviz_to_file - - g = graphviz.Digraph( - graph_attr={ - "concentrate": "True", - }, - # node_attr=node_attr, - # edge_attr=edge_attr - ) - all_states = set() - for edge in self._transitions_table.keys(): - all_states.update(set(edge)) - - seen = set() - with g.subgraph(name="cluster_0") as c: - c.attr(style="filled", color="lightgrey") - c.node_attr.update(style="filled", color="white") - c.attr(label="executable") - for state in ( - "waiting", - "ready", - "executing", - "constrained", - "long-running", - ): - c.node(state, label=state) - seen.add(state) - - with g.subgraph(name="cluster_1") as c: - for state in ["fetch", "flight", "missing"]: - c.attr(label="dependency") - c.node(state, label=state) - seen.add(state) - - g.edges(self._transitions_table.keys()) - return graphviz_to_file(g, filename=filename, format=format) - def handle_compute_task( self, *, @@ -1782,7 +1742,7 @@ def handle_compute_task( ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) - if ts.state in {"ready", "executing", "waiting"}: + if ts.state in {"ready", "executing", "waiting", "resumed"}: pass elif ts.state == "memory": recommendations[ts] = "memory" @@ -1790,7 +1750,6 @@ def handle_compute_task( elif ts.state in {"released", "fetch", "flight", "missing", "cancelled"}: recommendations[ts] = "waiting" else: - # FIXME Either remove exception or handle resumed raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") for msg in scheduler_msgs: @@ -2135,7 +2094,7 @@ def transition_released_forgotten(self, ts, *, stimulus_id): if dep.state == "released" and not dep.dependents: recommendations[dep] = "forgotten" - # Mark state as forgotten in case it is still referenced anymore + # Mark state as forgotten in case it is still referenced ts.state = "forgotten" self.tasks.pop(ts.key, None) return recommendations, [] @@ -2648,8 +2607,6 @@ async def find_missing(self): try: if self.validate: for ts in self._missing_dep_flight: - # If this was collected somewhere else we should've transitioned - # already, shouldn't we? maybe this is the place, let's see assert not ts.who_has stimulus_id = f"find-missing-{time()}" From 14d72408559e0e97ca2bb077ff40c6f5f0be6dbb Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 9 Sep 2021 14:48:15 +0200 Subject: [PATCH 4/6] PR comments --- distributed/stealing.py | 2 +- distributed/tests/test_client_executor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index 24f7ce7459..0297691f02 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -264,7 +264,7 @@ async def move_task_confirm(self, key=None, worker=None, state=None): await self.scheduler.remove_worker(thief.address) self.log(("confirm", key, victim.address, thief.address)) else: - raise ValueError(f"Unexpected task state: {ts}") + raise ValueError(f"Unexpected task state: {state}") except Exception as e: logger.exception(e) if LOG_PDB: diff --git a/distributed/tests/test_client_executor.py b/distributed/tests/test_client_executor.py index fe23f8c7d3..f4a90297e5 100644 --- a/distributed/tests/test_client_executor.py +++ b/distributed/tests/test_client_executor.py @@ -155,7 +155,7 @@ def test_map(client): assert number_of_processing_tasks(client) > 0 # Garbage collect the iterator => remaining tasks are cancelled del it - time.sleep(0.1) + time.sleep(0.5) assert number_of_processing_tasks(client) == 0 From 5623b1c06b10320a7382ec2a658c63c7a0c025d8 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 9 Sep 2021 14:54:11 +0200 Subject: [PATCH 5/6] PR comments --- distributed/tests/test_steal.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 82fe97be18..be7f1bc58b 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -703,10 +703,9 @@ async def test_dont_steal_already_released(c, s, a, b): with captured_logger( logging.getLogger("distributed.stealing"), level=logging.DEBUG ) as stealing_logs: - logs = stealing_logs.getvalue() - while f"Key released between request and confirm: {key}" not in logs: + msg = f"Key released between request and confirm: {key}" + while msg not in stealing_logs.getvalue(): await asyncio.sleep(0.05) - logs = stealing_logs.getvalue() @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) From 1dda0cfc27e5099f3664922747f7b4c0e34fb008 Mon Sep 17 00:00:00 2001 From: fjetter Date: Mon, 27 Sep 2021 12:02:08 +0200 Subject: [PATCH 6/6] Remove assert ts.dependents from validate task fetch --- distributed/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/worker.py b/distributed/worker.py index 4ad36f9e94..4e240dfa58 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3406,7 +3406,6 @@ def validate_task_flight(self, ts): def validate_task_fetch(self, ts): assert ts.key not in self.data assert self.address not in ts.who_has - assert ts.dependents for w in ts.who_has: assert ts.key in self.has_what[w]