diff --git a/conda_lock/common.py b/conda_lock/common.py index 7686a272..5d47ce69 100644 --- a/conda_lock/common.py +++ b/conda_lock/common.py @@ -19,10 +19,6 @@ ) -if typing.TYPE_CHECKING: - # Not in the release version of typeshed yet - from _typeshed import SupportsRichComparisonT # type: ignore - T = TypeVar("T") @@ -79,40 +75,6 @@ def ordered_union(collections: Iterable[Iterable[T]]) -> List[T]: return list({k: k for k in chain.from_iterable(collections)}.values()) -def suffix_union( - collections: Iterable[Sequence["SupportsRichComparisonT"]], -) -> List["SupportsRichComparisonT"]: - """Generates the union of sequence ensuring that they have a common suffix. - - This is used to unify channels. - - >>> suffix_union([[1], [2, 1], [3, 2, 1], [2, 1], [1]]) - [3, 2, 1] - - >>> suffix_union([[1], [2, 1], [4, 1]]) - Traceback (most recent call last) - ... - RuntimeError: [4, 1] is not a subset of [2, 1] - - """ - from genericpath import commonprefix - - result: List["SupportsRichComparisonT"] = [] - for seq in collections: - if seq: - rev_priority = list(reversed(seq)) - prefix = commonprefix([result, rev_priority]) # type: ignore - if len(result) == 0: - result = rev_priority - elif prefix[: len(rev_priority)] != result[: len(rev_priority)]: - raise ValueError( - f"{list(reversed(rev_priority))} is not a ordered subset of {list(reversed(result))}" - ) - elif len(rev_priority) > len(result): - result = rev_priority - return list(reversed(result)) - - def relative_path(source: pathlib.Path, target: pathlib.Path) -> str: """ Get posix representation of the relative path from `source` to `target`. diff --git a/conda_lock/models/channel.py b/conda_lock/models/channel.py index 60874f3f..9434b1d9 100644 --- a/conda_lock/models/channel.py +++ b/conda_lock/models/channel.py @@ -95,9 +95,6 @@ class Channel(ZeroValRepr, BaseModel): url: str used_env_vars: FrozenSet[str] = Field(default=frozenset()) - def __lt__(self, other: "Channel") -> bool: - return tuple(self.dict().values()) < tuple(other.dict().values()) - @classmethod def from_string(cls, value: str) -> "Channel": if "://" in value: diff --git a/conda_lock/src_parser/aggregation.py b/conda_lock/src_parser/aggregation.py index 65ab4842..3a035ec8 100644 --- a/conda_lock/src_parser/aggregation.py +++ b/conda_lock/src_parser/aggregation.py @@ -3,8 +3,9 @@ from itertools import chain from typing import Dict, List, Tuple -from conda_lock.common import ordered_union, suffix_union +from conda_lock.common import ordered_union from conda_lock.errors import ChannelAggregationError +from conda_lock.models.channel import Channel from conda_lock.models.lock_spec import Dependency, LockSpecification @@ -37,7 +38,9 @@ def aggregate_lock_specs( dependencies[platform] = list(unique_deps.values()) try: - channels = suffix_union(lock_spec.channels for lock_spec in lock_specs) + channels = unify_package_sources( + [lock_spec.channels for lock_spec in lock_specs] + ) except ValueError as e: raise ChannelAggregationError(*e.args) @@ -51,3 +54,40 @@ def aggregate_lock_specs( lock_spec.allow_pypi_requests for lock_spec in lock_specs ), ) + + +def unify_package_sources(collections: List[List[Channel]]) -> List[Channel]: + """Unify the package sources from multiple lock specs. + + To be able to merge the lock specs, the package sources must be compatible between + them. This means that between any two lock specs, the package sources must be + identical or one must be an extension of the other. + + This allows us to use a superset of all of the package source lists in the + aggregated lock spec. + + The following is allowed: + + > unify_package_sources([[channel_two, channel_one], [channel_one]]) + [channel_two, channel_one] + + Whilst the following will fail: + + > unify_package_sources([[channel_two, channel_one], [channel_three, channel_one]]) + + In the failing example, it is not possible to predictably decide which channel + to search first, `channel_two` or `channel_three`, so we error in this case. + """ + if not collections: + return [] + result = max(collections, key=len) + for collection in collections: + if collection == []: + truncated_result = [] + else: + truncated_result = result[-len(collection) :] + if collection != truncated_result: + raise ValueError( + f"{collection} is not an ordered subset at the end of {result}" + ) + return result diff --git a/tests/test_channel.py b/tests/test_channel.py index 1e381dd9..c523fbc9 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -1,6 +1,9 @@ import typing -from conda_lock.models.channel import _detect_used_env_var, _env_var_normalize +import pytest + +from conda_lock.models.channel import Channel, _detect_used_env_var, _env_var_normalize +from conda_lock.src_parser.aggregation import unify_package_sources if typing.TYPE_CHECKING: @@ -41,3 +44,45 @@ def test_url_auth_info(monkeypatch: "MonkeyPatch") -> None: replaced = y.conda_token_replaced_url() assert replaced == f"http://{user}:{passwd}@host/t//prefix/suffix" + + +@pytest.mark.parametrize( + "collections,expected", + [ + ( + [ + ["three", "two", "one"], + ["two", "one"], + ], + ["three", "two", "one"], + ), + ( + [ + ["three", "two", "one"], + ["two", "one"], + [], + ], + ["three", "two", "one"], + ), + ( + [ + ["three", "two", "one"], + ["three", "one", "two"], + ], + ValueError, + ), + ], +) +def test_unify_package_sources( + collections: typing.List[str], + expected: typing.Union[typing.List[str], typing.Type[Exception]], +): + channel_collections = [ + [Channel.from_string(name) for name in collection] for collection in collections + ] + if isinstance(expected, list): + result = unify_package_sources(channel_collections) + assert [channel.url for channel in result] == expected + else: + with pytest.raises(expected): + unify_package_sources(channel_collections)