Skip to content

Commit

Permalink
[PERF] Iceberg Truncate Transform (#1783)
Browse files Browse the repository at this point in the history
* Iceberg Truncate Transform following the spec here
https://iceberg.apache.org/spec/#truncate-transform-details
  • Loading branch information
samster25 authored Jan 13, 2024
1 parent 1870c84 commit 447ea5a
Show file tree
Hide file tree
Showing 14 changed files with 239 additions and 8 deletions.
3 changes: 3 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ class PartitionTransform:
def hour() -> PartitionTransform: ...
@staticmethod
def iceberg_bucket(n: int) -> PartitionTransform: ...
@staticmethod
def iceberg_truncate(w: int) -> PartitionTransform: ...

class Pushdowns:
"""
Expand Down Expand Up @@ -933,6 +935,7 @@ class PySeries:
def partitioning_months(self) -> PySeries: ...
def partitioning_years(self) -> PySeries: ...
def partitioning_iceberg_bucket(self, n: int) -> PySeries: ...
def partitioning_iceberg_truncate(self, w: int) -> PySeries: ...
def list_lengths(self) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def image_decode(self) -> PySeries: ...
Expand Down
4 changes: 4 additions & 0 deletions daft/iceberg/iceberg_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _iceberg_partition_field_to_daft_partition_field(
HourTransform,
IdentityTransform,
MonthTransform,
TruncateTransform,
YearTransform,
)

Expand All @@ -66,6 +67,9 @@ def _iceberg_partition_field_to_daft_partition_field(
elif isinstance(transform, BucketTransform):
n = transform.num_buckets
tfm = PartitionTransform.iceberg_bucket(n)
elif isinstance(transform, TruncateTransform):
w = transform.width
tfm = PartitionTransform.iceberg_truncate(w)
else:
warnings.warn(f"{transform} not implemented, Please make an issue!")
return make_partition_field(result_field, daft_field, transform=tfm)
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ def years(self) -> Series:
def iceberg_bucket(self, n: int) -> Series:
return Series._from_pyseries(self._series.partitioning_iceberg_bucket(n))

def iceberg_truncate(self, w: int) -> Series:
return Series._from_pyseries(self._series.partitioning_iceberg_truncate(w))


class SeriesListNamespace(SeriesNamespace):
def lengths(self) -> Series:
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[dependencies]
arrow2 = {workspace = true, features = ["chrono-tz", "compute_take", "compute_cast", "compute_aggregate", "compute_if_then_else", "compute_sort", "compute_filter", "compute_temporal", "compute_comparison", "compute_arithmetics", "compute_concatenate", "io_ipc"]}
arrow2 = {workspace = true, features = ["chrono-tz", "compute_take", "compute_cast", "compute_aggregate", "compute_if_then_else", "compute_sort", "compute_filter", "compute_temporal", "compute_comparison", "compute_arithmetics", "compute_concatenate", "compute_substring", "io_ipc"]}
base64 = "0.21.5"
bincode = {workspace = true}
chrono = {workspace = true}
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod struct_;
mod sum;
mod take;
pub(crate) mod tensor;
mod truncate;
mod utf8;

pub use sort::{build_multi_array_bicompare, build_multi_array_compare};
Expand Down
70 changes: 70 additions & 0 deletions src/daft-core/src/array/ops/truncate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::ops::Rem;

use common_error::DaftResult;
use num_traits::ToPrimitive;

use crate::{
array::DataArray,
datatypes::{
logical::Decimal128Array, DaftNumericType, Int16Type, Int32Type, Int64Type, Int8Type,
UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Array,
},
};

use super::as_arrow::AsArrow;

macro_rules! impl_int_truncate {
($DT:ty) => {
impl DataArray<$DT> {
pub fn iceberg_truncate(&self, w: i64) -> DaftResult<DataArray<$DT>> {
let as_arrowed = self.as_arrow();

let trun_value = as_arrowed.into_iter().map(|v| {
v.map(|v| {
let i = v.to_i64().unwrap();
let t = (i - (((i.rem(w)) + w).rem(w)));
t as <$DT as DaftNumericType>::Native
})
});
let array = Box::new(arrow2::array::PrimitiveArray::from_iter(trun_value));
Ok(<DataArray<$DT>>::from((self.name(), array)))
}
}
};
}

impl_int_truncate!(Int8Type);
impl_int_truncate!(Int16Type);
impl_int_truncate!(Int32Type);
impl_int_truncate!(Int64Type);

impl_int_truncate!(UInt8Type);
impl_int_truncate!(UInt16Type);
impl_int_truncate!(UInt32Type);
impl_int_truncate!(UInt64Type);

impl Decimal128Array {
pub fn iceberg_truncate(&self, w: i64) -> DaftResult<Decimal128Array> {
let as_arrow = self.as_arrow();
let trun_value = as_arrow.into_iter().map(|v| {
v.map(|i| {
let w = w as i128;
let remainder = ((i.rem(w)) + w).rem(w);
i - remainder
})
});
let array = Box::new(arrow2::array::PrimitiveArray::from_iter(trun_value));
Ok(Decimal128Array::new(
self.field.clone(),
DataArray::from((self.name(), array)),
))
}
}

impl Utf8Array {
pub fn iceberg_truncate(&self, w: i64) -> DaftResult<Utf8Array> {
let as_arrow = self.as_arrow();
let substring = arrow2::compute::substring::utf8_substring(as_arrow, 0, &Some(w));
Ok(Utf8Array::from((self.name(), Box::new(substring))))
}
}
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ impl PySeries {
Ok(self.series.partitioning_iceberg_bucket(n)?.into())
}

pub fn partitioning_iceberg_truncate(&self, w: i64) -> PyResult<Self> {
Ok(self.series.partitioning_iceberg_truncate(w)?.into())
}

pub fn murmur3_32(&self) -> PyResult<Self> {
Ok(self.series.murmur3_32()?.into_series().into())
}
Expand Down
19 changes: 19 additions & 0 deletions src/daft-core/src/series/ops/partitioning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::array::ops::as_arrow::AsArrow;
use crate::datatypes::logical::TimestampArray;
use crate::datatypes::{Int32Array, Int64Array, TimeUnit};
use crate::series::array_impl::IntoSeries;
use crate::with_match_integer_daft_types;
use crate::{datatypes::DataType, series::Series};
use common_error::{DaftError, DaftResult};

Expand Down Expand Up @@ -104,4 +105,22 @@ impl Series {
let array = Box::new(arrow2::array::Int32Array::from_iter(buckets));
Ok(Int32Array::from((self.name(), array)).into_series())
}

pub fn partitioning_iceberg_truncate(&self, w: i64) -> DaftResult<Self> {
assert!(w > 0, "Expected w to be positive, got {w}");
match self.data_type() {
i if i.is_integer() => {
with_match_integer_daft_types!(i, |$T| {
let downcasted = self.downcast::<<$T as DaftDataType>::ArrayType>()?;
Ok(downcasted.iceberg_truncate(w)?.into_series())
})
}
DataType::Decimal128(..) => Ok(self.decimal128()?.iceberg_truncate(w)?.into_series()),
DataType::Utf8 => Ok(self.utf8()?.iceberg_truncate(w)?.into_series()),
_ => Err(DaftError::ComputeError(format!(
"Can only run partitioning_iceberg_truncate() operation on integers, decimal and string, got {}",
self.data_type()
))),
}
}
}
51 changes: 50 additions & 1 deletion src/daft-dsl/src/functions/partitioning/evaluators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl FunctionEvaluator for IcebergBucketEvaluator {
func: FunctionExpr::Partitioning(PartitioningExpr::IcebergBucket(n)),
inputs: _,
} => n,
_ => panic!("Expected Url Download Expr, got {expr}"),
_ => panic!("Expected PartitioningExpr::IcebergBucket Expr, got {expr}"),
};

match inputs {
Expand All @@ -109,3 +109,52 @@ impl FunctionEvaluator for IcebergBucketEvaluator {
}
}
}

pub(super) struct IcebergTruncateEvaluator {}

impl FunctionEvaluator for IcebergTruncateEvaluator {
fn fn_name(&self) -> &'static str {
"partitioning_iceberg_truncate"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[input] => match input.to_field(schema) {
Ok(field) => match &field.dtype {
DataType::Decimal128(_, _)
| DataType::Utf8 => Ok(field.clone()),
v if v.is_integer() => Ok(field.clone()),
_ => Err(DaftError::TypeError(format!(
"Expected input to IcebergTruncate to be an Integer, Utf8 or Decimal, got {}",
field.dtype
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], expr: &Expr) -> DaftResult<Series> {
use crate::functions::FunctionExpr;

let w = match expr {
Expr::Function {
func: FunctionExpr::Partitioning(PartitioningExpr::IcebergTruncate(w)),
inputs: _,
} => w,
_ => panic!("Expected PartitioningExpr::IcebergTruncate Expr, got {expr}"),
};

match inputs {
[input] => input.partitioning_iceberg_truncate(*w),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}
}
12 changes: 11 additions & 1 deletion src/daft-dsl/src/functions/partitioning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use serde::{Deserialize, Serialize};

use crate::{
functions::partitioning::evaluators::{
DaysEvaluator, HoursEvaluator, IcebergBucketEvaluator, MonthsEvaluator, YearsEvaluator,
DaysEvaluator, HoursEvaluator, IcebergBucketEvaluator, IcebergTruncateEvaluator,
MonthsEvaluator, YearsEvaluator,
},
Expr,
};
Expand All @@ -18,6 +19,7 @@ pub enum PartitioningExpr {
Days,
Hours,
IcebergBucket(i32),
IcebergTruncate(i64),
}

impl PartitioningExpr {
Expand All @@ -30,6 +32,7 @@ impl PartitioningExpr {
Days => &DaysEvaluator {},
Hours => &HoursEvaluator {},
IcebergBucket(..) => &IcebergBucketEvaluator {},
IcebergTruncate(..) => &IcebergTruncateEvaluator {},
}
}
}
Expand Down Expand Up @@ -68,3 +71,10 @@ pub fn iceberg_bucket(input: Expr, n: i32) -> Expr {
inputs: vec![input],
}
}

pub fn iceberg_truncate(input: Expr, w: i64) -> Expr {
Expr::Function {
func: super::FunctionExpr::Partitioning(PartitioningExpr::IcebergTruncate(w)),
inputs: vec![input],
}
}
8 changes: 5 additions & 3 deletions src/daft-scan/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ fn apply_partitioning_expr(expr: Expr, pfield: &PartitionField) -> Option<Expr>
expr.cast(&pfield.source_field.as_ref().unwrap().dtype),
n as i32,
)),
Some(IcebergTruncate(w)) => Some(partitioning::iceberg_truncate(
expr.cast(&pfield.source_field.as_ref().unwrap().dtype),
w as i64,
)),
_ => None,
}
}
Expand Down Expand Up @@ -114,16 +118,14 @@ pub fn rewrite_predicate_for_partitioning(
Expr::BinaryOp {
op,
ref left, ref right } if matches!(op, Lt | LtEq | Gt | GtEq)=> {
use PartitionTransform::*;

let relaxed_op = match op {
Lt | LtEq => LtEq,
Gt | GtEq => GtEq,
_ => unreachable!("this branch only supports Lt | LtEq | Gt | GtEq")
};

if let Expr::Column(col_name) = left.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) {
if let Some(tfm) = pfield.transform && tfm.supports_comparison() && matches!(tfm, Year | Month | Hour | Day) && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), pfield) {
if let Some(tfm) = pfield.transform && tfm.supports_comparison() && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), pfield) {
return Ok(Transformed::Yes(Expr::BinaryOp { op: relaxed_op, left: col(pfield.field.name.as_str()).into(), right: new_expr.into() }));
}
Ok(Transformed::No(expr))
Expand Down
5 changes: 5 additions & 0 deletions src/daft-scan/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ pub mod pylib {
Ok(Self(crate::PartitionTransform::IcebergBucket(n)))
}

#[staticmethod]
pub fn iceberg_truncate(n: u64) -> PyResult<Self> {
Ok(Self(crate::PartitionTransform::IcebergTruncate(n)))
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.0))
}
Expand Down
2 changes: 0 additions & 2 deletions tests/integration/iceberg/test_partition_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def test_daft_iceberg_table_predicate_pushdown_on_letter(predicate, table, limit
if limit:
df = df.limit(limit)
df.collect()

daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["letter"])]
Expand Down Expand Up @@ -199,7 +198,6 @@ def test_daft_iceberg_table_predicate_pushdown_on_number(predicate, table, limit
if limit:
df = df.limit(limit)
df.collect()

daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["number"])]
Expand Down
63 changes: 63 additions & 0 deletions tests/series/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,66 @@ def test_iceberg_bucketing(input, n):
assert seen[v] == b
else:
seen[v] = b


def test_iceberg_truncate_decimal():
data = ["12.34", "12.30", "12.29", "0.05", "-0.05"]
data = [Decimal(v) for v in data] + [None]
expected = ["12.30", "12.30", "12.20", "0.00", "-0.10"]
expected = [Decimal(v) for v in expected] + [None]

s = Series.from_pylist(data)
trunc = s.partitioning.iceberg_truncate(10)
assert trunc.datatype() == s.datatype()
assert trunc.to_pylist() == expected


@pytest.mark.parametrize(
"dtype",
[
DataType.int8(),
DataType.int16(),
DataType.int32(),
DataType.int64(),
],
)
def test_iceberg_truncate_signed_int(dtype):
data = [0, 1, 5, 9, 10, 11, -1, -5, -10, -11, None]
expected = [0, 0, 0, 0, 10, 10, -10, -10, -10, -20, None]

s = Series.from_pylist(data).cast(dtype)
trunc = s.partitioning.iceberg_truncate(10)
assert trunc.datatype() == s.datatype()
assert trunc.to_pylist() == expected


@pytest.mark.parametrize(
"dtype",
[
DataType.uint8(),
DataType.uint16(),
DataType.uint32(),
DataType.uint64(),
DataType.int8(),
DataType.int16(),
DataType.int32(),
DataType.int64(),
],
)
def test_iceberg_truncate_all_int(dtype):
data = [0, 1, 5, 9, 10, 11, None]
expected = [0, 0, 0, 0, 10, 10, None]

s = Series.from_pylist(data).cast(dtype)
trunc = s.partitioning.iceberg_truncate(10)
assert trunc.datatype() == s.datatype()
assert trunc.to_pylist() == expected


def test_iceberg_truncate_str():
data = ["abcdefg", "abc", "abcde", None]
expected = ["abcde", "abc", "abcde", None]
s = Series.from_pylist(data)
trunc = s.partitioning.iceberg_truncate(5)
assert trunc.datatype() == s.datatype()
assert trunc.to_pylist() == expected

0 comments on commit 447ea5a

Please sign in to comment.