Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Enable Comparison between timestamp / dates #1689

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,17 @@
return cls._from_pydatatype(PyDataType.date())

@classmethod
def timestamp(cls, timeunit: TimeUnit, timezone: str | None = None) -> DataType:
def timestamp(cls, timeunit: TimeUnit | str, timezone: str | None = None) -> DataType:
"""Timestamp DataType."""
if isinstance(timeunit, str):
timeunit = TimeUnit.from_str(timeunit)
return cls._from_pydatatype(PyDataType.timestamp(timeunit._timeunit, timezone))

@classmethod
def duration(cls, timeunit: TimeUnit) -> DataType:
def duration(cls, timeunit: TimeUnit | str) -> DataType:
"""Duration DataType."""
if isinstance(timeunit, str):
timeunit = TimeUnit.from_str(timeunit)

Check warning on line 207 in daft/datatype.py

View check run for this annotation

Codecov / codecov/patch

daft/datatype.py#L207

Added line #L207 was not covered by tests
return cls._from_pydatatype(PyDataType.duration(timeunit._timeunit))

@classmethod
Expand Down
42 changes: 26 additions & 16 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ops::{Add, Div, Mul, Rem, Sub};

use common_error::{DaftError, DaftResult};

use crate::impl_binary_trait_by_reference;
use crate::{impl_binary_trait_by_reference, utils::supertype::try_get_supertype};

use super::DataType;

