From 1fa25ae5d50c5f34f17e77e9f635f854ef5e7642 Mon Sep 17 00:00:00 2001 From: wiedld Date: Sun, 31 Mar 2024 05:09:06 -0700 Subject: [PATCH] fix(9870): common expression elimination optimization, should always re-find the correct expression during re-write. (#9871) * test(9870): reproducer of error with jumping traversal patterns in common-expr-elimination traversals * refactor: remove the IdArray ordered idx, since the idx ordering does not always stay in sync with the updated TreeNode traversal * refactor: use the only reproducible key (expr_identifer) for expr_set, while keeping the (stack-popped) symbol used for alias. * refactor: encapsulate most of the logic within ExprSet, and delineate the expr_identifier from the alias symbol * test(9870): demonstrate that the sqllogictests are now passing --- datafusion/expr/src/logical_plan/plan.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 441 ++++++------------ datafusion/sqllogictest/test_files/expr.slt | 63 +++ 3 files changed, 214 insertions(+), 292 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 05d7ac539458..a1dc90dda0ab 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2368,7 +2368,7 @@ impl DistinctOn { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] pub struct Aggregate { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0c9064d0641f..25c25c63f0b7 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,6 +17,7 @@ //! Eliminate common sub-expression. +use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -35,37 +36,75 @@ use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; use datafusion_expr::{col, Expr, ExprSchemable}; -/// A map from expression's identifier to tuple including -/// - the expression itself (cloned) -/// - counter -/// - DataType of this expression. -type ExprSet = HashMap; +/// Set of expressions generated by the [`ExprIdentifierVisitor`] +/// and consumed by the [`CommonSubexprRewriter`]. +#[derive(Default)] +struct ExprSet { + /// A map from expression's identifier (stringified expr) to tuple including: + /// - the expression itself (cloned) + /// - counter + /// - DataType of this expression. + /// - symbol used as the identifier in the alias. + map: HashMap, +} -/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an -/// initial expression walk. -/// -/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove -/// common subexpressions. -/// -/// Elements in this array are created on the walk down the expression tree -/// during `f_down`. Thus element 0 is the root of the expression tree. The -/// tuple contains: -/// - series_number. -/// - Incremented during `f_up`, start from 1. -/// - Thus, items with higher idx have the lower series_number. -/// - [`Identifier`] -/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination. -/// -/// # Example -/// An expression like `(a + b)` would have the following `IdArray`: -/// ```text -/// [ -/// (3, "a + b"), -/// (2, "a"), -/// (1, "b") -/// ] -/// ``` -type IdArray = Vec<(usize, Identifier)>; +impl ExprSet { + fn expr_identifier(expr: &Expr) -> Identifier { + format!("{expr}") + } + + fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> { + self.map.get(key) + } + + fn entry( + &mut self, + key: Identifier, + ) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> { + self.map.entry(key) + } + + fn populate_expr_set( + &mut self, + expr: &[Expr], + input_schema: DFSchemaRef, + expr_mask: ExprMask, + ) -> Result<()> { + expr.iter().try_for_each(|e| { + self.expr_to_identifier(e, Arc::clone(&input_schema), expr_mask)?; + + Ok(()) + }) + } + + /// Go through an expression tree and generate identifier for every node in this tree. + fn expr_to_identifier( + &mut self, + expr: &Expr, + input_schema: DFSchemaRef, + expr_mask: ExprMask, + ) -> Result<()> { + expr.visit(&mut ExprIdentifierVisitor { + expr_set: self, + input_schema, + visit_stack: vec![], + node_count: 0, + expr_mask, + })?; + + Ok(()) + } +} + +impl From> for ExprSet { + fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self { + let mut expr_set = Self::default(); + entries.into_iter().for_each(|(k, v)| { + expr_set.map.insert(k, v); + }); + expr_set + } +} /// Identifier for each subexpression. /// @@ -112,21 +151,16 @@ impl CommonSubexprEliminate { fn rewrite_exprs_list( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result>> { exprs_list .iter() - .zip(arrays_list.iter()) - .map(|(exprs, arrays)| { + .map(|exprs| { exprs .iter() .cloned() - .zip(arrays.iter()) - .map(|(expr, id_array)| { - replace_common_expr(expr, id_array, expr_set, affected_id) - }) + .map(|expr| replace_common_expr(expr, expr_set, affected_id)) .collect::>>() }) .collect::>>() @@ -135,7 +169,6 @@ impl CommonSubexprEliminate { fn rewrite_expr( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], input: &LogicalPlan, expr_set: &ExprSet, config: &dyn OptimizerConfig, @@ -143,7 +176,7 @@ impl CommonSubexprEliminate { let mut affected_id = BTreeSet::::new(); let rewrite_exprs = - self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?; + self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?; let mut new_input = self .try_optimize(input, config)? @@ -161,8 +194,7 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result { let mut window_exprs = vec![]; - let mut arrays_per_window = vec![]; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); // Get all window expressions inside the consecutive window operators. // Consecutive window expressions may refer to same complex expression. @@ -181,30 +213,18 @@ impl CommonSubexprEliminate { plan = input.as_ref().clone(); let input_schema = Arc::clone(input.schema()); - let arrays = - to_arrays(&window_expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(&window_expr, input_schema, ExprMask::Normal)?; window_exprs.push(window_expr); - arrays_per_window.push(arrays); } let mut window_exprs = window_exprs .iter() .map(|expr| expr.as_slice()) .collect::>(); - let arrays_per_window = arrays_per_window - .iter() - .map(|arrays| arrays.as_slice()) - .collect::>(); - assert_eq!(window_exprs.len(), arrays_per_window.len()); - let (mut new_expr, new_input) = self.rewrite_expr( - &window_exprs, - &arrays_per_window, - &plan, - &expr_set, - config, - )?; + let (mut new_expr, new_input) = + self.rewrite_expr(&window_exprs, &plan, &expr_set, config)?; assert_eq!(window_exprs.len(), new_expr.len()); // Construct consecutive window operator, with their corresponding new window expressions. @@ -241,46 +261,36 @@ impl CommonSubexprEliminate { input, .. } = aggregate; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); - // rewrite inputs + // build expr_set, with groupby and aggr let input_schema = Arc::clone(input.schema()); - let group_arrays = to_arrays( + expr_set.populate_expr_set( group_expr, Arc::clone(&input_schema), - &mut expr_set, ExprMask::Normal, )?; - let aggr_arrays = - to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(aggr_expr, input_schema, ExprMask::Normal)?; - let (mut new_expr, new_input) = self.rewrite_expr( - &[group_expr, aggr_expr], - &[&group_arrays, &aggr_arrays], - input, - &expr_set, - config, - )?; + // rewrite inputs + let (mut new_expr, new_input) = + self.rewrite_expr(&[group_expr, aggr_expr], input, &expr_set, config)?; // note the reversed pop order. let new_aggr_expr = pop_expr(&mut new_expr)?; let new_group_expr = pop_expr(&mut new_expr)?; // create potential projection on top - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); let new_input_schema = Arc::clone(new_input.schema()); - let aggr_arrays = to_arrays( + expr_set.populate_expr_set( &new_aggr_expr, new_input_schema.clone(), - &mut expr_set, ExprMask::NormalAndAggregates, )?; + let mut affected_id = BTreeSet::::new(); - let mut rewritten = self.rewrite_exprs_list( - &[&new_aggr_expr], - &[&aggr_arrays], - &expr_set, - &mut affected_id, - )?; + let mut rewritten = + self.rewrite_exprs_list(&[&new_aggr_expr], &expr_set, &mut affected_id)?; let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { @@ -300,9 +310,9 @@ impl CommonSubexprEliminate { for id in affected_id { match expr_set.get(&id) { - Some((expr, _, _)) => { + Some((expr, _, _, symbol)) => { // todo: check `nullable` - agg_exprs.push(expr.clone().alias(&id)); + agg_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -320,9 +330,7 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - let id = ExprIdentifierVisitor::<'static>::expr_identifier( - &expr_rewritten, - ); + let id = ExprSet::expr_identifier(&expr_rewritten); let out_name = expr_rewritten.to_field(&new_input_schema)?.qualified_name(); agg_exprs.push(expr_rewritten.alias(&id)); @@ -356,13 +364,13 @@ impl CommonSubexprEliminate { let inputs = plan.inputs(); let input = inputs[0]; let input_schema = Arc::clone(input.schema()); - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); // Visit expr list and build expr identifier to occuring count map (`expr_set`). - let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(&expr, input_schema, ExprMask::Normal)?; let (mut new_expr, new_input) = - self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?; + self.rewrite_expr(&[&expr], input, &expr_set, config)?; plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) } @@ -448,28 +456,6 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) } -fn to_arrays( - expr: &[Expr], - input_schema: DFSchemaRef, - expr_set: &mut ExprSet, - expr_mask: ExprMask, -) -> Result>> { - expr.iter() - .map(|e| { - let mut id_array = vec![]; - expr_to_identifier( - e, - expr_set, - &mut id_array, - Arc::clone(&input_schema), - expr_mask, - )?; - - Ok(id_array) - }) - .collect::>>() -} - /// Build the "intermediate" projection plan that evaluates the extracted common expressions. fn build_common_expr_project_plan( input: LogicalPlan, @@ -481,11 +467,11 @@ fn build_common_expr_project_plan( for id in affected_id { match expr_set.get(&id) { - Some((expr, _, data_type)) => { + Some((expr, _, data_type, symbol)) => { // todo: check `nullable` let field = DFField::new_unqualified(&id, data_type.clone(), true); fields_set.insert(field.name().to_owned()); - project_exprs.push(expr.clone().alias(&id)); + project_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -601,8 +587,6 @@ impl ExprMask { struct ExprIdentifierVisitor<'a> { // param expr_set: &'a mut ExprSet, - /// series number (usize) and identifier. - id_array: &'a mut IdArray, /// input schema for the node that we're optimizing, so we can determine the correct datatype /// for each subexpression input_schema: DFSchemaRef, @@ -610,8 +594,6 @@ struct ExprIdentifierVisitor<'a> { visit_stack: Vec, /// increased in fn_down, start from 0. node_count: usize, - /// increased in fn_up, start from 1. - series_number: usize, /// which expression should be skipped? expr_mask: ExprMask, } @@ -628,10 +610,6 @@ enum VisitRecord { } impl ExprIdentifierVisitor<'_> { - fn expr_identifier(expr: &Expr) -> Identifier { - format!("{expr}") - } - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. fn pop_enter_mark(&mut self) -> (usize, Identifier) { @@ -655,9 +633,6 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type Node = Expr; fn f_down(&mut self, expr: &Expr) -> Result { - // put placeholder, sets the proper array length - self.id_array.push((0, "".to_string())); - // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { @@ -674,70 +649,38 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { } fn f_up(&mut self, expr: &Expr) -> Result { - self.series_number += 1; - - let (idx, sub_expr_identifier) = self.pop_enter_mark(); + let (_idx, sub_expr_identifier) = self.pop_enter_mark(); // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { - let curr_expr_identifier = Self::expr_identifier(expr); + let curr_expr_identifier = ExprSet::expr_identifier(expr); self.visit_stack .push(VisitRecord::ExprItem(curr_expr_identifier)); - self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr return Ok(TreeNodeRecursion::Continue); } - let mut desc = Self::expr_identifier(expr); - desc.push_str(&sub_expr_identifier); + let curr_expr_identifier = ExprSet::expr_identifier(expr); + let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}"); - self.id_array[idx] = (self.series_number, desc.clone()); - self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); + self.visit_stack + .push(VisitRecord::ExprItem(alias_symbol.clone())); let data_type = expr.get_type(&self.input_schema)?; self.expr_set - .entry(desc) - .or_insert_with(|| (expr.clone(), 0, data_type)) + .entry(curr_expr_identifier) + .or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol)) .1 += 1; Ok(TreeNodeRecursion::Continue) } } -/// Go through an expression tree and generate identifier for every node in this tree. -fn expr_to_identifier( - expr: &Expr, - expr_set: &mut ExprSet, - id_array: &mut Vec<(usize, Identifier)>, - input_schema: DFSchemaRef, - expr_mask: ExprMask, -) -> Result<()> { - expr.visit(&mut ExprIdentifierVisitor { - expr_set, - id_array, - input_schema, - visit_stack: vec![], - node_count: 0, - series_number: 0, - expr_mask, - })?; - - Ok(()) -} - /// Rewrite expression by replacing detected common sub-expression with /// the corresponding temporary column name. That column contains the /// evaluate result of replaced expression. struct CommonSubexprRewriter<'a> { expr_set: &'a ExprSet, - id_array: &'a IdArray, /// Which identifier is replaced. affected_id: &'a mut BTreeSet, - - /// the max series number we have rewritten. Other expression nodes - /// with smaller series number is already replaced and shouldn't - /// do anything with them. - max_series_number: usize, - /// current node's information's index in `id_array`. - curr_index: usize, } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { @@ -751,80 +694,41 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - let (series_number, curr_id) = &self.id_array[self.curr_index]; - - // halting conditions - if self.curr_index >= self.id_array.len() - || self.max_series_number > *series_number - { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); - } - - // skip `Expr`s without identifier (empty identifier). - if curr_id.is_empty() { - self.curr_index += 1; // incr idx for id_array, when not jumping - return Ok(Transformed::no(expr)); - } + let curr_id = &ExprSet::expr_identifier(&expr); // lookup previously visited expression match self.expr_set.get(curr_id) { - Some((_, counter, _)) => { + Some((_, counter, _, symbol)) => { // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } - - // incr idx for id_array, when not jumping - self.curr_index += 1; - - // series_number was the inverse number ordering (when doing f_up) - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - let expr_name = expr.display_name()?; // Alias this `Column` expr to it original "expr name", // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(curr_id).alias(expr_name), + col(symbol).alias(expr_name), true, TreeNodeRecursion::Jump, )) } else { - self.curr_index += 1; Ok(Transformed::no(expr)) } } - _ => internal_err!("expr_set invalid state"), + None => Ok(Transformed::no(expr)), } } } fn replace_common_expr( expr: Expr, - id_array: &IdArray, expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { expr.rewrite(&mut CommonSubexprRewriter { expr_set, - id_array, affected_id, - max_series_number: 0, - curr_index: 0, }) .data() } @@ -860,73 +764,6 @@ mod test { assert_eq!(expected, formatted_plan); } - #[test] - fn id_array_visitor() -> Result<()> { - let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2); - - let schema = Arc::new(DFSchema::new_with_metadata( - vec![ - DFField::new_unqualified("a", DataType::Int64, false), - DFField::new_unqualified("c", DataType::Int64, false), - ], - Default::default(), - )?); - - // skip aggregates - let mut id_array = vec![]; - expr_to_identifier( - &expr, - &mut HashMap::new(), - &mut id_array, - Arc::clone(&schema), - ExprMask::Normal, - )?; - - let expected = vec![ - (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), - (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), - (4, ""), - (3, "a + Int32(1)Int32(1)a"), - (1, ""), - (2, ""), - (6, ""), - (5, ""), - (8, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); - assert_eq!(expected, id_array); - - // include aggregates - let mut id_array = vec![]; - expr_to_identifier( - &expr, - &mut HashMap::new(), - &mut id_array, - Arc::clone(&schema), - ExprMask::NormalAndAggregates, - )?; - - let expected = vec![ - (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (3, "a + Int32(1)Int32(1)a"), - (1, ""), - (2, ""), - (6, "AVG(c)c"), - (5, ""), - (8, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); - assert_eq!(expected, id_array); - - Ok(()) - } - #[test] fn tpch_q1_simplified() -> Result<()> { // SQL: @@ -1171,24 +1008,28 @@ mod test { let table_scan = test_table_scan().unwrap(); let affected_id: BTreeSet = ["c+a".to_string(), "b+a".to_string()].into_iter().collect(); - let expr_set_1 = [ + let expr_set_1 = vec![ ( "c+a".to_string(), - (col("c") + col("a"), 1, DataType::UInt32), + (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()), ), ( "b+a".to_string(), - (col("b") + col("a"), 1, DataType::UInt32), + (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()), ), ] - .into_iter() - .collect(); - let expr_set_2 = [ - ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), - ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)), + .into(); + let expr_set_2 = vec![ + ( + "c+a".to_string(), + (col("c+a"), 1, DataType::UInt32, "c+a".to_string()), + ), + ( + "b+a".to_string(), + (col("b+a"), 1, DataType::UInt32, "b+a".to_string()), + ), ] - .into_iter() - .collect(); + .into(); let project = build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1) .unwrap(); @@ -1214,30 +1055,48 @@ mod test { ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()] .into_iter() .collect(); - let expr_set_1 = [ + let expr_set_1 = vec![ ( "test1.c+test1.a".to_string(), - (col("test1.c") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.c") + col("test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.b") + col("test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] - .into_iter() - .collect(); - let expr_set_2 = [ + .into(); + let expr_set_2 = vec![ ( "test1.c+test1.a".to_string(), - (col("test1.c+test1.a"), 1, DataType::UInt32), + ( + col("test1.c+test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b+test1.a"), 1, DataType::UInt32), + ( + col("test1.b+test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] - .into_iter() - .collect(); + .into(); let project = build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1) .unwrap(); diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 75bcbc07755b..2e0cbf50cab9 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2262,3 +2262,66 @@ query RRR rowsort select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles; ---- 10.1 0.09900990099 1.471623942989 + + +statement ok +CREATE TABLE t1( + time TIMESTAMP, + load1 DOUBLE, + load2 DOUBLE, + host VARCHAR +) AS VALUES + (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'), + (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'), + (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'), + (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL) +; + +# struct scalar function with columns +query ? +select struct(time,load1,load2,host) from t1; +---- +{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: host1} +{c0: 2018-05-22T19:53:26, c1: 2.2, c2: 202.0, c3: host2} +{c0: 2018-05-22T19:53:26, c1: 3.3, c2: 303.0, c3: host3} +{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: } + +# can have an aggregate function with an inner coalesce +query TR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 +host2 2.2 +host3 3.3 + +# can have an aggregate function with an inner CASE WHEN +query TR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 101 +host2 202 +host3 303 + +# can have 2 projections with aggr(short_circuited), with different short-circuited expr +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303