diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 6f79c98a6c3a..28f1fc31995a 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -66,6 +66,7 @@ impl AggregateExpr for Median { fn create_accumulator(&self) -> Result> { Ok(Box::new(MedianAccumulator { data_type: self.data_type.clone(), + arrays: vec![], all_values: vec![], })) } @@ -108,16 +109,21 @@ impl PartialEq for Median { /// The median accumulator accumulates the raw input values /// as `ScalarValue`s /// -/// The intermediate state is represented as a List of those scalars +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. struct MedianAccumulator { data_type: DataType, + arrays: Vec, all_values: Vec, } impl Accumulator for MedianAccumulator { fn state(&self) -> Result> { - let state = - ScalarValue::new_list(Some(self.all_values.clone()), self.data_type.clone()); + let all_values = to_scalar_values(&self.arrays)?; + let state = ScalarValue::new_list(Some(all_values), self.data_type.clone()); + Ok(vec![state]) } @@ -125,12 +131,9 @@ impl Accumulator for MedianAccumulator { assert_eq!(values.len(), 1); let array = &values[0]; + // Defer conversions to scalar values to final evaluation. assert_eq!(array.data_type(), &self.data_type); - self.all_values.reserve(array.len()); - for index in 0..array.len() { - self.all_values - .push(ScalarValue::try_from_array(array, index)?); - } + self.arrays.push(array.clone()); Ok(()) } @@ -157,7 +160,14 @@ impl Accumulator for MedianAccumulator { } fn evaluate(&self) -> Result { - if !self.all_values.iter().any(|v| !v.is_null()) { + let batch_values = to_scalar_values(&self.arrays)?; + + if !self + .all_values + .iter() + .chain(batch_values.iter()) + .any(|v| !v.is_null()) + { return ScalarValue::try_from(&self.data_type); } @@ -166,6 +176,7 @@ impl Accumulator for MedianAccumulator { let array = ScalarValue::iter_to_array( self.all_values .iter() + .chain(batch_values.iter()) // ignore null values .filter(|v| !v.is_null()) .cloned(), @@ -214,13 +225,30 @@ impl Accumulator for MedianAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_vec(&self.all_values) + let arrays_size: usize = self.arrays.iter().map(|a| a.len()).sum(); + + std::mem::size_of_val(self) + + ScalarValue::size_of_vec(&self.all_values) + + arrays_size - std::mem::size_of_val(&self.all_values) + self.data_type.size() - std::mem::size_of_val(&self.data_type) } } +fn to_scalar_values(arrays: &[ArrayRef]) -> Result> { + let num_values: usize = arrays.iter().map(|a| a.len()).sum(); + let mut all_values = Vec::with_capacity(num_values); + + for array in arrays { + for index in 0..array.len() { + all_values.push(ScalarValue::try_from_array(&array, index)?); + } + } + + Ok(all_values) +} + /// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue` fn scalar_at_index( array: &dyn Array,