Skip to content

Commit

Permalink
[FEAT]: sin/cos/tan expression implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Apr 13, 2024
1 parent 146037a commit c6d0c95
Show file tree
Hide file tree
Showing 15 changed files with 280 additions and 0 deletions.
6 changes: 6 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,9 @@ class PyExpr:
def floor(self) -> PyExpr: ...
def sign(self) -> PyExpr: ...
def round(self, decimal: int) -> PyExpr: ...
def sin(self) -> PyExpr: ...
def cos(self) -> PyExpr: ...
def tan(self) -> PyExpr: ...
def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ...
def count(self, mode: CountMode) -> PyExpr: ...
def sum(self) -> PyExpr: ...
Expand Down Expand Up @@ -1026,6 +1029,9 @@ class PySeries:
def floor(self) -> PySeries: ...
def sign(self) -> PySeries: ...
def round(self, decimal: int) -> PySeries: ...
def sin(self) -> PySeries: ...
def cos(self) -> PySeries: ...
def tan(self) -> PySeries: ...
@staticmethod
def concat(series: list[PySeries]) -> PySeries: ...
def __len__(self) -> int: ...
Expand Down
15 changes: 15 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,21 @@ def round(self, decimals: int = 0) -> Expression:
expr = self._expr.round(decimals)
return Expression._from_pyexpr(expr)

def sin(self) -> Expression:
"""The elementwise sine of a numeric expression (``expr.sin()``)"""
expr = self._expr.sin()
return Expression._from_pyexpr(expr)

def cos(self) -> Expression:
"""The elementwise cosine of a numeric expression (``expr.cos()``)"""
expr = self._expr.cos()
return Expression._from_pyexpr(expr)

def tan(self) -> Expression:
"""The elementwise tangent of a numeric expression (``expr.tan()``)"""
expr = self._expr.tan()
return Expression._from_pyexpr(expr)

def count(self, mode: CountMode = CountMode.Valid) -> Expression:
"""Counts the number of values in the expression.
Expand Down
12 changes: 12 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,18 @@ def sign(self) -> Series:
def round(self, decimal: int) -> Series:
return Series._from_pyseries(self._series.round(decimal))

def sin(self) -> Series:
"""The elementwise sine of a numeric series."""
return Series._from_pyseries(self._series.sin())

def cos(self) -> Series:
"""The elementwise cosine of a numeric series."""
return Series._from_pyseries(self._series.cos())

def tan(self) -> Series:
"""The elementwise tangent of a numeric series."""
return Series._from_pyseries(self._series.tan())

def __add__(self, other: object) -> Series:
if not isinstance(other, Series):
raise TypeError(f"expected another Series but got {type(other)}")
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ Numeric
Expression.floor
Expression.sign
Expression.round
Expression.sin
Expression.cos
Expression.tan

.. _api-comparison-expression:

Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ mod struct_;
mod sum;
mod take;
pub(crate) mod tensor;
pub mod trigonometry;
mod truncate;
mod utf8;

Expand Down
40 changes: 40 additions & 0 deletions src/daft-core/src/array/ops/trigonometry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use num_traits::Float;
use serde::{Deserialize, Serialize};

use common_error::DaftResult;

use crate::array::DataArray;
use crate::datatypes::DaftFloatType;

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum TrigonometricFunction {
Sin,
Cos,
Tan,
}

impl TrigonometricFunction {
pub fn fn_name(&self) -> &'static str {
use TrigonometricFunction::*;
match self {
Sin => "sin",
Cos => "cos",
Tan => "tan",
}
}
}

impl<T> DataArray<T>
where
T: DaftFloatType,
T::Native: Float,
{
pub fn trigonometry(&self, func: &TrigonometricFunction) -> DaftResult<Self> {
use TrigonometricFunction::*;
match func {
Sin => self.apply(|v| v.sin()),
Cos => self.apply(|v| v.cos()),
Tan => self.apply(|v| v.tan()),
}
}
}
22 changes: 22 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{

use super::datatype::PyDataType;
use crate::array::ops::as_arrow::AsArrow;
use crate::array::ops::trigonometry::TrigonometricFunction;

#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -129,6 +130,27 @@ impl PySeries {
Ok(self.series.round(decimal)?.into())
}

pub fn sin(&self) -> PyResult<Self> {
Ok(self
.series
.trigonometry(&TrigonometricFunction::Sin)?
.into())
}

pub fn cos(&self) -> PyResult<Self> {
Ok(self
.series
.trigonometry(&TrigonometricFunction::Cos)?
.into())
}

pub fn tan(&self) -> PyResult<Self> {
Ok(self
.series
.trigonometry(&TrigonometricFunction::Tan)?
.into())
}

pub fn take(&self, idx: &Self) -> PyResult<Self> {
Ok(self.series.take(&idx.series)?.into())
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod sign;
pub mod sort;
pub mod struct_;
pub mod take;
mod trigonometry;
pub mod utf8;

fn match_types_on_series(l: &Series, r: &Series) -> DaftResult<(Series, Series)> {
Expand Down
30 changes: 30 additions & 0 deletions src/daft-core/src/series/ops/trigonometry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::array::ops::trigonometry::TrigonometricFunction;
use crate::datatypes::DataType;
use crate::series::Series;
use crate::IntoSeries;
use common_error::DaftError;
use common_error::DaftResult;

impl Series {
pub fn trigonometry(&self, trig_function: &TrigonometricFunction) -> DaftResult<Self> {
match self.data_type() {
DataType::Float32 => {
let ca = self.f32().unwrap();
Ok(ca.trigonometry(trig_function)?.into_series())
}
DataType::Float64 => {
let ca = self.f64().unwrap();
Ok(ca.trigonometry(trig_function)?.into_series())
}
dt if dt.is_numeric() => {
let s = self.cast(&DataType::Float64)?;
let ca = s.f64().unwrap();
Ok(ca.trigonometry(trig_function)?.into_series())
}
dt => Err(DaftError::TypeError(format!(
"Expected input to trigonometry to be numeric, got {}",
dt
))),
}
}
}
29 changes: 29 additions & 0 deletions src/daft-dsl/src/functions/numeric/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod ceil;
mod floor;
mod round;
mod sign;
mod trigonometry;

use abs::AbsEvaluator;
use ceil::CeilEvaluator;
Expand All @@ -12,6 +13,7 @@ use sign::SignEvaluator;

use serde::{Deserialize, Serialize};

use crate::functions::numeric::trigonometry::{TrigonometricFunction, TrigonometryEvaluator};
use crate::Expr;

use super::FunctionEvaluator;
Expand All @@ -23,6 +25,9 @@ pub enum NumericExpr {
Floor,
Sign,
Round(i32),
Sin,
Cos,
Tan,
}

impl NumericExpr {
Expand All @@ -35,6 +40,9 @@ impl NumericExpr {
Floor => &FloorEvaluator {},
Sign => &SignEvaluator {},
Round(_) => &RoundEvaluator {},
Sin => &TrigonometryEvaluator(TrigonometricFunction::Sin),
Cos => &TrigonometryEvaluator(TrigonometricFunction::Cos),
Tan => &TrigonometryEvaluator(TrigonometricFunction::Tan),
}
}
}
Expand Down Expand Up @@ -73,3 +81,24 @@ pub fn round(input: &Expr, decimal: i32) -> Expr {
inputs: vec![input.clone()],
}
}

pub fn sin(input: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Numeric(NumericExpr::Sin),
inputs: vec![input.clone()],
}
}

pub fn cos(input: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Numeric(NumericExpr::Cos),
inputs: vec![input.clone()],
}
}

pub fn tan(input: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Numeric(NumericExpr::Tan),
inputs: vec![input.clone()],
}
}
47 changes: 47 additions & 0 deletions src/daft-dsl/src/functions/numeric/trigonometry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use common_error::{DaftError, DaftResult};
pub use daft_core::array::ops::trigonometry::TrigonometricFunction;
use daft_core::datatypes::Field;
use daft_core::schema::Schema;
use daft_core::{DataType, Series};

use crate::functions::FunctionEvaluator;
use crate::Expr;

pub(super) struct TrigonometryEvaluator(pub TrigonometricFunction);

impl FunctionEvaluator for TrigonometryEvaluator {
fn fn_name(&self) -> &'static str {
self.0.fn_name()
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
if inputs.len() != 1 {
return Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
)));
};
let field = inputs.first().unwrap().to_field(schema)?;
let dtype = match field.dtype {
DataType::Float32 => DataType::Float32,
dt if dt.is_numeric() => DataType::Float64,
_ => {
return Err(DaftError::TypeError(format!(
"Expected input to trigonometry to be numeric, got {}",
field.dtype
)))
}
};
Ok(Field::new(field.name, dtype))
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
if inputs.len() != 1 {
return Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
)));
}
inputs.first().unwrap().trigonometry(&self.0)
}
}
15 changes: 15 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,21 @@ impl PyExpr {
Ok(round(&self.expr, decimal).into())
}

pub fn sin(&self) -> PyResult<Self> {
use functions::numeric::sin;
Ok(sin(&self.expr).into())
}

pub fn cos(&self) -> PyResult<Self> {
use functions::numeric::cos;
Ok(cos(&self.expr).into())
}

pub fn tan(&self) -> PyResult<Self> {
use functions::numeric::tan;
Ok(tan(&self.expr).into())
}

pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult<Self> {
Ok(self.expr.if_else(&if_true.expr, &if_false.expr).into())
}
Expand Down
17 changes: 17 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def test_repr_functions_round() -> None:
assert repr_out == repr(copied)


@pytest.mark.parametrize(
"fun",
[
"sin",
"cos",
"tan",
],
)
def test_repr_functions_trigonometry(fun: str) -> None:
a = col("a")
y = getattr(a, fun)()
repr_out = repr(y)
assert repr_out == f"{fun}(col(a))"
copied = copy.deepcopy(y)
assert repr_out == repr(copied)


def test_repr_functions_day() -> None:
a = col("a")
y = a.dt.day()
Expand Down
18 changes: 18 additions & 0 deletions tests/expressions/typing/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,21 @@ def test_round(unary_data_fixture):
run_kernel=lambda: arg.round(0),
resolvable=is_numeric(arg.datatype()),
)


@pytest.mark.parametrize(
"fun",
[
"sin",
"cos",
"tan",
],
)
def test_trigonometry(fun: str, unary_data_fixture):
arg = unary_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=(unary_data_fixture,),
expr=getattr(col(arg.name()), fun)(),
run_kernel=lambda: getattr(arg, fun)(),
resolvable=is_numeric(arg.datatype()),
)
Loading

0 comments on commit c6d0c95

Please sign in to comment.