Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove more clones #8

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 163 additions & 166 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! [`SingleDistinctToGroupBy`] replaces `AGG(DISTINCT ..)` with `AGG(..) GROUP BY ..`

use std::hash::BuildHasher;
use std::sync::Arc;

use crate::optimizer::ApplyOrder;
Expand Down Expand Up @@ -66,38 +67,33 @@ impl SingleDistinctToGroupBy {
}

/// Check whether all aggregate exprs are distinct on a single field.
fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
match plan {
LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
let mut fields_set = HashSet::new();
let mut aggregate_count = 0;
for expr in aggr_expr {
if let Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(fun),
distinct,
args,
filter,
order_by,
null_treatment: _,
}) = expr
{
if filter.is_some() || order_by.is_some() {
return Ok(false);
}
aggregate_count += 1;
if *distinct {
for e in args {
fields_set.insert(e.canonical_name());
}
} else if !matches!(fun, Sum | Min | Max) {
return Ok(false);
}
fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result<bool> {
let mut fields_set = HashSet::new();
let mut aggregate_count = 0;
for expr in aggr_expr {
if let Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(fun),
distinct,
args,
filter,
order_by,
null_treatment: _,
}) = expr
{
if filter.is_some() || order_by.is_some() {
return Ok(false);
}
aggregate_count += 1;
if *distinct {
for e in args {
fields_set.insert(e.canonical_name());
}
} else if !matches!(fun, Sum | Min | Max) {
return Ok(false);
}
Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
}
_ => Ok(false),
}
Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
}

/// Check if the first expr is [Expr::GroupingSet].
Expand Down Expand Up @@ -131,162 +127,163 @@ impl OptimizerRule for SingleDistinctToGroupBy {
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
match &plan {
match plan {
LogicalPlan::Aggregate(Aggregate {
input,
aggr_expr,
schema,
group_expr,
..
}) => {
if is_single_distinct_agg(&plan)? && !contains_grouping_set(group_expr) {
// alias all original group_by exprs
let (mut inner_group_exprs, out_group_expr_with_alias): (
Vec<Expr>,
Vec<(Expr, Option<String>)>,
) = group_expr
.iter()
.enumerate()
.map(|(i, group_expr)| {
if let Expr::Column(_) = group_expr {
// For Column expressions we can use existing expression as is.
(group_expr.clone(), (group_expr.clone(), None))
} else {
// For complex expression write is as alias, to be able to refer
// if from parent operators successfully.
// Consider plan below.
//
// Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\
// --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
// ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
//
// First aggregate(from bottom) refers to `test.a` column.
// Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate.
// If we were to write plan above as below without alias
//
// Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\
// --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
// ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
//
// Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it.
let alias_str = format!("group_alias_{i}");
let alias_expr = group_expr.clone().alias(&alias_str);
let (qualifier, field) = schema.qualified_field(i);
}) if is_single_distinct_agg(&aggr_expr)?
&& !contains_grouping_set(&group_expr) =>
{
let group_size = group_expr.len();
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to copy the len before moving the expr into the iterator

// alias all original group_by exprs
let (mut inner_group_exprs, out_group_expr_with_alias): (
Vec<Expr>,
Vec<(Expr, Option<String>)>,
) = group_expr
.into_iter()
.enumerate()
.map(|(i, group_expr)| {
if let Expr::Column(_) = group_expr {
// For Column expressions we can use existing expression as is.
(group_expr.clone(), (group_expr, None))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one clone is necessary here

} else {
// For complex expression write is as alias, to be able to refer
// if from parent operators successfully.
// Consider plan below.
//
// Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\
// --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
// ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
//
// First aggregate(from bottom) refers to `test.a` column.
// Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate.
// If we were to write plan above as below without alias
//
// Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\
// --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
// ----TableScan: test [a:UInt32, b:UInt32, c:UInt32]
//
// Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it.
let alias_str = format!("group_alias_{i}");
let (qualifier, field) = schema.qualified_field(i);
(
group_expr.alias(alias_str.clone()),
(
alias_expr,
(
col(alias_str),
Some(qualified_name(qualifier, field.name())),
),
)
}
})
.unzip();

// and they can be referenced by the alias in the outer aggr plan
let outer_group_exprs = out_group_expr_with_alias
.iter()
.map(|(out_group_expr, _)| out_group_expr.clone())
.collect::<Vec<_>>();

// replace the distinct arg with alias
let mut index = 1;
let mut group_fields_set = HashSet::new();
let mut inner_aggr_exprs = vec![];
let outer_aggr_exprs = aggr_expr
.iter()
.map(|aggr_expr| match aggr_expr {
Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(fun),
args,
distinct,
..
}) => {
// is_single_distinct_agg ensure args.len=1
if *distinct
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old code had redundant conditionals so I refactored it into a single if ... else ...

&& group_fields_set.insert(args[0].display_name()?)
{
inner_group_exprs.push(
args[0].clone().alias(SINGLE_DISTINCT_ALIAS),
);
col(alias_str),
Some(qualified_name(qualifier, field.name())),
),
)
}
})
.unzip();

// replace the distinct arg with alias
let mut index = 1;
let mut distinct_aggr_exprs = HashSet::new();
let mut inner_aggr_exprs = vec![];
let outer_aggr_exprs = aggr_expr
.into_iter()
.map(|aggr_expr| match aggr_expr {
Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(fun),
mut args,
distinct,
..
}) => {
if distinct {
debug_assert_eq!(
args.len(),
1,
"DISTINCT aggregate should have exactly one argument"
);
let arg = args.swap_remove(0);

let expr_id = distinct_aggr_exprs.hasher().hash_one(&arg);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use HashSet<u64> with manual hashing instead of HashSet<&Expr> to avoid borrow checker issues.

if distinct_aggr_exprs.insert(expr_id) {
inner_group_exprs
.push(arg.alias(SINGLE_DISTINCT_ALIAS));
}

Expr::AggregateFunction(AggregateFunction::new(
fun,
vec![col(SINGLE_DISTINCT_ALIAS)],
false, // intentional to remove distinct here
None,
None,
None,
))
// if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
if !(*distinct) {
index += 1;
let alias_str = format!("alias{}", index);
inner_aggr_exprs.push(
Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
args.clone(),
false,
None,
None,
None,
))
.alias(&alias_str),
);
Ok(Expr::AggregateFunction(AggregateFunction::new(
} else {
index += 1;
let alias_str = format!("alias{}", index);
inner_aggr_exprs.push(
Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(&alias_str)],
args,
false,
None,
None,
None,
)))
} else {
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(SINGLE_DISTINCT_ALIAS)],
false, // intentional to remove distinct here
None,
None,
None,
)))
}
}
_ => Ok(aggr_expr.clone()),
})
.collect::<Result<Vec<_>>>()?;

