Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pip, _fix: Implement --fix for PipSource #212

Merged
merged 17 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ python -m pip_audit --help
usage: pip-audit [-h] [-V] [-l] [-r REQUIREMENTS] [-f FORMAT] [-s SERVICE]
[-d] [-S] [--desc [{on,off,auto}]] [--cache-dir CACHE_DIR]
[--progress-spinner {on,off}] [--timeout TIMEOUT]
[--path PATHS] [-v]
[--path PATHS] [-v] [--fix]

audit the Python environment for dependencies with known vulnerabilities

Expand Down Expand Up @@ -111,6 +111,8 @@ optional arguments:
-v, --verbose give more output; this setting overrides the
`PIP_AUDIT_LOGLEVEL` variable and is equivalent to
setting it to `debug` (default: False)
--fix automatically upgrade dependencies with known
vulnerabilities (default: False)
```
<!-- @end-pip-audit-help@ -->

Expand Down Expand Up @@ -216,6 +218,16 @@ Found 2 known vulnerabilities in 1 packages
]
```

Audit and attempt to automatically upgrade vulnerable dependencies:
```
$ pip-audit --fix
Found 2 known vulnerabilities in 1 packages and fixed 2 vulnerabilities in 1 packages
Name Version ID Fix Versions
----- ------- -------------- ------------
Flask 0.5 PYSEC-2019-179 1.0
Flask 0.5 PYSEC-2018-66 0.12.3
```

## Security Model

This section exists to describe the security assumptions you **can** and **must not**
Expand Down
34 changes: 32 additions & 2 deletions pip_audit/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
RequirementSource,
ResolveLibResolver,
)
from pip_audit._dependency_source.interface import DependencySourceError
from pip_audit._fix import ResolvedFixVersion, SkippedFixVersion, resolve_fix_versions
from pip_audit._format import ColumnsFormat, CycloneDxFormat, JsonFormat, VulnerabilityFormat
from pip_audit._service import OsvService, PyPIService, VulnerabilityService
from pip_audit._service.interface import ResolvedDependency, SkippedDependency
Expand Down Expand Up @@ -234,6 +236,11 @@ def audit() -> None:
help="give more output; this setting overrides the `PIP_AUDIT_LOGLEVEL` variable and is "
"equivalent to setting it to `debug`",
)
parser.add_argument(
"--fix",
action="store_true",
help="automatically upgrade dependencies with known vulnerabilities",
)

args = parser.parse_args()
if args.verbose:
Expand Down Expand Up @@ -280,11 +287,34 @@ def audit() -> None:
pkg_count += 1
vuln_count += len(vulns)

# If the `--fix` flag has been applied, find a set of suitable fix versions and upgrade the
# dependencies at the source
fixes = list()
fixed_pkg_count = 0
fixed_vuln_count = 0
if args.fix:
for fix_version in resolve_fix_versions(service, result):
if not fix_version.is_skipped():
fix_version = cast(ResolvedFixVersion, fix_version)
try:
source.fix(fix_version)
fixed_pkg_count += 1
fixed_vuln_count += len(result[fix_version.dep])
except DependencySourceError as dse:
fix_version = SkippedFixVersion(fix_version.dep, str(dse))
fixes.append(fix_version)
woodruffw marked this conversation as resolved.
Show resolved Hide resolved

