From dbad63f1ae9e666e96c03d40eb6c24e4cc73fc4f Mon Sep 17 00:00:00 2001 From: Alan Du Date: Fri, 20 Oct 2023 12:49:15 -0400 Subject: [PATCH] Add typehints for `fsspec.caching` and `fsspec.utils` (#1284) --- fsspec/caching.py | 227 ++++++++++++++++++++++++++++++---------------- fsspec/spec.py | 2 +- fsspec/utils.py | 118 +++++++++++++++++------- setup.cfg | 7 ++ 4 files changed, 243 insertions(+), 111 deletions(-) diff --git a/fsspec/caching.py b/fsspec/caching.py index 34365acb4..2812724fc 100644 --- a/fsspec/caching.py +++ b/fsspec/caching.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import functools import logging @@ -5,10 +7,34 @@ import os import threading import warnings -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Generic, + NamedTuple, + OrderedDict, + TypeVar, +) + +if TYPE_CHECKING: + import mmap + + from typing_extensions import ParamSpec + + P = ParamSpec("P") +else: + P = TypeVar("P") + +T = TypeVar("T") + logger = logging.getLogger("fsspec") +Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes + class BaseCache: """Pass-though cache: doesn't keep anything, calls every time @@ -26,14 +52,14 @@ class BaseCache: How big this file is """ - name = "none" + name: ClassVar[str] = "none" - def __init__(self, blocksize, fetcher, size): + def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None: self.blocksize = blocksize self.fetcher = fetcher self.size = size - def _fetch(self, start, stop): + def _fetch(self, start: int | None, stop: int | None) -> bytes: if start is None: start = 0 if stop is None: @@ -54,13 +80,20 @@ class MMapCache(BaseCache): name = "mmap" - def __init__(self, blocksize, fetcher, size, location=None, blocks=None): + def __init__( + self, + blocksize: int, + fetcher: Fetcher, + size: int, + location: str | None = None, + blocks: set[int] | None = None, + ) -> None: super().__init__(blocksize, fetcher, size) self.blocks = set() if blocks is None else blocks self.location = location self.cache = self._makefile() - def _makefile(self): + def _makefile(self) -> mmap.mmap | bytearray: import mmap import tempfile @@ -82,7 +115,7 @@ def _makefile(self): return mmap.mmap(fd.fileno(), self.size) - def _fetch(self, start, end): + def _fetch(self, start: int | None, end: int | None) -> bytes: logger.debug(f"MMap cache fetching {start}-{end}") if start is None: start = 0 @@ -105,13 +138,13 @@ def _fetch(self, start, end): return self.cache[start:end] - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # Remove the unpicklable entries. del state["cache"] return state - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: # Restore instance attributes self.__dict__.update(state) self.cache = self._makefile() @@ -127,13 +160,13 @@ class ReadAheadCache(BaseCache): name = "readahead" - def __init__(self, blocksize, fetcher, size): + def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None: super().__init__(blocksize, fetcher, size) self.cache = b"" self.start = 0 self.end = 0 - def _fetch(self, start, end): + def _fetch(self, start: int | None, end: int | None) -> bytes: if start is None: start = 0 if end is None or end > self.size: @@ -168,11 +201,11 @@ class FirstChunkCache(BaseCache): name = "first" - def __init__(self, blocksize, fetcher, size): + def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None: super().__init__(blocksize, fetcher, size) - self.cache = None + self.cache: bytes | None = None - def _fetch(self, start, end): + def _fetch(self, start: int | None, end: int | None) -> bytes: start = start or 0 end = end or self.size if start < self.blocksize: @@ -215,13 +248,15 @@ class BlockCache(BaseCache): name = "blockcache" - def __init__(self, blocksize, fetcher, size, maxblocks=32): + def __init__( + self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32 + ) -> None: super().__init__(blocksize, fetcher, size) self.nblocks = math.ceil(size / blocksize) self.maxblocks = maxblocks self._fetch_block_cached = functools.lru_cache(maxblocks)(self._fetch_block) - def __repr__(self): + def __repr__(self) -> str: return ( f"" @@ -238,18 +273,18 @@ def cache_info(self): """ return self._fetch_block_cached.cache_info() - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = self.__dict__ del state["_fetch_block_cached"] return state - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) self._fetch_block_cached = functools.lru_cache(state["maxblocks"])( self._fetch_block ) - def _fetch(self, start, end): + def _fetch(self, start: int | None, end: int | None) -> bytes: if start is None: start = 0 if end is None: @@ -272,7 +307,7 @@ def _fetch(self, start, end): end_block_number=end_block_number, ) - def _fetch_block(self, block_number): + def _fetch_block(self, block_number: int) -> bytes: """ Fetch the block of data for `block_number`. """ @@ -288,7 +323,9 @@ def _fetch_block(self, block_number): block_contents = super()._fetch(start, end) return block_contents - def _read_cache(self, start, end, start_block_number, end_block_number): + def _read_cache( + self, start: int, end: int, start_block_number: int, end_block_number: int + ) -> bytes: """ Read from our block cache. @@ -303,7 +340,7 @@ def _read_cache(self, start, end, start_block_number, end_block_number): end_pos = end % self.blocksize if start_block_number == end_block_number: - block = self._fetch_block_cached(start_block_number) + block: bytes = self._fetch_block_cached(start_block_number) return block[start_pos:end_pos] else: @@ -336,16 +373,18 @@ class BytesCache(BaseCache): we are more than a blocksize ahead of it. """ - name = "bytes" + name: ClassVar[str] = "bytes" - def __init__(self, blocksize, fetcher, size, trim=True): + def __init__( + self, blocksize: int, fetcher: Fetcher, size: int, trim: bool = True + ) -> None: super().__init__(blocksize, fetcher, size) self.cache = b"" - self.start = None - self.end = None + self.start: int | None = None + self.end: int | None = None self.trim = trim - def _fetch(self, start, end): + def _fetch(self, start: int | None, end: int | None) -> bytes: # TODO: only set start/end after fetch, in case it fails? # is this where retry logic might go? if start is None: @@ -378,23 +417,27 @@ def _fetch(self, start, end): # First read, or extending both before and after self.cache = self.fetcher(start, bend) self.start = start - elif start < self.start: - if self.end - end > self.blocksize: - self.cache = self.fetcher(start, bend) - self.start = start - else: - new = self.fetcher(start, self.start) - self.start = start - self.cache = new + self.cache - elif bend > self.end: - if self.end > self.size: - pass - elif end - self.end > self.blocksize: - self.cache = self.fetcher(start, bend) - self.start = start - else: - new = self.fetcher(self.end, bend) - self.cache = self.cache + new + else: + assert self.start is not None + assert self.end is not None + + if start < self.start: + if self.end is None or self.end - end > self.blocksize: + self.cache = self.fetcher(start, bend) + self.start = start + else: + new = self.fetcher(start, self.start) + self.start = start + self.cache = new + self.cache + elif self.end is not None and bend > self.end: + if self.end > self.size: + pass + elif end - self.end > self.blocksize: + self.cache = self.fetcher(start, bend) + self.start = start + else: + new = self.fetcher(self.end, bend) + self.cache = self.cache + new self.end = self.start + len(self.cache) offset = start - self.start @@ -406,23 +449,29 @@ def _fetch(self, start, end): self.cache = self.cache[self.blocksize * num :] return out - def __len__(self): + def __len__(self) -> int: return len(self.cache) class AllBytes(BaseCache): """Cache entire contents of the file""" - name = "all" + name: ClassVar[str] = "all" - def __init__(self, blocksize=None, fetcher=None, size=None, data=None): - super().__init__(blocksize, fetcher, size) + def __init__( + self, + blocksize: int | None = None, + fetcher: Fetcher | None = None, + size: int | None = None, + data: bytes | None = None, + ) -> None: + super().__init__(blocksize, fetcher, size) # type: ignore[arg-type] if data is None: data = self.fetcher(0, self.size) self.data = data - def _fetch(self, start, end): - return self.data[start:end] + def _fetch(self, start: int | None, stop: int | None) -> bytes: + return self.data[start:stop] class KnownPartsOfAFile(BaseCache): @@ -448,9 +497,17 @@ class KnownPartsOfAFile(BaseCache): begin outside a known byte-range. """ - name = "parts" - - def __init__(self, blocksize, fetcher, size, data={}, strict=True, **_): + name: ClassVar[str] = "parts" + + def __init__( + self, + blocksize: int, + fetcher: Fetcher, + size: int, + data: dict[tuple[int, int], bytes] = {}, + strict: bool = True, + **_: Any, + ): super().__init__(blocksize, fetcher, size) self.strict = strict @@ -472,7 +529,12 @@ def __init__(self, blocksize, fetcher, size, data={}, strict=True, **_): else: self.data = data - def _fetch(self, start, stop): + def _fetch(self, start: int | None, stop: int | None) -> bytes: + if start is None: + start = 0 + if stop is None: + stop = self.size + out = b"" for (loc0, loc1), data in self.data.items(): # If self.strict=False, use zero-padded data @@ -510,33 +572,37 @@ def _fetch(self, start, stop): return out + super()._fetch(start, stop) -class UpdatableLRU: +class UpdatableLRU(Generic[P, T]): """ Custom implementation of LRU cache that allows updating keys Used by BackgroudBlockCache """ - CacheInfo = collections.namedtuple( - "CacheInfo", ["hits", "misses", "maxsize", "currsize"] - ) + class CacheInfo(NamedTuple): + hits: int + misses: int + maxsize: int + currsize: int - def __init__(self, func, max_size=128): - self._cache = collections.OrderedDict() + def __init__(self, func: Callable[P, T], max_size: int = 128) -> None: + self._cache: OrderedDict[Any, T] = collections.OrderedDict() self._func = func self._max_size = max_size self._hits = 0 self._misses = 0 self._lock = threading.Lock() - def __call__(self, *args): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + if kwargs: + raise TypeError(f"Got unexpected keyword argument {kwargs.keys()}") with self._lock: if args in self._cache: self._cache.move_to_end(args) self._hits += 1 return self._cache[args] - result = self._func(*args) + result = self._func(*args, **kwargs) with self._lock: self._cache[args] = result @@ -546,17 +612,17 @@ def __call__(self, *args): return result - def is_key_cached(self, *args): + def is_key_cached(self, *args: Any) -> bool: with self._lock: return args in self._cache - def add_key(self, result, *args): + def add_key(self, result: T, *args: Any) -> None: with self._lock: self._cache[args] = result if len(self._cache) > self._max_size: self._cache.popitem(last=False) - def cache_info(self): + def cache_info(self) -> UpdatableLRU.CacheInfo: with self._lock: return self.CacheInfo( maxsize=self._max_size, @@ -592,26 +658,28 @@ class BackgroundBlockCache(BaseCache): use for this cache is then ``blocksize * maxblocks``. """ - name = "background" + name: ClassVar[str] = "background" - def __init__(self, blocksize, fetcher, size, maxblocks=32): + def __init__( + self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32 + ) -> None: super().__init__(blocksize, fetcher, size) self.nblocks = math.ceil(size / blocksize) self.maxblocks = maxblocks self._fetch_block_cached = UpdatableLRU(self._fetch_block, maxblocks) self._thread_executor = ThreadPoolExecutor(max_workers=1) - self._fetch_future_block_number = None - self._fetch_future = None + self._fetch_future_block_number: int | None = None + self._fetch_future: Future[bytes] | None = None self._fetch_future_lock = threading.Lock() - def __repr__(self): + def __repr__(self) -> str: return ( f"" ) - def cache_info(self): + def cache_info(self) -> UpdatableLRU.CacheInfo: """ The statistics on the block cache. @@ -622,7 +690,7 @@ def cache_info(self): """ return self._fetch_block_cached.cache_info() - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = self.__dict__ del state["_fetch_block_cached"] del state["_thread_executor"] @@ -631,7 +699,7 @@ def __getstate__(self): del state["_fetch_future_lock"] return state - def __setstate__(self, state): + def __setstate__(self, state) -> None: self.__dict__.update(state) self._fetch_block_cached = UpdatableLRU(self._fetch_block, state["maxblocks"]) self._thread_executor = ThreadPoolExecutor(max_workers=1) @@ -639,7 +707,7 @@ def __setstate__(self, state): self._fetch_future = None self._fetch_future_lock = threading.Lock() - def _fetch(self, start, end): + def _fetch(self, start: int | None, end: int | None) -> bytes: if start is None: start = 0 if end is None: @@ -656,6 +724,7 @@ def _fetch(self, start, end): with self._fetch_future_lock: # Background thread is running. Check we we can or must join it. if self._fetch_future is not None: + assert self._fetch_future_block_number is not None if self._fetch_future.done(): logger.info("BlockCache joined background fetch without waiting.") self._fetch_block_cached.add_key( @@ -714,7 +783,7 @@ def _fetch(self, start, end): end_block_number=end_block_number, ) - def _fetch_block(self, block_number, log_info="sync"): + def _fetch_block(self, block_number: int, log_info: str = "sync") -> bytes: """ Fetch the block of data for `block_number`. """ @@ -730,7 +799,9 @@ def _fetch_block(self, block_number, log_info="sync"): block_contents = super()._fetch(start, end) return block_contents - def _read_cache(self, start, end, start_block_number, end_block_number): + def _read_cache( + self, start: int, end: int, start_block_number: int, end_block_number: int + ) -> bytes: """ Read from our block cache. @@ -765,13 +836,13 @@ def _read_cache(self, start, end, start_block_number, end_block_number): return b"".join(out) -caches = { +caches: dict[str | None, type[BaseCache]] = { # one custom case None: BaseCache, } -def register_cache(cls, clobber=False): +def register_cache(cls: type[BaseCache], clobber: bool = False) -> None: """'Register' cache implementation. Parameters diff --git a/fsspec/spec.py b/fsspec/spec.py index 2af44f780..4ab3b7ee3 100644 --- a/fsspec/spec.py +++ b/fsspec/spec.py @@ -196,7 +196,7 @@ def _strip_protocol(cls, path): # use of root_marker to make minimum required path, e.g., "/" return path or cls.root_marker - def unstrip_protocol(self, name): + def unstrip_protocol(self, name: str) -> str: """Format FS-specific path to generic, including protocol""" protos = (self.protocol,) if isinstance(self.protocol, str) else self.protocol for protocol in protos: diff --git a/fsspec/utils.py b/fsspec/utils.py index 991ce8d71..c7e3e3ddc 100644 --- a/fsspec/utils.py +++ b/fsspec/utils.py @@ -11,12 +11,32 @@ from functools import partial from hashlib import md5 from importlib.metadata import version +from typing import ( + IO, + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + Sequence, + TypeVar, +) from urllib.parse import urlsplit +if TYPE_CHECKING: + from typing_extensions import TypeGuard + + from fsspec.spec import AbstractFileSystem + + DEFAULT_BLOCK_SIZE = 5 * 2**20 +T = TypeVar("T") -def infer_storage_options(urlpath, inherit_storage_options=None): + +def infer_storage_options( + urlpath: str, inherit_storage_options: dict[str, Any] | None = None +) -> dict[str, Any]: """Infer storage options from URL path and merge it with existing storage options. @@ -68,7 +88,7 @@ def infer_storage_options(urlpath, inherit_storage_options=None): # for HTTP, we don't want to parse, as requests will anyway return {"protocol": protocol, "path": urlpath} - options = {"protocol": protocol, "path": path} + options: dict[str, Any] = {"protocol": protocol, "path": path} if parsed_path.netloc: # Parse `hostname` from netloc manually because `parsed_path.hostname` @@ -98,7 +118,9 @@ def infer_storage_options(urlpath, inherit_storage_options=None): return options -def update_storage_options(options, inherited=None): +def update_storage_options( + options: dict[str, Any], inherited: dict[str, Any] | None = None +) -> None: if not inherited: inherited = {} collisions = set(options) & set(inherited) @@ -116,7 +138,7 @@ def update_storage_options(options, inherited=None): compressions: dict[str, str] = {} -def infer_compression(filename): +def infer_compression(filename: str) -> str | None: """Infer compression, if available, from filename. Infer a named compression type, if registered and available, from filename @@ -126,9 +148,10 @@ def infer_compression(filename): extension = os.path.splitext(filename)[-1].strip(".").lower() if extension in compressions: return compressions[extension] + return None -def build_name_function(max_int): +def build_name_function(max_int: float) -> Callable[[int], str]: """Returns a function that receives a single integer and returns it as a string padded by enough zero characters to align with maximum possible integer @@ -151,13 +174,13 @@ def build_name_function(max_int): pad_length = int(math.ceil(math.log10(max_int))) - def name_function(i): + def name_function(i: int) -> str: return str(i).zfill(pad_length) return name_function -def seek_delimiter(file, delimiter, blocksize): +def seek_delimiter(file: IO[bytes], delimiter: bytes, blocksize: int) -> bool: r"""Seek current file to file start, file end, or byte after delimiter seq. Seeks file to next chunk delimiter, where chunks are defined on file start, @@ -186,7 +209,7 @@ def seek_delimiter(file, delimiter, blocksize): # Interface is for binary IO, with delimiter as bytes, but initialize last # with result of file.read to preserve compatibility with text IO. - last = None + last: bytes | None = None while True: current = file.read(blocksize) if not current: @@ -206,7 +229,13 @@ def seek_delimiter(file, delimiter, blocksize): last = full[-len(delimiter) :] -def read_block(f, offset, length, delimiter=None, split_before=False): +def read_block( + f: IO[bytes], + offset: int, + length: int | None, + delimiter: bytes | None = None, + split_before: bool = False, +) -> bytes: """Read a block of bytes from a file Parameters @@ -267,11 +296,14 @@ def read_block(f, offset, length, delimiter=None, split_before=False): length = end - start f.seek(offset) + + # TODO: allow length to be None and read to the end of the file? + assert length is not None b = f.read(length) return b -def tokenize(*args, **kwargs): +def tokenize(*args: Any, **kwargs: Any) -> str: """Deterministic token (modified from dask.base) @@ -285,13 +317,14 @@ def tokenize(*args, **kwargs): if kwargs: args += (kwargs,) try: - return md5(str(args).encode()).hexdigest() + h = md5(str(args).encode()) except ValueError: # FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380 - return md5(str(args).encode(), usedforsecurity=False).hexdigest() + h = md5(str(args).encode(), usedforsecurity=False) # type: ignore[call-arg] + return h.hexdigest() -def stringify_path(filepath): +def stringify_path(filepath: str | os.PathLike[str] | pathlib.Path) -> str: """Attempt to convert a path-like object to a string. Parameters @@ -322,16 +355,18 @@ def stringify_path(filepath): elif hasattr(filepath, "path"): return filepath.path else: - return filepath + return filepath # type: ignore[return-value] -def make_instance(cls, args, kwargs): +def make_instance( + cls: Callable[..., T], args: Sequence[Any], kwargs: dict[str, Any] +) -> T: inst = cls(*args, **kwargs) - inst._determine_worker() + inst._determine_worker() # type: ignore[attr-defined] return inst -def common_prefix(paths): +def common_prefix(paths: Iterable[str]) -> str: """For a list of paths, find the shortest prefix common to all""" parts = [p.split("/") for p in paths] lmax = min(len(p) for p in parts) @@ -344,7 +379,12 @@ def common_prefix(paths): return "/".join(parts[0][:i]) -def other_paths(paths, path2, exists=False, flatten=False): +def other_paths( + paths: list[str], + path2: str | list[str], + exists: bool = False, + flatten: bool = False, +) -> list[str]: """In bulk file operations, construct a new file tree from a list of files Parameters @@ -384,25 +424,25 @@ def other_paths(paths, path2, exists=False, flatten=False): return path2 -def is_exception(obj): +def is_exception(obj: Any) -> bool: return isinstance(obj, BaseException) -def isfilelike(f): +def isfilelike(f: Any) -> TypeGuard[IO[bytes]]: for attr in ["read", "close", "tell"]: if not hasattr(f, attr): return False return True -def get_protocol(url): +def get_protocol(url: str) -> str: parts = re.split(r"(\:\:|\://)", url, 1) if len(parts) > 1: return parts[0] return "file" -def can_be_local(path): +def can_be_local(path: str) -> bool: """Can the given URL be used with open_local?""" from fsspec import get_filesystem_class @@ -413,7 +453,7 @@ def can_be_local(path): return False -def get_package_version_without_import(name): +def get_package_version_without_import(name: str) -> str | None: """For given package name, try to find the version without importing it Import and package.__version__ is still the backup here, so an import @@ -439,7 +479,12 @@ def get_package_version_without_import(name): return None -def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True): +def setup_logging( + logger: logging.Logger | None = None, + logger_name: str | None = None, + level: str = "DEBUG", + clear: bool = True, +) -> logging.Logger: if logger is None and logger_name is None: raise ValueError("Provide either logger object or logger name") logger = logger or logging.getLogger(logger_name) @@ -455,20 +500,22 @@ def setup_logging(logger=None, logger_name=None, level="DEBUG", clear=True): return logger -def _unstrip_protocol(name, fs): +def _unstrip_protocol(name: str, fs: AbstractFileSystem) -> str: return fs.unstrip_protocol(name) -def mirror_from(origin_name, methods): +def mirror_from( + origin_name: str, methods: Iterable[str] +) -> Callable[[type[T]], type[T]]: """Mirror attributes and methods from the given origin_name attribute of the instance to the decorated class""" - def origin_getter(method, self): + def origin_getter(method: str, self: Any) -> Any: origin = getattr(self, origin_name) return getattr(origin, method) - def wrapper(cls): + def wrapper(cls: type[T]) -> type[T]: for method in methods: wrapped_method = partial(origin_getter, method) setattr(cls, method, property(wrapped_method)) @@ -478,11 +525,18 @@ def wrapper(cls): @contextlib.contextmanager -def nullcontext(obj): +def nullcontext(obj: T) -> Iterator[T]: yield obj -def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=True): +def merge_offset_ranges( + paths: list[str], + starts: list[int] | int, + ends: list[int] | int, + max_gap: int = 0, + max_block: int | None = None, + sort: bool = True, +) -> tuple[list[str], list[int], list[int]]: """Merge adjacent byte-offset ranges when the inter-range gap is <= `max_gap`, and when the merged byte range does not exceed `max_block` (if specified). By default, this function @@ -496,7 +550,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru if not isinstance(starts, list): starts = [starts] * len(paths) if not isinstance(ends, list): - ends = [starts] * len(paths) + ends = [ends] * len(paths) if len(starts) != len(paths) or len(ends) != len(paths): raise ValueError @@ -545,7 +599,7 @@ def merge_offset_ranges(paths, starts, ends, max_gap=0, max_block=None, sort=Tru return paths, starts, ends -def file_size(filelike): +def file_size(filelike: IO[bytes]) -> int: """Find length of any open read-mode file-like""" pos = filelike.tell() try: diff --git a/setup.cfg b/setup.cfg index 8a8bdee72..42e7ad282 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,3 +52,10 @@ warn_unused_ignores = True # don't bother type-checking test_*.py or conftest.py files exclude = (test.*|conftest)\.py$ + +# TODO: turn on globally once we type enough of fsspec +[mypy-fsspec.caching] +check_untyped_defs = True + +[mypy-fsspec.utils] +check_untyped_defs = True