Skip to content

Commit

Permalink
fix: Correct results for grouping sets when columns contain nulls (#1…
Browse files Browse the repository at this point in the history
…2571)

* Fix grouping sets behavior when data contains nulls

* PR suggestion comment

* Update new test case

* Add grouping_id to the logical plan

* Add doc comment next to INTERNAL_GROUPING_ID

* Fix unparsing of Aggregate with grouping sets

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
eejbyfeldt and alamb authored Oct 7, 2024
1 parent 134939a commit ef227f4
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 187 deletions.
17 changes: 17 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,26 @@ impl DataFrame {
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<DataFrame> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let aggr_expr_len = aggr_expr.len();
let plan = LogicalPlanBuilder::from(self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
let plan = if is_grouping_set {
let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len;
// For grouping sets we do a project to not expose the internal grouping id
let exprs = plan
.schema()
.columns()
.into_iter()
.enumerate()
.filter(|(idx, _)| *idx != grouping_id_pos)
.map(|(_, column)| Expr::Column(column))
.collect::<Vec<_>>();
LogicalPlanBuilder::from(plan).project(exprs)?.build()?
} else {
plan
};
Ok(DataFrame {
session_state: self.session_state,
plan,
Expand Down
14 changes: 2 additions & 12 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,6 @@ impl DefaultPhysicalPlanner {
physical_input_schema.clone(),
)?);

// update group column indices based on partial aggregate plan evaluation
let final_group: Vec<Arc<dyn PhysicalExpr>> =
initial_aggr.output_group_expr();

let can_repartition = !groups.is_empty()
&& session_state.config().target_partitions() > 1
&& session_state.config().repartition_aggregations();
Expand All @@ -716,13 +712,7 @@ impl DefaultPhysicalPlanner {
AggregateMode::Final
};

let final_grouping_set = PhysicalGroupBy::new_single(
final_group
.iter()
.enumerate()
.map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone()))
.collect(),
);
let final_grouping_set = initial_aggr.group_expr().as_final();

Arc::new(AggregateExec::try_new(
next_partition_mode,
Expand Down Expand Up @@ -2345,7 +2335,7 @@ mod tests {
.expect("hash aggregate");
assert_eq!(
"sum(aggregate_test_100.c3)",
final_hash_agg.schema().field(2).name()
final_hash_agg.schema().field(3).name()
);
// we need access to the input to the partial aggregate so that other projects can
// implement serde
Expand Down
56 changes: 55 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use super::dml::CopyTo;
use super::DdlStatement;
Expand Down Expand Up @@ -2965,6 +2965,15 @@ impl Aggregate {
.into_iter()
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
.collect::<Vec<_>>();
qualified_fields.push((
None,
Field::new(
Self::INTERNAL_GROUPING_ID,
Self::grouping_id_type(qualified_fields.len()),
false,
)
.into(),
));
}

qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
Expand Down Expand Up @@ -3016,9 +3025,19 @@ impl Aggregate {
})
}

fn is_grouping_set(&self) -> bool {
matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
}

/// Get the output expressions.
fn output_expressions(&self) -> Result<Vec<&Expr>> {
static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
if self.is_grouping_set() {
exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
}));
}
exprs.extend(self.aggr_expr.iter());
debug_assert!(exprs.len() == self.schema.fields().len());
Ok(exprs)
Expand All @@ -3030,6 +3049,41 @@ impl Aggregate {
pub fn group_expr_len(&self) -> Result<usize> {
grouping_set_expr_count(&self.group_expr)
}

/// Returns the data type of the grouping id.
/// The grouping ID value is a bitmask where each set bit
/// indicates that the corresponding grouping expression is
/// null
pub fn grouping_id_type(group_exprs: usize) -> DataType {
if group_exprs <= 8 {
DataType::UInt8
} else if group_exprs <= 16 {
DataType::UInt16
} else if group_exprs <= 32 {
DataType::UInt32
} else {
DataType::UInt64
}
}

/// Internal column used when the aggregation is a grouping set.
///
/// This column contains a bitmask where each bit represents a grouping
/// expression. The least significant bit corresponds to the rightmost
/// grouping expression. A bit value of 0 indicates that the corresponding
/// column is included in the grouping set, while a value of 1 means it is excluded.
///
/// For example, for the grouping expressions CUBE(a, b), the grouping ID
/// column will have the following values:
/// 0b00: Both `a` and `b` are included
/// 0b01: `b` is excluded
/// 0b10: `a` is excluded
/// 0b11: Both `a` and `b` are excluded
///
/// This internal column is necessary because excluded columns are replaced
/// with `NULL` values. To handle these cases correctly, we must distinguish
/// between an actual `NULL` value in a column and a column being excluded from the set.
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
Expand Down
12 changes: 11 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,17 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result
/// Count the number of distinct exprs in a list of group by expressions. If the
/// first element is a `GroupingSet` expression then it must be the only expr.
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
// Groupings sets have an additional interal column for the grouping id
Ok(grouping_set.distinct_expr().len() + 1)
} else {
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
}
}

/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -373,7 +373,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -392,7 +392,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand Down
Loading

0 comments on commit ef227f4

Please sign in to comment.