Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] is_in expression #1811

Merged
merged 11 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ class PyExpr:
def __ne__(self, other: PyExpr) -> PyExpr: ... # type: ignore[override]
def is_null(self) -> PyExpr: ...
def not_null(self) -> PyExpr: ...
def is_in(self, other: PyExpr) -> PyExpr: ...
def name(self) -> str: ...
def to_field(self, schema: PySchema) -> PyField: ...
def __repr__(self) -> str: ...
Expand Down Expand Up @@ -873,6 +874,7 @@ def col(name: str) -> PyExpr: ...
def lit(item: Any) -> PyExpr: ...
def date_lit(item: int) -> PyExpr: ...
def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
def series_lit(item: PySeries) -> PyExpr: ...
def udf(func: Callable, expressions: list[PyExpr], return_dtype: PyDataType) -> PyExpr: ...

class PySeries:
Expand Down
25 changes: 23 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import builtins
import sys
from datetime import date, datetime
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, overload

import pyarrow as pa

Expand All @@ -13,11 +13,13 @@
from daft.daft import col as _col
from daft.daft import date_lit as _date_lit
from daft.daft import lit as _lit
from daft.daft import series_lit as _series_lit
from daft.daft import timestamp_lit as _timestamp_lit
from daft.daft import udf as _udf
from daft.datatype import DataType, TimeUnit
from daft.expressions.testing import expr_structurally_equal
from daft.logical.schema import Field, Schema
from daft.series import Series, item_to_series

if sys.version_info < (3, 8):
from typing_extensions import Literal
Expand Down Expand Up @@ -51,6 +53,8 @@ def lit(value: object) -> Expression:
# pyo3 date (PyDate) is not available when running in abi3 mode, workaround
epoch_time = value - date(1970, 1, 1)
lit_value = _date_lit(epoch_time.days)
elif isinstance(value, Series):
lit_value = _series_lit(value._series)
else:
lit_value = _lit(value)
return Expression._from_pyexpr(lit_value)
Expand Down Expand Up @@ -387,6 +391,24 @@ def not_null(self) -> Expression:
expr = self._expr.not_null()
return Expression._from_pyexpr(expr)

def is_in(self, other: Any) -> Expression:
"""Checks if values in the Expression are in the provided list

Example:
>>> # [1, 2, 3] -> [True, False, True]
>>> col("x").is_in([1, 3])

Returns:
Expression: Boolean Expression indicating whether values are in the provided list
"""

if not isinstance(other, Expression):
series = item_to_series("items", other)
other = Expression._to_expression(series)

expr = self._expr.is_in(other._expr)
return Expression._from_pyexpr(expr)

def name(self) -> builtins.str:
return self._expr.name()

