Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LockSpecification as a Dictionary from Platforms to List of Deps #383

Merged
merged 10 commits into from
Mar 11, 2023
10 changes: 3 additions & 7 deletions conda_lock/conda_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def make_lock_files(
update: Optional[List[str]] = None,
include_dev_dependencies: bool = True,
filename_template: Optional[str] = None,
filter_categories: bool = True,
filter_categories: bool = False,
extras: Optional[AbstractSet[str]] = None,
check_input_hash: bool = False,
metadata_choices: AbstractSet[MetadataOption] = frozenset(),
Expand Down Expand Up @@ -656,12 +656,8 @@ def _solve_for_arch(
"""
if update_spec is None:
update_spec = UpdateSpecification()
# filter requested and locked dependencies to the current platform
dependencies = [
dep
for dep in spec.dependencies
if (not dep.selectors.platform) or platform in dep.selectors.platform
]

dependencies = spec.dependencies[platform]
locked = [dep for dep in update_spec.locked if dep.platform == platform]
requested_deps_by_name = {
manager: {dep.name: dep for dep in dependencies if dep.manager == manager}
Expand Down
12 changes: 1 addition & 11 deletions conda_lock/conda_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,7 @@
import time

from contextlib import contextmanager
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
MutableSequence,
Optional,
Sequence,
cast,
)
from typing import Dict, Iterable, Iterator, List, MutableSequence, Optional, Sequence
from urllib.parse import urlsplit, urlunsplit

import yaml
Expand Down
29 changes: 8 additions & 21 deletions conda_lock/models/lock_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,12 @@
from conda_lock.virtual_package import FakeRepoData


class Selectors(StrictModel):
platform: Optional[List[str]] = None

def __ior__(self, other: "Selectors") -> "Selectors":
if not isinstance(other, Selectors):
raise TypeError
if other.platform and self.platform:
for p in other.platform:
if p not in self.platform:
self.platform.append(p)
return self

def for_platform(self, platform: str) -> bool:
return self.platform is None or platform in self.platform


class _BaseDependency(StrictModel):
name: str
manager: Literal["conda", "pip"] = "conda"
optional: bool = False
category: str = "main"
extras: List[str] = []
selectors: Selectors = Selectors()


class VersionedDependency(_BaseDependency):
Expand All @@ -58,14 +41,17 @@ class Package(StrictModel):


class LockSpecification(BaseModel):
dependencies: List[Dependency]
dependencies: Dict[str, List[Dependency]]
# TODO: Should we store the auth info in here?
channels: List[Channel]
platforms: List[str]
sources: List[pathlib.Path]
virtual_package_repo: Optional[FakeRepoData] = None
allow_pypi_requests: bool = True

@property
def platforms(self) -> List[str]:
return list(self.dependencies.keys())

def content_hash(self) -> Dict[str, str]:
return {
platform: self.content_hash_for_platform(platform)
Expand All @@ -77,8 +63,9 @@ def content_hash_for_platform(self, platform: str) -> str:
"channels": [c.json() for c in self.channels],
"specs": [
p.dict()
for p in sorted(self.dependencies, key=lambda p: (p.manager, p.name))
if p.selectors.for_platform(platform)
for p in sorted(
self.dependencies[platform], key=lambda p: (p.manager, p.name)
)
],
}
if self.virtual_package_repo is not None:
Expand Down
50 changes: 33 additions & 17 deletions conda_lock/src_parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
parse_platforms_from_env_file,
)
from conda_lock.src_parser.meta_yaml import parse_meta_yaml_file
from conda_lock.src_parser.pyproject_toml import parse_pyproject_toml
from conda_lock.src_parser.pyproject_toml import (
parse_platforms_from_pyproject_toml,
parse_pyproject_toml,
)
from conda_lock.virtual_package import FakeRepoData


Expand All @@ -36,7 +39,7 @@ def _parse_platforms_from_srcs(src_files: List[pathlib.Path]) -> List[str]:
if src_file.name == "meta.yaml":
continue
elif src_file.name == "pyproject.toml":
all_file_platforms.append(parse_pyproject_toml(src_file).platforms)
all_file_platforms.append(parse_platforms_from_pyproject_toml(src_file))
else:
all_file_platforms.append(parse_platforms_from_env_file(src_file))

Expand All @@ -62,7 +65,7 @@ def _parse_source_files(
if src_file.name == "meta.yaml":
desired_envs.append(parse_meta_yaml_file(src_file, platforms))
elif src_file.name == "pyproject.toml":
desired_envs.append(parse_pyproject_toml(src_file))
desired_envs.append(parse_pyproject_toml(src_file, platforms))
else:
desired_envs.append(parse_environment_file(src_file, platforms))
return desired_envs
Expand All @@ -85,24 +88,37 @@ def make_lock_spec(

lock_specs = _parse_source_files(src_files, platforms)

lock_spec = aggregate_lock_specs(lock_specs)
lock_spec.virtual_package_repo = virtual_package_repo
lock_spec.platforms = platforms
lock_spec.channels = (
aggregated_lock_spec = aggregate_lock_specs(lock_specs, platforms)

# Use channel overrides if given, otherwise use the channels specified in the
# source files.
channels = (
[Channel.from_string(co) for co in channel_overrides]
if channel_overrides
else lock_spec.channels
else aggregated_lock_spec.channels
)

if required_categories is not None:

if required_categories is None:
dependencies = aggregated_lock_spec.dependencies
else:
# Filtering based on category (e.g. "main" or "dev") was requested.
# Thus we need to filter the specs based on the category.
def dep_has_category(d: Dependency, categories: AbstractSet[str]) -> bool:
return d.category in categories

lock_spec.dependencies = [
d
for d in lock_spec.dependencies
if dep_has_category(d, categories=required_categories)
]

return lock_spec
dependencies = {
platform: [
d
for d in dependencies
if dep_has_category(d, categories=required_categories)
]
for platform, dependencies in aggregated_lock_spec.dependencies.items()
}

return LockSpecification(
dependencies=dependencies,
channels=channels,
sources=aggregated_lock_spec.sources,
virtual_package_repo=virtual_package_repo,
allow_pypi_requests=aggregated_lock_spec.allow_pypi_requests,
)
41 changes: 24 additions & 17 deletions conda_lock/src_parser/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,31 @@

def aggregate_lock_specs(
lock_specs: List[LockSpecification],
platforms: List[str],
) -> LockSpecification:
# unique dependencies
unique_deps: Dict[Tuple[str, str], Dependency] = {}
for dep in chain.from_iterable(
[lock_spec.dependencies for lock_spec in lock_specs]
):
key = (dep.manager, dep.name)
if key in unique_deps:
# Override existing, but merge selectors
previous_selectors = unique_deps[key].selectors
previous_selectors |= dep.selectors
dep.selectors = previous_selectors
unique_deps[key] = dep

dependencies = list(unique_deps.values())
for lock_spec in lock_specs:
if set(lock_spec.platforms) != set(platforms):
raise ValueError(
f"Lock specifications must have the same platforms in order to be "
f"aggregated. Expected platforms are {set(platforms)}, but the lock "
f"specification from {[str(s) for s in lock_spec.sources]} has "
f"platforms {set(lock_spec.platforms)}."
)

dependencies: Dict[str, List[Dependency]] = {}
for platform in platforms:
# unique dependencies
unique_deps: Dict[Tuple[str, str], Dependency] = {}
for dep in chain.from_iterable(
lock_spec.dependencies.get(platform, []) for lock_spec in lock_specs
):
key = (dep.manager, dep.name)
unique_deps[key] = dep

dependencies[platform] = list(unique_deps.values())

try:
channels = suffix_union(lock_spec.channels or [] for lock_spec in lock_specs)
channels = suffix_union(lock_spec.channels for lock_spec in lock_specs)
except ValueError as e:
raise ChannelAggregationError(*e.args)

Expand All @@ -38,8 +46,7 @@ def aggregate_lock_specs(
# Ensure channel are correctly ordered
channels=channels,
# uniquify metadata, preserving order
platforms=ordered_union(lock_spec.platforms or [] for lock_spec in lock_specs),
sources=ordered_union(lock_spec.sources or [] for lock_spec in lock_specs),
sources=ordered_union(lock_spec.sources for lock_spec in lock_specs),
allow_pypi_requests=all(
lock_spec.allow_pypi_requests for lock_spec in lock_specs
),
Expand Down
58 changes: 19 additions & 39 deletions conda_lock/src_parser/environment_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import yaml

from conda_lock.models.lock_spec import Dependency, LockSpecification
from conda_lock.src_parser.aggregation import aggregate_lock_specs
from conda_lock.src_parser.conda_common import conda_spec_to_versioned_dep
from conda_lock.src_parser.selectors import filter_platform_selectors

Expand All @@ -27,10 +26,10 @@ def parse_conda_requirement(req: str) -> Tuple[str, str]:


def _parse_environment_file_for_platform(
environment_file: pathlib.Path,
content: str,
category: str,
platform: str,
) -> LockSpecification:
) -> List[Dependency]:
"""
Parse dependencies from a conda environment specification for an
assumed target platform.
Expand All @@ -44,23 +43,15 @@ def _parse_environment_file_for_platform(
"""
filtered_content = "\n".join(filter_platform_selectors(content, platform=platform))
env_yaml_data = yaml.safe_load(filtered_content)

specs = env_yaml_data["dependencies"]
channels: List[str] = env_yaml_data.get("channels", [])

# These extension fields are nonstandard
platforms: List[str] = env_yaml_data.get("platforms", [])
category: str = env_yaml_data.get("category") or "main"

# Split out any sub spec sections from the dependencies mapping
mapping_specs = [x for x in specs if not isinstance(x, str)]
specs = [x for x in specs if isinstance(x, str)]

dependencies: List[Dependency] = []
for spec in specs:
vdep = conda_spec_to_versioned_dep(spec, category)
vdep.selectors.platform = [platform]
dependencies.append(vdep)
dependencies.append(conda_spec_to_versioned_dep(spec, category))

for mapping_spec in mapping_specs:
if "pip" in mapping_spec:
Expand Down Expand Up @@ -88,12 +79,7 @@ def _parse_environment_file_for_platform(
# ensure pip is in target env
dependencies.append(parse_python_requirement("pip", manager="conda"))

return LockSpecification(
dependencies=dependencies,
channels=channels, # type: ignore
platforms=platforms,
sources=[environment_file],
)
return dependencies


def parse_platforms_from_env_file(environment_file: pathlib.Path) -> List[str]:
Expand Down Expand Up @@ -127,26 +113,20 @@ def parse_environment_file(
with environment_file.open("r") as fo:
content = fo.read()

env_yaml_data = yaml.safe_load(content)
channels: List[str] = env_yaml_data.get("channels", [])

# These extension fields are nonstandard
category: str = env_yaml_data.get("category") or "main"

# Parse with selectors for each target platform
spec = aggregate_lock_specs(
[
_parse_environment_file_for_platform(
environment_file,
content,
platform,
)
for platform in platforms
]
)
dep_map = {
platform: _parse_environment_file_for_platform(content, category, platform)
for platform in platforms
}

# Remove platform selectors if they apply to all targets
for dep in spec.dependencies:
if dep.selectors.platform == platforms:
dep.selectors.platform = None

# Use the list of rendered platforms for the output spec only if
# there is a dependency that is not used on all platforms.
# This is unlike meta.yaml because environment-yaml files can contain an
# internal list of platforms, which should be used as long as it
spec.platforms = platforms
return spec
return LockSpecification(
dependencies=dep_map,
channels=channels, # type: ignore
sources=[environment_file],
)
Loading