diff --git a/distributed/spill.py b/distributed/spill.py index 7cba8161a7..0ff9422371 100644 --- a/distributed/spill.py +++ b/distributed/spill.py @@ -2,7 +2,7 @@ import logging import time -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping from contextlib import contextmanager from functools import partial from typing import Any, Literal, NamedTuple, cast @@ -14,7 +14,9 @@ from distributed.sizeof import safe_sizeof logger = logging.getLogger(__name__) -has_zict_210 = parse_version(zict.__version__) > parse_version("2.0.0") +has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0") +# At the moment of writing, zict 2.2.0 has not been released yet. Support git tip. +has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0.dev2") class SpilledSize(NamedTuple): @@ -38,7 +40,7 @@ class SpillBuffer(zict.Buffer): the total size of the stored data exceeds the target. If max_spill is provided the key/value pairs won't be spilled once this threshold has been reached. - Paramaters + Parameters ---------- spill_directory: str Location on disk to write the spill files to @@ -63,14 +65,15 @@ def __init__( ): if max_spill is not False and not has_zict_210: - raise ValueError("zict > 2.0.0 required to set max_weight") + raise ValueError("zict >= 2.1.0 required to set max-spill") - super().__init__( - fast={}, - slow=Slow(spill_directory, max_spill), - n=target, - weight=_in_memory_weight, - ) + slow: MutableMapping[str, Any] = Slow(spill_directory, max_spill) + if has_zict_220: + # If a value is still in use somewhere on the worker since the last time it + # was unspilled, don't duplicate it + slow = zict.Cache(slow, zict.WeakValueMapping()) + + super().__init__(fast={}, slow=slow, n=target, weight=_in_memory_weight) self.last_logged = 0 self.min_log_interval = min_log_interval self.logged_pickle_errors = set() # keys logged with pickle error @@ -204,7 +207,8 @@ def spilled_total(self) -> SpilledSize: The two may differ substantially, e.g. if sizeof() is inaccurate or in case of compression. """ - return cast(Slow, self.slow).total_weight + slow = cast(zict.Cache, self.slow).data if has_zict_220 else self.slow + return cast(Slow, slow).total_weight def _in_memory_weight(key: str, value: Any) -> int: @@ -224,6 +228,7 @@ class HandledError(Exception): pass +# zict.Func[str, Any] requires zict >= 2.2.0 class Slow(zict.Func): max_weight: int | Literal[False] weight_by_key: dict[str, SpilledSize] diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index c30aa6cefc..01f2369159 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -1,20 +1,27 @@ from __future__ import annotations +import gc import logging import os import pytest -zict = pytest.importorskip("zict") -from packaging.version import parse as parse_version - from dask.sizeof import sizeof from distributed.compatibility import WINDOWS from distributed.protocol import serialize_bytelist -from distributed.spill import SpillBuffer +from distributed.spill import SpillBuffer, has_zict_210, has_zict_220 from distributed.utils_test import captured_logger +requires_zict_210 = pytest.mark.skipif( + not has_zict_210, + reason="requires zict version >= 2.1.0", +) +requires_zict_220 = pytest.mark.skipif( + not has_zict_220, + reason="requires zict version >= 2.2.0", +) + def psize(*objs) -> tuple[int, int]: return ( @@ -23,15 +30,30 @@ def psize(*objs) -> tuple[int, int]: ) +def assert_buf(buf: SpillBuffer, expect_fast: dict, expect_slow: dict) -> None: + # assertions on fast + assert dict(buf.fast) == expect_fast + assert buf.fast.weights == {k: sizeof(v) for k, v in expect_fast.items()} + assert buf.fast.total_weight == sum(sizeof(v) for v in expect_fast.values()) + for k, v in buf.fast.items(): + assert buf[k] is v + + # assertions on slow + assert set(buf.slow) == expect_slow.keys() + slow = buf.slow.data if has_zict_220 else buf.slow # type: ignore + assert slow.weight_by_key == {k: psize(v) for k, v in expect_slow.items()} + total_weight = psize(*expect_slow.values()) + assert slow.total_weight == total_weight + assert buf.spilled_total == total_weight + + def test_spillbuffer(tmpdir): buf = SpillBuffer(str(tmpdir), target=300) # Convenience aliases assert buf.memory is buf.fast assert buf.disk is buf.slow - assert not buf.slow.weight_by_key - assert buf.slow.total_weight == (0, 0) - assert buf.spilled_total == (0, 0) + assert_buf(buf, {}, {}) a, b, c, d = "a" * 100, "b" * 99, "c" * 98, "d" * 97 @@ -40,75 +62,48 @@ def test_spillbuffer(tmpdir): assert psize(a)[0] != psize(a)[1] buf["a"] = a - assert not buf.slow - assert buf.fast.weights == {"a": sizeof(a)} - assert buf.fast.total_weight == sizeof(a) - assert buf.slow.weight_by_key == {} - assert buf.slow.total_weight == (0, 0) + assert_buf(buf, {"a": a}, {}) assert buf["a"] == a buf["b"] = b - assert not buf.slow - assert not buf.slow.weight_by_key - assert buf.slow.total_weight == (0, 0) + assert_buf(buf, {"a": a, "b": b}, {}) buf["c"] = c - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} - assert buf.slow.total_weight == psize(a) + assert_buf(buf, {"b": b, "c": c}, {"a": a}) assert buf["a"] == a - assert set(buf.slow) == {"b"} - assert buf.slow.weight_by_key == {"b": psize(b)} - assert buf.slow.total_weight == psize(b) + assert_buf(buf, {"a": a, "c": c}, {"b": b}) buf["d"] = d - assert set(buf.slow) == {"b", "c"} - assert buf.slow.weight_by_key == {"b": psize(b), "c": psize(c)} - assert buf.slow.total_weight == psize(b, c) + assert_buf(buf, {"a": a, "d": d}, {"b": b, "c": c}) # Deleting an in-memory key does not automatically move spilled keys back to memory del buf["a"] - assert set(buf.slow) == {"b", "c"} - assert buf.slow.weight_by_key == {"b": psize(b), "c": psize(c)} - assert buf.slow.total_weight == psize(b, c) + assert_buf(buf, {"d": d}, {"b": b, "c": c}) with pytest.raises(KeyError): buf["a"] # Deleting a spilled key updates the metadata del buf["b"] - assert set(buf.slow) == {"c"} - assert buf.slow.weight_by_key == {"c": psize(c)} - assert buf.slow.total_weight == psize(c) + assert_buf(buf, {"d": d}, {"c": c}) with pytest.raises(KeyError): buf["b"] # Updating a spilled key moves it to the top of the LRU and to memory - buf["c"] = c * 2 - assert set(buf.slow) == {"d"} - assert buf.slow.weight_by_key == {"d": psize(d)} - assert buf.slow.total_weight == psize(d) + c2 = c * 2 + buf["c"] = c2 + assert_buf(buf, {"c": c2}, {"d": d}) # Single key is larger than target and goes directly into slow e = "e" * 500 buf["e"] = e - assert set(buf.slow) == {"d", "e"} - assert buf.slow.weight_by_key == {"d": psize(d), "e": psize(e)} - assert buf.slow.total_weight == psize(d, e) + assert_buf(buf, {"c": c2}, {"d": d, "e": e}) # Updating a spilled key with another larger than target updates slow directly d = "d" * 500 buf["d"] = d - assert set(buf.slow) == {"d", "e"} - assert buf.slow.weight_by_key == {"d": psize(d), "e": psize(e)} - assert buf.slow.total_weight == psize(d, e) - - -requires_zict_210 = pytest.mark.skipif( - parse_version(zict.__version__) <= parse_version("2.0.0"), - reason="requires zict version > 2.0.0", -) + assert_buf(buf, {"c": c2}, {"d": d, "e": e}) @requires_zict_210 @@ -120,32 +115,17 @@ def test_spillbuffer_maxlim(tmpdir): # size of a is bigger than target and is smaller than max_spill; # key should be in slow buf["a"] = a - assert not buf.fast - assert not buf.fast.weights - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} - assert buf.slow.total_weight == psize(a) + assert_buf(buf, {}, {"a": a}) assert buf["a"] == a # size of b is smaller than target key should be in fast buf["b"] = b - assert set(buf.fast) == {"b"} - assert buf.fast.weights == {"b": sizeof(b)} - assert buf["b"] == b - assert buf.fast.total_weight == sizeof(b) + assert_buf(buf, {"b": b}, {"a": a}) # size of c is smaller than target but b+c > target, c should stay in fast and b # move to slow since the max_spill limit has not been reached yet - buf["c"] = c - assert set(buf.fast) == {"c"} - assert buf.fast.weights == {"c": sizeof(c)} - assert buf["c"] == c - assert buf.fast.total_weight == sizeof(c) - - assert set(buf.slow) == {"a", "b"} - assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} - assert buf.slow.total_weight == psize(a, b) + assert_buf(buf, {"c": c}, {"a": a, "b": b}) # size of e < target but e+c > target, this will trigger movement of c to slow # but the max spill limit prevents it. Resulting in e remaining in fast @@ -154,15 +134,7 @@ def test_spillbuffer_maxlim(tmpdir): buf["e"] = e assert "disk reached capacity" in logs_e.getvalue() - - assert set(buf.fast) == {"c", "e"} - assert buf.fast.weights == {"c": sizeof(c), "e": sizeof(e)} - assert buf["e"] == e - assert buf.fast.total_weight == sizeof(c) + sizeof(e) - - assert set(buf.slow) == {"a", "b"} - assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} - assert buf.slow.total_weight == psize(a, b) + assert_buf(buf, {"c": c, "e": e}, {"a": a, "b": b}) # size of d > target, d should go to slow but slow reached the max_spill limit then # d will end up on fast with c (which can't be move to slow because it won't fit @@ -171,15 +143,7 @@ def test_spillbuffer_maxlim(tmpdir): buf["d"] = d assert "disk reached capacity" in logs_d.getvalue() - - assert set(buf.fast) == {"c", "d", "e"} - assert buf.fast.weights == {"c": sizeof(c), "d": sizeof(d), "e": sizeof(e)} - assert buf["d"] == d - assert buf.fast.total_weight == sizeof(c) + sizeof(d) + sizeof(e) - - assert set(buf.slow) == {"a", "b"} - assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} - assert buf.slow.total_weight == psize(a, b) + assert_buf(buf, {"c": c, "d": d, "e": e}, {"a": a, "b": b}) # Overwrite a key that was in slow, but the size of the new key is larger than # max_spill @@ -191,11 +155,7 @@ def test_spillbuffer_maxlim(tmpdir): buf["a"] = a_large assert "disk reached capacity" in logs_alarge.getvalue() - - assert set(buf.fast) == {"a", "d", "e"} - assert set(buf.slow) == {"b", "c"} - assert buf.fast.total_weight == sizeof(d) + sizeof(a_large) + sizeof(e) - assert buf.slow.total_weight == psize(b, c) + assert_buf(buf, {"a": a_large, "d": d, "e": e}, {"b": b, "c": c}) # Overwrite a key that was in fast, but the size of the new key is larger than # max_spill @@ -205,11 +165,7 @@ def test_spillbuffer_maxlim(tmpdir): buf["d"] = d_large assert "disk reached capacity" in logs_dlarge.getvalue() - - assert set(buf.fast) == {"a", "d", "e"} - assert set(buf.slow) == {"b", "c"} - assert buf.fast.total_weight == sizeof(a_large) + sizeof(d_large) + sizeof(e) - assert buf.slow.total_weight == psize(b, c) + assert_buf(buf, {"a": a_large, "d": d_large, "e": e}, {"b": b, "c": c}) class MyError(Exception): @@ -241,13 +197,12 @@ def test_spillbuffer_fail_to_serialize(tmpdir): # spill.py must remain silent because we're already logging in worker.py assert not logs_bad_key.getvalue() - assert not set(buf.fast) - assert not set(buf.slow) + assert_buf(buf, {}, {}) b = Bad(size=100) # this is small enough to fit in memory/fast buf["b"] = b - assert set(buf.fast) == {"b"} + assert_buf(buf, {"b": b}, {}) c = "c" * 100 with captured_logger(logging.getLogger("distributed.spill")) as logs_bad_key_mem: @@ -259,9 +214,7 @@ def test_spillbuffer_fail_to_serialize(tmpdir): logs_value = logs_bad_key_mem.getvalue() assert "Failed to pickle" in logs_value # from distributed.spill assert "Traceback" in logs_value # from distributed.spill - assert set(buf.fast) == {"b", "c"} - assert buf.fast.total_weight == sizeof(b) + sizeof(c) - assert not set(buf.slow) + assert_buf(buf, {"b": b, "c": c}, {}) @requires_zict_210 @@ -279,8 +232,7 @@ def test_spillbuffer_oserror(tmpdir): # let's have something in fast and something in slow buf["a"] = a buf["b"] = b - assert set(buf.fast) == {"b"} - assert set(buf.slow) == {"a"} + assert_buf(buf, {"b": b}, {"a": a}) # modify permissions of disk to be read only. # This causes writes to raise OSError, just like in case of disk full. @@ -291,15 +243,10 @@ def test_spillbuffer_oserror(tmpdir): buf["c"] = c assert "Spill to disk failed" in logs_oserror_slow.getvalue() - assert set(buf.fast) == {"b", "c"} - assert set(buf.slow) == {"a"} - - assert buf.slow.weight_by_key == {"a": psize(a)} - assert buf.fast.weights == {"b": sizeof(b), "c": sizeof(c)} + assert_buf(buf, {"b": b, "c": c}, {"a": a}) del buf["c"] - assert set(buf.fast) == {"b"} - assert set(buf.slow) == {"a"} + assert_buf(buf, {"b": b}, {"a": a}) # add key to fast which is smaller than target but when added it triggers spill, # which triggers OSError @@ -307,40 +254,26 @@ def test_spillbuffer_oserror(tmpdir): buf["d"] = d assert "Spill to disk failed" in logs_oserror_evict.getvalue() - assert set(buf.fast) == {"b", "d"} - assert set(buf.slow) == {"a"} - - assert buf.slow.weight_by_key == {"a": psize(a)} - assert buf.fast.weights == {"b": sizeof(b), "d": sizeof(d)} + assert_buf(buf, {"b": b, "d": d}, {"a": a}) @requires_zict_210 def test_spillbuffer_evict(tmpdir): buf = SpillBuffer(str(tmpdir), target=300, min_log_interval=0) - a_bad = Bad(size=100) + bad = Bad(size=100) a = "a" * 100 buf["a"] = a - - assert set(buf.fast) == {"a"} - assert not set(buf.slow) - assert buf.fast.weights == {"a": sizeof(a)} + assert_buf(buf, {"a": a}, {}) # successful eviction weight = buf.evict() assert weight == sizeof(a) + assert_buf(buf, {}, {"a": a}) - assert not buf.fast - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} - - buf["a_bad"] = a_bad - - assert set(buf.fast) == {"a_bad"} - assert buf.fast.weights == {"a_bad": sizeof(a_bad)} - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} + buf["bad"] = bad + assert_buf(buf, {"bad": bad}, {"a": a}) # unsuccessful eviction with captured_logger(logging.getLogger("distributed.spill")) as logs_evict_key: @@ -349,7 +282,63 @@ def test_spillbuffer_evict(tmpdir): assert "Failed to pickle" in logs_evict_key.getvalue() # bad keys stays in fast - assert set(buf.fast) == {"a_bad"} - assert buf.fast.weights == {"a_bad": sizeof(a_bad)} - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} + assert_buf(buf, {"bad": bad}, {"a": a}) + + +class SupportsWeakRef: + def __init__(self, n): + self.n = n + + def __sizeof__(self): + return self.n + + +class NoWeakRef: + __slots__ = ("n",) + + def __init__(self, n): + self.n = n + + def __sizeof__(self): + return self.n + + +@pytest.mark.parametrize( + "cls,expect_cached", + [ + (SupportsWeakRef, has_zict_220), + (NoWeakRef, False), + ], +) +@pytest.mark.parametrize("size", [60, 110]) +def test_weakref_cache(tmpdir, cls, expect_cached, size): + buf = SpillBuffer(str(tmpdir), target=100) + + # Run this test twice: + # - x is smaller than target and is evicted by y; + # - x is individually larger than target and it never touches fast + x = cls(size) + buf["x"] = x + if size < 100: + buf["y"] = cls(60) # spill x + assert "x" in buf.slow + + # Test that we update the weakref cache on setitem + assert (buf["x"] is x) == expect_cached + + id_x = id(x) + del x + gc.collect() # Only needed on pypy + + if size < 100: + buf["y"] + assert "x" in buf.slow + + x2 = buf["x"] + assert id(x2) != id_x + if size < 100: + buf["y"] + assert "x" in buf.slow + + # Test that we update the weakref cache on getitem + assert (buf["x"] is x2) == expect_cached diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 9fd29884ed..0aefe05bfb 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -16,7 +16,6 @@ import psutil import pytest -from packaging.version import parse as parse_version from tlz import first, pluck, sliding_window import dask @@ -43,6 +42,7 @@ from distributed.metrics import time from distributed.protocol import pickle from distributed.scheduler import Scheduler +from distributed.spill import has_zict_210 from distributed.utils import TimeoutError from distributed.utils_test import ( TaskStateMetadataPlugin, @@ -63,15 +63,9 @@ pytestmark = pytest.mark.ci1 -try: - import zict -except ImportError: - zict = None # type: ignore - -requires_zict = pytest.mark.skipif(not zict, reason="requires zict") requires_zict_210 = pytest.mark.skipif( - not zict or parse_version(zict.__version__) <= parse_version("2.0.0"), - reason="requires zict version > 2.0.0", + not has_zict_210, + reason="requires zict version >= 2.1.0", ) @@ -924,7 +918,6 @@ async def assert_basic_futures(c: Client) -> None: assert results == list(map(inc, range(10))) -@requires_zict @gen_cluster(client=True) async def test_fail_write_to_disk_target_1(c, s, a, b): """Test failure to spill triggered by key which is individually larger @@ -942,7 +935,6 @@ async def test_fail_write_to_disk_target_1(c, s, a, b): await assert_basic_futures(c) -@requires_zict @gen_cluster( client=True, nthreads=[("", 1)], @@ -965,10 +957,8 @@ async def test_fail_write_to_disk_target_2(c, s, a): y = c.submit(lambda: "y" * 256, key="y") await wait(y) - if parse_version(zict.__version__) <= parse_version("2.0.0"): - assert set(a.data.memory) == {"y"} - else: - assert set(a.data.memory) == {"x", "y"} + + assert set(a.data.memory) == {"x", "y"} if has_zict_210 else {"y"} assert not a.data.disk await assert_basic_futures(c) @@ -1187,7 +1177,6 @@ async def test_statistical_profiling_2(c, s, a, b): break -@requires_zict @gen_cluster( client=True, nthreads=[("", 1)], @@ -1277,7 +1266,6 @@ async def test_spill_constrained(c, s, w): assert set(w.data.disk) == {x.key} -@requires_zict @gen_cluster( nthreads=[("", 1)], client=True, @@ -1301,7 +1289,6 @@ async def test_spill_spill_threshold(c, s, a): assert await x == 1 -@requires_zict @pytest.mark.parametrize( "memory_target_fraction,managed,expect_spilled", [