Skip to content

Commit

Permalink
Merge pull request #352 from Carreau/fix-unsafe-filenames
Browse files Browse the repository at this point in the history
Better handling of unsafe filenames
  • Loading branch information
Carreau authored Jul 29, 2024
2 parents 9b8498d + 3f912f9 commit c508269
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 52 deletions.
4 changes: 2 additions & 2 deletions lib/python/pyflyby/_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __new__(cls, arg):
def _from_filename(cls, filename: str):
if not isinstance(filename, str):
raise TypeError
filename = str(filename)
filename = os.path.abspath(filename)
if not filename:
raise UnsafeFilenameError("(empty string)")
# we only allow filename with given character set
Expand All @@ -58,7 +58,7 @@ def _from_filename(cls, filename: str):
if re.search("(^|/)~", filename):
raise UnsafeFilenameError(filename)
self = object.__new__(cls)
self._filename = os.path.abspath(filename)
self._filename = filename
return self

def __str__(self):
Expand Down
101 changes: 66 additions & 35 deletions lib/python/pyflyby/_importdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import re
import sys

from typing import Any, Dict, Tuple, Union, List
from pathlib import Path

from typing import Any, Dict, List, Tuple, Union

from pyflyby._file import (Filename, UnsafeFilenameError,
expand_py_files_from_args)
Expand Down Expand Up @@ -201,6 +203,8 @@ class ImportDB:
mandatory_imports: ImportSet
canonical_imports: ImportMap

_default_cache: Dict[Any, Any] = {}

def __new__(cls, *args):
if len(args) != 1:
raise TypeError
Expand All @@ -213,7 +217,6 @@ def __new__(cls, *args):



_default_cache: Dict[Any, Any] = {}

