diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index 7409f5c214b9..122d77683cef 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -15,24 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::BooleanBufferBuilder; use arrow::array::BufferBuilder; use arrow::array::GenericBinaryArray; use arrow::array::GenericStringArray; use arrow::array::OffsetSizeTrait; use arrow::array::PrimitiveArray; use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray}; -use arrow::buffer::NullBuffer; use arrow::buffer::OffsetBuffer; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::ArrowNativeType; use arrow::datatypes::ByteArrayType; use arrow::datatypes::DataType; use arrow::datatypes::GenericBinaryType; -use arrow::datatypes::GenericStringType; use datafusion_common::utils::proxy::VecAllocExt; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow_array::types::GenericStringType; use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; use std::sync::Arc; use std::vec; @@ -190,6 +187,12 @@ impl GroupColumn for PrimitiveGroupValueBuilder { } /// An implementation of [`GroupColumn`] for binary and utf8 types. +/// +/// Stores a collection of binary or utf8 group values in a single buffer +/// in a way that allows: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array pub struct ByteGroupValueBuilder where O: OffsetSizeTrait, @@ -201,8 +204,8 @@ where /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values /// are stored as a zero length string. offsets: Vec, - /// Null indexes in offsets, if `i` is in nulls, `offsets[i]` should be equals to `offsets[i+1]` - nulls: Vec, + /// Nulls + nulls: MaybeNullBufferBuilder, } impl ByteGroupValueBuilder @@ -214,7 +217,7 @@ where output_type, buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], - nulls: vec![], + nulls: MaybeNullBufferBuilder::new(), } } @@ -224,40 +227,33 @@ where { let arr = array.as_bytes::(); if arr.is_null(row) { - self.nulls.push(self.len()); + self.nulls.append(true); // nulls need a zero length in the offset buffer let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - return; + } else { + self.nulls.append(false); + let value: &[u8] = arr.value(row).as_ref(); + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); } - - let value: &[u8] = arr.value(row).as_ref(); - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); } fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool where B: ByteArrayType, { - // Handle nulls - let is_lhs_null = self.nulls.iter().any(|null_idx| *null_idx == lhs_row); let arr = array.as_bytes::(); - if is_lhs_null { - return arr.is_null(rhs_row); - } else if arr.is_null(rhs_row) { - return false; - } + self.nulls.is_null(lhs_row) == arr.is_null(rhs_row) + && self.value(lhs_row) == (arr.value(rhs_row).as_ref() as &[u8]) + } - let arr = array.as_bytes::(); - let rhs_elem: &[u8] = arr.value(rhs_row).as_ref(); - let rhs_elem_len = arr.value_length(rhs_row).as_usize(); - debug_assert_eq!(rhs_elem_len, rhs_elem.len()); - let l = self.offsets[lhs_row].as_usize(); - let r = self.offsets[lhs_row + 1].as_usize(); - let existing_elem = unsafe { self.buffer.as_slice().get_unchecked(l..r) }; - rhs_elem == existing_elem + /// return the current value of the specified row irrespective of null + pub fn value(&self, row: usize) -> &[u8] { + let l = self.offsets[row].as_usize(); + let r = self.offsets[row + 1].as_usize(); + // Safety: the offsets are constructed correctly and never decrease + unsafe { self.buffer.as_slice().get_unchecked(l..r) } } } @@ -325,18 +321,7 @@ where nulls, } = *self; - let null_buffer = if nulls.is_empty() { - None - } else { - // Only make a `NullBuffer` if there was a null value - let num_values = offsets.len() - 1; - let mut bool_builder = BooleanBufferBuilder::new(num_values); - bool_builder.append_n(num_values, true); - nulls.into_iter().for_each(|null_index| { - bool_builder.set_bit(null_index, false); - }); - Some(NullBuffer::from(bool_builder.finish())) - }; + let null_buffer = nulls.build(); // SAFETY: the offsets were constructed correctly in `insert_if_new` -- // monotonically increasing, overflows were checked. @@ -353,9 +338,9 @@ where // SAFETY: // 1. the offsets were constructed safely // - // 2. we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out + // 2. the input arrays were all the correct type and thus since + // all the values that went in were valid (e.g. utf8) so are all + // the values that come out Arc::new(unsafe { GenericStringArray::new_unchecked(offsets, values, null_buffer) }) @@ -366,27 +351,7 @@ where fn take_n(&mut self, n: usize) -> ArrayRef { debug_assert!(self.len() >= n); - - let null_buffer = if self.nulls.is_empty() { - None - } else { - // Only make a `NullBuffer` if there was a null value - let mut bool_builder = BooleanBufferBuilder::new(n); - bool_builder.append_n(n, true); - - let mut new_nulls = vec![]; - self.nulls.iter().for_each(|null_index| { - if *null_index < n { - bool_builder.set_bit(*null_index, false); - } else { - new_nulls.push(null_index - n); - } - }); - - self.nulls = new_nulls; - Some(NullBuffer::from(bool_builder.finish())) - }; - + let null_buffer = self.nulls.take_n(n); let first_remaining_offset = O::as_usize(self.offsets[n]); // Given offests like [0, 2, 4, 5] and n = 1, we expect to get