diff --git a/distributed/collections.py b/distributed/collections.py
index 4aef7d555e..992b4582eb 100644
--- a/distributed/collections.py
+++ b/distributed/collections.py
@@ -1,12 +1,18 @@
from __future__ import annotations
+import dataclasses
import heapq
import itertools
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator
-from typing import MutableSet # TODO move to collections.abc (requires Python >=3.9)
-from typing import Any, TypeVar, cast
+from typing import ( # TODO move to collections.abc (requires Python >=3.9)
+ Any,
+ Container,
+ MutableSet,
+ TypeVar,
+ cast,
+)
T = TypeVar("T", bound=Hashable)
@@ -199,3 +205,54 @@ def clear(self) -> None:
self._data.clear()
self._heap.clear()
self._sorted = True
+
+
+# NOTE: only used in Scheduler; if work stealing is ever removed,
+# this could be moved to `scheduler.py`.
+@dataclasses.dataclass
+class Occupancy:
+ cpu: float
+ network: float
+
+ def __add__(self, other: Any) -> Occupancy:
+ if isinstance(other, type(self)):
+ return type(self)(self.cpu + other.cpu, self.network + other.network)
+ return NotImplemented
+
+ def __iadd__(self, other: Any) -> Occupancy:
+ if isinstance(other, type(self)):
+ self.cpu += other.cpu
+ self.network += other.network
+ return self
+ return NotImplemented
+
+ def __sub__(self, other: Any) -> Occupancy:
+ if isinstance(other, type(self)):
+ return type(self)(self.cpu - other.cpu, self.network - other.network)
+ return NotImplemented
+
+ def __isub__(self, other: Any) -> Occupancy:
+ if isinstance(other, type(self)):
+ self.cpu -= other.cpu
+ self.network -= other.network
+ return self
+ return NotImplemented
+
+ def __bool__(self) -> bool:
+ return self.cpu != 0 or self.network != 0
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, type(self)):
+ return self.cpu == other.cpu and self.network == other.network
+ return NotImplemented
+
+ def clear(self) -> None:
+ self.cpu = 0.0
+ self.network = 0.0
+
+ def _to_dict(self, *, exclude: Container[str] = ()) -> dict[str, float]:
+ return {"cpu": self.cpu, "network": self.network}
+
+ @property
+ def total(self) -> float:
+ return self.cpu + self.network
diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py
index 01ca786cef..d37c153a4e 100644
--- a/distributed/dashboard/components/scheduler.py
+++ b/distributed/dashboard/components/scheduler.py
@@ -163,7 +163,8 @@ def update(self):
workers = self.scheduler.workers.values()
y = list(range(len(workers)))
- occupancy = [ws.occupancy for ws in workers]
+ # TODO split chart by cpu vs network
+ occupancy = [ws.occupancy.total for ws in workers]
ms = [occ * 1000 for occ in occupancy]
x = [occ / 500 for occ in occupancy]
total = sum(occupancy)
diff --git a/distributed/http/templates/worker-table.html b/distributed/http/templates/worker-table.html
index 87512ee386..52eeb9c3f8 100644
--- a/distributed/http/templates/worker-table.html
+++ b/distributed/http/templates/worker-table.html
@@ -5,7 +5,8 @@
Cores |
Memory |
Memory use |
- Occupancy |
+ Network Occupancy |
+ CPU Occupancy |
Processing |
In-memory |
Services |
@@ -19,7 +20,8 @@
{{ ws.nthreads }} |
{{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} |
|
- {{ format_time(ws.occupancy) }} |
+ {{ format_time(ws.occupancy.network) }} |
+ {{ format_time(ws.occupancy.cpu) }} |
{{ len(ws.processing) }} |
{{ len(ws.has_what) }} |
{% if 'dashboard' in ws.services %}
diff --git a/distributed/scheduler.py b/distributed/scheduler.py
index f5b73c83dd..568b19a55a 100644
--- a/distributed/scheduler.py
+++ b/distributed/scheduler.py
@@ -66,7 +66,7 @@
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
-from distributed.collections import HeapSet
+from distributed.collections import HeapSet, Occupancy
from distributed.comm import (
Comm,
CommClosedError,
@@ -425,10 +425,10 @@ class WorkerState:
#: (i.e. the tasks in this worker's :attr:`~WorkerState.has_what`).
nbytes: int
- #: The total expected runtime, in seconds, of all tasks currently processing on this
- #: worker. This is the sum of all the costs in this worker's
+ #: The total expected cost, in seconds, of all tasks currently processing on this
+ #: worker. This is the sum of all the Occupancies in this worker's
# :attr:`~WorkerState.processing` dictionary.
- occupancy: float
+ occupancy: Occupancy
#: Worker memory unknown to the worker, in bytes, which has been there for more than
#: 30 seconds. See :class:`MemoryState`.
@@ -456,12 +456,12 @@ class WorkerState:
_has_what: dict[TaskState, None]
#: A dictionary of tasks that have been submitted to this worker. Each task state is
- #: associated with the expected cost in seconds of running that task, summing both
- #: the task's expected computation time and the expected communication time of its
- #: result.
+ #: associated with the expected cost in seconds of running that task, both of
+ #: the task's expected computation time and the expected serial communication time of
+ #: its dependencies.
#:
- #: If a task is already executing on the worker and the excecution time is twice the
- #: learned average TaskGroup duration, this will be set to twice the current
+ #: If a task is already executing on the worker and the execution time is twice the
+ #: learned average TaskGroup duration, the `cpu` time will be set to twice the current
#: executing time. If the task is unknown, the default task duration is used instead
#: of the TaskGroup average.
#:
@@ -471,13 +471,13 @@ class WorkerState:
#:
#: All the tasks here are in the "processing" state.
#: This attribute is kept in sync with :attr:`TaskState.processing_on`.
- processing: dict[TaskState, float]
+ processing: dict[TaskState, Occupancy]
#: Running tasks that invoked :func:`distributed.secede`
long_running: set[TaskState]
#: A dictionary of tasks that are currently being run on this worker.
- #: Each task state is asssociated with the duration in seconds which the task has
+ #: Each task state is associated with the duration in seconds which the task has
#: been running.
executing: dict[TaskState, float]
@@ -528,7 +528,7 @@ def __init__(
self.status = status
self._hash = hash(self.server_id)
self.nbytes = 0
- self.occupancy = 0
+ self.occupancy = Occupancy(0.0, 0.0)
self._memory_unmanaged_old = 0
self._memory_unmanaged_history = deque()
self.metrics = {}
@@ -1312,7 +1312,7 @@ class SchedulerState:
#: Workers that are fully utilized. May include non-running workers.
saturated: set[WorkerState]
total_nthreads: int
- total_occupancy: float
+ total_occupancy: Occupancy
#: Cluster-wide resources. {resource name: {worker address: amount}}
resources: dict[str, dict[str, float]]
@@ -1426,7 +1426,7 @@ def __init__(
self.task_prefixes = {}
self.task_metadata = {}
self.total_nthreads = 0
- self.total_occupancy = 0.0
+ self.total_occupancy = Occupancy(0.0, 0.0)
self.unknown_durations = {}
self.queued = queued
self.unrunnable = unrunnable
@@ -2000,9 +2000,9 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
wp_vals = cast("Sequence[WorkerState]", worker_pool.values())
n_workers: int = len(wp_vals)
if n_workers < 20: # smart but linear in small case
- ws = min(wp_vals, key=operator.attrgetter("occupancy"))
+ ws = min(wp_vals, key=lambda ws: ws.occupancy.total)
assert ws
- if ws.occupancy == 0:
+ if not ws.occupancy:
# special case to use round-robin; linear search
# for next worker with zero occupancy (or just
# land back where we started).
@@ -2011,7 +2011,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
i: int
for i in range(n_workers):
wp_i = wp_vals[(i + start) % n_workers]
- if wp_i.occupancy == 0:
+ if not wp_i.occupancy:
ws = wp_i
break
else: # dumb but fast in large case
@@ -2843,19 +2843,21 @@ def _set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> None:
if ts in ws.long_running:
return
- exec_time: float = ws.executing.get(ts, 0)
- duration: float = self.get_task_duration(ts)
- total_duration: float
- if exec_time > 2 * duration:
- total_duration = 2 * exec_time
+ exec_time = ws.executing.get(ts, 0.0)
+ cpu = self.get_task_duration(ts)
+ if exec_time > 2 * cpu:
+ cpu = 2 * exec_time
+ # FIXME this matches existing behavior but is clearly bizarre
+ # https://github.com/dask/distributed/issues/7003
+ network = 0.0
else:
- comm: float = self.get_comm_cost(ts, ws)
- total_duration = duration + comm
+ network = self.get_comm_cost(ts, ws)
- old = ws.processing.get(ts, 0)
- ws.processing[ts] = total_duration
- self.total_occupancy += total_duration - old
- ws.occupancy += total_duration - old
+ old = ws.processing.get(ts, Occupancy(0, 0))
+ ws.processing[ts] = new = Occupancy(cpu, network)
+ delta = new - old
+ self.total_occupancy += delta
+ ws.occupancy += delta
def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
"""Update the status of the idle and saturated state
@@ -2883,11 +2885,11 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
if self.total_nthreads == 0 or ws.status == Status.closed:
return
if occ < 0:
- occ = ws.occupancy
+ occ = ws.occupancy.total
nc: int = ws.nthreads
p: int = len(ws.processing)
- avg: float = self.total_occupancy / self.total_nthreads
+ avg: float = self.total_occupancy.total / self.total_nthreads
idle = self.idle
saturated = self.saturated
@@ -3046,7 +3048,8 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple:
nbytes = dts.get_nbytes()
comm_bytes += nbytes
- stack_time: float = ws.occupancy / ws.nthreads
+ # FIXME use `occupancy.cpu` https://github.com/dask/distributed/issues/7003
+ stack_time: float = ws.occupancy.total / ws.nthreads
start_time: float = stack_time + comm_bytes / self.bandwidth
if ts.actor:
@@ -3088,7 +3091,7 @@ def remove_all_replicas(self, ts: TaskState):
def _reevaluate_occupancy_worker(self, ws: WorkerState):
"""See reevaluate_occupancy"""
ts: TaskState
- old = ws.occupancy
+ old = ws.occupancy.total
for ts in ws.processing:
self._set_duration_estimate(ts, ws)
@@ -3096,7 +3099,8 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState):
steal = self.extensions.get("stealing")
if steal is None:
return
- if ws.occupancy > old * 1.3 or old > ws.occupancy * 1.3:
+ current = ws.occupancy.total
+ if current > old * 1.3 or old > current * 1.3:
for ts in ws.processing:
steal.recalculate_cost(ts)
@@ -4185,7 +4189,7 @@ def update_graph(
dependencies = dependencies or {}
- if self.total_occupancy > 1e-9 and self.computations:
+ if self.total_occupancy.total > 1e-9 and self.computations:
# Still working on something. Assign new tasks to same computation
computation = self.computations[-1]
else:
@@ -4922,19 +4926,26 @@ def validate_state(self, allow_overlap: bool = False) -> None:
}
assert a == b, (a, b)
- actual_total_occupancy = 0.0
+ actual_total_occupancy = Occupancy(0, 0)
for worker, ws in self.workers.items():
ws_processing_total = sum(
- cost for ts, cost in ws.processing.items() if ts not in ws.long_running
+ (
+ cost
+ for ts, cost in ws.processing.items()
+ if ts not in ws.long_running
+ ),
+ start=Occupancy(0, 0),
)
- assert abs(ws_processing_total - ws.occupancy) < 1e-8, (
+ delta = ws_processing_total - ws.occupancy
+ assert abs(delta.cpu) < 1e-8 and abs(delta.network) < 1e-8, (
worker,
ws_processing_total,
ws.occupancy,
)
actual_total_occupancy += ws.occupancy
- assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, (
+ delta = actual_total_occupancy - self.total_occupancy
+ assert abs(delta.cpu) < 1e-8 and abs(delta.network) < 1e-8, (
actual_total_occupancy,
self.total_occupancy,
)
@@ -5131,7 +5142,7 @@ def handle_long_running(
if key not in self.tasks:
logger.debug("Skipping long_running since key %s was already released", key)
return
- ts = self.tasks[key]
+ ts: TaskState = self.tasks[key]
steal = self.extensions.get("stealing")
if steal is not None:
steal.remove_key_from_stealable(ts)
@@ -5155,7 +5166,7 @@ def handle_long_running(
# idleness detection. Idle workers are typically targeted for
# downscaling but we should not downscale workers with long running
# tasks
- ws.processing[ts] = 0
+ ws.processing[ts].clear()
ws.long_running.add(ts)
self.check_idle_saturated(ws)
@@ -7594,10 +7605,12 @@ def adaptive_target(self, target_duration=None):
# CPU
# TODO consider any user-specified default task durations for queued tasks
- queued_occupancy = len(self.queued) * self.UNKNOWN_TASK_DURATION
+ queued_occupancy: float = len(self.queued) * self.UNKNOWN_TASK_DURATION
+ # TODO: threads per worker
+ # TODO don't include network occupancy?
cpu = math.ceil(
- (self.total_occupancy + queued_occupancy) / target_duration
- ) # TODO: threads per worker
+ (self.total_occupancy.total + queued_occupancy) / target_duration
+ )
# Avoid a few long tasks from asking for many cores
tasks_ready = len(self.queued)
@@ -7744,7 +7757,7 @@ def _exit_processing_common(
ws.long_running.discard(ts)
if not ws.processing:
state.total_occupancy -= ws.occupancy
- ws.occupancy = 0
+ ws.occupancy.clear()
else:
state.total_occupancy -= duration
ws.occupancy -= duration
diff --git a/distributed/stealing.py b/distributed/stealing.py
index f1539c886e..4b29df3f42 100644
--- a/distributed/stealing.py
+++ b/distributed/stealing.py
@@ -15,6 +15,7 @@
import dask
from dask.utils import parse_timedelta
+from distributed.collections import Occupancy
from distributed.comm.addressing import get_address_host
from distributed.core import CommClosedError, Status
from distributed.diagnostics.plugin import SchedulerPlugin
@@ -57,8 +58,8 @@
class InFlightInfo(TypedDict):
victim: WorkerState
thief: WorkerState
- victim_duration: float
- thief_duration: float
+ victim_duration: Occupancy
+ thief_duration: Occupancy
stimulus_id: str
@@ -79,7 +80,7 @@ class WorkStealing(SchedulerPlugin):
# { task state: }
in_flight: dict[TaskState, InFlightInfo]
# { worker state: occupancy }
- in_flight_occupancy: defaultdict[WorkerState, float]
+ in_flight_occupancy: defaultdict[WorkerState, Occupancy]
_in_flight_event: asyncio.Event
_request_counter: int
@@ -104,7 +105,7 @@ def __init__(self, scheduler: Scheduler):
self.scheduler.events["stealing"] = deque(maxlen=100000)
self.count = 0
self.in_flight = {}
- self.in_flight_occupancy = defaultdict(lambda: 0)
+ self.in_flight_occupancy = defaultdict(lambda: Occupancy(0, 0))
self._in_flight_event = asyncio.Event()
self._request_counter = 0
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
@@ -232,21 +233,23 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non
return None, None
if not ts.dependencies: # no dependencies fast path
- return 0, 0
+ return 0.0, 0
assert ts.processing_on
ws = ts.processing_on
- compute_time = ws.processing[ts]
+ assert ws
+ occupancy = ws.processing[ts]
- if not compute_time:
+ if not occupancy:
# occupancy/ws.proccessing[ts] is only allowed to be zero for
# long running tasks which cannot be stolen
assert ts in ws.long_running
return None, None
nbytes = ts.get_nbytes_deps()
- transfer_time = nbytes / self.scheduler.bandwidth + LATENCY
- cost_multiplier = transfer_time / compute_time
+ transfer_time: float = nbytes / self.scheduler.bandwidth + LATENCY
+ # FIXME don't use `occupancy.total` https://github.com/dask/distributed/issues/7003
+ cost_multiplier = transfer_time / occupancy.total
level = int(round(log2(cost_multiplier) + 6))
if level < 1:
@@ -280,9 +283,10 @@ def move_task_request(
victim_duration = victim.processing[ts]
- thief_duration = self.scheduler.get_task_duration(
- ts
- ) + self.scheduler.get_comm_cost(ts, thief)
+ thief_duration = Occupancy(
+ self.scheduler.get_task_duration(ts),
+ self.scheduler.get_comm_cost(ts, thief),
+ )
self.scheduler.stream_comms[victim.address].send(
{"op": "steal-request", "key": key, "stimulus_id": stimulus_id}
@@ -374,7 +378,7 @@ async def move_task_confirm(
self.scheduler.total_occupancy -= duration
if not victim.processing:
self.scheduler.total_occupancy -= victim.occupancy
- victim.occupancy = 0
+ victim.occupancy.clear()
thief.processing[ts] = d["thief_duration"]
thief.occupancy += d["thief_duration"]
self.scheduler.total_occupancy += d["thief_duration"]
@@ -401,27 +405,32 @@ def balance(self) -> None:
start = time()
def combined_occupancy(ws: WorkerState) -> float:
- return ws.occupancy + self.in_flight_occupancy[ws]
+ return ws.occupancy.total + self.in_flight_occupancy[ws].total
def maybe_move_task(
level: int,
ts: TaskState,
victim: WorkerState,
thief: WorkerState,
- duration: float,
+ duration: Occupancy,
cost_multiplier: float,
) -> None:
+ # TODO calculate separately for cpu vs network?
occ_thief = combined_occupancy(thief)
occ_victim = combined_occupancy(victim)
- if occ_thief + cost_multiplier * duration <= occ_victim - duration / 2:
+ duration_total = duration.total
+ if (
+ occ_thief + cost_multiplier * duration_total
+ <= occ_victim - duration_total / 2
+ ):
self.move_task_request(ts, victim, thief)
log.append(
(
start,
level,
ts.key,
- duration,
+ duration_total,
victim.address,
occ_victim,
thief.address,
diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py
index 3336fc8481..e140d05d86 100644
--- a/distributed/tests/test_cancelled_state.py
+++ b/distributed/tests/test_cancelled_state.py
@@ -1034,7 +1034,7 @@ def f(ev1, ev2, ev3, ev4):
await ev1.wait()
ts = a.state.tasks["x"]
assert ts.state == "executing"
- assert sum(ws.processing.values()) > 0
+ assert any(ws.processing.values())
x.release()
await wait_for_state("x", "cancelled", a)
@@ -1050,7 +1050,7 @@ def f(ev1, ev2, ev3, ev4):
# Test that the scheduler receives a delayed {op: long-running}
assert ws.processing
- while sum(ws.processing.values()):
+ while any(ws.processing.values()):
await asyncio.sleep(0.1)
assert ws.processing
diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py
index 3121dcb6bd..848c74a527 100644
--- a/distributed/tests/test_client.py
+++ b/distributed/tests/test_client.py
@@ -70,6 +70,7 @@
wait,
)
from distributed.cluster_dump import load_cluster_dump
+from distributed.collections import Occupancy
from distributed.comm import CommClosedError
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import Status
@@ -5109,8 +5110,11 @@ def long_running(lock, entered):
await entered.wait()
ts = s.tasks[f.key]
ws = s.workers[a.address]
- assert ws.occupancy == parse_timedelta(
- dask.config.get("distributed.scheduler.unknown-task-duration")
+ assert ws.occupancy == Occupancy(
+ cpu=parse_timedelta(
+ dask.config.get("distributed.scheduler.unknown-task-duration"),
+ ),
+ network=0,
)
while ws.occupancy:
@@ -5118,12 +5122,12 @@ def long_running(lock, entered):
await a.heartbeat()
s._set_duration_estimate(ts, ws)
- assert s.workers[a.address].occupancy == 0
- assert s.total_occupancy == 0
- assert ws.occupancy == 0
+ assert s.workers[a.address].occupancy == Occupancy(0, 0)
+ assert s.total_occupancy == Occupancy(0, 0)
+ assert ws.occupancy == Occupancy(0, 0)
s._ongoing_background_tasks.call_soon(s.reevaluate_occupancy, 0)
- assert s.workers[a.address].occupancy == 0
+ assert s.workers[a.address].occupancy == Occupancy(0, 0)
await l.release()
with (
@@ -5133,8 +5137,8 @@ def long_running(lock, entered):
):
await f
- assert s.total_occupancy == 0
- assert ws.occupancy == 0
+ assert s.total_occupancy == Occupancy(0, 0)
+ assert ws.occupancy == Occupancy(0, 0)
assert not ws.long_running
@@ -5176,14 +5180,16 @@ def long_running(lock, entered):
if ordinary_task:
# Should be exactly 0.5 but if for whatever reason this test runs slow,
# some approximation may kick in increasing this number
- assert s.total_occupancy >= 0.5
- assert ws.occupancy >= 0.5
+ assert not s.total_occupancy.network
+ assert not ws.occupancy.network
+ assert s.total_occupancy.cpu >= 0.5
+ assert ws.occupancy.cpu >= 0.5
await l2.release()
await f2
# In the end, everything should be reset
- assert s.total_occupancy == 0
- assert ws.occupancy == 0
+ assert not s.total_occupancy
+ assert not ws.occupancy
assert not ws.long_running
diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py
index 066cf147a3..6df5a20613 100644
--- a/distributed/tests/test_collections.py
+++ b/distributed/tests/test_collections.py
@@ -7,7 +7,7 @@
import pytest
-from distributed.collections import LRU, HeapSet
+from distributed.collections import LRU, HeapSet, Occupancy
def test_lru():
@@ -339,3 +339,48 @@ def test_heapset_sort_duplicate():
heap.add(c1)
assert list(heap.sorted()) == [c1, c2]
+
+
+def test_occupancy():
+ ozero = Occupancy(0, 0)
+ assert not ozero
+ assert not ozero.total
+ assert ozero == ozero
+
+ o1_0 = Occupancy(1, 0)
+ assert o1_0
+ assert o1_0.total == 1
+ assert ozero != o1_0
+ assert o1_0 + ozero == o1_0
+
+ o0_1 = Occupancy(0, 1)
+ o1_1 = o0_1 + o1_0
+ assert o1_1.total == 2
+ assert o1_1 == o1_0 + o0_1
+
+ assert o1_1 - o0_1 == o1_0
+
+ mut = Occupancy(0, 0)
+ mut += ozero
+ assert not mut
+ mut += o1_0
+ assert mut == o1_0
+ mut += o1_0
+ assert mut == Occupancy(2, 0)
+
+ mut -= o0_1
+ assert mut == Occupancy(2, -1)
+ assert mut.total == 1
+
+ mut.clear()
+ assert mut == ozero
+
+ assert not mut == 1
+ with pytest.raises(TypeError):
+ mut + 1
+ with pytest.raises(TypeError):
+ mut - 1
+ with pytest.raises(TypeError):
+ mut += 1
+ with pytest.raises(TypeError):
+ mut -= 1
diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py
index 35fe85cdc7..9a768743e0 100644
--- a/distributed/tests/test_scheduler.py
+++ b/distributed/tests/test_scheduler.py
@@ -1423,12 +1423,13 @@ async def test_learn_occupancy(c, s, a, b):
await asyncio.sleep(0.01)
nproc = sum(ts.state == "processing" for ts in s.tasks.values())
- assert nproc * 0.1 < s.total_occupancy < nproc * 0.4
+ assert not s.total_occupancy.network
+ assert nproc * 0.1 < s.total_occupancy.cpu < nproc * 0.4
for w in [a, b]:
ws = s.workers[w.address]
- occ = ws.occupancy
+ assert not ws.occupancy.network
proc = len(ws.processing)
- assert proc * 0.1 < occ < proc * 0.4
+ assert proc * 0.1 < ws.occupancy.cpu < proc * 0.4
@pytest.mark.slow
@@ -1440,7 +1441,8 @@ async def test_learn_occupancy_2(c, s, a, b):
await asyncio.sleep(0.01)
nproc = sum(ts.state == "processing" for ts in s.tasks.values())
- assert nproc * 0.1 < s.total_occupancy < nproc * 0.4
+ assert not s.total_occupancy.network
+ assert nproc * 0.1 < s.total_occupancy.cpu < nproc * 0.4
@gen_cluster(client=True)
@@ -1448,14 +1450,14 @@ async def test_occupancy_cleardown(c, s, a, b):
s.validate = False
# Inject excess values in s.occupancy
- s.workers[a.address].occupancy = 2
- s.total_occupancy += 2
+ s.workers[a.address].occupancy.cpu += 2
+ s.total_occupancy.cpu += 2
futures = c.map(slowinc, range(100), delay=0.01)
await wait(futures)
# Verify that occupancy values have been zeroed out
- assert abs(s.total_occupancy) < 0.01
- assert all(ws.occupancy == 0 for ws in s.workers.values())
+ assert abs(s.total_occupancy.total) < 0.01
+ assert all(not ws.occupancy for ws in s.workers.values())
@nodebug
@@ -1492,11 +1494,13 @@ async def test_learn_occupancy_multiple_workers(c, s, a, b):
await wait(x)
- assert not any(v == 0.5 for w in s.workers.values() for v in w.processing.values())
+ assert not any(
+ occ.cpu == 0.5 for w in s.workers.values() for occ in w.processing.values()
+ )
@gen_cluster(client=True)
-async def test_include_communication_in_occupancy(c, s, a, b):
+async def test_occupancy_network(c, s, a, b):
await c.submit(slowadd, 1, 2, delay=0)
x = c.submit(operator.mul, b"0", int(s.bandwidth), workers=a.address)
y = c.submit(operator.mul, b"1", int(s.bandwidth * 1.5), workers=b.address)
@@ -1507,7 +1511,9 @@ async def test_include_communication_in_occupancy(c, s, a, b):
ts = s.tasks[z.key]
assert ts.processing_on == s.workers[b.address]
- assert s.workers[b.address].processing[ts] > 1
+ occ = s.workers[b.address].processing[ts]
+ assert occ.network >= 1
+ assert occ.cpu > 0
await wait(z)
del z
diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py
index 4a3751d6d9..cff30368f6 100644
--- a/distributed/tests/test_steal.py
+++ b/distributed/tests/test_steal.py
@@ -1129,7 +1129,7 @@ def block(x, event):
del futs1
- assert all(v == 0 for v in steal.in_flight_occupancy.values())
+ assert all(not v for v in steal.in_flight_occupancy.values())
@gen_cluster(