diff --git a/daft/daft.pyi b/daft/daft.pyi index 5bb906f67d..e0f77e5db1 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -891,6 +891,9 @@ class PyExpr: def floor(self) -> PyExpr: ... def sign(self) -> PyExpr: ... def round(self, decimal: int) -> PyExpr: ... + def sin(self) -> PyExpr: ... + def cos(self) -> PyExpr: ... + def tan(self) -> PyExpr: ... def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ... def count(self, mode: CountMode) -> PyExpr: ... def sum(self) -> PyExpr: ... @@ -1026,6 +1029,9 @@ class PySeries: def floor(self) -> PySeries: ... def sign(self) -> PySeries: ... def round(self, decimal: int) -> PySeries: ... + def sin(self) -> PySeries: ... + def cos(self) -> PySeries: ... + def tan(self) -> PySeries: ... @staticmethod def concat(series: list[PySeries]) -> PySeries: ... def __len__(self) -> int: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index fa377b4a00..c0c2843fdd 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -357,6 +357,21 @@ def round(self, decimals: int = 0) -> Expression: expr = self._expr.round(decimals) return Expression._from_pyexpr(expr) + def sin(self) -> Expression: + """The elementwise sine of a numeric expression (``expr.sin()``)""" + expr = self._expr.sin() + return Expression._from_pyexpr(expr) + + def cos(self) -> Expression: + """The elementwise cosine of a numeric expression (``expr.cos()``)""" + expr = self._expr.cos() + return Expression._from_pyexpr(expr) + + def tan(self) -> Expression: + """The elementwise tangent of a numeric expression (``expr.tan()``)""" + expr = self._expr.tan() + return Expression._from_pyexpr(expr) + def count(self, mode: CountMode = CountMode.Valid) -> Expression: """Counts the number of values in the expression. diff --git a/daft/series.py b/daft/series.py index b04b73f0d3..797267ca68 100644 --- a/daft/series.py +++ b/daft/series.py @@ -367,6 +367,18 @@ def sign(self) -> Series: def round(self, decimal: int) -> Series: return Series._from_pyseries(self._series.round(decimal)) + def sin(self) -> Series: + """The elementwise sine of a numeric series.""" + return Series._from_pyseries(self._series.sin()) + + def cos(self) -> Series: + """The elementwise cosine of a numeric series.""" + return Series._from_pyseries(self._series.cos()) + + def tan(self) -> Series: + """The elementwise tangent of a numeric series.""" + return Series._from_pyseries(self._series.tan()) + def __add__(self, other: object) -> Series: if not isinstance(other, Series): raise TypeError(f"expected another Series but got {type(other)}") diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index d5cc75dd91..ba7a0fc62e 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -50,6 +50,9 @@ Numeric Expression.floor Expression.sign Expression.round + Expression.sin + Expression.cos + Expression.tan .. _api-comparison-expression: diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index ef5e88c628..b2bd1f05fa 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -40,6 +40,7 @@ mod struct_; mod sum; mod take; pub(crate) mod tensor; +pub mod trigonometry; mod truncate; mod utf8; diff --git a/src/daft-core/src/array/ops/trigonometry.rs b/src/daft-core/src/array/ops/trigonometry.rs new file mode 100644 index 0000000000..5e2772447d --- /dev/null +++ b/src/daft-core/src/array/ops/trigonometry.rs @@ -0,0 +1,40 @@ +use num_traits::Float; +use serde::{Deserialize, Serialize}; + +use common_error::DaftResult; + +use crate::array::DataArray; +use crate::datatypes::DaftFloatType; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum TrigonometricFunction { + Sin, + Cos, + Tan, +} + +impl TrigonometricFunction { + pub fn fn_name(&self) -> &'static str { + use TrigonometricFunction::*; + match self { + Sin => "sin", + Cos => "cos", + Tan => "tan", + } + } +} + +impl DataArray +where + T: DaftFloatType, + T::Native: Float, +{ + pub fn trigonometry(&self, func: &TrigonometricFunction) -> DaftResult { + use TrigonometricFunction::*; + match func { + Sin => self.apply(|v| v.sin()), + Cos => self.apply(|v| v.cos()), + Tan => self.apply(|v| v.tan()), + } + } +} diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 495f160905..153cf60d51 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -18,6 +18,7 @@ use crate::{ use super::datatype::PyDataType; use crate::array::ops::as_arrow::AsArrow; +use crate::array::ops::trigonometry::TrigonometricFunction; #[pyclass] #[derive(Clone)] @@ -129,6 +130,27 @@ impl PySeries { Ok(self.series.round(decimal)?.into()) } + pub fn sin(&self) -> PyResult { + Ok(self + .series + .trigonometry(&TrigonometricFunction::Sin)? + .into()) + } + + pub fn cos(&self) -> PyResult { + Ok(self + .series + .trigonometry(&TrigonometricFunction::Cos)? + .into()) + } + + pub fn tan(&self) -> PyResult { + Ok(self + .series + .trigonometry(&TrigonometricFunction::Tan)? + .into()) + } + pub fn take(&self, idx: &Self) -> PyResult { Ok(self.series.take(&idx.series)?.into()) } diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index 5793c4ed46..17db504780 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -33,6 +33,7 @@ pub mod sign; pub mod sort; pub mod struct_; pub mod take; +mod trigonometry; pub mod utf8; fn match_types_on_series(l: &Series, r: &Series) -> DaftResult<(Series, Series)> { diff --git a/src/daft-core/src/series/ops/trigonometry.rs b/src/daft-core/src/series/ops/trigonometry.rs new file mode 100644 index 0000000000..c37bb088f6 --- /dev/null +++ b/src/daft-core/src/series/ops/trigonometry.rs @@ -0,0 +1,30 @@ +use crate::array::ops::trigonometry::TrigonometricFunction; +use crate::datatypes::DataType; +use crate::series::Series; +use crate::IntoSeries; +use common_error::DaftError; +use common_error::DaftResult; + +impl Series { + pub fn trigonometry(&self, trig_function: &TrigonometricFunction) -> DaftResult { + match self.data_type() { + DataType::Float32 => { + let ca = self.f32().unwrap(); + Ok(ca.trigonometry(trig_function)?.into_series()) + } + DataType::Float64 => { + let ca = self.f64().unwrap(); + Ok(ca.trigonometry(trig_function)?.into_series()) + } + dt if dt.is_numeric() => { + let s = self.cast(&DataType::Float64)?; + let ca = s.f64().unwrap(); + Ok(ca.trigonometry(trig_function)?.into_series()) + } + dt => Err(DaftError::TypeError(format!( + "Expected input to trigonometry to be numeric, got {}", + dt + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/numeric/mod.rs b/src/daft-dsl/src/functions/numeric/mod.rs index 4168c79b7a..49fdb2276c 100644 --- a/src/daft-dsl/src/functions/numeric/mod.rs +++ b/src/daft-dsl/src/functions/numeric/mod.rs @@ -3,6 +3,7 @@ mod ceil; mod floor; mod round; mod sign; +mod trigonometry; use abs::AbsEvaluator; use ceil::CeilEvaluator; @@ -12,6 +13,7 @@ use sign::SignEvaluator; use serde::{Deserialize, Serialize}; +use crate::functions::numeric::trigonometry::{TrigonometricFunction, TrigonometryEvaluator}; use crate::Expr; use super::FunctionEvaluator; @@ -23,6 +25,9 @@ pub enum NumericExpr { Floor, Sign, Round(i32), + Sin, + Cos, + Tan, } impl NumericExpr { @@ -35,6 +40,9 @@ impl NumericExpr { Floor => &FloorEvaluator {}, Sign => &SignEvaluator {}, Round(_) => &RoundEvaluator {}, + Sin => &TrigonometryEvaluator(TrigonometricFunction::Sin), + Cos => &TrigonometryEvaluator(TrigonometricFunction::Cos), + Tan => &TrigonometryEvaluator(TrigonometricFunction::Tan), } } } @@ -73,3 +81,24 @@ pub fn round(input: &Expr, decimal: i32) -> Expr { inputs: vec![input.clone()], } } + +pub fn sin(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Numeric(NumericExpr::Sin), + inputs: vec![input.clone()], + } +} + +pub fn cos(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Numeric(NumericExpr::Cos), + inputs: vec![input.clone()], + } +} + +pub fn tan(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Numeric(NumericExpr::Tan), + inputs: vec![input.clone()], + } +} diff --git a/src/daft-dsl/src/functions/numeric/trigonometry.rs b/src/daft-dsl/src/functions/numeric/trigonometry.rs new file mode 100644 index 0000000000..de77684b35 --- /dev/null +++ b/src/daft-dsl/src/functions/numeric/trigonometry.rs @@ -0,0 +1,47 @@ +use common_error::{DaftError, DaftResult}; +pub use daft_core::array::ops::trigonometry::TrigonometricFunction; +use daft_core::datatypes::Field; +use daft_core::schema::Schema; +use daft_core::{DataType, Series}; + +use crate::functions::FunctionEvaluator; +use crate::Expr; + +pub(super) struct TrigonometryEvaluator(pub TrigonometricFunction); + +impl FunctionEvaluator for TrigonometryEvaluator { + fn fn_name(&self) -> &'static str { + self.0.fn_name() + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + let field = inputs.first().unwrap().to_field(schema)?; + let dtype = match field.dtype { + DataType::Float32 => DataType::Float32, + dt if dt.is_numeric() => DataType::Float64, + _ => { + return Err(DaftError::TypeError(format!( + "Expected input to trigonometry to be numeric, got {}", + field.dtype + ))) + } + }; + Ok(Field::new(field.name, dtype)) + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + inputs.first().unwrap().trigonometry(&self.0) + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index ca22c7f5a9..58fce904e3 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -225,6 +225,21 @@ impl PyExpr { Ok(round(&self.expr, decimal).into()) } + pub fn sin(&self) -> PyResult { + use functions::numeric::sin; + Ok(sin(&self.expr).into()) + } + + pub fn cos(&self) -> PyResult { + use functions::numeric::cos; + Ok(cos(&self.expr).into()) + } + + pub fn tan(&self) -> PyResult { + use functions::numeric::tan; + Ok(tan(&self.expr).into()) + } + pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult { Ok(self.expr.if_else(&if_true.expr, &if_false.expr).into()) } diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 2e7c6f7847..d82ebef8e2 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -123,6 +123,23 @@ def test_repr_functions_round() -> None: assert repr_out == repr(copied) +@pytest.mark.parametrize( + "fun", + [ + "sin", + "cos", + "tan", + ], +) +def test_repr_functions_trigonometry(fun: str) -> None: + a = col("a") + y = getattr(a, fun)() + repr_out = repr(y) + assert repr_out == f"{fun}(col(a))" + copied = copy.deepcopy(y) + assert repr_out == repr(copied) + + def test_repr_functions_day() -> None: a = col("a") y = a.dt.day() diff --git a/tests/expressions/typing/test_arithmetic.py b/tests/expressions/typing/test_arithmetic.py index 97afcce7d1..5c4666e3df 100644 --- a/tests/expressions/typing/test_arithmetic.py +++ b/tests/expressions/typing/test_arithmetic.py @@ -112,3 +112,21 @@ def test_round(unary_data_fixture): run_kernel=lambda: arg.round(0), resolvable=is_numeric(arg.datatype()), ) + + +@pytest.mark.parametrize( + "fun", + [ + "sin", + "cos", + "tan", + ], +) +def test_trigonometry(fun: str, unary_data_fixture): + arg = unary_data_fixture + assert_typing_resolve_vs_runtime_behavior( + data=(unary_data_fixture,), + expr=getattr(col(arg.name()), fun)(), + run_kernel=lambda: getattr(arg, fun)(), + resolvable=is_numeric(arg.datatype()), + ) diff --git a/tests/table/test_eval.py b/tests/table/test_eval.py index d6dea6f478..f5986c8f40 100644 --- a/tests/table/test_eval.py +++ b/tests/table/test_eval.py @@ -4,6 +4,7 @@ import math import operator as ops +import numpy as np import pyarrow as pa import pytest @@ -239,6 +240,29 @@ def test_table_sign_bad_input() -> None: table.eval_expression_list([col("a").sign()]) +@pytest.mark.parametrize( + "fun", + [ + "sin", + "cos", + "tan", + ], +) +def test_table_numeric_trigonometry(fun: str) -> None: + table = MicroPartition.from_pydict({"a": [0.0, math.pi, math.pi / 2, math.nan]}) + s = table.to_pandas()["a"] + np_result = getattr(np, fun)(s) + + trigonometry_table = table.eval_expression_list([getattr(col("a"), fun)()]) + assert ( + all( + x == y or (math.isnan(x) and math.isnan(y)) + for x, y in zip(trigonometry_table.get_column("a").to_pylist(), np_result.to_list()) + ) + is True + ) + + def test_table_numeric_round() -> None: from decimal import ROUND_HALF_UP, Decimal