Skip to content

Commit

Permalink
Refactor: new module distributed.collections
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 1, 2022
1 parent 9ef92e7 commit 9280162
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 44 deletions.
26 changes: 26 additions & 0 deletions distributed/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from collections import OrderedDict, UserDict
from typing import TypeVar, cast

K = TypeVar("K")
V = TypeVar("V")


class LRU(UserDict[K, V]):
"""Limited size mapping, evicting the least recently looked-up key when full"""

def __init__(self, maxsize: float):
super().__init__()
self.data = OrderedDict()
self.maxsize = maxsize

def __getitem__(self, key):
value = super().__getitem__(key)
cast(OrderedDict, self.data).move_to_end(key)
return value

def __setitem__(self, key, value):
if len(self) >= self.maxsize:
cast(OrderedDict, self.data).popitem(last=False)
super().__setitem__(key, value)
18 changes: 18 additions & 0 deletions distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from distributed.collections import LRU


def test_lru():
l = LRU(maxsize=3)
l["a"] = 1
l["b"] = 2
l["c"] = 3
assert list(l.keys()) == ["a", "b", "c"]

# Use "a" and ensure it becomes the most recently used item
l["a"]
assert list(l.keys()) == ["b", "c", "a"]

# Ensure maxsize is respected
l["d"] = 4
assert len(l) == 3
assert list(l.keys()) == ["c", "a", "d"]
19 changes: 0 additions & 19 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from distributed.compatibility import MACOS, WINDOWS
from distributed.metrics import time
from distributed.utils import (
LRU,
All,
Log,
Logs,
Expand Down Expand Up @@ -594,24 +593,6 @@ def test_parse_ports():
parse_ports("100.5")


def test_lru():

l = LRU(maxsize=3)
l["a"] = 1
l["b"] = 2
l["c"] = 3
assert list(l.keys()) == ["a", "b", "c"]

# Use "a" and ensure it becomes the most recently used item
l["a"]
assert list(l.keys()) == ["b", "c", "a"]

# Ensure maxsize is respected
l["d"] = 4
assert len(l) == 3
assert list(l.keys()) == ["c", "a", "d"]


@gen_test()
async def test_offload():
assert (await offload(inc, 1)) == 2
Expand Down
21 changes: 1 addition & 20 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import weakref
import xml.etree.ElementTree
from asyncio import TimeoutError
from collections import OrderedDict, UserDict, deque
from collections import deque
from collections.abc import Callable, Collection, Container, KeysView, ValuesView
from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401
from contextlib import contextmanager, suppress
Expand Down Expand Up @@ -1447,25 +1447,6 @@ async def __aexit__(self, exc_type, exc_value, traceback):
empty_context = EmptyContext()


class LRU(UserDict):
"""Limited size mapping, evicting the least recently looked-up key when full"""

def __init__(self, maxsize):
super().__init__()
self.data = OrderedDict()
self.maxsize = maxsize

def __getitem__(self, key):
value = super().__getitem__(key)
self.data.move_to_end(key)
return value

def __setitem__(self, key, value):
if len(self) >= self.maxsize:
self.data.popitem(last=False)
super().__setitem__(key, value)


def clean_dashboard_address(addrs: AnyType, default_listen_ip: str = "") -> list[dict]:
"""
Examples
Expand Down
10 changes: 5 additions & 5 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from distributed import comm, preloading, profile, utils
from distributed._stories import worker_story
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import connect, get_address_host
from distributed.comm.addressing import address_from_user_args, parse_address
from distributed.comm.utils import OFFLOAD_THRESHOLD
Expand Down Expand Up @@ -79,7 +80,6 @@
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
LRU,
TimeoutError,
_maybe_complex,
get_ip,
Expand Down Expand Up @@ -4596,10 +4596,10 @@ async def _get_data():
job_counter = [0]


cache_loads = LRU(maxsize=100)
cache_loads: LRU[bytes, Callable] = LRU(maxsize=100)


def loads_function(bytes_object):
def loads_function(bytes_object: bytes) -> Callable:
"""Load a function from bytes, cache bytes"""
if len(bytes_object) < 100000:
try:
Expand Down Expand Up @@ -4646,12 +4646,12 @@ def execute_task(task):
return task


cache_dumps = LRU(maxsize=100)
cache_dumps: LRU[Callable, bytes] = LRU(maxsize=100)

_cache_lock = threading.Lock()


def dumps_function(func) -> bytes:
def dumps_function(func: Callable) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
Expand Down

0 comments on commit 9280162

Please sign in to comment.