Skip to content

Commit

Permalink
[FEAT]: Sql common table expressions (CTE's) (Eventual-Inc#3137)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored and sagiahrac committed Nov 4, 2024
1 parent 051793d commit 5d40e65
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 28 deletions.
3 changes: 2 additions & 1 deletion src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,10 @@ mod tests {
#[case::orderby_multi("select * from tbl1 order by i32 desc, f32 asc")]
#[case::whenthen("select case when i32 = 1 then 'a' else 'b' end from tbl1")]
#[case::globalagg("select max(i32) from tbl1")]
#[case::cte("with cte as (select * from tbl1) select * from cte")]
fn test_compiles(mut planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> {
let plan = planner.plan_sql(query);
assert!(plan.is_ok(), "query: {query}\nerror: {plan:?}");
assert!(&plan.is_ok(), "query: {query}\nerror: {plan:?}");

Ok(())
}
Expand Down
159 changes: 132 additions & 27 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use daft_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo,
ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField,
ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, Statement, StructField,
Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value,
WildcardAdditionalOptions,
WildcardAdditionalOptions, With,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand Down Expand Up @@ -66,14 +66,16 @@ pub struct SQLPlanner {
catalog: SQLCatalog,
current_relation: Option<Relation>,
table_map: HashMap<String, Relation>,
cte_map: HashMap<String, Relation>,
}

impl Default for SQLPlanner {
fn default() -> Self {
Self {
catalog: SQLCatalog::new(),
current_relation: None,
table_map: HashMap::new(),
current_relation: Default::default(),
table_map: Default::default(),
cte_map: Default::default(),
}
}
}
Expand All @@ -82,8 +84,7 @@ impl SQLPlanner {
pub fn new(context: SQLCatalog) -> Self {
Self {
catalog: context,
current_relation: None,
table_map: HashMap::new(),
..Default::default()
}
}

Expand All @@ -102,6 +103,69 @@ impl SQLPlanner {
fn clear_context(&mut self) {
self.current_relation = None;
self.table_map.clear();
self.cte_map.clear();
}

fn get_table_from_current_scope(&self, name: &str) -> Option<Relation> {
let table = self.table_map.get(name).cloned();
table
.or_else(|| self.cte_map.get(name).cloned())
.or_else(|| {
self.catalog
.get_table(name)
.map(|table| Relation::new(table.into(), name.to_string()))
})
}

fn register_cte(
&mut self,
mut rel: Relation,
column_aliases: &[Ident],
) -> SQLPlannerResult<()> {
if !column_aliases.is_empty() {
let schema = rel.schema();
let columns = schema.names();
if columns.len() != column_aliases.len() {
invalid_operation_err!(
"Column count mismatch: expected {} columns, found {}",
column_aliases.len(),
columns.len()
);
}

let projection = columns
.into_iter()
.zip(column_aliases)
.map(|(name, alias)| col(name).alias(ident_to_str(alias)))
.collect::<Vec<_>>();

rel.inner = rel.inner.select(projection)?;
}
self.cte_map.insert(rel.get_name(), rel);
Ok(())
}

fn plan_ctes(&mut self, with: &With) -> SQLPlannerResult<()> {
if with.recursive {
unsupported_sql_err!("Recursive CTEs are not supported");
}

for cte in &with.cte_tables {
if cte.materialized.is_some() {
unsupported_sql_err!("MATERIALIZED is not supported");
}

if cte.from.is_some() {
invalid_operation_err!("FROM should only exist in recursive CTEs");
}

let name = ident_to_str(&cte.alias.name);
let plan = self.plan_query(&cte.query)?;
let rel = Relation::new(plan, name);

self.register_cte(rel, cte.alias.columns.as_slice())?;
}
Ok(())
}

pub fn plan_sql(&mut self, sql: &str) -> SQLPlannerResult<LogicalPlanRef> {
Expand Down Expand Up @@ -136,15 +200,24 @@ impl SQLPlanner {
fn plan_query(&mut self, query: &Query) -> SQLPlannerResult<LogicalPlanBuilder> {
check_query_features(query)?;

let selection = query.body.as_select().ok_or_else(|| {
PlannerError::invalid_operation(format!(
"Only SELECT queries are supported, got: '{}'",
query.body
))
})?;
let selection = match query.body.as_ref() {
SetExpr::Select(selection) => selection,
SetExpr::Query(_) => unsupported_sql_err!("Subqueries are not supported"),
SetExpr::SetOperation { .. } => {
unsupported_sql_err!("Set operations are not supported")
}
SetExpr::Values(..) => unsupported_sql_err!("VALUES are not supported"),
SetExpr::Insert(..) => unsupported_sql_err!("INSERT is not supported"),
SetExpr::Update(..) => unsupported_sql_err!("UPDATE is not supported"),
SetExpr::Table(..) => unsupported_sql_err!("TABLE is not supported"),
};

check_select_features(selection)?;

if let Some(with) = &query.with {
self.plan_ctes(with)?;
}

// FROM/JOIN
let from = selection.clone().from;
let rel = self.plan_from(&from)?;
Expand Down Expand Up @@ -480,7 +553,7 @@ impl SQLPlanner {
Ok(left_rel)
}

fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult<Relation> {
fn plan_relation(&mut self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult<Relation> {
let (rel, alias) = match rel {
sqlparser::ast::TableFactor::Table {
name,
Expand All @@ -498,12 +571,48 @@ impl SQLPlanner {
..
} => {
let table_name = name.to_string();
let plan = self
.catalog
.get_table(&table_name)
.ok_or_else(|| PlannerError::table_not_found(table_name.clone()))?;
let plan_builder = LogicalPlanBuilder::new(plan, None);
(Relation::new(plan_builder, table_name), alias.clone())
let Some(rel) = self.get_table_from_current_scope(&table_name) else {
table_not_found_err!(table_name)
};
(rel, alias.clone())
}
sqlparser::ast::TableFactor::Derived {
lateral,
subquery,
alias: Some(alias),
} => {
if *lateral {
unsupported_sql_err!("LATERAL");
}
let subquery = self.plan_query(subquery)?;
let rel_name = ident_to_str(&alias.name);
let rel = Relation::new(subquery, rel_name);

(rel, Some(alias.clone()))
}
sqlparser::ast::TableFactor::TableFunction { .. } => {
unsupported_sql_err!("Unsupported table factor: TableFunction")
}
sqlparser::ast::TableFactor::Function { .. } => {
unsupported_sql_err!("Unsupported table factor: Function")
}
sqlparser::ast::TableFactor::UNNEST { .. } => {
unsupported_sql_err!("Unsupported table factor: UNNEST")
}
sqlparser::ast::TableFactor::JsonTable { .. } => {
unsupported_sql_err!("Unsupported table factor: JsonTable")
}
sqlparser::ast::TableFactor::NestedJoin { .. } => {
unsupported_sql_err!("Unsupported table factor: NestedJoin")
}
sqlparser::ast::TableFactor::Pivot { .. } => {
unsupported_sql_err!("Unsupported table factor: Pivot")
}
sqlparser::ast::TableFactor::Unpivot { .. } => {
unsupported_sql_err!("Unsupported table factor: Unpivot")
}
sqlparser::ast::TableFactor::MatchRecognize { .. } => {
unsupported_sql_err!("Unsupported table factor: MatchRecognize")
}
_ => unsupported_sql_err!("Unsupported table factor"),
};
Expand All @@ -520,8 +629,7 @@ impl SQLPlanner {

let root = idents.next().unwrap();
let root = ident_to_str(root);

let current_relation = match self.table_map.get(&root) {
let current_relation = match self.get_table_from_current_scope(&root) {
Some(rel) => rel,
None => {
return Err(PlannerError::TableNotFound {
Expand Down Expand Up @@ -626,7 +734,7 @@ impl SQLPlanner {
let Some(rel) = self.relation_opt() else {
table_not_found_err!(table_name);
};
let Some(table_rel) = self.table_map.get(&table_name) else {
let Some(table_rel) = self.get_table_from_current_scope(&table_name) else {
table_not_found_err!(table_name);
};
let right_schema = table_rel.inner.schema();
Expand Down Expand Up @@ -673,7 +781,7 @@ impl SQLPlanner {
Value::Null => LiteralValue::Null,
_ => {
return Err(PlannerError::invalid_operation(
"Only string, number, boolean and null literals are supported",
"Only string, number, boolean and null literals are supported. Instead found: `{value}`",
))
}
})
Expand All @@ -683,7 +791,7 @@ impl SQLPlanner {
if let sqlparser::ast::Expr::Value(v) = expr {
self.value_to_lit(v)
} else {
invalid_operation_err!("Only string, number, boolean and null literals are supported");
invalid_operation_err!("Only string, number, boolean and null literals are supported. Instead found: `{expr}`");
}
}
pub(crate) fn plan_expr(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult<ExprRef> {
Expand Down Expand Up @@ -1373,9 +1481,6 @@ impl SQLPlanner {
/// /// This function examines various clauses and options in the provided [sqlparser::ast::Query]
/// and returns an error if any unsupported features are encountered.
fn check_query_features(query: &sqlparser::ast::Query) -> SQLPlannerResult<()> {
if let Some(with) = &query.with {
unsupported_sql_err!("WITH: {with}")
}
if !query.limit_by.is_empty() {
unsupported_sql_err!("LIMIT BY");
}
Expand Down
61 changes: 61 additions & 0 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import daft
from daft import col
from daft.exceptions import DaftCoreException
from daft.sql.sql import SQLCatalog
from tests.assets import TPCH_QUERIES
Expand Down Expand Up @@ -221,3 +222,63 @@ def test_sql_distinct():
actual = daft.sql("SELECT DISTINCT n FROM df").collect().to_pydict()
expected = df.distinct().collect().to_pydict()
assert actual == expected


def test_sql_cte():
df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]})
actual = (
daft.sql("""
WITH cte1 AS (select * FROM df)
SELECT * FROM cte1
""")
.collect()
.to_pydict()
)

expected = df.collect().to_pydict()

assert actual == expected


def test_sql_cte_column_aliases():
df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]})
actual = (
daft.sql("""
WITH cte1 (cte_a, cte_b, cte_c) AS (select * FROM df)
SELECT * FROM cte1
""")
.collect()
.to_pydict()
)

expected = (
df.select(
col("a").alias("cte_a"),
col("b").alias("cte_b"),
col("c").alias("cte_c"),
)
.collect()
.to_pydict()
)

assert actual == expected


def test_sql_multiple_ctes():
df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]})
df2 = daft.from_pydict({"x": [1, 0, 3], "y": [True, None, False], "z": [1.0, 2.0, 3.0]})
actual = (
daft.sql("""
WITH
cte1 AS (select * FROM df1),
cte2 AS (select x as a, y, z FROM df2)
SELECT *
FROM cte1
JOIN cte2 USING (a)
""")
.collect()
.to_pydict()
)
expected = df1.join(df2.select(col("x").alias("a"), "y", "z"), on="a").collect().to_pydict()

assert actual == expected

0 comments on commit 5d40e65

Please sign in to comment.