@classmethod
def clear_default_cache(cls):
Expand All @@ -224,24 +227,22 @@ def clear_default_cache(cls):
cached results. Existing ImportDB instances are not affected by this
call.
"""
if cls._default_cache:
if logger.debug_enabled:
allpyfiles = set()
for tup in cls._default_cache:
if tup[0] != 2:
continue
for tup2 in tup[1:]:
for f in tup2:
assert isinstance(f, Filename)
if f.ext == '.py':
allpyfiles.add(f)
nfiles = len(allpyfiles)
logger.debug("ImportDB: Clearing default cache of %d files",
nfiles)
cls._default_cache.clear()
if logger.debug_enabled:
allpyfiles = set()
for tup in cls._default_cache:
if tup[0] != 2:
continue
for tup2 in tup[1:]:
for f in tup2:
assert isinstance(f, Filename)
if f.ext == ".py":
allpyfiles.add(f)
nfiles = len(allpyfiles)
logger.debug("ImportDB: Clearing default cache of %d files", nfiles)
cls._default_cache.clear()

@classmethod
def get_default(cls, target_filename: Union[Filename, str]):
def get_default(cls, target_filename: Union[Filename, str], /):
"""
Return the default import library for the given target filename.
Expand All @@ -258,7 +259,7 @@ def get_default(cls, target_filename: Union[Filename, str]):
:rtype:
`ImportDB`
"""
# We're going to canonicalize target_filenames in a number of steps.
# We're going to canonicalize target_filename in a number of steps.
# At each step, see if we've seen the input so far. We do the cache
# checking incrementally since the steps involve syscalls. Since this
# is going to potentially be executed inside the IPython interactive
Expand All @@ -267,25 +268,55 @@ def get_default(cls, target_filename: Union[Filename, str]):
# been touched, and if so, return new data. Check file timestamps at
# most once every 60 seconds.
cache_keys:List[Tuple[Any,...]] = []
if target_filename:
if isinstance(target_filename, str):
target_filename = Filename(target_filename)
target_filename = target_filename or Filename(".")
assert isinstance(target_filename, Filename)
if target_filename is None:
target_filename = "."

if isinstance(target_filename, Filename):
target_filename = str(target_filename)

assert isinstance(target_filename, str), (
target_filename,
type(target_filename),
)

target_path = Path(target_filename).resolve()

parents: List[Path]
if not target_path.is_dir():
parents = [target_path]
else:
parents = []

# filter safe parents
safe_parent = None
for p in parents + list(target_path.parents):
try:
safe_parent = Filename(str(p))
break
except UnsafeFilenameError:
pass
if safe_parent is None:
raise ValueError("No know path are safe")

target_dirname = safe_parent

if target_filename.startswith("/dev"):
target_filename = Filename(".")
target_dirname:Filename = target_filename
# TODO: with StatCache
while True:
cache_keys.append((1,
target_dirname,
os.getenv("PYFLYBY_PATH"),
os.getenv("PYFLYBY_KNOWN_IMPORTS_PATH"),
os.getenv("PYFLYBY_MANDATORY_IMPORTS_PATH")))
try:
return cls._default_cache[cache_keys[-1]]
except KeyError:
target_dirname = Filename(".")
except UnsafeFilenameError:
pass
# TODO: with StatCache
while True:
key = (
1,
target_dirname,
os.getenv("PYFLYBY_PATH"),
os.getenv("PYFLYBY_KNOWN_IMPORTS_PATH"),
os.getenv("PYFLYBY_MANDATORY_IMPORTS_PATH"),
)
cache_keys.append(key)
if key in cls._default_cache:
return cls._default_cache[key]
if target_dirname.isdir:
break
target_dirname = target_dirname.dir
Expand Down
2 changes: 1 addition & 1 deletion lib/python/pyflyby/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def from_text(cls, text, filename=None, startpos=None, flags=None,
def __construct_from_annotated_ast(cls, annotated_ast_nodes, text:FileText, flags):
# Constructor for internal use by _split_by_statement() or
# concatenate().
ast_node = AnnotatedModule(annotated_ast_nodes)
ast_node = AnnotatedModule(annotated_ast_nodes, type_ignores=[])
ast_node.text = text
ast_node.flags = flags
if not hasattr(ast_node, "source_flags"):
Expand Down
60 changes: 46 additions & 14 deletions tests/test_cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,34 @@ def test_tidy_imports_quiet_1():

def test_tidy_imports_log_level_1():
with EnvVarCtx(PYFLYBY_LOG_LEVEL="WARNING"):
result = pipe([BIN_DIR+"/tidy-imports"], stdin="os, sys")
expected = dedent('''
result = pipe([BIN_DIR + "/tidy-imports"], stdin="os, sys")
expected = dedent(
"""
import os
import sys
os, sys
''').strip()
"""
).strip()
assert result == expected


def test_tidy_imports_filename_action_print_1():
with tempfile.NamedTemporaryFile(suffix=".py", mode='w+') as f:
f.write(dedent('''
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+") as f:
f.write(
dedent(
"""
# hello
def foo():
foo() + os + sys
''').lstrip())
"""
).lstrip()
)
f.flush()
result = pipe([BIN_DIR+"/tidy-imports", f.name])
expected = dedent('''
result = pipe([BIN_DIR + "/tidy-imports", f.name])
expected = (
dedent(
"""
[PYFLYBY] {f.name}: added 'import os'
[PYFLYBY] {f.name}: added 'import sys'
# hello
Expand All @@ -88,10 +96,26 @@ def foo():
def foo():
foo() + os + sys
''').strip().format(f=f)
"""
)
.strip()
.format(f=f)
)
assert result == expected


def test_unsafe_cwd():
with tempfile.TemporaryDirectory() as d:
from pathlib import Path

p = Path(d)
unsafe = p / "foo#bar" / "foo#qux"
unsafe.mkdir(parents=True)
result = pipe([BIN_DIR + "/py"], cwd=unsafe, stdin="os")
assert "Unsafe" not in result
assert result == "[PYFLYBY] import os"


def test_tidy_imports_filename_action_replace_1():
with tempfile.NamedTemporaryFile(suffix=".py", delete=False, mode='w+') as f:
f.write(dedent('''
Expand Down Expand Up @@ -763,24 +787,32 @@ def test_tidy_imports_forward_references():
with tempfile.TemporaryDirectory() as temp_dir:
foo = os.path.join(temp_dir, "foo.py")
with open(foo, "w") as foo_fp:
foo_fp.write(dedent("""
foo_fp.write(
dedent(
"""
from __future__ import annotations
class A:
param1: str
param2: B
class B:
param1: str
""").lstrip())
"""
).lstrip()
)
foo_fp.flush()

dot_pyflyby = os.path.join(temp_dir, ".pyflyby")
with open(dot_pyflyby, "w") as dot_pyflyby_fp:
dot_pyflyby_fp.write(dedent("""
dot_pyflyby_fp.write(
dedent(
"""
from foo import A, B
""").lstrip())
"""
).lstrip()
)
dot_pyflyby_fp.flush()
with CwdCtx(temp_dir):
result = pipe(
Expand Down

0 comments on commit c508269

Please sign in to comment.