From e8bd953678114e2668e1a042bd62f16209a08a2e Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Tue, 27 Aug 2024 21:29:54 +0300 Subject: [PATCH 01/13] adding main function and tests --- .../datafusion/src/physical_plan/scan.rs | 245 +++++++++++++++++- crates/integrations/datafusion/src/table.rs | 24 +- .../tests/integration_datafusion_test.rs | 36 +++ 3 files changed, 299 insertions(+), 6 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index c50b32efb..24e20ad7b 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -23,12 +23,17 @@ use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::error::Result as DFResult; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::logical_expr::{BinaryExpr, Operator}; use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, }; +use datafusion::prelude::Expr; +use datafusion::scalar::ScalarValue; use futures::{Stream, TryStreamExt}; +use iceberg::expr::{Predicate, Reference}; +use iceberg::spec::Datum; use iceberg::table::Table; use crate::to_datafusion_error; @@ -44,17 +49,19 @@ pub(crate) struct IcebergTableScan { /// Stores certain, often expensive to compute, /// plan properties used in query optimization. plan_properties: PlanProperties, + predicates: Option, } impl IcebergTableScan { /// Creates a new [`IcebergTableScan`] object. - pub(crate) fn new(table: Table, schema: ArrowSchemaRef) -> Self { + pub(crate) fn new(table: Table, schema: ArrowSchemaRef, filters: &[Expr]) -> Self { let plan_properties = Self::compute_properties(schema.clone()); - + let predicates = convert_filters_to_predicate(filters); Self { table, schema, plan_properties, + predicates, } } @@ -100,7 +107,7 @@ impl ExecutionPlan for IcebergTableScan { _partition: usize, _context: Arc, ) -> DFResult { - let fut = get_batch_stream(self.table.clone()); + let fut = get_batch_stream(self.table.clone(), self.predicates.clone()); let stream = futures::stream::once(fut).try_flatten(); Ok(Box::pin(RecordBatchStreamAdapter::new( @@ -127,8 +134,13 @@ impl DisplayAs for IcebergTableScan { /// and then converts it into a stream of Arrow [`RecordBatch`]es. async fn get_batch_stream( table: Table, + predicates: Option, ) -> DFResult> + Send>>> { - let table_scan = table.scan().build().map_err(to_datafusion_error)?; + let mut scan_builder = table.scan(); + if let Some(pred) = predicates { + scan_builder = scan_builder.with_filter(pred); + } + let table_scan = scan_builder.build().map_err(to_datafusion_error)?; let stream = table_scan .to_arrow() @@ -138,3 +150,228 @@ async fn get_batch_stream( Ok(Box::pin(stream)) } + +/// convert DataFusion filters ([`Expr`]) to an iceberg [`Predicate`] +/// if none of the filters could be converted, return `None` +/// if the conversion was successful, return the converted predicates combined with an AND operator +fn convert_filters_to_predicate(filters: &[Expr]) -> Option { + filters + .iter() + .filter_map(expr_to_predicate) + .reduce(Predicate::and) +} + +/// Converts a DataFusion [`Expr`] to an Iceberg [`Predicate`]. +/// +/// This function handles the conversion of certain DataFusion expression types +/// to their corresponding Iceberg predicates. It supports the following cases: +/// +/// 1. Simple binary expressions (e.g., "column < value") +/// 2. Compound AND expressions (e.g., "x < 1 AND y > 10") +/// 3. Compound OR expressions (e.g., "x < 1 OR y > 10") +/// +/// For AND expressions, if one part of the expression can't be converted, +/// the function will still return a predicate for the part that can be converted. +/// For OR expressions, if any part can't be converted, the entire expression +/// will fail to convert. +/// +/// # Arguments +/// +/// * `expr` - A reference to a DataFusion [`Expr`] to be converted. +/// +/// # Returns +/// +/// * `Some(Predicate)` if the expression could be successfully converted. +/// * `None` if the expression couldn't be converted to an Iceberg predicate. +/// +fn expr_to_predicate(expr: &Expr) -> Option { + match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + match (left.as_ref(), op, right.as_ref()) { + // first option: x < 1 + (Expr::Column(col), op, Expr::Literal(lit)) => { + let reference = Reference::new(col.name.clone()); + let datum = scalar_value_to_datum(lit)?; + Some(binary_op_to_predicate(reference, op, datum)) + } + // second option (inner AND): x < 1 AND y > 10 + // if its an AND expression and one predicate fails, we can still go with the other one + (left_expr, Operator::And, right_expr) => { + let left_pred = expr_to_predicate(&left_expr.clone()); + let right_pred = expr_to_predicate(&right_expr.clone()); + match (left_pred, right_pred) { + (Some(left), Some(right)) => Some(Predicate::and(left, right)), + (Some(left), None) => Some(left), + (None, Some(right)) => Some(right), + (None, None) => None, + } + } + // third option (inner OR): x < 1 OR y > 10 + // if one is unsuported, we need to fail the predicate + (Expr::BinaryExpr(left_expr), Operator::Or, Expr::BinaryExpr(right_expr)) => { + let left_pred = expr_to_predicate(&Expr::BinaryExpr(left_expr.clone()))?; + let right_pred = expr_to_predicate(&Expr::BinaryExpr(right_expr.clone()))?; + Some(Predicate::or(left_pred, right_pred)) + } + _ => None, + } + } + _ => None, + } +} + +/// convert the data fusion Exp to an iceberg [`Predicate`] +fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate { + match op { + Operator::Eq => reference.equal_to(datum), + Operator::NotEq => reference.not_equal_to(datum), + Operator::Lt => reference.less_than(datum), + Operator::LtEq => reference.less_than_or_equal_to(datum), + Operator::Gt => reference.greater_than(datum), + Operator::GtEq => reference.greater_than_or_equal_to(datum), + _ => Predicate::AlwaysTrue, + } +} +/// convert a DataFusion scalar value to an iceberg [`Datum`] +fn scalar_value_to_datum(value: &ScalarValue) -> Option { + match value { + ScalarValue::Int8(Some(v)) => Some(Datum::long(*v as i64)), + ScalarValue::Int16(Some(v)) => Some(Datum::long(*v as i64)), + ScalarValue::Int32(Some(v)) => Some(Datum::long(*v as i64)), + ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)), + ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)), + ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)), + ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())), + // Add more cases as needed + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::DFSchema; + use datafusion::prelude::SessionContext; + + fn create_test_schema() -> DFSchema { + let arrow_schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Utf8, false), + ]); + DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() + } + fn create_test_schema_b() -> DFSchema { + let arrow_schema = Schema::new(vec![ + Field::new("xxx", DataType::Int32, false), + Field::new("yyy", DataType::Utf8, false), + Field::new("zzz", DataType::Int32, false), + ]); + DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() + } + + #[test] + fn test_predicate_conversion_with_single_condition() { + let sql = "foo > 1"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + assert_eq!( + predicate, + Reference::new("foo").greater_than(Datum::long(1)) + ); + } + + #[test] + fn test_predicate_conversion_with_multiple_conditions() { + let sql = "foo > 1 and bar = 'test'"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let inner_predicate = Predicate::and( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + ); + assert_eq!(predicate, inner_predicate); + } + + #[test] + fn test_predicate_conversion_with_multiple_binary_expr() { + let sql = "(foo > 1 and bar = 'test') or foo < 0 "; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let inner_predicate = Predicate::and( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + ); + let expected_predicate = Predicate::or( + inner_predicate, + Reference::new("foo").less_than(Datum::long(0)), + ); + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_unsupported_condition_not() { + let sql = "xxx > 1 and yyy is not null and zzz < 0 "; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let expected_predicate = Predicate::and( + Reference::new("xxx").greater_than(Datum::long(1)), + Reference::new("zzz").less_than(Datum::long(0)), + ); + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_unsupported_condition_and() { + let sql = "(xxx > 1 and yyy in ('test', 'test2')) and zzz < 0 "; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let expected_predicate = Predicate::and( + Reference::new("xxx").greater_than(Datum::long(1)), + Reference::new("zzz").less_than(Datum::long(0)), + ); + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_unsupported_condition_or() { + let sql = "(foo > 1 and bar in ('test', 'test2')) or foo < 0 "; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let expected_predicate = Predicate::or( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("foo").less_than(Datum::long(0)), + ); + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_unsupported_expr() { + let sql = "(xxx > 1 or yyy in ('test', 'test2')) and zzz < 0 "; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let expected_predicate = Reference::new("zzz").less_than(Datum::long(0)); + assert_eq!(predicate, expected_predicate); + } +} diff --git a/crates/integrations/datafusion/src/table.rs b/crates/integrations/datafusion/src/table.rs index 7ff7b2211..939236630 100644 --- a/crates/integrations/datafusion/src/table.rs +++ b/crates/integrations/datafusion/src/table.rs @@ -18,19 +18,19 @@ use std::any::Any; use std::sync::Arc; +use crate::physical_plan::scan::IcebergTableScan; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::catalog::Session; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result as DFResult; use datafusion::logical_expr::Expr; +use datafusion::logical_expr::{BinaryExpr, TableProviderFilterPushDown}; use datafusion::physical_plan::ExecutionPlan; use iceberg::arrow::schema_to_arrow_schema; use iceberg::table::Table; use iceberg::{Catalog, NamespaceIdent, Result, TableIdent}; -use crate::physical_plan::scan::IcebergTableScan; - /// Represents a [`TableProvider`] for the Iceberg [`Catalog`], /// managing access to a [`Table`]. pub(crate) struct IcebergTableProvider { @@ -82,6 +82,26 @@ impl TableProvider for IcebergTableProvider { Ok(Arc::new(IcebergTableScan::new( self.table.clone(), self.schema.clone(), + _filters, ))) } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> std::result::Result, datafusion::error::DataFusionError> + { + let filter_support = filters + .iter() + .map(|e| { + if let Expr::BinaryExpr(BinaryExpr { .. }) = e { + TableProviderFilterPushDown::Inexact + } else { + TableProviderFilterPushDown::Unsupported + } + }) + .collect::>(); + + Ok(filter_support) + } } diff --git a/crates/integrations/datafusion/tests/integration_datafusion_test.rs b/crates/integrations/datafusion/tests/integration_datafusion_test.rs index 9e62930fd..afb383904 100644 --- a/crates/integrations/datafusion/tests/integration_datafusion_test.rs +++ b/crates/integrations/datafusion/tests/integration_datafusion_test.rs @@ -147,3 +147,39 @@ async fn test_provider_list_schema_names() -> Result<()> { .all(|item| result.contains(&item.to_string()))); Ok(()) } +#[tokio::test] +async fn test_table_scan() -> Result<()> { + let iceberg_catalog = get_iceberg_catalog(); + let namespace = NamespaceIdent::new("test_provider_list_table_names".to_string()); + set_test_namespace(&iceberg_catalog, &namespace).await?; + let creation = set_table_creation(temp_path(), "my_table")?; + let new_table = iceberg_catalog.create_table(&namespace, creation).await?; + let client = Arc::new(iceberg_catalog); + let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); + let ctx = SessionContext::new(); + + ctx.register_catalog("catalog", catalog); + let df = ctx + .sql("select * from catalog.test_provider_list_table_names.my_table where (foo > 1 and bar = 'test') or foo < 0 ") + .await + .unwrap(); + + let compute_result = df.collect().await; + if let Ok(df) = compute_result { + println!("==> compute_result OK: {:?}", df); + } else { + println!( + "==> compute_result ERROR: {:?}", + compute_result.err().unwrap() + ); + } + let provider = ctx.catalog("catalog").unwrap(); + let schema = provider.schema("test_provider_list_table_names").unwrap(); + + let expected = vec!["my_table"]; + let result = schema.table_names(); + + assert_eq!(result, expected); + + Ok(()) +} From 9a54ef972f5d2d7dae71f260aabed923a5f9770c Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Wed, 28 Aug 2024 09:45:48 +0300 Subject: [PATCH 02/13] adding tests, removing integration test for now --- .../datafusion/src/physical_plan/scan.rs | 17 +++++---- .../tests/integration_datafusion_test.rs | 36 ------------------- 2 files changed, 8 insertions(+), 45 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index 24e20ad7b..16fe4aa77 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -161,10 +161,9 @@ fn convert_filters_to_predicate(filters: &[Expr]) -> Option { .reduce(Predicate::and) } -/// Converts a DataFusion [`Expr`] to an Iceberg [`Predicate`]. +/// Recuresivly converting DataFusion filters ( in a [`Expr`]) to an Iceberg [`Predicate`]. /// -/// This function handles the conversion of certain DataFusion expression types -/// to their corresponding Iceberg predicates. It supports the following cases: +/// This function currently handles the conversion of DataFusion expression of the following types: /// /// 1. Simple binary expressions (e.g., "column < value") /// 2. Compound AND expressions (e.g., "x < 1 AND y > 10") @@ -188,13 +187,13 @@ fn expr_to_predicate(expr: &Expr) -> Option { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => { match (left.as_ref(), op, right.as_ref()) { - // first option: x < 1 + // First option arm (simple case), e.g. x < 1 (Expr::Column(col), op, Expr::Literal(lit)) => { let reference = Reference::new(col.name.clone()); let datum = scalar_value_to_datum(lit)?; Some(binary_op_to_predicate(reference, op, datum)) } - // second option (inner AND): x < 1 AND y > 10 + // Second option arm (inner AND), e.g. x < 1 AND y > 10 // if its an AND expression and one predicate fails, we can still go with the other one (left_expr, Operator::And, right_expr) => { let left_pred = expr_to_predicate(&left_expr.clone()); @@ -206,8 +205,8 @@ fn expr_to_predicate(expr: &Expr) -> Option { (None, None) => None, } } - // third option (inner OR): x < 1 OR y > 10 - // if one is unsuported, we need to fail the predicate + // Third option arm (inner OR), e.g. x < 1 OR y > 10 + // if one is unsuported, we fail the predicate (Expr::BinaryExpr(left_expr), Operator::Or, Expr::BinaryExpr(right_expr)) => { let left_pred = expr_to_predicate(&Expr::BinaryExpr(left_expr.clone()))?; let right_pred = expr_to_predicate(&Expr::BinaryExpr(right_expr.clone()))?; @@ -292,11 +291,11 @@ mod tests { .parse_sql_expr(sql, &df_schema) .unwrap(); let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let inner_predicate = Predicate::and( + let expected_predicate = Predicate::and( Reference::new("foo").greater_than(Datum::long(1)), Reference::new("bar").equal_to(Datum::string("test")), ); - assert_eq!(predicate, inner_predicate); + assert_eq!(predicate, expected_predicate); } #[test] diff --git a/crates/integrations/datafusion/tests/integration_datafusion_test.rs b/crates/integrations/datafusion/tests/integration_datafusion_test.rs index afb383904..9e62930fd 100644 --- a/crates/integrations/datafusion/tests/integration_datafusion_test.rs +++ b/crates/integrations/datafusion/tests/integration_datafusion_test.rs @@ -147,39 +147,3 @@ async fn test_provider_list_schema_names() -> Result<()> { .all(|item| result.contains(&item.to_string()))); Ok(()) } -#[tokio::test] -async fn test_table_scan() -> Result<()> { - let iceberg_catalog = get_iceberg_catalog(); - let namespace = NamespaceIdent::new("test_provider_list_table_names".to_string()); - set_test_namespace(&iceberg_catalog, &namespace).await?; - let creation = set_table_creation(temp_path(), "my_table")?; - let new_table = iceberg_catalog.create_table(&namespace, creation).await?; - let client = Arc::new(iceberg_catalog); - let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); - let ctx = SessionContext::new(); - - ctx.register_catalog("catalog", catalog); - let df = ctx - .sql("select * from catalog.test_provider_list_table_names.my_table where (foo > 1 and bar = 'test') or foo < 0 ") - .await - .unwrap(); - - let compute_result = df.collect().await; - if let Ok(df) = compute_result { - println!("==> compute_result OK: {:?}", df); - } else { - println!( - "==> compute_result ERROR: {:?}", - compute_result.err().unwrap() - ); - } - let provider = ctx.catalog("catalog").unwrap(); - let schema = provider.schema("test_provider_list_table_names").unwrap(); - - let expected = vec!["my_table"]; - let result = schema.table_names(); - - assert_eq!(result, expected); - - Ok(()) -} From afc9c13b8da05ec764d183dc3a355952ab0c5ae6 Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Wed, 28 Aug 2024 11:01:33 +0300 Subject: [PATCH 03/13] fixing typos and lints --- crates/integrations/datafusion/src/physical_plan/scan.rs | 6 +++--- crates/integrations/datafusion/src/table.rs | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index 16fe4aa77..1c572d6fc 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -182,7 +182,6 @@ fn convert_filters_to_predicate(filters: &[Expr]) -> Option { /// /// * `Some(Predicate)` if the expression could be successfully converted. /// * `None` if the expression couldn't be converted to an Iceberg predicate. -/// fn expr_to_predicate(expr: &Expr) -> Option { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => { @@ -206,7 +205,7 @@ fn expr_to_predicate(expr: &Expr) -> Option { } } // Third option arm (inner OR), e.g. x < 1 OR y > 10 - // if one is unsuported, we fail the predicate + // if one is unsupported, we fail the predicate (Expr::BinaryExpr(left_expr), Operator::Or, Expr::BinaryExpr(right_expr)) => { let left_pred = expr_to_predicate(&Expr::BinaryExpr(left_expr.clone()))?; let right_pred = expr_to_predicate(&Expr::BinaryExpr(right_expr.clone()))?; @@ -248,11 +247,12 @@ fn scalar_value_to_datum(value: &ScalarValue) -> Option { #[cfg(test)] mod tests { - use super::*; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::DFSchema; use datafusion::prelude::SessionContext; + use super::*; + fn create_test_schema() -> DFSchema { let arrow_schema = Schema::new(vec![ Field::new("foo", DataType::Int32, false), diff --git a/crates/integrations/datafusion/src/table.rs b/crates/integrations/datafusion/src/table.rs index 939236630..98386efe9 100644 --- a/crates/integrations/datafusion/src/table.rs +++ b/crates/integrations/datafusion/src/table.rs @@ -18,19 +18,19 @@ use std::any::Any; use std::sync::Arc; -use crate::physical_plan::scan::IcebergTableScan; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::catalog::Session; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result as DFResult; -use datafusion::logical_expr::Expr; -use datafusion::logical_expr::{BinaryExpr, TableProviderFilterPushDown}; +use datafusion::logical_expr::{BinaryExpr, Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::ExecutionPlan; use iceberg::arrow::schema_to_arrow_schema; use iceberg::table::Table; use iceberg::{Catalog, NamespaceIdent, Result, TableIdent}; +use crate::physical_plan::scan::IcebergTableScan; + /// Represents a [`TableProvider`] for the Iceberg [`Catalog`], /// managing access to a [`Table`]. pub(crate) struct IcebergTableProvider { From 9f88a5a0cda9cb483bcf5c695e423678c3e9b2a8 Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Thu, 29 Aug 2024 08:42:01 +0300 Subject: [PATCH 04/13] fixing typing issue --- .../datafusion/src/physical_plan/scan.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index 1c572d6fc..e9a569bb8 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -233,15 +233,19 @@ fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> /// convert a DataFusion scalar value to an iceberg [`Datum`] fn scalar_value_to_datum(value: &ScalarValue) -> Option { match value { - ScalarValue::Int8(Some(v)) => Some(Datum::long(*v as i64)), - ScalarValue::Int16(Some(v)) => Some(Datum::long(*v as i64)), - ScalarValue::Int32(Some(v)) => Some(Datum::long(*v as i64)), + ScalarValue::Int8(Some(v)) => Some(Datum::int(*v)), + ScalarValue::Int16(Some(v)) => Some(Datum::int(*v)), + ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)), ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)), ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)), ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)), ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())), + ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())), // Add more cases as needed - _ => None, + _ => { + println!("unsupported scalar value: {:?}", value); + None + } } } From e473471af080430e7d796dbe9e2f7dde0895fd73 Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Sun, 1 Sep 2024 12:16:55 +0300 Subject: [PATCH 05/13] - added support in schmema to convert Date32 to correct arrow type - refactored scan to use new predicate converter as visitor and seperated it to a new mod - added support for simple predicates with column cast expressions - added testing, mostly around date functions --- crates/iceberg/src/arrow/schema.rs | 7 +- .../datafusion/src/physical_plan/mod.rs | 1 + .../src/physical_plan/predicate_converter.rs | 137 ++++++++++++ .../datafusion/src/physical_plan/scan.rs | 196 ++++++++---------- crates/integrations/datafusion/src/table.rs | 13 +- 5 files changed, 240 insertions(+), 114 deletions(-) create mode 100644 crates/integrations/datafusion/src/physical_plan/predicate_converter.rs diff --git a/crates/iceberg/src/arrow/schema.rs b/crates/iceberg/src/arrow/schema.rs index a41243756..1e48c4a3c 100644 --- a/crates/iceberg/src/arrow/schema.rs +++ b/crates/iceberg/src/arrow/schema.rs @@ -24,8 +24,8 @@ use arrow_array::types::{ validate_decimal_precision_and_scale, Decimal128Type, TimestampMicrosecondType, }; use arrow_array::{ - BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array, - PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray, + BooleanArray, Date32Array, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, + Int64Array, PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray, }; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; use bitvec::macros::internal::funty::Fundamental; @@ -634,6 +634,9 @@ pub(crate) fn get_arrow_datum(datum: &Datum) -> Result { Ok(Box::new(StringArray::new_scalar(value.as_str()))) } + (PrimitiveType::Date, PrimitiveLiteral::Int(value)) => { + Ok(Box::new(Date32Array::new_scalar(*value))) + } (PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => { Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value))) } diff --git a/crates/integrations/datafusion/src/physical_plan/mod.rs b/crates/integrations/datafusion/src/physical_plan/mod.rs index 5ae586a0a..af2e68f2e 100644 --- a/crates/integrations/datafusion/src/physical_plan/mod.rs +++ b/crates/integrations/datafusion/src/physical_plan/mod.rs @@ -16,3 +16,4 @@ // under the License. pub(crate) mod scan; +pub(crate) mod predicate_converter; diff --git a/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs b/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs new file mode 100644 index 000000000..ea5f42206 --- /dev/null +++ b/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs @@ -0,0 +1,137 @@ +use datafusion::logical_expr::{BinaryExpr, Cast, Operator}; +use datafusion::{arrow::datatypes::DataType, logical_expr::Expr, scalar::ScalarValue}; +use iceberg::expr::{Predicate, Reference}; +use iceberg::spec::Datum; +#[derive(Default)] +pub struct PredicateConverter; + +impl PredicateConverter { + /// Convert a list of DataFusion expressions to an iceberg predicate. + pub fn visit_many(&self, exprs: &[Expr]) -> Option { + exprs + .iter() + .filter_map(|expr| self.visit(expr)) + .reduce(Predicate::and) + } + + /// Convert a single DataFusion expression to an iceberg predicate. + /// currently only supports binary (simple) expressions + pub fn visit(&self, expr: &Expr) -> Option { + match expr { + Expr::BinaryExpr(binary) => self.visit_binary_expr(binary), + _ => None, + } + } + + /// Convert a binary expression to an iceberg predicate. + /// + /// currently supports: + /// - column, basic op, and literal, e.g. `a = 1` + /// - column and casted literal, e.g. `a = cast(1 as bigint)` + /// - binary conditional (and, or), e.g. `a = 1 and b = 2` + fn visit_binary_expr(&self, binary: &BinaryExpr) -> Option { + match (&*binary.left, &binary.op, &*binary.right) { + // column, op, literal + (Expr::Column(col), op, Expr::Literal(lit)) => self.visit_column_literal(col, op, lit), + // column, op, casted literal + (Expr::Column(col), op, Expr::Cast(Cast { expr, data_type })) => { + self.visit_column_cast(col, op, expr, data_type) + } + // binary conditional (and, or) + (left, op, right) if matches!(op, Operator::And | Operator::Or) => { + self.visit_binary_conditional(left, op, right) + } + _ => None, + } + } + + /// Convert a column and casted literal to an iceberg predicate. + /// The purpose of this function is to handle the common case in which there is a filter based on a casted literal. + /// These kinds of expressions are often not pushed down by query engines though its an important case to handle + /// for iceberg scan pushdown. + fn visit_column_cast( + &self, + col: &datafusion::common::Column, + op: &Operator, + expr: &Expr, + data_type: &DataType, + ) -> Option { + if let (Expr::Literal(ScalarValue::Utf8(lit)), DataType::Date32) = (expr, data_type) { + let reference = Reference::new(col.name.clone()); + let datum = lit + .clone() + .and_then(|date_str| Datum::date_from_str(date_str).ok())?; + return Some(binary_op_to_predicate(reference, op, datum)); + } + None + } + + /// Convert a binary conditional expression, i.e., (and, or), to an iceberg predicate. + /// + /// When processing an AND expression: + /// - if both expressions are valid predicates then an AND predicate is returned + /// - if either expression is None then the valid one is returned + /// + /// When processing an OR expression: + /// - only if both expressions are valid predicates then an OR predicate is returned + fn visit_binary_conditional( + &self, + left: &Expr, + op: &Operator, + right: &Expr, + ) -> Option { + let preds: Vec = vec![self.visit(left), self.visit(right)] + .into_iter() + .flatten() + .collect(); + match (op, preds.len()) { + (Operator::And, 1) => preds.first().cloned(), + (Operator::And, 2) => Some(Predicate::and(preds[0].clone(), preds[1].clone())), + (Operator::Or, 2) => Some(Predicate::or(preds[0].clone(), preds[1].clone())), + _ => None, + } + } + + /// Convert a simple expression based on column and literal (x > 1) to an iceberg predicate. + fn visit_column_literal( + &self, + col: &datafusion::common::Column, + op: &Operator, + lit: &ScalarValue, + ) -> Option { + let reference = Reference::new(col.name.clone()); + let datum = scalar_value_to_datum(lit)?; + Some(binary_op_to_predicate(reference, op, datum)) + } +} + +const MILLIS_PER_DAY: i64 = 24 * 60 * 60 * 1000; +/// Convert a scalar value to an iceberg datum. +fn scalar_value_to_datum(value: &ScalarValue) -> Option { + match value { + ScalarValue::Int8(Some(v)) => Some(Datum::int(*v as i32)), + ScalarValue::Int16(Some(v)) => Some(Datum::int(*v as i32)), + ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)), + ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)), + ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)), + ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)), + ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())), + ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())), + ScalarValue::Date32(Some(v)) => Some(Datum::date(*v)), + ScalarValue::Date64(Some(v)) => Some(Datum::date((*v / MILLIS_PER_DAY) as i32)), + _ => None, + } +} + +/// convert the data fusion Exp to an iceberg [`Predicate`] +fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate { + match op { + Operator::Eq => reference.equal_to(datum), + Operator::NotEq => reference.not_equal_to(datum), + Operator::Lt => reference.less_than(datum), + Operator::LtEq => reference.less_than_or_equal_to(datum), + Operator::Gt => reference.greater_than(datum), + Operator::GtEq => reference.greater_than_or_equal_to(datum), + _ => Predicate::AlwaysTrue, + } +} diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index e9a569bb8..27f8b5672 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -15,26 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; - +use super::predicate_converter::PredicateConverter; use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::error::Result as DFResult; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; -use datafusion::logical_expr::{BinaryExpr, Operator}; use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, }; use datafusion::prelude::Expr; -use datafusion::scalar::ScalarValue; use futures::{Stream, TryStreamExt}; -use iceberg::expr::{Predicate, Reference}; -use iceberg::spec::Datum; +use iceberg::expr::Predicate; use iceberg::table::Table; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; use crate::to_datafusion_error; @@ -141,7 +138,6 @@ async fn get_batch_stream( scan_builder = scan_builder.with_filter(pred); } let table_scan = scan_builder.build().map_err(to_datafusion_error)?; - let stream = table_scan .to_arrow() .await @@ -151,102 +147,11 @@ async fn get_batch_stream( Ok(Box::pin(stream)) } -/// convert DataFusion filters ([`Expr`]) to an iceberg [`Predicate`] -/// if none of the filters could be converted, return `None` -/// if the conversion was successful, return the converted predicates combined with an AND operator +/// Converts DataFusion filters ([`Expr`]) to an iceberg [`Predicate`]. +/// If none of the filters could be converted, return `None` which adds no predicates to the scan operation. +/// If the conversion was successful, return the converted predicates combined with an AND operator. fn convert_filters_to_predicate(filters: &[Expr]) -> Option { - filters - .iter() - .filter_map(expr_to_predicate) - .reduce(Predicate::and) -} - -/// Recuresivly converting DataFusion filters ( in a [`Expr`]) to an Iceberg [`Predicate`]. -/// -/// This function currently handles the conversion of DataFusion expression of the following types: -/// -/// 1. Simple binary expressions (e.g., "column < value") -/// 2. Compound AND expressions (e.g., "x < 1 AND y > 10") -/// 3. Compound OR expressions (e.g., "x < 1 OR y > 10") -/// -/// For AND expressions, if one part of the expression can't be converted, -/// the function will still return a predicate for the part that can be converted. -/// For OR expressions, if any part can't be converted, the entire expression -/// will fail to convert. -/// -/// # Arguments -/// -/// * `expr` - A reference to a DataFusion [`Expr`] to be converted. -/// -/// # Returns -/// -/// * `Some(Predicate)` if the expression could be successfully converted. -/// * `None` if the expression couldn't be converted to an Iceberg predicate. -fn expr_to_predicate(expr: &Expr) -> Option { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match (left.as_ref(), op, right.as_ref()) { - // First option arm (simple case), e.g. x < 1 - (Expr::Column(col), op, Expr::Literal(lit)) => { - let reference = Reference::new(col.name.clone()); - let datum = scalar_value_to_datum(lit)?; - Some(binary_op_to_predicate(reference, op, datum)) - } - // Second option arm (inner AND), e.g. x < 1 AND y > 10 - // if its an AND expression and one predicate fails, we can still go with the other one - (left_expr, Operator::And, right_expr) => { - let left_pred = expr_to_predicate(&left_expr.clone()); - let right_pred = expr_to_predicate(&right_expr.clone()); - match (left_pred, right_pred) { - (Some(left), Some(right)) => Some(Predicate::and(left, right)), - (Some(left), None) => Some(left), - (None, Some(right)) => Some(right), - (None, None) => None, - } - } - // Third option arm (inner OR), e.g. x < 1 OR y > 10 - // if one is unsupported, we fail the predicate - (Expr::BinaryExpr(left_expr), Operator::Or, Expr::BinaryExpr(right_expr)) => { - let left_pred = expr_to_predicate(&Expr::BinaryExpr(left_expr.clone()))?; - let right_pred = expr_to_predicate(&Expr::BinaryExpr(right_expr.clone()))?; - Some(Predicate::or(left_pred, right_pred)) - } - _ => None, - } - } - _ => None, - } -} - -/// convert the data fusion Exp to an iceberg [`Predicate`] -fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate { - match op { - Operator::Eq => reference.equal_to(datum), - Operator::NotEq => reference.not_equal_to(datum), - Operator::Lt => reference.less_than(datum), - Operator::LtEq => reference.less_than_or_equal_to(datum), - Operator::Gt => reference.greater_than(datum), - Operator::GtEq => reference.greater_than_or_equal_to(datum), - _ => Predicate::AlwaysTrue, - } -} -/// convert a DataFusion scalar value to an iceberg [`Datum`] -fn scalar_value_to_datum(value: &ScalarValue) -> Option { - match value { - ScalarValue::Int8(Some(v)) => Some(Datum::int(*v)), - ScalarValue::Int16(Some(v)) => Some(Datum::int(*v)), - ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)), - ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)), - ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)), - ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)), - ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())), - ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())), - // Add more cases as needed - _ => { - println!("unsupported scalar value: {:?}", value); - None - } - } + PredicateConverter.visit_many(filters) } #[cfg(test)] @@ -254,6 +159,8 @@ mod tests { use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::DFSchema; use datafusion::prelude::SessionContext; + use iceberg::expr::Reference; + use iceberg::spec::Datum; use super::*; @@ -266,6 +173,7 @@ mod tests { } fn create_test_schema_b() -> DFSchema { let arrow_schema = Schema::new(vec![ + Field::new("dt", DataType::Date32, false), Field::new("xxx", DataType::Int32, false), Field::new("yyy", DataType::Utf8, false), Field::new("zzz", DataType::Int32, false), @@ -377,4 +285,84 @@ mod tests { let expected_predicate = Reference::new("zzz").less_than(Datum::long(0)); assert_eq!(predicate, expected_predicate); } + #[test] + fn test_predicate_conversion_with_unsupported_condition() { + let sql = "yyy is not null"; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]); + assert_eq!(predicate, None); + } + #[test] + fn test_predicate_conversion_with_unsupported_condition_2() { + let sql = "yyy is not null and xxx > 1"; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let expected_predicate = Reference::new("xxx").greater_than(Datum::long(1)); + assert_eq!(predicate, expected_predicate); + } + #[test] + fn test_predicate_conversion_with_date() { + let sql = "dt > date '2024-02-29' and xxx = 1"; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let expected_predicate = Predicate::and( + Reference::new("dt").greater_than(Datum::date_from_ymd(2024, 2, 29).unwrap()), + Reference::new("xxx").equal_to(Datum::long(1)), + ); + assert_eq!(predicate, expected_predicate); + } + #[test] + fn test_predicate_conversion_with_date_or() { + let sql = "dt > date '2024-02-29' or xxx = 1"; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let expected_predicate = Predicate::or( + Reference::new("dt").greater_than(Datum::date_from_ymd(2024, 2, 29).unwrap()), + Reference::new("xxx").equal_to(Datum::long(1)), + ); + assert_eq!(predicate, expected_predicate); + } + #[test] + fn test_predicate_conversion_with_unsupported_date() { + let sql = "dt > date '2024-02-29-08'"; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]); + assert_eq!(predicate, None); + } + #[test] + fn test_predicate_conversion_with_unsupported_date_or() { + let sql = "dt > date '2024-02-29-08' or xxx = 1"; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]); + assert_eq!(predicate, None); + } + #[test] + fn test_predicate_conversion_with_unsupported_date_and() { + let sql = "dt > date '2024-02-29-08' and xxx = 1"; + let df_schema = create_test_schema_b(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let predicate = convert_filters_to_predicate(&[expr]).unwrap(); + let expected_predicate = Reference::new("xxx").equal_to(Datum::long(1)); + assert_eq!(predicate, expected_predicate); + } } diff --git a/crates/integrations/datafusion/src/table.rs b/crates/integrations/datafusion/src/table.rs index 98386efe9..14da61543 100644 --- a/crates/integrations/datafusion/src/table.rs +++ b/crates/integrations/datafusion/src/table.rs @@ -76,13 +76,13 @@ impl TableProvider for IcebergTableProvider { &self, _state: &dyn Session, _projection: Option<&Vec>, - _filters: &[Expr], + filters: &[Expr], _limit: Option, ) -> DFResult> { Ok(Arc::new(IcebergTableScan::new( self.table.clone(), self.schema.clone(), - _filters, + filters, ))) } @@ -93,12 +93,9 @@ impl TableProvider for IcebergTableProvider { { let filter_support = filters .iter() - .map(|e| { - if let Expr::BinaryExpr(BinaryExpr { .. }) = e { - TableProviderFilterPushDown::Inexact - } else { - TableProviderFilterPushDown::Unsupported - } + .map(|e| match e { + Expr::BinaryExpr(BinaryExpr { .. }) => TableProviderFilterPushDown::Inexact, + _ => TableProviderFilterPushDown::Unsupported, }) .collect::>(); From 9d2112d48d3ffe9867568ad4e29d7ae7e812f44d Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Sun, 1 Sep 2024 12:31:05 +0300 Subject: [PATCH 06/13] fixing format and lic --- .../datafusion/src/physical_plan/mod.rs | 2 +- .../src/physical_plan/predicate_converter.rs | 22 +++++++++++++++++-- .../datafusion/src/physical_plan/scan.rs | 9 ++++---- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/mod.rs b/crates/integrations/datafusion/src/physical_plan/mod.rs index af2e68f2e..87a11e282 100644 --- a/crates/integrations/datafusion/src/physical_plan/mod.rs +++ b/crates/integrations/datafusion/src/physical_plan/mod.rs @@ -15,5 +15,5 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod scan; pub(crate) mod predicate_converter; +pub(crate) mod scan; diff --git a/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs b/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs index ea5f42206..e24f5a986 100644 --- a/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs +++ b/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs @@ -1,5 +1,23 @@ -use datafusion::logical_expr::{BinaryExpr, Cast, Operator}; -use datafusion::{arrow::datatypes::DataType, logical_expr::Expr, scalar::ScalarValue}; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::{BinaryExpr, Cast, Expr, Operator}; +use datafusion::scalar::ScalarValue; use iceberg::expr::{Predicate, Reference}; use iceberg::spec::Datum; #[derive(Default)] diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index 27f8b5672..f91841f35 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use super::predicate_converter::PredicateConverter; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; + use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::error::Result as DFResult; @@ -29,10 +32,8 @@ use datafusion::prelude::Expr; use futures::{Stream, TryStreamExt}; use iceberg::expr::Predicate; use iceberg::table::Table; -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; +use super::predicate_converter::PredicateConverter; use crate::to_datafusion_error; /// Manages the scanning process of an Iceberg [`Table`], encapsulating the From f042ddcab8c9392a7ed21ccfb32698627b861857 Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Sun, 1 Sep 2024 16:14:46 +0300 Subject: [PATCH 07/13] reducing number of tests (17 -> 7) --- .../datafusion/src/physical_plan/scan.rs | 148 ++++-------------- 1 file changed, 28 insertions(+), 120 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index f91841f35..8ec9c3a87 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -29,7 +29,7 @@ use datafusion::physical_plan::{ DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, }; use datafusion::prelude::Expr; -use futures::{Stream, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use iceberg::expr::Predicate; use iceberg::table::Table; @@ -144,7 +144,6 @@ async fn get_batch_stream( .await .map_err(to_datafusion_error)? .map_err(to_datafusion_error); - Ok(Box::pin(stream)) } @@ -195,175 +194,84 @@ mod tests { Reference::new("foo").greater_than(Datum::long(1)) ); } - #[test] - fn test_predicate_conversion_with_multiple_conditions() { - let sql = "foo > 1 and bar = 'test'"; + fn test_predicate_conversion_with_single_unsupported_condition() { + let sql = "foo is null"; let df_schema = create_test_schema(); let expr = SessionContext::new() .parse_sql_expr(sql, &df_schema) .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Predicate::and( - Reference::new("foo").greater_than(Datum::long(1)), - Reference::new("bar").equal_to(Datum::string("test")), - ); - assert_eq!(predicate, expected_predicate); + let predicate = convert_filters_to_predicate(&[expr]); + assert_eq!(predicate, None); } #[test] - fn test_predicate_conversion_with_multiple_binary_expr() { - let sql = "(foo > 1 and bar = 'test') or foo < 0 "; + fn test_predicate_conversion_with_and_condition() { + let sql = "foo > 1 and bar = 'test'"; let df_schema = create_test_schema(); let expr = SessionContext::new() .parse_sql_expr(sql, &df_schema) .unwrap(); let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let inner_predicate = Predicate::and( + let expected_predicate = Predicate::and( Reference::new("foo").greater_than(Datum::long(1)), Reference::new("bar").equal_to(Datum::string("test")), ); - let expected_predicate = Predicate::or( - inner_predicate, - Reference::new("foo").less_than(Datum::long(0)), - ); assert_eq!(predicate, expected_predicate); } #[test] - fn test_predicate_conversion_with_unsupported_condition_not() { - let sql = "xxx > 1 and yyy is not null and zzz < 0 "; - let df_schema = create_test_schema_b(); + fn test_predicate_conversion_with_and_condition_unsupported() { + let sql = "foo > 1 and bar is not null"; + let df_schema = create_test_schema(); let expr = SessionContext::new() .parse_sql_expr(sql, &df_schema) .unwrap(); let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Predicate::and( - Reference::new("xxx").greater_than(Datum::long(1)), - Reference::new("zzz").less_than(Datum::long(0)), - ); + let expected_predicate = Reference::new("foo").greater_than(Datum::long(1)); assert_eq!(predicate, expected_predicate); } #[test] - fn test_predicate_conversion_with_unsupported_condition_and() { - let sql = "(xxx > 1 and yyy in ('test', 'test2')) and zzz < 0 "; - let df_schema = create_test_schema_b(); + fn test_predicate_conversion_with_or_condition_unsupported() { + let sql = "foo > 1 or bar is not null"; + let df_schema = create_test_schema(); let expr = SessionContext::new() .parse_sql_expr(sql, &df_schema) .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Predicate::and( - Reference::new("xxx").greater_than(Datum::long(1)), - Reference::new("zzz").less_than(Datum::long(0)), - ); + let predicate = convert_filters_to_predicate(&[expr]); + let expected_predicate = None; assert_eq!(predicate, expected_predicate); } #[test] - fn test_predicate_conversion_with_unsupported_condition_or() { - let sql = "(foo > 1 and bar in ('test', 'test2')) or foo < 0 "; + fn test_predicate_conversion_with_complex_binary_expr() { + let sql = "(foo > 1 and bar = 'test') or foo < 0 "; let df_schema = create_test_schema(); let expr = SessionContext::new() .parse_sql_expr(sql, &df_schema) .unwrap(); let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Predicate::or( + let inner_predicate = Predicate::and( Reference::new("foo").greater_than(Datum::long(1)), - Reference::new("foo").less_than(Datum::long(0)), - ); - assert_eq!(predicate, expected_predicate); - } - - #[test] - fn test_predicate_conversion_with_unsupported_expr() { - let sql = "(xxx > 1 or yyy in ('test', 'test2')) and zzz < 0 "; - let df_schema = create_test_schema_b(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Reference::new("zzz").less_than(Datum::long(0)); - assert_eq!(predicate, expected_predicate); - } - #[test] - fn test_predicate_conversion_with_unsupported_condition() { - let sql = "yyy is not null"; - let df_schema = create_test_schema_b(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]); - assert_eq!(predicate, None); - } - #[test] - fn test_predicate_conversion_with_unsupported_condition_2() { - let sql = "yyy is not null and xxx > 1"; - let df_schema = create_test_schema_b(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Reference::new("xxx").greater_than(Datum::long(1)); - assert_eq!(predicate, expected_predicate); - } - #[test] - fn test_predicate_conversion_with_date() { - let sql = "dt > date '2024-02-29' and xxx = 1"; - let df_schema = create_test_schema_b(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Predicate::and( - Reference::new("dt").greater_than(Datum::date_from_ymd(2024, 2, 29).unwrap()), - Reference::new("xxx").equal_to(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), ); - assert_eq!(predicate, expected_predicate); - } - #[test] - fn test_predicate_conversion_with_date_or() { - let sql = "dt > date '2024-02-29' or xxx = 1"; - let df_schema = create_test_schema_b(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); let expected_predicate = Predicate::or( - Reference::new("dt").greater_than(Datum::date_from_ymd(2024, 2, 29).unwrap()), - Reference::new("xxx").equal_to(Datum::long(1)), + inner_predicate, + Reference::new("foo").less_than(Datum::long(0)), ); assert_eq!(predicate, expected_predicate); } + #[test] - fn test_predicate_conversion_with_unsupported_date() { - let sql = "dt > date '2024-02-29-08'"; - let df_schema = create_test_schema_b(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]); - assert_eq!(predicate, None); - } - #[test] - fn test_predicate_conversion_with_unsupported_date_or() { - let sql = "dt > date '2024-02-29-08' or xxx = 1"; - let df_schema = create_test_schema_b(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]); - assert_eq!(predicate, None); - } - #[test] - fn test_predicate_conversion_with_unsupported_date_and() { - let sql = "dt > date '2024-02-29-08' and xxx = 1"; - let df_schema = create_test_schema_b(); + fn test_predicate_conversion_with_complex_binary_expr_unsupported() { + let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 "; + let df_schema = create_test_schema(); let expr = SessionContext::new() .parse_sql_expr(sql, &df_schema) .unwrap(); let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Reference::new("xxx").equal_to(Datum::long(1)); + let expected_predicate = Reference::new("foo").less_than(Datum::long(0)); assert_eq!(predicate, expected_predicate); } } From d9f7e3fa940a1e4a122261ffe2c87d0a51e5eb9c Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Sun, 1 Sep 2024 16:49:42 +0300 Subject: [PATCH 08/13] fix formats --- .../integrations/datafusion/src/physical_plan/scan.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index 8ec9c3a87..26ec89b34 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -29,7 +29,7 @@ use datafusion::physical_plan::{ DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, }; use datafusion::prelude::Expr; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{Stream, TryStreamExt}; use iceberg::expr::Predicate; use iceberg::table::Table; @@ -171,15 +171,6 @@ mod tests { ]); DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() } - fn create_test_schema_b() -> DFSchema { - let arrow_schema = Schema::new(vec![ - Field::new("dt", DataType::Date32, false), - Field::new("xxx", DataType::Int32, false), - Field::new("yyy", DataType::Utf8, false), - Field::new("zzz", DataType::Int32, false), - ]); - DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() - } #[test] fn test_predicate_conversion_with_single_condition() { From cbbf3a6fc35c3b29034141e030c898edecb0edd2 Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Mon, 2 Sep 2024 11:34:28 +0300 Subject: [PATCH 09/13] fix naming --- .../datafusion/src/physical_plan/predicate_converter.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs b/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs index e24f5a986..56ae4aa18 100644 --- a/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs +++ b/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs @@ -102,7 +102,8 @@ impl PredicateConverter { .into_iter() .flatten() .collect(); - match (op, preds.len()) { + let num_valid_preds = preds.len(); + match (op, num_valid_preds) { (Operator::And, 1) => preds.first().cloned(), (Operator::And, 2) => Some(Predicate::and(preds[0].clone(), preds[1].clone())), (Operator::Or, 2) => Some(Predicate::or(preds[0].clone(), preds[1].clone())), From e864fd148898fe68a68ccb18b9145c67add83fad Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Mon, 2 Sep 2024 22:17:04 +0300 Subject: [PATCH 10/13] refactoring to use TreeNodeVisitor --- .../src/physical_plan/expr_to_predicate.rs | 312 ++++++++++++++++++ .../datafusion/src/physical_plan/mod.rs | 2 +- .../src/physical_plan/predicate_converter.rs | 156 --------- .../datafusion/src/physical_plan/scan.rs | 128 +------ 4 files changed, 326 insertions(+), 272 deletions(-) create mode 100644 crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs delete mode 100644 crates/integrations/datafusion/src/physical_plan/predicate_converter.rs diff --git a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs new file mode 100644 index 000000000..da6cc77ef --- /dev/null +++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; +use datafusion::common::Column; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{Expr, Operator}; +use datafusion::scalar::ScalarValue; +use iceberg::expr::{Predicate, Reference}; +use iceberg::spec::Datum; +use std::collections::VecDeque; + +pub struct ExprToPredicateVisitor { + stack: VecDeque>, +} +impl ExprToPredicateVisitor { + /// Create a new predicate conversion visitor. + pub fn new() -> Self { + Self { + stack: VecDeque::new(), + } + } + /// Get the predicate from the stack. + pub fn get_predicate(&self) -> Option { + self.stack + .iter() + .filter_map(|opt| opt.clone()) + .reduce(Predicate::and) + } + + /// Convert a column expression to an iceberg predicate. + fn convert_column_expr( + &self, + col: &Column, + op: &Operator, + lit: &ScalarValue, + ) -> Option { + let reference = Reference::new(col.name.clone()); + let datum = scalar_value_to_datum(lit)?; + Some(binary_op_to_predicate(reference, op, datum)) + } + + /// Convert a compound expression to an iceberg predicate. + /// + /// The strategy is to support the following cases: + /// - if its an AND expression then the result will be the valid predicates, whether there are 2 or just 1 + /// - if its an OR expression then a predicate will be returned only if there are 2 valid predicates on both sides + fn convert_compound_expr(&self, valid_preds: &[Predicate], op: &Operator) -> Option { + let valid_preds_count = valid_preds.len(); + match (op, valid_preds_count) { + (Operator::And, 1) => valid_preds.first().cloned(), + (Operator::And, 2) => Some(Predicate::and( + valid_preds[0].clone(), + valid_preds[1].clone(), + )), + (Operator::Or, 2) => Some(Predicate::or( + valid_preds[0].clone(), + valid_preds[1].clone(), + )), + _ => None, + } + } +} + +// Implement TreeNodeVisitor for ExprToPredicateVisitor +impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor { + type Node = Expr; + + fn f_down(&mut self, _node: &'n Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, node: &'n Self::Node) -> Result { + if let Expr::BinaryExpr(binary) = node { + match (&*binary.left, &binary.op, &*binary.right) { + (Expr::Column(col), op, Expr::Literal(lit)) => { + let col_pred = self.convert_column_expr(col, op, lit); + self.stack.push_back(col_pred); + } + (_left, op, _right) if matches!(op, Operator::And | Operator::Or) => { + let right_pred = self.stack.pop_back().flatten(); + let left_pred = self.stack.pop_back().flatten(); + let valid_preds = [left_pred, right_pred] + .into_iter() + .flatten() + .collect::>(); + let compound_pred = self.convert_compound_expr(&valid_preds, op); + self.stack.push_back(compound_pred); + } + _ => {} + } + }; + Ok(TreeNodeRecursion::Continue) + } +} + +const MILLIS_PER_DAY: i64 = 24 * 60 * 60 * 1000; +/// Convert a scalar value to an iceberg datum. +fn scalar_value_to_datum(value: &ScalarValue) -> Option { + match value { + ScalarValue::Int8(Some(v)) => Some(Datum::int(*v as i32)), + ScalarValue::Int16(Some(v)) => Some(Datum::int(*v as i32)), + ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)), + ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)), + ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)), + ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)), + ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())), + ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())), + ScalarValue::Date32(Some(v)) => Some(Datum::date(*v)), + ScalarValue::Date64(Some(v)) => Some(Datum::date((*v / MILLIS_PER_DAY) as i32)), + _ => None, + } +} + +/// convert the data fusion Exp to an iceberg [`Predicate`] +fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate { + match op { + Operator::Eq => reference.equal_to(datum), + Operator::NotEq => reference.not_equal_to(datum), + Operator::Lt => reference.less_than(datum), + Operator::LtEq => reference.less_than_or_equal_to(datum), + Operator::Gt => reference.greater_than(datum), + Operator::GtEq => reference.greater_than_or_equal_to(datum), + _ => Predicate::AlwaysTrue, + } +} + +#[cfg(test)] +mod tests { + use std::collections::VecDeque; + + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::tree_node::TreeNode; + use datafusion::common::DFSchema; + use datafusion::prelude::SessionContext; + use iceberg::expr::{Predicate, Reference}; + use iceberg::spec::Datum; + + use super::ExprToPredicateVisitor; + + fn create_test_schema() -> DFSchema { + let arrow_schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Utf8, false), + ]); + DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() + } + + #[test] + fn test_predicate_conversion_with_single_condition() { + let sql = "foo > 1"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + assert_eq!( + predicate, + Reference::new("foo").greater_than(Datum::long(1)) + ); + } + #[test] + fn test_predicate_conversion_with_single_unsupported_condition() { + let sql = "foo is null"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate(); + assert_eq!(predicate, None); + } + #[test] + fn test_predicate_conversion_with_and_condition() { + let sql = "foo > 1 and bar = 'test'"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + let expected_predicate = Predicate::and( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + ); + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_and_condition_unsupported() { + let sql = "foo > 1 and bar is not null"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + let expected_predicate = Reference::new("foo").greater_than(Datum::long(1)); + assert_eq!(predicate, expected_predicate); + } + #[test] + fn test_predicate_conversion_with_and_condition_both_unsupported() { + let sql = "foo in (1, 2, 3) and bar is not null"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate(); + let expected_predicate = None; + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_or_condition_unsupported() { + let sql = "foo > 1 or bar is not null"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate(); + let expected_predicate = None; + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_complex_binary_expr() { + let sql = "(foo > 1 and bar = 'test') or foo < 0 "; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + let inner_predicate = Predicate::and( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + ); + let expected_predicate = Predicate::or( + inner_predicate, + Reference::new("foo").less_than(Datum::long(0)), + ); + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_complex_binary_expr_unsupported() { + let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 "; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + let expected_predicate = Reference::new("foo").less_than(Datum::long(0)); + assert_eq!(predicate, expected_predicate); + } + + #[test] + // test the get result method + fn test_get_result_multiple() { + let predicates = vec![ + Some(Reference::new("foo").greater_than(Datum::long(1))), + None, + Some(Reference::new("bar").equal_to(Datum::string("test"))), + ]; + let stack = VecDeque::from(predicates); + let visitor = ExprToPredicateVisitor { stack }; + assert_eq!( + visitor.get_predicate(), + Some(Predicate::and( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + )) + ); + } + + #[test] + fn test_get_result_single() { + let predicates = vec![Some(Reference::new("foo").greater_than(Datum::long(1)))]; + let stack = VecDeque::from(predicates); + let visitor = ExprToPredicateVisitor { stack }; + assert_eq!( + visitor.get_predicate(), + Some(Reference::new("foo").greater_than(Datum::long(1))) + ); + } +} diff --git a/crates/integrations/datafusion/src/physical_plan/mod.rs b/crates/integrations/datafusion/src/physical_plan/mod.rs index 87a11e282..2fab109d7 100644 --- a/crates/integrations/datafusion/src/physical_plan/mod.rs +++ b/crates/integrations/datafusion/src/physical_plan/mod.rs @@ -15,5 +15,5 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod predicate_converter; +pub(crate) mod expr_to_predicate; pub(crate) mod scan; diff --git a/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs b/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs deleted file mode 100644 index 56ae4aa18..000000000 --- a/crates/integrations/datafusion/src/physical_plan/predicate_converter.rs +++ /dev/null @@ -1,156 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::DataType; -use datafusion::logical_expr::{BinaryExpr, Cast, Expr, Operator}; -use datafusion::scalar::ScalarValue; -use iceberg::expr::{Predicate, Reference}; -use iceberg::spec::Datum; -#[derive(Default)] -pub struct PredicateConverter; - -impl PredicateConverter { - /// Convert a list of DataFusion expressions to an iceberg predicate. - pub fn visit_many(&self, exprs: &[Expr]) -> Option { - exprs - .iter() - .filter_map(|expr| self.visit(expr)) - .reduce(Predicate::and) - } - - /// Convert a single DataFusion expression to an iceberg predicate. - /// currently only supports binary (simple) expressions - pub fn visit(&self, expr: &Expr) -> Option { - match expr { - Expr::BinaryExpr(binary) => self.visit_binary_expr(binary), - _ => None, - } - } - - /// Convert a binary expression to an iceberg predicate. - /// - /// currently supports: - /// - column, basic op, and literal, e.g. `a = 1` - /// - column and casted literal, e.g. `a = cast(1 as bigint)` - /// - binary conditional (and, or), e.g. `a = 1 and b = 2` - fn visit_binary_expr(&self, binary: &BinaryExpr) -> Option { - match (&*binary.left, &binary.op, &*binary.right) { - // column, op, literal - (Expr::Column(col), op, Expr::Literal(lit)) => self.visit_column_literal(col, op, lit), - // column, op, casted literal - (Expr::Column(col), op, Expr::Cast(Cast { expr, data_type })) => { - self.visit_column_cast(col, op, expr, data_type) - } - // binary conditional (and, or) - (left, op, right) if matches!(op, Operator::And | Operator::Or) => { - self.visit_binary_conditional(left, op, right) - } - _ => None, - } - } - - /// Convert a column and casted literal to an iceberg predicate. - /// The purpose of this function is to handle the common case in which there is a filter based on a casted literal. - /// These kinds of expressions are often not pushed down by query engines though its an important case to handle - /// for iceberg scan pushdown. - fn visit_column_cast( - &self, - col: &datafusion::common::Column, - op: &Operator, - expr: &Expr, - data_type: &DataType, - ) -> Option { - if let (Expr::Literal(ScalarValue::Utf8(lit)), DataType::Date32) = (expr, data_type) { - let reference = Reference::new(col.name.clone()); - let datum = lit - .clone() - .and_then(|date_str| Datum::date_from_str(date_str).ok())?; - return Some(binary_op_to_predicate(reference, op, datum)); - } - None - } - - /// Convert a binary conditional expression, i.e., (and, or), to an iceberg predicate. - /// - /// When processing an AND expression: - /// - if both expressions are valid predicates then an AND predicate is returned - /// - if either expression is None then the valid one is returned - /// - /// When processing an OR expression: - /// - only if both expressions are valid predicates then an OR predicate is returned - fn visit_binary_conditional( - &self, - left: &Expr, - op: &Operator, - right: &Expr, - ) -> Option { - let preds: Vec = vec![self.visit(left), self.visit(right)] - .into_iter() - .flatten() - .collect(); - let num_valid_preds = preds.len(); - match (op, num_valid_preds) { - (Operator::And, 1) => preds.first().cloned(), - (Operator::And, 2) => Some(Predicate::and(preds[0].clone(), preds[1].clone())), - (Operator::Or, 2) => Some(Predicate::or(preds[0].clone(), preds[1].clone())), - _ => None, - } - } - - /// Convert a simple expression based on column and literal (x > 1) to an iceberg predicate. - fn visit_column_literal( - &self, - col: &datafusion::common::Column, - op: &Operator, - lit: &ScalarValue, - ) -> Option { - let reference = Reference::new(col.name.clone()); - let datum = scalar_value_to_datum(lit)?; - Some(binary_op_to_predicate(reference, op, datum)) - } -} - -const MILLIS_PER_DAY: i64 = 24 * 60 * 60 * 1000; -/// Convert a scalar value to an iceberg datum. -fn scalar_value_to_datum(value: &ScalarValue) -> Option { - match value { - ScalarValue::Int8(Some(v)) => Some(Datum::int(*v as i32)), - ScalarValue::Int16(Some(v)) => Some(Datum::int(*v as i32)), - ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)), - ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)), - ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)), - ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)), - ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())), - ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())), - ScalarValue::Date32(Some(v)) => Some(Datum::date(*v)), - ScalarValue::Date64(Some(v)) => Some(Datum::date((*v / MILLIS_PER_DAY) as i32)), - _ => None, - } -} - -/// convert the data fusion Exp to an iceberg [`Predicate`] -fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate { - match op { - Operator::Eq => reference.equal_to(datum), - Operator::NotEq => reference.not_equal_to(datum), - Operator::Lt => reference.less_than(datum), - Operator::LtEq => reference.less_than_or_equal_to(datum), - Operator::Gt => reference.greater_than(datum), - Operator::GtEq => reference.greater_than_or_equal_to(datum), - _ => Predicate::AlwaysTrue, - } -} diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index 26ec89b34..2aa8092e3 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; +use datafusion::common::tree_node::TreeNode; use datafusion::error::Result as DFResult; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::EquivalenceProperties; @@ -33,7 +34,7 @@ use futures::{Stream, TryStreamExt}; use iceberg::expr::Predicate; use iceberg::table::Table; -use super::predicate_converter::PredicateConverter; +use crate::physical_plan::expr_to_predicate::ExprToPredicateVisitor; use crate::to_datafusion_error; /// Manages the scanning process of an Iceberg [`Table`], encapsulating the @@ -151,118 +152,15 @@ async fn get_batch_stream( /// If none of the filters could be converted, return `None` which adds no predicates to the scan operation. /// If the conversion was successful, return the converted predicates combined with an AND operator. fn convert_filters_to_predicate(filters: &[Expr]) -> Option { - PredicateConverter.visit_many(filters) -} - -#[cfg(test)] -mod tests { - use datafusion::arrow::datatypes::{DataType, Field, Schema}; - use datafusion::common::DFSchema; - use datafusion::prelude::SessionContext; - use iceberg::expr::Reference; - use iceberg::spec::Datum; - - use super::*; - - fn create_test_schema() -> DFSchema { - let arrow_schema = Schema::new(vec![ - Field::new("foo", DataType::Int32, false), - Field::new("bar", DataType::Utf8, false), - ]); - DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() - } - - #[test] - fn test_predicate_conversion_with_single_condition() { - let sql = "foo > 1"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - assert_eq!( - predicate, - Reference::new("foo").greater_than(Datum::long(1)) - ); - } - #[test] - fn test_predicate_conversion_with_single_unsupported_condition() { - let sql = "foo is null"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]); - assert_eq!(predicate, None); - } - - #[test] - fn test_predicate_conversion_with_and_condition() { - let sql = "foo > 1 and bar = 'test'"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Predicate::and( - Reference::new("foo").greater_than(Datum::long(1)), - Reference::new("bar").equal_to(Datum::string("test")), - ); - assert_eq!(predicate, expected_predicate); - } - - #[test] - fn test_predicate_conversion_with_and_condition_unsupported() { - let sql = "foo > 1 and bar is not null"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Reference::new("foo").greater_than(Datum::long(1)); - assert_eq!(predicate, expected_predicate); - } - - #[test] - fn test_predicate_conversion_with_or_condition_unsupported() { - let sql = "foo > 1 or bar is not null"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]); - let expected_predicate = None; - assert_eq!(predicate, expected_predicate); - } - - #[test] - fn test_predicate_conversion_with_complex_binary_expr() { - let sql = "(foo > 1 and bar = 'test') or foo < 0 "; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let inner_predicate = Predicate::and( - Reference::new("foo").greater_than(Datum::long(1)), - Reference::new("bar").equal_to(Datum::string("test")), - ); - let expected_predicate = Predicate::or( - inner_predicate, - Reference::new("foo").less_than(Datum::long(0)), - ); - assert_eq!(predicate, expected_predicate); - } - - #[test] - fn test_predicate_conversion_with_complex_binary_expr_unsupported() { - let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 "; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let predicate = convert_filters_to_predicate(&[expr]).unwrap(); - let expected_predicate = Reference::new("foo").less_than(Datum::long(0)); - assert_eq!(predicate, expected_predicate); - } + filters + .iter() + .filter_map(|expr| { + let mut visitor = ExprToPredicateVisitor::new(); + if expr.visit(&mut visitor).is_ok() { + visitor.get_predicate() + } else { + None + } + }) + .reduce(Predicate::and) } From bb41f70bb40a4a6ae25ade63b788f16855ff9a86 Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Mon, 2 Sep 2024 22:21:49 +0300 Subject: [PATCH 11/13] fixing fmt --- .../datafusion/src/physical_plan/expr_to_predicate.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs index da6cc77ef..092cf5bcf 100644 --- a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs +++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::collections::VecDeque; + use datafusion::common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion::common::Column; use datafusion::error::DataFusionError; @@ -22,7 +24,6 @@ use datafusion::logical_expr::{Expr, Operator}; use datafusion::scalar::ScalarValue; use iceberg::expr::{Predicate, Reference}; use iceberg::spec::Datum; -use std::collections::VecDeque; pub struct ExprToPredicateVisitor { stack: VecDeque>, From 865047659d5204e885c73562adb19fe578ae8b64 Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Wed, 4 Sep 2024 07:43:53 +0300 Subject: [PATCH 12/13] small refactor --- .../datafusion/src/physical_plan/expr_to_predicate.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs index 092cf5bcf..86d71fffc 100644 --- a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs +++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs @@ -88,18 +88,17 @@ impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor { fn f_up(&mut self, node: &'n Self::Node) -> Result { if let Expr::BinaryExpr(binary) = node { match (&*binary.left, &binary.op, &*binary.right) { + // process simple expressions (involving a column, operator and literal) (Expr::Column(col), op, Expr::Literal(lit)) => { let col_pred = self.convert_column_expr(col, op, lit); self.stack.push_back(col_pred); } + // process compound expressions (involving AND or OR and children) (_left, op, _right) if matches!(op, Operator::And | Operator::Or) => { let right_pred = self.stack.pop_back().flatten(); let left_pred = self.stack.pop_back().flatten(); - let valid_preds = [left_pred, right_pred] - .into_iter() - .flatten() - .collect::>(); - let compound_pred = self.convert_compound_expr(&valid_preds, op); + let children: Vec<_> = [left_pred, right_pred].into_iter().flatten().collect(); + let compound_pred = self.convert_compound_expr(&children, op); self.stack.push_back(compound_pred); } _ => {} From 21a38f5546452041649719f73a5b16e8dbf97d01 Mon Sep 17 00:00:00 2001 From: Alon Agmon Date: Sun, 8 Sep 2024 12:50:58 +0300 Subject: [PATCH 13/13] adding swapped op and fixing CR comments --- .../src/physical_plan/expr_to_predicate.rs | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs index 86d71fffc..110e4f7e4 100644 --- a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs +++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs @@ -81,29 +81,36 @@ impl ExprToPredicateVisitor { impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor { type Node = Expr; - fn f_down(&mut self, _node: &'n Self::Node) -> Result { + fn f_down(&mut self, _node: &'n Expr) -> Result { Ok(TreeNodeRecursion::Continue) } - fn f_up(&mut self, node: &'n Self::Node) -> Result { - if let Expr::BinaryExpr(binary) = node { + fn f_up(&mut self, expr: &'n Expr) -> Result { + if let Expr::BinaryExpr(binary) = expr { match (&*binary.left, &binary.op, &*binary.right) { - // process simple expressions (involving a column, operator and literal) + // process simple binary expressions, e.g. col > 1 (Expr::Column(col), op, Expr::Literal(lit)) => { let col_pred = self.convert_column_expr(col, op, lit); self.stack.push_back(col_pred); } - // process compound expressions (involving AND or OR and children) - (_left, op, _right) if matches!(op, Operator::And | Operator::Or) => { + // // process reversed binary expressions, e.g. 1 < col + (Expr::Literal(lit), op, Expr::Column(col)) => { + let col_pred = op + .swap() + .and_then(|negated_op| self.convert_column_expr(col, &negated_op, lit)); + self.stack.push_back(col_pred); + } + // process compound expressions (involving logical operators. e.g., AND or OR and children) + (_left, op, _right) if op.is_logic_operator() => { let right_pred = self.stack.pop_back().flatten(); let left_pred = self.stack.pop_back().flatten(); let children: Vec<_> = [left_pred, right_pred].into_iter().flatten().collect(); let compound_pred = self.convert_compound_expr(&children, op); self.stack.push_back(compound_pred); } - _ => {} + _ => return Ok(TreeNodeRecursion::Continue), } - }; + } Ok(TreeNodeRecursion::Continue) } } @@ -187,6 +194,22 @@ mod tests { let predicate = visitor.get_predicate(); assert_eq!(predicate, None); } + + #[test] + fn test_predicate_conversion_with_single_condition_rev() { + let sql = "1 < foo"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + assert_eq!( + predicate, + Reference::new("foo").greater_than(Datum::long(1)) + ); + } #[test] fn test_predicate_conversion_with_and_condition() { let sql = "foo > 1 and bar = 'test'";