Skip to content

Commit

Permalink
enable trunc transform and remove bug
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Jan 13, 2024
1 parent a41ef22 commit 95612ba
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 6 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ class PartitionTransform:
def hour() -> PartitionTransform: ...
@staticmethod
def iceberg_bucket(n: int) -> PartitionTransform: ...
@staticmethod
def iceberg_truncate(w: int) -> PartitionTransform: ...

class Pushdowns:
"""
Expand Down
4 changes: 4 additions & 0 deletions daft/iceberg/iceberg_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _iceberg_partition_field_to_daft_partition_field(
HourTransform,
IdentityTransform,
MonthTransform,
TruncateTransform,
YearTransform,
)

Expand All @@ -66,6 +67,9 @@ def _iceberg_partition_field_to_daft_partition_field(
elif isinstance(transform, BucketTransform):
n = transform.num_buckets
tfm = PartitionTransform.iceberg_bucket(n)
elif isinstance(transform, TruncateTransform):
w = transform.width
tfm = PartitionTransform.iceberg_truncate(w)

Check warning on line 72 in daft/iceberg/iceberg_scan.py

View check run for this annotation

Codecov / codecov/patch

daft/iceberg/iceberg_scan.py#L70-L72

Added lines #L70 - L72 were not covered by tests
else:
warnings.warn(f"{transform} not implemented, Please make an issue!")
return make_partition_field(result_field, daft_field, transform=tfm)
Expand Down
49 changes: 49 additions & 0 deletions src/daft-dsl/src/functions/partitioning/evaluators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,52 @@ impl FunctionEvaluator for IcebergBucketEvaluator {
}
}
}

pub(super) struct IcebergTruncateEvaluator {}

impl FunctionEvaluator for IcebergTruncateEvaluator {
fn fn_name(&self) -> &'static str {
"partitioning_iceberg_truncate"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
match inputs {
[input] => match input.to_field(schema) {
Ok(field) => match &field.dtype {
DataType::Decimal128(_, _)
| DataType::Utf8 => Ok(field.clone()),
v if v.is_integer() => Ok(field.clone()),
_ => Err(DaftError::TypeError(format!(
"Expected input to IcebergTruncate to be an Integer, Utf8 or Decimal, got {}",
field.dtype
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], expr: &Expr) -> DaftResult<Series> {
use crate::functions::FunctionExpr;

let w = match expr {
Expr::Function {
func: FunctionExpr::Partitioning(PartitioningExpr::IcebergTruncate(w)),
inputs: _,
} => w,
_ => panic!("Expected PartitioningExpr::IcebergTruncate Expr, got {expr}"),
};

match inputs {
[input] => input.partitioning_iceberg_truncate(*w),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),
}
}
}
12 changes: 11 additions & 1 deletion src/daft-dsl/src/functions/partitioning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use serde::{Deserialize, Serialize};

use crate::{
functions::partitioning::evaluators::{
DaysEvaluator, HoursEvaluator, IcebergBucketEvaluator, MonthsEvaluator, YearsEvaluator,
DaysEvaluator, HoursEvaluator, IcebergBucketEvaluator, IcebergTruncateEvaluator,
MonthsEvaluator, YearsEvaluator,
},
Expr,
};
Expand All @@ -18,6 +19,7 @@ pub enum PartitioningExpr {
Days,
Hours,
IcebergBucket(i32),
IcebergTruncate(i64),
}

impl PartitioningExpr {
Expand All @@ -30,6 +32,7 @@ impl PartitioningExpr {
Days => &DaysEvaluator {},
Hours => &HoursEvaluator {},
IcebergBucket(..) => &IcebergBucketEvaluator {},
IcebergTruncate(..) => &IcebergTruncateEvaluator {},
}
}
}
Expand Down Expand Up @@ -68,3 +71,10 @@ pub fn iceberg_bucket(input: Expr, n: i32) -> Expr {
inputs: vec![input],
}
}

pub fn iceberg_truncate(input: Expr, w: i64) -> Expr {
Expr::Function {
func: super::FunctionExpr::Partitioning(PartitioningExpr::IcebergTruncate(w)),
inputs: vec![input],
}
}
8 changes: 5 additions & 3 deletions src/daft-scan/src/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ fn apply_partitioning_expr(expr: Expr, pfield: &PartitionField) -> Option<Expr>
expr.cast(&pfield.source_field.as_ref().unwrap().dtype),
n as i32,
)),
Some(IcebergTruncate(w)) => Some(partitioning::iceberg_truncate(
expr.cast(&pfield.source_field.as_ref().unwrap().dtype),
w as i64,
)),
_ => None,
}
}
Expand Down Expand Up @@ -114,16 +118,14 @@ pub fn rewrite_predicate_for_partitioning(
Expr::BinaryOp {
op,
ref left, ref right } if matches!(op, Lt | LtEq | Gt | GtEq)=> {
use PartitionTransform::*;

let relaxed_op = match op {
Lt | LtEq => LtEq,
Gt | GtEq => GtEq,
_ => unreachable!("this branch only supports Lt | LtEq | Gt | GtEq")
};

if let Expr::Column(col_name) = left.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) {
if let Some(tfm) = pfield.transform && tfm.supports_comparison() && matches!(tfm, Year | Month | Hour | Day) && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), pfield) {
if let Some(tfm) = pfield.transform && tfm.supports_comparison() && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), pfield) {
return Ok(Transformed::Yes(Expr::BinaryOp { op: relaxed_op, left: col(pfield.field.name.as_str()).into(), right: new_expr.into() }));
}
Ok(Transformed::No(expr))
Expand Down
5 changes: 5 additions & 0 deletions src/daft-scan/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,11 @@ pub mod pylib {
Ok(Self(crate::PartitionTransform::IcebergBucket(n)))
}

#[staticmethod]
pub fn iceberg_truncate(n: u64) -> PyResult<Self> {
Ok(Self(crate::PartitionTransform::IcebergTruncate(n)))
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.0))
}
Expand Down
2 changes: 0 additions & 2 deletions tests/integration/iceberg/test_partition_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def test_daft_iceberg_table_predicate_pushdown_on_letter(predicate, table, limit
if limit:
df = df.limit(limit)
df.collect()

daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["letter"])]
Expand Down Expand Up @@ -199,7 +198,6 @@ def test_daft_iceberg_table_predicate_pushdown_on_number(predicate, table, limit
if limit:
df = df.limit(limit)
df.collect()

daft_pandas = df.to_pandas()
iceberg_pandas = tab.scan().to_arrow().to_pandas()
iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["number"])]
Expand Down

0 comments on commit 95612ba

Please sign in to comment.