Skip to content

Commit

Permalink
Refactor all event handlers (#6410)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 7, 2022
1 parent c014e5b commit bd98e66
Show file tree
Hide file tree
Showing 9 changed files with 477 additions and 307 deletions.
54 changes: 22 additions & 32 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def __init__(self, key: str, run_spec: object):
self.has_lost_dependencies = False
self.host_restrictions = None # type: ignore
self.worker_restrictions = None # type: ignore
self.resource_restrictions = None # type: ignore
self.resource_restrictions = {}
self.loose_restrictions = False
self.actor = False
self.prefix = None # type: ignore
Expand Down Expand Up @@ -2670,14 +2670,12 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
return s

def consume_resources(self, ts: TaskState, ws: WorkerState):
if ts.resource_restrictions:
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required

def release_resources(self, ts: TaskState, ws: WorkerState):
if ts.resource_restrictions:
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] -= required
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] -= required

def coerce_hostname(self, host):
"""
Expand Down Expand Up @@ -7092,29 +7090,28 @@ def adaptive_target(self, target_duration=None):
to_close = self.workers_to_close()
return len(self.workers) - len(to_close)

def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_acquire_replicas(
self, addr: str, keys: Iterable[str], *, stimulus_id: str
) -> None:
"""Asynchronously ask a worker to acquire a replica of the listed keys from
other workers. This is a fire-and-forget operation which offers no feedback for
success or failure, and is intended for housekeeping and not for computation.
"""
who_has = {}
for key in keys:
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

who_has = {key: [ws.address for ws in self.tasks[key].who_has] for key in keys}
if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
"keys": keys,
"who_has": who_has,
"stimulus_id": stimulus_id,
},
)

def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_remove_replicas(
self, addr: str, keys: list[str], *, stimulus_id: str
) -> None:
"""Asynchronously ask a worker to discard its replica of the listed keys.
This must never be used to destroy the last replica of a key. This is a
fire-and-forget operation, intended for housekeeping and not for computation.
Expand All @@ -7125,15 +7122,14 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
to re-add itself to who_has. If the worker agrees to discard the task, there is
no feedback.
"""
ws: WorkerState = self.workers[addr]
validate = self.validate
ws = self.workers[addr]

# The scheduler immediately forgets about the replica and suggests the worker to
# drop it. The worker may refuse, at which point it will send back an add-keys
# message to reinstate it.
for key in keys:
ts: TaskState = self.tasks[key]
if validate:
ts = self.tasks[key]
if self.validate:
# Do not destroy the last copy
assert len(ts.who_has) > 1
self.remove_replica(ts, ws)
Expand Down Expand Up @@ -7314,22 +7310,16 @@ def _task_to_msg(
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
"run_spec": ts.run_spec,
"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
}
if state.validate:
assert all(msg["who_has"].values())

if ts.resource_restrictions:
msg["resource_restrictions"] = ts.resource_restrictions
if ts.actor:
msg["actor"] = True

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
msg["task"] = ts.run_spec

if ts.annotations:
msg["annotations"] = ts.annotations
if isinstance(msg["run_spec"], dict):
assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"})
assert msg["run_spec"].get("function")

return msg

Expand Down
6 changes: 4 additions & 2 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
slowadd,
slowinc,
)
from distributed.worker_state_machine import TaskState
from distributed.worker_state_machine import FreeKeysEvent, TaskState

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -425,7 +425,9 @@ def sink(*args):
# artificially, without notifying the scheduler.
# This can only succeed if B handles the missing data properly by
# removing A from the known sources of keys
a.handle_free_keys(keys=["f1"], stimulus_id="Am I evil?") # Yes, I am!
a.handle_stimulus(
FreeKeysEvent(keys=["f1"], stimulus_id="Am I evil?")
) # Yes, I am!
result_fut = c.submit(sink, futures, workers=x.address)

await result_fut
Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
slowidentity,
slowinc,
)
from distributed.worker_state_machine import StealRequestEvent

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -867,7 +868,7 @@ async def test_dont_steal_already_released(c, s, a, b):
while key in a.tasks and a.tasks[key].state != "released":
await asyncio.sleep(0.05)

a.handle_steal_request(key=key, stimulus_id="test")
a.handle_stimulus(StealRequestEvent(key=key, stimulus_id="test"))
assert len(a.batched_stream.buffer) == 1
msg = a.batched_stream.buffer[0]
assert msg["op"] == "steal-response"
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ async def test_log_invalid_transitions(c, s, a):
await asyncio.sleep(0.01)
ts = a.tasks[xkey]
with pytest.raises(InvalidTransition):
a.transition(ts, "foo", stimulus_id="bar")
a._transition(ts, "foo", stimulus_id="bar")

while not s.events["invalid-worker-transition"]:
await asyncio.sleep(0.01)
Expand Down
145 changes: 89 additions & 56 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import gc
import importlib
import logging
import os
Expand Down Expand Up @@ -72,6 +73,15 @@
error_message,
logger,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
ComputeTaskEvent,
ExecuteFailureEvent,
ExecuteSuccessEvent,
RemoveReplicasEvent,
SerializedTask,
StealRequestEvent,
)

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -1851,31 +1861,67 @@ async def test_story(c, s, w):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_stimulus_story(c, s, a):
# Test that substrings aren't matched by stimulus_story()
f = c.submit(inc, 1, key="f")
f1 = c.submit(lambda: "foo", key="f1")
f2 = c.submit(inc, f1, key="f2") # This will fail
await wait([f, f1, f2])

story = a.stimulus_story("f1", "f2")
assert len(story) == 4

assert isinstance(story[0], ComputeTaskEvent)
assert story[0].key == "f1"
assert story[0].run_spec == SerializedTask(task=None) # Not logged

assert isinstance(story[1], ExecuteSuccessEvent)
assert story[1].key == "f1"
assert story[1].value is None # Not logged
assert story[1].handled >= story[0].handled

assert isinstance(story[2], ComputeTaskEvent)
assert story[2].key == "f2"
assert story[2].who_has == {"f1": (a.address,)}
assert story[2].run_spec == SerializedTask(task=None) # Not logged
assert story[2].handled >= story[1].handled

assert isinstance(story[3], ExecuteFailureEvent)
assert story[3].key == "f2"
assert story[3].handled >= story[2].handled


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_worker_descopes_data(c, s, a):
"""Test that data is released on the worker:
1. when it's the output of a successful task
2. when it's the input of a failed task
3. when it's a local variable in the frame of a failed task
4. when it's embedded in the exception of a failed task
"""

class C:
pass
instances = weakref.WeakSet()

f = c.submit(C, key="f") # Test that substrings aren't matched by story()
f2 = c.submit(inc, 2, key="f2")
f3 = c.submit(inc, 3, key="f3")
await wait([f, f2, f3])
def __init__(self):
C.instances.add(self)

# Test that ExecuteSuccessEvent.value is not stored in the the event log
assert isinstance(a.data["f"], C)
ref = weakref.ref(a.data["f"])
del f
while "f" in a.data:
await asyncio.sleep(0.01)
with profile.lock:
assert ref() is None
def f(x):
y = C()
raise Exception(x, y)

f1 = c.submit(C, key="f1")
f2 = c.submit(f, f1, key="f2")
await wait([f2])

story = a.stimulus_story("f", "f2")
assert {ev.key for ev in story} == {"f", "f2"}
assert {ev.type for ev in story} == {C, int}
assert type(a.data["f1"]) is C

prev_handled = story[0].handled
for ev in story[1:]:
assert ev.handled >= prev_handled
prev_handled = ev.handled
del f1
del f2
while a.data:
await asyncio.sleep(0.01)
with profile.lock:
gc.collect()
assert not C.instances


@gen_cluster(client=True)
Expand Down Expand Up @@ -2570,7 +2616,7 @@ def __call__(self, *args, **kwargs):
await asyncio.sleep(0)

ts = s.tasks[fut.key]
a.handle_steal_request(fut.key, stimulus_id="test")
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)

fut2 = c.submit(inc, fut, workers=[a.address])
Expand Down Expand Up @@ -2681,41 +2727,29 @@ async def test_acquire_replicas_many(c, s, *workers):
await asyncio.sleep(0.001)


@pytest.mark.slow
@gen_cluster(client=True, Worker=Nanny)
async def test_acquire_replicas_already_in_flight(c, s, *nannies):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_acquire_replicas_already_in_flight(c, s, a):
"""Trying to acquire a replica that is already in flight is a no-op"""
async with BlockedGatherDep(s.address) as b:
x = c.submit(inc, 1, workers=[a.address], key="x")
y = c.submit(inc, x, workers=[b.address], key="y")
await b.in_gather_dep.wait()
assert b.tasks["x"].state == "flight"

class SlowToFly:
def __getstate__(self):
sleep(0.9)
return {}

a, b = s.workers
x = c.submit(SlowToFly, workers=[a], key="x")
await wait(x)
y = c.submit(lambda x: 123, x, workers=[b], key="y")
await asyncio.sleep(0.3)
s.request_acquire_replicas(b, [x.key], stimulus_id=f"test-{time()}")
assert await y == 123
b.handle_stimulus(
AcquireReplicasEvent(who_has={"x": a.address}, stimulus_id="test")
)
assert b.tasks["x"].state == "flight"
b.block_gather_dep.set()
assert await y == 3

story = await c.run(lambda dask_worker: dask_worker.story("x"))
assert_story(
story[b],
[
("x", "ensure-task-exists", "released"),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", a, {"x"}),
("x", "fetch", "flight", "flight", {}),
("request-dep", a, {"x"}),
("x", "ensure-task-exists", "flight"),
("x", "flight", "fetch", "flight", {}),
("receive-dep", a, {"x"}),
("x", "put-in-memory"),
("x", "flight", "memory", "memory", {"y": "ready"}),
],
strict=True,
)
assert_story(
b.story("x"),
[
("x", "fetch", "flight", "flight", {}),
("x", "flight", "fetch", "flight", {}),
],
)


@gen_cluster(client=True)
Expand Down Expand Up @@ -2873,8 +2907,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
if w.address == a.tasks[f1.key].coming_from:
break

coming_from.handle_remove_replicas([f1.key], "test")

coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test"))
await f2

assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
Expand Down Expand Up @@ -3343,7 +3376,7 @@ async def test_log_invalid_transitions(c, s, a):
await asyncio.sleep(0.01)
ts = a.tasks[xkey]
with pytest.raises(InvalidTransition):
a.transition(ts, "foo", stimulus_id="bar")
a._transition(ts, "foo", stimulus_id="bar")

while not s.events["invalid-worker-transition"]:
await asyncio.sleep(0.01)
Expand Down
Loading

0 comments on commit bd98e66

Please sign in to comment.