Skip to content

Commit

Permalink
[FEAT] Add floor division (#3064)
Browse files Browse the repository at this point in the history
Close #2418
  • Loading branch information
ConeyLiu authored Oct 26, 2024
1 parent 5b450fb commit 915467b
Show file tree
Hide file tree
Showing 13 changed files with 265 additions and 19 deletions.
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,7 @@ class PySeries:
def __ne__(self, other: PySeries) -> PySeries: ... # type: ignore[override]
def __rshift__(self, other: PySeries) -> PySeries: ...
def __lshift__(self, other: PySeries) -> PySeries: ...
def __floordiv__(self, other: PySeries) -> PySeries: ...
def take(self, idx: PySeries) -> PySeries: ...
def slice(self, start: int, end: int) -> PySeries: ...
def filter(self, mask: PySeries) -> PySeries: ...
Expand Down
10 changes: 10 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,16 @@ def __invert__(self) -> Expression:
expr = self._expr.__invert__()
return Expression._from_pyexpr(expr)

def __floordiv__(self, other: Expression) -> Expression:
"""Floor divides two numeric expressions (``e1 / e2``)"""
expr = Expression._to_expression(other)
return Expression._from_pyexpr(self._expr // expr._expr)

def __rfloordiv__(self, other: object) -> Expression:
"""Reverse floor divides two numeric expressions (``e2 / e1``)"""
expr = Expression._to_expression(other)
return Expression._from_pyexpr(expr._expr // self._expr)

def alias(self, name: builtins.str) -> Expression:
"""Gives the expression a new name, which is its column's name in the DataFrame schema and the name
by which subsequent expressions can refer to the results of this expression.
Expand Down
6 changes: 6 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,12 @@ def __xor__(self, other: object) -> Series:
assert self._series is not None and other._series is not None
return Series._from_pyseries(self._series ^ other._series)

def __floordiv__(self, other: object) -> Series:
if not isinstance(other, Series):
raise TypeError(f"expected another Series but got {type(other)}")
assert self._series is not None and other._series is not None
return Series._from_pyseries(self._series // other._series)

def count(self, mode: CountMode = CountMode.Valid) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.count(mode))
Expand Down
71 changes: 56 additions & 15 deletions src/daft-core/src/array/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use common_error::{DaftError, DaftResult};
use super::{as_arrow::AsArrow, full::FullNull};
use crate::{
array::{DataArray, FixedSizeListArray},
datatypes::{DaftNumericType, DataType, Field, Float64Array, Int64Array, Utf8Array},
datatypes::{DaftNumericType, DataType, Field, Utf8Array},
kernels::utf8::add_utf8_arrays,
series::Series,
};
Expand Down Expand Up @@ -108,20 +108,6 @@ where
}
}

impl Div for &Float64Array {
type Output = DaftResult<Float64Array>;
fn div(self, rhs: Self) -> Self::Output {
arithmetic_helper(self, rhs, basic::div, |l, r| l / r)
}
}

impl Div for &Int64Array {
type Output = DaftResult<Int64Array>;
fn div(self, rhs: Self) -> Self::Output {
arithmetic_helper(self, rhs, basic::div, |l, r| l / r)
}
}

pub fn binary_with_nulls<T, F>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
Expand Down Expand Up @@ -195,6 +181,61 @@ where
}
}

fn div_with_nulls<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: arrow2::types::NativeType + Div<Output = T>,
{
binary_with_nulls(lhs, rhs, |a, b| a / b)
}

impl<T> Div for &DataArray<T>
where
T: DaftNumericType,
T::Native: basic::NativeArithmetics,
{
type Output = DaftResult<DataArray<T>>;
fn div(self, rhs: Self) -> Self::Output {
if rhs.data().null_count() == 0 {
arithmetic_helper(self, rhs, basic::div, |l, r| l / r)
} else {
match (self.len(), rhs.len()) {
(a, b) if a == b => Ok(DataArray::from((
self.name(),
Box::new(div_with_nulls(self.as_arrow(), rhs.as_arrow())),
))),
// broadcast right path
(_, 1) => {
let opt_rhs = rhs.get(0);
match opt_rhs {
None => Ok(DataArray::full_null(
self.name(),
self.data_type(),
self.len(),
)),
Some(rhs) => self.apply(|lhs| lhs / rhs),
}
}
(1, _) => {
let opt_lhs = self.get(0);
Ok(match opt_lhs {
None => DataArray::full_null(rhs.name(), rhs.data_type(), rhs.len()),
Some(lhs) => {
let values_iter = rhs.as_arrow().iter().map(|v| v.map(|v| lhs / *v));
let arrow_array = unsafe {
PrimitiveArray::from_trusted_len_iter_unchecked(values_iter)
};
DataArray::from((self.name(), Box::new(arrow_array)))
}
})
}
(a, b) => Err(DaftError::ValueError(format!(
"Cannot apply operation on arrays of different lengths: {a} vs {b}"
))),
}
}
}
}

fn fixed_sized_list_arithmetic_helper<Kernel>(
lhs: &FixedSizeListArray,
rhs: &FixedSizeListArray,
Expand Down
11 changes: 11 additions & 0 deletions src/daft-core/src/datatypes/infer_datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ impl<'a> InferDataType<'a> {
// membership checks (is_in) use equality checks, so we can use the same logic as comparison ops.
self.comparison_op(other)
}

pub fn floor_div(&self, other: &Self) -> DaftResult<DataType> {
try_numeric_supertype(self.0, other.0).or(match (self.0, other.0) {
#[cfg(feature = "python")]
(DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python),
_ => Err(DaftError::TypeError(format!(
"Cannot perform floor divide on types: {}, {}",
self, other
))),
})
}
}

impl<'a> Add for InferDataType<'a> {
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ impl PySeries {
Ok(self.series.shift_right(&other.series)?.into())
}

pub fn __floordiv__(&self, other: &Self) -> PyResult<Self> {
Ok(self.series.floor_div(&other.series)?.into())
}

pub fn ceil(&self) -> PyResult<Self> {
Ok(self.series.ceil()?.into())
}
Expand Down
51 changes: 49 additions & 2 deletions src/daft-core/src/series/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
array::prelude::*,
datatypes::{InferDataType, Utf8Array},
series::{utils::cast::cast_downcast_op, IntoSeries, Series},
with_match_numeric_daft_types,
with_match_integer_daft_types, with_match_numeric_daft_types,
};

macro_rules! impl_arithmetic_ref_for_series {
Expand Down Expand Up @@ -308,6 +308,29 @@ impl Rem for &Series {
}
}
}

impl Series {
pub fn floor_div(&self, rhs: &Self) -> DaftResult<Self> {
let output_type = InferDataType::from(self.data_type())
.floor_div(&InferDataType::from(rhs.data_type()))?;
let lhs = self;
match &output_type {
#[cfg(feature = "python")]
DataType::Python => run_python_binary_operator_fn(lhs, rhs, "floordiv"),
output_type if output_type.is_integer() => {
with_match_integer_daft_types!(output_type, |$T| {
Ok(cast_downcast_op!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, div)?.into_series())
})
}
output_type if output_type.is_numeric() => {
let div_floor = lhs.div(rhs)?.floor()?;
div_floor.cast(output_type)
}
_ => arithmetic_op_not_implemented!(self, "floor_div", rhs, output_type),
}
}
}

enum FixedSizeBinaryOp {
Add,
Sub,
Expand Down Expand Up @@ -383,7 +406,7 @@ mod tests {

use crate::{
array::ops::full::FullNull,
datatypes::{DataType, Float64Array, Int64Array, Utf8Array},
datatypes::{DataType, Float32Array, Float64Array, Int32Array, Int64Array, Utf8Array},
series::IntoSeries,
};

Expand Down Expand Up @@ -430,6 +453,30 @@ mod tests {
Ok(())
}
#[test]
fn floor_div_int_and_int() -> DaftResult<()> {
let a = Int32Array::from(("a", vec![1, 2, 3]));
let b = Int64Array::from(("b", vec![1, 2, 3]));
let c = a.into_series().floor_div(&(b.into_series()));
assert_eq!(*c?.data_type(), DataType::Int64);
Ok(())
}
#[test]
fn floor_div_int_and_float() -> DaftResult<()> {
let a = Int64Array::from(("a", vec![1, 2, 3]));
let b = Float64Array::from(("b", vec![1., 2., 3.]));
let c = a.into_series().floor_div(&(b.into_series()));
assert_eq!(*c?.data_type(), DataType::Float64);
Ok(())
}
#[test]
fn floor_div_float_and_float() -> DaftResult<()> {
let a = Float32Array::from(("b", vec![1., 2., 3.]));
let b = Float64Array::from(("b", vec![1., 2., 3.]));
let c = a.into_series().floor_div(&(b.into_series()));
assert_eq!(*c?.data_type(), DataType::Float64);
Ok(())
}
#[test]
fn rem_int_and_float() -> DaftResult<()> {
let a = Int64Array::from(("a", vec![1, 2, 3]));
let b = Float64Array::from(("b", vec![1., 2., 3.]));
Expand Down
4 changes: 3 additions & 1 deletion src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,9 @@ impl Expr {
Ok(Field::new(left_field.name.as_str(), result_type))
}
Operator::FloorDivide => {
unimplemented!()
let result_type = (InferDataType::from(&left_field.dtype)
.floor_div(&InferDataType::from(&right_field.dtype)))?;
Ok(Field::new(left_field.name.as_str(), result_type))
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,7 @@ impl SQLPlanner {
BinaryOperator::NotEq => Ok(Operator::NotEq),
BinaryOperator::And => Ok(Operator::And),
BinaryOperator::Or => Ok(Operator::Or),
BinaryOperator::DuckIntegerDivide => Ok(Operator::FloorDivide),
other => unsupported_sql_err!("Unsupported operator: '{other}'"),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ impl Table {
Plus => lhs + rhs,
Minus => lhs - rhs,
TrueDivide => lhs / rhs,
FloorDivide => lhs.floor_div(&rhs),
Multiply => lhs * rhs,
Modulus => lhs % rhs,
Lt => Ok(lhs.lt(&rhs)?.into_series()),
Expand All @@ -543,7 +544,6 @@ impl Table {
Xor => lhs.xor(&rhs),
ShiftLeft => lhs.shift_left(&rhs),
ShiftRight => lhs.shift_right(&rhs),
_ => panic!("{op:?} not supported"),
}
}
Function { func, inputs } => {
Expand Down
28 changes: 28 additions & 0 deletions tests/series/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def test_arithmetic_numbers_array(l_dtype, r_dtype) -> None:
assert div.name() == left.name()
assert div.to_pylist() == [1.0, 0.5, 3.0, None, None, None]

floor_div = left // right
assert floor_div.name() == left.name()
assert floor_div.to_pylist() == [1, 0, 3, None, None, None]

# mod = (l % r)
# assert mod.name() == l.name()
# assert mod.to_pylist() == [0, 2, 0, None, None, None]
Expand Down Expand Up @@ -67,6 +71,10 @@ def test_arithmetic_numbers_left_scalar(l_dtype, r_dtype) -> None:
assert div.name() == left.name()
assert div.to_pylist() == [1.0, 0.25, 1.0, 0.2, None, None]

floor_div = left // right
assert floor_div.name() == left.name()
assert floor_div.to_pylist() == [1, 0, 1, 0, None, None]

mod = left % right
assert mod.name() == left.name()
assert mod.to_pylist() == [0, 1, 0, 1, None, None]
Expand Down Expand Up @@ -97,6 +105,10 @@ def test_arithmetic_numbers_right_scalar(l_dtype, r_dtype) -> None:
assert div.name() == left.name()
assert div.to_pylist() == [1.0, 2.0, 3.0, None, 5.0, None]

floor_div = left // right
assert floor_div.name() == left.name()
assert floor_div.to_pylist() == [1, 2, 3, None, 5, None]

mod = left % right
assert mod.name() == left.name()
assert mod.to_pylist() == [0, 0, 0, None, 0, None]
Expand Down Expand Up @@ -127,6 +139,10 @@ def test_arithmetic_numbers_null_scalar(l_dtype, r_dtype) -> None:
assert div.name() == left.name()
assert div.to_pylist() == [None, None, None, None, None, None]

floor_div = left / right
assert floor_div.name() == left.name()
assert floor_div.to_pylist() == [None, None, None, None, None, None]

mod = left % right
assert mod.name() == left.name()
assert mod.to_pylist() == [None, None, None, None, None, None]
Expand Down Expand Up @@ -207,6 +223,9 @@ def test_comparisons_bad_right_value() -> None:
with pytest.raises(TypeError, match="another Series"):
left / right

with pytest.raises(TypeError, match="another Series"):
left // right

with pytest.raises(TypeError, match="another Series"):
left * right

Expand All @@ -233,6 +252,9 @@ def test_arithmetic_numbers_array_mismatch_length() -> None:
with pytest.raises(ValueError, match="different lengths"):
left / right

with pytest.raises(ValueError, match="different lengths"):
left // right

with pytest.raises(ValueError, match="different lengths"):
left % right

Expand Down Expand Up @@ -263,6 +285,11 @@ def __mod__(self, other):
other = 5
return 5 % other

def __floordiv__(self, other):
if isinstance(other, FakeFive):
other = 5
return 5 // other


@pytest.mark.parametrize(
["op", "expected_datatype", "expected", "expected_self"],
Expand All @@ -272,6 +299,7 @@ def __mod__(self, other):
(operator.mul, DataType.int64(), [10, None, None], [25, 25, None]),
(operator.truediv, DataType.float64(), [2.5, None, None], [1.0, 1.0, None]),
(operator.mod, DataType.int64(), [1, None, None], [0, 0, None]),
(operator.floordiv, DataType.int64(), [2, None, None], [1.0, 1.0, None]),
],
)
def test_arithmetic_pyobjects(op, expected_datatype, expected, expected_self) -> None:
Expand Down
Loading

0 comments on commit 915467b

Please sign in to comment.