Skip to content

Commit

Permalink
[FEAT] Support zero_lit in Rust/Python side
Browse files Browse the repository at this point in the history
This commit also fixes a bug when converting struct literal to series
  • Loading branch information
advancedxy committed Oct 29, 2024
1 parent 5228930 commit ced8c4b
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 32 deletions.
3 changes: 2 additions & 1 deletion daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def refresh_logger() -> None:
from daft.dataframe import DataFrame
from daft.logical.schema import Schema
from daft.datatype import DataType, TimeUnit
from daft.expressions import Expression, col, lit, interval
from daft.expressions import Expression, col, lit, interval, zero_lit
from daft.io import (
DataCatalogTable,
DataCatalogType,
Expand Down Expand Up @@ -120,6 +120,7 @@ def refresh_logger() -> None:
"ImageMode",
"ImageFormat",
"lit",
"zero_lit",
"Series",
"TimeUnit",
"register_viz_hook",
Expand Down
2 changes: 2 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,7 @@ class PyExpr:
def eq(expr1: PyExpr, expr2: PyExpr) -> bool: ...
def col(name: str) -> PyExpr: ...
def lit(item: Any) -> PyExpr: ...
def zero_lit(dt: PyDataType) -> PyExpr: ...
def date_lit(item: int) -> PyExpr: ...
def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ...
def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
Expand Down Expand Up @@ -1733,6 +1734,7 @@ class LogicalPlanBuilder:
join_suffix: str | None = None,
) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ...
def table_write(
self,
Expand Down
4 changes: 2 additions & 2 deletions daft/expressions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

from .expressions import Expression, ExpressionsProjection, col, lit, interval
from .expressions import Expression, ExpressionsProjection, col, lit, interval, zero_lit

__all__ = ["Expression", "ExpressionsProjection", "col", "lit", "interval"]
__all__ = ["Expression", "ExpressionsProjection", "col", "lit", "interval", "zero_lit"]
34 changes: 34 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from daft.daft import tokenize_encode as _tokenize_encode
from daft.daft import url_download as _url_download
from daft.daft import utf8_count_matches as _utf8_count_matches
from daft.daft import zero_lit as _zero_lit
from daft.datatype import DataType, TimeUnit
from daft.dependencies import pa
from daft.expressions.testing import expr_structurally_equal
Expand Down Expand Up @@ -133,6 +134,39 @@ def lit(value: object) -> Expression:
return Expression._from_pyexpr(lit_value)


def zero_lit(dt: DataType) -> Expression:
"""Creates a literal Expression representing a zero value of corresponding data type
Example:
>>> import daft
>>> from daft import DataType
>>> df = daft.from_pydict({"x": [1, 2, 3]})
>>> df = df.with_column("y", daft.zero_lit(DataType.int32()))
>>> df.show()
╭───────┬───────╮
│ x ┆ y │
│ --- ┆ --- │
│ Int64 ┆ Int32 │
╞═══════╪═══════╡
│ 1 ┆ 0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 0 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 0 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 3 of 3 rows)
Args:
dt: data type of the zero value
Returns:
Expression: representing the zero value of the data type
"""
zero = _zero_lit(dt._dtype)
return Expression._from_pyexpr(zero)


def col(name: str) -> Expression:
"""Creates an Expression referring to the column with the provided name.
Expand Down
3 changes: 2 additions & 1 deletion src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub use expr::{
binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr,
ApproxPercentileParams, Expr, ExprRef, Operator, SketchType,
};
pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue};
pub use lit::{lit, literal_value, literals_to_series, null_lit, zero_lit, Literal, LiteralValue};
#[cfg(feature = "python")]
use pyo3::prelude::*;
pub use resolve_expr::{
Expand All @@ -39,6 +39,7 @@ pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_function(wrap_pyfunction_bound!(python::interval_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::decimal_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::series_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::zero_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::stateless_udf, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::stateful_udf, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(
Expand Down
Loading

0 comments on commit ced8c4b

Please sign in to comment.