Skip to content

Commit

Permalink
[FEAT] Make Binary Type Comparable (#1528)
Browse files Browse the repository at this point in the history
* Allows binary arrays to be compared to one another for comparisons,
aggs, grouping, sorting and hashing

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
samster25 and Jay Chia authored Oct 25, 2023
1 parent 227b622 commit 15085f5
Show file tree
Hide file tree
Showing 14 changed files with 744 additions and 17 deletions.
7 changes: 7 additions & 0 deletions src/daft-core/src/array/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ impl From<(&str, Box<arrow2::array::Utf8Array<i64>>)> for Utf8Array {
}
}

impl From<(&str, Box<arrow2::array::BinaryArray<i64>>)> for BinaryArray {
fn from(item: (&str, Box<arrow2::array::BinaryArray<i64>>)) -> Self {
let (name, array) = item;
DataArray::new(Field::new(name, DataType::Binary).into(), array).unwrap()
}
}

impl<T> From<(&str, &[T::Native])> for DataArray<T>
where
T: DaftNumericType,
Expand Down
84 changes: 83 additions & 1 deletion src/daft-core/src/array/ops/compare_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,89 @@ impl DaftCompareAggable for DataArray<Utf8Type> {
}
}

fn grouped_cmp_binary<'a, F>(
data_array: &'a BinaryArray,
op: F,
groups: &GroupIndices,
) -> DaftResult<BinaryArray>
where
F: Fn(&'a [u8], &'a [u8]) -> &'a [u8],
{
let arrow_array = data_array.as_arrow();
let cmp_per_group = if arrow_array.null_count() > 0 {
let cmp_values_iter = groups.iter().map(|g| {
let reduced_val = g
.iter()
.map(|i| {
let idx = *i as usize;
match arrow_array.is_null(idx) {
false => Some(unsafe { arrow_array.value_unchecked(idx) }),
true => None,
}
})
.reduce(|l, r| match (l, r) {
(None, None) => None,
(None, Some(r)) => Some(r),
(Some(l), None) => Some(l),
(Some(l), Some(r)) => Some(op(l, r)),
});
match reduced_val {
None => None,
Some(v) => v,
}
});
Box::new(arrow2::array::BinaryArray::<i64>::from_trusted_len_iter(
cmp_values_iter,
))
} else {
Box::new(
arrow2::array::BinaryArray::<i64>::from_trusted_len_values_iter(groups.iter().map(
|g| {
g.iter()
.map(|i| {
let idx = *i as usize;
unsafe { arrow_array.value_unchecked(idx) }
})
.reduce(|l, r| op(l, r))
.unwrap()
},
)),
)
};
Ok(DataArray::from((
data_array.field.name.as_ref(),
cmp_per_group,
)))
}

impl DaftCompareAggable for DataArray<BinaryType> {
type Output = DaftResult<DataArray<BinaryType>>;
fn min(&self) -> Self::Output {
let arrow_array: &arrow2::array::BinaryArray<i64> = self.as_arrow();

let result = arrow2::compute::aggregate::min_binary(arrow_array);
let res_arrow_array = arrow2::array::BinaryArray::<i64>::from([result]);

DataArray::new(self.field.clone(), Box::new(res_arrow_array))
}
fn max(&self) -> Self::Output {
let arrow_array: &arrow2::array::BinaryArray<i64> = self.as_arrow();

let result = arrow2::compute::aggregate::max_binary(arrow_array);
let res_arrow_array = arrow2::array::BinaryArray::<i64>::from([result]);

DataArray::new(self.field.clone(), Box::new(res_arrow_array))
}

fn grouped_min(&self, groups: &GroupIndices) -> Self::Output {
grouped_cmp_binary(self, |l, r| l.min(r), groups)
}

fn grouped_max(&self, groups: &GroupIndices) -> Self::Output {
grouped_cmp_binary(self, |l, r| l.max(r), groups)
}
}

fn grouped_cmp_bool(
data_array: &BooleanArray,
val_to_find: bool,
Expand Down Expand Up @@ -339,7 +422,6 @@ macro_rules! impl_todo_daft_comparable {
};
}

impl_todo_daft_comparable!(BinaryArray);
impl_todo_daft_comparable!(StructArray);
impl_todo_daft_comparable!(FixedSizeListArray);
impl_todo_daft_comparable!(ListArray);
Expand Down
Loading

0 comments on commit 15085f5

Please sign in to comment.