Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

requirement, test: Correct --fix for subdependencies in requirements files #297

Merged
merged 17 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ All versions prior to 0.0.9 are untracked.

## [Unreleased]

### Fixed

* Vulnerability fixing: the `--fix` flag now works for vulnerabilities found in
requirement subdependencies. A new line is now added to the requirement file
to explicitly pin the offending subdependency
([#297](https://github.com/trailofbits/pip-audit/pull/297))

## [2.3.3]

### Changed
Expand Down
140 changes: 100 additions & 40 deletions pip_audit/_dependency_source/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import ExitStack
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import IO, Iterator, List, Set, Union, cast
from typing import IO, Dict, Iterator, List, Set, Tuple, Union, cast

from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
Expand Down Expand Up @@ -71,6 +71,7 @@ def __init__(
self._require_hashes = require_hashes
self._no_deps = no_deps
self.state = state
self._dep_cache: Dict[Path, Dict[Requirement, Set[Dependency]]] = {}

def collect(self) -> Iterator[Dependency]:
"""
Expand All @@ -83,44 +84,17 @@ def collect(self) -> Iterator[Dependency]:
try:
reqs = parse_requirements(filename=filename)
except PipError as pe:
raise RequirementSourceError("requirement parsing raised an error") from pe

# There are three cases where we skip dependency resolution:
#
# 1. The user has explicitly specified `--require-hashes`.
# 2. One or more parsed requirements has hashes specified, enabling
# hash checking for all requirements.
# 3. The user has explicitly specified `--no-deps`.
require_hashes = self._require_hashes or any(
isinstance(req, ParsedRequirement) and req.hashes for req in reqs.values()
)
skip_deps = require_hashes or self._no_deps
if skip_deps:
yield from self._collect_preresolved_deps(
iter(reqs.values()), require_hashes=require_hashes
)
continue

# Invoke the dependency resolver to turn requirements into dependencies
req_values: List[Requirement] = [Requirement(str(req)) for req in reqs.values()]
raise RequirementSourceError(
f"requirement parsing raised an error: {filename}"
) from pe
try:
for _, deps in self._resolver.resolve_all(iter(req_values)):
for dep in deps:
# Don't allow duplicate dependencies to be returned
if dep in collected:
continue

if dep.is_skipped(): # pragma: no cover
dep = cast(SkippedDependency, dep)
self.state.update_state(f"Skipping {dep.name}: {dep.skip_reason}")
else:
dep = cast(ResolvedDependency, dep)
self.state.update_state(f"Collecting {dep.name} ({dep.version})")

collected.add(dep)
yield dep
for _, dep in self._collect_cached_deps(filename, list(reqs.values())):
if dep in collected:
continue
collected.add(dep)
yield dep
except DependencyResolverError as dre:
raise RequirementSourceError("dependency resolver raised an error") from dre
raise RequirementSourceError from dre

def fix(self, fix_version: ResolvedFixVersion) -> None:
"""
Expand Down Expand Up @@ -164,15 +138,46 @@ def _fix_file(self, filename: Path, fix_version: ResolvedFixVersion) -> None:

# Now write out the new requirements file
with filename.open("w") as f:
fixed = False
for req in req_list:
if (
req.name == fix_version.dep.name
and req.specifier.contains(fix_version.dep.version)
and not req.specifier.contains(fix_version.version)
):
req.specifier = SpecifierSet(f"=={fix_version.version}")
fixed = True
assert req.marker is None or req.marker.evaluate()
f.write(str(req) + os.linesep)
print(str(req), file=f)

# The vulnerable dependency may not be explicitly listed in the requirements file if it
# is a subdependency of a requirement. In this case, we should explicitly add the fixed
# dependency into the requirements file.
#
# To know whether this is the case, we'll need to resolve dependencies if we haven't
# already in order to figure out whether this subdependency belongs to this file or
# another.
try:
if not fixed:
origin_reqs: Set[Requirement] = set()
for req, dep in self._collect_cached_deps(filename, list(reqs.values())):
if fix_version.dep == dep:
origin_reqs.add(req)
if origin_reqs:
logger.warning(
"added fixed subdependency explicitly to requirements file "
f"{filename}: {fix_version.dep.canonical_name}"
)
origin_reqs_formatted = ",".join(
[str(req) for req in sorted(list(origin_reqs), key=lambda x: x.name)]
)
print(
f" # pip-audit: subdependency fixed via {origin_reqs_formatted}",
file=f,
)
print(f"{fix_version.dep.canonical_name}=={fix_version.version}", file=f)
except DependencyResolverError as dre:
raise RequirementFixError from dre

def _recover_files(self, tmp_files: List[IO[str]]) -> None:
for (filename, tmp_file) in zip(self._filenames, tmp_files):
Expand All @@ -191,7 +196,7 @@ def _collect_preresolved_deps(
self,
reqs: Iterator[Union[ParsedRequirement, UnparsedRequirement]],
require_hashes: bool = False,
) -> Iterator[Dependency]:
) -> Iterator[Tuple[Requirement, Dependency]]:
"""
Collect pre-resolved (pinned) dependencies, optionally enforcing a
hash requirement policy.
Expand All @@ -210,10 +215,65 @@ def _collect_preresolved_deps(
if pinned_specifier is None:
raise RequirementSourceError(f"requirement {req.name} is not pinned: {str(req)}")

yield ResolvedDependency(
yield Requirement(str(req)), ResolvedDependency(
req.name, Version(pinned_specifier.group("version")), req.hashes
)

def _collect_cached_deps(
self, filename: Path, reqs: List[Union[ParsedRequirement, UnparsedRequirement]]
) -> Iterator[Tuple[Requirement, Dependency]]:
"""
Collect resolved dependencies for a given requirements file, retrieving them from the
dependency cache if possible.
"""
# See if we've already have cached dependencies for this file
cached_deps_for_file = self._dep_cache.get(filename, None)
if cached_deps_for_file is not None:
for req, deps in cached_deps_for_file.items():
for dep in deps:
yield req, dep

new_cached_deps_for_file: Dict[Requirement, Set[Dependency]] = dict()

# There are three cases where we skip dependency resolution:
#
# 1. The user has explicitly specified `--require-hashes`.
# 2. One or more parsed requirements has hashes specified, enabling
# hash checking for all requirements.
# 3. The user has explicitly specified `--no-deps`.
require_hashes = self._require_hashes or any(
isinstance(req, ParsedRequirement) and req.hashes for req in reqs
)
skip_deps = require_hashes or self._no_deps
if skip_deps:
for req, dep in self._collect_preresolved_deps(
iter(reqs), require_hashes=require_hashes
):
if req not in new_cached_deps_for_file:
new_cached_deps_for_file[req] = set()
new_cached_deps_for_file[req].add(dep)
yield req, dep
else:
# Invoke the dependency resolver to turn requirements into dependencies
req_values: List[Requirement] = [Requirement(str(req)) for req in reqs]
for req, resolved_deps in self._resolver.resolve_all(iter(req_values)):
for dep in resolved_deps:
if req not in new_cached_deps_for_file:
new_cached_deps_for_file[req] = set()
new_cached_deps_for_file[req].add(dep)

if dep.is_skipped(): # pragma: no cover
dep = cast(SkippedDependency, dep)
self.state.update_state(f"Skipping {dep.name}: {dep.skip_reason}")
else:
dep = cast(ResolvedDependency, dep)
self.state.update_state(f"Collecting {dep.name} ({dep.version})")

yield req, dep

# Cache the collected dependencies
self._dep_cache[filename] = new_cached_deps_for_file


class RequirementSourceError(DependencySourceError):
"""A requirements-parsing specific `DependencySourceError`."""
Expand Down
Loading