Skip to content

Commit

Permalink
Code review
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 1, 2022
1 parent 74640b8 commit 23e6698
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 35 deletions.
31 changes: 20 additions & 11 deletions distributed/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import heapq
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Iterator
from collections.abc import Callable, Hashable, Iterator
from typing import MutableSet # TODO move to collections.abc (requires Python >=3.9)
from typing import Any, TypeVar, cast

T = TypeVar("T")
T = TypeVar("T", bound=Hashable)


# TODO change to UserDict[K, V] (requires Python >=3.9)
Expand Down Expand Up @@ -44,6 +44,7 @@ class HeapSet(MutableSet[T]):
_heap: list[tuple[Any, int, weakref.ref[T]]]

def __init__(self, *, key: Callable[[T], Any]):
# FIXME https://github.com/python/mypy/issues/708
self.key = key # type: ignore
self._data = set()
self._inc = 0
Expand All @@ -55,9 +56,6 @@ def __repr__(self) -> str:
def __contains__(self, value: object) -> bool:
return value in self._data

def __iter__(self) -> Iterator[T]:
return iter(self._data)

def __len__(self) -> int:
return len(self._data)

Expand All @@ -72,6 +70,8 @@ def add(self, value: T) -> None:

def discard(self, value: T) -> None:
self._data.discard(value)
if not self._data:
self._heap.clear()

def peek(self) -> T:
"""Get the smallest element without removing it"""
Expand All @@ -93,13 +93,22 @@ def pop(self) -> T:
self._data.remove(value)
return value

