diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 750413e579..b0f95d971c 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use common_error::DaftResult; use daft_core::prelude::*; use daft_dsl::{ col, @@ -96,30 +97,7 @@ impl SQLPlanner { } fn plan_query(&mut self, query: &Query) -> SQLPlannerResult { - if let Some(with) = &query.with { - unsupported_sql_err!("WITH: {with}") - } - if !query.limit_by.is_empty() { - unsupported_sql_err!("LIMIT BY"); - } - if query.offset.is_some() { - unsupported_sql_err!("OFFSET"); - } - if query.fetch.is_some() { - unsupported_sql_err!("FETCH"); - } - if !query.locks.is_empty() { - unsupported_sql_err!("LOCKS"); - } - if let Some(for_clause) = &query.for_clause { - unsupported_sql_err!("{for_clause}"); - } - if query.settings.is_some() { - unsupported_sql_err!("SETTINGS"); - } - if let Some(format_clause) = &query.format_clause { - unsupported_sql_err!("{format_clause}"); - } + check_query_features(query)?; let selection = query.body.as_select().ok_or_else(|| { PlannerError::invalid_operation(format!( @@ -128,91 +106,7 @@ impl SQLPlanner { )) })?; - self.plan_select(selection)?; - - if let Some(order_by) = &query.order_by { - if order_by.interpolate.is_some() { - unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); - } - // TODO: if ordering by a column not in the projection, this will fail. - let (exprs, descending) = self.plan_order_by_exprs(order_by.exprs.as_slice())?; - let rel = self.relation_mut(); - rel.inner = rel.inner.sort(exprs, descending)?; - } - - if let Some(limit) = &query.limit { - let limit = self.plan_expr(limit)?; - if let Expr::Literal(LiteralValue::Int64(limit)) = limit.as_ref() { - let rel = self.relation_mut(); - rel.inner = rel.inner.limit(*limit, true)?; // TODO: Should this be eager or not? - } else { - invalid_operation_err!( - "LIMIT must be a constant integer, instead got: {limit}" - ); - } - } - - Ok(self.current_relation.clone().unwrap().inner) - } - - fn plan_order_by_exprs( - &self, - expr: &[sqlparser::ast::OrderByExpr], - ) -> SQLPlannerResult<(Vec, Vec)> { - let mut exprs = Vec::with_capacity(expr.len()); - let mut desc = Vec::with_capacity(expr.len()); - for order_by_expr in expr { - if order_by_expr.nulls_first.is_some() { - unsupported_sql_err!("NULLS FIRST"); - } - if order_by_expr.with_fill.is_some() { - unsupported_sql_err!("WITH FILL"); - } - let expr = self.plan_expr(&order_by_expr.expr)?; - desc.push(!order_by_expr.asc.unwrap_or(true)); - - exprs.push(expr); - } - Ok((exprs, desc)) - } - - fn plan_select(&mut self, selection: &sqlparser::ast::Select) -> SQLPlannerResult<()> { - if selection.top.is_some() { - unsupported_sql_err!("TOP"); - } - if selection.distinct.is_some() { - unsupported_sql_err!("DISTINCT"); - } - if selection.into.is_some() { - unsupported_sql_err!("INTO"); - } - if !selection.lateral_views.is_empty() { - unsupported_sql_err!("LATERAL"); - } - if selection.prewhere.is_some() { - unsupported_sql_err!("PREWHERE"); - } - if !selection.cluster_by.is_empty() { - unsupported_sql_err!("CLUSTER BY"); - } - if !selection.distribute_by.is_empty() { - unsupported_sql_err!("DISTRIBUTE BY"); - } - if !selection.sort_by.is_empty() { - unsupported_sql_err!("SORT BY"); - } - if selection.having.is_some() { - unsupported_sql_err!("HAVING"); - } - if !selection.named_window.is_empty() { - unsupported_sql_err!("WINDOW"); - } - if selection.qualify.is_some() { - unsupported_sql_err!("QUALIFY"); - } - if selection.connect_by.is_some() { - unsupported_sql_err!("CONNECT BY"); - } + check_select_features(selection)?; // FROM/JOIN let from = selection.clone().from; @@ -246,18 +140,23 @@ impl SQLPlanner { } } - let to_select = selection + // split the selection into the groupby expressions and the rest + let (groupby_selection, to_select) = selection .projection .iter() .map(|expr| self.select_item_to_expr(expr)) .collect::>>()? .into_iter() .flatten() - .collect::>(); + .partition::, _>(|expr| { + groupby_exprs + .iter() + .any(|e| expr.input_mapping() == e.input_mapping()) + }); if !groupby_exprs.is_empty() { let rel = self.relation_mut(); - rel.inner = rel.inner.aggregate(to_select, groupby_exprs)?; + rel.inner = rel.inner.aggregate(to_select, groupby_exprs.clone())?; } else if !to_select.is_empty() { let rel = self.relation_mut(); let has_aggs = to_select.iter().any(has_agg); @@ -268,7 +167,86 @@ impl SQLPlanner { } } - Ok(()) + if let Some(order_by) = &query.order_by { + if order_by.interpolate.is_some() { + unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); + } + // TODO: if ordering by a column not in the projection, this will fail. + let (exprs, descending) = self.plan_order_by_exprs(order_by.exprs.as_slice())?; + let rel = self.relation_mut(); + rel.inner = rel.inner.sort(exprs, descending)?; + } + + // Properly apply or remove the groupby columns from the selection + // This needs to be done after the orderby + // otherwise, the orderby will not be able to reference the grouping columns + // + // ex: SELECT sum(a) as sum_a, max(a) as max_a, b as c FROM table GROUP BY b + // + // The groupby columns are [b] + // the evaluation of sum(a) and max(a) are already handled by the earlier aggregate, + // so our projection is [sum_a, max_a, (b as c)] + // leaving us to handle (b as c) + // + // we filter for the columns in the schema that are not in the groupby keys, + // [sum_a, max_a, b] -> [sum_a, max_a] + // + // Then we add the groupby columns back in with the correct expressions + // this gives us the final projection: [sum_a, max_a, (b as c)] + if !groupby_exprs.is_empty() { + let rel = self.relation_mut(); + let schema = rel.inner.schema(); + + let groupby_keys = groupby_exprs + .iter() + .map(|e| Ok(e.to_field(&schema)?.name)) + .collect::>>()?; + + let selection_colums = schema + .exclude(groupby_keys.as_ref())? + .names() + .iter() + .map(|n| col(n.as_str())) + .chain(groupby_selection) + .collect(); + + rel.inner = rel.inner.select(selection_colums)?; + } + + if let Some(limit) = &query.limit { + let limit = self.plan_expr(limit)?; + if let Expr::Literal(LiteralValue::Int64(limit)) = limit.as_ref() { + let rel = self.relation_mut(); + rel.inner = rel.inner.limit(*limit, true)?; // TODO: Should this be eager or not? + } else { + invalid_operation_err!( + "LIMIT must be a constant integer, instead got: {limit}" + ); + } + } + + Ok(self.current_relation.clone().unwrap().inner) + } + + fn plan_order_by_exprs( + &self, + expr: &[sqlparser::ast::OrderByExpr], + ) -> SQLPlannerResult<(Vec, Vec)> { + let mut exprs = Vec::with_capacity(expr.len()); + let mut desc = Vec::with_capacity(expr.len()); + for order_by_expr in expr { + if order_by_expr.nulls_first.is_some() { + unsupported_sql_err!("NULLS FIRST"); + } + if order_by_expr.with_fill.is_some() { + unsupported_sql_err!("WITH FILL"); + } + let expr = self.plan_expr(&order_by_expr.expr)?; + desc.push(!order_by_expr.asc.unwrap_or(true)); + + exprs.push(expr); + } + Ok((exprs, desc)) } fn plan_from(&self, from: &TableWithJoins) -> SQLPlannerResult { @@ -952,6 +930,89 @@ impl SQLPlanner { } } +/// Checks if the SQL query is valid syntax and doesn't use unsupported features. +/// /// This function examines various clauses and options in the provided [sqlparser::ast::Query] +/// and returns an error if any unsupported features are encountered. +fn check_query_features(query: &sqlparser::ast::Query) -> SQLPlannerResult<()> { + if let Some(with) = &query.with { + unsupported_sql_err!("WITH: {with}") + } + if !query.limit_by.is_empty() { + unsupported_sql_err!("LIMIT BY"); + } + if query.offset.is_some() { + unsupported_sql_err!("OFFSET"); + } + if query.fetch.is_some() { + unsupported_sql_err!("FETCH"); + } + if !query.locks.is_empty() { + unsupported_sql_err!("LOCKS"); + } + if let Some(for_clause) = &query.for_clause { + unsupported_sql_err!("{for_clause}"); + } + if query.settings.is_some() { + unsupported_sql_err!("SETTINGS"); + } + if let Some(format_clause) = &query.format_clause { + unsupported_sql_err!("{format_clause}"); + } + Ok(()) +} + +/// Checks if the features used in the SQL SELECT statement are supported. +/// +/// This function examines various clauses and options in the provided [sqlparser::ast::Select] +/// and returns an error if any unsupported features are encountered. +/// +/// # Arguments +/// +/// * `selection` - A reference to the [sqlparser::ast::Select] to be checked. +/// +/// # Returns +/// +/// * `SQLPlannerResult<()>` - Ok(()) if all features are supported, or an error describing +/// the first unsupported feature encountered. +fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult<()> { + if selection.top.is_some() { + unsupported_sql_err!("TOP"); + } + if selection.distinct.is_some() { + unsupported_sql_err!("DISTINCT"); + } + if selection.into.is_some() { + unsupported_sql_err!("INTO"); + } + if !selection.lateral_views.is_empty() { + unsupported_sql_err!("LATERAL"); + } + if selection.prewhere.is_some() { + unsupported_sql_err!("PREWHERE"); + } + if !selection.cluster_by.is_empty() { + unsupported_sql_err!("CLUSTER BY"); + } + if !selection.distribute_by.is_empty() { + unsupported_sql_err!("DISTRIBUTE BY"); + } + if !selection.sort_by.is_empty() { + unsupported_sql_err!("SORT BY"); + } + if selection.having.is_some() { + unsupported_sql_err!("HAVING"); + } + if !selection.named_window.is_empty() { + unsupported_sql_err!("WINDOW"); + } + if selection.qualify.is_some() { + unsupported_sql_err!("QUALIFY"); + } + if selection.connect_by.is_some() { + unsupported_sql_err!("CONNECT BY"); + } + Ok(()) +} pub fn sql_expr>(s: S) -> SQLPlannerResult { let planner = SQLPlanner::default(); diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index 662e4d8b14..dd7ac1fc54 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -111,8 +111,26 @@ def test_sql_global_agg(): def test_sql_groupby_agg(): df = daft.from_pydict({"n": [1, 1, 2, 2], "v": [1, 2, 3, 4]}) catalog = SQLCatalog({"test": df}) - df = daft.sql("SELECT sum(v) FROM test GROUP BY n ORDER BY n", catalog=catalog) - assert df.collect().to_pydict() == {"n": [1, 2], "v": [3, 7]} + actual = daft.sql("SELECT sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog) + assert actual.collect().to_pydict() == {"sum": [3, 7]} + + # test with grouping column + actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog) + assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]} + + # test with multiple columns + actual = daft.sql("SELECT max(v) as max, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog) + assert actual.collect().to_pydict() == {"max": [2, 4], "sum": [3, 7]} + + # test with aliased grouping key + actual = daft.sql("SELECT n as n_alias, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog) + assert actual.collect().to_pydict() == {"n_alias": [1, 2], "sum": [3, 7]} + + actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY -n", catalog=catalog) + assert actual.collect().to_pydict() == {"n": [2, 1], "sum": [7, 3]} + + actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY sum", catalog=catalog) + assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]} def test_sql_count_star():