diff --git a/src/poetry/mixology/version_solver.py b/src/poetry/mixology/version_solver.py index 19de00afccc..81dde5d3369 100644 --- a/src/poetry/mixology/version_solver.py +++ b/src/poetry/mixology/version_solver.py @@ -39,15 +39,20 @@ class DependencyCache: """ def __init__(self, provider: Provider) -> None: - self.provider = provider - self.cache: dict[ - tuple[str, str | None, str | None, str | None, str | None], - list[DependencyPackage], - ] = {} + self._provider = provider + self._cache: dict[ + int, + dict[ + tuple[str, str | None, str | None, str | None, str | None], + list[DependencyPackage], + ], + ] = collections.defaultdict(dict) self.search_for = functools.lru_cache(maxsize=128)(self._search_for) - def _search_for(self, dependency: Dependency) -> list[DependencyPackage]: + def _search_for( + self, dependency: Dependency, level: int + ) -> list[DependencyPackage]: key = ( dependency.complete_name, dependency.source_type, @@ -56,12 +61,17 @@ def _search_for(self, dependency: Dependency) -> list[DependencyPackage]: dependency.source_subdirectory, ) - packages = self.cache.get(key) - - if packages: - packages = [ - p for p in packages if dependency.constraint.allows(p.package.version) - ] + for check_level in range(level, -1, -1): + packages = self._cache[check_level].get(key) + if packages is not None: + packages = [ + p + for p in packages + if dependency.constraint.allows(p.package.version) + ] + break + else: + packages = None # provider.search_for() normally does not include pre-release packages # (unless requested), but will include them if there are no other @@ -71,14 +81,14 @@ def _search_for(self, dependency: Dependency) -> list[DependencyPackage]: # nothing, we need to call provider.search_for() again as it may return # additional results this time. if not packages: - packages = self.provider.search_for(dependency) - - self.cache[key] = packages + packages = self._provider.search_for(dependency) + self._cache[level][key] = packages return packages - def clear(self) -> None: - self.cache.clear() + def clear_level(self, level: int) -> None: + self.search_for.cache_clear() + self._cache.pop(level, None) class VersionSolver: @@ -318,9 +328,9 @@ def _resolve_conflict(self, incompatibility: Incompatibility) -> Incompatibility self._solution.decision_level, previous_satisfier_level, -1 ): self._contradicted_incompatibilities.pop(level, None) + self._dependency_cache.clear_level(level) self._solution.backtrack(previous_satisfier_level) - self._dependency_cache.clear() if new_incompatibility: self._add_incompatibility(incompatibility) @@ -418,7 +428,11 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]: if locked: return is_specific_marker, Preference.LOCKED, 1 - num_packages = len(self._dependency_cache.search_for(dependency)) + num_packages = len( + self._dependency_cache.search_for( + dependency, self._solution.decision_level + ) + ) if num_packages < 2: preference = Preference.NO_CHOICE @@ -435,7 +449,9 @@ def _get_min(dependency: Dependency) -> tuple[bool, int, int]: locked = self._provider.get_locked(dependency) if locked is None: - packages = self._dependency_cache.search_for(dependency) + packages = self._dependency_cache.search_for( + dependency, self._solution.decision_level + ) package = next(iter(packages), None) if package is None: diff --git a/tests/mixology/version_solver/test_dependency_cache.py b/tests/mixology/version_solver/test_dependency_cache.py index dffa7e535be..b3bd8721da4 100644 --- a/tests/mixology/version_solver/test_dependency_cache.py +++ b/tests/mixology/version_solver/test_dependency_cache.py @@ -2,6 +2,7 @@ from copy import deepcopy from typing import TYPE_CHECKING +from unittest import mock from poetry.factory import Factory from poetry.mixology.version_solver import DependencyCache @@ -32,14 +33,14 @@ def test_solver_dependency_cache_respects_source_type( cache.search_for.cache_clear() # ensure cache was never hit for both calls - cache.search_for(dependency_pypi) - cache.search_for(dependency_git) + cache.search_for(dependency_pypi, 0) + cache.search_for(dependency_git, 0) assert not cache.search_for.cache_info().hits # increase test coverage by searching for copies # (when searching for the exact same object, __eq__ is never called) - packages_pypi = cache.search_for(deepcopy(dependency_pypi)) - packages_git = cache.search_for(deepcopy(dependency_git)) + packages_pypi = cache.search_for(deepcopy(dependency_pypi), 0) + packages_git = cache.search_for(deepcopy(dependency_git), 0) assert cache.search_for.cache_info().hits == 2 assert cache.search_for.cache_info().currsize == 2 @@ -60,6 +61,44 @@ def test_solver_dependency_cache_respects_source_type( assert package_git.package.source_resolved_reference == MOCK_DEFAULT_GIT_REVISION +def test_solver_dependency_cache_pulls_from_prior_level_cache( + root: ProjectPackage, provider: Provider, repo: Repository +) -> None: + dependency_pypi = Factory.create_dependency("demo", ">=0.1.0") + root.add_dependency(dependency_pypi) + add_to_repo(repo, "demo", "1.0.0") + + wrapped_provider = mock.Mock(wraps=provider) + cache = DependencyCache(wrapped_provider) + cache.search_for.cache_clear() + + # On first call, provider.search_for() should be called and the level-0 + # cache populated. + cache.search_for(dependency_pypi, 0) + assert len(wrapped_provider.search_for.mock_calls) == 1 + assert ("demo", None, None, None, None) in cache._cache[0] + assert cache.search_for.cache_info().hits == 0 + assert cache.search_for.cache_info().misses == 1 + + # On second call at level 1, provider.search_for() should not be called + # again and the level-1 cache should be populated from the level-0 cache. + cache.search_for(dependency_pypi, 1) + assert len(wrapped_provider.search_for.mock_calls) == 1 + assert ("demo", None, None, None, None) in cache._cache[1] + assert cache._cache[0] == cache._cache[1] + assert cache.search_for.cache_info().hits == 0 + assert cache.search_for.cache_info().misses == 2 + + # Clearing the level 1 cache should invalidate the lru_cache on + # cache.search_for and wipe out the level 1 cache while preserving the + # level 0 cache. + cache.clear_level(1) + assert set(cache._cache.keys()) == {0} + assert ("demo", None, None, None, None) in cache._cache[0] + assert cache.search_for.cache_info().hits == 0 + assert cache.search_for.cache_info().misses == 0 + + def test_solver_dependency_cache_respects_subdirectories( root: ProjectPackage, provider: Provider, repo: Repository ) -> None: @@ -87,14 +126,14 @@ def test_solver_dependency_cache_respects_subdirectories( cache.search_for.cache_clear() # ensure cache was never hit for both calls - cache.search_for(dependency_one) - cache.search_for(dependency_one_copy) + cache.search_for(dependency_one, 0) + cache.search_for(dependency_one_copy, 0) assert not cache.search_for.cache_info().hits # increase test coverage by searching for copies # (when searching for the exact same object, __eq__ is never called) - packages_one = cache.search_for(deepcopy(dependency_one)) - packages_one_copy = cache.search_for(deepcopy(dependency_one_copy)) + packages_one = cache.search_for(deepcopy(dependency_one), 0) + packages_one_copy = cache.search_for(deepcopy(dependency_one_copy), 0) assert cache.search_for.cache_info().hits == 2 assert cache.search_for.cache_info().currsize == 2