diff --git a/src/daft-core/src/array/ops/between.rs b/src/daft-core/src/array/ops/between.rs index 09e1914372..d40dbc571e 100644 --- a/src/daft-core/src/array/ops/between.rs +++ b/src/daft-core/src/array/ops/between.rs @@ -3,12 +3,12 @@ use common_error::{DaftError, DaftResult}; use super::{DaftBetween, DaftCompare, DaftLogical}; use crate::{ array::DataArray, - datatypes::{BooleanArray, DaftNumericType}, + datatypes::{BooleanArray, DaftPrimitiveType}, }; impl DaftBetween<&Self, &Self> for DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, { type Output = DaftResult; diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index ce6e662b37..74332a28a6 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -213,6 +213,31 @@ macro_rules! with_match_numeric_daft_types {( } })} +#[macro_export] +macro_rules! with_match_primitive_daft_types {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use $crate::datatypes::DataType::*; + use $crate::datatypes::*; + + match $key_type { + Int8 => __with_ty__! { Int8Type }, + Int16 => __with_ty__! { Int16Type }, + Int32 => __with_ty__! { Int32Type }, + Int64 => __with_ty__! { Int64Type }, + UInt8 => __with_ty__! { UInt8Type }, + UInt16 => __with_ty__! { UInt16Type }, + UInt32 => __with_ty__! { UInt32Type }, + UInt64 => __with_ty__! { UInt64Type }, + // Float16 => __with_ty__! { Float16Type }, + Float32 => __with_ty__! { Float32Type }, + Float64 => __with_ty__! { Float64Type }, + Decimal128(..) => __with_ty__! { Decimal128Type }, + _ => panic!("{:?} not implemented", $key_type) + } +})} + #[macro_export] macro_rules! with_match_integer_daft_types {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* diff --git a/src/daft-core/src/series/ops/between.rs b/src/daft-core/src/series/ops/between.rs index f6247eab15..1799533a44 100644 --- a/src/daft-core/src/series/ops/between.rs +++ b/src/daft-core/src/series/ops/between.rs @@ -6,7 +6,7 @@ use crate::{ array::ops::DaftBetween, datatypes::{BooleanArray, DataType, InferDataType}, series::{IntoSeries, Series}, - with_match_numeric_daft_types, + with_match_primitive_daft_types, }; impl Series { @@ -31,8 +31,8 @@ impl Series { .clone() .into_series()), DataType::Null => Ok(Self::full_null(self.name(), &DataType::Boolean, self.len())), - ref v if v.is_numeric() => { - with_match_numeric_daft_types!(comp_type, |$T| { + ref v if v.is_primitive() => { + with_match_primitive_daft_types!(comp_type, |$T| { let casted_value = it_value.cast(&comp_type)?; let casted_lower = it_lower.cast(&comp_type)?; let casted_upper = it_upper.cast(&comp_type)?; diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index e7836c616b..3ef50d6b05 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -392,6 +392,26 @@ impl DataType { } } + #[inline] + pub fn is_primitive(&self) -> bool { + match self { + Self::Int8 + | Self::Int16 + | Self::Int32 + | Self::Int64 + | Self::UInt8 + | Self::UInt16 + | Self::UInt32 + | Self::UInt64 + // DataType::Float16 + | Self::Float32 + | Self::Float64 + | Self::Decimal128(..) => true, + Self::Extension(_, inner, _) => inner.is_primitive(), + _ => false + } + } + #[inline] pub fn assert_is_numeric(&self) -> DaftResult<()> { if self.is_numeric() { diff --git a/tests/table/test_between.py b/tests/table/test_between.py index 8848c0693b..75fc074672 100644 --- a/tests/table/test_between.py +++ b/tests/table/test_between.py @@ -1,4 +1,5 @@ import datetime +from decimal import Decimal import pytest @@ -12,6 +13,27 @@ pytest.param([1, 2, 3, 4], 1, 2, [True, True, False, False], id="IntIntInt"), pytest.param([1, 2, 3, 4], 1.0, 2.0, [True, True, False, False], id="IntFloatFloat"), pytest.param([1, 2, 3, 4], 1, 2.0, [True, True, False, False], id="IntIntFloat"), + pytest.param( + [Decimal("1.0"), Decimal("2.0"), Decimal("3.0"), Decimal("4.0")], + Decimal("1.0"), + Decimal("2.0"), + [True, True, False, False], + id="DecimalDecimalDecimal", + ), + pytest.param( + [Decimal("1.0"), Decimal("2.0"), Decimal("3.0"), Decimal("4.0")], + 1, + Decimal("2.0"), + [True, True, False, False], + id="DecimalIntDecimal", + ), + pytest.param( + [Decimal("1.0"), Decimal("2.0"), Decimal("3.0"), Decimal("4.0")], + 1.0, + 2.0, + [True, True, False, False], + id="DecimalFloatFloat", + ), pytest.param([1.0, 2.0, 3.0, 4.0], 1.0, 2.0, [True, True, False, False], id="FloatFloatFloat"), pytest.param([1.0, 2.0, 3.0, 4.0], 1, 2, [True, True, False, False], id="FloatIntInt"), pytest.param([1.0, 2.0, 3.0, 4.0], 1, 2.0, [True, True, False, False], id="FloatIntFloat"), @@ -43,6 +65,20 @@ def test_table_expr_between_scalars(value, lower, upper, expected) -> None: pytest.param( [1, 2, 3, 4], [1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [True, True, False, False], id="IntFloatFloat" ), + pytest.param( + [Decimal("1.0"), Decimal("2.0"), Decimal("3.0"), Decimal("4.0")], + [Decimal("1.0"), Decimal("1.0"), Decimal("1.0"), Decimal("1.0")], + [Decimal("2.0"), Decimal("2.0"), Decimal("2.0"), Decimal("2.0")], + [True, True, False, False], + id="DecimalDecimalDecimal", + ), + pytest.param( + [Decimal("1.0"), Decimal("2.0"), Decimal("3.0"), Decimal("4.0")], + [1, 1, 1, 1], + [2.0, 2.0, 2.0, 2.0], + [True, True, False, False], + id="DecimalIntFloat", + ), pytest.param([1, 2, 3, 4], [1, 1, 1, 1], [2.0, 2.0, 2.0, 2.0], [True, True, False, False], id="IntIntFloat"), pytest.param( [None, None, None, None],