Skip to content

Commit

Permalink
perf: improve DependencyCache lru_cache hit rate, avoid iterating thr…
Browse files Browse the repository at this point in the history
…ough lots of levels (python-poetry#7950)
  • Loading branch information
chriskuehl authored and radoering committed May 23, 2023
1 parent 0adc1c5 commit 9217975
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 60 deletions.
94 changes: 62 additions & 32 deletions src/poetry/mixology/version_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time

from typing import TYPE_CHECKING
from typing import Optional
from typing import Tuple

from poetry.core.packages.dependency import Dependency

Expand All @@ -29,6 +31,11 @@
_conflict = object()


DependencyCacheKey = Tuple[
str, Optional[str], Optional[str], Optional[str], Optional[str]
]


class DependencyCache:
"""
A cache of the valid dependencies.
Expand All @@ -40,36 +47,37 @@ class DependencyCache:

def __init__(self, provider: Provider) -> None:
self._provider = provider
self._cache: dict[
int,
dict[
tuple[str, str | None, str | None, str | None, str | None],
list[DependencyPackage],
],
] = collections.defaultdict(dict)

self.search_for = functools.lru_cache(maxsize=128)(self._search_for)
# self._cache maps a package name to a stack of cached package lists,
# ordered by the decision level which added them to the cache. This is
# done so that when backtracking we can maintain cache entries from
# previous decision levels, while clearing cache entries from only the
# rolled back levels.
#
# In order to maintain the integrity of the cache, `clear_level()`
# needs to be called in descending order as decision levels are
# backtracked so that the correct items can be popped from the stack.
self._cache: dict[DependencyCacheKey, list[list[DependencyPackage]]] = (
collections.defaultdict(list)
)
self._cached_dependencies_by_level: dict[int, list[DependencyCacheKey]] = (
collections.defaultdict(list)
)

self._search_for_cached = functools.lru_cache(maxsize=128)(self._search_for)

def _search_for(
self, dependency: Dependency, level: int
self,
dependency: Dependency,
key: DependencyCacheKey,
) -> list[DependencyPackage]:
key = (
dependency.complete_name,
dependency.source_type,
dependency.source_url,
dependency.source_reference,
dependency.source_subdirectory,
)

for check_level in range(level, -1, -1):
packages = self._cache[check_level].get(key)
if packages is not None:
packages = [
p
for p in packages
if dependency.constraint.allows(p.package.version)
]
break
cache_entries = self._cache[key]
if cache_entries:
packages = [
p
for p in cache_entries[-1]
if dependency.constraint.allows(p.package.version)
]
else:
packages = None

Expand All @@ -83,12 +91,33 @@ def _search_for(
if not packages:
packages = self._provider.search_for(dependency)

self._cache[level][key] = packages
return packages

def search_for(
self,
dependency: Dependency,
decision_level: int,
) -> list[DependencyPackage]:
key = (
dependency.complete_name,
dependency.source_type,
dependency.source_url,
dependency.source_reference,
dependency.source_subdirectory,
)

packages = self._search_for_cached(dependency, key)
if not self._cache[key] or self._cache[key][-1] is not packages:
self._cache[key].append(packages)
self._cached_dependencies_by_level[decision_level].append(key)

return packages

def clear_level(self, level: int) -> None:
self.search_for.cache_clear()
self._cache.pop(level, None)
if level in self._cached_dependencies_by_level:
self._search_for_cached.cache_clear()
for key in self._cached_dependencies_by_level.pop(level):
self._cache[key].pop()


class VersionSolver:
Expand Down Expand Up @@ -327,9 +356,10 @@ def _resolve_conflict(self, incompatibility: Incompatibility) -> Incompatibility
for level in range(
self._solution.decision_level, previous_satisfier_level, -1
):
self._contradicted_incompatibilities.difference_update(
self._contradicted_incompatibilities_by_level.pop(level, set()),
)
if level in self._contradicted_incompatibilities_by_level:
self._contradicted_incompatibilities.difference_update(
self._contradicted_incompatibilities_by_level.pop(level),
)
self._dependency_cache.clear_level(level)

self._solution.backtrack(previous_satisfier_level)
Expand Down
77 changes: 49 additions & 28 deletions tests/mixology/version_solver/test_dependency_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ def test_solver_dependency_cache_respects_source_type(
add_to_repo(repo, "demo", "1.0.0")

cache = DependencyCache(provider)
cache.search_for.cache_clear()
cache._search_for_cached.cache_clear()

# ensure cache was never hit for both calls
cache.search_for(dependency_pypi, 0)
cache.search_for(dependency_git, 0)
assert not cache.search_for.cache_info().hits
assert not cache._search_for_cached.cache_info().hits

# increase test coverage by searching for copies
# (when searching for the exact same object, __eq__ is never called)
packages_pypi = cache.search_for(deepcopy(dependency_pypi), 0)
packages_git = cache.search_for(deepcopy(dependency_git), 0)

assert cache.search_for.cache_info().hits == 2
assert cache.search_for.cache_info().currsize == 2
assert cache._search_for_cached.cache_info().hits == 2
assert cache._search_for_cached.cache_info().currsize == 2

assert len(packages_pypi) == len(packages_git) == 1
assert packages_pypi != packages_git
Expand All @@ -65,38 +65,59 @@ def test_solver_dependency_cache_pulls_from_prior_level_cache(
root: ProjectPackage, provider: Provider, repo: Repository
) -> None:
dependency_pypi = Factory.create_dependency("demo", ">=0.1.0")
dependency_pypi_constrained = Factory.create_dependency("demo", ">=0.1.0,<2.0.0")
root.add_dependency(dependency_pypi)
root.add_dependency(dependency_pypi_constrained)
add_to_repo(repo, "demo", "1.0.0")

wrapped_provider = mock.Mock(wraps=provider)
cache = DependencyCache(wrapped_provider)
cache.search_for.cache_clear()
cache._search_for_cached.cache_clear()

# On first call, provider.search_for() should be called and the level-0
# cache populated.
# On first call, provider.search_for() should be called and the cache
# populated.
cache.search_for(dependency_pypi, 0)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache[0]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 1

# On second call at level 1, provider.search_for() should not be called
# again and the level-1 cache should be populated from the level-0 cache.
assert ("demo", None, None, None, None) in cache._cache
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0]
assert cache._search_for_cached.cache_info().hits == 0
assert cache._search_for_cached.cache_info().misses == 1

# On second call at level 1, neither provider.search_for() nor
# cache._search_for_cached() should have been called again, and the cache
# should remain the same.
cache.search_for(dependency_pypi, 1)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache[1]
assert cache._cache[0] == cache._cache[1]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 2

# Clearing the level 1 cache should invalidate the lru_cache on
# cache.search_for and wipe out the level 1 cache while preserving the
assert ("demo", None, None, None, None) in cache._cache
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0]
assert set(cache._cached_dependencies_by_level.keys()) == {0}
assert cache._search_for_cached.cache_info().hits == 1
assert cache._search_for_cached.cache_info().misses == 1

# On third call at level 2 with an updated constraint for the `demo`
# package should not call provider.search_for(), but should call
# cache._search_for_cached() and update the cache.
cache.search_for(dependency_pypi_constrained, 2)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0]
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[2]
assert set(cache._cached_dependencies_by_level.keys()) == {0, 2}
assert cache._search_for_cached.cache_info().hits == 1
assert cache._search_for_cached.cache_info().misses == 2

# Clearing the level 2 and level 1 caches should invalidate the lru_cache
# on cache.search_for and wipe out the level 2 cache while preserving the
# level 0 cache.
cache.clear_level(2)
cache.clear_level(1)
assert set(cache._cache.keys()) == {0}
assert ("demo", None, None, None, None) in cache._cache[0]
assert cache.search_for.cache_info().hits == 0
assert cache.search_for.cache_info().misses == 0
cache.search_for(dependency_pypi, 0)
assert len(wrapped_provider.search_for.mock_calls) == 1
assert ("demo", None, None, None, None) in cache._cache
assert ("demo", None, None, None, None) in cache._cached_dependencies_by_level[0]
assert set(cache._cached_dependencies_by_level.keys()) == {0}
assert cache._search_for_cached.cache_info().hits == 0
assert cache._search_for_cached.cache_info().misses == 1


def test_solver_dependency_cache_respects_subdirectories(
Expand All @@ -123,20 +144,20 @@ def test_solver_dependency_cache_respects_subdirectories(
root.add_dependency(dependency_one_copy)

cache = DependencyCache(provider)
cache.search_for.cache_clear()
cache._search_for_cached.cache_clear()

# ensure cache was never hit for both calls
cache.search_for(dependency_one, 0)
cache.search_for(dependency_one_copy, 0)
assert not cache.search_for.cache_info().hits
assert not cache._search_for_cached.cache_info().hits

# increase test coverage by searching for copies
# (when searching for the exact same object, __eq__ is never called)
packages_one = cache.search_for(deepcopy(dependency_one), 0)
packages_one_copy = cache.search_for(deepcopy(dependency_one_copy), 0)

assert cache.search_for.cache_info().hits == 2
assert cache.search_for.cache_info().currsize == 2
assert cache._search_for_cached.cache_info().hits == 2
assert cache._search_for_cached.cache_info().currsize == 2

assert len(packages_one) == len(packages_one_copy) == 1

Expand Down

0 comments on commit 9217975

Please sign in to comment.