diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 8c4853d1f5..ff090f642e 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1106,10 +1106,6 @@ class PyExpr: def __repr__(self) -> str: ... def __hash__(self) -> int: ... def __reduce__(self) -> tuple: ... - def is_nan(self) -> PyExpr: ... - def is_inf(self) -> PyExpr: ... - def not_nan(self) -> PyExpr: ... - def fill_nan(self, fill_value: PyExpr) -> PyExpr: ... def dt_date(self) -> PyExpr: ... def dt_day(self) -> PyExpr: ... def dt_hour(self) -> PyExpr: ... @@ -1255,6 +1251,14 @@ def image_encode(expr: PyExpr, image_format: ImageFormat) -> PyExpr: ... def image_resize(expr: PyExpr, w: int, h: int) -> PyExpr: ... def image_to_mode(expr: PyExpr, mode: ImageMode) -> PyExpr: ... +# --- +# expr.float namespace +# --- +def is_nan(expr: PyExpr) -> PyExpr: ... +def is_inf(expr: PyExpr) -> PyExpr: ... +def not_nan(expr: PyExpr) -> PyExpr: ... +def fill_nan(expr: PyExpr, fill_value: PyExpr) -> PyExpr: ... + # --- # expr.json namespace # --- diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 3dc46487ec..8466697dd6 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1362,7 +1362,7 @@ def is_nan(self) -> Expression: Returns: Expression: Boolean Expression indicating whether values are invalid. """ - return Expression._from_pyexpr(self._expr.is_nan()) + return Expression._from_pyexpr(native.is_nan(self._expr)) def is_inf(self) -> Expression: """Checks if values in the Expression are Infinity. @@ -1394,7 +1394,7 @@ def is_inf(self) -> Expression: Returns: Expression: Boolean Expression indicating whether values are Infinity. """ - return Expression._from_pyexpr(self._expr.is_inf()) + return Expression._from_pyexpr(native.is_inf(self._expr)) def not_nan(self) -> Expression: """Checks if values are not NaN (a special float value indicating not-a-number) @@ -1424,7 +1424,7 @@ def not_nan(self) -> Expression: Returns: Expression: Boolean Expression indicating whether values are not invalid. """ - return Expression._from_pyexpr(self._expr.not_nan()) + return Expression._from_pyexpr(native.not_nan(self._expr)) def fill_nan(self, fill_value: Expression) -> Expression: """Fills NaN values in the Expression with the provided fill_value @@ -1453,7 +1453,7 @@ def fill_nan(self, fill_value: Expression) -> Expression: """ fill_value = Expression._to_expression(fill_value) - expr = self._expr.fill_nan(fill_value._expr) + expr = native.fill_nan(self._expr, fill_value._expr) return Expression._from_pyexpr(expr) diff --git a/src/daft-dsl/src/functions/float/fill_nan.rs b/src/daft-dsl/src/functions/float/fill_nan.rs deleted file mode 100644 index c9417a0552..0000000000 --- a/src/daft-dsl/src/functions/float/fill_nan.rs +++ /dev/null @@ -1,46 +0,0 @@ -use daft_core::{prelude::*, utils::supertype::try_get_supertype}; - -use crate::ExprRef; - -use crate::functions::FunctionExpr; -use common_error::{DaftError, DaftResult}; - -use super::super::FunctionEvaluator; - -pub(super) struct FillNanEvaluator {} - -impl FunctionEvaluator for FillNanEvaluator { - fn fn_name(&self) -> &'static str { - "fill_nan" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data, fill_value] => match (data.to_field(schema), fill_value.to_field(schema)) { - (Ok(data_field), Ok(fill_value_field)) => { - match (&data_field.dtype.is_floating(), &fill_value_field.dtype.is_floating(), try_get_supertype(&data_field.dtype, &fill_value_field.dtype)) { - (true, true, Ok(dtype)) => Ok(Field::new(data_field.name, dtype)), - _ => Err(DaftError::TypeError(format!( - "Expects input to fill_nan to be float, but received {data_field} and {fill_value_field}", - ))), - } - } - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data, fill_value] => data.fill_nan(fill_value), - _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/float/mod.rs b/src/daft-dsl/src/functions/float/mod.rs deleted file mode 100644 index 5ef47f004c..0000000000 --- a/src/daft-dsl/src/functions/float/mod.rs +++ /dev/null @@ -1,67 +0,0 @@ -mod fill_nan; -mod is_inf; -mod is_nan; -mod not_nan; - -use fill_nan::FillNanEvaluator; -use is_inf::IsInfEvaluator; -use is_nan::IsNanEvaluator; -use not_nan::NotNanEvaluator; -use serde::{Deserialize, Serialize}; - -use crate::{Expr, ExprRef}; - -use super::FunctionEvaluator; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum FloatExpr { - IsNan, - IsInf, - NotNan, - FillNan, -} - -impl FloatExpr { - #[inline] - pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use FloatExpr::*; - match self { - IsNan => &IsNanEvaluator {}, - IsInf => &IsInfEvaluator {}, - NotNan => &NotNanEvaluator {}, - FillNan => &FillNanEvaluator {}, - } - } -} - -pub fn is_nan(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Float(FloatExpr::IsNan), - inputs: vec![data], - } - .into() -} - -pub fn is_inf(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Float(FloatExpr::IsInf), - inputs: vec![data], - } - .into() -} - -pub fn not_nan(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Float(FloatExpr::NotNan), - inputs: vec![data], - } - .into() -} - -pub fn fill_nan(data: ExprRef, fill_value: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Float(FloatExpr::FillNan), - inputs: vec![data, fill_value], - } - .into() -} diff --git a/src/daft-dsl/src/functions/float/not_nan.rs b/src/daft-dsl/src/functions/float/not_nan.rs deleted file mode 100644 index c9464eeb8e..0000000000 --- a/src/daft-dsl/src/functions/float/not_nan.rs +++ /dev/null @@ -1,47 +0,0 @@ -use daft_core::prelude::*; - -use crate::ExprRef; - -use crate::functions::FunctionExpr; -use common_error::{DaftError, DaftResult}; - -use super::super::FunctionEvaluator; - -pub(super) struct NotNanEvaluator {} - -impl FunctionEvaluator for NotNanEvaluator { - fn fn_name(&self) -> &'static str { - "not_nan" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - // DataType::Float16 | - DataType::Float32 | DataType::Float64 => { - Ok(Field::new(data_field.name, DataType::Boolean)) - } - _ => Err(DaftError::TypeError(format!( - "Expects input to is_nan to be float, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.not_nan(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 03d31cd49b..e92c0621f2 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -1,4 +1,3 @@ -pub mod float; pub mod list; pub mod map; pub mod numeric; @@ -15,7 +14,6 @@ use std::hash::Hash; use crate::{Expr, ExprRef, Operator}; -use self::float::FloatExpr; use self::list::ListExpr; use self::map::MapExpr; use self::numeric::NumericExpr; @@ -37,7 +35,6 @@ use python::PythonUDF; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { Numeric(NumericExpr), - Float(FloatExpr), Utf8(Utf8Expr), Temporal(TemporalExpr), List(ListExpr), @@ -65,7 +62,6 @@ impl FunctionExpr { use FunctionExpr::*; match self { Numeric(expr) => expr.get_evaluator(), - Float(expr) => expr.get_evaluator(), Utf8(expr) => expr.get_evaluator(), Temporal(expr) => expr.get_evaluator(), List(expr) => expr.get_evaluator(), diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 58d62d4f0d..bebad3d766 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -579,26 +579,6 @@ impl PyExpr { hasher.finish() } - pub fn is_nan(&self) -> PyResult { - use functions::float::is_nan; - Ok(is_nan(self.into()).into()) - } - - pub fn is_inf(&self) -> PyResult { - use functions::float::is_inf; - Ok(is_inf(self.into()).into()) - } - - pub fn not_nan(&self) -> PyResult { - use functions::float::not_nan; - Ok(not_nan(self.into()).into()) - } - - pub fn fill_nan(&self, fill_value: &Self) -> PyResult { - use functions::float::fill_nan; - Ok(fill_nan(self.into(), fill_value.expr.clone()).into()) - } - pub fn dt_date(&self) -> PyResult { use functions::temporal::date; Ok(date(self.into()).into()) diff --git a/src/daft-functions/src/float/fill_nan.rs b/src/daft-functions/src/float/fill_nan.rs new file mode 100644 index 0000000000..e79dd0a936 --- /dev/null +++ b/src/daft-functions/src/float/fill_nan.rs @@ -0,0 +1,70 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, + utils::supertype::try_get_supertype, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct FillNan {} + +#[typetag::serde] +impl ScalarUDF for FillNan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "fill_nan" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, fill_value] => match (data.to_field(schema), fill_value.to_field(schema)) { + (Ok(data_field), Ok(fill_value_field)) => { + match (&data_field.dtype.is_floating(), &fill_value_field.dtype.is_floating(), try_get_supertype(&data_field.dtype, &fill_value_field.dtype)) { + (true, true, Ok(dtype)) => Ok(Field::new(data_field.name, dtype)), + _ => Err(DaftError::TypeError(format!( + "Expects input for fill_nan to be float, but received {data_field} and {fill_value_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data, fill_value] => data.fill_nan(fill_value), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +pub fn fill_nan(input: ExprRef, fill_value: ExprRef) -> ExprRef { + ScalarFunction::new(FillNan {}, vec![input, fill_value]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "fill_nan")] +pub fn py_fill_nan(expr: PyExpr, fill_value: PyExpr) -> PyResult { + Ok(fill_nan(expr.into(), fill_value.into()).into()) +} diff --git a/src/daft-dsl/src/functions/float/is_inf.rs b/src/daft-functions/src/float/is_inf.rs similarity index 52% rename from src/daft-dsl/src/functions/float/is_inf.rs rename to src/daft-functions/src/float/is_inf.rs index 117f58c3a2..a46e221255 100644 --- a/src/daft-dsl/src/functions/float/is_inf.rs +++ b/src/daft-functions/src/float/is_inf.rs @@ -1,20 +1,27 @@ -use daft_core::prelude::*; - -use crate::ExprRef; - -use crate::functions::FunctionExpr; use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct IsInf {} -pub(super) struct IsInfEvaluator {} - -impl FunctionEvaluator for IsInfEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for IsInf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "is_inf" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data] => match data.to_field(schema) { Ok(data_field) => match &data_field.dtype { @@ -35,7 +42,7 @@ impl FunctionEvaluator for IsInfEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data] => data.is_inf(), _ => Err(DaftError::ValueError(format!( @@ -45,3 +52,19 @@ impl FunctionEvaluator for IsInfEvaluator { } } } + +pub fn is_inf(input: ExprRef) -> ExprRef { + ScalarFunction::new(IsInf {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "is_inf")] +pub fn py_is_inf(expr: PyExpr) -> PyResult { + Ok(is_inf(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/float/is_nan.rs b/src/daft-functions/src/float/is_nan.rs similarity index 52% rename from src/daft-dsl/src/functions/float/is_nan.rs rename to src/daft-functions/src/float/is_nan.rs index ec6e7a6737..365c09b80c 100644 --- a/src/daft-dsl/src/functions/float/is_nan.rs +++ b/src/daft-functions/src/float/is_nan.rs @@ -1,20 +1,27 @@ -use daft_core::prelude::*; - -use crate::ExprRef; - -use crate::functions::FunctionExpr; use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct IsNan {} -pub(super) struct IsNanEvaluator {} - -impl FunctionEvaluator for IsNanEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for IsNan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "is_nan" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data] => match data.to_field(schema) { Ok(data_field) => match &data_field.dtype { @@ -35,7 +42,7 @@ impl FunctionEvaluator for IsNanEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data] => data.is_nan(), _ => Err(DaftError::ValueError(format!( @@ -45,3 +52,19 @@ impl FunctionEvaluator for IsNanEvaluator { } } } + +pub fn is_nan(input: ExprRef) -> ExprRef { + ScalarFunction::new(IsNan {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "is_nan")] +pub fn py_is_nan(expr: PyExpr) -> PyResult { + Ok(is_nan(expr.into()).into()) +} diff --git a/src/daft-functions/src/float/mod.rs b/src/daft-functions/src/float/mod.rs new file mode 100644 index 0000000000..82f348a736 --- /dev/null +++ b/src/daft-functions/src/float/mod.rs @@ -0,0 +1,22 @@ +mod fill_nan; +mod is_inf; +mod is_nan; +mod not_nan; + +pub use fill_nan::fill_nan; +pub use is_inf::is_inf; +pub use is_nan::is_nan; +pub use not_nan::not_nan; + +#[cfg(feature = "python")] +use pyo3::prelude::*; + +#[cfg(feature = "python")] +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_function(wrap_pyfunction_bound!(fill_nan::py_fill_nan, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(is_inf::py_is_inf, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(is_nan::py_is_nan, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(not_nan::py_not_nan, parent)?)?; + + Ok(()) +} diff --git a/src/daft-functions/src/float/not_nan.rs b/src/daft-functions/src/float/not_nan.rs new file mode 100644 index 0000000000..87bca04011 --- /dev/null +++ b/src/daft-functions/src/float/not_nan.rs @@ -0,0 +1,70 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct NotNan {} + +#[typetag::serde] +impl ScalarUDF for NotNan { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "not_nan" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + // DataType::Float16 | + DataType::Float32 | DataType::Float64 => { + Ok(Field::new(data_field.name, DataType::Boolean)) + } + _ => Err(DaftError::TypeError(format!( + "Expects input to not_nan to be float, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.not_nan(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +pub fn not_nan(input: ExprRef) -> ExprRef { + ScalarFunction::new(NotNan {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "not_nan")] +pub fn py_not_nan(expr: PyExpr) -> PyResult { + Ok(not_nan(expr.into()).into()) +} diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index 94a886eba3..d55d17a5f8 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -1,6 +1,7 @@ #![feature(async_closure)] pub mod count_matches; pub mod distance; +pub mod float; pub mod hash; pub mod image; pub mod list_sort; @@ -48,6 +49,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(uri::python::url_download, parent)?)?; parent.add_function(wrap_pyfunction_bound!(uri::python::url_upload, parent)?)?; image::register_modules(parent)?; + float::register_modules(parent)?; Ok(()) } diff --git a/src/daft-sql/src/modules/float.rs b/src/daft-sql/src/modules/float.rs index 035f993dee..3b1132ffe4 100644 --- a/src/daft-sql/src/modules/float.rs +++ b/src/daft-sql/src/modules/float.rs @@ -1,11 +1,82 @@ use super::SQLModule; use crate::functions::SQLFunctions; +use daft_dsl::ExprRef; +use daft_functions::float; +use sqlparser::ast::FunctionArg; + +use crate::{error::SQLPlannerResult, functions::SQLFunction, unsupported_sql_err}; pub struct SQLModuleFloat; impl SQLModule for SQLModuleFloat { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::Float as f; - // TODO + fn register(parent: &mut SQLFunctions) { + parent.add_fn("fill_nan", SQLFillNan {}); + parent.add_fn("is_inf", SQLIsInf {}); + parent.add_fn("is_nan", SQLIsNan {}); + parent.add_fn("not_nan", SQLNotNan {}); + } +} + +pub struct SQLFillNan {} + +impl SQLFunction for SQLFillNan { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, fill_value] => { + let input = planner.plan_function_arg(input)?; + let fill_value = planner.plan_function_arg(fill_value)?; + Ok(float::fill_nan(input, fill_value)) + } + _ => unsupported_sql_err!("Invalid arguments for 'fill_nan': '{inputs:?}'"), + } + } +} + +pub struct SQLIsInf {} + +impl SQLFunction for SQLIsInf { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => planner.plan_function_arg(input).map(float::is_inf), + _ => unsupported_sql_err!("Invalid arguments for 'is_inf': '{inputs:?}'"), + } + } +} + +pub struct SQLIsNan {} + +impl SQLFunction for SQLIsNan { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => planner.plan_function_arg(input).map(float::is_nan), + _ => unsupported_sql_err!("Invalid arguments for 'is_nan': '{inputs:?}'"), + } + } +} + +pub struct SQLNotNan {} + +impl SQLFunction for SQLNotNan { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => planner.plan_function_arg(input).map(float::not_nan), + _ => unsupported_sql_err!("Invalid arguments for 'not_nan': '{inputs:?}'"), + } } } diff --git a/tests/sql/test_float_exprs.py b/tests/sql/test_float_exprs.py new file mode 100644 index 0000000000..f4c7a49928 --- /dev/null +++ b/tests/sql/test_float_exprs.py @@ -0,0 +1,32 @@ +import numpy as np + +import daft +from daft.sql.sql import SQLCatalog + + +def test_floats(): + df = daft.from_pydict( + { + "nans": [1.0, 2.0, np.nan, 4.0], + "infs": [1.0, 2.0, np.inf, np.inf], + } + ) + catalog = SQLCatalog({"test": df}) + + sql = """ + SELECT + is_nan(nans) as is_nan, + is_inf(infs) as is_inf, + not_nan(nans) as not_nan, + fill_nan(nans, 0.0) as fill_nan + FROM test + """ + df = daft.sql(sql, catalog=catalog).collect() + expected = { + "is_nan": [False, False, True, False], + "is_inf": [False, False, True, True], + "not_nan": [True, True, False, True], + "fill_nan": [1.0, 2.0, 0.0, 4.0], + } + actual = df.to_pydict() + assert actual == expected