From 3b7132dfb08574c4d11e22352c444f225c85ec59 Mon Sep 17 00:00:00 2001 From: Ruben Arts Date: Wed, 8 Nov 2023 16:04:42 +0100 Subject: [PATCH] feat: add channel priority and channel-specific selectors to solver info (#394) --- crates/rattler-bin/src/commands/create.rs | 1 - .../src/sparse/mod.rs | 127 +++++------------- crates/rattler_solve/Cargo.toml | 2 +- crates/rattler_solve/benches/bench.rs | 2 +- crates/rattler_solve/src/resolvo/mod.rs | 68 +++++++++- crates/rattler_solve/tests/backends.rs | 111 ++++++++++++++- py-rattler/Cargo.lock | 4 +- py-rattler/rattler/repo_data/sparse.py | 2 - py-rattler/rattler/solver/solver.py | 5 - py-rattler/src/repo_data/sparse.rs | 16 +-- py-rattler/src/solver.rs | 2 - 11 files changed, 213 insertions(+), 127 deletions(-) diff --git a/crates/rattler-bin/src/commands/create.rs b/crates/rattler-bin/src/commands/create.rs index b7f6bc631..5492602a5 100644 --- a/crates/rattler-bin/src/commands/create.rs +++ b/crates/rattler-bin/src/commands/create.rs @@ -167,7 +167,6 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> { record.depends.push("pip".to_string()); } }), - true, ) })?; diff --git a/crates/rattler_repodata_gateway/src/sparse/mod.rs b/crates/rattler_repodata_gateway/src/sparse/mod.rs index f15820932..6a0680d4b 100644 --- a/crates/rattler_repodata_gateway/src/sparse/mod.rs +++ b/crates/rattler_repodata_gateway/src/sparse/mod.rs @@ -119,14 +119,10 @@ impl SparseRepoData { /// This will parse the records for the specified packages as well as all the packages these records /// depend on. /// - /// When strict_channel_priority is true, the channel where a package is found first will be - /// the only channel used for that package. Make it false to search in all channels for all packages. - /// pub fn load_records_recursive<'a>( repo_data: impl IntoIterator, package_names: impl IntoIterator, patch_function: Option, - strict_channel_priority: bool, ) -> io::Result>> { let repo_data: Vec<_> = repo_data.into_iter().collect(); @@ -141,13 +137,7 @@ impl SparseRepoData { // Iterate over the list of packages that still need to be processed. while let Some(next_package) = pending.pop_front() { - let mut found_in_channel = None; for (i, repo_data) in repo_data.iter().enumerate() { - // If package was found in other channel, skip this repodata - if found_in_channel.map_or(false, |c| c != &repo_data.channel.base_url) { - continue; - } - let repo_data_packages = repo_data.inner.borrow_repo_data(); let base_url = repo_data_packages .info @@ -173,10 +163,6 @@ impl SparseRepoData { )?; records.append(&mut conda_records); - if strict_channel_priority && !records.is_empty() { - found_in_channel = Some(&repo_data.channel.base_url); - } - // Iterate over all packages to find recursive dependencies. for record in records.iter() { for dependency in &record.package_record.depends { @@ -274,7 +260,6 @@ pub async fn load_repo_data_recursively( repo_data_paths: impl IntoIterator, impl AsRef)>, package_names: impl IntoIterator, patch_function: Option, - strict_channel_priority: bool, ) -> Result>, io::Error> { // Open the different files and memory map them to get access to their bytes. Do this in parallel. let lazy_repo_data = stream::iter(repo_data_paths) @@ -293,12 +278,7 @@ pub async fn load_repo_data_recursively( .try_collect::>() .await?; - SparseRepoData::load_records_recursive( - &lazy_repo_data, - package_names, - patch_function, - strict_channel_priority, - ) + SparseRepoData::load_records_recursive(&lazy_repo_data, package_names, patch_function) } fn deserialize_filename_and_raw_record<'d, D: Deserializer<'d>>( @@ -401,7 +381,6 @@ impl<'de> TryFrom<&'de str> for PackageFilename<'de> { #[cfg(test)] mod test { use super::{load_repo_data_recursively, PackageFilename}; - use itertools::Itertools; use rattler_conda_types::{Channel, ChannelConfig, PackageName, RepoData, RepoDataRecord}; use rstest::rstest; use std::path::{Path, PathBuf}; @@ -412,7 +391,6 @@ mod test { async fn load_sparse( package_names: impl IntoIterator>, - strict_channel_priority: bool, ) -> Vec> { load_repo_data_recursively( [ @@ -426,17 +404,11 @@ mod test { "linux-64", test_dir().join("channels/conda-forge/linux-64/repodata.json"), ), - ( - Channel::from_str("pytorch", &ChannelConfig::default()).unwrap(), - "linux-64", - test_dir().join("channels/pytorch/linux-64/repodata.json"), - ), ], package_names .into_iter() .map(|name| PackageName::try_from(name.as_ref()).unwrap()), None, - strict_channel_priority, ) .await .unwrap() @@ -444,13 +416,13 @@ mod test { #[tokio::test] async fn test_empty_sparse_load() { - let sparse_empty_data = load_sparse(Vec::::new(), false).await; - assert_eq!(sparse_empty_data, vec![vec![], vec![], vec![]]); + let sparse_empty_data = load_sparse(Vec::::new()).await; + assert_eq!(sparse_empty_data, vec![vec![], vec![]]); } #[tokio::test] async fn test_sparse_single() { - let sparse_empty_data = load_sparse(["_libgcc_mutex"], false).await; + let sparse_empty_data = load_sparse(["_libgcc_mutex"]).await; let total_records = sparse_empty_data .iter() .map(|repo| repo.len()) @@ -459,45 +431,9 @@ mod test { assert_eq!(total_records, 3); } - #[tokio::test] - async fn test_sparse_strict() { - // If we load pytorch-cpy from all channels (non-strict) we expect records from both - // conda-forge and the pytorch channels. - let sparse_data = load_sparse(["pytorch-cpu"], false).await; - let channels = sparse_data - .into_iter() - .flatten() - .filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu") - .map(|record| record.channel) - .unique() - .collect_vec(); - assert_eq!( - channels, - vec![ - String::from("https://conda.anaconda.org/conda-forge/"), - String::from("https://conda.anaconda.org/pytorch/") - ] - ); - - // If we load pytorch-cpy from strict channels we expect records only from the first channel - // that contains the package. In this case we expect records only from conda-forge. - let strict_sparse_data = load_sparse(["pytorch-cpu"], true).await; - let channels = strict_sparse_data - .into_iter() - .flatten() - .filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu") - .map(|record| record.channel) - .unique() - .collect_vec(); - assert_eq!( - channels, - vec![String::from("https://conda.anaconda.org/conda-forge/")] - ); - } - #[tokio::test] async fn test_parse_duplicate() { - let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"], false).await; + let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"]).await; let total_records = sparse_empty_data .iter() .map(|repo| repo.len()) @@ -509,7 +445,7 @@ mod test { #[tokio::test] async fn test_sparse_jupyterlab_detectron2() { - let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"], true).await; + let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"]).await; let total_records = sparse_empty_data .iter() @@ -521,33 +457,30 @@ mod test { #[tokio::test] async fn test_sparse_numpy_dev() { - let sparse_empty_data = load_sparse( - [ - "python", - "cython", - "compilers", - "openblas", - "nomkl", - "pytest", - "pytest-cov", - "pytest-xdist", - "hypothesis", - "mypy", - "typing_extensions", - "sphinx", - "numpydoc", - "ipython", - "scipy", - "pandas", - "matplotlib", - "pydata-sphinx-theme", - "pycodestyle", - "gitpython", - "cffi", - "pytz", - ], - false, - ) + let sparse_empty_data = load_sparse([ + "python", + "cython", + "compilers", + "openblas", + "nomkl", + "pytest", + "pytest-cov", + "pytest-xdist", + "hypothesis", + "mypy", + "typing_extensions", + "sphinx", + "numpydoc", + "ipython", + "scipy", + "pandas", + "matplotlib", + "pydata-sphinx-theme", + "pycodestyle", + "gitpython", + "cffi", + "pytz", + ]) .await; let total_records = sparse_empty_data diff --git a/crates/rattler_solve/Cargo.toml b/crates/rattler_solve/Cargo.toml index a72e229e2..e69d7ff8b 100644 --- a/crates/rattler_solve/Cargo.toml +++ b/crates/rattler_solve/Cargo.toml @@ -24,7 +24,7 @@ url = "2.4.1" hex = "0.4.3" tempfile = "3.8.0" rattler_libsolv_c = { version = "0.11.0", path = "../rattler_libsolv_c", optional = true } -resolvo = { version = "0.1.0", optional = true } +resolvo = { version = "0.2.0", optional = true } [dev-dependencies] rattler_repodata_gateway = { version = "0.11.0", path = "../rattler_repodata_gateway", default-features = false, features = ["sparse"] } diff --git a/crates/rattler_solve/benches/bench.rs b/crates/rattler_solve/benches/bench.rs index 98c08efa6..fbb23927b 100644 --- a/crates/rattler_solve/benches/bench.rs +++ b/crates/rattler_solve/benches/bench.rs @@ -52,7 +52,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) { let names = specs.iter().map(|s| s.name.clone().unwrap()); let available_packages = - SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None, true).unwrap(); + SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None).unwrap(); #[cfg(feature = "libsolv_c")] group.bench_function("libsolv_c", |b| { diff --git a/crates/rattler_solve/src/resolvo/mod.rs b/crates/rattler_solve/src/resolvo/mod.rs index c7eeaec83..c10908cfd 100644 --- a/crates/rattler_solve/src/resolvo/mod.rs +++ b/crates/rattler_solve/src/resolvo/mod.rs @@ -172,6 +172,7 @@ impl<'a> CondaDependencyProvider<'a> { favored_records: &'a [RepoDataRecord], locked_records: &'a [RepoDataRecord], virtual_packages: &'a [GenericVirtualPackage], + match_specs: &[MatchSpec], ) -> Self { let pool = Pool::default(); let mut records: HashMap = HashMap::default(); @@ -184,6 +185,15 @@ impl<'a> CondaDependencyProvider<'a> { records.entry(name).or_default().candidates.push(solvable); } + // TODO: Normalize these channel names to urls so we can compare them correctly. + let channel_specific_specs = match_specs + .iter() + .filter(|spec| spec.channel.is_some()) + .collect::>(); + + // Hashmap that maps the package name to the channel it was first found in. + let mut package_name_found_in_channel = HashMap::::new(); + // Add additional records for repo_datas in repodata { // Iterate over all records and dedup records that refer to the same package data but with @@ -237,6 +247,55 @@ impl<'a> CondaDependencyProvider<'a> { pool.intern_solvable(package_name, SolverPackageRecord::Record(record)); let candidates = records.entry(package_name).or_default(); candidates.candidates.push(solvable_id); + + // Add to excluded when package is not in the specified channel. + if !channel_specific_specs.is_empty() { + if let Some(spec) = channel_specific_specs.iter().find(|&&spec| { + spec.name + .as_ref() + .expect("expecting a name") + .as_normalized() + == record.package_record.name.as_normalized() + }) { + // Check if the spec has a channel, and compare it to the repodata channel + if let Some(spec_channel) = &spec.channel { + if !&record.channel.contains(spec_channel) { + tracing::debug!("Ignoring {} from {} because it was not requested from that channel.", &record.package_record.name.as_normalized(), &record.channel); + // Add record to the excluded with reason of being in the non requested channel. + candidates.excluded.push(( + solvable_id, + pool.intern_string(format!( + "candidate not in requested channel: '{spec_channel}'" + )), + )); + continue; + } + } + } + } + + // Enforce channel priority + // This functions makes the assumtion that the records are given in order of the channels. + if let Some(first_channel) = package_name_found_in_channel + .get(&record.package_record.name.as_normalized().to_string()) + { + // Add the record to the excluded list when it is from a different channel. + if first_channel != &&record.channel { + tracing::debug!( + "Ignoring '{}' from '{}' because of strict channel priority.", + &record.package_record.name.as_normalized(), + &record.channel + ); + candidates.excluded.push((solvable_id, pool.intern_string(format!("due to strict channel priority not using this option from: '{first_channel}'", )))); + continue; + } + } else { + package_name_found_in_channel.insert( + record.package_record.name.as_normalized().to_string(), + &record.channel, + ); + } + candidates.hint_dependencies_available.push(solvable_id); } } @@ -245,7 +304,7 @@ impl<'a> CondaDependencyProvider<'a> { for favored_record in favored_records { let name = pool.intern_package_name(favored_record.package_record.name.as_normalized()); let solvable = pool.intern_solvable(name, SolverPackageRecord::Record(favored_record)); - let mut candidates = records.entry(name).or_default(); + let candidates = records.entry(name).or_default(); candidates.candidates.push(solvable); candidates.favored = Some(solvable); } @@ -253,7 +312,7 @@ impl<'a> CondaDependencyProvider<'a> { for locked_record in locked_records { let name = pool.intern_package_name(locked_record.package_record.name.as_normalized()); let solvable = pool.intern_solvable(name, SolverPackageRecord::Record(locked_record)); - let mut candidates = records.entry(name).or_default(); + let candidates = records.entry(name).or_default(); candidates.candidates.push(solvable); candidates.locked = Some(solvable); } @@ -347,14 +406,15 @@ impl super::SolverImpl for Solver { &task.locked_packages, &task.pinned_packages, &task.virtual_packages, + task.specs.clone().as_ref(), ); // Construct the requirements that the solver needs to satisfy. let root_requirements = task .specs - .into_iter() + .iter() .map(|spec| { - let (name, spec) = spec.into_nameless(); + let (name, spec) = spec.clone().into_nameless(); let name = name.expect("cannot use matchspec without a name"); let name_id = provider.pool.intern_package_name(name.as_normalized()); provider.pool.intern_version_set(name_id, spec.into()) diff --git a/crates/rattler_solve/tests/backends.rs b/crates/rattler_solve/tests/backends.rs index be97681df..dd12b5440 100644 --- a/crates/rattler_solve/tests/backends.rs +++ b/crates/rattler_solve/tests/backends.rs @@ -25,6 +25,14 @@ fn conda_json_path_noarch() -> String { ) } +fn pytorch_json_path() -> String { + format!( + "{}/{}", + env!("CARGO_MANIFEST_DIR"), + "../../test-data/channels/pytorch/linux-64/repodata.json" + ) +} + fn dummy_channel_json_path() -> String { format!( "{}/{}", @@ -110,7 +118,7 @@ fn solve_real_world(specs: Vec<&str>) -> Vec { let names = specs.iter().filter_map(|s| s.name.as_ref().cloned()); let available_packages = - SparseRepoData::load_records_recursive(sparse_repo_datas, names, None, true).unwrap(); + SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap(); let solver_task = SolverTask { available_packages: &available_packages, @@ -161,6 +169,34 @@ fn read_real_world_repo_data() -> &'static Vec { &REPO_DATA } +fn read_pytorch_sparse_repo_data() -> &'static SparseRepoData { + static REPO_DATA: Lazy = Lazy::new(|| { + let pytorch = pytorch_json_path(); + SparseRepoData::new( + Channel::from_str("pytorch", &ChannelConfig::default()).unwrap(), + "pytorch".to_string(), + pytorch, + None, + ) + .unwrap() + }); + + &REPO_DATA +} + +fn read_conda_forge_sparse_repo_data() -> &'static SparseRepoData { + static REPO_DATA: Lazy = Lazy::new(|| { + let conda_forge = conda_json_path(); + SparseRepoData::new( + Channel::from_str("conda-forge", &ChannelConfig::default()).unwrap(), + "conda-forge".to_string(), + conda_forge, + None, + ) + .unwrap() + }); + &REPO_DATA +} macro_rules! solver_backend_tests { ($T:path) => { #[test] @@ -592,7 +628,7 @@ fn compare_solve(specs: Vec<&str>) { let names = specs.iter().filter_map(|s| s.name.as_ref().cloned()); let available_packages = - SparseRepoData::load_records_recursive(sparse_repo_datas, names, None, true).unwrap(); + SparseRepoData::load_records_recursive(sparse_repo_datas, names, None).unwrap(); let extract_pkgs = |records: Vec| { let mut pkgs = records @@ -699,3 +735,74 @@ fn compare_solve_quetz() { fn compare_solve_xtensor_xsimd() { compare_solve(vec!["xtensor", "xsimd"]); } + +fn solve_to_get_channel_of_spec( + spec_str: &str, + expected_channel: &str, + repo_data: Vec<&SparseRepoData>, +) { + let spec = MatchSpec::from_str(spec_str).unwrap(); + let specs = vec![spec.clone()]; + let names = specs.iter().filter_map(|s| s.name.as_ref().cloned()); + + let available_packages = + SparseRepoData::load_records_recursive(repo_data, names, None).unwrap(); + + let result = rattler_solve::resolvo::Solver + .solve(SolverTask { + available_packages: &available_packages, + specs: specs.clone(), + locked_packages: Default::default(), + pinned_packages: Default::default(), + virtual_packages: Default::default(), + }) + .unwrap(); + + let record = result.iter().find(|record| { + record.package_record.name.as_normalized() == spec.name.as_ref().unwrap().as_normalized() + }); + assert_eq!(record.unwrap().channel, expected_channel.to_string()); +} + +#[test] +fn channel_specific_requirement() { + let repodata = vec![ + read_conda_forge_sparse_repo_data(), + read_pytorch_sparse_repo_data(), + ]; + solve_to_get_channel_of_spec( + "conda-forge::pytorch-cpu", + "https://conda.anaconda.org/conda-forge/", + repodata.clone(), + ); + solve_to_get_channel_of_spec( + "pytorch::pytorch-cpu", + "https://conda.anaconda.org/pytorch/", + repodata, + ); +} + +#[test] +fn channel_order_strict() { + // Solve with conda-forge as the first channel + let repodata = vec![ + read_conda_forge_sparse_repo_data(), + read_pytorch_sparse_repo_data(), + ]; + solve_to_get_channel_of_spec( + "pytorch-cpu", + "https://conda.anaconda.org/conda-forge/", + repodata, + ); + + // Solve with pytorch as the first channel + let repodata = vec![ + read_pytorch_sparse_repo_data(), + read_conda_forge_sparse_repo_data(), + ]; + solve_to_get_channel_of_spec( + "pytorch-cpu", + "https://conda.anaconda.org/pytorch/", + repodata, + ); +} diff --git a/py-rattler/Cargo.lock b/py-rattler/Cargo.lock index 565ae0511..1e6493ae1 100644 --- a/py-rattler/Cargo.lock +++ b/py-rattler/Cargo.lock @@ -2240,9 +2240,9 @@ dependencies = [ [[package]] name = "resolvo" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dab30801b54723f1949c6453a35db09c89e2ce7e052dc63e715f32fb40e427c" +checksum = "554db165775d6858d17a9626c327b796d81db95db0a4cd6ca0efdfb7e7e3a264" dependencies = [ "bitvec", "elsa", diff --git a/py-rattler/rattler/repo_data/sparse.py b/py-rattler/rattler/repo_data/sparse.py index b7c40f83d..bc6432280 100644 --- a/py-rattler/rattler/repo_data/sparse.py +++ b/py-rattler/rattler/repo_data/sparse.py @@ -115,7 +115,6 @@ def subdir(self) -> str: def load_records_recursive( repo_data: List[SparseRepoData], package_names: List[PackageName], - strict_channel_priority: bool = True, ) -> List[List[RepoDataRecord]]: """ Given a set of [`SparseRepoData`]s load all the records @@ -143,7 +142,6 @@ def load_records_recursive( for list_of_records in PySparseRepoData.load_records_recursive( [r._sparse for r in repo_data], [p._name for p in package_names], - strict_channel_priority, ) ] diff --git a/py-rattler/rattler/solver/solver.py b/py-rattler/rattler/solver/solver.py index eacf23ef5..d443c5b28 100644 --- a/py-rattler/rattler/solver/solver.py +++ b/py-rattler/rattler/solver/solver.py @@ -14,7 +14,6 @@ def solve( locked_packages: Optional[List[RepoDataRecord]] = None, pinned_packages: Optional[List[RepoDataRecord]] = None, virtual_packages: Optional[List[GenericVirtualPackage]] = None, - strict_channel_priority: bool = True, ) -> List[RepoDataRecord]: """ Resolve the dependencies and return the `RepoDataRecord`s @@ -39,9 +38,6 @@ def solve( will always select that version no matter what even if that means other packages have to be downgraded. virtual_packages: A list of virtual packages considered active. - strict_channel_priority: (Default = True) When `True` the channel that the package - is first found in will be used as the only channel for that package. - When `False` it will search for every package in every channel. Returns: Resolved list of `RepoDataRecord`s. @@ -58,6 +54,5 @@ def solve( v_package._generic_virtual_package for v_package in virtual_packages or [] ], - strict_channel_priority, ) ] diff --git a/py-rattler/src/repo_data/sparse.rs b/py-rattler/src/repo_data/sparse.rs index 55afaa595..b8e5be46f 100644 --- a/py-rattler/src/repo_data/sparse.rs +++ b/py-rattler/src/repo_data/sparse.rs @@ -62,20 +62,16 @@ impl PySparseRepoData { py: Python<'_>, repo_data: Vec, package_names: Vec, - strict_channel_priority: bool, ) -> PyResult>> { py.allow_threads(move || { let repo_data = repo_data.iter().map(Into::into); let package_names = package_names.into_iter().map(Into::into); - Ok(SparseRepoData::load_records_recursive( - repo_data, - package_names, - None, - strict_channel_priority, - )? - .into_iter() - .map(|v| v.into_iter().map(Into::into).collect::>()) - .collect::>()) + Ok( + SparseRepoData::load_records_recursive(repo_data, package_names, None)? + .into_iter() + .map(|v| v.into_iter().map(Into::into).collect::>()) + .collect::>(), + ) }) } } diff --git a/py-rattler/src/solver.rs b/py-rattler/src/solver.rs index 1697b5724..380b4d213 100644 --- a/py-rattler/src/solver.rs +++ b/py-rattler/src/solver.rs @@ -17,7 +17,6 @@ pub fn py_solve( locked_packages: Vec, pinned_packages: Vec, virtual_packages: Vec, - strict_channel_priority: bool, ) -> PyResult> { py.allow_threads(move || { let package_names = specs @@ -28,7 +27,6 @@ pub fn py_solve( available_packages.iter().map(Into::into), package_names, None, - strict_channel_priority, )?; let task = SolverTask {