Skip to content

Commit

Permalink
Add Support for Multiple Categories per LockedDependency
Browse files Browse the repository at this point in the history
  • Loading branch information
srilman committed Dec 23, 2023
1 parent 3553ef7 commit 77807e1
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 15 deletions.
58 changes: 43 additions & 15 deletions conda_lock/lockfile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import functools
import pathlib

from collections import defaultdict
from textwrap import dedent
from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Union
from typing import (
Collection,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Union,
)

import yaml

Expand Down Expand Up @@ -39,6 +48,23 @@ def _seperator_munge_get(
return d[key.replace("_", "-")]


def _truncate_main_category(
planned: Mapping[str, Union[List[LockedDependency], LockedDependency]],
) -> None:
"""
Given the package dependencies with their respective categories
for any package that is in the main category, remove all other associated categories
"""
# Packages in the main category are always installed
# so other categories are not necessary
for targets in planned.values():
if not isinstance(targets, list):
targets = [targets]
for target in targets:
if "main" in target.categories:
target.categories = {"main"}


def apply_categories(
requested: Dict[str, Dependency],
planned: Mapping[str, Union[List[LockedDependency], LockedDependency]],
Expand Down Expand Up @@ -112,27 +138,31 @@ def dep_name(manager: str, dep: str) -> str:

by_category[request.category].append(request.name)

# now, map each package to its root request preferring the ones earlier in the
# list
# now, map each package to every root request that requires it
categories = [*categories, *(k for k in by_category if k not in categories)]
root_requests = {}
root_requests: DefaultDict[str, List[str]] = defaultdict(list)
for category in categories:
for root in by_category.get(category, []):
for transitive_dep in dependents[root]:
if transitive_dep not in root_requests:
root_requests[transitive_dep] = root
root_requests[transitive_dep].append(root)
# include root requests themselves
for name in requested:
root_requests[name] = name
root_requests[name].append(name)

for dep, root in root_requests.items():
source = requested[root]
for dep, roots in root_requests.items():
# try a conda target first
targets = _seperator_munge_get(planned, dep)
if not isinstance(targets, list):
targets = [targets]
for target in targets:
target.categories = {source.category}

for root in roots:
source = requested[root]
for target in targets:
target.categories.add(source.category)

# For any dep that is part of the 'main' category
# we should remove all other categories
_truncate_main_category(planned)


def parse_conda_lock_file(path: pathlib.Path) -> Lockfile:
Expand Down Expand Up @@ -164,9 +194,7 @@ def write_conda_lock_file(
content.filter_virtual_packages_inplace()
with path.open("w") as f:
if include_help_text:
categories: Set[str] = functools.reduce(
set.union, (set(p.categories) for p in content.package), set()
)
categories: Set[str] = set().union(*(p.categories for p in content.package))

def write_section(text: str) -> None:
lines = dedent(text).split("\n")
Expand Down
50 changes: 50 additions & 0 deletions tests/test_conda_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_add_auth_to_line,
_add_auth_to_lockfile,
_extract_domain,
_solve_for_arch,
_strip_auth_from_line,
_strip_auth_from_lockfile,
create_lockfile_from_spec,
Expand Down Expand Up @@ -597,6 +598,7 @@ def test_choose_wheel() -> None:
platform="linux-64",
)
assert len(solution) == 1
assert solution["fastavro"].categories == {"main"}
assert solution["fastavro"].hash == HashModel(
sha256="a111a384a786b7f1fd6a8a8307da07ccf4d4c425084e2d61bae33ecfb60de405"
)
Expand Down Expand Up @@ -1782,6 +1784,54 @@ def test_aggregate_lock_specs_invalid_pip_repos():
aggregate_lock_specs([base_spec, spec_a, spec_a_b], platforms=[])


def test_solve_arch_multiple_categories():
_conda_exe = determine_conda_executable(None, mamba=False, micromamba=False)
vpr = default_virtual_package_repodata()
channels = [Channel.from_string("conda-forge")]

with vpr, tempfile.NamedTemporaryFile(dir=".") as tf:
spec = LockSpecification(
dependencies={
"linux-64": [
VersionedDependency(
name="python",
version="=3.10.9",
manager="conda",
category="main",
extras=[],
),
VersionedDependency(
name="pandas",
version="=1.5.3",
manager="conda",
category="test",
extras=[],
),
VersionedDependency(
name="pyarrow",
version="=9.0.0",
manager="conda",
category="dev",
extras=[],
),
],
},
channels=channels,
# NB: this file must exist for relative path resolution to work
# in create_lockfile_from_spec
sources=[Path(tf.name)],
virtual_package_repo=vpr,
)

locked_deps = _solve_for_arch(_conda_exe, spec, "linux-64", channels, [])
python_deps = [dep for dep in locked_deps if dep.name == "python"]
assert len(python_deps) == 1
assert python_deps[0].categories == {"main"}
numpy_deps = [dep for dep in locked_deps if dep.name == "numpy"]
assert len(numpy_deps) == 1
assert numpy_deps[0].categories == {"test", "dev"}


def _check_package_installed(package: str, prefix: str):
import glob

Expand Down

0 comments on commit 77807e1

Please sign in to comment.