diff --git a/distributed/collections.py b/distributed/collections.py new file mode 100644 index 00000000000..dbe4ed2268e --- /dev/null +++ b/distributed/collections.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from collections import OrderedDict, UserDict +from typing import TypeVar, cast + +K = TypeVar("K") +V = TypeVar("V") + + +class LRU(UserDict[K, V]): + """Limited size mapping, evicting the least recently looked-up key when full""" + + def __init__(self, maxsize: float): + super().__init__() + self.data = OrderedDict() + self.maxsize = maxsize + + def __getitem__(self, key): + value = super().__getitem__(key) + cast(OrderedDict, self.data).move_to_end(key) + return value + + def __setitem__(self, key, value): + if len(self) >= self.maxsize: + cast(OrderedDict, self.data).popitem(last=False) + super().__setitem__(key, value) diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py new file mode 100644 index 00000000000..d82e6ddb26f --- /dev/null +++ b/distributed/tests/test_collections.py @@ -0,0 +1,18 @@ +from distributed.collections import LRU + + +def test_lru(): + l = LRU(maxsize=3) + l["a"] = 1 + l["b"] = 2 + l["c"] = 3 + assert list(l.keys()) == ["a", "b", "c"] + + # Use "a" and ensure it becomes the most recently used item + l["a"] + assert list(l.keys()) == ["b", "c", "a"] + + # Ensure maxsize is respected + l["d"] = 4 + assert len(l) == 3 + assert list(l.keys()) == ["c", "a", "d"] diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index bc1dea2e81c..79e75fd3b35 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -19,7 +19,6 @@ from distributed.compatibility import MACOS, WINDOWS from distributed.metrics import time from distributed.utils import ( - LRU, All, Log, Logs, @@ -594,24 +593,6 @@ def test_parse_ports(): parse_ports("100.5") -def test_lru(): - - l = LRU(maxsize=3) - l["a"] = 1 - l["b"] = 2 - l["c"] = 3 - assert list(l.keys()) == ["a", "b", "c"] - - # Use "a" and ensure it becomes the most recently used item - l["a"] - assert list(l.keys()) == ["b", "c", "a"] - - # Ensure maxsize is respected - l["d"] = 4 - assert len(l) == 3 - assert list(l.keys()) == ["c", "a", "d"] - - @gen_test() async def test_offload(): assert (await offload(inc, 1)) == 2 diff --git a/distributed/utils.py b/distributed/utils.py index ad6745a09fd..a9ae1381473 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -19,7 +19,7 @@ import weakref import xml.etree.ElementTree from asyncio import TimeoutError -from collections import OrderedDict, UserDict, deque +from collections import deque from collections.abc import Callable, Collection, Container, KeysView, ValuesView from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401 from contextlib import contextmanager, suppress @@ -1447,25 +1447,6 @@ async def __aexit__(self, exc_type, exc_value, traceback): empty_context = EmptyContext() -class LRU(UserDict): - """Limited size mapping, evicting the least recently looked-up key when full""" - - def __init__(self, maxsize): - super().__init__() - self.data = OrderedDict() - self.maxsize = maxsize - - def __getitem__(self, key): - value = super().__getitem__(key) - self.data.move_to_end(key) - return value - - def __setitem__(self, key, value): - if len(self) >= self.maxsize: - self.data.popitem(last=False) - super().__setitem__(key, value) - - def clean_dashboard_address(addrs: AnyType, default_listen_ip: str = "") -> list[dict]: """ Examples diff --git a/distributed/worker.py b/distributed/worker.py index c146cb56fff..8c06183096d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -50,6 +50,7 @@ from distributed import comm, preloading, profile, utils from distributed._stories import worker_story from distributed.batched import BatchedSend +from distributed.collections import LRU from distributed.comm import connect, get_address_host from distributed.comm.addressing import address_from_user_args, parse_address from distributed.comm.utils import OFFLOAD_THRESHOLD @@ -79,7 +80,6 @@ from distributed.threadpoolexecutor import ThreadPoolExecutor from distributed.threadpoolexecutor import secede as tpe_secede from distributed.utils import ( - LRU, TimeoutError, _maybe_complex, get_ip, @@ -4596,10 +4596,10 @@ async def _get_data(): job_counter = [0] -cache_loads = LRU(maxsize=100) +cache_loads: LRU[bytes, Callable] = LRU(maxsize=100) -def loads_function(bytes_object): +def loads_function(bytes_object: bytes) -> Callable: """Load a function from bytes, cache bytes""" if len(bytes_object) < 100000: try: @@ -4646,12 +4646,12 @@ def execute_task(task): return task -cache_dumps = LRU(maxsize=100) +cache_dumps: LRU[Callable, bytes] = LRU(maxsize=100) _cache_lock = threading.Lock() -def dumps_function(func) -> bytes: +def dumps_function(func: Callable) -> bytes: """Dump a function to bytes, cache functions""" try: with _cache_lock: