diff --git a/Cargo.lock b/Cargo.lock index 4815488c2b..8dc27f4c89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1819,6 +1819,7 @@ dependencies = [ "base64 0.22.1", "bytes", "common-error", + "common-hashable-float-wrapper", "common-io-config", "daft-core", "daft-dsl", diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 5a81e48b79..7c8f04a7fc 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1040,29 +1040,6 @@ class PySchema: class PyExpr: def alias(self, name: str) -> PyExpr: ... def cast(self, dtype: PyDataType) -> PyExpr: ... - def ceil(self) -> PyExpr: ... - def floor(self) -> PyExpr: ... - def sign(self) -> PyExpr: ... - def round(self, decimal: int) -> PyExpr: ... - def sqrt(self) -> PyExpr: ... - def sin(self) -> PyExpr: ... - def cos(self) -> PyExpr: ... - def tan(self) -> PyExpr: ... - def cot(self) -> PyExpr: ... - def arcsin(self) -> PyExpr: ... - def arccos(self) -> PyExpr: ... - def arctan(self) -> PyExpr: ... - def arctan2(self, other: PyExpr) -> PyExpr: ... - def arctanh(self) -> PyExpr: ... - def arccosh(self) -> PyExpr: ... - def arcsinh(self) -> PyExpr: ... - def degrees(self) -> PyExpr: ... - def radians(self) -> PyExpr: ... - def log2(self) -> PyExpr: ... - def log10(self) -> PyExpr: ... - def log(self, base: float) -> PyExpr: ... - def ln(self) -> PyExpr: ... - def exp(self) -> PyExpr: ... def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ... def count(self, mode: CountMode) -> PyExpr: ... def sum(self) -> PyExpr: ... @@ -1074,7 +1051,6 @@ class PyExpr: def any_value(self, ignore_nulls: bool) -> PyExpr: ... def agg_list(self) -> PyExpr: ... def agg_concat(self) -> PyExpr: ... - def __abs__(self) -> PyExpr: ... def __add__(self, other: PyExpr) -> PyExpr: ... def __sub__(self, other: PyExpr) -> PyExpr: ... def __mul__(self, other: PyExpr) -> PyExpr: ... @@ -1217,9 +1193,35 @@ def minhash( def sql(sql: str, catalog: PyCatalog, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def sql_expr(sql: str) -> PyExpr: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... -def cbrt(expr: PyExpr) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... +# expr numeric ops +def abs(expr: PyExpr) -> PyExpr: ... +def cbrt(expr: PyExpr) -> PyExpr: ... +def ceil(expr: PyExpr) -> PyExpr: ... +def exp(expr: PyExpr) -> PyExpr: ... +def floor(expr: PyExpr) -> PyExpr: ... +def log2(expr: PyExpr) -> PyExpr: ... +def log10(expr: PyExpr) -> PyExpr: ... +def log(expr: PyExpr, base: float) -> PyExpr: ... +def ln(expr: PyExpr) -> PyExpr: ... +def round(expr: PyExpr, decimal: int) -> PyExpr: ... +def sign(expr: PyExpr) -> PyExpr: ... +def sqrt(expr: PyExpr) -> PyExpr: ... +def sin(expr: PyExpr) -> PyExpr: ... +def cos(expr: PyExpr) -> PyExpr: ... +def tan(expr: PyExpr) -> PyExpr: ... +def cot(expr: PyExpr) -> PyExpr: ... +def arcsin(expr: PyExpr) -> PyExpr: ... +def arccos(expr: PyExpr) -> PyExpr: ... +def arctan(expr: PyExpr) -> PyExpr: ... +def arctan2(expr: PyExpr, other: PyExpr) -> PyExpr: ... +def radians(expr: PyExpr) -> PyExpr: ... +def degrees(expr: PyExpr) -> PyExpr: ... +def arctanh(expr: PyExpr) -> PyExpr: ... +def arccosh(expr: PyExpr) -> PyExpr: ... +def arcsinh(expr: PyExpr) -> PyExpr: ... + # --- # expr.image namespace # --- diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index a2232d9a24..1ae7e90dac 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -328,7 +328,7 @@ def __abs__(self) -> Expression: def abs(self) -> Expression: """Absolute of a numeric expression (``expr.abs()``)""" - return Expression._from_pyexpr(abs(self._expr)) + return Expression._from_pyexpr(native.abs(self._expr)) def __add__(self, other: object) -> Expression: """Adds two numeric expressions or concatenates two string expressions (``e1 + e2``)""" @@ -577,17 +577,17 @@ def cast(self, dtype: DataType) -> Expression: def ceil(self) -> Expression: """The ceiling of a numeric expression (``expr.ceil()``)""" - expr = self._expr.ceil() + expr = native.ceil(self._expr) return Expression._from_pyexpr(expr) def floor(self) -> Expression: """The floor of a numeric expression (``expr.floor()``)""" - expr = self._expr.floor() + expr = native.floor(self._expr) return Expression._from_pyexpr(expr) def sign(self) -> Expression: """The sign of a numeric expression (``expr.sign()``)""" - expr = self._expr.sign() + expr = native.sign(self._expr) return Expression._from_pyexpr(expr) def round(self, decimals: int = 0) -> Expression: @@ -597,12 +597,12 @@ def round(self, decimals: int = 0) -> Expression: decimals: number of decimal places to round to. Defaults to 0. """ assert isinstance(decimals, int) - expr = self._expr.round(decimals) + expr = native.round(self._expr, decimals) return Expression._from_pyexpr(expr) def sqrt(self) -> Expression: """The square root of a numeric expression (``expr.sqrt()``)""" - expr = self._expr.sqrt() + expr = native.sqrt(self._expr) return Expression._from_pyexpr(expr) def cbrt(self) -> Expression: @@ -611,37 +611,37 @@ def cbrt(self) -> Expression: def sin(self) -> Expression: """The elementwise sine of a numeric expression (``expr.sin()``)""" - expr = self._expr.sin() + expr = native.sin(self._expr) return Expression._from_pyexpr(expr) def cos(self) -> Expression: """The elementwise cosine of a numeric expression (``expr.cos()``)""" - expr = self._expr.cos() + expr = native.cos(self._expr) return Expression._from_pyexpr(expr) def tan(self) -> Expression: """The elementwise tangent of a numeric expression (``expr.tan()``)""" - expr = self._expr.tan() + expr = native.tan(self._expr) return Expression._from_pyexpr(expr) def cot(self) -> Expression: """The elementwise cotangent of a numeric expression (``expr.cot()``)""" - expr = self._expr.cot() + expr = native.cot(self._expr) return Expression._from_pyexpr(expr) def arcsin(self) -> Expression: """The elementwise arc sine of a numeric expression (``expr.arcsin()``)""" - expr = self._expr.arcsin() + expr = native.arcsin(self._expr) return Expression._from_pyexpr(expr) def arccos(self) -> Expression: """The elementwise arc cosine of a numeric expression (``expr.arccos()``)""" - expr = self._expr.arccos() + expr = native.arccos(self._expr) return Expression._from_pyexpr(expr) def arctan(self) -> Expression: """The elementwise arc tangent of a numeric expression (``expr.arctan()``)""" - expr = self._expr.arctan() + expr = native.arctan(self._expr) return Expression._from_pyexpr(expr) def arctan2(self, other: Expression) -> Expression: @@ -652,41 +652,41 @@ def arctan2(self, other: Expression) -> Expression: * ``y >= 0``: ``(pi/2, pi]`` * ``y < 0``: ``(-pi, -pi/2)``""" expr = Expression._to_expression(other) - return Expression._from_pyexpr(self._expr.arctan2(expr._expr)) + return Expression._from_pyexpr(native.arctan2(self._expr, expr._expr)) def arctanh(self) -> Expression: """The elementwise inverse hyperbolic tangent of a numeric expression (``expr.arctanh()``)""" - expr = self._expr.arctanh() + expr = native.arctanh(self._expr) return Expression._from_pyexpr(expr) def arccosh(self) -> Expression: """The elementwise inverse hyperbolic cosine of a numeric expression (``expr.arccosh()``)""" - expr = self._expr.arccosh() + expr = native.arccosh(self._expr) return Expression._from_pyexpr(expr) def arcsinh(self) -> Expression: """The elementwise inverse hyperbolic sine of a numeric expression (``expr.arcsinh()``)""" - expr = self._expr.arcsinh() + expr = native.arcsinh(self._expr) return Expression._from_pyexpr(expr) def radians(self) -> Expression: """The elementwise radians of a numeric expression (``expr.radians()``)""" - expr = self._expr.radians() + expr = native.radians(self._expr) return Expression._from_pyexpr(expr) def degrees(self) -> Expression: """The elementwise degrees of a numeric expression (``expr.degrees()``)""" - expr = self._expr.degrees() + expr = native.degrees(self._expr) return Expression._from_pyexpr(expr) def log2(self) -> Expression: """The elementwise log base 2 of a numeric expression (``expr.log2()``)""" - expr = self._expr.log2() + expr = native.log2(self._expr) return Expression._from_pyexpr(expr) def log10(self) -> Expression: """The elementwise log base 10 of a numeric expression (``expr.log10()``)""" - expr = self._expr.log10() + expr = native.log10(self._expr) return Expression._from_pyexpr(expr) def log(self, base: float = math.e) -> Expression: # type: ignore @@ -695,17 +695,17 @@ def log(self, base: float = math.e) -> Expression: # type: ignore base: The base of the logarithm. Defaults to e. """ assert isinstance(base, (int, float)), f"base must be an int or float, but {type(base)} was provided." - expr = self._expr.log(float(base)) + expr = native.log(self._expr, float(base)) return Expression._from_pyexpr(expr) def ln(self) -> Expression: """The elementwise natural log of a numeric expression (``expr.ln()``)""" - expr = self._expr.ln() + expr = native.ln(self._expr) return Expression._from_pyexpr(expr) def exp(self) -> Expression: """The e^self of a numeric expression (``expr.exp()``)""" - expr = self._expr.exp() + expr = native.exp(self._expr) return Expression._from_pyexpr(expr) def bitwise_and(self, other: Expression) -> Expression: diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 4ff2375b68..0386d7c54c 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -1,5 +1,4 @@ pub mod map; -pub mod numeric; pub mod partitioning; pub mod scalar; pub mod sketch; @@ -17,8 +16,8 @@ pub use scalar::*; use serde::{Deserialize, Serialize}; use self::{ - map::MapExpr, numeric::NumericExpr, partitioning::PartitioningExpr, sketch::SketchExpr, - struct_::StructExpr, utf8::Utf8Expr, + map::MapExpr, partitioning::PartitioningExpr, sketch::SketchExpr, struct_::StructExpr, + utf8::Utf8Expr, }; use crate::{Expr, ExprRef, Operator}; @@ -27,7 +26,6 @@ use python::PythonUDF; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { - Numeric(NumericExpr), Utf8(Utf8Expr), Map(MapExpr), Sketch(SketchExpr), @@ -52,7 +50,6 @@ impl FunctionExpr { fn get_evaluator(&self) -> &dyn FunctionEvaluator { use FunctionExpr::*; match self { - Numeric(expr) => expr.get_evaluator(), Utf8(expr) => expr.get_evaluator(), Map(expr) => expr.get_evaluator(), Sketch(expr) => expr.get_evaluator(), diff --git a/src/daft-dsl/src/functions/numeric/abs.rs b/src/daft-dsl/src/functions/numeric/abs.rs deleted file mode 100644 index af8566960d..0000000000 --- a/src/daft-dsl/src/functions/numeric/abs.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct AbsEvaluator {} - -impl FunctionEvaluator for AbsEvaluator { - fn fn_name(&self) -> &'static str { - "abs" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to abs to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().abs() - } -} diff --git a/src/daft-dsl/src/functions/numeric/ceil.rs b/src/daft-dsl/src/functions/numeric/ceil.rs deleted file mode 100644 index 735ab91bc3..0000000000 --- a/src/daft-dsl/src/functions/numeric/ceil.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct CeilEvaluator {} - -impl FunctionEvaluator for CeilEvaluator { - fn fn_name(&self) -> &'static str { - "ceil" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to ceil to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().ceil() - } -} diff --git a/src/daft-dsl/src/functions/numeric/exp.rs b/src/daft-dsl/src/functions/numeric/exp.rs deleted file mode 100644 index bde9b90f6f..0000000000 --- a/src/daft-dsl/src/functions/numeric/exp.rs +++ /dev/null @@ -1,46 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use crate::{ - functions::{FunctionEvaluator, FunctionExpr}, - ExprRef, -}; - -pub(super) struct ExpEvaluator {} - -impl FunctionEvaluator for ExpEvaluator { - fn fn_name(&self) -> &'static str { - "exp" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - }; - let field = inputs.first().unwrap().to_field(schema)?; - let dtype = match field.dtype { - DataType::Float32 => DataType::Float32, - dt if dt.is_numeric() => DataType::Float64, - _ => { - return Err(DaftError::TypeError(format!( - "Expected input to compute exp to be numeric, got {}", - field.dtype - ))) - } - }; - Ok(Field::new(field.name, dtype)) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().exp() - } -} diff --git a/src/daft-dsl/src/functions/numeric/floor.rs b/src/daft-dsl/src/functions/numeric/floor.rs deleted file mode 100644 index a76de5fda9..0000000000 --- a/src/daft-dsl/src/functions/numeric/floor.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct FloorEvaluator {} - -impl FunctionEvaluator for FloorEvaluator { - fn fn_name(&self) -> &'static str { - "floor" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to floor to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().floor() - } -} diff --git a/src/daft-dsl/src/functions/numeric/log.rs b/src/daft-dsl/src/functions/numeric/log.rs deleted file mode 100644 index 9c6105c449..0000000000 --- a/src/daft-dsl/src/functions/numeric/log.rs +++ /dev/null @@ -1,68 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, NumericExpr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) enum LogFunction { - Log2, - Log10, - Log, - Ln, -} -pub(super) struct LogEvaluator(pub LogFunction); - -impl FunctionEvaluator for LogEvaluator { - fn fn_name(&self) -> &'static str { - match self.0 { - LogFunction::Log2 => "log2", - LogFunction::Log10 => "log10", - LogFunction::Log => "log", - LogFunction::Ln => "ln", - } - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - let dtype = match field.dtype { - DataType::Float32 => DataType::Float32, - dt if dt.is_numeric() => DataType::Float64, - _ => { - return Err(DaftError::TypeError(format!( - "Expected input to log to be numeric, got {}", - field.dtype - ))) - } - }; - Ok(Field::new(field.name, dtype)) - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let input = inputs.first().unwrap(); - match self.0 { - LogFunction::Log2 => input.log2(), - LogFunction::Log10 => input.log10(), - LogFunction::Log => { - let base = match expr { - FunctionExpr::Numeric(NumericExpr::Log(value)) => value, - _ => panic!("Expected Log Expr, got {expr}"), - }; - - input.log(base.0) - } - LogFunction::Ln => input.ln(), - } - } -} diff --git a/src/daft-dsl/src/functions/numeric/mod.rs b/src/daft-dsl/src/functions/numeric/mod.rs deleted file mode 100644 index 98ab6d3f32..0000000000 --- a/src/daft-dsl/src/functions/numeric/mod.rs +++ /dev/null @@ -1,283 +0,0 @@ -mod abs; -mod ceil; -mod exp; -mod floor; -mod log; -mod round; -mod sign; -mod sqrt; -mod trigonometry; - -use std::hash::Hash; - -use abs::AbsEvaluator; -use ceil::CeilEvaluator; -use common_hashable_float_wrapper::FloatWrapper; -use floor::FloorEvaluator; -use log::LogEvaluator; -use round::RoundEvaluator; -use serde::{Deserialize, Serialize}; -use sign::SignEvaluator; -use sqrt::SqrtEvaluator; -use trigonometry::Atan2Evaluator; - -use super::FunctionEvaluator; -use crate::{ - functions::numeric::{ - exp::ExpEvaluator, - trigonometry::{TrigonometricFunction, TrigonometryEvaluator}, - }, - Expr, ExprRef, -}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum NumericExpr { - Abs, - Ceil, - Floor, - Sign, - Round(i32), - Sqrt, - Sin, - Cos, - Tan, - Cot, - ArcSin, - ArcCos, - ArcTan, - ArcTan2, - Radians, - Degrees, - Log2, - Log10, - Log(FloatWrapper), - Ln, - Exp, - ArcTanh, - ArcCosh, - ArcSinh, -} - -impl NumericExpr { - #[inline] - pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - match self { - NumericExpr::Abs => &AbsEvaluator {}, - NumericExpr::Ceil => &CeilEvaluator {}, - NumericExpr::Floor => &FloorEvaluator {}, - NumericExpr::Sign => &SignEvaluator {}, - NumericExpr::Round(_) => &RoundEvaluator {}, - NumericExpr::Sqrt => &SqrtEvaluator {}, - NumericExpr::Sin => &TrigonometryEvaluator(TrigonometricFunction::Sin), - NumericExpr::Cos => &TrigonometryEvaluator(TrigonometricFunction::Cos), - NumericExpr::Tan => &TrigonometryEvaluator(TrigonometricFunction::Tan), - NumericExpr::Cot => &TrigonometryEvaluator(TrigonometricFunction::Cot), - NumericExpr::ArcSin => &TrigonometryEvaluator(TrigonometricFunction::ArcSin), - NumericExpr::ArcCos => &TrigonometryEvaluator(TrigonometricFunction::ArcCos), - NumericExpr::ArcTan => &TrigonometryEvaluator(TrigonometricFunction::ArcTan), - NumericExpr::ArcTan2 => &Atan2Evaluator {}, - NumericExpr::Radians => &TrigonometryEvaluator(TrigonometricFunction::Radians), - NumericExpr::Degrees => &TrigonometryEvaluator(TrigonometricFunction::Degrees), - NumericExpr::Log2 => &LogEvaluator(log::LogFunction::Log2), - NumericExpr::Log10 => &LogEvaluator(log::LogFunction::Log10), - NumericExpr::Log(_) => &LogEvaluator(log::LogFunction::Log), - NumericExpr::Ln => &LogEvaluator(log::LogFunction::Ln), - NumericExpr::Exp => &ExpEvaluator {}, - NumericExpr::ArcTanh => &TrigonometryEvaluator(TrigonometricFunction::ArcTanh), - NumericExpr::ArcCosh => &TrigonometryEvaluator(TrigonometricFunction::ArcCosh), - NumericExpr::ArcSinh => &TrigonometryEvaluator(TrigonometricFunction::ArcSinh), - } - } -} - -pub fn abs(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Abs), - inputs: vec![input], - } - .into() -} - -pub fn ceil(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Ceil), - inputs: vec![input], - } - .into() -} - -pub fn floor(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Floor), - inputs: vec![input], - } - .into() -} - -pub fn sign(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Sign), - inputs: vec![input], - } - .into() -} - -pub fn round(input: ExprRef, decimal: i32) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Round(decimal)), - inputs: vec![input], - } - .into() -} - -pub fn sqrt(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Sqrt), - inputs: vec![input], - } - .into() -} - -pub fn sin(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Sin), - inputs: vec![input], - } - .into() -} - -pub fn cos(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Cos), - inputs: vec![input], - } - .into() -} - -pub fn tan(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Tan), - inputs: vec![input], - } - .into() -} - -pub fn cot(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Cot), - inputs: vec![input], - } - .into() -} - -pub fn arcsin(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcSin), - inputs: vec![input], - } - .into() -} - -pub fn arccos(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcCos), - inputs: vec![input], - } - .into() -} - -pub fn arctan(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcTan), - inputs: vec![input], - } - .into() -} - -pub fn arctan2(input: ExprRef, other: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcTan2), - inputs: vec![input, other], - } - .into() -} - -pub fn radians(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Radians), - inputs: vec![input], - } - .into() -} - -pub fn degrees(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Degrees), - inputs: vec![input], - } - .into() -} - -pub fn arctanh(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcTanh), - inputs: vec![input], - } - .into() -} - -pub fn arccosh(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcCosh), - inputs: vec![input], - } - .into() -} - -pub fn arcsinh(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::ArcSinh), - inputs: vec![input], - } - .into() -} - -pub fn log2(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Log2), - inputs: vec![input], - } - .into() -} - -pub fn log10(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Log10), - inputs: vec![input], - } - .into() -} - -pub fn log(input: ExprRef, base: f64) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Log(FloatWrapper(base))), - inputs: vec![input], - } - .into() -} - -pub fn ln(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Ln), - inputs: vec![input], - } - .into() -} - -pub fn exp(input: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Numeric(NumericExpr::Exp), - inputs: vec![input], - } - .into() -} diff --git a/src/daft-dsl/src/functions/numeric/round.rs b/src/daft-dsl/src/functions/numeric/round.rs deleted file mode 100644 index aee8fa29c6..0000000000 --- a/src/daft-dsl/src/functions/numeric/round.rs +++ /dev/null @@ -1,44 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, NumericExpr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct RoundEvaluator {} - -impl FunctionEvaluator for RoundEvaluator { - fn fn_name(&self) -> &'static str { - "round" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to round to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let decimal = match expr { - FunctionExpr::Numeric(NumericExpr::Round(index)) => index, - _ => panic!("Expected Round Expr, got {expr}"), - }; - inputs.first().unwrap().round(*decimal) - } -} diff --git a/src/daft-dsl/src/functions/numeric/sign.rs b/src/daft-dsl/src/functions/numeric/sign.rs deleted file mode 100644 index 20dafba799..0000000000 --- a/src/daft-dsl/src/functions/numeric/sign.rs +++ /dev/null @@ -1,40 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct SignEvaluator {} - -impl FunctionEvaluator for SignEvaluator { - fn fn_name(&self) -> &'static str { - "sign" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - let field = inputs.first().unwrap().to_field(schema)?; - if !field.dtype.is_numeric() { - return Err(DaftError::TypeError(format!( - "Expected input to sign to be numeric, got {}", - field.dtype - ))); - } - Ok(field) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().sign() - } -} diff --git a/src/daft-dsl/src/functions/numeric/sqrt.rs b/src/daft-dsl/src/functions/numeric/sqrt.rs deleted file mode 100644 index 248ce1de7e..0000000000 --- a/src/daft-dsl/src/functions/numeric/sqrt.rs +++ /dev/null @@ -1,37 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct SqrtEvaluator {} - -impl FunctionEvaluator for SqrtEvaluator { - fn fn_name(&self) -> &'static str { - "sqrt" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [first] => { - let field = first.to_field(schema)?; - let dtype = field.dtype.to_floating_representation()?; - Ok(Field::new(field.name, dtype)) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [first] => first.sqrt(), - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/numeric/trigonometry.rs b/src/daft-dsl/src/functions/numeric/trigonometry.rs deleted file mode 100644 index 9779e802d4..0000000000 --- a/src/daft-dsl/src/functions/numeric/trigonometry.rs +++ /dev/null @@ -1,87 +0,0 @@ -use common_error::{DaftError, DaftResult}; -pub use daft_core::array::ops::trigonometry::TrigonometricFunction; -use daft_core::prelude::*; - -use crate::{ - functions::{FunctionEvaluator, FunctionExpr}, - ExprRef, -}; - -pub(super) struct TrigonometryEvaluator(pub TrigonometricFunction); - -impl FunctionEvaluator for TrigonometryEvaluator { - fn fn_name(&self) -> &'static str { - self.0.fn_name() - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - }; - let field = inputs.first().unwrap().to_field(schema)?; - let dtype = match field.dtype { - DataType::Float32 => DataType::Float32, - dt if dt.is_numeric() => DataType::Float64, - _ => { - return Err(DaftError::TypeError(format!( - "Expected input to trigonometry to be numeric, got {}", - field.dtype - ))) - } - }; - Ok(Field::new(field.name, dtype)) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 1 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().trigonometry(&self.0) - } -} - -pub(super) struct Atan2Evaluator {} - -impl FunctionEvaluator for Atan2Evaluator { - fn fn_name(&self) -> &'static str { - "atan2" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - if inputs.len() != 2 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))); - } - let field1 = inputs.first().unwrap().to_field(schema)?; - let field2 = inputs.get(1).unwrap().to_field(schema)?; - let dtype = match (field1.dtype, field2.dtype) { - (DataType::Float32, DataType::Float32) => DataType::Float32, - (dt1, dt2) if dt1.is_numeric() && dt2.is_numeric() => DataType::Float64, - (dt1, dt2) => { - return Err(DaftError::TypeError(format!( - "Expected inputs to atan2 to be numeric, got {} and {}", - dt1, dt2 - ))) - } - }; - Ok(Field::new(field1.name, dtype)) - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - if inputs.len() != 2 { - return Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))); - } - inputs.first().unwrap().atan2(inputs.get(1).unwrap()) - } -} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 7c5a1d7930..af56dc68d8 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -1,5 +1,4 @@ #![allow(non_snake_case)] - use std::{ collections::{hash_map::DefaultHasher, HashMap}, hash::{Hash, Hasher}, @@ -22,7 +21,7 @@ use pyo3::{ }; use serde::{Deserialize, Serialize}; -use crate::{functions, Expr, ExprRef, LiteralValue}; +use crate::{Expr, ExprRef, LiteralValue}; #[pyfunction] pub fn col(name: &str) -> PyResult { @@ -274,126 +273,6 @@ impl PyExpr { Ok(self.expr.clone().cast(&dtype.into()).into()) } - pub fn ceil(&self) -> PyResult { - use functions::numeric::ceil; - Ok(ceil(self.into()).into()) - } - - pub fn floor(&self) -> PyResult { - use functions::numeric::floor; - Ok(floor(self.into()).into()) - } - - pub fn sign(&self) -> PyResult { - use functions::numeric::sign; - Ok(sign(self.into()).into()) - } - - pub fn round(&self, decimal: i32) -> PyResult { - use functions::numeric::round; - if decimal < 0 { - return Err(PyValueError::new_err(format!( - "decimal can not be negative: {decimal}" - ))); - } - Ok(round(self.into(), decimal).into()) - } - - pub fn sqrt(&self) -> PyResult { - use functions::numeric::sqrt; - Ok(sqrt(self.into()).into()) - } - - pub fn sin(&self) -> PyResult { - use functions::numeric::sin; - Ok(sin(self.into()).into()) - } - - pub fn cos(&self) -> PyResult { - use functions::numeric::cos; - Ok(cos(self.into()).into()) - } - - pub fn tan(&self) -> PyResult { - use functions::numeric::tan; - Ok(tan(self.into()).into()) - } - - pub fn cot(&self) -> PyResult { - use functions::numeric::cot; - Ok(cot(self.into()).into()) - } - - pub fn arcsin(&self) -> PyResult { - use functions::numeric::arcsin; - Ok(arcsin(self.into()).into()) - } - - pub fn arccos(&self) -> PyResult { - use functions::numeric::arccos; - Ok(arccos(self.into()).into()) - } - - pub fn arctan(&self) -> PyResult { - use functions::numeric::arctan; - Ok(arctan(self.into()).into()) - } - - pub fn arctan2(&self, other: &Self) -> PyResult { - use functions::numeric::arctan2; - Ok(arctan2(self.into(), other.expr.clone()).into()) - } - - pub fn radians(&self) -> PyResult { - use functions::numeric::radians; - Ok(radians(self.into()).into()) - } - - pub fn degrees(&self) -> PyResult { - use functions::numeric::degrees; - Ok(degrees(self.into()).into()) - } - - pub fn arctanh(&self) -> PyResult { - use functions::numeric::arctanh; - Ok(arctanh(self.into()).into()) - } - - pub fn arccosh(&self) -> PyResult { - use functions::numeric::arccosh; - Ok(arccosh(self.into()).into()) - } - - pub fn arcsinh(&self) -> PyResult { - use functions::numeric::arcsinh; - Ok(arcsinh(self.into()).into()) - } - - pub fn log2(&self) -> PyResult { - use functions::numeric::log2; - Ok(log2(self.into()).into()) - } - - pub fn log10(&self) -> PyResult { - use functions::numeric::log10; - Ok(log10(self.into()).into()) - } - - pub fn log(&self, base: f64) -> PyResult { - use functions::numeric::log; - Ok(log(self.into(), base).into()) - } - - pub fn ln(&self) -> PyResult { - use functions::numeric::ln; - Ok(ln(self.into()).into()) - } - - pub fn exp(&self) -> PyResult { - use functions::numeric::exp; - Ok(exp(self.into()).into()) - } - pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult { Ok(self .expr @@ -460,10 +339,6 @@ impl PyExpr { Ok(self.expr.clone().agg_concat().into()) } - pub fn __abs__(&self) -> PyResult { - use functions::numeric::abs; - Ok(abs(self.into()).into()) - } pub fn __add__(&self, other: &Self) -> PyResult { Ok(crate::binary_op(crate::Operator::Plus, self.into(), other.expr.clone()).into()) } diff --git a/src/daft-functions/Cargo.toml b/src/daft-functions/Cargo.toml index b965e9f417..9be2a86dc2 100644 --- a/src/daft-functions/Cargo.toml +++ b/src/daft-functions/Cargo.toml @@ -2,6 +2,7 @@ arrow2 = {workspace = true} base64 = {workspace = true} common-error = {path = "../common/error", default-features = false} +common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"} common-io-config = {path = "../common/io-config", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index 0976f17c21..0a8486864e 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -31,7 +31,6 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(hash::python::hash, parent)?)?; parent.add_function(wrap_pyfunction_bound!(minhash::python::minhash, parent)?)?; - parent.add_function(wrap_pyfunction_bound!(numeric::cbrt::python::cbrt, parent)?)?; parent.add_function(wrap_pyfunction_bound!( to_struct::python::to_struct, parent @@ -46,6 +45,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { )?)?; parent.add_function(wrap_pyfunction_bound!(uri::python::url_download, parent)?)?; parent.add_function(wrap_pyfunction_bound!(uri::python::url_upload, parent)?)?; + numeric::register_modules(parent)?; image::register_modules(parent)?; float::register_modules(parent)?; temporal::register_modules(parent)?; diff --git a/src/daft-functions/src/numeric/abs.rs b/src/daft-functions/src/numeric/abs.rs new file mode 100644 index 0000000000..f054950e0f --- /dev/null +++ b/src/daft-functions/src/numeric/abs.rs @@ -0,0 +1,48 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Abs {} + +#[typetag::serde] +impl ScalarUDF for Abs { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "abs" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::abs) + } +} + +pub fn abs(input: ExprRef) -> ExprRef { + ScalarFunction::new(Abs {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "abs")] +pub fn py_abs(expr: PyExpr) -> PyResult { + Ok(abs(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/cbrt.rs b/src/daft-functions/src/numeric/cbrt.rs index 7f2e635689..c9b4e9286f 100644 --- a/src/daft-functions/src/numeric/cbrt.rs +++ b/src/daft-functions/src/numeric/cbrt.rs @@ -1,13 +1,17 @@ -use common_error::{DaftError, DaftResult}; +use common_error::DaftResult; use daft_core::prelude::*; -use daft_dsl::{functions::ScalarUDF, ExprRef}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -struct CbrtFunction; +pub struct Cbrt; +use super::{evaluate_single_numeric, to_field_single_floating}; #[typetag::serde] -impl ScalarUDF for CbrtFunction { +impl ScalarUDF for Cbrt { fn as_any(&self) -> &dyn std::any::Any { self } @@ -17,41 +21,26 @@ impl ScalarUDF for CbrtFunction { } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { - match inputs { - [input] => { - let field = input.to_field(schema)?; - let dtype = field.dtype.to_floating_representation()?; - Ok(Field::new(field.name, dtype)) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } + to_field_single_floating(self, inputs, schema) } fn evaluate(&self, inputs: &[Series]) -> DaftResult { - match inputs { - [input] => input.cbrt(), - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } + evaluate_single_numeric(inputs, Series::cbrt) } } -#[cfg(feature = "python")] -pub mod python { - use daft_dsl::{functions::ScalarFunction, python::PyExpr, ExprRef}; - use pyo3::{pyfunction, PyResult}; - - use super::CbrtFunction; +pub fn cbrt(input: ExprRef) -> ExprRef { + ScalarFunction::new(Cbrt {}, vec![input]).into() +} - #[pyfunction] - pub fn cbrt(expr: PyExpr) -> PyResult { - let scalar_function = ScalarFunction::new(CbrtFunction, vec![expr.into()]); - let expr = ExprRef::from(scalar_function); - Ok(expr.into()) - } +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "cbrt")] +pub fn py_cbrt(expr: PyExpr) -> PyResult { + Ok(cbrt(expr.into()).into()) } diff --git a/src/daft-functions/src/numeric/ceil.rs b/src/daft-functions/src/numeric/ceil.rs new file mode 100644 index 0000000000..26c37bec6b --- /dev/null +++ b/src/daft-functions/src/numeric/ceil.rs @@ -0,0 +1,49 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Ceil {} + +#[typetag::serde] +impl ScalarUDF for Ceil { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "ceil" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::ceil) + } +} + +pub fn ceil(input: ExprRef) -> ExprRef { + ScalarFunction::new(Ceil {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "ceil")] +pub fn py_ceil(expr: PyExpr) -> PyResult { + Ok(ceil(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/exp.rs b/src/daft-functions/src/numeric/exp.rs new file mode 100644 index 0000000000..abde081b46 --- /dev/null +++ b/src/daft-functions/src/numeric/exp.rs @@ -0,0 +1,67 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +use super::evaluate_single_numeric; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Exp {} + +#[typetag::serde] +impl ScalarUDF for Exp { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "exp" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to compute exp to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::exp) + } +} + +pub fn exp(input: ExprRef) -> ExprRef { + ScalarFunction::new(Exp {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "exp")] +pub fn py_exp(expr: PyExpr) -> PyResult { + Ok(exp(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/floor.rs b/src/daft-functions/src/numeric/floor.rs new file mode 100644 index 0000000000..36ec365e0f --- /dev/null +++ b/src/daft-functions/src/numeric/floor.rs @@ -0,0 +1,50 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Floor {} + +#[typetag::serde] +impl ScalarUDF for Floor { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "floor" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::floor) + } +} + +pub fn floor(input: ExprRef) -> ExprRef { + ScalarFunction::new(Floor {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "floor")] +pub fn py_floor(expr: PyExpr) -> PyResult { + Ok(floor(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/log.rs b/src/daft-functions/src/numeric/log.rs new file mode 100644 index 0000000000..7aecb2de56 --- /dev/null +++ b/src/daft-functions/src/numeric/log.rs @@ -0,0 +1,142 @@ +use common_error::{DaftError, DaftResult}; +use common_hashable_float_wrapper::FloatWrapper; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +// super annoying, but using an enum with typetag::serde doesn't work with bincode because it uses Deserializer::deserialize_identifier +macro_rules! log { + ($name:ident, $variant:ident) => { + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] + pub struct $variant; + + #[typetag::serde] + impl ScalarUDF for $variant { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + stringify!($name) + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to log to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::$name) + } + } + + pub fn $name(input: ExprRef) -> ExprRef { + ScalarFunction::new($variant, vec![input]).into() + } + }; +} + +log!(log2, Log2); +log!(log10, Log10); +log!(ln, Ln); + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Log(FloatWrapper); + +#[typetag::serde] +impl ScalarUDF for Log { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "log" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to log to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, |x| x.log(self.0 .0)) + } +} + +pub fn log(input: ExprRef, base: f64) -> ExprRef { + ScalarFunction::new(Log(FloatWrapper(base)), vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::evaluate_single_numeric; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "log2")] +pub fn py_log2(expr: PyExpr) -> PyResult { + Ok(log2(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "log10")] +pub fn py_log10(expr: PyExpr) -> PyResult { + Ok(log10(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "log")] +pub fn py_log(expr: PyExpr, base: f64) -> PyResult { + Ok(log(expr.into(), base).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "ln")] +pub fn py_ln(expr: PyExpr) -> PyResult { + Ok(ln(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/mod.rs b/src/daft-functions/src/numeric/mod.rs index f78822dd52..28d0e50cad 100644 --- a/src/daft-functions/src/numeric/mod.rs +++ b/src/daft-functions/src/numeric/mod.rs @@ -1 +1,104 @@ +pub mod abs; pub mod cbrt; +pub mod ceil; +pub mod exp; +pub mod floor; +pub mod log; +pub mod round; +pub mod sign; +pub mod sqrt; +pub mod trigonometry; + +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{functions::ScalarUDF, ExprRef}; +#[cfg(feature = "python")] +use pyo3::prelude::*; + +#[cfg(feature = "python")] +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_function(wrap_pyfunction_bound!(abs::py_abs, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(cbrt::py_cbrt, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(ceil::py_ceil, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(exp::py_exp, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(floor::py_floor, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(log::py_log2, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(log::py_log10, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(log::py_log, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(log::py_ln, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(round::py_round, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(sign::py_sign, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(sqrt::py_sqrt, parent)?)?; + + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_sin, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_cos, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_tan, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_cot, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arcsin, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arccos, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arctan, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_radians, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_degrees, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arctanh, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arccosh, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arcsinh, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(trigonometry::py_arctan2, parent)?)?; + + Ok(()) +} + +fn to_field_single_numeric( + f: &dyn ScalarUDF, + inputs: &[ExprRef], + schema: &Schema, +) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + let field = inputs.first().unwrap().to_field(schema)?; + if !field.dtype.is_numeric() { + return Err(DaftError::TypeError(format!( + "Expected input to {} to be numeric, got {}", + f.name(), + field.dtype + ))); + } + Ok(field) +} + +fn to_field_single_floating( + f: &dyn ScalarUDF, + inputs: &[ExprRef], + schema: &Schema, +) -> DaftResult { + match inputs { + [first] => { + let field = first.to_field(schema)?; + let dtype = field.dtype.to_floating_representation()?; + Ok(Field::new(field.name, dtype)) + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg for {}, got {}", + f.name(), + inputs.len() + ))), + } +} +fn evaluate_single_numeric DaftResult>( + inputs: &[Series], + func: F, +) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + func(inputs.first().unwrap()) +} diff --git a/src/daft-functions/src/numeric/round.rs b/src/daft-functions/src/numeric/round.rs new file mode 100644 index 0000000000..395b0ee696 --- /dev/null +++ b/src/daft-functions/src/numeric/round.rs @@ -0,0 +1,59 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Round { + decimal: i32, +} + +#[typetag::serde] +impl ScalarUDF for Round { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "round" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, |s| s.round(self.decimal)) + } +} + +pub fn round(input: ExprRef, decimal: i32) -> ExprRef { + ScalarFunction::new(Round { decimal }, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "round")] +pub fn py_round(expr: PyExpr, decimal: i32) -> PyResult { + use pyo3::exceptions::PyValueError; + + if decimal < 0 { + return Err(PyValueError::new_err(format!( + "decimal can not be negative: {decimal}" + ))); + } + Ok(round(expr.into(), decimal).into()) +} diff --git a/src/daft-functions/src/numeric/sign.rs b/src/daft-functions/src/numeric/sign.rs new file mode 100644 index 0000000000..a58b7f294d --- /dev/null +++ b/src/daft-functions/src/numeric/sign.rs @@ -0,0 +1,50 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Sign {} + +#[typetag::serde] +impl ScalarUDF for Sign { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "sign" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_numeric(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::sign) + } +} + +pub fn sign(input: ExprRef) -> ExprRef { + ScalarFunction::new(Sign {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +use super::{evaluate_single_numeric, to_field_single_numeric}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "sign")] +pub fn py_sign(expr: PyExpr) -> PyResult { + Ok(sign(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/sqrt.rs b/src/daft-functions/src/numeric/sqrt.rs new file mode 100644 index 0000000000..11766e4f17 --- /dev/null +++ b/src/daft-functions/src/numeric/sqrt.rs @@ -0,0 +1,51 @@ +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +use super::{evaluate_single_numeric, to_field_single_floating}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Sqrt {} + +#[typetag::serde] +impl ScalarUDF for Sqrt { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "sqrt" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + to_field_single_floating(self, inputs, schema) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, Series::sqrt) + } +} + +pub fn sqrt(input: ExprRef) -> ExprRef { + ScalarFunction::new(Sqrt {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "sqrt")] +pub fn py_sqrt(expr: PyExpr) -> PyResult { + Ok(sqrt(expr.into()).into()) +} diff --git a/src/daft-functions/src/numeric/trigonometry.rs b/src/daft-functions/src/numeric/trigonometry.rs new file mode 100644 index 0000000000..9a47875596 --- /dev/null +++ b/src/daft-functions/src/numeric/trigonometry.rs @@ -0,0 +1,222 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + array::ops::trigonometry::TrigonometricFunction, + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +use super::evaluate_single_numeric; + +// super annoying, but using an enum with typetag::serde doesn't work with bincode because it uses Deserializer::deserialize_identifier +macro_rules! trigonometry { + ($name:ident, $variant:ident) => { + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] + pub struct $variant; + + #[typetag::serde] + impl ScalarUDF for $variant { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + TrigonometricFunction::$variant.fn_name() + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to trigonometry to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + evaluate_single_numeric(inputs, |s| { + s.trigonometry(&TrigonometricFunction::$variant) + }) + } + } + + pub fn $name(input: ExprRef) -> ExprRef { + ScalarFunction::new($variant, vec![input]).into() + } + }; +} + +trigonometry!(sin, Sin); +trigonometry!(cos, Cos); +trigonometry!(tan, Tan); +trigonometry!(cot, Cot); +trigonometry!(arcsin, ArcSin); +trigonometry!(arccos, ArcCos); +trigonometry!(arctan, ArcTan); +trigonometry!(radians, Radians); +trigonometry!(degrees, Degrees); +trigonometry!(arctanh, ArcTanh); +trigonometry!(arccosh, ArcCosh); +trigonometry!(arcsinh, ArcSinh); + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Atan2 {} + +#[typetag::serde] +impl ScalarUDF for Atan2 { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "atan2" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + if inputs.len() != 2 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))); + } + let field1 = inputs.first().unwrap().to_field(schema)?; + let field2 = inputs.get(1).unwrap().to_field(schema)?; + let dtype = match (field1.dtype, field2.dtype) { + (DataType::Float32, DataType::Float32) => DataType::Float32, + (dt1, dt2) if dt1.is_numeric() && dt2.is_numeric() => DataType::Float64, + (dt1, dt2) => { + return Err(DaftError::TypeError(format!( + "Expected inputs to atan2 to be numeric, got {} and {}", + dt1, dt2 + ))) + } + }; + Ok(Field::new(field1.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [x, y] => x.atan2(y), + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +pub fn atan2(x: ExprRef, y: ExprRef) -> ExprRef { + ScalarFunction::new(Atan2 {}, vec![x, y]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "sin")] +pub fn py_sin(expr: PyExpr) -> PyResult { + Ok(sin(expr.into()).into()) +} +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "cos")] +pub fn py_cos(expr: PyExpr) -> PyResult { + Ok(cos(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "tan")] +pub fn py_tan(expr: PyExpr) -> PyResult { + Ok(tan(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "cot")] +pub fn py_cot(expr: PyExpr) -> PyResult { + Ok(cot(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arcsin")] +pub fn py_arcsin(expr: PyExpr) -> PyResult { + Ok(arcsin(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arccos")] +pub fn py_arccos(expr: PyExpr) -> PyResult { + Ok(arccos(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arctan")] +pub fn py_arctan(expr: PyExpr) -> PyResult { + Ok(arctan(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "radians")] +pub fn py_radians(expr: PyExpr) -> PyResult { + Ok(radians(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "degrees")] +pub fn py_degrees(expr: PyExpr) -> PyResult { + Ok(degrees(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arctanh")] +pub fn py_arctanh(expr: PyExpr) -> PyResult { + Ok(arctanh(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arccosh")] +pub fn py_arccosh(expr: PyExpr) -> PyResult { + Ok(arccosh(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arcsinh")] +pub fn py_arcsinh(expr: PyExpr) -> PyResult { + Ok(arcsinh(expr.into()).into()) +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "arctan2")] +pub fn py_arctan2(x: PyExpr, y: PyExpr) -> PyResult { + Ok(atan2(x.into(), y.into()).into()) +} diff --git a/src/daft-sql/src/modules/numeric.rs b/src/daft-sql/src/modules/numeric.rs index 078878faef..197d958860 100644 --- a/src/daft-sql/src/modules/numeric.rs +++ b/src/daft-sql/src/modules/numeric.rs @@ -1,6 +1,17 @@ -use daft_dsl::{ - functions::{self, numeric::NumericExpr}, - ExprRef, LiteralValue, +use daft_dsl::{ExprRef, LiteralValue}; +use daft_functions::numeric::{ + abs::abs, + ceil::ceil, + exp::exp, + floor::floor, + log::{ln, log, log10, log2}, + round::round, + sign::sign, + sqrt::sqrt, + trigonometry::{ + arccos, arccosh, arcsin, arcsinh, arctan, arctanh, atan2, cos, cot, degrees, radians, sin, + tan, + }, }; use super::SQLModule; @@ -13,38 +24,62 @@ use crate::{ pub struct SQLModuleNumeric; -/// SQLModule for FunctionExpr::Numeric impl SQLModule for SQLModuleNumeric { fn register(parent: &mut SQLFunctions) { - use NumericExpr::*; - parent.add_fn("abs", Abs); - parent.add_fn("ceil", Ceil); - parent.add_fn("floor", Floor); - parent.add_fn("sign", Sign); - parent.add_fn("round", Round(0)); - parent.add_fn("sqrt", Sqrt); - parent.add_fn("sin", Sin); - parent.add_fn("cos", Cos); - parent.add_fn("tan", Tan); - parent.add_fn("cot", Cot); - parent.add_fn("asin", ArcSin); - parent.add_fn("acos", ArcCos); - parent.add_fn("atan", ArcTan); - parent.add_fn("atan2", ArcTan2); - parent.add_fn("radians", Radians); - parent.add_fn("degrees", Degrees); - parent.add_fn("log2", Log2); - parent.add_fn("log10", Log10); - // parent.add("log", f(Log(FloatWrapper(0.0)))); - parent.add_fn("ln", Ln); - parent.add_fn("exp", Exp); - parent.add_fn("atanh", ArcTanh); - parent.add_fn("acosh", ArcCosh); - parent.add_fn("asinh", ArcSinh); + parent.add_fn("abs", SQLNumericExpr::Abs); + parent.add_fn("ceil", SQLNumericExpr::Ceil); + parent.add_fn("floor", SQLNumericExpr::Floor); + parent.add_fn("sign", SQLNumericExpr::Sign); + parent.add_fn("round", SQLNumericExpr::Round); + parent.add_fn("sqrt", SQLNumericExpr::Sqrt); + parent.add_fn("sin", SQLNumericExpr::Sin); + parent.add_fn("cos", SQLNumericExpr::Cos); + parent.add_fn("tan", SQLNumericExpr::Tan); + parent.add_fn("cot", SQLNumericExpr::Cot); + parent.add_fn("asin", SQLNumericExpr::ArcSin); + parent.add_fn("acos", SQLNumericExpr::ArcCos); + parent.add_fn("atan", SQLNumericExpr::ArcTan); + parent.add_fn("atan2", SQLNumericExpr::ArcTan2); + parent.add_fn("radians", SQLNumericExpr::Radians); + parent.add_fn("degrees", SQLNumericExpr::Degrees); + parent.add_fn("log2", SQLNumericExpr::Log2); + parent.add_fn("log10", SQLNumericExpr::Log10); + parent.add_fn("log", SQLNumericExpr::Log); + parent.add_fn("ln", SQLNumericExpr::Ln); + parent.add_fn("exp", SQLNumericExpr::Exp); + parent.add_fn("atanh", SQLNumericExpr::ArcTanh); + parent.add_fn("acosh", SQLNumericExpr::ArcCosh); + parent.add_fn("asinh", SQLNumericExpr::ArcSinh); } } +enum SQLNumericExpr { + Abs, + Ceil, + Exp, + Floor, + Round, + Sign, + Sqrt, + Sin, + Cos, + Tan, + Cot, + ArcSin, + ArcCos, + ArcTan, + ArcTan2, + Radians, + Degrees, + Log, + Log2, + Log10, + Ln, + ArcTanh, + ArcCosh, + ArcSinh, +} -impl SQLFunction for NumericExpr { +impl SQLFunction for SQLNumericExpr { fn to_expr( &self, inputs: &[sqlparser::ast::FunctionArg], @@ -54,27 +89,26 @@ impl SQLFunction for NumericExpr { to_expr(self, inputs.as_slice()) } } -fn to_expr(expr: &NumericExpr, args: &[ExprRef]) -> SQLPlannerResult { - use functions::numeric::*; - use NumericExpr::*; + +fn to_expr(expr: &SQLNumericExpr, args: &[ExprRef]) -> SQLPlannerResult { match expr { - Abs => { + SQLNumericExpr::Abs => { ensure!(args.len() == 1, "abs takes exactly one argument"); Ok(abs(args[0].clone())) } - Ceil => { + SQLNumericExpr::Ceil => { ensure!(args.len() == 1, "ceil takes exactly one argument"); Ok(ceil(args[0].clone())) } - Floor => { + SQLNumericExpr::Floor => { ensure!(args.len() == 1, "floor takes exactly one argument"); Ok(floor(args[0].clone())) } - Sign => { + SQLNumericExpr::Sign => { ensure!(args.len() == 1, "sign takes exactly one argument"); Ok(sign(args[0].clone())) } - Round(_) => { + SQLNumericExpr::Round => { ensure!(args.len() == 2, "round takes exactly two arguments"); let precision = match args[1].as_ref().as_literal() { Some(LiteralValue::Int32(i)) => *i, @@ -84,63 +118,63 @@ fn to_expr(expr: &NumericExpr, args: &[ExprRef]) -> SQLPlannerResult { }; Ok(round(args[0].clone(), precision)) } - Sqrt => { + SQLNumericExpr::Sqrt => { ensure!(args.len() == 1, "sqrt takes exactly one argument"); Ok(sqrt(args[0].clone())) } - Sin => { + SQLNumericExpr::Sin => { ensure!(args.len() == 1, "sin takes exactly one argument"); Ok(sin(args[0].clone())) } - Cos => { + SQLNumericExpr::Cos => { ensure!(args.len() == 1, "cos takes exactly one argument"); Ok(cos(args[0].clone())) } - Tan => { + SQLNumericExpr::Tan => { ensure!(args.len() == 1, "tan takes exactly one argument"); Ok(tan(args[0].clone())) } - Cot => { + SQLNumericExpr::Cot => { ensure!(args.len() == 1, "cot takes exactly one argument"); Ok(cot(args[0].clone())) } - ArcSin => { + SQLNumericExpr::ArcSin => { ensure!(args.len() == 1, "asin takes exactly one argument"); Ok(arcsin(args[0].clone())) } - ArcCos => { + SQLNumericExpr::ArcCos => { ensure!(args.len() == 1, "acos takes exactly one argument"); Ok(arccos(args[0].clone())) } - ArcTan => { + SQLNumericExpr::ArcTan => { ensure!(args.len() == 1, "atan takes exactly one argument"); Ok(arctan(args[0].clone())) } - ArcTan2 => { + SQLNumericExpr::ArcTan2 => { ensure!(args.len() == 2, "atan2 takes exactly two arguments"); - Ok(arctan2(args[0].clone(), args[1].clone())) + Ok(atan2(args[0].clone(), args[1].clone())) } - Degrees => { + SQLNumericExpr::Degrees => { ensure!(args.len() == 1, "degrees takes exactly one argument"); Ok(degrees(args[0].clone())) } - Radians => { + SQLNumericExpr::Radians => { ensure!(args.len() == 1, "radians takes exactly one argument"); Ok(radians(args[0].clone())) } - Log2 => { + SQLNumericExpr::Log2 => { ensure!(args.len() == 1, "log2 takes exactly one argument"); Ok(log2(args[0].clone())) } - Log10 => { + SQLNumericExpr::Log10 => { ensure!(args.len() == 1, "log10 takes exactly one argument"); Ok(log10(args[0].clone())) } - Ln => { + SQLNumericExpr::Ln => { ensure!(args.len() == 1, "ln takes exactly one argument"); Ok(ln(args[0].clone())) } - Log(_) => { + SQLNumericExpr::Log => { ensure!(args.len() == 2, "log takes exactly two arguments"); let base = args[1] .as_literal() @@ -158,19 +192,19 @@ fn to_expr(expr: &NumericExpr, args: &[ExprRef]) -> SQLPlannerResult { Ok(log(args[0].clone(), base)) } - Exp => { + SQLNumericExpr::Exp => { ensure!(args.len() == 1, "exp takes exactly one argument"); Ok(exp(args[0].clone())) } - ArcTanh => { + SQLNumericExpr::ArcTanh => { ensure!(args.len() == 1, "atanh takes exactly one argument"); Ok(arctanh(args[0].clone())) } - ArcCosh => { + SQLNumericExpr::ArcCosh => { ensure!(args.len() == 1, "acosh takes exactly one argument"); Ok(arccosh(args[0].clone())) } - ArcSinh => { + SQLNumericExpr::ArcSinh => { ensure!(args.len() == 1, "asinh takes exactly one argument"); Ok(arcsinh(args[0].clone())) } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index ff92b38b85..cf82f72743 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -4,12 +4,10 @@ use common_error::DaftResult; use daft_core::prelude::*; use daft_dsl::{ col, - functions::{ - numeric::{ceil, floor}, - utf8::{ilike, like}, - }, + functions::utf8::{ilike, like}, has_agg, lit, literals_to_series, null_lit, Expr, ExprRef, LiteralValue, Operator, }; +use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{