Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Sql groupby fix #2843

Merged
merged 15 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 174 additions & 113 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
error::{PlannerError, SQLPlannerResult},
invalid_operation_err, table_not_found_err, unsupported_sql_err,
};
use common_error::DaftResult;
use daft_core::prelude::*;
use daft_dsl::{
col,
Expand Down Expand Up @@ -96,30 +97,7 @@ impl SQLPlanner {
}

fn plan_query(&mut self, query: &Query) -> SQLPlannerResult<LogicalPlanBuilder> {
if let Some(with) = &query.with {
unsupported_sql_err!("WITH: {with}")
}
if !query.limit_by.is_empty() {
unsupported_sql_err!("LIMIT BY");
}
if query.offset.is_some() {
unsupported_sql_err!("OFFSET");
}
if query.fetch.is_some() {
unsupported_sql_err!("FETCH");
}
if !query.locks.is_empty() {
unsupported_sql_err!("LOCKS");
}
if let Some(for_clause) = &query.for_clause {
unsupported_sql_err!("{for_clause}");
}
if query.settings.is_some() {
unsupported_sql_err!("SETTINGS");
}
if let Some(format_clause) = &query.format_clause {
unsupported_sql_err!("{format_clause}");
}
check_query_features(query)?;

let selection = query.body.as_select().ok_or_else(|| {
PlannerError::invalid_operation(format!(
Expand All @@ -128,91 +106,7 @@ impl SQLPlanner {
))
})?;

self.plan_select(selection)?;

if let Some(order_by) = &query.order_by {
if order_by.interpolate.is_some() {
unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]");
}
// TODO: if ordering by a column not in the projection, this will fail.
let (exprs, descending) = self.plan_order_by_exprs(order_by.exprs.as_slice())?;
let rel = self.relation_mut();
rel.inner = rel.inner.sort(exprs, descending)?;
}

if let Some(limit) = &query.limit {
let limit = self.plan_expr(limit)?;
if let Expr::Literal(LiteralValue::Int64(limit)) = limit.as_ref() {
let rel = self.relation_mut();
rel.inner = rel.inner.limit(*limit, true)?; // TODO: Should this be eager or not?
} else {
invalid_operation_err!(
"LIMIT <n> must be a constant integer, instead got: {limit}"
);
}
}

Ok(self.current_relation.clone().unwrap().inner)
}

fn plan_order_by_exprs(
&self,
expr: &[sqlparser::ast::OrderByExpr],
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<bool>)> {
let mut exprs = Vec::with_capacity(expr.len());
let mut desc = Vec::with_capacity(expr.len());
for order_by_expr in expr {
if order_by_expr.nulls_first.is_some() {
unsupported_sql_err!("NULLS FIRST");
}
if order_by_expr.with_fill.is_some() {
unsupported_sql_err!("WITH FILL");
}
let expr = self.plan_expr(&order_by_expr.expr)?;
desc.push(!order_by_expr.asc.unwrap_or(true));

exprs.push(expr);
}
Ok((exprs, desc))
}

fn plan_select(&mut self, selection: &sqlparser::ast::Select) -> SQLPlannerResult<()> {
if selection.top.is_some() {
unsupported_sql_err!("TOP");
}
if selection.distinct.is_some() {
unsupported_sql_err!("DISTINCT");
}
if selection.into.is_some() {
unsupported_sql_err!("INTO");
}
if !selection.lateral_views.is_empty() {
unsupported_sql_err!("LATERAL");
}
if selection.prewhere.is_some() {
unsupported_sql_err!("PREWHERE");
}
if !selection.cluster_by.is_empty() {
unsupported_sql_err!("CLUSTER BY");
}
if !selection.distribute_by.is_empty() {
unsupported_sql_err!("DISTRIBUTE BY");
}
if !selection.sort_by.is_empty() {
unsupported_sql_err!("SORT BY");
}
if selection.having.is_some() {
unsupported_sql_err!("HAVING");
}
if !selection.named_window.is_empty() {
unsupported_sql_err!("WINDOW");
}
if selection.qualify.is_some() {
unsupported_sql_err!("QUALIFY");
}
if selection.connect_by.is_some() {
unsupported_sql_err!("CONNECT BY");
}
check_select_features(selection)?;

// FROM/JOIN
let from = selection.clone().from;
Expand Down Expand Up @@ -246,18 +140,23 @@ impl SQLPlanner {
}
}

let to_select = selection
// split the selection into the groupby expressions and the rest
let (groupby_selection, to_select) = selection
.projection
.iter()
.map(|expr| self.select_item_to_expr(expr))
.collect::<SQLPlannerResult<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();
.partition::<Vec<_>, _>(|expr| {
groupby_exprs
.iter()
.any(|e| expr.input_mapping() == e.input_mapping())
});

if !groupby_exprs.is_empty() {
let rel = self.relation_mut();
rel.inner = rel.inner.aggregate(to_select, groupby_exprs)?;
rel.inner = rel.inner.aggregate(to_select, groupby_exprs.clone())?;
} else if !to_select.is_empty() {
let rel = self.relation_mut();
let has_aggs = to_select.iter().any(has_agg);
Expand All @@ -268,7 +167,86 @@ impl SQLPlanner {
}
}

Ok(())
if let Some(order_by) = &query.order_by {
if order_by.interpolate.is_some() {
unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]");
}
// TODO: if ordering by a column not in the projection, this will fail.
let (exprs, descending) = self.plan_order_by_exprs(order_by.exprs.as_slice())?;
let rel = self.relation_mut();
rel.inner = rel.inner.sort(exprs, descending)?;
}