Expand Down Expand Up @@ -464,7 +486,6 @@ def download(
Expression: a Binary expression which is the bytes contents of the URL, or None if an error occured during download
"""
if use_native_downloader:

raise_on_error = False
if on_error == "raise":
raise_on_error = True
Expand Down
18 changes: 17 additions & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TypeVar
from typing import Any, TypeVar

import pyarrow as pa

Expand Down Expand Up @@ -512,6 +512,22 @@ def _debug_bincode_deserialize(cls, b: bytes) -> Series:
return Series._from_pyseries(PySeries._debug_bincode_deserialize(b))


def item_to_series(name: str, item: Any) -> Series:
if isinstance(item, list):
series = Series.from_pylist(item, name)
elif _NUMPY_AVAILABLE and isinstance(item, np.ndarray):
series = Series.from_numpy(item, name)
elif isinstance(item, Series):
series = item
elif isinstance(item, (pa.Array, pa.ChunkedArray)):
series = Series.from_arrow(item, name)
elif _PANDAS_AVAILABLE and isinstance(item, pd.Series):
series = Series.from_pandas(item, name)
else:
raise ValueError(f"Creating a Series from data of type {type(item)} not implemented")
return series


SomeSeriesNamespace = TypeVar("SomeSeriesNamespace", bound="SeriesNamespace")


Expand Down
15 changes: 2 additions & 13 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from daft.datatype import DataType, TimeUnit
from daft.expressions import Expression, ExpressionsProjection
from daft.logical.schema import Schema
from daft.series import Series
from daft.series import Series, item_to_series

_NUMPY_AVAILABLE = True
try:
Expand Down Expand Up @@ -148,18 +148,7 @@ def from_pandas(pd_df: pd.DataFrame) -> Table:
def from_pydict(data: dict) -> Table:
series_dict = dict()
for k, v in data.items():
if isinstance(v, list):
series = Series.from_pylist(v, name=k)
elif _NUMPY_AVAILABLE and isinstance(v, np.ndarray):
series = Series.from_numpy(v, name=k)
elif isinstance(v, Series):
series = v
elif isinstance(v, (pa.Array, pa.ChunkedArray)):
series = Series.from_arrow(v, name=k)
elif _PANDAS_AVAILABLE and isinstance(v, pd.Series):
series = Series.from_pandas(v, name=k)
else:
raise ValueError(f"Creating a Series from data of type {type(v)} not implemented")
series = item_to_series(k, v)
series_dict[k] = series._series
return Table._from_pytable(_PyTable.from_pylist_series(series_dict))

Expand Down
11 changes: 11 additions & 0 deletions daft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@
]


def python_list_membership_check(
left_pylist: list,
right_pylist: list,
) -> list:
try:
right_pyset = set(right_pylist)
return [elem in right_pyset for elem in left_pylist]
except TypeError:
return [elem in right_pylist for elem in left_pylist]

Check warning on line 98 in daft/utils.py

View check run for this annotation

Codecov / codecov/patch

daft/utils.py#L97-L98

Added lines #L97 - L98 were not covered by tests


def map_operator_arrow_semantics(
operator: Callable[[Any, Any], Any],
left_pylist: list,
Expand Down
16 changes: 15 additions & 1 deletion src/daft-core/src/array/from_iter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::datatypes::{BinaryArray, DaftNumericType, Field, Utf8Array};
use crate::datatypes::{BinaryArray, BooleanArray, DaftNumericType, Field, Utf8Array};

use super::DataArray;

Expand Down Expand Up @@ -41,3 +41,17 @@ impl BinaryArray {
.unwrap()
}
}

impl BooleanArray {
pub fn from_iter(
name: &str,
iter: impl Iterator<Item = Option<bool>> + arrow2::trusted_len::TrustedLen,
) -> Self {
let arrow_array = Box::new(arrow2::array::BooleanArray::from_trusted_len_iter(iter));
DataArray::new(
Field::new(name, crate::DataType::Boolean).into(),
arrow_array,
)
.unwrap()
}
}
93 changes: 93 additions & 0 deletions src/daft-core/src/array/ops/is_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use crate::{
array::DataArray,
datatypes::{
BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, Float32Array, Float64Array,
NullArray, Utf8Array,
},
DataType,
};

use super::as_arrow::AsArrow;
use super::{full::FullNull, DaftIsIn};
use crate::utils::orderable_float_wrapper::FloatWrapper;
use common_error::DaftResult;
use std::collections::{BTreeSet, HashSet};

macro_rules! collect_to_set_and_check_membership {
($self:expr, $rhs:expr) => {{
let set = $rhs
.as_arrow()
.iter()
.filter_map(|item| item)
.collect::<HashSet<_>>();
let result = $self
.as_arrow()
.iter()
.map(|option| option.and_then(|value| Some(set.contains(&value))));
Ok(BooleanArray::from_iter($self.name(), result))
}};
}

impl<T> DaftIsIn<&DataArray<T>> for DataArray<T>
where
T: DaftIntegerType,
<T as DaftNumericType>::Native: Ord,
<T as DaftNumericType>::Native: std::hash::Hash,
<T as DaftNumericType>::Native: std::cmp::Eq,
{
type Output = DaftResult<BooleanArray>;

fn is_in(&self, rhs: &DataArray<T>) -> Self::Output {
collect_to_set_and_check_membership!(self, rhs)
}
}

macro_rules! impl_is_in_floating_array {
($arr:ident, $T:ident) => {
impl DaftIsIn<&$arr> for $arr {
type Output = DaftResult<BooleanArray>;

fn is_in(&self, rhs: &$arr) -> Self::Output {
let set = rhs
.as_arrow()
.iter()
.filter_map(|item| item.map(|value| FloatWrapper(*value)))
.collect::<BTreeSet<FloatWrapper<$T>>>();
let result = self.as_arrow().iter().map(|option| {
option.and_then(|value| Some(set.contains(&FloatWrapper(*value))))
});
Ok(BooleanArray::from_iter(self.name(), result))
}
}
};
}
impl_is_in_floating_array!(Float32Array, f32);
impl_is_in_floating_array!(Float64Array, f64);

macro_rules! impl_is_in_non_numeric_array {
($arr:ident) => {
impl DaftIsIn<&$arr> for $arr {
type Output = DaftResult<BooleanArray>;

fn is_in(&self, rhs: &$arr) -> Self::Output {
collect_to_set_and_check_membership!(self, rhs)
}
}
};
}
impl_is_in_non_numeric_array!(BooleanArray);
impl_is_in_non_numeric_array!(Utf8Array);
impl_is_in_non_numeric_array!(BinaryArray);

impl DaftIsIn<&NullArray> for NullArray {
type Output = DaftResult<BooleanArray>;

fn is_in(&self, _rhs: &NullArray) -> Self::Output {
// If self and rhs are null array then return a full null array
Ok(BooleanArray::full_null(
self.name(),
&DataType::Boolean,
self.len(),
))
}
}
6 changes: 6 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub(crate) mod groups;
mod hash;
mod if_else;
pub(crate) mod image;
mod is_in;
mod len;
mod list;
mod list_agg;
Expand Down Expand Up @@ -78,6 +79,11 @@ pub trait DaftLogical<Rhs> {
fn xor(&self, rhs: Rhs) -> Self::Output;
}

pub trait DaftIsIn<Rhs> {
type Output;
fn is_in(&self, rhs: Rhs) -> Self::Output;
}

pub trait DaftIsNull {
type Output;
fn is_null(&self) -> Self::Output;
Expand Down
7 changes: 7 additions & 0 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ impl DataType {
))
})
}
pub fn membership_op(
&self,
other: &Self,
) -> DaftResult<(DataType, Option<DataType>, DataType)> {
// membership checks (is_in) use equality checks, so we can use the same logic as comparison ops.
self.comparison_op(other)
}
}

impl Add for &DataType {
Expand Down
11 changes: 10 additions & 1 deletion src/daft-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
};

use crate::{
array::ops::{from_arrow::FromArrow, full::FullNull},
array::ops::{from_arrow::FromArrow, full::FullNull, DaftCompare},
datatypes::{DataType, Field, FieldRef},
utils::display_table::make_comfy_table,
with_match_daft_types,
Expand All @@ -26,6 +26,15 @@ pub struct Series {
pub inner: Arc<dyn SeriesLike>,
}

impl PartialEq for Series {
fn eq(&self, other: &Self) -> bool {
match self.equal(other) {
Ok(arr) => arr.into_iter().all(|x| x.unwrap_or(false)),
Err(_) => false,
}
}
}

impl Series {
pub fn to_arrow(&self) -> Box<dyn arrow2::array::Array> {
self.inner.to_arrow()
Expand Down
53 changes: 53 additions & 0 deletions src/daft-core/src/series/ops/is_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use common_error::DaftResult;

use crate::{
array::ops::DaftIsIn, datatypes::BooleanArray, with_match_comparable_daft_types, DataType,
IntoSeries, Series,
};

#[cfg(feature = "python")]
use crate::series::ops::py_membership_op_utilfn;

fn default(name: &str, size: usize) -> DaftResult<Series> {
Ok(BooleanArray::from((name, vec![false; size].as_slice())).into_series())
}

impl Series {
pub fn is_in(&self, items: &Self) -> DaftResult<Series> {
if items.is_empty() {
return default(self.name(), self.len());
}

let (output_type, intermediate, comp_type) =
match self.data_type().membership_op(items.data_type()) {
Ok(types) => types,
Err(_) => return default(self.name(), self.len()),
};

let (lhs, rhs) = if let Some(ref it) = intermediate {
(self.cast(it)?, items.cast(it)?)
} else {
(self.clone(), items.clone())
};

if let DataType::Boolean = output_type {
match comp_type {
#[cfg(feature = "python")]
DataType::Python => Ok(py_membership_op_utilfn(self, items)?
.downcast::<BooleanArray>()?
.clone()
.into_series()),
_ => with_match_comparable_daft_types!(comp_type, |$T| {
let casted_lhs = lhs.cast(&comp_type)?;
let casted_rhs = rhs.cast(&comp_type)?;
let lhs = casted_lhs.downcast::<<$T as DaftDataType>::ArrayType>()?;
let rhs = casted_rhs.downcast::<<$T as DaftDataType>::ArrayType>()?;

Ok(lhs.is_in(rhs)?.into_series())
}),
}
} else {
unreachable!()
}
}
}
Loading
Loading