Skip to content

Commit

Permalink
feat: add channel priority and channel-specific selectors to solver i…
Browse files Browse the repository at this point in the history
…nfo (#394)
  • Loading branch information
ruben-arts authored Nov 8, 2023
1 parent b66d753 commit 3b7132d
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 127 deletions.
1 change: 0 additions & 1 deletion crates/rattler-bin/src/commands/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
record.depends.push("pip".to_string());
}
}),
true,
)
})?;

Expand Down
127 changes: 30 additions & 97 deletions crates/rattler_repodata_gateway/src/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,10 @@ impl SparseRepoData {
/// This will parse the records for the specified packages as well as all the packages these records
/// depend on.
///
/// When strict_channel_priority is true, the channel where a package is found first will be
/// the only channel used for that package. Make it false to search in all channels for all packages.
///
pub fn load_records_recursive<'a>(
repo_data: impl IntoIterator<Item = &'a SparseRepoData>,
package_names: impl IntoIterator<Item = PackageName>,
patch_function: Option<fn(&mut PackageRecord)>,
strict_channel_priority: bool,
) -> io::Result<Vec<Vec<RepoDataRecord>>> {
let repo_data: Vec<_> = repo_data.into_iter().collect();

Expand All @@ -141,13 +137,7 @@ impl SparseRepoData {

// Iterate over the list of packages that still need to be processed.
while let Some(next_package) = pending.pop_front() {
let mut found_in_channel = None;
for (i, repo_data) in repo_data.iter().enumerate() {
// If package was found in other channel, skip this repodata
if found_in_channel.map_or(false, |c| c != &repo_data.channel.base_url) {
continue;
}

let repo_data_packages = repo_data.inner.borrow_repo_data();
let base_url = repo_data_packages
.info
Expand All @@ -173,10 +163,6 @@ impl SparseRepoData {
)?;
records.append(&mut conda_records);

if strict_channel_priority && !records.is_empty() {
found_in_channel = Some(&repo_data.channel.base_url);
}

// Iterate over all packages to find recursive dependencies.
for record in records.iter() {
for dependency in &record.package_record.depends {
Expand Down Expand Up @@ -274,7 +260,6 @@ pub async fn load_repo_data_recursively(
repo_data_paths: impl IntoIterator<Item = (Channel, impl Into<String>, impl AsRef<Path>)>,
package_names: impl IntoIterator<Item = PackageName>,
patch_function: Option<fn(&mut PackageRecord)>,
strict_channel_priority: bool,
) -> Result<Vec<Vec<RepoDataRecord>>, io::Error> {
// Open the different files and memory map them to get access to their bytes. Do this in parallel.
let lazy_repo_data = stream::iter(repo_data_paths)
Expand All @@ -293,12 +278,7 @@ pub async fn load_repo_data_recursively(
.try_collect::<Vec<_>>()
.await?;

SparseRepoData::load_records_recursive(
&lazy_repo_data,
package_names,
patch_function,
strict_channel_priority,
)
SparseRepoData::load_records_recursive(&lazy_repo_data, package_names, patch_function)
}

fn deserialize_filename_and_raw_record<'d, D: Deserializer<'d>>(
Expand Down Expand Up @@ -401,7 +381,6 @@ impl<'de> TryFrom<&'de str> for PackageFilename<'de> {
#[cfg(test)]
mod test {
use super::{load_repo_data_recursively, PackageFilename};
use itertools::Itertools;
use rattler_conda_types::{Channel, ChannelConfig, PackageName, RepoData, RepoDataRecord};
use rstest::rstest;
use std::path::{Path, PathBuf};
Expand All @@ -412,7 +391,6 @@ mod test {

async fn load_sparse(
package_names: impl IntoIterator<Item = impl AsRef<str>>,
strict_channel_priority: bool,
) -> Vec<Vec<RepoDataRecord>> {
load_repo_data_recursively(
[
Expand All @@ -426,31 +404,25 @@ mod test {
"linux-64",
test_dir().join("channels/conda-forge/linux-64/repodata.json"),
),
(
Channel::from_str("pytorch", &ChannelConfig::default()).unwrap(),
"linux-64",
test_dir().join("channels/pytorch/linux-64/repodata.json"),
),
],
package_names
.into_iter()
.map(|name| PackageName::try_from(name.as_ref()).unwrap()),
None,
strict_channel_priority,
)
.await
.unwrap()
}

#[tokio::test]
async fn test_empty_sparse_load() {
let sparse_empty_data = load_sparse(Vec::<String>::new(), false).await;
assert_eq!(sparse_empty_data, vec![vec![], vec![], vec![]]);
let sparse_empty_data = load_sparse(Vec::<String>::new()).await;
assert_eq!(sparse_empty_data, vec![vec![], vec![]]);
}

#[tokio::test]
async fn test_sparse_single() {
let sparse_empty_data = load_sparse(["_libgcc_mutex"], false).await;
let sparse_empty_data = load_sparse(["_libgcc_mutex"]).await;
let total_records = sparse_empty_data
.iter()
.map(|repo| repo.len())
Expand All @@ -459,45 +431,9 @@ mod test {
assert_eq!(total_records, 3);
}

#[tokio::test]
async fn test_sparse_strict() {
// If we load pytorch-cpy from all channels (non-strict) we expect records from both
// conda-forge and the pytorch channels.
let sparse_data = load_sparse(["pytorch-cpu"], false).await;
let channels = sparse_data
.into_iter()
.flatten()
.filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu")
.map(|record| record.channel)
.unique()
.collect_vec();
assert_eq!(
channels,
vec![
String::from("https://conda.anaconda.org/conda-forge/"),
String::from("https://conda.anaconda.org/pytorch/")
]
);

// If we load pytorch-cpy from strict channels we expect records only from the first channel
// that contains the package. In this case we expect records only from conda-forge.
let strict_sparse_data = load_sparse(["pytorch-cpu"], true).await;
let channels = strict_sparse_data
.into_iter()
.flatten()
.filter(|record| record.package_record.name.as_normalized() == "pytorch-cpu")
.map(|record| record.channel)
.unique()
.collect_vec();
assert_eq!(
channels,
vec![String::from("https://conda.anaconda.org/conda-forge/")]
);
}

#[tokio::test]
async fn test_parse_duplicate() {
let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"], false).await;
let sparse_empty_data = load_sparse(["_libgcc_mutex", "_libgcc_mutex"]).await;
let total_records = sparse_empty_data
.iter()
.map(|repo| repo.len())
Expand All @@ -509,7 +445,7 @@ mod test {

#[tokio::test]
async fn test_sparse_jupyterlab_detectron2() {
let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"], true).await;
let sparse_empty_data = load_sparse(["jupyterlab", "detectron2"]).await;

let total_records = sparse_empty_data
.iter()
Expand All @@ -521,33 +457,30 @@ mod test {

#[tokio::test]
async fn test_sparse_numpy_dev() {
let sparse_empty_data = load_sparse(
[
"python",
"cython",
"compilers",
"openblas",
"nomkl",
"pytest",
"pytest-cov",
"pytest-xdist",
"hypothesis",
"mypy",
"typing_extensions",
"sphinx",
"numpydoc",
"ipython",
"scipy",
"pandas",
"matplotlib",
"pydata-sphinx-theme",
"pycodestyle",
"gitpython",
"cffi",
"pytz",
],
false,
)
let sparse_empty_data = load_sparse([
"python",
"cython",
"compilers",
"openblas",
"nomkl",
"pytest",
"pytest-cov",
"pytest-xdist",
"hypothesis",
"mypy",
"typing_extensions",
"sphinx",
"numpydoc",
"ipython",
"scipy",
"pandas",
"matplotlib",
"pydata-sphinx-theme",
"pycodestyle",
"gitpython",
"cffi",
"pytz",
])
.await;

let total_records = sparse_empty_data
Expand Down
2 changes: 1 addition & 1 deletion crates/rattler_solve/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ url = "2.4.1"
hex = "0.4.3"
tempfile = "3.8.0"
rattler_libsolv_c = { version = "0.11.0", path = "../rattler_libsolv_c", optional = true }
resolvo = { version = "0.1.0", optional = true }
resolvo = { version = "0.2.0", optional = true }

[dev-dependencies]
rattler_repodata_gateway = { version = "0.11.0", path = "../rattler_repodata_gateway", default-features = false, features = ["sparse"] }
Expand Down
2 changes: 1 addition & 1 deletion crates/rattler_solve/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn bench_solve_environment(c: &mut Criterion, specs: Vec<&str>) {

let names = specs.iter().map(|s| s.name.clone().unwrap());
let available_packages =
SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None, true).unwrap();
SparseRepoData::load_records_recursive(&sparse_repo_datas, names, None).unwrap();

#[cfg(feature = "libsolv_c")]
group.bench_function("libsolv_c", |b| {
Expand Down
68 changes: 64 additions & 4 deletions crates/rattler_solve/src/resolvo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ impl<'a> CondaDependencyProvider<'a> {
favored_records: &'a [RepoDataRecord],
locked_records: &'a [RepoDataRecord],
virtual_packages: &'a [GenericVirtualPackage],
match_specs: &[MatchSpec],
) -> Self {
let pool = Pool::default();
let mut records: HashMap<NameId, Candidates> = HashMap::default();
Expand All @@ -184,6 +185,15 @@ impl<'a> CondaDependencyProvider<'a> {
records.entry(name).or_default().candidates.push(solvable);
}

// TODO: Normalize these channel names to urls so we can compare them correctly.
let channel_specific_specs = match_specs
.iter()
.filter(|spec| spec.channel.is_some())
.collect::<Vec<_>>();

// Hashmap that maps the package name to the channel it was first found in.
let mut package_name_found_in_channel = HashMap::<String, &String>::new();

// Add additional records
for repo_datas in repodata {
// Iterate over all records and dedup records that refer to the same package data but with
Expand Down Expand Up @@ -237,6 +247,55 @@ impl<'a> CondaDependencyProvider<'a> {
pool.intern_solvable(package_name, SolverPackageRecord::Record(record));
let candidates = records.entry(package_name).or_default();
candidates.candidates.push(solvable_id);

// Add to excluded when package is not in the specified channel.
if !channel_specific_specs.is_empty() {
if let Some(spec) = channel_specific_specs.iter().find(|&&spec| {
spec.name
.as_ref()
.expect("expecting a name")
.as_normalized()
== record.package_record.name.as_normalized()
}) {
// Check if the spec has a channel, and compare it to the repodata channel
if let Some(spec_channel) = &spec.channel {
if !&record.channel.contains(spec_channel) {
tracing::debug!("Ignoring {} from {} because it was not requested from that channel.", &record.package_record.name.as_normalized(), &record.channel);
// Add record to the excluded with reason of being in the non requested channel.
candidates.excluded.push((
solvable_id,
pool.intern_string(format!(
"candidate not in requested channel: '{spec_channel}'"
)),
));
continue;
}
}
}
}

// Enforce channel priority
// This functions makes the assumtion that the records are given in order of the channels.
if let Some(first_channel) = package_name_found_in_channel
.get(&record.package_record.name.as_normalized().to_string())
{
// Add the record to the excluded list when it is from a different channel.
if first_channel != &&record.channel {
tracing::debug!(
"Ignoring '{}' from '{}' because of strict channel priority.",
&record.package_record.name.as_normalized(),
&record.channel
);
candidates.excluded.push((solvable_id, pool.intern_string(format!("due to strict channel priority not using this option from: '{first_channel}'", ))));
continue;
}
} else {
package_name_found_in_channel.insert(
record.package_record.name.as_normalized().to_string(),
&record.channel,
);
}

candidates.hint_dependencies_available.push(solvable_id);
}
}
Expand All @@ -245,15 +304,15 @@ impl<'a> CondaDependencyProvider<'a> {
for favored_record in favored_records {
let name = pool.intern_package_name(favored_record.package_record.name.as_normalized());
let solvable = pool.intern_solvable(name, SolverPackageRecord::Record(favored_record));
let mut candidates = records.entry(name).or_default();
let candidates = records.entry(name).or_default();
candidates.candidates.push(solvable);
candidates.favored = Some(solvable);
}

for locked_record in locked_records {
let name = pool.intern_package_name(locked_record.package_record.name.as_normalized());
let solvable = pool.intern_solvable(name, SolverPackageRecord::Record(locked_record));
let mut candidates = records.entry(name).or_default();
let candidates = records.entry(name).or_default();
candidates.candidates.push(solvable);
candidates.locked = Some(solvable);
}
Expand Down Expand Up @@ -347,14 +406,15 @@ impl super::SolverImpl for Solver {
&task.locked_packages,
&task.pinned_packages,
&task.virtual_packages,
task.specs.clone().as_ref(),
);

// Construct the requirements that the solver needs to satisfy.
let root_requirements = task
.specs
.into_iter()
.iter()
.map(|spec| {
let (name, spec) = spec.into_nameless();
let (name, spec) = spec.clone().into_nameless();
let name = name.expect("cannot use matchspec without a name");
let name_id = provider.pool.intern_package_name(name.as_normalized());
provider.pool.intern_version_set(name_id, spec.into())
Expand Down
Loading

0 comments on commit 3b7132d

Please sign in to comment.