Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor all event handlers #6410

Merged
merged 12 commits into from
Jun 7, 2022
54 changes: 22 additions & 32 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def __init__(self, key: str, run_spec: object):
self.has_lost_dependencies = False
self.host_restrictions = None # type: ignore
self.worker_restrictions = None # type: ignore
self.resource_restrictions = None # type: ignore
self.resource_restrictions = {}
self.loose_restrictions = False
self.actor = False
self.prefix = None # type: ignore
Expand Down Expand Up @@ -2670,14 +2670,12 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
return s

def consume_resources(self, ts: TaskState, ws: WorkerState):
if ts.resource_restrictions:
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required

def release_resources(self, ts: TaskState, ws: WorkerState):
if ts.resource_restrictions:
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] -= required
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] -= required

def coerce_hostname(self, host):
"""
Expand Down Expand Up @@ -7092,29 +7090,28 @@ def adaptive_target(self, target_duration=None):
to_close = self.workers_to_close()
return len(self.workers) - len(to_close)

def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_acquire_replicas(
self, addr: str, keys: Iterable[str], *, stimulus_id: str
) -> None:
"""Asynchronously ask a worker to acquire a replica of the listed keys from
other workers. This is a fire-and-forget operation which offers no feedback for
success or failure, and is intended for housekeeping and not for computation.
"""
who_has = {}
for key in keys:
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

who_has = {key: [ws.address for ws in self.tasks[key].who_has] for key in keys}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
who_has = {key: [ws.address for ws in self.tasks[key].who_has] for key in keys}
who_has = {key: {ws.address for ws in self.tasks[key].who_has} for key in keys}

This was changed from a list to a set, was that intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from a set to a list. yes, as the set doesn't offer any useful feature here and a list is fractionally faster. This also makes it coherent with Scheduler.get_who_has.

if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
"keys": keys,
"who_has": who_has,
"stimulus_id": stimulus_id,
},
)

def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_remove_replicas(
self, addr: str, keys: list[str], *, stimulus_id: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self, addr: str, keys: list[str], *, stimulus_id: str
self, addr: str, keys: Iterable[str], *, stimulus_id: str

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because the variable is sent as-is through the msgpack RPC, so only a list, set, or tuple are accepted.

) -> None:
"""Asynchronously ask a worker to discard its replica of the listed keys.
This must never be used to destroy the last replica of a key. This is a
fire-and-forget operation, intended for housekeeping and not for computation.
Expand All @@ -7125,15 +7122,14 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
to re-add itself to who_has. If the worker agrees to discard the task, there is
no feedback.
"""
ws: WorkerState = self.workers[addr]
validate = self.validate
ws = self.workers[addr]

# The scheduler immediately forgets about the replica and suggests the worker to
# drop it. The worker may refuse, at which point it will send back an add-keys
# message to reinstate it.
for key in keys:
ts: TaskState = self.tasks[key]
if validate:
ts = self.tasks[key]
if self.validate:
# Do not destroy the last copy
assert len(ts.who_has) > 1
self.remove_replica(ts, ws)
Expand Down Expand Up @@ -7314,22 +7310,16 @@ def _task_to_msg(
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
"run_spec": ts.run_spec,
"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
}
if state.validate:
assert all(msg["who_has"].values())

if ts.resource_restrictions:
msg["resource_restrictions"] = ts.resource_restrictions
if ts.actor:
msg["actor"] = True

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
msg["task"] = ts.run_spec

if ts.annotations:
msg["annotations"] = ts.annotations
if isinstance(msg["run_spec"], dict):
assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"})
assert msg["run_spec"].get("function")

return msg

Expand Down
6 changes: 4 additions & 2 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
slowadd,
slowinc,
)
from distributed.worker_state_machine import TaskState
from distributed.worker_state_machine import FreeKeysEvent, TaskState

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -425,7 +425,9 @@ def sink(*args):
# artificially, without notifying the scheduler.
# This can only succeed if B handles the missing data properly by
# removing A from the known sources of keys
a.handle_free_keys(keys=["f1"], stimulus_id="Am I evil?") # Yes, I am!
a.handle_stimulus(
FreeKeysEvent(keys=["f1"], stimulus_id="Am I evil?")
) # Yes, I am!
result_fut = c.submit(sink, futures, workers=x.address)

await result_fut
Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
slowidentity,
slowinc,
)
from distributed.worker_state_machine import StealRequestEvent

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -867,7 +868,7 @@ async def test_dont_steal_already_released(c, s, a, b):
while key in a.tasks and a.tasks[key].state != "released":
await asyncio.sleep(0.05)

a.handle_steal_request(key=key, stimulus_id="test")
a.handle_stimulus(StealRequestEvent(key=key, stimulus_id="test"))
assert len(a.batched_stream.buffer) == 1
msg = a.batched_stream.buffer[0]
assert msg["op"] == "steal-response"
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ async def test_log_invalid_transitions(c, s, a):
await asyncio.sleep(0.01)
ts = a.tasks[xkey]
with pytest.raises(InvalidTransition):
a.transition(ts, "foo", stimulus_id="bar")
a._transition(ts, "foo", stimulus_id="bar")

while not s.events["invalid-worker-transition"]:
await asyncio.sleep(0.01)
Expand Down
145 changes: 89 additions & 56 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import gc
import importlib
import logging
import os
Expand Down Expand Up @@ -72,6 +73,15 @@
error_message,
logger,
)
from distributed.worker_state_machine import (
AcquireReplicasEvent,
ComputeTaskEvent,
ExecuteFailureEvent,
ExecuteSuccessEvent,
RemoveReplicasEvent,
SerializedTask,
StealRequestEvent,
)

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -1851,31 +1861,67 @@ async def test_story(c, s, w):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_stimulus_story(c, s, a):
# Test that substrings aren't matched by stimulus_story()
f = c.submit(inc, 1, key="f")
f1 = c.submit(lambda: "foo", key="f1")
f2 = c.submit(inc, f1, key="f2") # This will fail
await wait([f, f1, f2])

story = a.stimulus_story("f1", "f2")
assert len(story) == 4

assert isinstance(story[0], ComputeTaskEvent)
assert story[0].key == "f1"
assert story[0].run_spec == SerializedTask(task=None) # Not logged

assert isinstance(story[1], ExecuteSuccessEvent)
assert story[1].key == "f1"
assert story[1].value is None # Not logged
assert story[1].handled >= story[0].handled

assert isinstance(story[2], ComputeTaskEvent)
assert story[2].key == "f2"
assert story[2].who_has == {"f1": (a.address,)}
assert story[2].run_spec == SerializedTask(task=None) # Not logged
assert story[2].handled >= story[1].handled

assert isinstance(story[3], ExecuteFailureEvent)
assert story[3].key == "f2"
assert story[3].handled >= story[2].handled


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_worker_descopes_data(c, s, a):
"""Test that data is released on the worker:
1. when it's the output of a successful task
2. when it's the input of a failed task
3. when it's a local variable in the frame of a failed task
4. when it's embedded in the exception of a failed task
"""

class C:
pass
instances = weakref.WeakSet()

f = c.submit(C, key="f") # Test that substrings aren't matched by story()
f2 = c.submit(inc, 2, key="f2")
f3 = c.submit(inc, 3, key="f3")
await wait([f, f2, f3])
def __init__(self):
C.instances.add(self)

# Test that ExecuteSuccessEvent.value is not stored in the the event log
assert isinstance(a.data["f"], C)
ref = weakref.ref(a.data["f"])
del f
while "f" in a.data:
await asyncio.sleep(0.01)
with profile.lock:
assert ref() is None
def f(x):
y = C()
raise Exception(x, y)

f1 = c.submit(C, key="f1")
f2 = c.submit(f, f1, key="f2")
await wait([f2])

story = a.stimulus_story("f", "f2")
assert {ev.key for ev in story} == {"f", "f2"}
assert {ev.type for ev in story} == {C, int}
assert type(a.data["f1"]) is C

prev_handled = story[0].handled
for ev in story[1:]:
assert ev.handled >= prev_handled
prev_handled = ev.handled
del f1
del f2
while a.data:
await asyncio.sleep(0.01)
with profile.lock:
gc.collect()
assert not C.instances


@gen_cluster(client=True)
Expand Down Expand Up @@ -2570,7 +2616,7 @@ def __call__(self, *args, **kwargs):
await asyncio.sleep(0)

ts = s.tasks[fut.key]
a.handle_steal_request(fut.key, stimulus_id="test")
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)

fut2 = c.submit(inc, fut, workers=[a.address])
Expand Down Expand Up @@ -2681,41 +2727,29 @@ async def test_acquire_replicas_many(c, s, *workers):
await asyncio.sleep(0.001)


@pytest.mark.slow
@gen_cluster(client=True, Worker=Nanny)
async def test_acquire_replicas_already_in_flight(c, s, *nannies):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_acquire_replicas_already_in_flight(c, s, a):
"""Trying to acquire a replica that is already in flight is a no-op"""
async with BlockedGatherDep(s.address) as b:
x = c.submit(inc, 1, workers=[a.address], key="x")
y = c.submit(inc, x, workers=[b.address], key="y")
await b.in_gather_dep.wait()
assert b.tasks["x"].state == "flight"

class SlowToFly:
def __getstate__(self):
sleep(0.9)
return {}

a, b = s.workers
x = c.submit(SlowToFly, workers=[a], key="x")
await wait(x)
y = c.submit(lambda x: 123, x, workers=[b], key="y")
await asyncio.sleep(0.3)
s.request_acquire_replicas(b, [x.key], stimulus_id=f"test-{time()}")
assert await y == 123
b.handle_stimulus(
AcquireReplicasEvent(who_has={"x": a.address}, stimulus_id="test")
)
assert b.tasks["x"].state == "flight"
b.block_gather_dep.set()
assert await y == 3

story = await c.run(lambda dask_worker: dask_worker.story("x"))
assert_story(
story[b],
[
("x", "ensure-task-exists", "released"),
("x", "released", "fetch", "fetch", {}),
("gather-dependencies", a, {"x"}),
("x", "fetch", "flight", "flight", {}),
("request-dep", a, {"x"}),
("x", "ensure-task-exists", "flight"),
("x", "flight", "fetch", "flight", {}),
("receive-dep", a, {"x"}),
("x", "put-in-memory"),
("x", "flight", "memory", "memory", {"y": "ready"}),
],
strict=True,
)
assert_story(
b.story("x"),
[
("x", "fetch", "flight", "flight", {}),
("x", "flight", "fetch", "flight", {}),
],
)


@gen_cluster(client=True)
Expand Down Expand Up @@ -2873,8 +2907,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
if w.address == a.tasks[f1.key].coming_from:
break

coming_from.handle_remove_replicas([f1.key], "test")

coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test"))
await f2

assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
Expand Down Expand Up @@ -3343,7 +3376,7 @@ async def test_log_invalid_transitions(c, s, a):
await asyncio.sleep(0.01)
ts = a.tasks[xkey]
with pytest.raises(InvalidTransition):
a.transition(ts, "foo", stimulus_id="bar")
a._transition(ts, "foo", stimulus_id="bar")

while not s.events["invalid-worker-transition"]:
await asyncio.sleep(0.01)
Expand Down
Loading