Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dependency resolution performance by not clearing caches when backtracking #7950

Merged
merged 4 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 87 additions & 25 deletions src/poetry/mixology/version_solver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import collections
import functools
import time

from typing import TYPE_CHECKING
from typing import Optional
from typing import Tuple

from poetry.core.packages.dependency import Dependency

Expand All @@ -28,6 +31,11 @@
_conflict = object()


DependencyCacheKey = Tuple[
str, Optional[str], Optional[str], Optional[str], Optional[str]
]


class DependencyCache:
"""
A cache of the valid dependencies.
Expand All @@ -38,29 +46,40 @@ class DependencyCache:
"""

def __init__(self, provider: Provider) -> None:
self.provider = provider
self.cache: dict[
tuple[str, str | None, str | None, str | None, str | None],
list[DependencyPackage],
] = {}

self.search_for = functools.lru_cache(maxsize=128)(self._search_for)
self._provider = provider

def _search_for(self, dependency: Dependency) -> list[DependencyPackage]:
key = (
dependency.complete_name,
dependency.source_type,
dependency.source_url,
dependency.source_reference,
dependency.source_subdirectory,
# self._cache maps a package name to a stack of cached package lists,
# ordered by the decision level which added them to the cache. This is
# done so that when backtracking we can maintain cache entries from
# previous decision levels, while clearing cache entries from only the
# rolled back levels.
#
# In order to maintain the integrity of the cache, `clear_level()`
# needs to be called in descending order as decision levels are
# backtracked so that the correct items can be popped from the stack.
self._cache: dict[DependencyCacheKey, list[list[DependencyPackage]]] = (
collections.defaultdict(list)
)
self._cached_dependencies_by_level: dict[int, list[DependencyCacheKey]] = (
collections.defaultdict(list)
)

packages = self.cache.get(key)
self._search_for_cached = functools.lru_cache(maxsize=128)(self._search_for)

if packages:
def _search_for(
self,
dependency: Dependency,
key: DependencyCacheKey,
) -> list[DependencyPackage]:
cache_entries = self._cache[key]
if cache_entries:
packages = [
p for p in packages if dependency.constraint.allows(p.package.version)
p
for p in cache_entries[-1]
if dependency.constraint.allows(p.package.version)
]
else:
packages = None

# provider.search_for() normally does not include pre-release packages
# (unless requested), but will include them if there are no other
Expand All @@ -70,14 +89,35 @@ def _search_for(self, dependency: Dependency) -> list[DependencyPackage]:
# nothing, we need to call provider.search_for() again as it may return
# additional results this time.
if not packages:
packages = self.provider.search_for(dependency)
packages = self._provider.search_for(dependency)

return packages

def search_for(
self,
dependency: Dependency,
decision_level: int,
) -> list[DependencyPackage]:
key = (
dependency.complete_name,
dependency.source_type,
dependency.source_url,
dependency.source_reference,
dependency.source_subdirectory,
)

self.cache[key] = packages
packages = self._search_for_cached(dependency, key)
if not self._cache[key] or self._cache[key][-1] is not packages:
self._cache[key].append(packages)
self._cached_dependencies_by_level[decision_level].append(key)

return packages

def clear(self) -> None:
self.cache.clear()
def clear_level(self, level: int) -> None:
if level in self._cached_dependencies_by_level:
self._search_for_cached.cache_clear()
for key in self._cached_dependencies_by_level.pop(level):
self._cache[key].pop()


class VersionSolver:
Expand All @@ -95,6 +135,9 @@ def __init__(self, root: ProjectPackage, provider: Provider) -> None:
self._dependency_cache = DependencyCache(provider)
self._incompatibilities: dict[str, list[Incompatibility]] = {}
self._contradicted_incompatibilities: set[Incompatibility] = set()
self._contradicted_incompatibilities_by_level: dict[
int, set[Incompatibility]
] = collections.defaultdict(set)
self._solution = PartialSolution()

@property
Expand Down Expand Up @@ -193,6 +236,9 @@ def _propagate_incompatibility(
# incompatibility is contradicted as well and there's nothing new we
# can deduce from it.
self._contradicted_incompatibilities.add(incompatibility)
self._contradicted_incompatibilities_by_level[
self._solution.decision_level
].add(incompatibility)
return None
elif relation == SetRelation.OVERLAPPING:
# If more than one term is inconclusive, we can't deduce anything about
Expand All @@ -211,6 +257,9 @@ def _propagate_incompatibility(
return _conflict

self._contradicted_incompatibilities.add(incompatibility)
self._contradicted_incompatibilities_by_level[
self._solution.decision_level
].add(incompatibility)

adverb = "not " if unsatisfied.is_positive() else ""
self._log(f"derived: {adverb}{unsatisfied.dependency}")
Expand Down Expand Up @@ -304,9 +353,16 @@ def _resolve_conflict(self, incompatibility: Incompatibility) -> Incompatibility
previous_satisfier_level < most_recent_satisfier.decision_level
or most_recent_satisfier.cause is None
):
for level in range(
self._solution.decision_level, previous_satisfier_level, -1
):
if level in self._contradicted_incompatibilities_by_level:
self._contradicted_incompatibilities.difference_update(
self._contradicted_incompatibilities_by_level.pop(level),
)
self._dependency_cache.clear_level(level)

self._solution.backtrack(previous_satisfier_level)
self._contradicted_incompatibilities.clear()
self._dependency_cache.clear()
if new_incompatibility:
self._add_incompatibility(incompatibility)

Expand Down Expand Up @@ -404,7 +460,11 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]:
if locked:
return is_specific_marker, Preference.LOCKED, 1

num_packages = len(self._dependency_cache.search_for(dependency))
num_packages = len(
self._dependency_cache.search_for(
dependency, self._solution.decision_level
)
)

if num_packages < 2:
preference = Preference.NO_CHOICE
Expand All @@ -421,7 +481,9 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]:

