Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Worker state machine refactor #5046

Merged
merged 8 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,6 @@ def transition(self, key, start, finish, **kwargs):
kwargs : More options passed when transitioning
"""

def release_key(self, key, state, cause, reason, report):
"""
Called when the worker releases a task.

Parameters
----------
key : string
state : string
State of the released task.
One of waiting, ready, executing, long-running, memory, error.
cause : string or None
Additional information on what triggered the release of the task.
reason : None
Not used.
report : bool
Whether the worker should report the released task to the scheduler.
"""


class NannyPlugin:
"""Interface to extend the Nanny
Expand Down
75 changes: 62 additions & 13 deletions distributed/diagnostics/tests/test_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def transition(self, key, start, finish, **kwargs):
{"key": key, "start": start, "finish": finish}
)

def release_key(self, key, state, cause, reason, report):
self.observed_notifications.append({"key": key, "state": state})


@gen_cluster(client=True, nthreads=[])
async def test_create_with_client(c, s):
Expand Down Expand Up @@ -107,11 +104,12 @@ async def test_create_on_construction(c, s, a, b):
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_normal_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -127,11 +125,12 @@ def failing(x):
raise Exception()

expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "error"},
{"key": "task", "state": "error"},
{"key": "task", "start": "error", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -147,11 +146,12 @@ def failing(x):
)
async def test_superseding_task_transitions_called(c, s, w):
expected_notifications = [
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "constrained"},
{"key": "task", "start": "constrained", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "task", "state": "memory"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand All @@ -166,16 +166,18 @@ async def test_dependent_tasks(c, s, w):
dsk = {"dep": 1, "task": (inc, "dep")}

expected_notifications = [
{"key": "dep", "start": "new", "finish": "waiting"},
{"key": "dep", "start": "released", "finish": "waiting"},
{"key": "dep", "start": "waiting", "finish": "ready"},
{"key": "dep", "start": "ready", "finish": "executing"},
{"key": "dep", "start": "executing", "finish": "memory"},
{"key": "task", "start": "new", "finish": "waiting"},
{"key": "task", "start": "released", "finish": "waiting"},
{"key": "task", "start": "waiting", "finish": "ready"},
{"key": "task", "start": "ready", "finish": "executing"},
{"key": "task", "start": "executing", "finish": "memory"},
{"key": "dep", "state": "memory"},
{"key": "task", "state": "memory"},
{"key": "dep", "start": "memory", "finish": "released"},
{"key": "task", "start": "memory", "finish": "released"},
{"key": "task", "start": "released", "finish": "forgotten"},
{"key": "dep", "start": "released", "finish": "forgotten"},
]

plugin = MyPlugin(1, expected_notifications=expected_notifications)
Expand Down Expand Up @@ -203,6 +205,53 @@ class MyCustomPlugin(WorkerPlugin):
assert next(iter(w.plugins)).startswith("MyCustomPlugin-")


def test_release_key_deprecated():
class ReleaseKeyDeprecated(WorkerPlugin):
def __init__(self):
self._called = False

def release_key(self, key, state, cause, reason, report):
# Ensure that the handler still works
self._called = True
assert state == "memory"
assert key == "task"

def teardown(self, worker):
assert self._called
return super().teardown(worker)

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):

await c.register_worker_plugin(ReleaseKeyDeprecated())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

with pytest.deprecated_call(
match="The `WorkerPlugin.release_key` hook is depreacted"
):
test()


def test_assert_no_warning_no_overload():
"""Assert we do not receive a deprecation warning if we do not overload any
methods
"""

class Dummy(WorkerPlugin):
pass

@gen_cluster(client=True, nthreads=[("", 1)])
async def test(c, s, a):

await c.register_worker_plugin(Dummy())
fut = await c.submit(inc, 1, key="task")
assert fut == 2

with pytest.warns(None):
test()


@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
async def test_WorkerPlugin_overwrite(c, s, w):
class MyCustomPlugin(WorkerPlugin):
Expand Down
34 changes: 19 additions & 15 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3002,7 +3002,11 @@ def transition_processing_released(self, key):
w: str = _remove_from_processing(self, ts)
if w:
worker_msgs[w] = [
{"op": "free-keys", "keys": [key], "reason": "Processing->Released"}
{
"op": "free-keys",
"keys": [key],
"reason": f"processing-released-{time()}",
}
]

ts.state = "released"
Expand Down Expand Up @@ -5365,7 +5369,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
self.log.append(("missing", key, errant_worker))

ts: TaskState = parent._tasks.get(key)
if ts is None or not ts._who_has:
if ts is None:
return
ws: WorkerState = parent._workers_dv.get(errant_worker)
if ws is not None and ws in ts._who_has:
Expand All @@ -5378,17 +5382,14 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
else:
self.transitions({key: "forgotten"})

def release_worker_data(self, comm=None, keys=None, worker=None):
def release_worker_data(self, comm=None, key=None, worker=None):
parent: SchedulerState = cast(SchedulerState, self)
ws: WorkerState = parent._workers_dv.get(worker)
if not ws:
ts: TaskState = parent._tasks.get(key)
if not ws or not ts:
return
tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks}
removed_tasks: set = tasks.intersection(ws._has_what)

