diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index f18aeb0671c7..dab2348ba3de 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -99,30 +99,6 @@ impl DistinctSumAccumulator { data_type: data_type.clone(), }) } - - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - values.iter().for_each(|v| { - // If the value is NULL, it is not included in the final sum. - if !v.is_null() { - self.hash_values.insert(v.clone()); - } - }); - - Ok(()) - } - - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - states.iter().try_for_each(|state| match state { - ScalarValue::List(Some(values), _) => self.update(values.as_ref()), - _ => Err(DataFusionError::Internal(format!( - "Unexpected accumulator state {state:?}" - ))), - }) - } } impl Accumulator for DistinctSumAccumulator { @@ -147,10 +123,14 @@ impl Accumulator for DistinctSumAccumulator { return Ok(()); } - let scalar_values = (0..values[0].len()) - .map(|index| ScalarValue::try_from_array(&values[0], index)) - .collect::>>()?; - self.update(&scalar_values) + let arr = &values[0]; + (0..values[0].len()).try_for_each(|index| { + if !arr.is_null(index) { + let v = ScalarValue::try_from_array(arr, index)?; + self.hash_values.insert(v); + } + Ok(()) + }) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -158,12 +138,22 @@ impl Accumulator for DistinctSumAccumulator { return Ok(()); } - (0..states[0].len()).try_for_each(|index| { - let v = states - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>()?; - self.merge(&v) + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let scalar = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::List(Some(scalar), _) = scalar { + scalar.iter().for_each(|scalar| { + if !ScalarValue::is_null(scalar) { + self.hash_values.insert(scalar.clone()); + } + }); + } else { + return Err(DataFusionError::Internal( + "Unexpected accumulator state".into(), + )); + } + Ok(()) }) }