// construct the inner AggrPlan
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
input.clone(),
inner_group_exprs,
inner_aggr_exprs,
)?);

// so the aggregates are displayed in the same way even after the rewrite
// this optimizer has two kinds of alias:
// - group_by aggr
// - aggr expr
let group_size = group_expr.len();
let alias_expr: Vec<_> = out_group_expr_with_alias
.into_iter()
.map(|(group_expr, original_field)| {
if let Some(name) = original_field {
group_expr.alias(name)
} else {
group_expr
))
.alias(&alias_str),
);
Expr::AggregateFunction(AggregateFunction::new(
fun,
vec![col(&alias_str)],
false,
None,
None,
None,
))
}
})
.chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| {
}
_ => aggr_expr,
})
.collect::<Vec<_>>();

// construct the inner AggrPlan
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
input,
inner_group_exprs,
inner_aggr_exprs,
)?);

let outer_group_exprs = out_group_expr_with_alias
.iter()
.map(|(expr, _)| expr.clone())
.collect();

// so the aggregates are displayed in the same way even after the rewrite
// this optimizer has two kinds of alias:
// - group_by aggr
// - aggr expr
let alias_expr: Vec<_> = out_group_expr_with_alias
.into_iter()
.map(|(group_expr, original_field)| {
if let Some(name) = original_field {
group_expr.alias(name)
} else {
group_expr
}
})
.chain(outer_aggr_exprs.iter().cloned().enumerate().map(
|(idx, expr)| {
let idx = idx + group_size;
let (qualifier, field) = schema.qualified_field(idx);
let name = qualified_name(qualifier, field.name());
expr.clone().alias(name)
}))
.collect();

let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
Arc::new(inner_agg),
outer_group_exprs,
outer_aggr_exprs,
)?);
Ok(Transformed::yes(project(outer_aggr, alias_expr)?))
} else {
Ok(Transformed::no(plan))
}
expr.alias(name)
},
))
.collect();

let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
Arc::new(inner_agg),
outer_group_exprs,
outer_aggr_exprs,
)?);
Ok(Transformed::yes(project(outer_aggr, alias_expr)?))
}
_ => Ok(Transformed::no(plan)),
}
Expand Down