// Properly apply or remove the groupby columns from the selection
// This needs to be done after the orderby
// otherwise, the orderby will not be able to reference the grouping columns
//
// ex: SELECT sum(a) as sum_a, max(a) as max_a, b as c FROM table GROUP BY b
//
// The groupby columns are [b]
// the evaluation of sum(a) and max(a) are already handled by the earlier aggregate,
// so our projection is [sum_a, max_a, (b as c)]
// leaving us to handle (b as c)
//
// we filter for the columns in the schema that are not in the groupby keys,
// [sum_a, max_a, b] -> [sum_a, max_a]
//
// Then we add the groupby columns back in with the correct expressions
// this gives us the final projection: [sum_a, max_a, (b as c)]
if !groupby_exprs.is_empty() {
let rel = self.relation_mut();
let schema = rel.inner.schema();

let groupby_keys = groupby_exprs
.iter()
.map(|e| Ok(e.to_field(&schema)?.name))
.collect::<DaftResult<Vec<_>>>()?;

let selection_colums = schema
.exclude(groupby_keys.as_ref())?
.names()
.iter()
.map(|n| col(n.as_str()))
.chain(groupby_selection)
.collect();

rel.inner = rel.inner.select(selection_colums)?;
}

if let Some(limit) = &query.limit {
let limit = self.plan_expr(limit)?;
if let Expr::Literal(LiteralValue::Int64(limit)) = limit.as_ref() {
let rel = self.relation_mut();
rel.inner = rel.inner.limit(*limit, true)?; // TODO: Should this be eager or not?
} else {
invalid_operation_err!(
"LIMIT <n> must be a constant integer, instead got: {limit}"
);
}
}

Ok(self.current_relation.clone().unwrap().inner)
}

fn plan_order_by_exprs(
&self,
expr: &[sqlparser::ast::OrderByExpr],
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<bool>)> {
let mut exprs = Vec::with_capacity(expr.len());
let mut desc = Vec::with_capacity(expr.len());
for order_by_expr in expr {
if order_by_expr.nulls_first.is_some() {
unsupported_sql_err!("NULLS FIRST");
}
if order_by_expr.with_fill.is_some() {
unsupported_sql_err!("WITH FILL");
}
let expr = self.plan_expr(&order_by_expr.expr)?;
desc.push(!order_by_expr.asc.unwrap_or(true));

exprs.push(expr);
}
Ok((exprs, desc))
}

