Skip to content

Commit

Permalink
test: add ut for DatafusionArrowPredicate
Browse files Browse the repository at this point in the history
  • Loading branch information
v0y4g3r committed Jun 28, 2023
1 parent ace7b8b commit 7ebb884
Showing 1 changed file with 88 additions and 0 deletions.
88 changes: 88 additions & 0 deletions src/storage/src/sst/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ impl ArrowPredicate for PlainTimestampRowFilter {

#[cfg(test)]
mod tests {
use arrow_array::ArrayRef;
use datafusion_common::ToDFSchema;
use datafusion_expr::Operator;
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use datatypes::arrow_array::StringArray;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::value::timestamp_to_scalar_value;
use parquet::arrow::arrow_to_parquet_schema;

use super::*;

fn check_unit_lossy(range_unit: TimeUnit, col_unit: TimeUnit, expect: bool) {
Expand Down Expand Up @@ -317,4 +327,82 @@ mod tests {
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Microsecond, true);
check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Nanosecond, false);
}

fn check_arrow_predicate(
schema: Schema,
expr: datafusion_expr::Expr,
columns: Vec<ArrayRef>,
expected: Vec<Option<bool>>,
) {
let arrow_schema = schema.arrow_schema();
let df_schema = arrow_schema.clone().to_dfschema().unwrap();
let physical_expr = create_physical_expr(
&expr,
&df_schema,
arrow_schema.as_ref(),
&ExecutionProps::default(),
)
.unwrap();
let parquet_schema = arrow_to_parquet_schema(arrow_schema).unwrap();
let mut predicate = DatafusionArrowPredicate {
physical_expr,
projection_mask: ProjectionMask::roots(&parquet_schema, vec![0, 1]),
};

let batch = arrow_array::RecordBatch::try_new(arrow_schema.clone(), columns).unwrap();

let res = predicate.evaluate(batch).unwrap();
assert_eq!(expected, res.iter().collect::<Vec<_>>());
}

#[test]
fn test_datafusion_predicate() {
let schema = Schema::new(vec![
ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond),
false,
),
ColumnSchema::new("name", ConcreteDataType::string_datatype(), true),
]);

let expr = datafusion_expr::and(
datafusion_expr::binary_expr(
datafusion_expr::col("ts"),
Operator::GtEq,
datafusion_expr::lit(timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(10))),
),
datafusion_expr::binary_expr(
datafusion_expr::col("name"),
Operator::Lt,
datafusion_expr::lit("Bob"),
),
);

let ts_arr = Arc::new(TimestampNanosecondArray::from(vec![9, 11])) as Arc<_>;
let name_arr = Arc::new(StringArray::from(vec![Some("Alice"), Some("Charlie")])) as Arc<_>;

let columns = vec![ts_arr, name_arr];
check_arrow_predicate(
schema.clone(),
expr,
columns.clone(),
vec![Some(false), Some(false)],
);

let expr = datafusion_expr::and(
datafusion_expr::binary_expr(
datafusion_expr::col("ts"),
Operator::Lt,
datafusion_expr::lit(timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(10))),
),
datafusion_expr::binary_expr(
datafusion_expr::col("name"),
Operator::Lt,
datafusion_expr::lit("Bob"),
),
);

check_arrow_predicate(schema, expr, columns, vec![Some(true), Some(false)]);
}
}

0 comments on commit 7ebb884

Please sign in to comment.