diff --git a/conda_lock/conda_lock.py b/conda_lock/conda_lock.py index 14af2ac7..71400490 100644 --- a/conda_lock/conda_lock.py +++ b/conda_lock/conda_lock.py @@ -245,7 +245,7 @@ def make_lock_files( update: Optional[List[str]] = None, include_dev_dependencies: bool = True, filename_template: Optional[str] = None, - filter_categories: bool = True, + filter_categories: bool = False, extras: Optional[AbstractSet[str]] = None, check_input_hash: bool = False, metadata_choices: AbstractSet[MetadataOption] = frozenset(), @@ -656,12 +656,8 @@ def _solve_for_arch( """ if update_spec is None: update_spec = UpdateSpecification() - # filter requested and locked dependencies to the current platform - dependencies = [ - dep - for dep in spec.dependencies - if (not dep.selectors.platform) or platform in dep.selectors.platform - ] + + dependencies = spec.dependencies[platform] locked = [dep for dep in update_spec.locked if dep.platform == platform] requested_deps_by_name = { manager: {dep.name: dep for dep in dependencies if dep.manager == manager} diff --git a/conda_lock/conda_solver.py b/conda_lock/conda_solver.py index c5359e21..906e55d8 100644 --- a/conda_lock/conda_solver.py +++ b/conda_lock/conda_solver.py @@ -9,17 +9,7 @@ import time from contextlib import contextmanager -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - MutableSequence, - Optional, - Sequence, - cast, -) +from typing import Dict, Iterable, Iterator, List, MutableSequence, Optional, Sequence from urllib.parse import urlsplit, urlunsplit import yaml diff --git a/conda_lock/models/lock_spec.py b/conda_lock/models/lock_spec.py index 345391ab..a6ce7b88 100644 --- a/conda_lock/models/lock_spec.py +++ b/conda_lock/models/lock_spec.py @@ -13,29 +13,12 @@ from conda_lock.virtual_package import FakeRepoData -class Selectors(StrictModel): - platform: Optional[List[str]] = None - - def __ior__(self, other: "Selectors") -> "Selectors": - if not isinstance(other, Selectors): - raise TypeError - if other.platform and self.platform: - for p in other.platform: - if p not in self.platform: - self.platform.append(p) - return self - - def for_platform(self, platform: str) -> bool: - return self.platform is None or platform in self.platform - - class _BaseDependency(StrictModel): name: str manager: Literal["conda", "pip"] = "conda" optional: bool = False category: str = "main" extras: List[str] = [] - selectors: Selectors = Selectors() class VersionedDependency(_BaseDependency): @@ -58,14 +41,17 @@ class Package(StrictModel): class LockSpecification(BaseModel): - dependencies: List[Dependency] + dependencies: Dict[str, List[Dependency]] # TODO: Should we store the auth info in here? channels: List[Channel] - platforms: List[str] sources: List[pathlib.Path] virtual_package_repo: Optional[FakeRepoData] = None allow_pypi_requests: bool = True + @property + def platforms(self) -> List[str]: + return list(self.dependencies.keys()) + def content_hash(self) -> Dict[str, str]: return { platform: self.content_hash_for_platform(platform) @@ -77,8 +63,9 @@ def content_hash_for_platform(self, platform: str) -> str: "channels": [c.json() for c in self.channels], "specs": [ p.dict() - for p in sorted(self.dependencies, key=lambda p: (p.manager, p.name)) - if p.selectors.for_platform(platform) + for p in sorted( + self.dependencies[platform], key=lambda p: (p.manager, p.name) + ) ], } if self.virtual_package_repo is not None: diff --git a/conda_lock/src_parser/__init__.py b/conda_lock/src_parser/__init__.py index d3cd0543..3c2f9169 100644 --- a/conda_lock/src_parser/__init__.py +++ b/conda_lock/src_parser/__init__.py @@ -12,7 +12,10 @@ parse_platforms_from_env_file, ) from conda_lock.src_parser.meta_yaml import parse_meta_yaml_file -from conda_lock.src_parser.pyproject_toml import parse_pyproject_toml +from conda_lock.src_parser.pyproject_toml import ( + parse_platforms_from_pyproject_toml, + parse_pyproject_toml, +) from conda_lock.virtual_package import FakeRepoData @@ -36,7 +39,7 @@ def _parse_platforms_from_srcs(src_files: List[pathlib.Path]) -> List[str]: if src_file.name == "meta.yaml": continue elif src_file.name == "pyproject.toml": - all_file_platforms.append(parse_pyproject_toml(src_file).platforms) + all_file_platforms.append(parse_platforms_from_pyproject_toml(src_file)) else: all_file_platforms.append(parse_platforms_from_env_file(src_file)) @@ -62,7 +65,7 @@ def _parse_source_files( if src_file.name == "meta.yaml": desired_envs.append(parse_meta_yaml_file(src_file, platforms)) elif src_file.name == "pyproject.toml": - desired_envs.append(parse_pyproject_toml(src_file)) + desired_envs.append(parse_pyproject_toml(src_file, platforms)) else: desired_envs.append(parse_environment_file(src_file, platforms)) return desired_envs @@ -85,24 +88,37 @@ def make_lock_spec( lock_specs = _parse_source_files(src_files, platforms) - lock_spec = aggregate_lock_specs(lock_specs) - lock_spec.virtual_package_repo = virtual_package_repo - lock_spec.platforms = platforms - lock_spec.channels = ( + aggregated_lock_spec = aggregate_lock_specs(lock_specs, platforms) + + # Use channel overrides if given, otherwise use the channels specified in the + # source files. + channels = ( [Channel.from_string(co) for co in channel_overrides] if channel_overrides - else lock_spec.channels + else aggregated_lock_spec.channels ) - if required_categories is not None: - + if required_categories is None: + dependencies = aggregated_lock_spec.dependencies + else: + # Filtering based on category (e.g. "main" or "dev") was requested. + # Thus we need to filter the specs based on the category. def dep_has_category(d: Dependency, categories: AbstractSet[str]) -> bool: return d.category in categories - lock_spec.dependencies = [ - d - for d in lock_spec.dependencies - if dep_has_category(d, categories=required_categories) - ] - - return lock_spec + dependencies = { + platform: [ + d + for d in dependencies + if dep_has_category(d, categories=required_categories) + ] + for platform, dependencies in aggregated_lock_spec.dependencies.items() + } + + return LockSpecification( + dependencies=dependencies, + channels=channels, + sources=aggregated_lock_spec.sources, + virtual_package_repo=virtual_package_repo, + allow_pypi_requests=aggregated_lock_spec.allow_pypi_requests, + ) diff --git a/conda_lock/src_parser/aggregation.py b/conda_lock/src_parser/aggregation.py index 2ea1604c..65ab4842 100644 --- a/conda_lock/src_parser/aggregation.py +++ b/conda_lock/src_parser/aggregation.py @@ -13,23 +13,31 @@ def aggregate_lock_specs( lock_specs: List[LockSpecification], + platforms: List[str], ) -> LockSpecification: - # unique dependencies - unique_deps: Dict[Tuple[str, str], Dependency] = {} - for dep in chain.from_iterable( - [lock_spec.dependencies for lock_spec in lock_specs] - ): - key = (dep.manager, dep.name) - if key in unique_deps: - # Override existing, but merge selectors - previous_selectors = unique_deps[key].selectors - previous_selectors |= dep.selectors - dep.selectors = previous_selectors - unique_deps[key] = dep - - dependencies = list(unique_deps.values()) + for lock_spec in lock_specs: + if set(lock_spec.platforms) != set(platforms): + raise ValueError( + f"Lock specifications must have the same platforms in order to be " + f"aggregated. Expected platforms are {set(platforms)}, but the lock " + f"specification from {[str(s) for s in lock_spec.sources]} has " + f"platforms {set(lock_spec.platforms)}." + ) + + dependencies: Dict[str, List[Dependency]] = {} + for platform in platforms: + # unique dependencies + unique_deps: Dict[Tuple[str, str], Dependency] = {} + for dep in chain.from_iterable( + lock_spec.dependencies.get(platform, []) for lock_spec in lock_specs + ): + key = (dep.manager, dep.name) + unique_deps[key] = dep + + dependencies[platform] = list(unique_deps.values()) + try: - channels = suffix_union(lock_spec.channels or [] for lock_spec in lock_specs) + channels = suffix_union(lock_spec.channels for lock_spec in lock_specs) except ValueError as e: raise ChannelAggregationError(*e.args) @@ -38,8 +46,7 @@ def aggregate_lock_specs( # Ensure channel are correctly ordered channels=channels, # uniquify metadata, preserving order - platforms=ordered_union(lock_spec.platforms or [] for lock_spec in lock_specs), - sources=ordered_union(lock_spec.sources or [] for lock_spec in lock_specs), + sources=ordered_union(lock_spec.sources for lock_spec in lock_specs), allow_pypi_requests=all( lock_spec.allow_pypi_requests for lock_spec in lock_specs ), diff --git a/conda_lock/src_parser/environment_yaml.py b/conda_lock/src_parser/environment_yaml.py index 5a853d6e..e1e221e7 100644 --- a/conda_lock/src_parser/environment_yaml.py +++ b/conda_lock/src_parser/environment_yaml.py @@ -7,7 +7,6 @@ import yaml from conda_lock.models.lock_spec import Dependency, LockSpecification -from conda_lock.src_parser.aggregation import aggregate_lock_specs from conda_lock.src_parser.conda_common import conda_spec_to_versioned_dep from conda_lock.src_parser.selectors import filter_platform_selectors @@ -27,10 +26,10 @@ def parse_conda_requirement(req: str) -> Tuple[str, str]: def _parse_environment_file_for_platform( - environment_file: pathlib.Path, content: str, + category: str, platform: str, -) -> LockSpecification: +) -> List[Dependency]: """ Parse dependencies from a conda environment specification for an assumed target platform. @@ -44,13 +43,7 @@ def _parse_environment_file_for_platform( """ filtered_content = "\n".join(filter_platform_selectors(content, platform=platform)) env_yaml_data = yaml.safe_load(filtered_content) - specs = env_yaml_data["dependencies"] - channels: List[str] = env_yaml_data.get("channels", []) - - # These extension fields are nonstandard - platforms: List[str] = env_yaml_data.get("platforms", []) - category: str = env_yaml_data.get("category") or "main" # Split out any sub spec sections from the dependencies mapping mapping_specs = [x for x in specs if not isinstance(x, str)] @@ -58,9 +51,7 @@ def _parse_environment_file_for_platform( dependencies: List[Dependency] = [] for spec in specs: - vdep = conda_spec_to_versioned_dep(spec, category) - vdep.selectors.platform = [platform] - dependencies.append(vdep) + dependencies.append(conda_spec_to_versioned_dep(spec, category)) for mapping_spec in mapping_specs: if "pip" in mapping_spec: @@ -88,12 +79,7 @@ def _parse_environment_file_for_platform( # ensure pip is in target env dependencies.append(parse_python_requirement("pip", manager="conda")) - return LockSpecification( - dependencies=dependencies, - channels=channels, # type: ignore - platforms=platforms, - sources=[environment_file], - ) + return dependencies def parse_platforms_from_env_file(environment_file: pathlib.Path) -> List[str]: @@ -127,26 +113,20 @@ def parse_environment_file( with environment_file.open("r") as fo: content = fo.read() + env_yaml_data = yaml.safe_load(content) + channels: List[str] = env_yaml_data.get("channels", []) + + # These extension fields are nonstandard + category: str = env_yaml_data.get("category") or "main" + # Parse with selectors for each target platform - spec = aggregate_lock_specs( - [ - _parse_environment_file_for_platform( - environment_file, - content, - platform, - ) - for platform in platforms - ] - ) + dep_map = { + platform: _parse_environment_file_for_platform(content, category, platform) + for platform in platforms + } - # Remove platform selectors if they apply to all targets - for dep in spec.dependencies: - if dep.selectors.platform == platforms: - dep.selectors.platform = None - - # Use the list of rendered platforms for the output spec only if - # there is a dependency that is not used on all platforms. - # This is unlike meta.yaml because environment-yaml files can contain an - # internal list of platforms, which should be used as long as it - spec.platforms = platforms - return spec + return LockSpecification( + dependencies=dep_map, + channels=channels, # type: ignore + sources=[environment_file], + ) diff --git a/conda_lock/src_parser/meta_yaml.py b/conda_lock/src_parser/meta_yaml.py index cc05e9a3..c85b4a9f 100644 --- a/conda_lock/src_parser/meta_yaml.py +++ b/conda_lock/src_parser/meta_yaml.py @@ -7,7 +7,7 @@ from conda_lock.common import get_in from conda_lock.models.lock_spec import Dependency, LockSpecification -from conda_lock.src_parser.aggregation import aggregate_lock_specs +from conda_lock.src_parser.conda_common import conda_spec_to_versioned_dep from conda_lock.src_parser.selectors import filter_platform_selectors @@ -94,31 +94,37 @@ def parse_meta_yaml_file( selectors other than platform. """ + if not meta_yaml_file.exists(): + raise FileNotFoundError(f"{meta_yaml_file} not found") + + with meta_yaml_file.open("r") as fo: + t = jinja2.Template(fo.read(), undefined=UndefinedNeverFail) + rendered = t.render() + meta_yaml_data = yaml.safe_load(rendered) + + channels = get_in(["extra", "channels"], meta_yaml_data, []) + # parse with selectors for each target platform - spec = aggregate_lock_specs( - [ - _parse_meta_yaml_file_for_platform(meta_yaml_file, platform) - for platform in platforms - ] - ) - # remove platform selectors if they apply to all targets - for dep in spec.dependencies: - if dep.selectors.platform == platforms: - dep.selectors.platform = None + dep_map = { + platform: _parse_meta_yaml_file_for_platform(meta_yaml_file, platform) + for platform in platforms + } - return spec + return LockSpecification( + dependencies=dep_map, + channels=channels, + sources=[meta_yaml_file], + ) def _parse_meta_yaml_file_for_platform( meta_yaml_file: pathlib.Path, platform: str, -) -> LockSpecification: +) -> List[Dependency]: """Parse a simple meta-yaml file for dependencies, assuming the target platform. * This does not support multi-output files and will ignore all lines with selectors other than platform """ - if not meta_yaml_file.exists(): - raise FileNotFoundError(f"{meta_yaml_file} not found") with meta_yaml_file.open("r") as fo: filtered_recipe = "\n".join( @@ -129,17 +135,13 @@ def _parse_meta_yaml_file_for_platform( meta_yaml_data = yaml.safe_load(rendered) - channels = get_in(["extra", "channels"], meta_yaml_data, []) dependencies: List[Dependency] = [] def add_spec(spec: str, category: str) -> None: if spec is None: return - from .conda_common import conda_spec_to_versioned_dep - dep = conda_spec_to_versioned_dep(spec, category) - dep.selectors.platform = [platform] dependencies.append(dep) def add_requirements_from_recipe_or_output(yaml_data: Dict[str, Any]) -> None: @@ -154,9 +156,4 @@ def add_requirements_from_recipe_or_output(yaml_data: Dict[str, Any]) -> None: for output in get_in(["outputs"], meta_yaml_data, []): add_requirements_from_recipe_or_output(output) - return LockSpecification( - dependencies=dependencies, - channels=channels, - platforms=[platform], - sources=[meta_yaml_file], - ) + return dependencies diff --git a/conda_lock/src_parser/pyproject_toml.py b/conda_lock/src_parser/pyproject_toml.py index 52d67a9f..21d6f3d7 100644 --- a/conda_lock/src_parser/pyproject_toml.py +++ b/conda_lock/src_parser/pyproject_toml.py @@ -85,6 +85,7 @@ def poetry_version_to_conda_version(version_string: Optional[str]) -> Optional[s def parse_poetry_pyproject_toml( path: pathlib.Path, + platforms: List[str], contents: Mapping[str, Any], ) -> LockSpecification: """ @@ -183,11 +184,14 @@ def parse_poetry_pyproject_toml( ) ) - return specification_with_dependencies(path, contents, dependencies) + return specification_with_dependencies(path, platforms, contents, dependencies) def specification_with_dependencies( - path: pathlib.Path, toml_contents: Mapping[str, Any], dependencies: List[Dependency] + path: pathlib.Path, + platforms: List[str], + toml_contents: Mapping[str, Any], + dependencies: List[Dependency], ) -> LockSpecification: force_pypi = set() for depname, depattrs in get_in( @@ -217,9 +221,8 @@ def specification_with_dependencies( dep.manager = "pip" return LockSpecification( - dependencies=dependencies, + dependencies={platform: dependencies for platform in platforms}, channels=get_in(["tool", "conda-lock", "channels"], toml_contents, []), - platforms=get_in(["tool", "conda-lock", "platforms"], toml_contents, []), sources=[path], allow_pypi_requests=get_in( ["tool", "conda-lock", "allow-pypi-requests"], toml_contents, True @@ -284,6 +287,7 @@ def parse_python_requirement( def parse_requirements_pyproject_toml( pyproject_toml_path: pathlib.Path, + platforms: List[str], contents: Mapping[str, Any], prefix: Sequence[str], main_tag: str, @@ -311,11 +315,14 @@ def parse_requirements_pyproject_toml( ) ) - return specification_with_dependencies(pyproject_toml_path, contents, dependencies) + return specification_with_dependencies( + pyproject_toml_path, platforms, contents, dependencies + ) def parse_pdm_pyproject_toml( path: pathlib.Path, + platforms: List[str], contents: Mapping[str, Any], ) -> LockSpecification: """ @@ -324,6 +331,7 @@ def parse_pdm_pyproject_toml( """ res = parse_requirements_pyproject_toml( path, + platforms, contents, prefix=("project",), main_tag="dependencies", @@ -342,13 +350,23 @@ def parse_pdm_pyproject_toml( ] ) - res.dependencies.extend(dev_reqs) + for dep_list in res.dependencies.values(): + dep_list.extend(dev_reqs) return res +def parse_platforms_from_pyproject_toml( + pyproject_toml: pathlib.Path, +) -> List[str]: + with pyproject_toml.open("rb") as fp: + contents = toml_load(fp) + return get_in(["tool", "conda-lock", "platforms"], contents, []) + + def parse_pyproject_toml( pyproject_toml: pathlib.Path, + platforms: List[str], ) -> LockSpecification: with pyproject_toml.open("rb") as fp: contents = toml_load(fp) @@ -397,4 +415,4 @@ def parse_pyproject_toml( "Could not detect build-system in pyproject.toml. Assuming poetry" ) - return parse(pyproject_toml, contents) + return parse(pyproject_toml, platforms, contents) diff --git a/tests/test_conda_lock.py b/tests/test_conda_lock.py index d595ced8..7b59a2b4 100644 --- a/tests/test_conda_lock.py +++ b/tests/test_conda_lock.py @@ -62,7 +62,7 @@ parse_conda_lock_file, ) from conda_lock.models.channel import Channel -from conda_lock.models.lock_spec import Selectors, VersionedDependency +from conda_lock.models.lock_spec import VersionedDependency from conda_lock.pypi_solver import parse_pip_requirement, solve_pypi from conda_lock.src_parser import ( DEFAULT_PLATFORMS, @@ -76,6 +76,7 @@ parse_platforms_from_env_file, ) from conda_lock.src_parser.pyproject_toml import ( + parse_platforms_from_pyproject_toml, parse_pyproject_toml, poetry_version_to_conda_version, ) @@ -337,7 +338,7 @@ def test_lock_poetry_ibis( def test_parse_environment_file(gdal_environment: Path): res = parse_environment_file(gdal_environment, DEFAULT_PLATFORMS) assert all( - x in res.dependencies + x in res.dependencies[plat] for x in [ VersionedDependency( name="python", @@ -350,14 +351,16 @@ def test_parse_environment_file(gdal_environment: Path): version="", ), ] + for plat in DEFAULT_PLATFORMS ) - assert ( + assert all( VersionedDependency( name="toolz", manager="pip", version="*", ) - in res.dependencies + in res.dependencies[plat] + for plat in DEFAULT_PLATFORMS ) assert all( Channel.from_string(x) in res.channels for x in ["conda-forge", "defaults"] @@ -366,16 +369,17 @@ def test_parse_environment_file(gdal_environment: Path): def test_parse_environment_file_with_pip(pip_environment: Path): res = parse_environment_file(pip_environment, DEFAULT_PLATFORMS) - assert [dep for dep in res.dependencies if dep.manager == "pip"] == [ - VersionedDependency( - name="requests-toolbelt", - manager="pip", - optional=False, - category="main", - extras=[], - version="=0.9.1", - ) - ] + for plat in DEFAULT_PLATFORMS: + assert [dep for dep in res.dependencies[plat] if dep.manager == "pip"] == [ + VersionedDependency( + name="requests-toolbelt", + manager="pip", + optional=False, + category="main", + extras=[], + version="=0.9.1", + ) + ] def test_parse_env_file_with_filters_no_args(filter_conda_environment: Path): @@ -385,32 +389,42 @@ def test_parse_env_file_with_filters_no_args(filter_conda_environment: Path): assert res.channels == [Channel.from_string("conda-forge")] assert all( - x in res.dependencies - for x in [ - VersionedDependency( - name="python", - manager="conda", - version="<3.11", + x in res.dependencies[plat] + for x, platforms in [ + ( + VersionedDependency( + name="python", + manager="conda", + version="<3.11", + ), + platforms, ), - VersionedDependency( - name="clang_osx-arm64", - manager="conda", - version="", - selectors=Selectors(platform=["osx-arm64"]), + ( + VersionedDependency( + name="clang_osx-arm64", + manager="conda", + version="", + ), + ["osx-arm64"], ), - VersionedDependency( - name="clang_osx-64", - manager="conda", - version="", - selectors=Selectors(platform=["osx-64"]), + ( + VersionedDependency( + name="clang_osx-64", + manager="conda", + version="", + ), + ["osx-64"], ), - VersionedDependency( - name="gcc_linux-64", - manager="conda", - version=">=6", - selectors=Selectors(platform=["linux-64"]), + ( + VersionedDependency( + name="gcc_linux-64", + manager="conda", + version=">=6", + ), + ["linux-64"], ), ] + for plat in platforms ) @@ -420,26 +434,34 @@ def test_parse_env_file_with_filters_defaults(filter_conda_environment: Path): assert res.channels == [Channel.from_string("conda-forge")] assert all( - x in res.dependencies - for x in [ - VersionedDependency( - name="python", - manager="conda", - version="<3.11", + x in res.dependencies[plat] + for x, platforms in [ + ( + VersionedDependency( + name="python", + manager="conda", + version="<3.11", + ), + DEFAULT_PLATFORMS, ), - VersionedDependency( - name="clang_osx-64", - manager="conda", - version="", - selectors=Selectors(platform=["osx-64"]), + ( + VersionedDependency( + name="clang_osx-64", + manager="conda", + version="", + ), + ["osx-64"], ), - VersionedDependency( - name="gcc_linux-64", - manager="conda", - version=">=6", - selectors=Selectors(platform=["linux-64"]), + ( + VersionedDependency( + name="gcc_linux-64", + manager="conda", + version=">=6", + ), + ["linux-64"], ), ] + for plat in platforms ) @@ -561,31 +583,25 @@ def test_parse_pip_requirement( def test_parse_meta_yaml_file(meta_yaml_environment: Path): - res = parse_meta_yaml_file(meta_yaml_environment, ["linux-64", "osx-64"]) - specs = {dep.name: dep for dep in res.dependencies} - assert all(x in specs for x in ["python", "numpy"]) - assert all( - dep.selectors - == Selectors( - platform=None - ) # Platform will be set to None if all dependencies are the same - for dep in specs.values() - ) - # Ensure that this dep specified by a python selector is ignored - assert "enum34" not in specs - # Ensure that this platform specific dep is included - assert "zlib" in specs - assert specs["pytest"].category == "dev" - assert specs["pytest"].optional is True + platforms = ["linux-64", "osx-64"] + res = parse_meta_yaml_file(meta_yaml_environment, platforms) + for plat in platforms: + specs = {dep.name: dep for dep in res.dependencies[plat]} + assert all(x in specs for x in ["python", "numpy"]) + # Ensure that this dep specified by a python selector is ignored + assert "enum34" not in specs + # Ensure that this platform specific dep is included + assert "zlib" in specs + assert specs["pytest"].category == "dev" + assert specs["pytest"].optional is True def test_parse_poetry(poetry_pyproject_toml: Path): - res = parse_pyproject_toml( - poetry_pyproject_toml, - ) + res = parse_pyproject_toml(poetry_pyproject_toml, ["linux-64"]) specs = { - dep.name: typing.cast(VersionedDependency, dep) for dep in res.dependencies + dep.name: typing.cast(VersionedDependency, dep) + for dep in res.dependencies["linux-64"] } assert specs["requests"].version == ">=2.13.0,<3.0.0" @@ -603,9 +619,8 @@ def test_parse_poetry(poetry_pyproject_toml: Path): def test_parse_poetry_no_pypi(poetry_pyproject_toml_no_pypi: Path): - res = parse_pyproject_toml( - poetry_pyproject_toml_no_pypi, - ) + platforms = parse_platforms_from_pyproject_toml(poetry_pyproject_toml_no_pypi) + res = parse_pyproject_toml(poetry_pyproject_toml_no_pypi, platforms) assert res.allow_pypi_requests is False @@ -652,39 +667,41 @@ def test_spec_poetry(poetry_pyproject_toml: Path): spec = make_lock_spec( src_files=[poetry_pyproject_toml], virtual_package_repo=virtual_package_repo ) - deps = {d.name for d in spec.dependencies} - assert "tomlkit" in deps - assert "pytest" in deps - assert "requests" in deps + for plat in spec.platforms: + deps = {d.name for d in spec.dependencies[plat]} + assert "tomlkit" in deps + assert "pytest" in deps + assert "requests" in deps spec = make_lock_spec( src_files=[poetry_pyproject_toml], virtual_package_repo=virtual_package_repo, required_categories={"main", "dev"}, ) - deps = {d.name for d in spec.dependencies} - assert "tomlkit" not in deps - assert "pytest" in deps - assert "requests" in deps + for plat in spec.platforms: + deps = {d.name for d in spec.dependencies[plat]} + assert "tomlkit" not in deps + assert "pytest" in deps + assert "requests" in deps spec = make_lock_spec( src_files=[poetry_pyproject_toml], virtual_package_repo=virtual_package_repo, required_categories={"main"}, ) - deps = {d.name for d in spec.dependencies} - assert "tomlkit" not in deps - assert "pytest" not in deps - assert "requests" in deps + for plat in spec.platforms: + deps = {d.name for d in spec.dependencies[plat]} + assert "tomlkit" not in deps + assert "pytest" not in deps + assert "requests" in deps def test_parse_flit(flit_pyproject_toml: Path): - res = parse_pyproject_toml( - flit_pyproject_toml, - ) + res = parse_pyproject_toml(flit_pyproject_toml, ["linux-64"]) specs = { - dep.name: typing.cast(VersionedDependency, dep) for dep in res.dependencies + dep.name: typing.cast(VersionedDependency, dep) + for dep in res.dependencies["linux-64"] } assert specs["requests"].version == ">=2.13.0" @@ -701,12 +718,11 @@ def test_parse_flit(flit_pyproject_toml: Path): def test_parse_pdm(pdm_pyproject_toml: Path): - res = parse_pyproject_toml( - pdm_pyproject_toml, - ) + res = parse_pyproject_toml(pdm_pyproject_toml, ["linux-64"]) specs = { - dep.name: typing.cast(VersionedDependency, dep) for dep in res.dependencies + dep.name: typing.cast(VersionedDependency, dep) + for dep in res.dependencies["linux-64"] } # Base dependencies @@ -1078,7 +1094,9 @@ def test_run_lock_with_local_package( virtual_package_repo=virtual_package_repo, ) assert not any( - p.manager == "pip" for p in lock_spec.dependencies + p.manager == "pip" + for platform in lock_spec.platforms + for p in lock_spec.dependencies[platform] ), "conda-lock ignores editable pip deps" @@ -1139,18 +1157,19 @@ def test_poetry_version_parsing_constraints( with vpr, capsys.disabled(): with tempfile.NamedTemporaryFile(dir=".") as tf: spec = LockSpecification( - dependencies=[ - VersionedDependency( - name=package, - version=poetry_version_to_conda_version(version) or "", - manager="conda", - optional=False, - category="main", - extras=[], - ) - ], + dependencies={ + "linux-64": [ + VersionedDependency( + name=package, + version=poetry_version_to_conda_version(version) or "", + manager="conda", + optional=False, + category="main", + extras=[], + ), + ], + }, channels=[Channel.from_string("conda-forge")], - platforms=["linux-64"], # NB: this file must exist for relative path resolution to work # in create_lockfile_from_spec sources=[Path(tf.name)], @@ -1195,76 +1214,33 @@ def _make_spec(name: str, constraint: str = "*"): ) -def _make_dependency_with_platforms( - name: str, platforms: typing.List[str], constraint: str = "*" -): - return VersionedDependency( - name=name, - version=constraint, - selectors=Selectors(platform=platforms), - ) - - def test_aggregate_lock_specs(): """Ensure that the way two specs combine when both specify channels is correct""" base_spec = LockSpecification( - dependencies=[_make_spec("python", "=3.7")], + dependencies={"linux-64": [_make_spec("python", "=3.7")]}, channels=[Channel.from_string("conda-forge")], - platforms=["linux-64"], sources=[Path("base-env.yml")], ) gpu_spec = LockSpecification( - dependencies=[_make_spec("pytorch")], + dependencies={"linux-64": [_make_spec("pytorch")]}, channels=[Channel.from_string("pytorch"), Channel.from_string("conda-forge")], - platforms=["linux-64"], sources=[Path("ml-stuff.yml")], ) # NB: content hash explicitly does not depend on the source file names - actual = aggregate_lock_specs([base_spec, gpu_spec]) + actual = aggregate_lock_specs([base_spec, gpu_spec], platforms=["linux-64"]) expected = LockSpecification( - dependencies=[ - _make_spec("python", "=3.7"), - _make_spec("pytorch"), - ], + dependencies={ + "linux-64": [ + _make_spec("python", "=3.7"), + _make_spec("pytorch"), + ] + }, channels=[ Channel.from_string("pytorch"), Channel.from_string("conda-forge"), ], - platforms=["linux-64"], - sources=[], - ) - assert actual.dict(exclude={"sources"}) == expected.dict(exclude={"sources"}) - assert actual.content_hash() == expected.content_hash() - - -def test_aggregate_lock_specs_multiple_platforms(): - """Ensure that plaforms are merged correctly""" - linux_spec = LockSpecification( - dependencies=[_make_dependency_with_platforms("python", ["linux-64"], "=3.7")], - channels=[Channel.from_string("conda-forge")], - platforms=["linux-64"], - sources=[Path("base-env.yml")], - ) - - osx_spec = LockSpecification( - dependencies=[_make_dependency_with_platforms("python", ["osx-64"], "=3.7")], - channels=[Channel.from_string("conda-forge")], - platforms=["osx-64"], - sources=[Path("base-env.yml")], - ) - - # NB: content hash explicitly does not depend on the source file names - actual = aggregate_lock_specs([linux_spec, osx_spec]) - expected = LockSpecification( - dependencies=[ - _make_dependency_with_platforms("python", ["linux-64", "osx-64"], "=3.7") - ], - channels=[ - Channel.from_string("conda-forge"), - ], - platforms=["linux-64", "osx-64"], sources=[], ) assert actual.dict(exclude={"sources"}) == expected.dict(exclude={"sources"}) @@ -1273,20 +1249,18 @@ def test_aggregate_lock_specs_multiple_platforms(): def test_aggregate_lock_specs_override_version(): base_spec = LockSpecification( - dependencies=[_make_spec("package", "=1.0")], + dependencies={"linux-64": [_make_spec("package", "=1.0")]}, channels=[Channel.from_string("conda-forge")], - platforms=["linux-64"], sources=[Path("base.yml")], ) override_spec = LockSpecification( - dependencies=[_make_spec("package", "=2.0")], + dependencies={"linux-64": [_make_spec("package", "=2.0")]}, channels=[Channel.from_string("internal"), Channel.from_string("conda-forge")], - platforms=["linux-64"], sources=[Path("override.yml")], ) - agg_spec = aggregate_lock_specs([base_spec, override_spec]) + agg_spec = aggregate_lock_specs([base_spec, override_spec], platforms=["linux-64"]) assert agg_spec.dependencies == override_spec.dependencies @@ -1294,9 +1268,8 @@ def test_aggregate_lock_specs_override_version(): def test_aggregate_lock_specs_invalid_channels(): """Ensure that aggregating specs from mismatched channel orderings raises an error.""" base_spec = LockSpecification( - dependencies=[], + dependencies={}, channels=[Channel.from_string("defaults")], - platforms=[], sources=[], ) @@ -1308,7 +1281,7 @@ def test_aggregate_lock_specs_invalid_channels(): ] } ) - agg_spec = aggregate_lock_specs([base_spec, add_conda_forge]) + agg_spec = aggregate_lock_specs([base_spec, add_conda_forge], platforms=[]) assert agg_spec.channels == add_conda_forge.channels # swap the order of the two channels, which is an error @@ -1322,7 +1295,9 @@ def test_aggregate_lock_specs_invalid_channels(): ) with pytest.raises(ChannelAggregationError): - agg_spec = aggregate_lock_specs([base_spec, add_conda_forge, flipped]) + agg_spec = aggregate_lock_specs( + [base_spec, add_conda_forge, flipped], platforms=[] + ) add_pytorch = base_spec.copy( update={ @@ -1333,7 +1308,9 @@ def test_aggregate_lock_specs_invalid_channels(): } ) with pytest.raises(ChannelAggregationError): - agg_spec = aggregate_lock_specs([base_spec, add_conda_forge, add_pytorch]) + agg_spec = aggregate_lock_specs( + [base_spec, add_conda_forge, add_pytorch], platforms=[] + ) @pytest.fixture(scope="session") @@ -1700,9 +1677,8 @@ def test_virtual_package_input_hash_stability(): vpr = virtual_package_repo_from_specification(vspec) spec = LockSpecification( - dependencies=[], + dependencies={"linux-64": []}, channels=[], - platforms=["linux-64"], sources=[], virtual_package_repo=vpr, ) @@ -1725,9 +1701,8 @@ def test_default_virtual_package_input_hash_stability(): } spec = LockSpecification( - dependencies=[], + dependencies={platform: [] for platform in expected.keys()}, channels=[], - platforms=list(expected.keys()), sources=[], virtual_package_repo=vpr, )