fn plan_from(&self, from: &TableWithJoins) -> SQLPlannerResult<Relation> {
Expand Down Expand Up @@ -952,6 +930,89 @@ impl SQLPlanner {
}
}

/// Checks if the SQL query is valid syntax and doesn't use unsupported features.
/// /// This function examines various clauses and options in the provided [sqlparser::ast::Query]
/// and returns an error if any unsupported features are encountered.
fn check_query_features(query: &sqlparser::ast::Query) -> SQLPlannerResult<()> {
if let Some(with) = &query.with {
unsupported_sql_err!("WITH: {with}")
}
if !query.limit_by.is_empty() {
unsupported_sql_err!("LIMIT BY");
}
if query.offset.is_some() {
unsupported_sql_err!("OFFSET");
}
if query.fetch.is_some() {
unsupported_sql_err!("FETCH");
}
if !query.locks.is_empty() {
unsupported_sql_err!("LOCKS");
}
if let Some(for_clause) = &query.for_clause {
unsupported_sql_err!("{for_clause}");
}
if query.settings.is_some() {
unsupported_sql_err!("SETTINGS");
}
if let Some(format_clause) = &query.format_clause {
unsupported_sql_err!("{format_clause}");
}
Ok(())
}

/// Checks if the features used in the SQL SELECT statement are supported.
///
/// This function examines various clauses and options in the provided [sqlparser::ast::Select]
/// and returns an error if any unsupported features are encountered.
///
/// # Arguments
///
/// * `selection` - A reference to the [sqlparser::ast::Select] to be checked.
///
/// # Returns
///
/// * `SQLPlannerResult<()>` - Ok(()) if all features are supported, or an error describing
/// the first unsupported feature encountered.
fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult<()> {
if selection.top.is_some() {
unsupported_sql_err!("TOP");
}
if selection.distinct.is_some() {
unsupported_sql_err!("DISTINCT");
}
if selection.into.is_some() {
unsupported_sql_err!("INTO");
}
if !selection.lateral_views.is_empty() {
unsupported_sql_err!("LATERAL");
}
if selection.prewhere.is_some() {
unsupported_sql_err!("PREWHERE");
}
if !selection.cluster_by.is_empty() {
unsupported_sql_err!("CLUSTER BY");
}
if !selection.distribute_by.is_empty() {
unsupported_sql_err!("DISTRIBUTE BY");
}
if !selection.sort_by.is_empty() {
unsupported_sql_err!("SORT BY");
}
if selection.having.is_some() {
unsupported_sql_err!("HAVING");
}
if !selection.named_window.is_empty() {
unsupported_sql_err!("WINDOW");
}
if selection.qualify.is_some() {
unsupported_sql_err!("QUALIFY");
}
if selection.connect_by.is_some() {
unsupported_sql_err!("CONNECT BY");
}
Ok(())
}
pub fn sql_expr<S: AsRef<str>>(s: S) -> SQLPlannerResult<ExprRef> {
let planner = SQLPlanner::default();

Expand Down
16 changes: 14 additions & 2 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,20 @@ def test_sql_global_agg():
def test_sql_groupby_agg():
df = daft.from_pydict({"n": [1, 1, 2, 2], "v": [1, 2, 3, 4]})
catalog = SQLCatalog({"test": df})
df = daft.sql("SELECT sum(v) FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert df.collect().to_pydict() == {"n": [1, 2], "v": [3, 7]}
actual = daft.sql("SELECT sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"sum": [3, 7]}

# test with grouping column
actual = daft.sql("SELECT n, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"n": [1, 2], "sum": [3, 7]}

# test with multiple columns
actual = daft.sql("SELECT max(v) as max, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"max": [2, 4], "sum": [3, 7]}

# test with aliased grouping key
actual = daft.sql("SELECT n as n_alias, sum(v) as sum FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert actual.collect().to_pydict() == {"n_alias": [1, 2], "sum": [3, 7]}
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved


def test_sql_count_star():
Expand Down
Loading