Skip to content

Commit

Permalink
Refactor all event handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 22, 2022
1 parent 9bb999d commit d16cf26
Show file tree
Hide file tree
Showing 6 changed files with 426 additions and 232 deletions.
54 changes: 22 additions & 32 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ def __init__(self, key: str, run_spec: object):
self.has_lost_dependencies = False
self.host_restrictions = None # type: ignore
self.worker_restrictions = None # type: ignore
self.resource_restrictions = None # type: ignore
self.resource_restrictions = {}
self.loose_restrictions = False
self.actor = False
self.prefix = None # type: ignore
Expand Down Expand Up @@ -2670,14 +2670,12 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None
return s

def consume_resources(self, ts: TaskState, ws: WorkerState):
if ts.resource_restrictions:
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] += required

def release_resources(self, ts: TaskState, ws: WorkerState):
if ts.resource_restrictions:
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] -= required
for r, required in ts.resource_restrictions.items():
ws.used_resources[r] -= required

def coerce_hostname(self, host):
"""
Expand Down Expand Up @@ -7060,29 +7058,28 @@ def adaptive_target(self, target_duration=None):
to_close = self.workers_to_close()
return len(self.workers) - len(to_close)

def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_acquire_replicas(
self, addr: str, keys: Iterable[str], *, stimulus_id: str
) -> None:
"""Asynchronously ask a worker to acquire a replica of the listed keys from
other workers. This is a fire-and-forget operation which offers no feedback for
success or failure, and is intended for housekeeping and not for computation.
"""
who_has = {}
for key in keys:
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

who_has = {key: [ws.address for ws in self.tasks[key].who_has] for key in keys}
if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
"keys": keys,
"who_has": who_has,
"stimulus_id": stimulus_id,
},
)

def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
def request_remove_replicas(
self, addr: str, keys: list[str], *, stimulus_id: str
) -> None:
"""Asynchronously ask a worker to discard its replica of the listed keys.
This must never be used to destroy the last replica of a key. This is a
fire-and-forget operation, intended for housekeeping and not for computation.
Expand All @@ -7093,15 +7090,14 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str):
to re-add itself to who_has. If the worker agrees to discard the task, there is
no feedback.
"""
ws: WorkerState = self.workers[addr]
validate = self.validate
ws = self.workers[addr]

# The scheduler immediately forgets about the replica and suggests the worker to
# drop it. The worker may refuse, at which point it will send back an add-keys
# message to reinstate it.
for key in keys:
ts: TaskState = self.tasks[key]
if validate:
ts = self.tasks[key]
if self.validate:
# Do not destroy the last copy
assert len(ts.who_has) > 1
self.remove_replica(ts, ws)
Expand Down Expand Up @@ -7282,22 +7278,16 @@ def _task_to_msg(
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
"run_spec": ts.run_spec,
"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
}
if state.validate:
assert all(msg["who_has"].values())

if ts.resource_restrictions:
msg["resource_restrictions"] = ts.resource_restrictions
if ts.actor:
msg["actor"] = True

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
msg["task"] = ts.run_spec

if ts.annotations:
msg["annotations"] = ts.annotations
if isinstance(msg["run_spec"], dict):
assert set(msg["run_spec"]).issubset({"function", "args", "kwargs"})
assert msg["run_spec"].get("function")

return msg

Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
slowidentity,
slowinc,
)
from distributed.worker_state_machine import StealRequestEvent

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -868,7 +869,7 @@ async def test_dont_steal_already_released(c, s, a, b):
while key in a.tasks and a.tasks[key].state != "released":
await asyncio.sleep(0.05)

a.handle_steal_request(key=key, stimulus_id="test")
a.handle_stimulus(StealRequestEvent(key=key, stimulus_id="test"))
assert len(a.batched_stream.buffer) == 1
msg = a.batched_stream.buffer[0]
assert msg["op"] == "steal-response"
Expand Down
91 changes: 68 additions & 23 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import gc
import importlib
import logging
import os
Expand Down Expand Up @@ -70,6 +71,14 @@
error_message,
logger,
)
from distributed.worker_state_machine import (
ComputeTaskEvent,
ExecuteFailureEvent,
ExecuteSuccessEvent,
RemoveReplicasEvent,
SerializedTask,
StealRequestEvent,
)

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -1837,31 +1846,68 @@ async def test_story(c, s, w):

@gen_cluster(client=True, nthreads=[("", 1)])
async def test_stimulus_story(c, s, a):
# Test that substrings aren't matched by stimulus_story()
f = c.submit(inc, 1, key="f")
f1 = c.submit(lambda: "foo", key="f1")
f2 = c.submit(inc, f1, key="f2") # This will fail
await wait([f, f1, f2])

story = a.stimulus_story("f1", "f2")
assert len(story) == 4

assert isinstance(story[0], ComputeTaskEvent)
assert story[0].key == "f1"
assert story[0].run_spec == SerializedTask(task=None) # Not logged

assert isinstance(story[1], ExecuteSuccessEvent)
assert story[1].key == "f1"
assert story[1].value is None # Not logged
assert story[1].handled >= story[0].handled

assert isinstance(story[2], ComputeTaskEvent)
assert story[2].key == "f2"
assert story[2].who_has == {"f1": (a.address,)}
assert story[2].run_spec == SerializedTask(task=None) # Not logged
assert story[2].handled >= story[1].handled

assert isinstance(story[3], ExecuteFailureEvent)
assert story[3].key == "f2"
assert story[3].handled >= story[2].handled


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_worker_descopes_data(c, s, a):
"""Test that data is released on the worker:
1. when it's the output of a successful task
2. when it's the input of a failed task
3. when it's a local variable in the frame of a failed task
4. when it's embedded in the exception of a failed task
"""

class C:
pass
instances = weakref.WeakSet()

f = c.submit(C, key="f") # Test that substrings aren't matched by story()
f2 = c.submit(inc, 2, key="f2")
f3 = c.submit(inc, 3, key="f3")
await wait([f, f2, f3])
def __init__(self):
C.instances.add(self)

# Test that ExecuteSuccessEvent.value is not stored in the the event log
assert isinstance(a.data["f"], C)
ref = weakref.ref(a.data["f"])
del f
while "f" in a.data:
await asyncio.sleep(0.01)
wait_profiler()
assert ref() is None
def f(x):
y = C()
raise Exception(x, y)

story = a.stimulus_story("f", "f2")
assert {ev.key for ev in story} == {"f", "f2"}
assert {ev.type for ev in story} == {C, int}
f1 = c.submit(C, key="f1")
f2 = c.submit(f, f1, key="f2")
await wait([f2])

prev_handled = story[0].handled
for ev in story[1:]:
assert ev.handled >= prev_handled
prev_handled = ev.handled
assert type(a.data["f1"]) is C

del f1
del f2
while a.data:
await asyncio.sleep(0.01)

gc.collect()
wait_profiler()
assert not C.instances


@gen_cluster(client=True)
Expand Down Expand Up @@ -2556,7 +2602,7 @@ def __call__(self, *args, **kwargs):
await asyncio.sleep(0)

ts = s.tasks[fut.key]
a.handle_steal_request(fut.key, stimulus_id="test")
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)

fut2 = c.submit(inc, fut, workers=[a.address])
Expand Down Expand Up @@ -2916,8 +2962,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
if w.address == a.tasks[f1.key].coming_from:
break

coming_from.handle_remove_replicas([f1.key], "test")

coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test"))
await f2

assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
Expand Down
67 changes: 67 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
from distributed.utils import recursive_to_dict
from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc
from distributed.worker_state_machine import (
ComputeTaskEvent,
ExecuteFailureEvent,
ExecuteSuccessEvent,
Instruction,
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
SerializedTask,
StateMachineEvent,
TaskState,
UniqueTaskHeap,
UpdateDataEvent,
merge_recs_instructions,
)

Expand Down Expand Up @@ -167,6 +170,70 @@ def test_event_to_dict():
assert ev3 == ev


def test_computetask_to_dict():
"""The potentially very large ComputeTaskEvent.run_spec is not stored in the log"""
ev = ComputeTaskEvent(
key="x",
who_has={"y": ["w1"]},
nbytes={"y": 123},
priority=(0,),
duration=123.45,
# Automatically converted to SerializedTask on init
run_spec={"function": b"blob", "args": b"blob"},
resource_restrictions={},
actor=False,
annotations={},
stimulus_id="test",
)
assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob")
ev2 = ev.to_loggable(handled=11.22)
assert ev2.handled == 11.22
assert ev2.run_spec == SerializedTask(task=None)
assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob")
d = recursive_to_dict(ev2)
assert d == {
"cls": "ComputeTaskEvent",
"key": "x",
"who_has": {"y": ["w1"]},
"nbytes": {"y": 123},
"priority": [0],
"run_spec": [None, None, None, None],
"duration": 123.45,
"resource_restrictions": {},
"actor": False,
"annotations": {},
"stimulus_id": "test",
"handled": 11.22,
}
ev3 = StateMachineEvent.from_dict(d)
assert isinstance(ev3, ComputeTaskEvent)
assert ev3.run_spec == SerializedTask(task=None)
assert ev3.priority == (0,) # List is automatically converted back to tuple


def test_updatedata_to_dict():
"""The potentially very large UpdateDataEvent.data is not stored in the log"""
ev = UpdateDataEvent(
data={"x": "foo", "y": "bar"},
report=True,
stimulus_id="test",
)
ev2 = ev.to_loggable(handled=11.22)
assert ev2.handled == 11.22
assert ev2.data == {"x": None, "y": None}
d = recursive_to_dict(ev2)
assert d == {
"cls": "UpdateDataEvent",
"data": {"x": None, "y": None},
"report": True,
"stimulus_id": "test",
"handled": 11.22,
}
ev3 = StateMachineEvent.from_dict(d)
assert isinstance(ev3, UpdateDataEvent)
assert ev3.data == {"x": None, "y": None}


def test_executesuccess_to_dict():
"""The potentially very large ExecuteSuccessEvent.value is not stored in the log"""
ev = ExecuteSuccessEvent(
Expand Down
Loading

0 comments on commit d16cf26

Please sign in to comment.