diff --git a/lib/python/pyflyby/_file.py b/lib/python/pyflyby/_file.py index 51e4b2ba..1546cb0d 100644 --- a/lib/python/pyflyby/_file.py +++ b/lib/python/pyflyby/_file.py @@ -48,7 +48,7 @@ def __new__(cls, arg): def _from_filename(cls, filename: str): if not isinstance(filename, str): raise TypeError - filename = str(filename) + filename = os.path.abspath(filename) if not filename: raise UnsafeFilenameError("(empty string)") # we only allow filename with given character set @@ -58,7 +58,7 @@ def _from_filename(cls, filename: str): if re.search("(^|/)~", filename): raise UnsafeFilenameError(filename) self = object.__new__(cls) - self._filename = os.path.abspath(filename) + self._filename = filename return self def __str__(self): diff --git a/lib/python/pyflyby/_importdb.py b/lib/python/pyflyby/_importdb.py index 832cc803..d8d851b5 100644 --- a/lib/python/pyflyby/_importdb.py +++ b/lib/python/pyflyby/_importdb.py @@ -9,6 +9,8 @@ import re import sys +from pathlib import Path + from typing import Any, Dict, Tuple, Union, List from pyflyby._file import (Filename, UnsafeFilenameError, @@ -201,6 +203,8 @@ class ImportDB: mandatory_imports: ImportSet canonical_imports: ImportMap + _default_cache: Dict[Any, Any] = {} + def __new__(cls, *args): if len(args) != 1: raise TypeError @@ -213,7 +217,6 @@ def __new__(cls, *args): - _default_cache: Dict[Any, Any] = {} @classmethod def clear_default_cache(cls): @@ -224,24 +227,22 @@ def clear_default_cache(cls): cached results. Existing ImportDB instances are not affected by this call. """ - if cls._default_cache: - if logger.debug_enabled: - allpyfiles = set() - for tup in cls._default_cache: - if tup[0] != 2: - continue - for tup2 in tup[1:]: - for f in tup2: - assert isinstance(f, Filename) - if f.ext == '.py': - allpyfiles.add(f) - nfiles = len(allpyfiles) - logger.debug("ImportDB: Clearing default cache of %d files", - nfiles) - cls._default_cache.clear() + if logger.debug_enabled: + allpyfiles = set() + for tup in cls._default_cache: + if tup[0] != 2: + continue + for tup2 in tup[1:]: + for f in tup2: + assert isinstance(f, Filename) + if f.ext == ".py": + allpyfiles.add(f) + nfiles = len(allpyfiles) + logger.debug("ImportDB: Clearing default cache of %d files", nfiles) + cls._default_cache.clear() @classmethod - def get_default(cls, target_filename: Union[Filename, str]): + def get_default(cls, target_filename: Union[Filename, str], /): """ Return the default import library for the given target filename. @@ -258,7 +259,7 @@ def get_default(cls, target_filename: Union[Filename, str]): :rtype: `ImportDB` """ - # We're going to canonicalize target_filenames in a number of steps. + # We're going to canonicalize target_filename in a number of steps. # At each step, see if we've seen the input so far. We do the cache # checking incrementally since the steps involve syscalls. Since this # is going to potentially be executed inside the IPython interactive @@ -267,25 +268,55 @@ def get_default(cls, target_filename: Union[Filename, str]): # been touched, and if so, return new data. Check file timestamps at # most once every 60 seconds. cache_keys:List[Tuple[Any,...]] = [] - if target_filename: - if isinstance(target_filename, str): - target_filename = Filename(target_filename) - target_filename = target_filename or Filename(".") - assert isinstance(target_filename, Filename) + if target_filename is None: + target_filename = "." + + if isinstance(target_filename, Filename): + target_filename = str(target_filename) + + assert isinstance(target_filename, str), ( + target_filename, + type(target_filename), + ) + + target_path = Path(target_filename).resolve() + + parents: List[Path] + if not target_path.is_dir(): + parents = [target_path] + else: + parents = [] + + # filter safe parents + safe_parent = None + for p in parents + list(target_path.parents): + try: + safe_parent = Filename(str(p)) + break + except UnsafeFilenameError: + pass + if safe_parent is None: + raise ValueError("No know path are safe") + + target_dirname = safe_parent + if target_filename.startswith("/dev"): - target_filename = Filename(".") - target_dirname:Filename = target_filename - # TODO: with StatCache - while True: - cache_keys.append((1, - target_dirname, - os.getenv("PYFLYBY_PATH"), - os.getenv("PYFLYBY_KNOWN_IMPORTS_PATH"), - os.getenv("PYFLYBY_MANDATORY_IMPORTS_PATH"))) try: - return cls._default_cache[cache_keys[-1]] - except KeyError: + target_dirname = Filename(".") + except UnsafeFilenameError: pass + # TODO: with StatCache + while True: + key = ( + 1, + target_dirname, + os.getenv("PYFLYBY_PATH"), + os.getenv("PYFLYBY_KNOWN_IMPORTS_PATH"), + os.getenv("PYFLYBY_MANDATORY_IMPORTS_PATH"), + ) + cache_keys.append(key) + if key in cls._default_cache: + return cls._default_cache[key] if target_dirname.isdir: break target_dirname = target_dirname.dir