From 559d1f73a2819b775b9c8757cff08522ec216f14 Mon Sep 17 00:00:00 2001 From: "Lei, HUANG" <6406592+v0y4g3r@users.noreply.github.com> Date: Wed, 28 Jun 2023 20:14:37 +0800 Subject: [PATCH] feat: push all possible filters down to parquet exec (#1839) * feat: push all possible filters down to parquet exec * fix: project * test: add ut for DatafusionArrowPredicate * fix: according to CR comments --- Cargo.lock | 2 + src/datatypes/src/value.rs | 2 +- src/query/src/tests/time_range_filter_test.rs | 2 +- src/storage/Cargo.toml | 2 + src/storage/src/chunk.rs | 15 +- src/storage/src/compaction/writer.rs | 147 +++++-- src/storage/src/error.rs | 7 + src/storage/src/region/tests/flush.rs | 7 +- src/storage/src/sst.rs | 1 + src/storage/src/sst/parquet.rs | 274 ++---------- src/storage/src/sst/pruning.rs | 408 ++++++++++++++++++ src/table/src/predicate.rs | 133 ++++-- 12 files changed, 658 insertions(+), 342 deletions(-) create mode 100644 src/storage/src/sst/pruning.rs diff --git a/Cargo.lock b/Cargo.lock index a7ae3e45c097..afa26a81abe4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9150,8 +9150,10 @@ dependencies = [ "common-test-util", "common-time", "criterion 0.3.6", + "datafusion", "datafusion-common", "datafusion-expr", + "datafusion-physical-expr", "datatypes", "futures", "futures-util", diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index d436c3934c3e..d087d78a271b 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -294,7 +294,7 @@ fn new_item_field(data_type: ArrowDataType) -> Field { Field::new("item", data_type, false) } -fn timestamp_to_scalar_value(unit: TimeUnit, val: Option) -> ScalarValue { +pub fn timestamp_to_scalar_value(unit: TimeUnit, val: Option) -> ScalarValue { match unit { TimeUnit::Second => ScalarValue::TimestampSecond(val, None), TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(val, None), diff --git a/src/query/src/tests/time_range_filter_test.rs b/src/query/src/tests/time_range_filter_test.rs index 49d32bb712bb..b8941933866f 100644 --- a/src/query/src/tests/time_range_filter_test.rs +++ b/src/query/src/tests/time_range_filter_test.rs @@ -140,7 +140,7 @@ impl TimeRangeTester { let _ = exec_selection(self.engine.clone(), sql).await; let filters = self.table.get_filters().await; - let range = TimeRangePredicateBuilder::new("ts", &filters).build(); + let range = TimeRangePredicateBuilder::new("ts", TimeUnit::Millisecond, &filters).build(); assert_eq!(expect, range); } } diff --git a/src/storage/Cargo.toml b/src/storage/Cargo.toml index 2aae3f549adb..47814c45a74e 100644 --- a/src/storage/Cargo.toml +++ b/src/storage/Cargo.toml @@ -23,6 +23,8 @@ common-time = { path = "../common/time" } datatypes = { path = "../datatypes" } datafusion-common.workspace = true datafusion-expr.workspace = true +datafusion-physical-expr.workspace = true +datafusion.workspace = true futures.workspace = true futures-util.workspace = true itertools.workspace = true diff --git a/src/storage/src/chunk.rs b/src/storage/src/chunk.rs index a4370efdf272..f6318bd5ac15 100644 --- a/src/storage/src/chunk.rs +++ b/src/storage/src/chunk.rs @@ -226,10 +226,16 @@ impl ChunkReaderBuilder { reader_builder = reader_builder.push_batch_iter(iter); } + let predicate = Predicate::try_new( + self.filters.clone(), + self.schema.store_schema().schema().clone(), + ) + .context(error::BuildPredicateSnafu)?; + let read_opts = ReadOptions { batch_size: self.iter_ctx.batch_size, projected_schema: schema.clone(), - predicate: Predicate::new(self.filters.clone()), + predicate, time_range: *time_range, }; for file in &self.files_to_read { @@ -270,7 +276,12 @@ impl ChunkReaderBuilder { /// Build time range predicate from schema and filters. pub fn build_time_range_predicate(&self) -> TimestampRange { let Some(ts_col) = self.schema.user_schema().timestamp_column() else { return TimestampRange::min_to_max() }; - TimeRangePredicateBuilder::new(&ts_col.name, &self.filters).build() + let unit = ts_col + .data_type + .as_timestamp() + .expect("Timestamp column must have timestamp-compatible type") + .unit(); + TimeRangePredicateBuilder::new(&ts_col.name, unit, &self.filters).build() } /// Check if SST file's time range matches predicate. diff --git a/src/storage/src/compaction/writer.rs b/src/storage/src/compaction/writer.rs index 0f9bf766a42a..fcf6f3bf4278 100644 --- a/src/storage/src/compaction/writer.rs +++ b/src/storage/src/compaction/writer.rs @@ -13,8 +13,9 @@ // limitations under the License. use common_query::logical_plan::{DfExpr, Expr}; -use datafusion_common::ScalarValue; -use datafusion_expr::{BinaryExpr, Operator}; +use common_time::timestamp::TimeUnit; +use datafusion_expr::Operator; +use datatypes::value::timestamp_to_scalar_value; use crate::chunk::{ChunkReaderBuilder, ChunkReaderImpl}; use crate::error; @@ -31,53 +32,84 @@ pub(crate) async fn build_sst_reader( ) -> error::Result { // TODO(hl): Schemas in different SSTs may differ, thus we should infer // timestamp column name from Parquet metadata. - let ts_col_name = schema - .user_schema() - .timestamp_column() - .unwrap() - .name - .clone(); + + // safety: Region schema's timestamp column must present + let ts_col = schema.user_schema().timestamp_column().unwrap(); + let ts_col_unit = ts_col.data_type.as_timestamp().unwrap().unit(); + let ts_col_name = ts_col.name.clone(); ChunkReaderBuilder::new(schema, sst_layer) .pick_ssts(files) - .filters(vec![build_time_range_filter( - lower_sec_inclusive, - upper_sec_exclusive, - &ts_col_name, - )]) + .filters( + build_time_range_filter( + lower_sec_inclusive, + upper_sec_exclusive, + &ts_col_name, + ts_col_unit, + ) + .into_iter() + .collect(), + ) .build() .await } -fn build_time_range_filter(low_sec: i64, high_sec: i64, ts_col_name: &str) -> Expr { - let ts_col = Box::new(DfExpr::Column(datafusion_common::Column::from_name( - ts_col_name, - ))); - let lower_bound_expr = Box::new(DfExpr::Literal(ScalarValue::TimestampSecond( - Some(low_sec), - None, - ))); - - let upper_bound_expr = Box::new(DfExpr::Literal(ScalarValue::TimestampSecond( - Some(high_sec), - None, - ))); - - let expr = DfExpr::BinaryExpr(BinaryExpr { - left: Box::new(DfExpr::BinaryExpr(BinaryExpr { - left: ts_col.clone(), - op: Operator::GtEq, - right: lower_bound_expr, - })), - op: Operator::And, - right: Box::new(DfExpr::BinaryExpr(BinaryExpr { - left: ts_col, - op: Operator::Lt, - right: upper_bound_expr, - })), - }); - - Expr::from(expr) +/// Build time range filter expr from lower (inclusive) and upper bound(exclusive). +/// Returns `None` if time range overflows. +fn build_time_range_filter( + low_sec: i64, + high_sec: i64, + ts_col_name: &str, + ts_col_unit: TimeUnit, +) -> Option { + debug_assert!(low_sec <= high_sec); + let ts_col = DfExpr::Column(datafusion_common::Column::from_name(ts_col_name)); + + // Converting seconds to whatever unit won't lose precision. + // Here only handles overflow. + let low_ts = common_time::Timestamp::new_second(low_sec) + .convert_to(ts_col_unit) + .map(|ts| ts.value()); + let high_ts = common_time::Timestamp::new_second(high_sec) + .convert_to(ts_col_unit) + .map(|ts| ts.value()); + + let expr = match (low_ts, high_ts) { + (Some(low), Some(high)) => { + let lower_bound_expr = + DfExpr::Literal(timestamp_to_scalar_value(ts_col_unit, Some(low))); + let upper_bound_expr = + DfExpr::Literal(timestamp_to_scalar_value(ts_col_unit, Some(high))); + Some(datafusion_expr::and( + datafusion_expr::binary_expr(ts_col.clone(), Operator::GtEq, lower_bound_expr), + datafusion_expr::binary_expr(ts_col, Operator::Lt, upper_bound_expr), + )) + } + + (Some(low), None) => { + let lower_bound_expr = + datafusion_expr::lit(timestamp_to_scalar_value(ts_col_unit, Some(low))); + Some(datafusion_expr::binary_expr( + ts_col, + Operator::GtEq, + lower_bound_expr, + )) + } + + (None, Some(high)) => { + let upper_bound_expr = + datafusion_expr::lit(timestamp_to_scalar_value(ts_col_unit, Some(high))); + Some(datafusion_expr::binary_expr( + ts_col, + Operator::Lt, + upper_bound_expr, + )) + } + + (None, None) => None, + }; + + expr.map(Expr::from) } #[cfg(test)] @@ -490,4 +522,35 @@ mod tests { assert_eq!(timestamps_in_outputs, timestamps_in_inputs); } + + #[test] + fn test_build_time_range_filter() { + assert!(build_time_range_filter(i64::MIN, i64::MAX, "ts", TimeUnit::Nanosecond).is_none()); + + assert_eq!( + Expr::from(datafusion_expr::binary_expr( + datafusion_expr::col("ts"), + Operator::Lt, + datafusion_expr::lit(timestamp_to_scalar_value( + TimeUnit::Nanosecond, + Some(TimeUnit::Second.factor() as i64 / TimeUnit::Nanosecond.factor() as i64) + )) + )), + build_time_range_filter(i64::MIN, 1, "ts", TimeUnit::Nanosecond).unwrap() + ); + + assert_eq!( + Expr::from(datafusion_expr::binary_expr( + datafusion_expr::col("ts"), + Operator::GtEq, + datafusion_expr::lit(timestamp_to_scalar_value( + TimeUnit::Nanosecond, + Some( + 2 * TimeUnit::Second.factor() as i64 / TimeUnit::Nanosecond.factor() as i64 + ) + )) + )), + build_time_range_filter(2, i64::MAX, "ts", TimeUnit::Nanosecond).unwrap() + ); + } } diff --git a/src/storage/src/error.rs b/src/storage/src/error.rs index 6f8c6977d5ba..c483022fe106 100644 --- a/src/storage/src/error.rs +++ b/src/storage/src/error.rs @@ -522,6 +522,12 @@ pub enum Error { source: ArrowError, location: Location, }, + + #[snafu(display("Failed to build scan predicate, source: {}", source))] + BuildPredicate { + source: table::error::Error, + location: Location, + }, } pub type Result = std::result::Result; @@ -621,6 +627,7 @@ impl ErrorExt for Error { TtlCalculation { source, .. } => source.status_code(), ConvertColumnsToRows { .. } | SortArrays { .. } => StatusCode::Unexpected, + BuildPredicate { source, .. } => source.status_code(), } } diff --git a/src/storage/src/region/tests/flush.rs b/src/storage/src/region/tests/flush.rs index 29a060abd722..606ab8bc76fe 100644 --- a/src/storage/src/region/tests/flush.rs +++ b/src/storage/src/region/tests/flush.rs @@ -21,7 +21,9 @@ use arrow::compute::SortOptions; use common_query::prelude::Expr; use common_recordbatch::OrderOption; use common_test_util::temp_dir::create_temp_dir; +use common_time::timestamp::TimeUnit; use datafusion_common::Column; +use datatypes::value::timestamp_to_scalar_value; use log_store::raft_engine::log_store::RaftEngineLogStore; use store_api::storage::{FlushContext, FlushReason, OpenOptions, Region, ScanRequest}; @@ -404,7 +406,10 @@ async fn test_flush_and_query_empty() { filters: vec![Expr::from(datafusion_expr::binary_expr( DfExpr::Column(Column::from("timestamp")), datafusion_expr::Operator::GtEq, - datafusion_expr::lit(20000), + datafusion_expr::lit(timestamp_to_scalar_value( + TimeUnit::Millisecond, + Some(20000), + )), ))], output_ordering: Some(vec![OrderOption { name: "timestamp".to_string(), diff --git a/src/storage/src/sst.rs b/src/storage/src/sst.rs index 49aefe717cd3..dff3c40e5a59 100644 --- a/src/storage/src/sst.rs +++ b/src/storage/src/sst.rs @@ -13,6 +13,7 @@ // limitations under the License. pub(crate) mod parquet; +mod pruning; mod stream_writer; use std::collections::HashMap; diff --git a/src/storage/src/sst/parquet.rs b/src/storage/src/sst/parquet.rs index e2dd03544148..1f906ecbcec4 100644 --- a/src/storage/src/sst/parquet.rs +++ b/src/storage/src/sst/parquet.rs @@ -18,12 +18,6 @@ use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; -use arrow::datatypes::DataType; -use arrow_array::types::Int64Type; -use arrow_array::{ - Array, PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, -}; use async_compat::CompatExt; use async_stream::try_stream; use async_trait::async_trait; @@ -31,19 +25,16 @@ use common_telemetry::{debug, error}; use common_time::range::TimestampRange; use common_time::timestamp::TimeUnit; use common_time::Timestamp; -use datatypes::arrow::array::BooleanArray; -use datatypes::arrow::error::ArrowError; use datatypes::arrow::record_batch::RecordBatch; use datatypes::prelude::ConcreteDataType; use futures_util::{Stream, StreamExt, TryStreamExt}; use object_store::ObjectStore; -use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; use parquet::basic::{Compression, Encoding, ZstdLevel}; use parquet::file::metadata::KeyValue; use parquet::file::properties::WriterProperties; use parquet::format::FileMetaData; -use parquet::schema::types::{ColumnPath, SchemaDescriptor}; +use parquet::schema::types::ColumnPath; use snafu::{OptionExt, ResultExt}; use store_api::storage::consts::SEQUENCE_COLUMN_NAME; use table::predicate::Predicate; @@ -54,6 +45,7 @@ use crate::read::{Batch, BatchReader}; use crate::schema::compat::ReadAdapter; use crate::schema::{ProjectedSchemaRef, StoreSchema}; use crate::sst; +use crate::sst::pruning::build_row_filter; use crate::sst::stream_writer::BufferedWriter; use crate::sst::{FileHandle, Source, SstInfo}; @@ -277,10 +269,7 @@ impl ParquetReader { let pruned_row_groups = self .predicate - .prune_row_groups( - store_schema.schema().clone(), - builder.metadata().row_groups(), - ) + .prune_row_groups(builder.metadata().row_groups()) .into_iter() .enumerate() .filter_map(|(idx, valid)| if valid { Some(idx) } else { None }) @@ -288,15 +277,18 @@ impl ParquetReader { let parquet_schema_desc = builder.metadata().file_metadata().schema_descr_ptr(); - let projection = ProjectionMask::roots(&parquet_schema_desc, adapter.fields_to_read()); + let projection_mask = ProjectionMask::roots(&parquet_schema_desc, adapter.fields_to_read()); let mut builder = builder - .with_projection(projection) + .with_projection(projection_mask.clone()) .with_row_groups(pruned_row_groups); - // if time range row filter is present, we can push down the filter to reduce rows to scan. - if let Some(row_filter) = - build_time_range_row_filter(self.time_range, &store_schema, &parquet_schema_desc) - { + if let Some(row_filter) = build_row_filter( + self.time_range, + &self.predicate, + &store_schema, + &parquet_schema_desc, + projection_mask, + ) { builder = builder.with_row_filter(row_filter); } @@ -314,198 +306,6 @@ impl ParquetReader { } } -/// Builds time range row filter. -fn build_time_range_row_filter( - time_range: TimestampRange, - store_schema: &Arc, - schema_desc: &SchemaDescriptor, -) -> Option { - let ts_col_idx = store_schema.timestamp_index(); - - let ts_col = store_schema.columns().get(ts_col_idx)?; - - let ts_col_unit = match &ts_col.desc.data_type { - ConcreteDataType::Int64(_) => TimeUnit::Millisecond, - ConcreteDataType::Timestamp(ts_type) => ts_type.unit(), - _ => unreachable!(), - }; - - let projection = ProjectionMask::roots(schema_desc, vec![ts_col_idx]); - - // checks if converting time range unit into ts col unit will result into rounding error. - if time_unit_lossy(&time_range, ts_col_unit) { - let filter = RowFilter::new(vec![Box::new(PlainTimestampRowFilter::new( - time_range, projection, - ))]); - return Some(filter); - } - - // If any of the conversion overflows, we cannot use arrow's computation method, instead - // we resort to plain filter that compares timestamp with given range, less efficient, - // but simpler. - // TODO(hl): If the range is gt_eq/lt, we also use PlainTimestampRowFilter, but these cases - // can also use arrow's gt_eq_scalar/lt_scalar methods. - let row_filter = if let (Some(lower), Some(upper)) = ( - time_range - .start() - .and_then(|s| s.convert_to(ts_col_unit)) - .map(|t| t.value()), - time_range - .end() - .and_then(|s| s.convert_to(ts_col_unit)) - .map(|t| t.value()), - ) { - Box::new(FastTimestampRowFilter::new(projection, lower, upper)) as _ - } else { - Box::new(PlainTimestampRowFilter::new(time_range, projection)) as _ - }; - let filter = RowFilter::new(vec![row_filter]); - Some(filter) -} - -fn time_unit_lossy(range: &TimestampRange, ts_col_unit: TimeUnit) -> bool { - range - .start() - .map(|start| start.unit().factor() < ts_col_unit.factor()) - .unwrap_or(false) - || range - .end() - .map(|end| end.unit().factor() < ts_col_unit.factor()) - .unwrap_or(false) -} - -/// `FastTimestampRowFilter` is used to filter rows within given timestamp range when reading -/// row groups from parquet files, while avoids fetching all columns from SSTs file. -struct FastTimestampRowFilter { - lower_bound: i64, - upper_bound: i64, - projection: ProjectionMask, -} - -impl FastTimestampRowFilter { - fn new(projection: ProjectionMask, lower_bound: i64, upper_bound: i64) -> Self { - Self { - lower_bound, - upper_bound, - projection, - } - } -} - -impl ArrowPredicate for FastTimestampRowFilter { - fn projection(&self) -> &ProjectionMask { - &self.projection - } - - /// Selects the rows matching given time range. - fn evaluate(&mut self, batch: RecordBatch) -> std::result::Result { - // the projection has only timestamp column, so we can safely take the first column in batch. - let ts_col = batch.column(0); - - macro_rules! downcast_and_compute { - ($typ: ty) => { - { - let ts_col = ts_col - .as_any() - .downcast_ref::<$typ>() - .unwrap(); // safety: we've checked the data type of timestamp column. - let left = arrow::compute::gt_eq_scalar(ts_col, self.lower_bound)?; - let right = arrow::compute::lt_scalar(ts_col, self.upper_bound)?; - arrow::compute::and(&left, &right) - } - }; - } - - match ts_col.data_type() { - DataType::Timestamp(unit, _) => match unit { - arrow::datatypes::TimeUnit::Second => { - downcast_and_compute!(TimestampSecondArray) - } - arrow::datatypes::TimeUnit::Millisecond => { - downcast_and_compute!(TimestampMillisecondArray) - } - arrow::datatypes::TimeUnit::Microsecond => { - downcast_and_compute!(TimestampMicrosecondArray) - } - arrow::datatypes::TimeUnit::Nanosecond => { - downcast_and_compute!(TimestampNanosecondArray) - } - }, - DataType::Int64 => downcast_and_compute!(PrimitiveArray), - _ => { - unreachable!() - } - } - } -} - -/// [PlainTimestampRowFilter] iterates each element in timestamp column, build a [Timestamp] struct -/// and checks if given time range contains the timestamp. -struct PlainTimestampRowFilter { - time_range: TimestampRange, - projection: ProjectionMask, -} - -impl PlainTimestampRowFilter { - fn new(time_range: TimestampRange, projection: ProjectionMask) -> Self { - Self { - time_range, - projection, - } - } -} - -impl ArrowPredicate for PlainTimestampRowFilter { - fn projection(&self) -> &ProjectionMask { - &self.projection - } - - fn evaluate(&mut self, batch: RecordBatch) -> std::result::Result { - // the projection has only timestamp column, so we can safely take the first column in batch. - let ts_col = batch.column(0); - - macro_rules! downcast_and_compute { - ($array_ty: ty, $unit: ident) => {{ - let ts_col = ts_col - .as_any() - .downcast_ref::<$array_ty>() - .unwrap(); // safety: we've checked the data type of timestamp column. - Ok(BooleanArray::from_iter(ts_col.iter().map(|ts| { - ts.map(|val| { - Timestamp::new(val, TimeUnit::$unit) - }).map(|ts| { - self.time_range.contains(&ts) - }) - }))) - - }}; - } - - match ts_col.data_type() { - DataType::Timestamp(unit, _) => match unit { - arrow::datatypes::TimeUnit::Second => { - downcast_and_compute!(TimestampSecondArray, Second) - } - arrow::datatypes::TimeUnit::Millisecond => { - downcast_and_compute!(TimestampMillisecondArray, Millisecond) - } - arrow::datatypes::TimeUnit::Microsecond => { - downcast_and_compute!(TimestampMicrosecondArray, Microsecond) - } - arrow::datatypes::TimeUnit::Nanosecond => { - downcast_and_compute!(TimestampNanosecondArray, Nanosecond) - } - }, - DataType::Int64 => { - downcast_and_compute!(PrimitiveArray, Millisecond) - } - _ => { - unreachable!() - } - } - } -} - pub type SendableChunkStream = Pin> + Send>>; pub struct ChunkStream { @@ -740,11 +540,12 @@ mod tests { let operator = create_object_store(dir.path().to_str().unwrap()); let projected_schema = Arc::new(ProjectedSchema::new(schema, Some(vec![1])).unwrap()); + let user_schema = projected_schema.projected_user_schema().clone(); let reader = ParquetReader::new( sst_file_handle, operator, projected_schema, - Predicate::empty(), + Predicate::empty(user_schema), TimestampRange::min_to_max(), ); @@ -826,11 +627,12 @@ mod tests { let operator = create_object_store(dir.path().to_str().unwrap()); let projected_schema = Arc::new(ProjectedSchema::new(schema, Some(vec![1])).unwrap()); + let user_schema = projected_schema.projected_user_schema().clone(); let reader = ParquetReader::new( file_handle, operator, projected_schema, - Predicate::empty(), + Predicate::empty(user_schema), TimestampRange::min_to_max(), ); @@ -854,8 +656,14 @@ mod tests { range: TimestampRange, expect: Vec, ) { - let reader = - ParquetReader::new(file_handle, object_store, schema, Predicate::empty(), range); + let store_schema = schema.schema_to_read().clone(); + let reader = ParquetReader::new( + file_handle, + object_store, + schema, + Predicate::empty(store_schema.schema().clone()), + range, + ); let mut stream = reader.chunk_stream().await.unwrap(); let result = stream.next_batch().await; @@ -981,16 +789,6 @@ mod tests { .await; } - fn check_unit_lossy(range_unit: TimeUnit, col_unit: TimeUnit, expect: bool) { - assert_eq!( - expect, - time_unit_lossy( - &TimestampRange::with_unit(0, 1, range_unit).unwrap(), - col_unit - ) - ) - } - #[tokio::test] async fn test_write_empty_file() { common_telemetry::init_default_ut_logging(); @@ -1014,28 +812,4 @@ mod tests { // The file should not exist when no row has been written. assert!(!object_store.is_exist(sst_file_name).await.unwrap()); } - - #[test] - fn test_time_unit_lossy() { - // converting a range with unit second to millisecond will not cause rounding error - check_unit_lossy(TimeUnit::Second, TimeUnit::Second, false); - check_unit_lossy(TimeUnit::Second, TimeUnit::Millisecond, false); - check_unit_lossy(TimeUnit::Second, TimeUnit::Microsecond, false); - check_unit_lossy(TimeUnit::Second, TimeUnit::Nanosecond, false); - - check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Second, true); - check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Millisecond, false); - check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Microsecond, false); - check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Nanosecond, false); - - check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Second, true); - check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Millisecond, true); - check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Microsecond, false); - check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Nanosecond, false); - - check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Second, true); - check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Millisecond, true); - check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Microsecond, true); - check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Nanosecond, false); - } } diff --git a/src/storage/src/sst/pruning.rs b/src/storage/src/sst/pruning.rs new file mode 100644 index 000000000000..499d04ebde37 --- /dev/null +++ b/src/storage/src/sst/pruning.rs @@ -0,0 +1,408 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::array::{ + PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, +}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::error::ArrowError; +use arrow_array::{Array, BooleanArray, RecordBatch}; +use common_time::range::TimestampRange; +use common_time::timestamp::TimeUnit; +use common_time::Timestamp; +use datafusion::physical_plan::PhysicalExpr; +use datatypes::prelude::ConcreteDataType; +use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; +use parquet::arrow::ProjectionMask; +use parquet::schema::types::SchemaDescriptor; +use table::predicate::Predicate; + +use crate::error; +use crate::schema::StoreSchema; + +/// Builds row filters according to predicates. +pub(crate) fn build_row_filter( + time_range: TimestampRange, + predicate: &Predicate, + store_schema: &Arc, + schema_desc: &SchemaDescriptor, + projection_mask: ProjectionMask, +) -> Option { + let ts_col_idx = store_schema.timestamp_index(); + let ts_col = store_schema.columns().get(ts_col_idx)?; + let ts_col_unit = match &ts_col.desc.data_type { + ConcreteDataType::Int64(_) => TimeUnit::Millisecond, + ConcreteDataType::Timestamp(ts_type) => ts_type.unit(), + _ => unreachable!(), + }; + + let ts_col_projection = ProjectionMask::roots(schema_desc, vec![ts_col_idx]); + + // checks if converting time range unit into ts col unit will result into rounding error. + if time_unit_lossy(&time_range, ts_col_unit) { + let filter = RowFilter::new(vec![Box::new(PlainTimestampRowFilter::new( + time_range, + ts_col_projection, + ))]); + return Some(filter); + } + + // If any of the conversion overflows, we cannot use arrow's computation method, instead + // we resort to plain filter that compares timestamp with given range, less efficient, + // but simpler. + // TODO(hl): If the range is gt_eq/lt, we also use PlainTimestampRowFilter, but these cases + // can also use arrow's gt_eq_scalar/lt_scalar methods. + let time_range_row_filter = if let (Some(lower), Some(upper)) = ( + time_range + .start() + .and_then(|s| s.convert_to(ts_col_unit)) + .map(|t| t.value()), + time_range + .end() + .and_then(|s| s.convert_to(ts_col_unit)) + .map(|t| t.value()), + ) { + Box::new(FastTimestampRowFilter::new(ts_col_projection, lower, upper)) as _ + } else { + Box::new(PlainTimestampRowFilter::new(time_range, ts_col_projection)) as _ + }; + let mut predicates = vec![time_range_row_filter]; + if let Ok(datafusion_filters) = predicate_to_row_filter(predicate, projection_mask) { + predicates.extend(datafusion_filters); + } + let filter = RowFilter::new(predicates); + Some(filter) +} + +fn predicate_to_row_filter( + predicate: &Predicate, + projection_mask: ProjectionMask, +) -> error::Result>> { + let mut datafusion_predicates = Vec::with_capacity(predicate.exprs().len()); + for expr in predicate.exprs() { + datafusion_predicates.push(Box::new(DatafusionArrowPredicate { + projection_mask: projection_mask.clone(), + physical_expr: expr.clone(), + }) as _); + } + + Ok(datafusion_predicates) +} + +#[derive(Debug)] +struct DatafusionArrowPredicate { + projection_mask: ProjectionMask, + physical_expr: Arc, +} + +impl ArrowPredicate for DatafusionArrowPredicate { + fn projection(&self) -> &ProjectionMask { + &self.projection_mask + } + + fn evaluate(&mut self, batch: RecordBatch) -> Result { + match self + .physical_expr + .evaluate(&batch) + .map(|v| v.into_array(batch.num_rows())) + { + Ok(array) => { + let bool_arr = array + .as_any() + .downcast_ref::() + .ok_or(ArrowError::CastError( + "Physical expr evaluated res is not a boolean array".to_string(), + ))? + .clone(); + Ok(bool_arr) + } + Err(e) => Err(ArrowError::ComputeError(format!( + "Error evaluating filter predicate: {e:?}" + ))), + } + } +} + +fn time_unit_lossy(range: &TimestampRange, ts_col_unit: TimeUnit) -> bool { + range + .start() + .map(|start| start.unit().factor() < ts_col_unit.factor()) + .unwrap_or(false) + || range + .end() + .map(|end| end.unit().factor() < ts_col_unit.factor()) + .unwrap_or(false) +} + +/// `FastTimestampRowFilter` is used to filter rows within given timestamp range when reading +/// row groups from parquet files, while avoids fetching all columns from SSTs file. +struct FastTimestampRowFilter { + lower_bound: i64, + upper_bound: i64, + projection: ProjectionMask, +} + +impl FastTimestampRowFilter { + fn new(projection: ProjectionMask, lower_bound: i64, upper_bound: i64) -> Self { + Self { + lower_bound, + upper_bound, + projection, + } + } +} + +impl ArrowPredicate for FastTimestampRowFilter { + fn projection(&self) -> &ProjectionMask { + &self.projection + } + + /// Selects the rows matching given time range. + fn evaluate(&mut self, batch: RecordBatch) -> Result { + // the projection has only timestamp column, so we can safely take the first column in batch. + let ts_col = batch.column(0); + + macro_rules! downcast_and_compute { + ($typ: ty) => { + { + let ts_col = ts_col + .as_any() + .downcast_ref::<$typ>() + .unwrap(); // safety: we've checked the data type of timestamp column. + let left = arrow::compute::gt_eq_scalar(ts_col, self.lower_bound)?; + let right = arrow::compute::lt_scalar(ts_col, self.upper_bound)?; + arrow::compute::and(&left, &right) + } + }; + } + + match ts_col.data_type() { + DataType::Timestamp(unit, _) => match unit { + arrow::datatypes::TimeUnit::Second => { + downcast_and_compute!(TimestampSecondArray) + } + arrow::datatypes::TimeUnit::Millisecond => { + downcast_and_compute!(TimestampMillisecondArray) + } + arrow::datatypes::TimeUnit::Microsecond => { + downcast_and_compute!(TimestampMicrosecondArray) + } + arrow::datatypes::TimeUnit::Nanosecond => { + downcast_and_compute!(TimestampNanosecondArray) + } + }, + DataType::Int64 => downcast_and_compute!(PrimitiveArray), + _ => { + unreachable!() + } + } + } +} + +/// [PlainTimestampRowFilter] iterates each element in timestamp column, build a [Timestamp] struct +/// and checks if given time range contains the timestamp. +struct PlainTimestampRowFilter { + time_range: TimestampRange, + projection: ProjectionMask, +} + +impl PlainTimestampRowFilter { + fn new(time_range: TimestampRange, projection: ProjectionMask) -> Self { + Self { + time_range, + projection, + } + } +} + +impl ArrowPredicate for PlainTimestampRowFilter { + fn projection(&self) -> &ProjectionMask { + &self.projection + } + + fn evaluate(&mut self, batch: RecordBatch) -> Result { + // the projection has only timestamp column, so we can safely take the first column in batch. + let ts_col = batch.column(0); + + macro_rules! downcast_and_compute { + ($array_ty: ty, $unit: ident) => {{ + let ts_col = ts_col + .as_any() + .downcast_ref::<$array_ty>() + .unwrap(); // safety: we've checked the data type of timestamp column. + Ok(BooleanArray::from_iter(ts_col.iter().map(|ts| { + ts.map(|val| { + Timestamp::new(val, TimeUnit::$unit) + }).map(|ts| { + self.time_range.contains(&ts) + }) + }))) + + }}; + } + + match ts_col.data_type() { + DataType::Timestamp(unit, _) => match unit { + arrow::datatypes::TimeUnit::Second => { + downcast_and_compute!(TimestampSecondArray, Second) + } + arrow::datatypes::TimeUnit::Millisecond => { + downcast_and_compute!(TimestampMillisecondArray, Millisecond) + } + arrow::datatypes::TimeUnit::Microsecond => { + downcast_and_compute!(TimestampMicrosecondArray, Microsecond) + } + arrow::datatypes::TimeUnit::Nanosecond => { + downcast_and_compute!(TimestampNanosecondArray, Nanosecond) + } + }, + DataType::Int64 => { + downcast_and_compute!(PrimitiveArray, Millisecond) + } + _ => { + unreachable!() + } + } + } +} + +#[cfg(test)] +mod tests { + use arrow_array::ArrayRef; + use datafusion_common::ToDFSchema; + use datafusion_expr::Operator; + use datafusion_physical_expr::create_physical_expr; + use datafusion_physical_expr::execution_props::ExecutionProps; + use datatypes::arrow_array::StringArray; + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::value::timestamp_to_scalar_value; + use parquet::arrow::arrow_to_parquet_schema; + + use super::*; + + fn check_unit_lossy(range_unit: TimeUnit, col_unit: TimeUnit, expect: bool) { + assert_eq!( + expect, + time_unit_lossy( + &TimestampRange::with_unit(0, 1, range_unit).unwrap(), + col_unit + ) + ) + } + + #[test] + fn test_time_unit_lossy() { + // converting a range with unit second to millisecond will not cause rounding error + check_unit_lossy(TimeUnit::Second, TimeUnit::Second, false); + check_unit_lossy(TimeUnit::Second, TimeUnit::Millisecond, false); + check_unit_lossy(TimeUnit::Second, TimeUnit::Microsecond, false); + check_unit_lossy(TimeUnit::Second, TimeUnit::Nanosecond, false); + + check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Second, true); + check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Millisecond, false); + check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Microsecond, false); + check_unit_lossy(TimeUnit::Millisecond, TimeUnit::Nanosecond, false); + + check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Second, true); + check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Millisecond, true); + check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Microsecond, false); + check_unit_lossy(TimeUnit::Microsecond, TimeUnit::Nanosecond, false); + + check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Second, true); + check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Millisecond, true); + check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Microsecond, true); + check_unit_lossy(TimeUnit::Nanosecond, TimeUnit::Nanosecond, false); + } + + fn check_arrow_predicate( + schema: Schema, + expr: datafusion_expr::Expr, + columns: Vec, + expected: Vec>, + ) { + let arrow_schema = schema.arrow_schema(); + let df_schema = arrow_schema.clone().to_dfschema().unwrap(); + let physical_expr = create_physical_expr( + &expr, + &df_schema, + arrow_schema.as_ref(), + &ExecutionProps::default(), + ) + .unwrap(); + let parquet_schema = arrow_to_parquet_schema(arrow_schema).unwrap(); + let mut predicate = DatafusionArrowPredicate { + physical_expr, + projection_mask: ProjectionMask::roots(&parquet_schema, vec![0, 1]), + }; + + let batch = arrow_array::RecordBatch::try_new(arrow_schema.clone(), columns).unwrap(); + + let res = predicate.evaluate(batch).unwrap(); + assert_eq!(expected, res.iter().collect::>()); + } + + #[test] + fn test_datafusion_predicate() { + let schema = Schema::new(vec![ + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond), + false, + ), + ColumnSchema::new("name", ConcreteDataType::string_datatype(), true), + ]); + + let expr = datafusion_expr::and( + datafusion_expr::binary_expr( + datafusion_expr::col("ts"), + Operator::GtEq, + datafusion_expr::lit(timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(10))), + ), + datafusion_expr::binary_expr( + datafusion_expr::col("name"), + Operator::Lt, + datafusion_expr::lit("Bob"), + ), + ); + + let ts_arr = Arc::new(TimestampNanosecondArray::from(vec![9, 11])) as Arc<_>; + let name_arr = Arc::new(StringArray::from(vec![Some("Alice"), Some("Charlie")])) as Arc<_>; + + let columns = vec![ts_arr, name_arr]; + check_arrow_predicate( + schema.clone(), + expr, + columns.clone(), + vec![Some(false), Some(false)], + ); + + let expr = datafusion_expr::and( + datafusion_expr::binary_expr( + datafusion_expr::col("ts"), + Operator::Lt, + datafusion_expr::lit(timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(10))), + ), + datafusion_expr::binary_expr( + datafusion_expr::col("name"), + Operator::Lt, + datafusion_expr::lit("Bob"), + ), + ); + + check_arrow_predicate(schema, expr, columns, vec![Some(true), Some(false)]); + } +} diff --git a/src/table/src/predicate.rs b/src/table/src/predicate.rs index 88ce88f8a52c..92e1552bbd98 100644 --- a/src/table/src/predicate.rs +++ b/src/table/src/predicate.rs @@ -12,66 +12,94 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use common_query::logical_plan::{DfExpr, Expr}; use common_telemetry::{error, warn}; use common_time::range::TimestampRange; +use common_time::timestamp::TimeUnit; use common_time::Timestamp; use datafusion::parquet::file::metadata::RowGroupMetaData; use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion_common::ToDFSchema; use datafusion_expr::expr::InList; use datafusion_expr::{Between, BinaryExpr, Operator}; -use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; +use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datatypes::schema::SchemaRef; use datatypes::value::scalar_value_to_timestamp; +use snafu::ResultExt; +use crate::error; use crate::predicate::stats::RowGroupPruningStatistics; mod stats; -#[derive(Default, Clone)] +#[derive(Clone)] pub struct Predicate { - exprs: Vec, + /// The schema of underlying storage. + schema: SchemaRef, + /// Physical expressions of this predicate. + exprs: Vec>, } impl Predicate { - pub fn new(exprs: Vec) -> Self { - Self { exprs } + /// Creates a new `Predicate` by converting logical exprs to physical exprs that can be + /// evaluated against record batches. + /// Returns error when failed to convert exprs. + pub fn try_new(exprs: Vec, schema: SchemaRef) -> error::Result { + let arrow_schema = schema.arrow_schema(); + let df_schema = arrow_schema + .clone() + .to_dfschema_ref() + .context(error::DatafusionSnafu)?; + + // TODO(hl): `execution_props` provides variables required by evaluation. + // we may reuse the `execution_props` from `SessionState` once we support + // registering variables. + let execution_props = &ExecutionProps::new(); + + let physical_exprs = exprs + .iter() + .map(|expr| { + create_physical_expr( + expr.df_expr(), + df_schema.as_ref(), + arrow_schema.as_ref(), + execution_props, + ) + }) + .collect::>() + .context(error::DatafusionSnafu)?; + + Ok(Self { + schema, + exprs: physical_exprs, + }) } - pub fn empty() -> Self { - Self { exprs: vec![] } + #[inline] + pub fn exprs(&self) -> &[Arc] { + &self.exprs } - pub fn prune_row_groups( - &self, - schema: SchemaRef, - row_groups: &[RowGroupMetaData], - ) -> Vec { - let mut res = vec![true; row_groups.len()]; - let arrow_schema = (*schema.arrow_schema()).clone(); - let df_schema = arrow_schema.clone().to_dfschema_ref(); - let df_schema = match df_schema { - Ok(x) => x, - Err(e) => { - warn!("Failed to create Datafusion schema when trying to prune row groups, error: {e}"); - return res; - } - }; + /// Builds an empty predicate from given schema. + pub fn empty(schema: SchemaRef) -> Self { + Self { + schema, + exprs: vec![], + } + } - let execution_props = &ExecutionProps::new(); + /// Evaluates the predicate against row group metadata. + /// Returns a vector of boolean values, among which `false` means the row group can be skipped. + pub fn prune_row_groups(&self, row_groups: &[RowGroupMetaData]) -> Vec { + let mut res = vec![true; row_groups.len()]; + let arrow_schema = self.schema.arrow_schema(); for expr in &self.exprs { - match create_physical_expr( - expr.df_expr(), - df_schema.as_ref(), - arrow_schema.as_ref(), - execution_props, - ) - .and_then(|expr| PruningPredicate::try_new(expr, arrow_schema.clone())) - { + match PruningPredicate::try_new(expr.clone(), arrow_schema.clone()) { Ok(p) => { - let stat = RowGroupPruningStatistics::new(row_groups, &schema); + let stat = RowGroupPruningStatistics::new(row_groups, &self.schema); match p.prune(&stat) { Ok(r) => { for (curr_val, res) in r.into_iter().zip(res.iter_mut()) { @@ -94,15 +122,19 @@ impl Predicate { // tests for `TimeRangePredicateBuilder` locates in src/query/tests/time_range_filter_test.rs // since it requires query engine to convert sql to filters. +/// `TimeRangePredicateBuilder` extracts time range from logical exprs to facilitate fast +/// time range pruning. pub struct TimeRangePredicateBuilder<'a> { ts_col_name: &'a str, + ts_col_unit: TimeUnit, filters: &'a [Expr], } impl<'a> TimeRangePredicateBuilder<'a> { - pub fn new(ts_col_name: &'a str, filters: &'a [Expr]) -> Self { + pub fn new(ts_col_name: &'a str, ts_col_unit: TimeUnit, filters: &'a [Expr]) -> Self { Self { ts_col_name, + ts_col_unit, filters, } } @@ -149,18 +181,23 @@ impl<'a> TimeRangePredicateBuilder<'a> { match op { Operator::Eq => self .get_timestamp_filter(left, right) + .and_then(|ts| ts.convert_to(self.ts_col_unit)) .map(TimestampRange::single), Operator::Lt => self .get_timestamp_filter(left, right) + .and_then(|ts| ts.convert_to_ceil(self.ts_col_unit)) .map(|ts| TimestampRange::until_end(ts, false)), Operator::LtEq => self .get_timestamp_filter(left, right) + .and_then(|ts| ts.convert_to_ceil(self.ts_col_unit)) .map(|ts| TimestampRange::until_end(ts, true)), Operator::Gt => self .get_timestamp_filter(left, right) + .and_then(|ts| ts.convert_to(self.ts_col_unit)) .map(TimestampRange::from_start), Operator::GtEq => self .get_timestamp_filter(left, right) + .and_then(|ts| ts.convert_to(self.ts_col_unit)) .map(TimestampRange::from_start), Operator::And => { // instead of return none when failed to extract time range from left/right, we unwrap the none into @@ -231,8 +268,10 @@ impl<'a> TimeRangePredicateBuilder<'a> { match (low, high) { (DfExpr::Literal(low), DfExpr::Literal(high)) => { - let low_opt = scalar_value_to_timestamp(low); - let high_opt = scalar_value_to_timestamp(high); + let low_opt = + scalar_value_to_timestamp(low).and_then(|ts| ts.convert_to(self.ts_col_unit)); + let high_opt = scalar_value_to_timestamp(high) + .and_then(|ts| ts.convert_to_ceil(self.ts_col_unit)); Some(TimestampRange::new_inclusive(low_opt, high_opt)) } _ => None, @@ -329,10 +368,15 @@ mod tests { (path, schema) } - async fn assert_prune(array_cnt: usize, predicate: Predicate, expect: Vec) { + async fn assert_prune( + array_cnt: usize, + filters: Vec, + expect: Vec, + ) { let dir = create_temp_dir("prune_parquet"); let (path, schema) = gen_test_parquet_file(&dir, array_cnt).await; let schema = Arc::new(datatypes::schema::Schema::try_from(schema).unwrap()); + let arrow_predicate = Predicate::try_new(filters, schema.clone()).unwrap(); let builder = ParquetRecordBatchStreamBuilder::new( tokio::fs::OpenOptions::new() .read(true) @@ -344,23 +388,23 @@ mod tests { .unwrap(); let metadata = builder.metadata().clone(); let row_groups = metadata.row_groups(); - let res = predicate.prune_row_groups(schema, row_groups); + let res = arrow_predicate.prune_row_groups(row_groups); assert_eq!(expect, res); } - fn gen_predicate(max_val: i32, op: Operator) -> Predicate { - Predicate::new(vec![common_query::logical_plan::Expr::from( - Expr::BinaryExpr(BinaryExpr { + fn gen_predicate(max_val: i32, op: Operator) -> Vec { + vec![common_query::logical_plan::Expr::from(Expr::BinaryExpr( + BinaryExpr { left: Box::new(Expr::Column(Column::from_name("cnt"))), op, right: Box::new(Expr::Literal(ScalarValue::Int32(Some(max_val)))), - }), - )]) + }, + ))] } #[tokio::test] async fn test_prune_empty() { - assert_prune(3, Predicate::empty(), vec![true]).await; + assert_prune(3, vec![], vec![true]).await; } #[tokio::test] @@ -424,7 +468,6 @@ mod tests { let e = Expr::Column(Column::from_name("cnt")) .gt(30.lit()) .or(Expr::Column(Column::from_name("cnt")).lt(20.lit())); - let p = Predicate::new(vec![e.into()]); - assert_prune(40, p, vec![true, true, false, true]).await; + assert_prune(40, vec![e.into()], vec![true, true, false, true]).await; } }