Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: sql float operations #2834

Merged
merged 12 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1106,10 +1106,6 @@ class PyExpr:
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __reduce__(self) -> tuple: ...
def is_nan(self) -> PyExpr: ...
def is_inf(self) -> PyExpr: ...
def not_nan(self) -> PyExpr: ...
def fill_nan(self, fill_value: PyExpr) -> PyExpr: ...
def dt_date(self) -> PyExpr: ...
def dt_day(self) -> PyExpr: ...
def dt_hour(self) -> PyExpr: ...
Expand Down Expand Up @@ -1255,6 +1251,14 @@ def image_encode(expr: PyExpr, image_format: ImageFormat) -> PyExpr: ...
def image_resize(expr: PyExpr, w: int, h: int) -> PyExpr: ...
def image_to_mode(expr: PyExpr, mode: ImageMode) -> PyExpr: ...

# ---
# expr.float namespace
# ---
def is_nan(expr: PyExpr) -> PyExpr: ...
def is_inf(expr: PyExpr) -> PyExpr: ...
def not_nan(expr: PyExpr) -> PyExpr: ...
def fill_nan(expr: PyExpr, fill_value: PyExpr) -> PyExpr: ...

# ---
# expr.json namespace
# ---
Expand Down
8 changes: 4 additions & 4 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ def is_nan(self) -> Expression:
Returns:
Expression: Boolean Expression indicating whether values are invalid.
"""
return Expression._from_pyexpr(self._expr.is_nan())
return Expression._from_pyexpr(native.is_nan(self._expr))

def is_inf(self) -> Expression:
"""Checks if values in the Expression are Infinity.
Expand Down Expand Up @@ -1394,7 +1394,7 @@ def is_inf(self) -> Expression:
Returns:
Expression: Boolean Expression indicating whether values are Infinity.
"""
return Expression._from_pyexpr(self._expr.is_inf())
return Expression._from_pyexpr(native.is_inf(self._expr))

def not_nan(self) -> Expression:
"""Checks if values are not NaN (a special float value indicating not-a-number)
Expand Down Expand Up @@ -1424,7 +1424,7 @@ def not_nan(self) -> Expression:
Returns:
Expression: Boolean Expression indicating whether values are not invalid.
"""
return Expression._from_pyexpr(self._expr.not_nan())
return Expression._from_pyexpr(native.not_nan(self._expr))

def fill_nan(self, fill_value: Expression) -> Expression:
"""Fills NaN values in the Expression with the provided fill_value
Expand Down Expand Up @@ -1453,7 +1453,7 @@ def fill_nan(self, fill_value: Expression) -> Expression:
"""

fill_value = Expression._to_expression(fill_value)
expr = self._expr.fill_nan(fill_value._expr)
expr = native.fill_nan(self._expr, fill_value._expr)
return Expression._from_pyexpr(expr)


Expand Down
46 changes: 0 additions & 46 deletions src/daft-dsl/src/functions/float/fill_nan.rs

This file was deleted.

67 changes: 0 additions & 67 deletions src/daft-dsl/src/functions/float/mod.rs

This file was deleted.

47 changes: 0 additions & 47 deletions src/daft-dsl/src/functions/float/not_nan.rs

This file was deleted.

4 changes: 0 additions & 4 deletions src/daft-dsl/src/functions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pub mod float;
pub mod list;
pub mod map;
pub mod numeric;
Expand All @@ -15,7 +14,6 @@ use std::hash::Hash;

use crate::{Expr, ExprRef, Operator};

use self::float::FloatExpr;
use self::list::ListExpr;
use self::map::MapExpr;
use self::numeric::NumericExpr;
Expand All @@ -37,7 +35,6 @@ use python::PythonUDF;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum FunctionExpr {
Numeric(NumericExpr),
Float(FloatExpr),
Utf8(Utf8Expr),
Temporal(TemporalExpr),
List(ListExpr),
Expand Down Expand Up @@ -65,7 +62,6 @@ impl FunctionExpr {
use FunctionExpr::*;
match self {
Numeric(expr) => expr.get_evaluator(),
Float(expr) => expr.get_evaluator(),
Utf8(expr) => expr.get_evaluator(),
Temporal(expr) => expr.get_evaluator(),
List(expr) => expr.get_evaluator(),
Expand Down
20 changes: 0 additions & 20 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,26 +579,6 @@ impl PyExpr {
hasher.finish()
}

pub fn is_nan(&self) -> PyResult<Self> {
use functions::float::is_nan;
Ok(is_nan(self.into()).into())
}

pub fn is_inf(&self) -> PyResult<Self> {
use functions::float::is_inf;
Ok(is_inf(self.into()).into())
}

pub fn not_nan(&self) -> PyResult<Self> {
use functions::float::not_nan;
Ok(not_nan(self.into()).into())
}

pub fn fill_nan(&self, fill_value: &Self) -> PyResult<Self> {
use functions::float::fill_nan;
Ok(fill_nan(self.into(), fill_value.expr.clone()).into())
}

pub fn dt_date(&self) -> PyResult<Self> {
use functions::temporal::date;
Ok(date(self.into()).into())
Expand Down
70 changes: 70 additions & 0 deletions src/daft-functions/src/float/fill_nan.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have a scalar_udf macro for this? If we do, maybe we could utilize that to avoid all of this boilerplate...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly for the other functions inside of daft-functions/src/float as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to come up with a macro for this, but I think the consensus was that it wasn't intuitive & could make things harder to debug later on. So I just stuck with copy and paste.

Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
prelude::{Field, Schema},
series::Series,
utils::supertype::try_get_supertype,
};
use daft_dsl::{
functions::{ScalarFunction, ScalarUDF},
ExprRef,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct FillNan {}

#[typetag::serde]
impl ScalarUDF for FillNan {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &'static str {
"fill_nan"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[data, fill_value] => match (data.to_field(schema), fill_value.to_field(schema)) {
(Ok(data_field), Ok(fill_value_field)) => {
match (&data_field.dtype.is_floating(), &fill_value_field.dtype.is_floating(), try_get_supertype(&data_field.dtype, &fill_value_field.dtype)) {
(true, true, Ok(dtype)) => Ok(Field::new(data_field.name, dtype)),
_ => Err(DaftError::TypeError(format!(
"Expects input for fill_nan to be float, but received {data_field} and {fill_value_field}",
))),
}
}
(Err(e), _) | (_, Err(e)) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[data, fill_value] => data.fill_nan(fill_value),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}

pub fn fill_nan(input: ExprRef, fill_value: ExprRef) -> ExprRef {
ScalarFunction::new(FillNan {}, vec![input, fill_value]).into()
}

#[cfg(feature = "python")]
use {
daft_dsl::python::PyExpr,
pyo3::{pyfunction, PyResult},
};
#[cfg(feature = "python")]
#[pyfunction]
#[pyo3(name = "fill_nan")]
pub fn py_fill_nan(expr: PyExpr, fill_value: PyExpr) -> PyResult<PyExpr> {
Ok(fill_nan(expr.into(), fill_value.into()).into())
}
Loading
Loading