Expand All @@ -24,27 +24,37 @@ impl DataType {
))
})
}
pub fn comparison_op(&self, other: &Self) -> DaftResult<(DataType, DataType)> {
pub fn comparison_op(
&self,
other: &Self,
) -> DaftResult<(DataType, Option<DataType>, DataType)> {
// Whether a comparison op is supported between the two types.
// Returns:
// - the output type,
// - an optional intermediate type
// - the type at which the comparison should be performed.
use DataType::*;
match (self, other) {
(s, o) if s == o => Ok(s.to_physical()),
(s, o) if s.is_physical() && o.is_physical() => {
try_physical_supertype(s, o).map_err(|_| ())
}
// To maintain existing behaviour. TODO: cleanup
(Date, o) | (o, Date) if o.is_physical() && o.clone() != Boolean => {
try_physical_supertype(&Date.to_physical(), o).map_err(|_| ())
let evaluator = || {
use DataType::*;
match (self, other) {
(s, o) if s == o => Ok((Boolean, None, s.to_physical())),
(s, o) if s.is_physical() && o.is_physical() => {
Ok((Boolean, None, try_physical_supertype(s, o)?))
}
(Timestamp(..) | Date, Timestamp(..) | Date) => {
let intermediate_type = try_get_supertype(self, other)?;
let pt = intermediate_type.to_physical();
Ok((Boolean, Some(intermediate_type), pt))
}
_ => Err(DaftError::TypeError(format!(
"Cannot perform comparison on types: {}, {}",
self, other
))),
}
_ => Err(()),
}
.map(|comp_type| (Boolean, comp_type))
.map_err(|()| {
};

evaluator().map_err(|err| {
DaftError::TypeError(format!(
"Cannot perform comparison on types: {}, {}",
"Cannot perform comparison on types: {}, {}\nDetails:\n{err}",
self, other
))
})
Expand Down
13 changes: 10 additions & 3 deletions src/daft-core/src/series/array_impl/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,24 @@ macro_rules! physical_logic_op {

macro_rules! physical_compare_op {
($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{
let (output_type, comp_type) = ($self.data_type().comparison_op($rhs.data_type()))?;
let (output_type, intermediate, comp_type) =
($self.data_type().comparison_op($rhs.data_type()))?;
let lhs = $self.into_series();
let (lhs, rhs) = if let Some(ref it) = intermediate {
(lhs.cast(it)?, $rhs.cast(it)?)
} else {
(lhs, $rhs.clone())
};

use DataType::*;
if let Boolean = output_type {
match comp_type {
#[cfg(feature = "python")]
Python => py_binary_op_bool!(lhs, $rhs, $pyop)
Python => py_binary_op_bool!(lhs, rhs, $pyop)
.downcast::<BooleanArray>()
.cloned(),
_ => with_match_comparable_daft_types!(comp_type, |$T| {
cast_downcast_op!(lhs, $rhs, &comp_type, <$T as DaftDataType>::ArrayType, $op)
cast_downcast_op!(lhs, rhs, &comp_type, <$T as DaftDataType>::ArrayType, $op)
}),
}
} else {
Expand Down
12 changes: 6 additions & 6 deletions src/daft-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,20 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
(Duration(_), Date) | (Date, Duration(_)) => Some(Date),
(Duration(lu), Duration(ru)) => Some(Duration(get_time_units(lu, ru))),

// None and Some("") timezones
// Some() timezones that are non equal
// we cast from more precision to higher precision as that always fits with occasional loss of precision
(Timestamp(tu_l, tz_l), Timestamp(tu_r, tz_r))
if (tz_l.is_none() || tz_l.as_deref() == Some(""))
&& (tz_r.is_none() || tz_r.as_deref() == Some("")) =>
(Timestamp(tu_l, Some(tz_l)), Timestamp(tu_r, Some(tz_r)))
if !tz_l.is_empty()
&& !tz_r.is_empty() && tz_l != tz_r =>
{
let tu = get_time_units(tu_l, tu_r);
Some(Timestamp(tu, None))
Some(Timestamp(tu, Some("UTC".to_string())))
}
// None and Some("<tz>") timezones
// we cast from more precision to higher precision as that always fits with occasional loss of precision
(Timestamp(tu_l, tz_l), Timestamp(tu_r, tz_r)) if
// both are none
tz_l.is_none() && tz_r.is_some()
tz_l.is_none() && tz_r.is_none()
// both have the same time zone
|| (tz_l.is_some() && (tz_l == tz_r)) => {
let tu = get_time_units(tu_l, tu_r);
Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ impl Expr {
| Operator::NotEq
| Operator::LtEq
| Operator::GtEq => {
let (result_type, _comp_type) =
let (result_type, _intermediate, _comp_type) =
left_field.dtype.comparison_op(&right_field.dtype)?;
Ok(Field::new(left_field.name.as_str(), result_type))
}
Expand Down
62 changes: 56 additions & 6 deletions tests/expressions/typing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import itertools
import sys

import pytz

if sys.version_info < (3, 8):
pass
else:
Expand Down Expand Up @@ -33,17 +35,65 @@
(DataType.bool(), pa.array([True, False, None], type=pa.bool_())),
(DataType.null(), pa.array([None, None, None], type=pa.null())),
(DataType.binary(), pa.array([b"1", b"2", None], type=pa.binary())),
(DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())),
# TODO(jay): Some of the fixtures are broken/become very complicated when testing against timestamps
# (
# DataType.timestamp(TimeUnit.ms()),
# pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp("ms")),
# ),
]

ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2))


ALL_TEMPORAL_DTYPES = [
(DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())),
*[
(
DataType.timestamp(unit),
pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp(unit)),
)
for unit in ["ns", "us", "ms"]
],
*[
(
DataType.timestamp(unit, "US/Eastern"),
pa.array(
[
datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("US/Eastern")),
datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("US/Eastern")),
None,
],
type=pa.timestamp(unit, "US/Eastern"),
),
)
for unit in ["ns", "us", "ms"]
],
*[
(
DataType.timestamp(unit, "Africa/Accra"),
pa.array(
[
datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("Africa/Accra")),
datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("Africa/Accra")),
None,
],
type=pa.timestamp(unit, "Africa/Accra"),
),
)
for unit in ["ns", "us", "ms"]
],
]

ALL_DTYPES += ALL_TEMPORAL_DTYPES

ALL_TEMPORAL_DATATYPES_BINARY_PAIRS = [
((dt1, data1), (dt2, data2))
for (dt1, data1), (dt2, data2) in itertools.product(ALL_TEMPORAL_DTYPES, repeat=2)
if not (
pa.types.is_timestamp(data1.type)
and pa.types.is_timestamp(data2.type)
and (data1.type.tz is None) ^ (data2.type.tz is None)
)
]

ALL_DATATYPES_BINARY_PAIRS += ALL_TEMPORAL_DATATYPES_BINARY_PAIRS


@pytest.fixture(
scope="module",
params=ALL_DATATYPES_BINARY_PAIRS,
Expand Down
50 changes: 50 additions & 0 deletions tests/series/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import itertools
import operator
from datetime import date, datetime

import pyarrow as pa
import pytest
import pytz

from daft import DataType, Series

Expand Down Expand Up @@ -682,3 +684,51 @@ def test_logicalops_pyobjects(op, expected, expected_self) -> None:
assert op(custom_falses, values).datatype() == DataType.bool()
assert op(custom_falses, values).to_pylist() == expected
assert op(custom_falses, custom_falses).to_pylist() == expected_self


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
def test_compare_timestamps_no_tz(tu1, tu2):
tz1 = Series.from_pylist([datetime(2022, 1, 1)])
assert (tz1.cast(DataType.timestamp(tu1)) == tz1.cast(DataType.timestamp(tu2))).to_pylist() == [True]


def test_compare_timestamps_no_tz_date():
tz1 = Series.from_pylist([datetime(2022, 1, 1)])
Series.from_pylist([date(2022, 1, 1)])
assert (tz1 == tz1).to_pylist() == [True]


def test_compare_timestamps_one_tz():
tz1 = Series.from_pylist([datetime(2022, 1, 1)])
tz2 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)])
with pytest.raises(ValueError, match="Cannot perform comparison on types"):
assert (tz1 == tz2).to_pylist() == [True]


def test_compare_timestamps_and_int():
tz1 = Series.from_pylist([datetime(2022, 1, 1)])
tz2 = Series.from_pylist([5])
with pytest.raises(ValueError, match="Cannot perform comparison on types"):
assert (tz1 == tz2).to_pylist() == [True]


def test_compare_timestamps_tz_date():
tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)])
Series.from_pylist([date(2022, 1, 1)])
assert (tz1 == tz1).to_pylist() == [True]


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
def test_compare_timestamps_same_tz(tu1, tu2):
tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu1, "UTC"))
tz2 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu2, "UTC"))
assert (tz1 == tz2).to_pylist() == [True]


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
def test_compare_timestamps_diff_tz(tu1, tu2):
utc = datetime(2022, 1, 1, tzinfo=pytz.utc)
eastern = utc.astimezone(pytz.timezone("US/Eastern"))
tz1 = Series.from_pylist([utc]).cast(DataType.timestamp(tu1, "UTC"))
tz2 = Series.from_pylist([eastern]).cast(DataType.timestamp(tu1, "US/Eastern"))
assert (tz1 == tz2).to_pylist() == [True]
Loading