Skip to content

Commit

Permalink
[FEAT]: sql list operations (#2856)
Browse files Browse the repository at this point in the history
depends on #2854

you can see the relevant diff
[here](universalmind303/Daft@sql-lists...universalmind303:Daft:sql-lists-2)
  • Loading branch information
universalmind303 authored Sep 20, 2024
1 parent 688150f commit 48a123a
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 32 deletions.
6 changes: 1 addition & 5 deletions src/daft-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ pub mod float;
pub mod hash;
pub mod image;
pub mod list;
pub mod list_sort;
pub mod minhash;
pub mod numeric;
pub mod to_struct;
Expand All @@ -29,10 +28,7 @@ pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent
)?)?;
parent.add_function(wrap_pyfunction_bound!(hash::python::hash, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(
list_sort::python::list_sort,
parent
)?)?;

parent.add_function(wrap_pyfunction_bound!(minhash::python::minhash, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(numeric::cbrt::python::cbrt, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(
Expand Down
2 changes: 1 addition & 1 deletion src/daft-functions/src/list/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct ListCount {
mode: CountMode,
pub mode: CountMode,
}

#[typetag::serde]
Expand Down
24 changes: 14 additions & 10 deletions src/daft-functions/src/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@ mod max;
mod mean;
mod min;
mod slice;
mod sort;
mod sum;
pub use chunk::list_chunk as chunk;
pub use count::list_count as count;
pub use explode::explode;
pub use get::list_get as get;
pub use join::list_join as join;
pub use max::list_max as max;
pub use mean::list_mean as mean;
pub use min::list_min as min;

pub use chunk::{list_chunk as chunk, ListChunk};
pub use count::{list_count as count, ListCount};
pub use explode::{explode, Explode};
pub use get::{list_get as get, ListGet};
pub use join::{list_join as join, ListJoin};
pub use max::{list_max as max, ListMax};
pub use mean::{list_mean as mean, ListMean};
pub use min::{list_min as min, ListMin};
#[cfg(feature = "python")]
use pyo3::prelude::*;
pub use slice::list_slice as slice;
pub use sum::list_sum as sum;
pub use slice::{list_slice as slice, ListSlice};
pub use sort::{list_sort as sort, ListSort};
pub use sum::{list_sum as sum, ListSum};

#[cfg(feature = "python")]
pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
Expand All @@ -35,6 +38,7 @@ pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_function(wrap_pyfunction_bound!(min::py_list_min, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(slice::py_list_slice, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(sum::py_list_sum, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(sort::py_list_sort, parent)?)?;

Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use common_error::{DaftError, DaftResult};
use daft_core::prelude::*;
use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
lit, ExprRef,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
struct ListSortFunction {}
pub struct ListSort {}

#[typetag::serde]
impl ScalarUDF for ListSortFunction {
impl ScalarUDF for ListSort {
fn as_any(&self) -> &dyn std::any::Any {
self
}
Expand Down Expand Up @@ -51,18 +51,20 @@ impl ScalarUDF for ListSortFunction {
}
}

pub fn list_sort(input: ExprRef, desc: ExprRef) -> ExprRef {
ScalarFunction::new(ListSortFunction {}, vec![input, desc]).into()
pub fn list_sort(input: ExprRef, desc: Option<ExprRef>) -> ExprRef {
let desc = desc.unwrap_or_else(|| lit(false));
ScalarFunction::new(ListSort {}, vec![input, desc]).into()
}

#[cfg(feature = "python")]
pub mod python {
use daft_dsl::python::PyExpr;
use pyo3::{pyfunction, PyResult};
use {
daft_dsl::python::PyExpr,
pyo3::{pyfunction, PyResult},
};

#[pyfunction]
pub fn list_sort(expr: PyExpr, desc: PyExpr) -> PyResult<PyExpr> {
let expr = super::list_sort(expr.into(), desc.into());
Ok(expr.into())
}
#[cfg(feature = "python")]
#[pyfunction]
#[pyo3(name = "list_sort")]
pub fn py_list_sort(expr: PyExpr, desc: PyExpr) -> PyResult<PyExpr> {
Ok(list_sort(expr.into(), Some(desc.into())).into())
}
256 changes: 253 additions & 3 deletions src/daft-sql/src/modules/list.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,261 @@
use daft_core::prelude::CountMode;
use daft_dsl::{lit, Expr, LiteralValue};

use super::SQLModule;
use crate::functions::SQLFunctions;
use crate::{
error::PlannerError,
functions::{SQLFunction, SQLFunctions},
unsupported_sql_err,
};

pub struct SQLModuleList;

impl SQLModule for SQLModuleList {
fn register(_parent: &mut SQLFunctions) {
// use FunctionExpr::List as f;
fn register(parent: &mut SQLFunctions) {
parent.add_fn("list_chunk", SQLListChunk);
parent.add_fn("list_count", SQLListCount);
parent.add_fn("explode", SQLExplode);
parent.add_fn("unnest", SQLExplode);
// this is commonly called `array_to_string` in other SQL dialects
parent.add_fn("array_to_string", SQLListJoin);
// but we also want to support our `list_join` alias as well
parent.add_fn("list_join", SQLListJoin);
parent.add_fn("list_max", SQLListMax);
parent.add_fn("list_min", SQLListMin);
parent.add_fn("list_sum", SQLListSum);
parent.add_fn("list_mean", SQLListMean);
parent.add_fn("list_slice", SQLListSlice);
parent.add_fn("list_sort", SQLListSort);

// TODO
}
}

pub struct SQLListChunk;

impl SQLFunction for SQLListChunk {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input, chunk_size] => {
let input = planner.plan_function_arg(input)?;
let chunk_size = planner
.plan_function_arg(chunk_size)
.and_then(|arg| match arg.as_ref() {
Expr::Literal(LiteralValue::Int64(n)) => Ok(*n as usize),
_ => unsupported_sql_err!("Expected chunk size to be a number"),
})?;
Ok(daft_functions::list::chunk(input, chunk_size))
}
_ => unsupported_sql_err!(
"invalid arguments for list_chunk. Expected list_chunk(expr, chunk_size)"
),
}
}
}

pub struct SQLListCount;

impl SQLFunction for SQLListCount {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(daft_functions::list::count(input, CountMode::Valid))
}
[input, count_mode] => {
let input = planner.plan_function_arg(input)?;
let mode =
planner
.plan_function_arg(count_mode)
.and_then(|arg| match arg.as_ref() {
Expr::Literal(LiteralValue::Utf8(s)) => {
s.parse().map_err(PlannerError::from)
}
_ => unsupported_sql_err!("Expected mode to be a string"),
})?;
Ok(daft_functions::list::count(input, mode))
}
_ => unsupported_sql_err!("invalid arguments for list_count. Expected either list_count(expr) or list_count(expr, mode)"),
}
}
}

pub struct SQLExplode;

impl SQLFunction for SQLExplode {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(daft_functions::list::explode(input))
}
_ => unsupported_sql_err!("Expected 1 argument"),
}
}
}

pub struct SQLListJoin;

impl SQLFunction for SQLListJoin {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input, separator] => {
let input = planner.plan_function_arg(input)?;
let separator = planner.plan_function_arg(separator)?;
Ok(daft_functions::list::join(input, separator))
}
_ => unsupported_sql_err!(
"invalid arguments for list_join. Expected list_join(expr, separator)"
),
}
}
}