ts: TaskState
recommendations: dict = {}
for ts in removed_tasks:
if ts in ws._has_what:
del ws._has_what[ts]
ws._nbytes -= ts.get_nbytes()
wh: set = ts._who_has
Expand Down Expand Up @@ -6707,7 +6708,7 @@ def add_keys(self, comm=None, worker=None, keys=()):
if worker not in parent._workers_dv:
return "not found"
ws: WorkerState = parent._workers_dv[worker]
superfluous_data = []
redundant_replicas = []
for key in keys:
ts: TaskState = parent._tasks.get(key)
if ts is not None and ts._state == "memory":
Expand All @@ -6716,14 +6717,15 @@ def add_keys(self, comm=None, worker=None, keys=()):
ws._has_what[ts] = None
ts._who_has.add(ws)
else:
superfluous_data.append(key)
if superfluous_data:
redundant_replicas.append(key)

if redundant_replicas:
self.worker_send(
worker,
{
"op": "superfluous-data",
"keys": superfluous_data,
"reason": f"Add keys which are not in-memory {superfluous_data}",
"op": "remove-replicas",
"keys": redundant_replicas,
"stimulus_id": f"redundant-replicas-{time()}",
},
)

Expand Down Expand Up @@ -7846,6 +7848,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
"key": ts._key,
"priority": ts._priority,
"duration": duration,
"stimulus_id": f"compute-task-{time()}",
"who_has": {},
}
if ts._resource_restrictions:
msg["resource_restrictions"] = ts._resource_restrictions
Expand Down
10 changes: 9 additions & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
return

# Victim had already started execution, reverse stealing
if state in ("memory", "executing", "long-running", None):
if state in (
"memory",
"executing",
"long-running",
"released",
"cancelled",
"resumed",
None,
):
self.log(("already-computing", key, victim.address, thief.address))
self.scheduler.check_idle_saturated(thief)
self.scheduler.check_idle_saturated(victim)
Expand Down
133 changes: 133 additions & 0 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio
from unittest import mock

import distributed
from distributed.core import CommClosedError
from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc


async def wait_for_state(key, state, dask_worker):
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)


async def wait_for_cancelled(key, dask_worker):
while key in dask_worker.tasks:
if dask_worker.tasks[key].state == "cancelled":
return
await asyncio.sleep(0.005)
assert False


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_release(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
fut.release()
await wait_for_cancelled(fut.key, a)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_reschedule(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
fut.release()
await wait_for_cancelled(fut.key, a)
fut = c.submit(slowinc, 1, delay=0.1)
await fut


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_abort_execution_add_as_dependency(c, s, a):
fut = c.submit(slowinc, 1, delay=1)
await wait_for_state(fut.key, "executing", a)
fut.release()
await wait_for_cancelled(fut.key, a)

fut = c.submit(slowinc, 1, delay=1)
fut = c.submit(slowinc, fut, delay=1)
await fut


@gen_cluster(client=True)
async def test_abort_execution_to_fetch(c, s, a, b):
fut = c.submit(slowinc, 1, delay=2, key="f1", workers=[a.worker_address])
await wait_for_state(fut.key, "executing", a)
fut.release()
await wait_for_cancelled(fut.key, a)

# While the first worker is still trying to compute f1, we'll resubmit it to
# another worker with a smaller delay. The key is still the same
fut = c.submit(inc, 1, key="f1", workers=[b.worker_address])
# then, a must switch the execute to fetch. Instead of doing so, it will
# simply re-use the currently computing result.
fut = c.submit(inc, fut, workers=[a.worker_address], key="f2")
await fut


@gen_cluster(client=True)
async def test_worker_find_missing(c, s, a, b):
fut = c.submit(inc, 1, workers=[a.address])
await fut
# We do not want to use proper API since it would ensure that the cluster is
# informed properly
del a.data[fut.key]
del a.tasks[fut.key]

# Actually no worker has the data; the scheduler is supposed to reschedule
assert await c.submit(inc, fut, workers=[b.address]) == 3


@gen_cluster(client=True)
async def test_worker_stream_died_during_comm(c, s, a, b):
write_queue = asyncio.Queue()
write_event = asyncio.Event()
b.rpc = _LockedCommPool(
b.rpc,
write_queue=write_queue,
write_event=write_event,
)
fut = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
await fut
# Actually no worker has the data; the scheduler is supposed to reschedule
res = c.submit(inc, fut, workers=[b.address])

await write_queue.get()
await a.close()
write_event.set()

await res
assert any("receive-dep-failed" in msg for msg in b.log)


@gen_cluster(client=True)
async def test_flight_to_executing_via_cancelled_resumed(c, s, a, b):
lock = asyncio.Lock()
await lock.acquire()

async def wait_and_raise(*args, **kwargs):
async with lock:
raise CommClosedError()

with mock.patch.object(
distributed.worker,
"get_data_from_worker",
side_effect=wait_and_raise,
):
fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
fut2 = c.submit(inc, fut1, workers=[b.address])

await wait_for_state(fut1.key, "flight", b)

# Close in scheduler to ensure we transition and reschedule task properly
await s.close_worker(worker=a.address)
await wait_for_state(fut1.key, "resumed", b)

lock.release()
assert await fut2 == 3

b_story = b.story(fut1.key)
assert any("receive-dep-failed" in msg for msg in b_story)
assert any("missing-dep" in msg for msg in b_story)
assert any("cancelled" in msg for msg in b_story)
assert any("resumed" in msg for msg in b_story)
Loading