diff --git a/distributed/collections.py b/distributed/collections.py index 4aef7d555e..992b4582eb 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -1,12 +1,18 @@ from __future__ import annotations +import dataclasses import heapq import itertools import weakref from collections import OrderedDict, UserDict from collections.abc import Callable, Hashable, Iterator -from typing import MutableSet # TODO move to collections.abc (requires Python >=3.9) -from typing import Any, TypeVar, cast +from typing import ( # TODO move to collections.abc (requires Python >=3.9) + Any, + Container, + MutableSet, + TypeVar, + cast, +) T = TypeVar("T", bound=Hashable) @@ -199,3 +205,54 @@ def clear(self) -> None: self._data.clear() self._heap.clear() self._sorted = True + + +# NOTE: only used in Scheduler; if work stealing is ever removed, +# this could be moved to `scheduler.py`. +@dataclasses.dataclass +class Occupancy: + cpu: float + network: float + + def __add__(self, other: Any) -> Occupancy: + if isinstance(other, type(self)): + return type(self)(self.cpu + other.cpu, self.network + other.network) + return NotImplemented + + def __iadd__(self, other: Any) -> Occupancy: + if isinstance(other, type(self)): + self.cpu += other.cpu + self.network += other.network + return self + return NotImplemented + + def __sub__(self, other: Any) -> Occupancy: + if isinstance(other, type(self)): + return type(self)(self.cpu - other.cpu, self.network - other.network) + return NotImplemented + + def __isub__(self, other: Any) -> Occupancy: + if isinstance(other, type(self)): + self.cpu -= other.cpu + self.network -= other.network + return self + return NotImplemented + + def __bool__(self) -> bool: + return self.cpu != 0 or self.network != 0 + + def __eq__(self, other: Any) -> bool: + if isinstance(other, type(self)): + return self.cpu == other.cpu and self.network == other.network + return NotImplemented + + def clear(self) -> None: + self.cpu = 0.0 + self.network = 0.0 + + def _to_dict(self, *, exclude: Container[str] = ()) -> dict[str, float]: + return {"cpu": self.cpu, "network": self.network} + + @property + def total(self) -> float: + return self.cpu + self.network diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 01ca786cef..d37c153a4e 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -163,7 +163,8 @@ def update(self): workers = self.scheduler.workers.values() y = list(range(len(workers))) - occupancy = [ws.occupancy for ws in workers] + # TODO split chart by cpu vs network + occupancy = [ws.occupancy.total for ws in workers] ms = [occ * 1000 for occ in occupancy] x = [occ / 500 for occ in occupancy] total = sum(occupancy) diff --git a/distributed/http/templates/worker-table.html b/distributed/http/templates/worker-table.html index 87512ee386..52eeb9c3f8 100644 --- a/distributed/http/templates/worker-table.html +++ b/distributed/http/templates/worker-table.html @@ -5,7 +5,8 @@ Cores Memory Memory use - Occupancy + Network Occupancy + CPU Occupancy Processing In-memory Services @@ -19,7 +20,8 @@ {{ ws.nthreads }} {{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} - {{ format_time(ws.occupancy) }} + {{ format_time(ws.occupancy.network) }} + {{ format_time(ws.occupancy.cpu) }} {{ len(ws.processing) }} {{ len(ws.has_what) }} {% if 'dashboard' in ws.services %} diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f5b73c83dd..568b19a55a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -66,7 +66,7 @@ from distributed._stories import scheduler_story from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from distributed.batched import BatchedSend -from distributed.collections import HeapSet +from distributed.collections import HeapSet, Occupancy from distributed.comm import ( Comm, CommClosedError, @@ -425,10 +425,10 @@ class WorkerState: #: (i.e. the tasks in this worker's :attr:`~WorkerState.has_what`). nbytes: int - #: The total expected runtime, in seconds, of all tasks currently processing on this - #: worker. This is the sum of all the costs in this worker's + #: The total expected cost, in seconds, of all tasks currently processing on this + #: worker. This is the sum of all the Occupancies in this worker's # :attr:`~WorkerState.processing` dictionary. - occupancy: float + occupancy: Occupancy #: Worker memory unknown to the worker, in bytes, which has been there for more than #: 30 seconds. See :class:`MemoryState`. @@ -456,12 +456,12 @@ class WorkerState: _has_what: dict[TaskState, None] #: A dictionary of tasks that have been submitted to this worker. Each task state is - #: associated with the expected cost in seconds of running that task, summing both - #: the task's expected computation time and the expected communication time of its - #: result. + #: associated with the expected cost in seconds of running that task, both of + #: the task's expected computation time and the expected serial communication time of + #: its dependencies. #: - #: If a task is already executing on the worker and the excecution time is twice the - #: learned average TaskGroup duration, this will be set to twice the current + #: If a task is already executing on the worker and the execution time is twice the + #: learned average TaskGroup duration, the `cpu` time will be set to twice the current #: executing time. If the task is unknown, the default task duration is used instead #: of the TaskGroup average. #: @@ -471,13 +471,13 @@ class WorkerState: #: #: All the tasks here are in the "processing" state. #: This attribute is kept in sync with :attr:`TaskState.processing_on`. - processing: dict[TaskState, float] + processing: dict[TaskState, Occupancy] #: Running tasks that invoked :func:`distributed.secede` long_running: set[TaskState] #: A dictionary of tasks that are currently being run on this worker. - #: Each task state is asssociated with the duration in seconds which the task has + #: Each task state is associated with the duration in seconds which the task has #: been running. executing: dict[TaskState, float] @@ -528,7 +528,7 @@ def __init__( self.status = status self._hash = hash(self.server_id) self.nbytes = 0 - self.occupancy = 0 + self.occupancy = Occupancy(0.0, 0.0) self._memory_unmanaged_old = 0 self._memory_unmanaged_history = deque() self.metrics = {} @@ -1312,7 +1312,7 @@ class SchedulerState: #: Workers that are fully utilized. May include non-running workers. saturated: set[WorkerState] total_nthreads: int - total_occupancy: float + total_occupancy: Occupancy #: Cluster-wide resources. {resource name: {worker address: amount}} resources: dict[str, dict[str, float]] @@ -1426,7 +1426,7 @@ def __init__( self.task_prefixes = {} self.task_metadata = {} self.total_nthreads = 0 - self.total_occupancy = 0.0 + self.total_occupancy = Occupancy(0.0, 0.0) self.unknown_durations = {} self.queued = queued self.unrunnable = unrunnable @@ -2000,9 +2000,9 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: wp_vals = cast("Sequence[WorkerState]", worker_pool.values()) n_workers: int = len(wp_vals) if n_workers < 20: # smart but linear in small case - ws = min(wp_vals, key=operator.attrgetter("occupancy")) + ws = min(wp_vals, key=lambda ws: ws.occupancy.total) assert ws - if ws.occupancy == 0: + if not ws.occupancy: # special case to use round-robin; linear search # for next worker with zero occupancy (or just # land back where we started). @@ -2011,7 +2011,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: i: int for i in range(n_workers): wp_i = wp_vals[(i + start) % n_workers] - if wp_i.occupancy == 0: + if not wp_i.occupancy: ws = wp_i break else: # dumb but fast in large case @@ -2843,19 +2843,21 @@ def _set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> None: if ts in ws.long_running: return - exec_time: float = ws.executing.get(ts, 0) - duration: float = self.get_task_duration(ts) - total_duration: float - if exec_time > 2 * duration: - total_duration = 2 * exec_time + exec_time = ws.executing.get(ts, 0.0) + cpu = self.get_task_duration(ts) + if exec_time > 2 * cpu: + cpu = 2 * exec_time + # FIXME this matches existing behavior but is clearly bizarre + # https://github.com/dask/distributed/issues/7003 + network = 0.0 else: - comm: float = self.get_comm_cost(ts, ws) - total_duration = duration + comm + network = self.get_comm_cost(ts, ws) - old = ws.processing.get(ts, 0) - ws.processing[ts] = total_duration - self.total_occupancy += total_duration - old - ws.occupancy += total_duration - old + old = ws.processing.get(ts, Occupancy(0, 0)) + ws.processing[ts] = new = Occupancy(cpu, network) + delta = new - old + self.total_occupancy += delta + ws.occupancy += delta def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0): """Update the status of the idle and saturated state @@ -2883,11 +2885,11 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0): if self.total_nthreads == 0 or ws.status == Status.closed: return if occ < 0: - occ = ws.occupancy + occ = ws.occupancy.total nc: int = ws.nthreads p: int = len(ws.processing) - avg: float = self.total_occupancy / self.total_nthreads + avg: float = self.total_occupancy.total / self.total_nthreads idle = self.idle saturated = self.saturated @@ -3046,7 +3048,8 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: nbytes = dts.get_nbytes() comm_bytes += nbytes - stack_time: float = ws.occupancy / ws.nthreads + # FIXME use `occupancy.cpu` https://github.com/dask/distributed/issues/7003 + stack_time: float = ws.occupancy.total / ws.nthreads start_time: float = stack_time + comm_bytes / self.bandwidth if ts.actor: @@ -3088,7 +3091,7 @@ def remove_all_replicas(self, ts: TaskState): def _reevaluate_occupancy_worker(self, ws: WorkerState): """See reevaluate_occupancy""" ts: TaskState - old = ws.occupancy + old = ws.occupancy.total for ts in ws.processing: self._set_duration_estimate(ts, ws) @@ -3096,7 +3099,8 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): steal = self.extensions.get("stealing") if steal is None: return - if ws.occupancy > old * 1.3 or old > ws.occupancy * 1.3: + current = ws.occupancy.total + if current > old * 1.3 or old > current * 1.3: for ts in ws.processing: steal.recalculate_cost(ts) @@ -4185,7 +4189,7 @@ def update_graph( dependencies = dependencies or {} - if self.total_occupancy > 1e-9 and self.computations: + if self.total_occupancy.total > 1e-9 and self.computations: # Still working on something. Assign new tasks to same computation computation = self.computations[-1] else: @@ -4922,19 +4926,26 @@ def validate_state(self, allow_overlap: bool = False) -> None: } assert a == b, (a, b) - actual_total_occupancy = 0.0 + actual_total_occupancy = Occupancy(0, 0) for worker, ws in self.workers.items(): ws_processing_total = sum( - cost for ts, cost in ws.processing.items() if ts not in ws.long_running + ( + cost + for ts, cost in ws.processing.items() + if ts not in ws.long_running + ), + start=Occupancy(0, 0), ) - assert abs(ws_processing_total - ws.occupancy) < 1e-8, ( + delta = ws_processing_total - ws.occupancy + assert abs(delta.cpu) < 1e-8 and abs(delta.network) < 1e-8, ( worker, ws_processing_total, ws.occupancy, ) actual_total_occupancy += ws.occupancy - assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, ( + delta = actual_total_occupancy - self.total_occupancy + assert abs(delta.cpu) < 1e-8 and abs(delta.network) < 1e-8, ( actual_total_occupancy, self.total_occupancy, ) @@ -5131,7 +5142,7 @@ def handle_long_running( if key not in self.tasks: logger.debug("Skipping long_running since key %s was already released", key) return - ts = self.tasks[key] + ts: TaskState = self.tasks[key] steal = self.extensions.get("stealing") if steal is not None: steal.remove_key_from_stealable(ts) @@ -5155,7 +5166,7 @@ def handle_long_running( # idleness detection. Idle workers are typically targeted for # downscaling but we should not downscale workers with long running # tasks - ws.processing[ts] = 0 + ws.processing[ts].clear() ws.long_running.add(ts) self.check_idle_saturated(ws) @@ -7594,10 +7605,12 @@ def adaptive_target(self, target_duration=None): # CPU # TODO consider any user-specified default task durations for queued tasks - queued_occupancy = len(self.queued) * self.UNKNOWN_TASK_DURATION + queued_occupancy: float = len(self.queued) * self.UNKNOWN_TASK_DURATION + # TODO: threads per worker + # TODO don't include network occupancy? cpu = math.ceil( - (self.total_occupancy + queued_occupancy) / target_duration - ) # TODO: threads per worker + (self.total_occupancy.total + queued_occupancy) / target_duration + ) # Avoid a few long tasks from asking for many cores tasks_ready = len(self.queued) @@ -7744,7 +7757,7 @@ def _exit_processing_common( ws.long_running.discard(ts) if not ws.processing: state.total_occupancy -= ws.occupancy - ws.occupancy = 0 + ws.occupancy.clear() else: state.total_occupancy -= duration ws.occupancy -= duration diff --git a/distributed/stealing.py b/distributed/stealing.py index f1539c886e..4b29df3f42 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -15,6 +15,7 @@ import dask from dask.utils import parse_timedelta +from distributed.collections import Occupancy from distributed.comm.addressing import get_address_host from distributed.core import CommClosedError, Status from distributed.diagnostics.plugin import SchedulerPlugin @@ -57,8 +58,8 @@ class InFlightInfo(TypedDict): victim: WorkerState thief: WorkerState - victim_duration: float - thief_duration: float + victim_duration: Occupancy + thief_duration: Occupancy stimulus_id: str @@ -79,7 +80,7 @@ class WorkStealing(SchedulerPlugin): # { task state: } in_flight: dict[TaskState, InFlightInfo] # { worker state: occupancy } - in_flight_occupancy: defaultdict[WorkerState, float] + in_flight_occupancy: defaultdict[WorkerState, Occupancy] _in_flight_event: asyncio.Event _request_counter: int @@ -104,7 +105,7 @@ def __init__(self, scheduler: Scheduler): self.scheduler.events["stealing"] = deque(maxlen=100000) self.count = 0 self.in_flight = {} - self.in_flight_occupancy = defaultdict(lambda: 0) + self.in_flight_occupancy = defaultdict(lambda: Occupancy(0, 0)) self._in_flight_event = asyncio.Event() self._request_counter = 0 self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm @@ -232,21 +233,23 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non return None, None if not ts.dependencies: # no dependencies fast path - return 0, 0 + return 0.0, 0 assert ts.processing_on ws = ts.processing_on - compute_time = ws.processing[ts] + assert ws + occupancy = ws.processing[ts] - if not compute_time: + if not occupancy: # occupancy/ws.proccessing[ts] is only allowed to be zero for # long running tasks which cannot be stolen assert ts in ws.long_running return None, None nbytes = ts.get_nbytes_deps() - transfer_time = nbytes / self.scheduler.bandwidth + LATENCY - cost_multiplier = transfer_time / compute_time + transfer_time: float = nbytes / self.scheduler.bandwidth + LATENCY + # FIXME don't use `occupancy.total` https://github.com/dask/distributed/issues/7003 + cost_multiplier = transfer_time / occupancy.total level = int(round(log2(cost_multiplier) + 6)) if level < 1: @@ -280,9 +283,10 @@ def move_task_request( victim_duration = victim.processing[ts] - thief_duration = self.scheduler.get_task_duration( - ts - ) + self.scheduler.get_comm_cost(ts, thief) + thief_duration = Occupancy( + self.scheduler.get_task_duration(ts), + self.scheduler.get_comm_cost(ts, thief), + ) self.scheduler.stream_comms[victim.address].send( {"op": "steal-request", "key": key, "stimulus_id": stimulus_id} @@ -374,7 +378,7 @@ async def move_task_confirm( self.scheduler.total_occupancy -= duration if not victim.processing: self.scheduler.total_occupancy -= victim.occupancy - victim.occupancy = 0 + victim.occupancy.clear() thief.processing[ts] = d["thief_duration"] thief.occupancy += d["thief_duration"] self.scheduler.total_occupancy += d["thief_duration"] @@ -401,27 +405,32 @@ def balance(self) -> None: start = time() def combined_occupancy(ws: WorkerState) -> float: - return ws.occupancy + self.in_flight_occupancy[ws] + return ws.occupancy.total + self.in_flight_occupancy[ws].total def maybe_move_task( level: int, ts: TaskState, victim: WorkerState, thief: WorkerState, - duration: float, + duration: Occupancy, cost_multiplier: float, ) -> None: + # TODO calculate separately for cpu vs network? occ_thief = combined_occupancy(thief) occ_victim = combined_occupancy(victim) - if occ_thief + cost_multiplier * duration <= occ_victim - duration / 2: + duration_total = duration.total + if ( + occ_thief + cost_multiplier * duration_total + <= occ_victim - duration_total / 2 + ): self.move_task_request(ts, victim, thief) log.append( ( start, level, ts.key, - duration, + duration_total, victim.address, occ_victim, thief.address, diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 3336fc8481..e140d05d86 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -1034,7 +1034,7 @@ def f(ev1, ev2, ev3, ev4): await ev1.wait() ts = a.state.tasks["x"] assert ts.state == "executing" - assert sum(ws.processing.values()) > 0 + assert any(ws.processing.values()) x.release() await wait_for_state("x", "cancelled", a) @@ -1050,7 +1050,7 @@ def f(ev1, ev2, ev3, ev4): # Test that the scheduler receives a delayed {op: long-running} assert ws.processing - while sum(ws.processing.values()): + while any(ws.processing.values()): await asyncio.sleep(0.1) assert ws.processing diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 3121dcb6bd..848c74a527 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -70,6 +70,7 @@ wait, ) from distributed.cluster_dump import load_cluster_dump +from distributed.collections import Occupancy from distributed.comm import CommClosedError from distributed.compatibility import LINUX, WINDOWS from distributed.core import Status @@ -5109,8 +5110,11 @@ def long_running(lock, entered): await entered.wait() ts = s.tasks[f.key] ws = s.workers[a.address] - assert ws.occupancy == parse_timedelta( - dask.config.get("distributed.scheduler.unknown-task-duration") + assert ws.occupancy == Occupancy( + cpu=parse_timedelta( + dask.config.get("distributed.scheduler.unknown-task-duration"), + ), + network=0, ) while ws.occupancy: @@ -5118,12 +5122,12 @@ def long_running(lock, entered): await a.heartbeat() s._set_duration_estimate(ts, ws) - assert s.workers[a.address].occupancy == 0 - assert s.total_occupancy == 0 - assert ws.occupancy == 0 + assert s.workers[a.address].occupancy == Occupancy(0, 0) + assert s.total_occupancy == Occupancy(0, 0) + assert ws.occupancy == Occupancy(0, 0) s._ongoing_background_tasks.call_soon(s.reevaluate_occupancy, 0) - assert s.workers[a.address].occupancy == 0 + assert s.workers[a.address].occupancy == Occupancy(0, 0) await l.release() with ( @@ -5133,8 +5137,8 @@ def long_running(lock, entered): ): await f - assert s.total_occupancy == 0 - assert ws.occupancy == 0 + assert s.total_occupancy == Occupancy(0, 0) + assert ws.occupancy == Occupancy(0, 0) assert not ws.long_running @@ -5176,14 +5180,16 @@ def long_running(lock, entered): if ordinary_task: # Should be exactly 0.5 but if for whatever reason this test runs slow, # some approximation may kick in increasing this number - assert s.total_occupancy >= 0.5 - assert ws.occupancy >= 0.5 + assert not s.total_occupancy.network + assert not ws.occupancy.network + assert s.total_occupancy.cpu >= 0.5 + assert ws.occupancy.cpu >= 0.5 await l2.release() await f2 # In the end, everything should be reset - assert s.total_occupancy == 0 - assert ws.occupancy == 0 + assert not s.total_occupancy + assert not ws.occupancy assert not ws.long_running diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 066cf147a3..6df5a20613 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -7,7 +7,7 @@ import pytest -from distributed.collections import LRU, HeapSet +from distributed.collections import LRU, HeapSet, Occupancy def test_lru(): @@ -339,3 +339,48 @@ def test_heapset_sort_duplicate(): heap.add(c1) assert list(heap.sorted()) == [c1, c2] + + +def test_occupancy(): + ozero = Occupancy(0, 0) + assert not ozero + assert not ozero.total + assert ozero == ozero + + o1_0 = Occupancy(1, 0) + assert o1_0 + assert o1_0.total == 1 + assert ozero != o1_0 + assert o1_0 + ozero == o1_0 + + o0_1 = Occupancy(0, 1) + o1_1 = o0_1 + o1_0 + assert o1_1.total == 2 + assert o1_1 == o1_0 + o0_1 + + assert o1_1 - o0_1 == o1_0 + + mut = Occupancy(0, 0) + mut += ozero + assert not mut + mut += o1_0 + assert mut == o1_0 + mut += o1_0 + assert mut == Occupancy(2, 0) + + mut -= o0_1 + assert mut == Occupancy(2, -1) + assert mut.total == 1 + + mut.clear() + assert mut == ozero + + assert not mut == 1 + with pytest.raises(TypeError): + mut + 1 + with pytest.raises(TypeError): + mut - 1 + with pytest.raises(TypeError): + mut += 1 + with pytest.raises(TypeError): + mut -= 1 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 35fe85cdc7..9a768743e0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1423,12 +1423,13 @@ async def test_learn_occupancy(c, s, a, b): await asyncio.sleep(0.01) nproc = sum(ts.state == "processing" for ts in s.tasks.values()) - assert nproc * 0.1 < s.total_occupancy < nproc * 0.4 + assert not s.total_occupancy.network + assert nproc * 0.1 < s.total_occupancy.cpu < nproc * 0.4 for w in [a, b]: ws = s.workers[w.address] - occ = ws.occupancy + assert not ws.occupancy.network proc = len(ws.processing) - assert proc * 0.1 < occ < proc * 0.4 + assert proc * 0.1 < ws.occupancy.cpu < proc * 0.4 @pytest.mark.slow @@ -1440,7 +1441,8 @@ async def test_learn_occupancy_2(c, s, a, b): await asyncio.sleep(0.01) nproc = sum(ts.state == "processing" for ts in s.tasks.values()) - assert nproc * 0.1 < s.total_occupancy < nproc * 0.4 + assert not s.total_occupancy.network + assert nproc * 0.1 < s.total_occupancy.cpu < nproc * 0.4 @gen_cluster(client=True) @@ -1448,14 +1450,14 @@ async def test_occupancy_cleardown(c, s, a, b): s.validate = False # Inject excess values in s.occupancy - s.workers[a.address].occupancy = 2 - s.total_occupancy += 2 + s.workers[a.address].occupancy.cpu += 2 + s.total_occupancy.cpu += 2 futures = c.map(slowinc, range(100), delay=0.01) await wait(futures) # Verify that occupancy values have been zeroed out - assert abs(s.total_occupancy) < 0.01 - assert all(ws.occupancy == 0 for ws in s.workers.values()) + assert abs(s.total_occupancy.total) < 0.01 + assert all(not ws.occupancy for ws in s.workers.values()) @nodebug @@ -1492,11 +1494,13 @@ async def test_learn_occupancy_multiple_workers(c, s, a, b): await wait(x) - assert not any(v == 0.5 for w in s.workers.values() for v in w.processing.values()) + assert not any( + occ.cpu == 0.5 for w in s.workers.values() for occ in w.processing.values() + ) @gen_cluster(client=True) -async def test_include_communication_in_occupancy(c, s, a, b): +async def test_occupancy_network(c, s, a, b): await c.submit(slowadd, 1, 2, delay=0) x = c.submit(operator.mul, b"0", int(s.bandwidth), workers=a.address) y = c.submit(operator.mul, b"1", int(s.bandwidth * 1.5), workers=b.address) @@ -1507,7 +1511,9 @@ async def test_include_communication_in_occupancy(c, s, a, b): ts = s.tasks[z.key] assert ts.processing_on == s.workers[b.address] - assert s.workers[b.address].processing[ts] > 1 + occ = s.workers[b.address].processing[ts] + assert occ.network >= 1 + assert occ.cpu > 0 await wait(z) del z diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 4a3751d6d9..cff30368f6 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1129,7 +1129,7 @@ def block(x, event): del futs1 - assert all(v == 0 for v in steal.in_flight_occupancy.values()) + assert all(not v for v in steal.in_flight_occupancy.values()) @gen_cluster(