diff --git a/crates/polars-stream/src/nodes/parquet_source/init.rs b/crates/polars-stream/src/nodes/parquet_source/init.rs index 661ea4b84825..6703acaf47a0 100644 --- a/crates/polars-stream/src/nodes/parquet_source/init.rs +++ b/crates/polars-stream/src/nodes/parquet_source/init.rs @@ -5,6 +5,7 @@ use futures::stream::FuturesUnordered; use futures::StreamExt; use polars_core::frame::DataFrame; use polars_error::PolarsResult; +use polars_io::prelude::ParallelStrategy; use super::row_group_data_fetch::RowGroupDataFetcher; use super::row_group_decode::RowGroupDecoder; @@ -264,11 +265,11 @@ impl ParquetSourceNode { /// Creates a `RowGroupDecoder` that turns `RowGroupData` into DataFrames. /// This must be called AFTER the following have been initialized: - /// * `self.projected_arrow_fields` + /// * `self.projected_arrow_schema` /// * `self.physical_predicate` pub(super) fn init_row_group_decoder(&self) -> RowGroupDecoder { assert!( - !self.projected_arrow_fields.is_empty() + !self.projected_arrow_schema.is_empty() || self.file_options.with_columns.as_deref() == Some(&[]) ); assert_eq!(self.predicate.is_some(), self.physical_predicate.is_some()); @@ -280,24 +281,33 @@ impl ParquetSourceNode { .map(|x| x[0].get_statistics().column_stats().len()) .unwrap_or(0); let include_file_paths = self.file_options.include_file_paths.clone(); - let projected_arrow_fields = self.projected_arrow_fields.clone(); + let projected_arrow_schema = self.projected_arrow_schema.clone(); let row_index = self.file_options.row_index.clone(); let physical_predicate = self.physical_predicate.clone(); let ideal_morsel_size = get_ideal_morsel_size(); + let min_values_per_thread = self.config.min_values_per_thread; + + let use_prefiltered = physical_predicate.is_some() + && matches!( + self.options.parallel, + ParallelStrategy::Auto | ParallelStrategy::Prefiltered + ); RowGroupDecoder { scan_sources, hive_partitions, hive_partitions_width, include_file_paths, - projected_arrow_fields, + projected_arrow_schema, row_index, physical_predicate, + use_prefiltered, ideal_morsel_size, + min_values_per_thread, } } - pub(super) fn init_projected_arrow_fields(&mut self) { + pub(super) fn init_projected_arrow_schema(&mut self) { let reader_schema = self .file_info .reader_schema @@ -307,20 +317,25 @@ impl ParquetSourceNode { .unwrap_left() .clone(); - self.projected_arrow_fields = + self.projected_arrow_schema = if let Some(columns) = self.file_options.with_columns.as_deref() { - columns - .iter() - .map(|x| reader_schema.get(x).unwrap().clone()) - .collect() + Arc::new( + columns + .iter() + .map(|x| { + let (_, k, v) = reader_schema.get_full(x).unwrap(); + (k.clone(), v.clone()) + }) + .collect(), + ) } else { - reader_schema.iter_values().cloned().collect() + reader_schema.clone() }; if self.verbose { eprintln!( "[ParquetSource]: {} columns to be projected from {} files", - self.projected_arrow_fields.len(), + self.projected_arrow_schema.len(), self.scan_sources.len(), ); } diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs index b5a2453cfe1c..58fe76681ad9 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs @@ -17,7 +17,7 @@ use crate::utils::task_handles_ext; impl ParquetSourceNode { /// Constructs the task that fetches file metadata. - /// Note: This must be called AFTER `self.projected_arrow_fields` has been initialized. + /// Note: This must be called AFTER `self.projected_arrow_schema` has been initialized. #[allow(clippy::type_complexity)] pub(super) fn init_metadata_fetcher( &mut self, @@ -35,10 +35,10 @@ impl ParquetSourceNode { let io_runtime = polars_io::pl_async::get_runtime(); assert!( - !self.projected_arrow_fields.is_empty() + !self.projected_arrow_schema.is_empty() || self.file_options.with_columns.as_deref() == Some(&[]) ); - let projected_arrow_fields = self.projected_arrow_fields.clone(); + let projected_arrow_schema = self.projected_arrow_schema.clone(); let (normalized_slice_oneshot_tx, normalized_slice_oneshot_rx) = tokio::sync::oneshot::channel(); @@ -115,7 +115,7 @@ impl ParquetSourceNode { move |handle: task_handles_ext::AbortOnDropHandle< PolarsResult<(usize, Arc, MemSlice)>, >| { - let projected_arrow_fields = projected_arrow_fields.clone(); + let projected_arrow_schema = projected_arrow_schema.clone(); let first_metadata = first_metadata.clone(); // Run on CPU runtime - metadata deserialization is expensive, especially // for very wide tables. @@ -132,7 +132,7 @@ impl ParquetSourceNode { }; ensure_metadata_has_projected_fields( - projected_arrow_fields.as_ref(), + projected_arrow_schema.as_ref(), &metadata, )?; diff --git a/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs b/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs index 7c848b07b750..24184fd12b10 100644 --- a/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs +++ b/crates/polars-stream/src/nodes/parquet_source/metadata_utils.rs @@ -1,4 +1,4 @@ -use polars_core::prelude::{DataType, PlHashMap}; +use polars_core::prelude::{ArrowSchema, DataType, PlHashMap}; use polars_error::{polars_bail, PolarsResult}; use polars_io::prelude::FileMetadata; use polars_io::utils::byte_source::{ByteSource, DynByteSource}; @@ -124,7 +124,7 @@ pub(super) async fn read_parquet_metadata_bytes( /// Ensures that a parquet file has all the necessary columns for a projection with the correct /// dtype. There are no ordering requirements and extra columns are permitted. pub(super) fn ensure_metadata_has_projected_fields( - projected_fields: &[polars_core::prelude::ArrowField], + projected_fields: &ArrowSchema, metadata: &FileMetadata, ) -> PolarsResult<()> { let schema = polars_parquet::arrow::read::infer_schema(metadata)?; @@ -138,7 +138,7 @@ pub(super) fn ensure_metadata_has_projected_fields( }) .collect::>(); - for field in projected_fields { + for field in projected_fields.iter_values() { let Some(dtype) = schema.remove(&field.name) else { polars_bail!(SchemaMismatch: "did not find column: {}", field.name) }; diff --git a/crates/polars-stream/src/nodes/parquet_source/mod.rs b/crates/polars-stream/src/nodes/parquet_source/mod.rs index 10df7ef0e3bf..dfede52f13cd 100644 --- a/crates/polars-stream/src/nodes/parquet_source/mod.rs +++ b/crates/polars-stream/src/nodes/parquet_source/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use mem_prefetch_funcs::get_memory_prefetch_func; use polars_core::config; use polars_core::frame::DataFrame; +use polars_core::prelude::ArrowSchema; use polars_error::PolarsResult; use polars_expr::prelude::{phys_expr_to_io_expr, PhysicalExpr}; use polars_io::cloud::CloudOptions; @@ -47,7 +48,7 @@ pub struct ParquetSourceNode { config: Config, verbose: bool, physical_predicate: Option>, - projected_arrow_fields: Arc<[polars_core::prelude::ArrowField]>, + projected_arrow_schema: Arc, byte_source_builder: DynByteSourceBuilder, memory_prefetch_func: fn(&[u8]) -> (), // This permit blocks execution until the first morsel is requested. @@ -67,6 +68,9 @@ struct Config { metadata_decode_ahead_size: usize, /// Number of row groups to pre-fetch concurrently, this can be across files row_group_prefetch_size: usize, + /// Minimum number of values for a parallel spawned task to process to amortize + /// parallelism overhead. + min_values_per_thread: usize, } #[allow(clippy::too_many_arguments)] @@ -106,10 +110,11 @@ impl ParquetSourceNode { metadata_prefetch_size: 0, metadata_decode_ahead_size: 0, row_group_prefetch_size: 0, + min_values_per_thread: 0, }, verbose, physical_predicate: None, - projected_arrow_fields: Arc::new([]), + projected_arrow_schema: Arc::new(ArrowSchema::default()), byte_source_builder, memory_prefetch_func, @@ -134,11 +139,17 @@ impl ComputeNode for ParquetSourceNode { (metadata_prefetch_size / 2).min(1 + num_pipelines).max(1); let row_group_prefetch_size = polars_core::config::get_rg_prefetch_size(); + // This can be set to 1 to force column-per-thread parallelism, e.g. for bug reproduction. + let min_values_per_thread = std::env::var("POLARS_MIN_VALUES_PER_THREAD") + .map(|x| x.parse::().expect("integer").max(1)) + .unwrap_or(16_777_216); + Config { num_pipelines, metadata_prefetch_size, metadata_decode_ahead_size, row_group_prefetch_size, + min_values_per_thread, } }; @@ -146,7 +157,7 @@ impl ComputeNode for ParquetSourceNode { eprintln!("[ParquetSource]: {:?}", &self.config); } - self.init_projected_arrow_fields(); + self.init_projected_arrow_schema(); self.physical_predicate = self.predicate.clone().map(phys_expr_to_io_expr); let (raw_morsel_receivers, morsel_stream_task_handle) = self.init_raw_morsel_stream(); diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs index 376562c92fb2..4d707c83b8cb 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_data_fetch.rs @@ -23,7 +23,7 @@ use crate::utils::task_handles_ext; /// Represents byte-data that can be transformed into a DataFrame after some computation. pub(super) struct RowGroupData { - pub(super) byte_source: FetchedBytes, + pub(super) fetched_bytes: FetchedBytes, pub(super) path_index: usize, pub(super) row_offset: usize, pub(super) slice: Option<(usize, usize)>, @@ -167,7 +167,7 @@ impl RowGroupDataFetcher { // Push calculation of byte ranges to a task to run in parallel, as it can be // expensive for very wide tables and projections. let handle = async_executor::spawn(TaskPriority::Low, async move { - let byte_source = if let DynByteSource::MemSlice(mem_slice) = + let fetched_bytes = if let DynByteSource::MemSlice(mem_slice) = current_byte_source.as_ref() { // Skip byte range calculation for `no_prefetch`. @@ -251,7 +251,7 @@ impl RowGroupDataFetcher { }; PolarsResult::Ok(RowGroupData { - byte_source, + fetched_bytes, path_index: current_path_index, row_offset: current_row_offset, slice, diff --git a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs index b3249e60057c..668bf7cee340 100644 --- a/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs +++ b/crates/polars-stream/src/nodes/parquet_source/row_group_decode.rs @@ -1,7 +1,9 @@ use std::sync::Arc; use polars_core::frame::DataFrame; -use polars_core::prelude::{ChunkFull, IdxCa, StringChunked}; +use polars_core::prelude::{ + ArrowField, ArrowSchema, BooleanChunked, ChunkFull, IdxCa, StringChunked, +}; use polars_core::series::{IntoSeries, IsSorted, Series}; use polars_error::{polars_bail, PolarsResult}; use polars_io::predicates::PhysicalIoExpr; @@ -21,21 +23,34 @@ pub(super) struct RowGroupDecoder { pub(super) hive_partitions: Option>>, pub(super) hive_partitions_width: usize, pub(super) include_file_paths: Option, - pub(super) projected_arrow_fields: Arc<[polars_core::prelude::ArrowField]>, + pub(super) projected_arrow_schema: Arc, pub(super) row_index: Option, pub(super) physical_predicate: Option>, + pub(super) use_prefiltered: bool, pub(super) ideal_morsel_size: usize, + pub(super) min_values_per_thread: usize, } impl RowGroupDecoder { pub(super) async fn row_group_data_to_df( &self, row_group_data: RowGroupData, + ) -> PolarsResult> { + if self.use_prefiltered { + self.row_group_data_to_df_prefiltered(row_group_data).await + } else { + self.row_group_data_to_df_impl(row_group_data).await + } + } + + async fn row_group_data_to_df_impl( + &self, + row_group_data: RowGroupData, ) -> PolarsResult> { let row_group_data = Arc::new(row_group_data); let out_width = self.row_index.is_some() as usize - + self.projected_arrow_fields.len() + + self.projected_arrow_schema.len() + self.hive_partitions_width + self.include_file_paths.is_some() as usize; @@ -52,91 +67,103 @@ impl RowGroupDecoder { .map(|(offset, len)| offset..offset + len) .unwrap_or(0..row_group_data.row_group_metadata.num_rows()); - let projected_arrow_fields = &self.projected_arrow_fields; - let projected_arrow_fields = projected_arrow_fields.clone(); + assert!(slice_range.end <= row_group_data.row_group_metadata.num_rows()); - let row_group_data_2 = row_group_data.clone(); - let slice_range_2 = slice_range.clone(); + self.decode_all_columns( + &mut out_columns, + &row_group_data, + Some(polars_parquet::read::Filter::Range(slice_range.clone())), + ) + .await?; - // Minimum number of values to amortize the overhead of spawning tasks. - // This value is arbitrarily chosen. - const VALUES_PER_THREAD: usize = 16_777_216; - let n_rows = row_group_data.row_group_metadata.num_rows(); - let cols_per_task = 1 + VALUES_PER_THREAD / n_rows; + let projection_height = if self.projected_arrow_schema.is_empty() { + slice_range.len() + } else { + debug_assert!(out_columns.len() > self.row_index.is_some() as usize); + out_columns.last().unwrap().len() + }; - let decode_fut_iter = (0..self.projected_arrow_fields.len()) - .step_by(cols_per_task) - .map(move |offset| { - let row_group_data = row_group_data_2.clone(); - let slice_range = slice_range_2.clone(); - let projected_arrow_fields = projected_arrow_fields.clone(); + if let Some(s) = self.materialize_row_index(row_group_data.as_ref(), slice_range)? { + out_columns[0] = s; + } - async move { - (offset - ..offset - .saturating_add(cols_per_task) - .min(projected_arrow_fields.len())) - .map(|i| { - let arrow_field = projected_arrow_fields[i].clone(); - - let columns_to_deserialize = row_group_data - .row_group_metadata - .columns_under_root_iter(&arrow_field.name) - .map(|col_md| { - let byte_range = col_md.byte_range(); - - ( - col_md, - row_group_data.byte_source.get_range( - byte_range.start as usize..byte_range.end as usize, - ), - ) - }) - .collect::>(); - - assert!( - slice_range.end <= row_group_data.row_group_metadata.num_rows() - ); - - let array = polars_io::prelude::_internal::to_deserializer( - columns_to_deserialize, - arrow_field.clone(), - Some(polars_parquet::read::Filter::Range(slice_range.clone())), - )?; - - let series = Series::try_from((&arrow_field, array))?; - - // TODO: Also load in the metadata. - - PolarsResult::Ok(series) - }) - .collect::>>() - } - }); + let shared_file_state = row_group_data + .shared_file_state + .get_or_init(|| self.shared_file_state_init_func(&row_group_data)) + .await; - if decode_fut_iter.len() > 1 { - for handle in decode_fut_iter.map(|fut| { - async_executor::AbortOnDropHandle::new(async_executor::spawn( - TaskPriority::Low, - fut, - )) - }) { - out_columns.extend(handle.await?); + assert_eq!(shared_file_state.path_index, row_group_data.path_index); + + for s in &shared_file_state.hive_series { + debug_assert!(s.len() >= projection_height); + out_columns.push(s.slice(0, projection_height)); + } + + if let Some(file_path_series) = &shared_file_state.file_path_series { + debug_assert!(file_path_series.len() >= projection_height); + out_columns.push(file_path_series.slice(0, projection_height)); + } + + let df = unsafe { DataFrame::new_no_checks(out_columns) }; + + let df = if let Some(predicate) = self.physical_predicate.as_deref() { + let mask = predicate.evaluate_io(&df)?; + let mask = mask.bool().unwrap(); + + unsafe { + DataFrame::new_no_checks( + filter_cols(df.take_columns(), mask, self.min_values_per_thread).await?, + ) } } else { - for fut in decode_fut_iter { - out_columns.extend(fut.await?); - } - } + df + }; - let projection_height = if self.projected_arrow_fields.is_empty() { - slice_range.len() + assert_eq!(df.width(), out_width); // `out_width` should have been calculated correctly + + Ok(self.split_to_morsels(df)) + } + + async fn shared_file_state_init_func(&self, row_group_data: &RowGroupData) -> SharedFileState { + let path_index = row_group_data.path_index; + + let hive_series = if let Some(hp) = self.hive_partitions.as_deref() { + let mut v = hp[path_index].materialize_partition_columns(); + for s in v.iter_mut() { + *s = s.new_from_index(0, row_group_data.file_max_row_group_height); + } + v } else { - debug_assert!(out_columns.len() > self.row_index.is_some() as usize); - out_columns.last().unwrap().len() + vec![] }; + let file_path_series = self.include_file_paths.clone().map(|file_path_col| { + StringChunked::full( + file_path_col, + self.scan_sources + .get(path_index) + .unwrap() + .to_include_path_name(), + row_group_data.file_max_row_group_height, + ) + .into_series() + }); + + SharedFileState { + path_index, + hive_series, + file_path_series, + } + } + + fn materialize_row_index( + &self, + row_group_data: &RowGroupData, + slice_range: core::ops::Range, + ) -> PolarsResult> { if let Some(RowIndex { name, offset }) = self.row_index.as_ref() { + let projection_height = row_group_data.row_group_metadata.num_rows(); + let Some(offset) = (|| { let offset = offset .checked_add((row_group_data.row_offset + slice_range.start) as IdxSize)?; @@ -161,102 +188,89 @@ impl RowGroupDecoder { ); ca.set_sorted_flag(IsSorted::Ascending); - out_columns[0] = ca.into_series(); + Ok(Some(ca.into_series())) + } else { + Ok(None) } + } - let shared_file_state = row_group_data - .shared_file_state - .get_or_init(|| async { - let path_index = row_group_data.path_index; + /// Potentially parallelizes based on number of rows & columns. Decoded columns are appended to + /// `out_vec`. + async fn decode_all_columns( + &self, + out_vec: &mut Vec, + row_group_data: &Arc, + filter: Option, + ) -> PolarsResult<()> { + let projected_arrow_schema = &self.projected_arrow_schema; + + let Some((cols_per_thread, remainder)) = calc_cols_per_thread( + row_group_data.row_group_metadata.num_rows(), + projected_arrow_schema.len(), + self.min_values_per_thread, + ) else { + // Single-threaded + for s in projected_arrow_schema + .iter_values() + .map(|arrow_field| decode_column(arrow_field, row_group_data, filter.clone())) + { + out_vec.push(s?) + } - let hive_series = if let Some(hp) = self.hive_partitions.as_deref() { - let mut v = hp[path_index].materialize_partition_columns(); - for s in v.iter_mut() { - *s = s.new_from_index(0, row_group_data.file_max_row_group_height); - } - v - } else { - vec![] - }; - - let file_path_series = self.include_file_paths.clone().map(|file_path_col| { - StringChunked::full( - file_path_col, - self.scan_sources - .get(path_index) - .unwrap() - .to_include_path_name(), - row_group_data.file_max_row_group_height, - ) - .into_series() - }); - - SharedFileState { - path_index, - hive_series, - file_path_series, - } - }) - .await; + return Ok(()); + }; - assert_eq!(shared_file_state.path_index, row_group_data.path_index); + let projected_arrow_schema = projected_arrow_schema.clone(); + let row_group_data_2 = row_group_data.clone(); - for s in &shared_file_state.hive_series { - debug_assert!(s.len() >= projection_height); - out_columns.push(s.slice(0, projection_height)); - } + let task_handles = { + let projected_arrow_schema = projected_arrow_schema.clone(); + let filter = filter.clone(); + + (remainder..projected_arrow_schema.len()) + .step_by(cols_per_thread) + .map(move |offset| { + let row_group_data = row_group_data_2.clone(); + let projected_arrow_schema = projected_arrow_schema.clone(); + let filter = filter.clone(); + + async move { + // This is exact as we have already taken out the remainder. + (offset..offset + cols_per_thread) + .map(|i| { + let (_, arrow_field) = + projected_arrow_schema.get_at_index(i).unwrap(); + + decode_column(arrow_field, &row_group_data, filter.clone()) + }) + .collect::>>() + } + }) + .map(|fut| { + async_executor::AbortOnDropHandle::new(async_executor::spawn( + TaskPriority::Low, + fut, + )) + }) + .collect::>() + }; - if let Some(file_path_series) = &shared_file_state.file_path_series { - debug_assert!(file_path_series.len() >= projection_height); - out_columns.push(file_path_series.slice(0, projection_height)); + for out in projected_arrow_schema + .iter_values() + .take(remainder) + .map(|arrow_field| decode_column(arrow_field, row_group_data, filter.clone())) + { + out_vec.push(out?); } - let df = unsafe { DataFrame::new_no_checks(out_columns) }; - - // Re-calculate: A slice may have been applied. - let cols_per_task = 1 + VALUES_PER_THREAD / df.height(); - - let df = if let Some(predicate) = self.physical_predicate.as_deref() { - let mask = predicate.evaluate_io(&df)?; - let mask = mask.bool().unwrap(); - - if cols_per_task <= df.width() { - df._filter_seq(mask)? - } else { - let mask = mask.clone(); - let cols = Arc::new(df.take_columns()); - let mut out_cols = Vec::with_capacity(cols.len()); - - for handle in (0..cols.len()) - .step_by(cols_per_task) - .map(move |offset| { - let cols = cols.clone(); - let mask = mask.clone(); - async move { - cols[offset..offset.saturating_add(cols_per_task).min(cols.len())] - .iter() - .map(|s| s.filter(&mask)) - .collect::>>() - } - }) - .map(|fut| { - async_executor::AbortOnDropHandle::new(async_executor::spawn( - TaskPriority::Low, - fut, - )) - }) - { - out_cols.extend(handle.await?); - } - - unsafe { DataFrame::new_no_checks(out_cols) } - } - } else { - df - }; + for handle in task_handles { + out_vec.extend(handle.await?); + } - assert_eq!(df.width(), out_width); + Ok(()) + } + fn split_to_morsels(&self, df: DataFrame) -> Vec { let n_morsels = if df.height() > 3 * self.ideal_morsel_size / 2 { // num_rows > (1.5 * ideal_morsel_size) (df.height() / self.ideal_morsel_size).max(2) @@ -265,18 +279,130 @@ impl RowGroupDecoder { } as u64; if n_morsels == 1 { - return Ok(vec![df]); + return vec![df]; } let rows_per_morsel = 1 + df.height() / n_morsels as usize; - let out = (0..i64::try_from(df.height()).unwrap()) + (0..i64::try_from(df.height()).unwrap()) .step_by(rows_per_morsel) .map(|offset| df.slice(offset, rows_per_morsel)) - .collect::>(); + .collect::>() + } +} + +fn decode_column( + arrow_field: &ArrowField, + row_group_data: &RowGroupData, + filter: Option, +) -> PolarsResult { + let columns_to_deserialize = row_group_data + .row_group_metadata + .columns_under_root_iter(&arrow_field.name) + .map(|col_md| { + let byte_range = col_md.byte_range(); + + ( + col_md, + row_group_data + .fetched_bytes + .get_range(byte_range.start as usize..byte_range.end as usize), + ) + }) + .collect::>(); + + let array = polars_io::prelude::_internal::to_deserializer( + columns_to_deserialize, + arrow_field.clone(), + filter, + )?; + + let series = Series::try_from((arrow_field, array))?; + + // TODO: Also load in the metadata. + + Ok(series) +} + +/// # Safety +/// All series in `cols` have the same length. +async unsafe fn filter_cols( + mut cols: Vec, + mask: &BooleanChunked, + min_values_per_thread: usize, +) -> PolarsResult> { + if cols.is_empty() { + return Ok(cols); + } + + let Some((cols_per_thread, remainder)) = + calc_cols_per_thread(cols[0].len(), cols.len(), min_values_per_thread) + else { + for s in cols.iter_mut() { + *s = s.filter(mask)?; + } + + return Ok(cols); + }; + + let mut out_vec = Vec::with_capacity(cols.len()); + let cols = Arc::new(cols); + let mask = mask.clone(); + + let task_handles = { + let cols = &cols; + let mask = &mask; + + (remainder..cols.len()) + .step_by(cols_per_thread) + .map(move |offset| { + let cols = cols.clone(); + let mask = mask.clone(); + async move { + (offset..offset + cols_per_thread) + .map(|i| cols[i].filter(&mask)) + .collect::>>() + } + }) + .map(|fut| { + async_executor::AbortOnDropHandle::new(async_executor::spawn( + TaskPriority::Low, + fut, + )) + }) + .collect::>() + }; + + for out in cols.iter().take(remainder).map(|s| s.filter(&mask)) { + out_vec.push(out?); + } - Ok(out) + for handle in task_handles { + out_vec.extend(handle.await?) } + + Ok(out_vec) +} + +/// Returns `Some((n_cols_per_thread, n_remainder))` if at least 2 tasks with >= `min_values_per_thread` can be created. +fn calc_cols_per_thread( + n_rows_per_col: usize, + n_cols: usize, + min_values_per_thread: usize, +) -> Option<(usize, usize)> { + let cols_per_thread = 1 + min_values_per_thread / n_rows_per_col.max(1); + + let cols_per_thread = if n_rows_per_col >= min_values_per_thread { + 1 + } else { + cols_per_thread + }; + + // At least 2 fully saturated tasks according to floordiv. + let parallel = n_cols / cols_per_thread >= 2; + let remainder = n_cols % cols_per_thread; + + parallel.then_some((cols_per_thread, remainder)) } /// State shared across row groups for a single file. @@ -285,3 +411,37 @@ pub(super) struct SharedFileState { hive_series: Vec, file_path_series: Option, } + +/// +/// Pre-filtered +/// + +impl RowGroupDecoder { + async fn row_group_data_to_df_prefiltered( + &self, + row_group_data: RowGroupData, + ) -> PolarsResult> { + // TODO: actually prefilter + self.row_group_data_to_df_impl(row_group_data).await + } +} + +mod tests { + #[test] + fn test_calc_cols_per_thread() { + use super::calc_cols_per_thread; + + let n_rows = 3; + let n_cols = 11; + let min_vals = 5; + assert_eq!(calc_cols_per_thread(n_rows, n_cols, min_vals), Some((2, 1))); + + let n_rows = 6; + let n_cols = 11; + let min_vals = 5; + assert_eq!(calc_cols_per_thread(n_rows, n_cols, min_vals), Some((1, 0))); + + calc_cols_per_thread(0, 1, 1); + calc_cols_per_thread(1, 0, 1); + } +}