Skip to content

Commit

Permalink
fix: gateway recursive records (#930)
Browse files Browse the repository at this point in the history
Fix an issue where all records where returned by the gateway even when
they didnt match the root specs when using recursive mode.
  • Loading branch information
baszalmstra authored Nov 5, 2024
1 parent ccbfb88 commit e8cad8f
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 13 deletions.
4 changes: 2 additions & 2 deletions crates/rattler_repodata_gateway/src/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ mod test {
.unwrap();

let total_records_single_openssl: usize = records.iter().map(RepoData::len).sum();
assert_eq!(total_records_single_openssl, 4644);
assert_eq!(total_records_single_openssl, 4219);

// There should be only one record for the openssl package.
let openssl_records: Vec<&RepoDataRecord> = records
Expand Down Expand Up @@ -571,7 +571,7 @@ mod test {
// The total number of records should be greater than the number of records
// fetched when selecting the openssl with a direct url.
assert!(total_records > total_records_single_openssl);
assert_eq!(total_records, 4692);
assert_eq!(total_records, 4267);

let openssl_records: Vec<&RepoDataRecord> = records
.iter()
Expand Down
38 changes: 27 additions & 11 deletions crates/rattler_repodata_gateway/src/gateway/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ pub struct RepoDataQuery {
reporter: Option<Arc<dyn Reporter>>,
}

#[derive(Clone)]
enum SourceSpecs {
/// The record is required by the user.
Input(Vec<MatchSpec>),

/// The record is required by a dependency.
Transitive,
}

impl RepoDataQuery {
/// Constructs a new instance. This should not be called directly, use
/// [`Gateway::query`] instead.
Expand Down Expand Up @@ -108,10 +117,13 @@ impl RepoDataQuery {
direct_url_specs.push((spec.clone(), url, name));
} else if let Some(name) = &spec.name {
seen.insert(name.clone());
pending_package_specs
let pending = pending_package_specs
.entry(name.clone())
.or_insert_with(Vec::new)
.push(spec);
.or_insert_with(|| SourceSpecs::Input(vec![]));
let SourceSpecs::Input(input_specs) = pending else {
panic!("RootSpecs::Input was overwritten by RootSpecs::Transitive");
};
input_specs.push(spec);
}
}

Expand Down Expand Up @@ -176,7 +188,7 @@ impl RepoDataQuery {
}
}
// Push the direct url in the first subdir result for channel priority logic.
Ok((0, vec![spec], record))
Ok((0, SourceSpecs::Input(vec![spec]), record))
}
.boxed(),
);
Expand Down Expand Up @@ -228,17 +240,19 @@ impl RepoDataQuery {
// Extract the dependencies from the records and recursively add them to the
// list of package names that we need to fetch.
for record in records.iter() {
if !request_specs.iter().any(|spec| spec.matches(record)) {
// Do not recurse into records that do not match to root spec.
continue;
if let SourceSpecs::Input(request_specs) = &request_specs {
if !request_specs.iter().any(|spec| spec.matches(record)) {
// Do not recurse into records that do not match to root spec.
continue;
}
}
for dependency in &record.package_record.depends {
// Use only the name for transitive dependencies.
let dependency_name = PackageName::new_unchecked(
dependency.split_once(' ').unwrap_or((dependency, "")).0,
);
if seen.insert(dependency_name.clone()) {
pending_package_specs.insert(dependency_name.clone(), vec![dependency_name.into()]);
pending_package_specs.insert(dependency_name.clone(), SourceSpecs::Transitive);
}
}
}
Expand All @@ -249,9 +263,11 @@ impl RepoDataQuery {
let result = &mut result[result_idx];

for record in records.iter() {
if !self.recursive && !request_specs.iter().any(|spec| spec.matches(record)) {
// Do not return records that do not match to root spec.
continue;
if let SourceSpecs::Input(request_specs) = &request_specs {
if !request_specs.iter().any(|spec| spec.matches(record)) {
// Do not return records that do not match to input spec.
continue;
}
}
result.len += 1;
result.shards.push(Arc::new([record.clone()]));
Expand Down
10 changes: 10 additions & 0 deletions py-rattler/rattler/package/package_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,15 @@ def __eq__(self, other: object) -> bool:
>>> PackageName("test-abc") == PackageName("test-ABC")
True
>>> PackageName("test-abc") == "test-abc"
True
>>> PackageName("test-abc") == "not-test-abc"
False
>>>
```
"""
if isinstance(other, str):
return self._name == PyPackageName(other)

if not isinstance(other, PackageName):
return False

Expand All @@ -124,10 +129,15 @@ def __ne__(self, other: object) -> bool:
>>> PackageName("test-abc") != PackageName("abc-test")
True
>>> PackageName("test-abc") != "test-abc"
False
>>> PackageName("test-abc") != "not-test-abc"
True
>>>
```
"""
if isinstance(other, str):
return self._name != PyPackageName(other)

if not isinstance(other, PackageName):
return True

Expand Down
2 changes: 2 additions & 0 deletions py-rattler/rattler/repo_data/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,12 @@ def clear_repodata_cache(
Examples
--------
```python
>>> gateway = Gateway()
>>> gateway.clear_repodata_cache("conda-forge", ["linux-64"])
>>> gateway.clear_repodata_cache("robostack")
>>>
```
"""
self._gateway.clear_repodata_cache(
channel._channel if isinstance(channel, Channel) else Channel(channel)._channel,
Expand Down
13 changes: 13 additions & 0 deletions py-rattler/tests/unit/test_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from rattler import Gateway, Channel


@pytest.mark.asyncio
async def test_single_record_in_recursive_query(gateway: Gateway, conda_forge_channel: Channel) -> None:
subdirs = await gateway.query(
[conda_forge_channel], ["linux-64", "noarch"], ["python 3.10.0 h543edf9_1_cpython"], recursive=True
)

python_records = [record for subdir in subdirs for record in subdir if record.name == "python"]
assert len(python_records) == 1

0 comments on commit e8cad8f

Please sign in to comment.