From 95612ba210f4ebeeb36352577259bcfbefe539fa Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Fri, 12 Jan 2024 17:12:38 -0800 Subject: [PATCH] enable trunc transform and remove bug --- daft/daft.pyi | 2 + daft/iceberg/iceberg_scan.py | 4 ++ .../src/functions/partitioning/evaluators.rs | 49 +++++++++++++++++++ .../src/functions/partitioning/mod.rs | 12 ++++- src/daft-scan/src/expr_rewriter.rs | 8 +-- src/daft-scan/src/python.rs | 5 ++ .../iceberg/test_partition_pruning.py | 2 - 7 files changed, 76 insertions(+), 6 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index b22a50d86c..f7ae41f19f 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -609,6 +609,8 @@ class PartitionTransform: def hour() -> PartitionTransform: ... @staticmethod def iceberg_bucket(n: int) -> PartitionTransform: ... + @staticmethod + def iceberg_truncate(w: int) -> PartitionTransform: ... class Pushdowns: """ diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index 3ee6259776..bb43d16392 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -49,6 +49,7 @@ def _iceberg_partition_field_to_daft_partition_field( HourTransform, IdentityTransform, MonthTransform, + TruncateTransform, YearTransform, ) @@ -66,6 +67,9 @@ def _iceberg_partition_field_to_daft_partition_field( elif isinstance(transform, BucketTransform): n = transform.num_buckets tfm = PartitionTransform.iceberg_bucket(n) + elif isinstance(transform, TruncateTransform): + w = transform.width + tfm = PartitionTransform.iceberg_truncate(w) else: warnings.warn(f"{transform} not implemented, Please make an issue!") return make_partition_field(result_field, daft_field, transform=tfm) diff --git a/src/daft-dsl/src/functions/partitioning/evaluators.rs b/src/daft-dsl/src/functions/partitioning/evaluators.rs index f88ddfed97..4442e3cabc 100644 --- a/src/daft-dsl/src/functions/partitioning/evaluators.rs +++ b/src/daft-dsl/src/functions/partitioning/evaluators.rs @@ -109,3 +109,52 @@ impl FunctionEvaluator for IcebergBucketEvaluator { } } } + +pub(super) struct IcebergTruncateEvaluator {} + +impl FunctionEvaluator for IcebergTruncateEvaluator { + fn fn_name(&self) -> &'static str { + "partitioning_iceberg_truncate" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => match input.to_field(schema) { + Ok(field) => match &field.dtype { + DataType::Decimal128(_, _) + | DataType::Utf8 => Ok(field.clone()), + v if v.is_integer() => Ok(field.clone()), + _ => Err(DaftError::TypeError(format!( + "Expected input to IcebergTruncate to be an Integer, Utf8 or Decimal, got {}", + field.dtype + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], expr: &Expr) -> DaftResult { + use crate::functions::FunctionExpr; + + let w = match expr { + Expr::Function { + func: FunctionExpr::Partitioning(PartitioningExpr::IcebergTruncate(w)), + inputs: _, + } => w, + _ => panic!("Expected PartitioningExpr::IcebergTruncate Expr, got {expr}"), + }; + + match inputs { + [input] => input.partitioning_iceberg_truncate(*w), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/partitioning/mod.rs b/src/daft-dsl/src/functions/partitioning/mod.rs index f3c4d5be74..74f9dc5ecc 100644 --- a/src/daft-dsl/src/functions/partitioning/mod.rs +++ b/src/daft-dsl/src/functions/partitioning/mod.rs @@ -4,7 +4,8 @@ use serde::{Deserialize, Serialize}; use crate::{ functions::partitioning::evaluators::{ - DaysEvaluator, HoursEvaluator, IcebergBucketEvaluator, MonthsEvaluator, YearsEvaluator, + DaysEvaluator, HoursEvaluator, IcebergBucketEvaluator, IcebergTruncateEvaluator, + MonthsEvaluator, YearsEvaluator, }, Expr, }; @@ -18,6 +19,7 @@ pub enum PartitioningExpr { Days, Hours, IcebergBucket(i32), + IcebergTruncate(i64), } impl PartitioningExpr { @@ -30,6 +32,7 @@ impl PartitioningExpr { Days => &DaysEvaluator {}, Hours => &HoursEvaluator {}, IcebergBucket(..) => &IcebergBucketEvaluator {}, + IcebergTruncate(..) => &IcebergTruncateEvaluator {}, } } } @@ -68,3 +71,10 @@ pub fn iceberg_bucket(input: Expr, n: i32) -> Expr { inputs: vec![input], } } + +pub fn iceberg_truncate(input: Expr, w: i64) -> Expr { + Expr::Function { + func: super::FunctionExpr::Partitioning(PartitioningExpr::IcebergTruncate(w)), + inputs: vec![input], + } +} diff --git a/src/daft-scan/src/expr_rewriter.rs b/src/daft-scan/src/expr_rewriter.rs index 1fc74e99aa..84eb82f61f 100644 --- a/src/daft-scan/src/expr_rewriter.rs +++ b/src/daft-scan/src/expr_rewriter.rs @@ -41,6 +41,10 @@ fn apply_partitioning_expr(expr: Expr, pfield: &PartitionField) -> Option expr.cast(&pfield.source_field.as_ref().unwrap().dtype), n as i32, )), + Some(IcebergTruncate(w)) => Some(partitioning::iceberg_truncate( + expr.cast(&pfield.source_field.as_ref().unwrap().dtype), + w as i64, + )), _ => None, } } @@ -114,8 +118,6 @@ pub fn rewrite_predicate_for_partitioning( Expr::BinaryOp { op, ref left, ref right } if matches!(op, Lt | LtEq | Gt | GtEq)=> { - use PartitionTransform::*; - let relaxed_op = match op { Lt | LtEq => LtEq, Gt | GtEq => GtEq, @@ -123,7 +125,7 @@ pub fn rewrite_predicate_for_partitioning( }; if let Expr::Column(col_name) = left.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) { - if let Some(tfm) = pfield.transform && tfm.supports_comparison() && matches!(tfm, Year | Month | Hour | Day) && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), pfield) { + if let Some(tfm) = pfield.transform && tfm.supports_comparison() && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), pfield) { return Ok(Transformed::Yes(Expr::BinaryOp { op: relaxed_op, left: col(pfield.field.name.as_str()).into(), right: new_expr.into() })); } Ok(Transformed::No(expr)) diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 73b1b0cf11..031cb901dd 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -395,6 +395,11 @@ pub mod pylib { Ok(Self(crate::PartitionTransform::IcebergBucket(n))) } + #[staticmethod] + pub fn iceberg_truncate(n: u64) -> PyResult { + Ok(Self(crate::PartitionTransform::IcebergTruncate(n))) + } + pub fn __repr__(&self) -> PyResult { Ok(format!("{}", self.0)) } diff --git a/tests/integration/iceberg/test_partition_pruning.py b/tests/integration/iceberg/test_partition_pruning.py index 22a63de860..d2239ea91e 100644 --- a/tests/integration/iceberg/test_partition_pruning.py +++ b/tests/integration/iceberg/test_partition_pruning.py @@ -158,7 +158,6 @@ def test_daft_iceberg_table_predicate_pushdown_on_letter(predicate, table, limit if limit: df = df.limit(limit) df.collect() - daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["letter"])] @@ -199,7 +198,6 @@ def test_daft_iceberg_table_predicate_pushdown_on_number(predicate, table, limit if limit: df = df.limit(limit) df.collect() - daft_pandas = df.to_pandas() iceberg_pandas = tab.scan().to_arrow().to_pandas() iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["number"])]