Skip to content

Commit

Permalink
Prevent data duplication when unspilling
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 14, 2022
1 parent 60ad756 commit 7843d4f
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 133 deletions.
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ repos:
- numpy
- dask
- tornado
- zict
# TEMP revert before merging
# - zict
- git+https://github.com/crusaderky/zict@cache
4 changes: 3 additions & 1 deletion continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ dependencies:
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/dask/s3fs
- git+https://github.com/dask/zict
# TEMP revert before merging
# - git+https://github.com/dask/zict
- git+https://github.com/crusaderky/zict@cache
# FIXME https://github.com/dask/distributed/issues/5345
# - git+https://github.com/intake/filesystem_spec
- git+https://github.com/joblib/joblib
Expand Down
23 changes: 14 additions & 9 deletions distributed/spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import time
from collections.abc import Mapping
from collections.abc import Mapping, MutableMapping
from contextlib import contextmanager
from functools import partial
from typing import Any, Literal, NamedTuple, cast
Expand All @@ -15,6 +15,8 @@

logger = logging.getLogger(__name__)
has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0")
# At the moment of writing, zict 2.2.0 has not been released yet. Support git tip.
has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0.dev2")


class SpilledSize(NamedTuple):
Expand All @@ -38,7 +40,7 @@ class SpillBuffer(zict.Buffer):
the total size of the stored data exceeds the target. If max_spill is provided the
key/value pairs won't be spilled once this threshold has been reached.
Paramaters
Parameters
----------
spill_directory: str
Location on disk to write the spill files to
Expand All @@ -65,12 +67,13 @@ def __init__(
if max_spill is not False and not has_zict_210:
raise ValueError("zict >= 2.1.0 required to set max-spill")

super().__init__(
fast={},
slow=Slow(spill_directory, max_spill),
n=target,
weight=_in_memory_weight,
)
slow: MutableMapping[str, Any] = Slow(spill_directory, max_spill)
if has_zict_220:
# If a value is still in use somewhere on the worker since the last time it
# was unspilled, don't duplicate it
slow = zict.Cache(slow, zict.WeakRefCache())

super().__init__(fast={}, slow=slow, n=target, weight=_in_memory_weight)
self.last_logged = 0
self.min_log_interval = min_log_interval
self.logged_pickle_errors = set() # keys logged with pickle error
Expand Down Expand Up @@ -204,7 +207,8 @@ def spilled_total(self) -> SpilledSize:
The two may differ substantially, e.g. if sizeof() is inaccurate or in case of
compression.
"""
return cast(Slow, self.slow).total_weight
slow = cast(zict.Cache, self.slow).data if has_zict_220 else self.slow
return cast(Slow, slow).total_weight


def _in_memory_weight(key: str, value: Any) -> int:
Expand All @@ -224,6 +228,7 @@ class HandledError(Exception):
pass


# zict.Func[str, Any] requires zict >= 2.2.0
class Slow(zict.Func):
max_weight: int | Literal[False]
weight_by_key: dict[str, SpilledSize]
Expand Down
Loading

0 comments on commit 7843d4f

Please sign in to comment.