Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

requirement: Close temporary files before passing them to pip #551

Merged
merged 4 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions pip_audit/_dependency_source/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def collect(self) -> Iterator[Dependency]:
Raises a `RequirementSourceError` on any errors.
"""

with ExitStack() as stack:
tmp_files = []
try:
# We need to handle process substitution inputs so we can invoke
# `pip-audit` like so:
#
Expand All @@ -97,20 +98,26 @@ def collect(self) -> Iterator[Dependency]:
# In order to get around this, we're going to copy each input into a
# a corresponding temporary file and then pass that set of files
# into `pip`.
tmp_files = []

# For each input file, copy it to one of our temporary files.
# Ensure we flush so our writes are visible to `pip`.
for filename in self._filenames:
tmp_file = stack.enter_context(NamedTemporaryFile(mode="w"))
# Deliberately pass `delete=False` so that our temporary file doesn't get
# automatically deleted on close. We need to close it so that `pip` can
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm referring to pip here instead of the requirements parser (where the actual error is at the moment) since we're about to remove the parser from this code path.

# use it however, we obviously want it to persist.
tmp_file = NamedTemporaryFile(mode="w", delete=False)
with filename.open("r") as f:
shutil.copyfileobj(f, tmp_file)
tmp_file.flush()
tmp_files.append(tmp_file)

# Close the file since it's going to get re-opened by `pip`
tmp_file.close()
tmp_files.append(tmp_file.name)

# Now pass the list of temporary filenames into the rest of our
# logic.
yield from self._collect_from_files([Path(f.name) for f in tmp_files])
yield from self._collect_from_files([Path(f) for f in tmp_files])
finally:
# Since we disabled automatically deletion for these temporary files, we need to
# manually delete them on the way out.
for f in tmp_files:
os.unlink(f)

def _collect_from_files(self, filenames: list[os.PathLike]) -> Iterator[Dependency]:
# Figure out whether we have a fully resolved set of dependencies.
Expand Down
35 changes: 34 additions & 1 deletion test/dependency_source/test_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from email.message import EmailMessage
from pathlib import Path
from tempfile import NamedTemporaryFile

import pretend # type: ignore
import pytest
Expand All @@ -17,7 +18,7 @@
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import ResolvedDependency, SkippedDependency
from pip_audit._state import AuditState
from pip_audit._virtual_env import VirtualEnvError
from pip_audit._virtual_env import VirtualEnv, VirtualEnvError


def get_metadata_mock():
Expand Down Expand Up @@ -483,6 +484,38 @@ def test_requirement_source_no_deps_duplicate_dependencies(req_file):
list(source.collect())


def test_requirement_source_no_double_open(monkeypatch, req_file):
source = _init_requirement([(req_file(), "flask==2.0.1")])

# Intercept the calls to `NamedTemporaryFile` to get a handle on each file object.
tmp_files = []

def named_temp_file(*args, **kwargs):
tmp_file = NamedTemporaryFile(*args, **kwargs)
tmp_files.append(tmp_file)
return tmp_file

monkeypatch.setattr(
requirement,
"NamedTemporaryFile",
named_temp_file,
)

# Intercept the `VirtualEnv` constructor to check that all file handles are closed prior to
# the `pip` invocation.
#
# `pip` will open the file so we need to ensure that we've closed it.
def virtual_env(*args, **kwargs):
for tmp_file in tmp_files:
assert tmp_file.closed
return VirtualEnv(*args, **kwargs)

monkeypatch.setattr(requirement, "VirtualEnv", virtual_env)

specs = list(source.collect())
assert ResolvedDependency("Flask", Version("2.0.1")) in specs


def test_requirement_source_fix_explicit_subdep(monkeypatch, req_file):
logger = pretend.stub(warning=pretend.call_recorder(lambda s: None))
monkeypatch.setattr(requirement, "logger", logger)
Expand Down