pub struct SQLListMax;

impl SQLFunction for SQLListMax {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(daft_functions::list::max(input))
}
_ => unsupported_sql_err!("invalid arguments for list_max. Expected list_max(expr)"),
}
}
}

pub struct SQLListMean;

impl SQLFunction for SQLListMean {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(daft_functions::list::mean(input))
}
_ => unsupported_sql_err!("invalid arguments for list_mean. Expected list_mean(expr)"),
}
}
}

pub struct SQLListMin;

impl SQLFunction for SQLListMin {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(daft_functions::list::min(input))
}
_ => unsupported_sql_err!("invalid arguments for list_min. Expected list_min(expr)"),
}
}
}

pub struct SQLListSum;

impl SQLFunction for SQLListSum {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(daft_functions::list::sum(input))
}
_ => unsupported_sql_err!("invalid arguments for list_sum. Expected list_sum(expr)"),
}
}
}

pub struct SQLListSlice;

impl SQLFunction for SQLListSlice {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input, start, end] => {
let input = planner.plan_function_arg(input)?;
let start = planner.plan_function_arg(start)?;
let end = planner.plan_function_arg(end)?;
Ok(daft_functions::list::slice(input, start, end))
}
_ => unsupported_sql_err!(
"invalid arguments for list_slice. Expected list_slice(expr, start, end)"
),
}
}
}

pub struct SQLListSort;

impl SQLFunction for SQLListSort {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(daft_functions::list::sort(input, None))
}
[input, order] => {
let input = planner.plan_function_arg(input)?;
use sqlparser::ast::{
Expr::Identifier as SQLIdent, FunctionArg::Unnamed,
FunctionArgExpr::Expr as SQLExpr,
};

let order = match order {
Unnamed(SQLExpr(SQLIdent(ident))) => {
match ident.value.to_lowercase().as_str() {
"asc" => lit(false),
"desc" => lit(true),
_ => unsupported_sql_err!("invalid order for list_sort"),
}
}
_ => unsupported_sql_err!("invalid order for list_sort"),
};
Ok(daft_functions::list::sort(input, Some(order)))
}
_ => unsupported_sql_err!(
"invalid arguments for list_sort. Expected list_sort(expr, ASC|DESC)"
),
}
}
}
Loading

0 comments on commit 48a123a

Please sign in to comment.