Skip to content

Commit

Permalink
Review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Nov 29, 2023
1 parent 0aab1d1 commit 8d97950
Showing 1 changed file with 53 additions and 15 deletions.
68 changes: 53 additions & 15 deletions arrow-schema/src/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,27 +100,38 @@ impl Fields {
.all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
}

/// Performs a depth-first scan of [`Fields`] filtering the [`FieldRef`] with no children
/// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate
///
/// Invokes `filter` with each leaf [`FieldRef`], i.e. one containing no children, and a
/// count of the number of previous calls to `filter` - i.e. the leaf's index.
/// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`]
/// containing no child [`FieldRef`], a leaf field, along with a count of the number
/// of such leaves encountered so far. Only [`FieldRef`] for which `filter`
/// returned `true` will be included in the result.
///
/// Returns a new [`Fields`] comprising the [`FieldRef`] for which `filter` returned `true`
/// This can therefore be used to select a subset of fields from nested types
/// such as [`DataType::Struct`] or [`DataType::List`].
///
/// ```
/// # use arrow_schema::{DataType, Field, Fields};
/// let fields = Fields::from(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("a", DataType::Int32, true), // Leaf 0
/// Field::new("b", DataType::Struct(Fields::from(vec![
/// Field::new("c", DataType::Float32, false),
/// Field::new("d", DataType::Float64, false),
/// Field::new("c", DataType::Float32, false), // Leaf 1
/// Field::new("d", DataType::Float64, false), // Leaf 2
/// Field::new("e", DataType::Struct(Fields::from(vec![
/// Field::new("f", DataType::Int32, false), // Leaf 3
/// Field::new("g", DataType::Float16, false), // Leaf 4
/// ])), true),
/// ])), false)
/// ]);
/// let filtered = fields.filter_leaves(|idx, _| idx == 0 || idx == 2);
/// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx));
/// let expected = Fields::from(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("b", DataType::Struct(Fields::from(vec![
/// Field::new("d", DataType::Float64, false),
/// Field::new("e", DataType::Struct(Fields::from(vec![
/// Field::new("f", DataType::Int32, false),
/// Field::new("g", DataType::Float16, false),
/// ])), true),
/// ])), false)
/// ]);
/// assert_eq!(filtered, expected);
Expand All @@ -132,9 +143,10 @@ impl Fields {
) -> Option<FieldRef> {
use DataType::*;

let (k, v) = match f.data_type() {
Dictionary(k, v) => (Some(k.clone()), v.as_ref()),
d => (None, d),
let v = match f.data_type() {
Dictionary(_, v) => v.as_ref(), // Key must be integer
RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer
d => d,
};
let d = match v {
List(child) => List(filter_field(child, filter)?),
Expand Down Expand Up @@ -167,9 +179,12 @@ impl Fields {
}
_ => return filter(f).then(|| f.clone()),
};
let d = match k {
Some(k) => Dictionary(k, Box::new(d)),
None => d,
let d = match f.data_type() {
Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
RunEndEncoded(v, f) => {
RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
}
_ => d,
};
Some(Arc::new(f.as_ref().clone().with_data_type(d)))
}
Expand Down Expand Up @@ -456,6 +471,14 @@ mod tests {
),
true,
),
Field::new(
"i",
DataType::RunEndEncoded(
Arc::new(Field::new("run_ends", DataType::Int32, false)),
Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
),
false,
),
]);

let floats_a = DataType::Struct(vec![floats[0].clone()].into());
Expand All @@ -466,7 +489,7 @@ mod tests {
assert_eq!(r[1].data_type(), &floats_a);

let r = fields.filter_leaves(|_, f| f.name() == "a");
assert_eq!(r.len(), 4);
assert_eq!(r.len(), 5);
assert_eq!(r[0], fields[0]);
assert_eq!(r[1].data_type(), &floats_a);
assert_eq!(
Expand All @@ -477,6 +500,17 @@ mod tests {
r[3].as_ref(),
&Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
);
assert_eq!(
r[4].as_ref(),
&Field::new(
"i",
DataType::RunEndEncoded(
Arc::new(Field::new("run_ends", DataType::Int32, false)),
Arc::new(Field::new("values", floats_a.clone(), true)),
),
false,
)
);

let r = fields.filter_leaves(|_, f| f.name() == "floats");
assert_eq!(r.len(), 0);
Expand All @@ -497,5 +531,9 @@ mod tests {
let r = fields.filter_leaves(|idx, _| idx == 12);
assert_eq!(r.len(), 1);
assert_eq!(r[0].data_type(), &union);

let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[9]);
}
}

0 comments on commit 8d97950

Please sign in to comment.