Skip to content

Commit

Permalink
chore(pruning): fix data type and column expression for null and row …
Browse files Browse the repository at this point in the history
…counts

chore: fix pruning_predicate in slt tests
  • Loading branch information
appletreeisyellow committed Feb 16, 2024
1 parent 23b03fa commit 9006d1e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 28 deletions.
73 changes: 49 additions & 24 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,19 +944,39 @@ impl<'a> PruningExpressionBuilder<'a> {
.max_column_expr(&self.column, &self.column_expr, self.field)
}

/// Note that this function intentionally overwrites the column expression to [`phys_expr::Column`].
/// i.e. expressions like [`phys_expr::CastExpr`] or [`phys_expr::TryCastExpr`] will be overwritten.
///
/// This is to avoid cases like `cast(x_null_count)` or `try_cast(x_null_count)`.
fn null_count_column_expr(&mut self) -> Result<Arc<dyn PhysicalExpr>> {
// overwrite to [`phys_expr::Column`]
let column_expr = Arc::new(self.column.clone()) as _;

// null_count is DataType::UInt64, which is different from the column's data type (i.e. self.field)
let null_count_field = &Field::new(self.field.name(), DataType::UInt64, true);

self.required_columns.null_count_column_expr(
&self.column,
&self.column_expr,
self.field,
&column_expr,
&null_count_field,
)
}

/// Note that this function intentionally overwrites the column expression to [`phys_expr::Column`].
/// i.e. expressions like [`phys_expr::CastExpr`] or [`phys_expr::TryCastExpr`] will be overwritten.
///
/// This is to avoid cases like `cast(x_row_count)` or `try_cast(x_row_count)`.
fn row_count_column_expr(&mut self) -> Result<Arc<dyn PhysicalExpr>> {
// overwrite to [`phys_expr::Column`]
let column_expr = Arc::new(self.column.clone()) as _;

// row_count is DataType::UInt64, which is different from the column's data type (i.e. self.field)
let row_count_field = &Field::new(self.field.name(), DataType::UInt64, true);

self.required_columns.row_count_column_expr(
&self.column,
&self.column_expr,
self.field,
&column_expr,
row_count_field,
)
}
}
Expand Down Expand Up @@ -1396,10 +1416,15 @@ fn build_statistics_expr(
/// ELSE x_min <= 10 AND 10 <= x_max
/// END
/// ````
///
/// If the column is known to be all nulls, then the expression
/// `x_null_count = x_row_count` will be true, which will cause the
/// case expression to return false. Therefore, prune out the container.
fn wrap_case_expr(
statistics_expr: Arc<dyn PhysicalExpr>,
expr_builder: &mut PruningExpressionBuilder,
) -> Result<Arc<dyn PhysicalExpr>> {
// x_null_count = x_row_count
let when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new(
expr_builder.null_count_column_expr()?,
Operator::Eq,
Expand Down Expand Up @@ -2244,7 +2269,7 @@ mod tests {
)
);
// c1 < 1 should add c1_null_count
let c1_null_count_field = Field::new("c1_null_count", DataType::Int32, false);
let c1_null_count_field = Field::new("c1_null_count", DataType::UInt64, false);
assert_eq!(
required_columns.columns[1],
(
Expand All @@ -2254,7 +2279,7 @@ mod tests {
)
);
// c1 < 1 should add c1_row_count
let c1_row_count_field = Field::new("c1_row_count", DataType::Int32, false);
let c1_row_count_field = Field::new("c1_row_count", DataType::UInt64, false);
assert_eq!(
required_columns.columns[2],
(
Expand Down Expand Up @@ -2283,7 +2308,7 @@ mod tests {
)
);
// c2 = 2 should add c2_null_count
let c2_null_count_field = Field::new("c2_null_count", DataType::Int32, false);
let c2_null_count_field = Field::new("c2_null_count", DataType::UInt64, false);
assert_eq!(
required_columns.columns[5],
(
Expand All @@ -2293,7 +2318,7 @@ mod tests {
)
);
// c2 = 2 should add c2_row_count
let c2_row_count_field = Field::new("c2_row_count", DataType::Int32, false);
let c2_row_count_field = Field::new("c2_row_count", DataType::UInt64, false);
assert_eq!(
required_columns.columns[6],
(
Expand Down Expand Up @@ -2466,12 +2491,15 @@ mod tests {
Ok(())
}

// TODO chunchun: add test for two different columns
// e.g. c1 = 3 and c2 = 4
// cast(c1) = 3 and cast(c2) = 4

#[test]
fn row_group_predicate_cast() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr =
"CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
let expected_expr = "CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \
END";

Expand All @@ -2487,9 +2515,8 @@ mod tests {
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);

let expected_expr =
"CASE \
WHEN TRY_CAST(c1_null_count@1 AS Int64) = TRY_CAST(c1_row_count@2 AS Int64) THEN false \
let expected_expr = "CASE \
WHEN c1_null_count@1 = c1_row_count@2 THEN false \
ELSE TRY_CAST(c1_max@0 AS Int64) > 1 \
END";

Expand Down Expand Up @@ -2523,17 +2550,16 @@ mod tests {
],
false,
));
let expected_expr =
"CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
let expected_expr = "CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \
END \
OR CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) \
END \
OR CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64) \
END";
let predicate_expr =
Expand All @@ -2549,17 +2575,16 @@ mod tests {
],
true,
));
let expected_expr =
"CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
let expected_expr = "CASE \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64) \
END \
AND CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64) \
END \
AND CASE \
WHEN CAST(c1_null_count@2 AS Int64) = CAST(c1_row_count@3 AS Int64) THEN false \
WHEN c1_null_count@2 = c1_row_count@3 THEN false \
ELSE CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64) \
END";
let predicate_expr =
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/repartition_scan.slt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Filter: parquet_table.column1 != Int32(42)
physical_plan
CoalesceBatchesExec: target_batch_size=8192
--FilterExec: column1@0 != 42
----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)]
----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)]

# disable round robin repartitioning
statement ok
Expand All @@ -77,7 +77,7 @@ Filter: parquet_table.column1 != Int32(42)
physical_plan
CoalesceBatchesExec: target_batch_size=8192
--FilterExec: column1@0 != 42
----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)]
----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)]

# enable round robin repartitioning again
statement ok
Expand All @@ -102,7 +102,7 @@ SortPreservingMergeExec: [column1@0 ASC NULLS LAST]
--SortExec: expr=[column1@0 ASC NULLS LAST]
----CoalesceBatchesExec: target_batch_size=8192
------FilterExec: column1@0 != 42
--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)]
--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)]


## Read the files as though they are ordered
Expand Down Expand Up @@ -138,7 +138,7 @@ physical_plan
SortPreservingMergeExec: [column1@0 ASC NULLS LAST]
--CoalesceBatchesExec: target_batch_size=8192
----FilterExec: column1@0 != 42
------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)]
------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)]

# Cleanup
statement ok
Expand Down

0 comments on commit 9006d1e

Please sign in to comment.