Skip to content

Commit

Permalink
[FEAT] fill_null expression (#2089)
Browse files Browse the repository at this point in the history
Closes #1904
  • Loading branch information
colin-ho authored Apr 11, 2024
1 parent 93fc6ca commit acb8203
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 8 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ class PyExpr:
def __ne__(self, other: PyExpr) -> PyExpr: ... # type: ignore[override]
def is_null(self) -> PyExpr: ...
def not_null(self) -> PyExpr: ...
def fill_null(self, fill_value: PyExpr) -> PyExpr: ...
def is_in(self, other: PyExpr) -> PyExpr: ...
def name(self) -> str: ...
def to_field(self, schema: PySchema) -> PyField: ...
Expand Down Expand Up @@ -1068,6 +1069,7 @@ class PySeries:
def if_else(self, other: PySeries, predicate: PySeries) -> PySeries: ...
def is_null(self) -> PySeries: ...
def not_null(self) -> PySeries: ...
def fill_null(self, fill_value: PySeries) -> PySeries: ...
def murmur3_32(self) -> PySeries: ...
def to_str_values(self) -> PySeries: ...
def _debug_bincode_serialize(self) -> bytes: ...
Expand Down
27 changes: 27 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,33 @@ def not_null(self) -> Expression:
expr = self._expr.not_null()
return Expression._from_pyexpr(expr)

def fill_null(self, fill_value: Expression) -> Expression:
"""Fills null values in the Expression with the provided fill_value
Example:
>>> df = daft.from_pydict({"data": [1, None, 3]})
>>> df = df.select(df["data"].fill_null(2))
>>> df.collect()
╭───────╮
│ data │
│ --- │
│ Int64 │
╞═══════╡
│ 1 │
├╌╌╌╌╌╌╌┤
│ 2 │
├╌╌╌╌╌╌╌┤
│ 3 │
╰───────╯
Returns:
Expression: Expression with null values filled with the provided fill_value
"""

fill_value = Expression._to_expression(fill_value)
expr = self._expr.fill_null(fill_value._expr)
return Expression._from_pyexpr(expr)

def is_in(self, other: Any) -> Expression:
"""Checks if values in the Expression are in the provided list
Expand Down
6 changes: 6 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,12 @@ def not_null(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.not_null())

def fill_null(self, fill_value: object) -> Series:
if not isinstance(fill_value, Series):
raise ValueError(f"expected another Series but got {type(fill_value)}")
assert self._series is not None and fill_value._series is not None
return Series._from_pyseries(self._series.fill_null(fill_value._series))

def _to_str_values(self) -> Series:
return Series._from_pyseries(self._series.to_str_values())

Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ impl PySeries {
Ok(self.series.not_null()?.into())
}

pub fn fill_null(&self, fill_value: &Self) -> PyResult<Self> {
Ok(self.series.fill_null(&fill_value.series)?.into())
}

pub fn _debug_bincode_serialize(&self, py: Python) -> PyResult<PyObject> {
let values = bincode::serialize(&self.series).unwrap();
Ok(PyBytes::new(py, &values).to_object(py))
Expand Down
5 changes: 5 additions & 0 deletions src/daft-core/src/series/ops/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@ impl Series {
pub fn not_null(&self) -> DaftResult<Series> {
self.inner.not_null()
}

pub fn fill_null(&self, fill_value: &Series) -> DaftResult<Series> {
let predicate = self.not_null()?;
self.if_else(fill_value, &predicate)
}
}
37 changes: 31 additions & 6 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub enum Expr {
Not(ExprRef),
IsNull(ExprRef),
NotNull(ExprRef),
FillNull(ExprRef, ExprRef),
IsIn(ExprRef, ExprRef),
Literal(lit::LiteralValue),
IfElse {
Expand Down Expand Up @@ -279,6 +280,10 @@ impl Expr {
Expr::NotNull(self.clone().into())
}

pub fn fill_null(&self, fill_value: &Self) -> Self {
Expr::FillNull(self.clone().into(), fill_value.clone().into())
}

pub fn is_in(&self, items: &Self) -> Self {
Expr::IsIn(self.clone().into(), items.clone().into())
}
Expand Down Expand Up @@ -342,6 +347,11 @@ impl Expr {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.not_null()"))
}
FillNull(expr, fill_value) => {
let child_id = expr.semantic_id(schema);
let fill_value_id = fill_value.semantic_id(schema);
FieldID::new(format!("{child_id}.fill_null({fill_value_id})"))
}
IsIn(expr, items) => {
let child_id = expr.semantic_id(schema);
let items_id = items.semantic_id(schema);
Expand Down Expand Up @@ -400,6 +410,7 @@ impl Expr {
} => {
vec![predicate.clone(), if_true.clone(), if_false.clone()]
}
FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()],
}
}

Expand All @@ -421,6 +432,16 @@ impl Expr {
}
IsNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)),
NotNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)),
FillNull(expr, fill_value) => {
let expr_field = expr.to_field(schema)?;
let fill_value_field = fill_value.to_field(schema)?;
match try_get_supertype(&expr_field.dtype, &fill_value_field.dtype) {
Ok(supertype) => Ok(Field::new(expr_field.name.as_str(), supertype)),
Err(_) => Err(DaftError::TypeError(format!(
"Expected expr and fill_value arguments for fill_null to be castable to the same supertype, but received {expr_field} and {fill_value_field}",
)))
}
}
IsIn(left, right) => {
let left_field = left.to_field(schema)?;
let right_field = right.to_field(schema)?;
Expand Down Expand Up @@ -510,6 +531,7 @@ impl Expr {
Not(expr) => expr.name(),
IsNull(expr) => expr.name(),
NotNull(expr) => expr.name(),
FillNull(expr, ..) => expr.name(),
IsIn(expr, ..) => expr.name(),
Literal(..) => Ok("literal"),
Function { func, inputs } => match func {
Expand Down Expand Up @@ -598,12 +620,14 @@ impl Expr {
write!(buffer, " END")
}
// TODO: Implement SQL translations for these expressions if possible
Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) | Expr::Function { .. } => {
Err(io::Error::new(
io::ErrorKind::Other,
"Unsupported expression for SQL translation",
))
}
Expr::Agg(..)
| Expr::Cast(..)
| Expr::IsIn(..)
| Expr::Function { .. }
| Expr::FillNull(..) => Err(io::Error::new(
io::ErrorKind::Other,
"Unsupported expression for SQL translation",
)),
}
}

Expand Down Expand Up @@ -638,6 +662,7 @@ impl Display for Expr {
Not(expr) => write!(f, "not({expr})"),
IsNull(expr) => write!(f, "is_null({expr})"),
NotNull(expr) => write!(f, "not_null({expr})"),
FillNull(expr, fill_value) => write!(f, "fill_null({expr}, {fill_value})"),
IsIn(expr, items) => write!(f, "{expr} in {items}"),
Literal(val) => write!(f, "lit({val})"),
Function { func, inputs } => function_display(f, func, inputs),
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/src/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub fn requires_computation(e: &Expr) -> bool {
| Expr::Not(..)
| Expr::IsNull(..)
| Expr::NotNull(..)
| Expr::FillNull(..)
| Expr::IsIn { .. }
| Expr::IfElse { .. } => true,
}
Expand Down
4 changes: 4 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ impl PyExpr {
Ok(self.expr.not_null().into())
}

pub fn fill_null(&self, fill_value: &Self) -> PyResult<Self> {
Ok(self.expr.fill_null(&fill_value.expr).into())
}

pub fn is_in(&self, other: &Self) -> PyResult<Self> {
Ok(self.expr.is_in(&other.expr).into())
}
Expand Down
5 changes: 5 additions & 0 deletions src/daft-dsl/src/treenode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ impl TreeNode for Expr {
}
BinaryOp { op: _, left, right } => vec![left.as_ref(), right.as_ref()],
IsIn(expr, items) => vec![expr.as_ref(), items.as_ref()],
FillNull(expr, fill_value) => vec![expr.as_ref(), fill_value.as_ref()],
Column(_) | Literal(_) => vec![],
Function { func: _, inputs } => inputs.iter().collect::<Vec<_>>(),
IfElse {
Expand Down Expand Up @@ -83,6 +84,10 @@ impl TreeNode for Expr {
Not(expr) => Not(transform(expr.as_ref().clone())?.into()),
IsNull(expr) => IsNull(transform(expr.as_ref().clone())?.into()),
NotNull(expr) => NotNull(transform(expr.as_ref().clone())?.into()),
FillNull(expr, fill_value) => FillNull(
transform(expr.as_ref().clone())?.into(),
transform(fill_value.as_ref().clone())?.into(),
),
IsIn(expr, items) => IsIn(
transform(expr.as_ref().clone())?.into(),
transform(items.as_ref().clone())?.into(),
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn check_for_agg(expr: &Expr) -> bool {
Alias(e, _) | Cast(e, _) | Not(e) | IsNull(e) | NotNull(e) => check_for_agg(e),
BinaryOp { left, right, .. } => check_for_agg(left) || check_for_agg(right),
Function { inputs, .. } => inputs.iter().any(check_for_agg),
IsIn(l, r) => check_for_agg(l) || check_for_agg(r),
IsIn(l, r) | FillNull(l, r) => check_for_agg(l) || check_for_agg(r),
IfElse {
if_true,
if_false,
Expand Down
16 changes: 16 additions & 0 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,22 @@ fn replace_column_with_semantic_id(
|_| e,
)
}
Expr::FillNull(child, fill_value) => {
let child =
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema);
let fill_value = replace_column_with_semantic_id(
fill_value.clone(),
subexprs_to_replace,
schema,
);
if child.is_no() && fill_value.is_no() {
Transformed::No(e)
} else {
Transformed::Yes(
Expr::FillNull(child.unwrap().clone(), fill_value.unwrap().clone()).into(),
)
}
}
Expr::IsIn(child, items) => {
let child =
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema);
Expand Down
11 changes: 11 additions & 0 deletions src/daft-plan/src/physical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ impl Project {
)?;
Ok(Expr::NotNull(newchild.into()))
}
Expr::FillNull(child, fill_value) => {
let newchild = Self::translate_clustering_spec_expr(
child.as_ref(),
old_colname_to_new_colname,
)?;
let newfill = Self::translate_clustering_spec_expr(
fill_value.as_ref(),
old_colname_to_new_colname,
)?;
Ok(Expr::FillNull(newchild.into(), newfill.into()))
}
Expr::IsIn(child, items) => {
let newchild = Self::translate_clustering_spec_expr(
child.as_ref(),
Expand Down
4 changes: 4 additions & 0 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ impl Table {
Not(child) => !(self.eval_expression(child)?),
IsNull(child) => self.eval_expression(child)?.is_null(),
NotNull(child) => self.eval_expression(child)?.not_null(),
FillNull(child, fill_value) => {
let fill_value = self.eval_expression(fill_value)?;
self.eval_expression(child)?.fill_null(&fill_value)
}
IsIn(child, items) => self
.eval_expression(child)?
.is_in(&self.eval_expression(items)?),
Expand Down
15 changes: 14 additions & 1 deletion tests/expressions/typing/test_null.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from daft.expressions import col
from tests.expressions.typing.conftest import assert_typing_resolve_vs_runtime_behavior
from tests.expressions.typing.conftest import (
assert_typing_resolve_vs_runtime_behavior,
has_supertype,
)


def test_is_null(unary_data_fixture):
Expand All @@ -22,3 +25,13 @@ def test_not_null(unary_data_fixture):
run_kernel=lambda: arg.not_null(),
resolvable=True,
)


def test_fill_null(binary_data_fixture):
lhs, rhs = binary_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=binary_data_fixture,
expr=col(lhs.name()).fill_null(col(rhs.name())),
run_kernel=lambda: lhs.fill_null(rhs),
resolvable=has_supertype(lhs.datatype(), rhs.datatype()),
)
32 changes: 32 additions & 0 deletions tests/series/test_fill_null.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

import pyarrow as pa
import pytest

from daft.series import Series


@pytest.mark.parametrize(
"input,fill_value,expected",
[
# No broadcast
[[1, 2, None], [3, 3, 3], [1, 2, 3]],
# Broadcast input
[[None], [3, 3, 3], [3, 3, 3]],
# Broadcast fill_value
[[1, 2, None], [3], [1, 2, 3]],
# Empty
[[], [], []],
],
)
def test_series_fill_null(input, fill_value, expected) -> None:
s = Series.from_arrow(pa.array(input, pa.int64()))
fill_value = Series.from_arrow(pa.array(fill_value, pa.int64()))
filled = s.fill_null(fill_value)
assert filled.to_pylist() == expected


def test_series_fill_null_bad_input() -> None:
s = Series.from_arrow(pa.array([1, 2, 3], pa.int64()))
with pytest.raises(ValueError, match="expected another Series but got"):
s.fill_null([1, 2, 3])
37 changes: 37 additions & 0 deletions tests/table/test_fill_null.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import datetime

import pytest

from daft.expressions.expressions import col
from daft.table.micropartition import MicroPartition


@pytest.mark.parametrize(
"input,fill_value,expected",
[
pytest.param([None, None, None], "a", ["a", "a", "a"], id="NullColumn"),
pytest.param([True, False, None], False, [True, False, False], id="BoolColumn"),
pytest.param(["a", "b", None], "b", ["a", "b", "b"], id="StringColumn"),
pytest.param([b"a", None, b"c"], b"b", [b"a", b"b", b"c"], id="BinaryColumn"),
pytest.param([-1, None, 3], 0, [-1, 0, 3], id="IntColumn"),
pytest.param([-1.0, None, 3.0], 0.0, [-1.0, 0.0, 3.0], id="FloatColumn"),
pytest.param(
[datetime.date.today(), None, datetime.date(2023, 1, 1)],
datetime.date(2022, 1, 1),
[datetime.date.today(), datetime.date(2022, 1, 1), datetime.date(2023, 1, 1)],
),
pytest.param(
[datetime.datetime(2022, 1, 1), None, datetime.datetime(2023, 1, 1)],
datetime.datetime(2022, 1, 1),
[datetime.datetime(2022, 1, 1), datetime.datetime(2022, 1, 1), datetime.datetime(2023, 1, 1)],
),
],
)
def test_table_expr_fill_null(input, fill_value, expected) -> None:
daft_table = MicroPartition.from_pydict({"input": input})
daft_table = daft_table.eval_expression_list([col("input").fill_null(fill_value)])
pydict = daft_table.to_pydict()

assert pydict["input"] == expected

0 comments on commit acb8203

Please sign in to comment.