# TODO(ww): Refine this: we should always output if our output format is an SBOM
# or other manifest format (like the default JSON format).
if vuln_count > 0:
print(f"Found {vuln_count} known vulnerabilities in {pkg_count} packages", file=sys.stderr)
summary_msg = f"Found {vuln_count} known vulnerabilities in {pkg_count} packages"
if args.fix:
summary_msg += (
f" and fixed {fixed_vuln_count} vulnerabilities in {fixed_pkg_count} packages"
)
print(summary_msg, file=sys.stderr)
print(formatter.format(result))
sys.exit(1)
if pkg_count != fixed_pkg_count:
sys.exit(1)
else:
print("No known vulnerabilities found", file=sys.stderr)
2 changes: 2 additions & 0 deletions pip_audit/_dependency_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from .interface import (
DependencyFixError,
DependencyResolver,
DependencyResolverError,
DependencySource,
Expand All @@ -13,6 +14,7 @@
from .resolvelib import ResolveLibResolver

__all__ = [
"DependencyFixError",
"DependencyResolver",
"DependencyResolverError",
"DependencySource",
Expand Down
20 changes: 20 additions & 0 deletions pip_audit/_dependency_source/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from packaging.requirements import Requirement

from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency


Expand All @@ -26,6 +27,13 @@ def collect(self) -> Iterator[Dependency]: # pragma: no cover
"""
raise NotImplementedError

@abstractmethod
def fix(self, fix_version: ResolvedFixVersion) -> None: # pragma: no cover
"""
Upgrade a dependency to the given fix version.
"""
raise NotImplementedError


class DependencySourceError(Exception):
"""
Expand All @@ -38,6 +46,18 @@ class DependencySourceError(Exception):
pass


class DependencyFixError(Exception):
"""
Raised when a `DependencySource` fails to perform a "fix" operation, i.e.
fails to upgrade a package to a different version.

Concrete implementations are expected to subclass this exception to provide
more context.
"""

pass


class DependencyResolver(ABC):
"""
Represents an abstract resolver of Python dependencies that takes a single
Expand Down
32 changes: 31 additions & 1 deletion pip_audit/_dependency_source/pip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
"""

import logging
import subprocess
import sys
from pathlib import Path
from typing import Iterator, Sequence

import pip_api
from packaging.version import InvalidVersion, Version

from pip_audit._dependency_source import DependencySource, DependencySourceError
from pip_audit._dependency_source import DependencyFixError, DependencySource, DependencySourceError
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency, ResolvedDependency, SkippedDependency
from pip_audit._state import AuditState

Expand Down Expand Up @@ -87,8 +90,35 @@ def collect(self) -> Iterator[Dependency]:
except Exception as e:
raise PipSourceError("failed to list installed distributions") from e

def fix(self, fix_version: ResolvedFixVersion) -> None:
"""
Fixes a dependency version in this `PipSource`.
"""
fix_cmd = [
sys.executable,
"-m",
"pip",
"install",
f"{fix_version.dep.canonical_name}=={fix_version.version}",
]
try:
subprocess.run(
fix_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
except subprocess.CalledProcessError as cpe:
raise PipFixError(
f"failed to upgrade dependency {fix_version.dep.name} to fix version "
f"{fix_version.version}"
) from cpe


class PipSourceError(DependencySourceError):
"""A `pip` specific `DependencySourceError`."""

pass


class PipFixError(DependencyFixError):
"""A `pip` specific `DependencyFixError`."""

pass
7 changes: 7 additions & 0 deletions pip_audit/_dependency_source/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DependencySource,
DependencySourceError,
)
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency
from pip_audit._service.interface import ResolvedDependency, SkippedDependency
from pip_audit._state import AuditState
Expand Down Expand Up @@ -78,6 +79,12 @@ def collect(self) -> Iterator[Dependency]:
except DependencyResolverError as dre:
raise RequirementSourceError("dependency resolver raised an error") from dre

def fix(self, fix_version: ResolvedFixVersion) -> None: # pragma: no cover
"""
Fixes a dependency version for this `RequirementSource`.
"""
raise NotImplementedError


class RequirementSourceError(DependencySourceError):
"""A requirements-parsing specific `DependencySourceError`."""
Expand Down
111 changes: 111 additions & 0 deletions pip_audit/_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Functionality for resolving fixed versions of dependencies.
"""

from dataclasses import dataclass
from typing import Dict, Iterator, List, cast

from packaging.version import Version

from pip_audit._service import (
Dependency,
ResolvedDependency,
VulnerabilityResult,
VulnerabilityService,
)


@dataclass(frozen=True)
woodruffw marked this conversation as resolved.
Show resolved Hide resolved
class FixVersion:
"""
Represents an abstract dependency fix version.

This class cannot be constructed directly.
"""

dep: ResolvedDependency

def __init__(self, *_args, **_kwargs) -> None: # pragma: no cover
"""
A stub constructor that always fails.
"""
raise NotImplementedError

def is_skipped(self) -> bool:
"""
Check whether the `FixVersion` was unable to be resolved.
"""
return self.__class__ is SkippedFixVersion


@dataclass(frozen=True)
class ResolvedFixVersion(FixVersion):
"""
Represents a resolved fix version.
"""

version: Version


@dataclass(frozen=True)
class SkippedFixVersion(FixVersion):
"""
Represents a fix version that was unable to be resolved and therefore, skipped.
"""

skip_reason: str


def resolve_fix_versions(
service: VulnerabilityService, result: Dict[Dependency, List[VulnerabilityResult]]
) -> Iterator[FixVersion]:
"""
Resolves a mapping of dependencies to known vulnerabilities to a series of fix versions without
known vulnerabilties.
"""
for (dep, vulns) in result.items():
if dep.is_skipped():
continue
if not vulns:
continue
dep = cast(ResolvedDependency, dep)
try:
version = _resolve_fix_version(service, dep, vulns)
yield ResolvedFixVersion(dep, version)
except FixResolutionImpossible as fri:
yield SkippedFixVersion(dep, str(fri))


def _resolve_fix_version(
service: VulnerabilityService, dep: ResolvedDependency, vulns: List[VulnerabilityResult]
) -> Version:
# We need to upgrade to a fix version that satisfies all vulnerability results
#
# However, whenever we upgrade a dependency, we run the risk of introducing new vulnerabilities
# so we need to run this in a loop and continue polling the vulnerability service on each
# prospective resolved fix version
current_version = dep.version
current_vulns = vulns
while current_vulns:

def get_earliest_fix_version(d: ResolvedDependency, v: VulnerabilityResult) -> Version:
for fix_version in v.fix_versions:
if fix_version > current_version:
return fix_version
raise FixResolutionImpossible(
f"failed to fix dependency {dep.name} ({dep.version}), unable to find fix version "
f"for vulnerability {v.id}"
)

# We want to retrieve a version that potentially fixes all vulnerabilities
current_version = max([get_earliest_fix_version(dep, v) for v in current_vulns])
_, current_vulns = service.query(ResolvedDependency(dep.name, current_version))
return current_version


class FixResolutionImpossible(Exception):
"""
Raised when `resolve_fix_versions` fails to find a fix version without known vulnerabilities
"""

pass
3 changes: 3 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class Source(DependencySource):
def collect(self):
yield spec("1.0.1")

def fix(self, _) -> None:
raise NotImplementedError

return Source


Expand Down
35 changes: 35 additions & 0 deletions test/dependency_source/test_pip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import subprocess
import sys
from dataclasses import dataclass
from typing import Dict, List

Expand All @@ -8,6 +10,7 @@
from packaging.version import Version

from pip_audit._dependency_source import pip
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service.interface import ResolvedDependency, SkippedDependency


Expand Down Expand Up @@ -82,3 +85,35 @@ def mock_installed_distributions(
in specs
)
assert ResolvedDependency(name="pip-api", version=Version("1.0")) in specs


def test_pip_source_fix(monkeypatch):
source = pip.PipSource()

fix_version = ResolvedFixVersion(
dep=ResolvedDependency(name="pip-api", version=Version("1.0")), version=Version("1.5")
)

def run_mock(args, **kwargs):
assert " ".join(args) == f"{sys.executable} -m pip install pip-api==1.5"

monkeypatch.setattr(subprocess, "run", run_mock)

source.fix(fix_version)


def test_pip_source_fix_failure(monkeypatch):
source = pip.PipSource()

fix_version = ResolvedFixVersion(
dep=ResolvedDependency(name="pip-api", version=Version("1.0")), version=Version("1.5")
)

def run_mock(args, **kwargs):
assert " ".join(args) == f"{sys.executable} -m pip install pip-api==1.5"
raise subprocess.CalledProcessError(-1, str())

monkeypatch.setattr(subprocess, "run", run_mock)

with pytest.raises(pip.PipFixError):
source.fix(fix_version)
Loading