def sorted(self) -> list[T]:
"""Return a list containing all elements, from smallest to largest according to
the key and insertion order.
def __iter__(self) -> Iterator[T]:
"""Iterate over all elements. This is a O(n) operation which returns the
elements in pseudo-random order.
"""
return iter(self._data)

def sorted(self) -> Iterator[T]:
"""Iterate ofer all elements. This is a O(n*logn) operation which returns the
elements in order, from smallest to largest according to the key and insertion
order.
"""
out = []
for _, _, vref in sorted(self._heap):
value = vref()
if value in self._data:
out.append(value)
return out
yield value

def clear(self) -> None:
self._data.clear()
self._heap.clear()
29 changes: 21 additions & 8 deletions distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@ def __eq__(self, other):
assert cz in heap
assert cw in heap

heap_list = heap.sorted()
heap_sorted = heap.sorted()
# iteration does not empty heap
assert len(heap) == 4
assert len(heap_list) == 4
assert heap_list[0] is cy
assert heap_list[1] is cx
assert heap_list[2] is cz
assert heap_list[3] is cw
assert next(heap_sorted) is cy
assert next(heap_sorted) is cx
assert next(heap_sorted) is cz
assert next(heap_sorted) is cw
with pytest.raises(StopIteration):
next(heap_sorted)

assert set(heap) == {cx, cy, cz, cw}

Expand Down Expand Up @@ -96,7 +97,7 @@ def __eq__(self, other):
heap.discard(cw)

assert len(heap) == 2
assert heap.sorted() == [cx, cz]
assert list(heap.sorted()) == [cx, cz]
# cy is at the top of heap._heap, but is skipped
assert heap.peek() is cx
assert heap.pop() is cx
Expand All @@ -108,4 +109,16 @@ def __eq__(self, other):
heap.peek()
with pytest.raises(KeyError):
heap.pop()
assert heap.sorted() == []
assert list(heap.sorted()) == []

# Test clear()
heap.add(cx)
heap.clear()
assert not heap
heap.add(cx)
assert cx in heap
# Test discard last element
heap.discard(cx)
assert not heap
heap.add(cx)
assert cx in heap
33 changes: 18 additions & 15 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
import heapq
import logging
import operator
import os
import pathlib
import random
Expand Down Expand Up @@ -586,9 +587,9 @@ def __init__(
self.tasks = {}
self.waiting_for_data_count = 0
self.has_what = defaultdict(set)
self.data_needed = HeapSet(key=lambda ts: ts.priority)
self.data_needed = HeapSet(key=operator.attrgetter("priority"))
self.data_needed_per_worker = defaultdict(
lambda: HeapSet(key=lambda ts: ts.priority) # type: ignore
lambda: HeapSet(key=operator.attrgetter("priority"))
)
self.nanny = nanny
self._lock = threading.Lock()
Expand Down Expand Up @@ -1078,9 +1079,9 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
"status": self.status,
"ready": self.ready,
"constrained": self.constrained,
"data_needed": self.data_needed.sorted(),
"data_needed": list(self.data_needed.sorted()),
"data_needed_per_worker": {
w: v.sorted() for w, v in self.data_needed_per_worker.items()
w: list(v.sorted()) for w, v in self.data_needed_per_worker.items()
},
"long_running": self.long_running,
"executing_count": self.executing_count,
Expand Down Expand Up @@ -3032,16 +3033,16 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
local = [w for w in workers if get_address_host(w) == host]
worker = random.choice(local or workers)

to_gather, total_nbytes = self._select_keys_for_gather(worker, ts)
to_gather_tasks, total_nbytes = self._select_keys_for_gather(worker, ts)
to_gather_keys = {ts.key for ts in to_gather_tasks}

self.log.append(
("gather-dependencies", worker, to_gather, stimulus_id, time())
("gather-dependencies", worker, to_gather_keys, stimulus_id, time())
)

self.comm_nbytes += total_nbytes
self.in_flight_workers[worker] = to_gather
for d_key in to_gather:
d_ts = self.tasks[d_key]
self.in_flight_workers[worker] = to_gather_keys
for d_ts in to_gather_tasks:
if self.validate:
assert d_ts.state == "fetch"
assert d_ts not in recommendations
Expand All @@ -3055,7 +3056,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
instructions.append(
GatherDep(
worker=worker,
to_gather=to_gather,
to_gather=to_gather_keys,
total_nbytes=total_nbytes,
stimulus_id=stimulus_id,
)
Expand Down Expand Up @@ -3152,14 +3153,14 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs:

def _select_keys_for_gather(
self, worker: str, ts: TaskState
) -> tuple[set[str], int]:
) -> tuple[set[TaskState], int]:
"""``_ensure_communicating`` decided to fetch a single task from a worker,
following priority. In order to minimise overhead, request fetching other tasks
from the same worker within the message, following priority for the single
worker but ignoring higher priority tasks from other workers, up to
``target_message_size``.
"""
keys = {ts.key}
tss = {ts}
total_bytes = ts.get_nbytes()
tasks = self.data_needed_per_worker[worker]

Expand All @@ -3176,10 +3177,10 @@ def _select_keys_for_gather(
if other_worker != worker:
self.data_needed_per_worker[other_worker].remove(ts)

keys.add(ts.key)
tss.add(ts)
total_bytes += ts.get_nbytes()

return keys, total_bytes
return tss, total_bytes

@property
def total_comm_bytes(self):
Expand Down Expand Up @@ -4363,10 +4364,12 @@ def validate_state(self):

for ts in self.data_needed:
assert ts.state == "fetch"
assert self.tasks[ts.key] is ts
for worker, tss in self.data_needed_per_worker.items():
for ts in tss:
assert ts in self.data_needed
assert ts.state == "fetch"
assert self.tasks[ts.key] is ts
assert ts in self.data_needed
assert worker in ts.who_has

for ts in self.tasks.values():
Expand Down
8 changes: 7 additions & 1 deletion distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ def __repr__(self) -> str:
return f"<TaskState {self.key!r} {self.state}>"

def __eq__(self, other: object) -> bool:
return isinstance(other, TaskState) and other.key == self.key
if not isinstance(other, TaskState) or other.key != self.key:
return False
# When a task transitions to forgotten and exits Worker.tasks, it should be
# immediately dereferenced. If the same task is recreated later on on the
# worker, we should not have to deal with its previous incarnation lingering.
assert other is self
return True

def __hash__(self) -> int:
return hash(self.key)
Expand Down

0 comments on commit 23e6698

Please sign in to comment.