Skip to content

Commit

Permalink
[FEAT] is_in expression (#1811)
Browse files Browse the repository at this point in the history
Closes #993 

The `is_in` expression checks whether the values of a series are
contained in a given list of items, and produces a series of boolean
values as the results of this membership test.

Changes:
- Added a Literal Series so that Series can be passed into the
expression
- Added `is_in` expression and kernel
- Added tests
  • Loading branch information
colin-ho authored Jan 25, 2024
1 parent 446a669 commit 21cb2b5
Show file tree
Hide file tree
Showing 27 changed files with 599 additions and 33 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ class PyExpr:
def __ne__(self, other: PyExpr) -> PyExpr: ... # type: ignore[override]
def is_null(self) -> PyExpr: ...
def not_null(self) -> PyExpr: ...
def is_in(self, other: PyExpr) -> PyExpr: ...
def name(self) -> str: ...
def to_field(self, schema: PySchema) -> PyField: ...
def __repr__(self) -> str: ...
Expand Down Expand Up @@ -879,6 +880,7 @@ def col(name: str) -> PyExpr: ...
def lit(item: Any) -> PyExpr: ...
def date_lit(item: int) -> PyExpr: ...
def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
def series_lit(item: PySeries) -> PyExpr: ...
def udf(func: Callable, expressions: list[PyExpr], return_dtype: PyDataType) -> PyExpr: ...

class PySeries:
Expand Down
25 changes: 23 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import builtins
import sys
from datetime import date, datetime
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, overload

import pyarrow as pa

Expand All @@ -13,11 +13,13 @@
from daft.daft import col as _col
from daft.daft import date_lit as _date_lit
from daft.daft import lit as _lit
from daft.daft import series_lit as _series_lit
from daft.daft import timestamp_lit as _timestamp_lit
from daft.daft import udf as _udf
from daft.datatype import DataType, TimeUnit
from daft.expressions.testing import expr_structurally_equal
from daft.logical.schema import Field, Schema
from daft.series import Series, item_to_series

if sys.version_info < (3, 8):
from typing_extensions import Literal
Expand Down Expand Up @@ -51,6 +53,8 @@ def lit(value: object) -> Expression:
# pyo3 date (PyDate) is not available when running in abi3 mode, workaround
epoch_time = value - date(1970, 1, 1)
lit_value = _date_lit(epoch_time.days)
elif isinstance(value, Series):
lit_value = _series_lit(value._series)
else:
lit_value = _lit(value)
return Expression._from_pyexpr(lit_value)
Expand Down Expand Up @@ -392,6 +396,24 @@ def not_null(self) -> Expression:
expr = self._expr.not_null()
return Expression._from_pyexpr(expr)

def is_in(self, other: Any) -> Expression:
"""Checks if values in the Expression are in the provided list
Example:
>>> # [1, 2, 3] -> [True, False, True]
>>> col("x").is_in([1, 3])
Returns:
Expression: Boolean Expression indicating whether values are in the provided list
"""

if not isinstance(other, Expression):
series = item_to_series("items", other)
other = Expression._to_expression(series)

expr = self._expr.is_in(other._expr)
return Expression._from_pyexpr(expr)

def name(self) -> builtins.str:
return self._expr.name()

Expand Down Expand Up @@ -469,7 +491,6 @@ def download(
Expression: a Binary expression which is the bytes contents of the URL, or None if an error occured during download
"""
if use_native_downloader:

raise_on_error = False
if on_error == "raise":
raise_on_error = True
Expand Down
18 changes: 17 additions & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TypeVar
from typing import Any, TypeVar

import pyarrow as pa

Expand Down Expand Up @@ -515,6 +515,22 @@ def _debug_bincode_deserialize(cls, b: bytes) -> Series:
return Series._from_pyseries(PySeries._debug_bincode_deserialize(b))


def item_to_series(name: str, item: Any) -> Series:
if isinstance(item, list):
series = Series.from_pylist(item, name)
elif _NUMPY_AVAILABLE and isinstance(item, np.ndarray):
series = Series.from_numpy(item, name)
elif isinstance(item, Series):
series = item
elif isinstance(item, (pa.Array, pa.ChunkedArray)):
series = Series.from_arrow(item, name)
elif _PANDAS_AVAILABLE and isinstance(item, pd.Series):
series = Series.from_pandas(item, name)
else:
raise ValueError(f"Creating a Series from data of type {type(item)} not implemented")
return series


SomeSeriesNamespace = TypeVar("SomeSeriesNamespace", bound="SeriesNamespace")


Expand Down
15 changes: 2 additions & 13 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from daft.datatype import DataType, TimeUnit
from daft.expressions import Expression, ExpressionsProjection
from daft.logical.schema import Schema
from daft.series import Series
from daft.series import Series, item_to_series

_NUMPY_AVAILABLE = True
try:
Expand Down Expand Up @@ -148,18 +148,7 @@ def from_pandas(pd_df: pd.DataFrame) -> Table:
def from_pydict(data: dict) -> Table:
series_dict = dict()
for k, v in data.items():
if isinstance(v, list):
series = Series.from_pylist(v, name=k)
elif _NUMPY_AVAILABLE and isinstance(v, np.ndarray):
series = Series.from_numpy(v, name=k)
elif isinstance(v, Series):
series = v
elif isinstance(v, (pa.Array, pa.ChunkedArray)):
series = Series.from_arrow(v, name=k)
elif _PANDAS_AVAILABLE and isinstance(v, pd.Series):
series = Series.from_pandas(v, name=k)
else:
raise ValueError(f"Creating a Series from data of type {type(v)} not implemented")
series = item_to_series(k, v)
series_dict[k] = series._series
return Table._from_pytable(_PyTable.from_pylist_series(series_dict))

Expand Down
11 changes: 11 additions & 0 deletions daft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ def map_operator_arrow_semantics_bool(
]


def python_list_membership_check(
left_pylist: list,
right_pylist: list,
) -> list:
try:
right_pyset = set(right_pylist)
return [elem in right_pyset for elem in left_pylist]
except TypeError:
return [elem in right_pylist for elem in left_pylist]


def map_operator_arrow_semantics(
operator: Callable[[Any, Any], Any],
left_pylist: list,
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dyn-clone = "1.0.16"
fnv = "1.0.7"
html-escape = {workspace = true}
indexmap = {workspace = true, features = ["serde"]}
itertools = {workspace = true}
lazy_static = {workspace = true}
log = {workspace = true}
mur3 = "0.1.0"
Expand Down
16 changes: 15 additions & 1 deletion src/daft-core/src/array/from_iter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::datatypes::{BinaryArray, DaftNumericType, Field, Utf8Array};
use crate::datatypes::{BinaryArray, BooleanArray, DaftNumericType, Field, Utf8Array};

use super::DataArray;

Expand Down Expand Up @@ -41,3 +41,17 @@ impl BinaryArray {
.unwrap()
}
}

impl BooleanArray {
pub fn from_iter(
name: &str,
iter: impl Iterator<Item = Option<bool>> + arrow2::trusted_len::TrustedLen,
) -> Self {
let arrow_array = Box::new(arrow2::array::BooleanArray::from_trusted_len_iter(iter));
DataArray::new(
Field::new(name, crate::DataType::Boolean).into(),
arrow_array,
)
.unwrap()
}
}
93 changes: 93 additions & 0 deletions src/daft-core/src/array/ops/is_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use crate::{
array::DataArray,
datatypes::{
BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, Float32Array, Float64Array,
NullArray, Utf8Array,
},
DataType,
};

use super::as_arrow::AsArrow;
use super::{full::FullNull, DaftIsIn};
use crate::utils::hashable_float_wrapper::FloatWrapper;
use common_error::DaftResult;
use std::collections::{BTreeSet, HashSet};

macro_rules! collect_to_set_and_check_membership {
($self:expr, $rhs:expr) => {{
let set = $rhs
.as_arrow()
.iter()
.filter_map(|item| item)
.collect::<HashSet<_>>();
let result = $self
.as_arrow()
.iter()
.map(|option| option.and_then(|value| Some(set.contains(&value))));
Ok(BooleanArray::from_iter($self.name(), result))
}};
}

impl<T> DaftIsIn<&DataArray<T>> for DataArray<T>
where
T: DaftIntegerType,
<T as DaftNumericType>::Native: Ord,
<T as DaftNumericType>::Native: std::hash::Hash,
<T as DaftNumericType>::Native: std::cmp::Eq,
{
type Output = DaftResult<BooleanArray>;

fn is_in(&self, rhs: &DataArray<T>) -> Self::Output {
collect_to_set_and_check_membership!(self, rhs)
}
}

macro_rules! impl_is_in_floating_array {
($arr:ident, $T:ident) => {
impl DaftIsIn<&$arr> for $arr {
type Output = DaftResult<BooleanArray>;

fn is_in(&self, rhs: &$arr) -> Self::Output {
let set = rhs
.as_arrow()
.iter()
.filter_map(|item| item.map(|value| FloatWrapper(*value)))
.collect::<BTreeSet<FloatWrapper<$T>>>();
let result = self.as_arrow().iter().map(|option| {
option.and_then(|value| Some(set.contains(&FloatWrapper(*value))))
});
Ok(BooleanArray::from_iter(self.name(), result))
}
}
};
}
impl_is_in_floating_array!(Float32Array, f32);
impl_is_in_floating_array!(Float64Array, f64);

macro_rules! impl_is_in_non_numeric_array {
($arr:ident) => {
impl DaftIsIn<&$arr> for $arr {
type Output = DaftResult<BooleanArray>;

fn is_in(&self, rhs: &$arr) -> Self::Output {
collect_to_set_and_check_membership!(self, rhs)
}
}
};
}
impl_is_in_non_numeric_array!(BooleanArray);
impl_is_in_non_numeric_array!(Utf8Array);
impl_is_in_non_numeric_array!(BinaryArray);

impl DaftIsIn<&NullArray> for NullArray {
type Output = DaftResult<BooleanArray>;

fn is_in(&self, _rhs: &NullArray) -> Self::Output {
// If self and rhs are null array then return a full null array
Ok(BooleanArray::full_null(
self.name(),
&DataType::Boolean,
self.len(),
))
}
}
6 changes: 6 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub(crate) mod groups;
mod hash;
mod if_else;
pub(crate) mod image;
mod is_in;
mod len;
mod list;
mod list_agg;
Expand Down Expand Up @@ -78,6 +79,11 @@ pub trait DaftLogical<Rhs> {
fn xor(&self, rhs: Rhs) -> Self::Output;
}

pub trait DaftIsIn<Rhs> {
type Output;
fn is_in(&self, rhs: Rhs) -> Self::Output;
}

pub trait DaftIsNull {
type Output;
fn is_null(&self) -> Self::Output;
Expand Down
7 changes: 7 additions & 0 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ impl DataType {
))
})
}
pub fn membership_op(
&self,
other: &Self,
) -> DaftResult<(DataType, Option<DataType>, DataType)> {
// membership checks (is_in) use equality checks, so we can use the same logic as comparison ops.
self.comparison_op(other)
}
}

impl Add for &DataType {
Expand Down
11 changes: 10 additions & 1 deletion src/daft-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
};

use crate::{
array::ops::{from_arrow::FromArrow, full::FullNull},
array::ops::{from_arrow::FromArrow, full::FullNull, DaftCompare},
datatypes::{DataType, Field, FieldRef},
utils::display_table::make_comfy_table,
with_match_daft_types,
Expand All @@ -26,6 +26,15 @@ pub struct Series {
pub inner: Arc<dyn SeriesLike>,
}

impl PartialEq for Series {
fn eq(&self, other: &Self) -> bool {
match self.equal(other) {
Ok(arr) => arr.into_iter().all(|x| x.unwrap_or(false)),
Err(_) => false,
}
}
}

impl Series {
pub fn to_arrow(&self) -> Box<dyn arrow2::array::Array> {
self.inner.to_arrow()
Expand Down
Loading

0 comments on commit 21cb2b5

Please sign in to comment.