diff --git a/CHANGELOG.md b/CHANGELOG.md index b9e5d8cb..b7e30e8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ All versions prior to 0.0.9 are untracked. ## [Unreleased] +### Fixed + +* Vulnerability fixing: the `--fix` flag now works for vulnerabilities found in + requirement subdependencies. A new line is now added to the requirement file + to explicitly pin the offending subdependency + ([#297](https://github.com/trailofbits/pip-audit/pull/297)) + ## [2.3.3] ### Changed diff --git a/pip_audit/_dependency_source/requirement.py b/pip_audit/_dependency_source/requirement.py index 7210645e..999693b2 100644 --- a/pip_audit/_dependency_source/requirement.py +++ b/pip_audit/_dependency_source/requirement.py @@ -9,7 +9,7 @@ from contextlib import ExitStack from pathlib import Path from tempfile import NamedTemporaryFile -from typing import IO, Iterator, List, Set, Union, cast +from typing import IO, Dict, Iterator, List, Set, Tuple, Union, cast from packaging.requirements import Requirement from packaging.specifiers import SpecifierSet @@ -71,6 +71,7 @@ def __init__( self._require_hashes = require_hashes self._no_deps = no_deps self.state = state + self._dep_cache: Dict[Path, Dict[Requirement, Set[Dependency]]] = {} def collect(self) -> Iterator[Dependency]: """ @@ -83,44 +84,17 @@ def collect(self) -> Iterator[Dependency]: try: reqs = parse_requirements(filename=filename) except PipError as pe: - raise RequirementSourceError("requirement parsing raised an error") from pe - - # There are three cases where we skip dependency resolution: - # - # 1. The user has explicitly specified `--require-hashes`. - # 2. One or more parsed requirements has hashes specified, enabling - # hash checking for all requirements. - # 3. The user has explicitly specified `--no-deps`. - require_hashes = self._require_hashes or any( - isinstance(req, ParsedRequirement) and req.hashes for req in reqs.values() - ) - skip_deps = require_hashes or self._no_deps - if skip_deps: - yield from self._collect_preresolved_deps( - iter(reqs.values()), require_hashes=require_hashes - ) - continue - - # Invoke the dependency resolver to turn requirements into dependencies - req_values: List[Requirement] = [Requirement(str(req)) for req in reqs.values()] + raise RequirementSourceError( + f"requirement parsing raised an error: {filename}" + ) from pe try: - for _, deps in self._resolver.resolve_all(iter(req_values)): - for dep in deps: - # Don't allow duplicate dependencies to be returned - if dep in collected: - continue - - if dep.is_skipped(): # pragma: no cover - dep = cast(SkippedDependency, dep) - self.state.update_state(f"Skipping {dep.name}: {dep.skip_reason}") - else: - dep = cast(ResolvedDependency, dep) - self.state.update_state(f"Collecting {dep.name} ({dep.version})") - - collected.add(dep) - yield dep + for _, dep in self._collect_cached_deps(filename, list(reqs.values())): + if dep in collected: + continue + collected.add(dep) + yield dep except DependencyResolverError as dre: - raise RequirementSourceError("dependency resolver raised an error") from dre + raise RequirementSourceError from dre def fix(self, fix_version: ResolvedFixVersion) -> None: """ @@ -164,6 +138,7 @@ def _fix_file(self, filename: Path, fix_version: ResolvedFixVersion) -> None: # Now write out the new requirements file with filename.open("w") as f: + fixed = False for req in req_list: if ( req.name == fix_version.dep.name @@ -171,8 +146,38 @@ def _fix_file(self, filename: Path, fix_version: ResolvedFixVersion) -> None: and not req.specifier.contains(fix_version.version) ): req.specifier = SpecifierSet(f"=={fix_version.version}") + fixed = True assert req.marker is None or req.marker.evaluate() - f.write(str(req) + os.linesep) + print(str(req), file=f) + + # The vulnerable dependency may not be explicitly listed in the requirements file if it + # is a subdependency of a requirement. In this case, we should explicitly add the fixed + # dependency into the requirements file. + # + # To know whether this is the case, we'll need to resolve dependencies if we haven't + # already in order to figure out whether this subdependency belongs to this file or + # another. + try: + if not fixed: + origin_reqs: Set[Requirement] = set() + for req, dep in self._collect_cached_deps(filename, list(reqs.values())): + if fix_version.dep == dep: + origin_reqs.add(req) + if origin_reqs: + logger.warning( + "added fixed subdependency explicitly to requirements file " + f"{filename}: {fix_version.dep.canonical_name}" + ) + origin_reqs_formatted = ",".join( + [str(req) for req in sorted(list(origin_reqs), key=lambda x: x.name)] + ) + print( + f" # pip-audit: subdependency fixed via {origin_reqs_formatted}", + file=f, + ) + print(f"{fix_version.dep.canonical_name}=={fix_version.version}", file=f) + except DependencyResolverError as dre: + raise RequirementFixError from dre def _recover_files(self, tmp_files: List[IO[str]]) -> None: for (filename, tmp_file) in zip(self._filenames, tmp_files): @@ -191,7 +196,7 @@ def _collect_preresolved_deps( self, reqs: Iterator[Union[ParsedRequirement, UnparsedRequirement]], require_hashes: bool = False, - ) -> Iterator[Dependency]: + ) -> Iterator[Tuple[Requirement, Dependency]]: """ Collect pre-resolved (pinned) dependencies, optionally enforcing a hash requirement policy. @@ -210,10 +215,65 @@ def _collect_preresolved_deps( if pinned_specifier is None: raise RequirementSourceError(f"requirement {req.name} is not pinned: {str(req)}") - yield ResolvedDependency( + yield Requirement(str(req)), ResolvedDependency( req.name, Version(pinned_specifier.group("version")), req.hashes ) + def _collect_cached_deps( + self, filename: Path, reqs: List[Union[ParsedRequirement, UnparsedRequirement]] + ) -> Iterator[Tuple[Requirement, Dependency]]: + """ + Collect resolved dependencies for a given requirements file, retrieving them from the + dependency cache if possible. + """ + # See if we've already have cached dependencies for this file + cached_deps_for_file = self._dep_cache.get(filename, None) + if cached_deps_for_file is not None: + for req, deps in cached_deps_for_file.items(): + for dep in deps: + yield req, dep + + new_cached_deps_for_file: Dict[Requirement, Set[Dependency]] = dict() + + # There are three cases where we skip dependency resolution: + # + # 1. The user has explicitly specified `--require-hashes`. + # 2. One or more parsed requirements has hashes specified, enabling + # hash checking for all requirements. + # 3. The user has explicitly specified `--no-deps`. + require_hashes = self._require_hashes or any( + isinstance(req, ParsedRequirement) and req.hashes for req in reqs + ) + skip_deps = require_hashes or self._no_deps + if skip_deps: + for req, dep in self._collect_preresolved_deps( + iter(reqs), require_hashes=require_hashes + ): + if req not in new_cached_deps_for_file: + new_cached_deps_for_file[req] = set() + new_cached_deps_for_file[req].add(dep) + yield req, dep + else: + # Invoke the dependency resolver to turn requirements into dependencies + req_values: List[Requirement] = [Requirement(str(req)) for req in reqs] + for req, resolved_deps in self._resolver.resolve_all(iter(req_values)): + for dep in resolved_deps: + if req not in new_cached_deps_for_file: + new_cached_deps_for_file[req] = set() + new_cached_deps_for_file[req].add(dep) + + if dep.is_skipped(): # pragma: no cover + dep = cast(SkippedDependency, dep) + self.state.update_state(f"Skipping {dep.name}: {dep.skip_reason}") + else: + dep = cast(ResolvedDependency, dep) + self.state.update_state(f"Collecting {dep.name} ({dep.version})") + + yield req, dep + + # Cache the collected dependencies + self._dep_cache[filename] = new_cached_deps_for_file + class RequirementSourceError(DependencySourceError): """A requirements-parsing specific `DependencySourceError`.""" diff --git a/test/dependency_source/test_requirement.py b/test/dependency_source/test_requirement.py index e367d22c..7dc33b25 100644 --- a/test/dependency_source/test_requirement.py +++ b/test/dependency_source/test_requirement.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import List +from typing import List, Optional import pretend # type: ignore import pytest @@ -141,8 +141,8 @@ def test_requirement_source_fix(req_file): def test_requirement_source_fix_multiple_files(req_file): _check_fixes( - ["flask==0.5", "requests==1.0\nflask==0.5"], - ["flask==1.0", "requests==1.0\nflask==1.0"], + ["flask==0.5", "requests==2.0\nflask==0.5"], + ["flask==1.0", "requests==2.0\nflask==1.0"], [req_file(), req_file()], [ ResolvedFixVersion( @@ -154,8 +154,8 @@ def test_requirement_source_fix_multiple_files(req_file): def test_requirement_source_fix_specifier_match(req_file): _check_fixes( - ["flask<1.0", "requests==1.0\nflask<=0.6"], - ["flask==1.0", "requests==1.0\nflask==1.0"], + ["flask<1.0", "requests==2.0\nflask<=0.6"], + ["flask==1.0", "requests==2.0\nflask==1.0"], [req_file(), req_file()], [ ResolvedFixVersion( @@ -170,8 +170,8 @@ def test_requirement_source_fix_specifier_no_match(req_file): # version. If the specifier matches both, we don't apply the fix since installing from the given # requirements file would already install the fixed version. _check_fixes( - ["flask>=0.5", "requests==1.0\nflask<2.0"], - ["flask>=0.5", "requests==1.0\nflask<2.0"], + ["flask>=0.5", "requests==2.0\nflask<2.0"], + ["flask>=0.5", "requests==2.0\nflask<2.0"], [req_file(), req_file()], [ ResolvedFixVersion( @@ -187,11 +187,11 @@ def test_requirement_source_fix_marker(req_file): _check_fixes( [ 'flask<1.0; python_version > "2.7"', - 'requests==1.0\nflask<=0.6; python_version <= "2.7"', + 'requests==2.0\nflask<=0.6; python_version <= "2.7"', ], [ 'flask==1.0; python_version > "2.7"', - "requests==1.0", + "requests==2.0", ], [req_file(), req_file()], [ @@ -207,9 +207,9 @@ def test_requirement_source_fix_comments(req_file): _check_fixes( [ "# comment here\nflask==0.5", - "requests==1.0\n# another comment\nflask==0.5", + "requests==2.0\n# another comment\nflask==0.5", ], - ["flask==1.0", "requests==1.0\nflask==1.0"], + ["flask==1.0", "requests==2.0\nflask==1.0"], [req_file(), req_file()], [ ResolvedFixVersion( @@ -225,7 +225,7 @@ def test_requirement_source_fix_parse_failure(monkeypatch, req_file): # If `pip-api` encounters multiple of the same package in the requirements file, it will throw a # parsing error - input_reqs = ["flask==0.5", "flask==0.5\nrequests==1.0\nflask==0.3"] + input_reqs = ["flask==0.5", "flask==0.5\nrequests==2.0\nflask==0.3"] req_paths = [req_file(), req_file()] # Populate the requirements files @@ -255,7 +255,7 @@ def test_requirement_source_fix_rollback_failure(monkeypatch, req_file): # If `pip-api` encounters multiple of the same package in the requirements file, it will throw a # parsing error - input_reqs = ["flask==0.5", "flask==0.5\nrequests==1.0\nflask==0.3"] + input_reqs = ["flask==0.5", "flask==0.5\nrequests==2.0\nflask==0.3"] req_paths = [req_file(), req_file()] # Populate the requirements files @@ -282,7 +282,7 @@ def mock_replace(*_args, **_kwargs): # We couldn't move the original requirements files back so we should expect a partially applied # fix. The first requirements file contains the fix, while the second one doesn't since we were # in the process of writing it out and didn't flush. - expected_reqs = ["flask==1.0", "flask==0.5\nrequests==1.0\nflask==0.3"] + expected_reqs = ["flask==1.0", "flask==0.5\nrequests==2.0\nflask==0.3"] for (expected_req, req_path) in zip(expected_reqs, req_paths): with open(req_path, "r") as f: assert expected_req == f.read().strip() @@ -328,7 +328,7 @@ def test_requirement_source_require_hashes_inferred(monkeypatch): monkeypatch.setattr( _parse_requirements, "_read_file", - lambda _: ["flask==2.0.1 --hash=sha256:flask-hash\nrequests==1.0"], + lambda _: ["flask==2.0.1 --hash=sha256:flask-hash\nrequests==2.0"], ) # If at least one requirement is hashed, this infers `require-hashes` @@ -384,3 +384,170 @@ def test_requirement_source_no_deps_unpinned(monkeypatch): # When dependency resolution is disabled, all requirements must be pinned. with pytest.raises(DependencySourceError): list(source.collect()) + + +def test_requirement_source_dep_caching(monkeypatch): + source = requirement.RequirementSource( + [Path("requirements.txt")], ResolveLibResolver(), no_deps=True + ) + + monkeypatch.setattr( + _parse_requirements, + "_read_file", + lambda _: ["flask==2.0.1"], + ) + + specs = list(source.collect()) + + class MockResolver(DependencyResolver): + def resolve(self, req: Requirement) -> List[Dependency]: + raise DependencyResolverError + + # Now run collect again and check that dependency resolution doesn't get repeated + source._resolver = MockResolver() + + cached_specs = list(source.collect()) + assert specs == cached_specs + + +def test_requirement_source_fix_explicit_subdep(monkeypatch, req_file): + logger = pretend.stub(warning=pretend.call_recorder(lambda s: None)) + monkeypatch.setattr(requirement, "logger", logger) + + # We're going to simulate the situation where a subdependency of `flask` has a vulnerability. + # In this case, we're choosing `jinja2`. + flask_deps = ResolveLibResolver().resolve(Requirement("flask==2.0.1")) + + # Firstly, get a handle on the `jinja2` dependency. The version cannot be hardcoded since it + # depends what versions are available on PyPI when dependency resolution runs. + jinja_dep: Optional[ResolvedDependency] = None + for dep in flask_deps: + if isinstance(dep, ResolvedDependency) and dep.canonical_name == "jinja2": + jinja_dep = dep + break + assert jinja_dep is not None + + # Check that the `jinja2` dependency is explicitly added to the requirements file with an + # associated comment. + _check_fixes( + ["flask==2.0.1"], + ["flask==2.0.1\n # pip-audit: subdependency fixed via flask==2.0.1\njinja2==4.0.0"], + [req_file()], + [ + ResolvedFixVersion( + dep=jinja_dep, + version=Version("4.0.0"), + ) + ], + ) + + # When explicitly listing a fixed subdependency, we issue a warning. + assert len(logger.warning.calls) == 1 + + +def test_requirement_source_fix_explicit_subdep_multiple_reqs(monkeypatch, req_file): + # Recreate the vulnerable subdependency case. + flask_deps = ResolveLibResolver().resolve(Requirement("flask==2.0.1")) + jinja_dep: Optional[ResolvedDependency] = None + for dep in flask_deps: + if isinstance(dep, ResolvedDependency) and dep.canonical_name == "jinja2": + jinja_dep = dep + break + assert jinja_dep is not None + + # This time our requirements file also lists `django-jinja`, another requirement that depends on + # `jinja2`. We're expecting that the comment generated above the `jinja2` requirement that gets + # added into the file will list both `flask` and `django-jinja` as sources. + _check_fixes( + ["flask==2.0.1\ndjango-jinja==1.0"], + [ + "flask==2.0.1\ndjango-jinja==1.0\n" + " # pip-audit: subdependency fixed via django-jinja==1.0,flask==2.0.1\n" + "jinja2==4.0.0" + ], + [req_file()], + [ + ResolvedFixVersion( + dep=jinja_dep, + version=Version("4.0.0"), + ) + ], + ) + + +def test_requirement_source_fix_explicit_subdep_resolver_error(req_file): + # Pass the requirement source a resolver that automatically raises errors + class MockResolver(DependencyResolver): + def resolve(self, req: Requirement) -> List[Dependency]: + raise DependencyResolverError + + req_file_name = req_file() + with open(req_file_name, "w") as f: + f.write("flask==2.0.1") + + # Recreate the vulnerable subdependency case. + flask_deps = ResolveLibResolver().resolve(Requirement("flask==2.0.1")) + jinja_dep: Optional[ResolvedDependency] = None + for dep in flask_deps: + if isinstance(dep, ResolvedDependency) and dep.canonical_name == "jinja2": + jinja_dep = dep + break + assert jinja_dep is not None + + # When we try to fix a vulnerable subdependency, we need to resolve dependencies if that + # information isn't already cached. + # + # Test the case where we hit a resolver error. + source = requirement.RequirementSource([req_file_name], MockResolver()) + with pytest.raises(DependencyFixError): + source.fix( + ResolvedFixVersion( + dep=jinja_dep, + version=Version("4.0.0"), + ) + ) + + +def test_requirement_source_fix_explicit_subdep_comment_removal(req_file): + # This test is documenting a weakness in the current fix implementation. + # + # When fixing a subdependency and explicitly adding it to the requirements file, we add a + # comment above the line to explain its presence since it's unusual to explicitly pin a + # subdependency like this. + # + # When we "fix" dependencies, we use `pip-api` to parse the requirements file and write it back + # out with the relevant line amended or added. One downside of this method is that `pip-api` + # filters out comments so applying fixes removes all comments in the file. + # See: https://github.com/di/pip-api/issues/120 + # + # Therefore, when we apply a subdependency fix, the automated comment will be removed + # by any subsequent fixes. + + # Recreate the vulnerable subdependency case. + flask_deps = ResolveLibResolver().resolve(Requirement("flask==2.0.1")) + jinja_dep: Optional[ResolvedDependency] = None + for dep in flask_deps: + if isinstance(dep, ResolvedDependency) and dep.canonical_name == "jinja2": + jinja_dep = dep + break + assert jinja_dep is not None + + # Now place a fix for the top-level `flask` requirement after the `jinja2` subdependency fix. + # + # When applying the `flask` fix, `pip-audit` reparses the requirements file, stripping out the + # comment and writes it back out with the fixed `flask` version. + _check_fixes( + ["flask==2.0.1"], + ["flask==3.0.0\njinja2==4.0.0"], + [req_file()], + [ + ResolvedFixVersion( + dep=jinja_dep, + version=Version("4.0.0"), + ), + ResolvedFixVersion( + dep=ResolvedDependency(name="flask", version=Version("2.0.1")), + version=Version("3.0.0"), + ), + ], + )