locked = self._provider.get_locked(dependency)
if locked is None:
packages = self._dependency_cache.search_for(dependency)
packages = self._dependency_cache.search_for(
dependency, self._solution.decision_level
)
package = next(iter(packages), None)

if package is None:
Expand Down
92 changes: 76 additions & 16 deletions tests/mixology/version_solver/test_dependency_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from copy import deepcopy
from typing import TYPE_CHECKING
from unittest import mock

from poetry.factory import Factory
from poetry.mixology.version_solver import DependencyCache
Expand Down Expand Up @@ -29,20 +30,20 @@ def test_solver_dependency_cache_respects_source_type(
add_to_repo(repo, "demo", "1.0.0")

cache = DependencyCache(provider)
cache.search_for.cache_clear()
cache._search_for_cached.cache_clear()

# ensure cache was never hit for both calls
cache.search_for(dependency_pypi)
cache.search_for(dependency_git)
assert not cache.search_for.cache_info().hits
cache.search_for(dependency_pypi, 0)
cache.search_for(dependency_git, 0)
assert not cache._search_for_cached.cache_info().hits

# increase test coverage by searching for copies
# (when searching for the exact same object, __eq__ is never called)
packages_pypi = cache.search_for(deepcopy(dependency_pypi))
packages_git = cache.search_for(deepcopy(dependency_git))
packages_pypi = cache.search_for(deepcopy(dependency_pypi), 0)
packages_git = cache.search_for(deepcopy(dependency_git), 0)

assert cache.search_for.cache_info().hits == 2
assert cache.search_for.cache_info().currsize == 2
assert cache._search_for_cached.cache_info().hits == 2
assert cache._search_for_cached.cache_info().currsize == 2

assert len(packages_pypi) == len(packages_git) == 1
assert packages_pypi != packages_git
Expand All @@ -60,6 +61,65 @@ def test_solver_dependency_cache_respects_source_type(
assert package_git.package.source_resolved_reference == MOCK_DEFAULT_GIT_REVISION


def test_solver_dependency_cache_pulls_from_prior_level_cache(
root: ProjectPackage, provider: Provider, repo: Repository
) -> None:
dependency_pypi = Factory.create_dependency("demo", ">=0.1.0")
dependency_pypi_constrained = Factory.create_dependency("demo", ">=0.1.0,<2.0.0")
root.add_dependency(dependency_pypi)
root.add_dependency(dependency_pypi_constrained)
add_to_repo(repo, "demo", "1.0.0")

wrapped_provider = mock.Mock(wraps=provider)
cache = DependencyCache(wrapped_provider)
cache._search_for_cached.cache_clear()

# On first call, provider.search_for() should be called and the cache
# populated.
cache.search_for(dependency_pypi, 0)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0]
assert cache._search_for_cached.cache_info().hits == 0
assert cache._search_for_cached.cache_info().misses == 1

# On second call at level 1, neither provider.search_for() nor
# cache._search_for_cached() should have been called again, and the cache
# should remain the same.
cache.search_for(dependency_pypi, 1)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0]
assert set(cache._cached_dependencies_by_level.keys()) == {0}
assert cache._search_for_cached.cache_info().hits == 1
assert cache._search_for_cached.cache_info().misses == 1

# On third call at level 2 with an updated constraint for the `demo`
# package should not call provider.search_for(), but should call
# cache._search_for_cached() and update the cache.
cache.search_for(dependency_pypi_constrained, 2)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0]
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[2]
assert set(cache._cached_dependencies_by_level.keys()) == {0, 2}
assert cache._search_for_cached.cache_info().hits == 1
assert cache._search_for_cached.cache_info().misses == 2

# Clearing the level 2 and level 1 caches should invalidate the lru_cache
# on cache.search_for and wipe out the level 2 cache while preserving the
# level 0 cache.
cache.clear_level(2)
cache.clear_level(1)
cache.search_for(dependency_pypi, 0)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0]
assert set(cache._cached_dependencies_by_level.keys()) == {0}
assert cache._search_for_cached.cache_info().hits == 0
assert cache._search_for_cached.cache_info().misses == 1


def test_solver_dependency_cache_respects_subdirectories(
root: ProjectPackage, provider: Provider, repo: Repository
) -> None:
Expand All @@ -84,20 +144,20 @@ def test_solver_dependency_cache_respects_subdirectories(
root.add_dependency(dependency_one_copy)

cache = DependencyCache(provider)
cache.search_for.cache_clear()
cache._search_for_cached.cache_clear()

# ensure cache was never hit for both calls
cache.search_for(dependency_one)
cache.search_for(dependency_one_copy)
assert not cache.search_for.cache_info().hits
cache.search_for(dependency_one, 0)
cache.search_for(dependency_one_copy, 0)
assert not cache._search_for_cached.cache_info().hits

# increase test coverage by searching for copies
# (when searching for the exact same object, __eq__ is never called)
packages_one = cache.search_for(deepcopy(dependency_one))
packages_one_copy = cache.search_for(deepcopy(dependency_one_copy))
packages_one = cache.search_for(deepcopy(dependency_one), 0)
packages_one_copy = cache.search_for(deepcopy(dependency_one_copy), 0)

assert cache.search_for.cache_info().hits == 2
assert cache.search_for.cache_info().currsize == 2
assert cache._search_for_cached.cache_info().hits == 2
assert cache._search_for_cached.cache_info().currsize == 2

assert len(packages_one) == len(packages_one_copy) == 1

Expand Down