Skip to content

Commit

Permalink
refactor(rust): Refactor code into functions in new parquet source (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Sep 12, 2024
1 parent a66532d commit 470d8c4
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 192 deletions.
39 changes: 27 additions & 12 deletions crates/polars-stream/src/nodes/parquet_source/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use futures::stream::FuturesUnordered;
use futures::StreamExt;
use polars_core::frame::DataFrame;
use polars_error::PolarsResult;
use polars_io::prelude::ParallelStrategy;

use super::row_group_data_fetch::RowGroupDataFetcher;
use super::row_group_decode::RowGroupDecoder;
Expand Down Expand Up @@ -264,11 +265,11 @@ impl ParquetSourceNode {

/// Creates a `RowGroupDecoder` that turns `RowGroupData` into DataFrames.
/// This must be called AFTER the following have been initialized:
/// * `self.projected_arrow_fields`
/// * `self.projected_arrow_schema`
/// * `self.physical_predicate`
pub(super) fn init_row_group_decoder(&self) -> RowGroupDecoder {
assert!(
!self.projected_arrow_fields.is_empty()
!self.projected_arrow_schema.is_empty()
|| self.file_options.with_columns.as_deref() == Some(&[])
);
assert_eq!(self.predicate.is_some(), self.physical_predicate.is_some());
Expand All @@ -280,24 +281,33 @@ impl ParquetSourceNode {
.map(|x| x[0].get_statistics().column_stats().len())
.unwrap_or(0);
let include_file_paths = self.file_options.include_file_paths.clone();
let projected_arrow_fields = self.projected_arrow_fields.clone();
let projected_arrow_schema = self.projected_arrow_schema.clone();
let row_index = self.file_options.row_index.clone();
let physical_predicate = self.physical_predicate.clone();
let ideal_morsel_size = get_ideal_morsel_size();
let min_values_per_thread = self.config.min_values_per_thread;

let use_prefiltered = physical_predicate.is_some()
&& matches!(
self.options.parallel,
ParallelStrategy::Auto | ParallelStrategy::Prefiltered
);

RowGroupDecoder {
scan_sources,
hive_partitions,
hive_partitions_width,
include_file_paths,
projected_arrow_fields,
projected_arrow_schema,
row_index,
physical_predicate,
use_prefiltered,
ideal_morsel_size,
min_values_per_thread,
}
}

pub(super) fn init_projected_arrow_fields(&mut self) {
pub(super) fn init_projected_arrow_schema(&mut self) {
let reader_schema = self
.file_info
.reader_schema
Expand All @@ -307,20 +317,25 @@ impl ParquetSourceNode {
.unwrap_left()
.clone();

self.projected_arrow_fields =
self.projected_arrow_schema =
if let Some(columns) = self.file_options.with_columns.as_deref() {
columns
.iter()
.map(|x| reader_schema.get(x).unwrap().clone())
.collect()
Arc::new(
columns
.iter()
.map(|x| {
let (_, k, v) = reader_schema.get_full(x).unwrap();
(k.clone(), v.clone())
})
.collect(),
)
} else {
reader_schema.iter_values().cloned().collect()
reader_schema.clone()
};

if self.verbose {
eprintln!(
"[ParquetSource]: {} columns to be projected from {} files",
self.projected_arrow_fields.len(),
self.projected_arrow_schema.len(),
self.scan_sources.len(),
);
}
Expand Down
10 changes: 5 additions & 5 deletions crates/polars-stream/src/nodes/parquet_source/metadata_fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::utils::task_handles_ext;

impl ParquetSourceNode {
/// Constructs the task that fetches file metadata.
/// Note: This must be called AFTER `self.projected_arrow_fields` has been initialized.
/// Note: This must be called AFTER `self.projected_arrow_schema` has been initialized.
#[allow(clippy::type_complexity)]
pub(super) fn init_metadata_fetcher(
&mut self,
Expand All @@ -35,10 +35,10 @@ impl ParquetSourceNode {
let io_runtime = polars_io::pl_async::get_runtime();

assert!(
!self.projected_arrow_fields.is_empty()
!self.projected_arrow_schema.is_empty()
|| self.file_options.with_columns.as_deref() == Some(&[])
);
let projected_arrow_fields = self.projected_arrow_fields.clone();
let projected_arrow_schema = self.projected_arrow_schema.clone();

let (normalized_slice_oneshot_tx, normalized_slice_oneshot_rx) =
tokio::sync::oneshot::channel();
Expand Down Expand Up @@ -115,7 +115,7 @@ impl ParquetSourceNode {
move |handle: task_handles_ext::AbortOnDropHandle<
PolarsResult<(usize, Arc<DynByteSource>, MemSlice)>,
>| {
let projected_arrow_fields = projected_arrow_fields.clone();
let projected_arrow_schema = projected_arrow_schema.clone();
let first_metadata = first_metadata.clone();
// Run on CPU runtime - metadata deserialization is expensive, especially
// for very wide tables.
Expand All @@ -132,7 +132,7 @@ impl ParquetSourceNode {
};

ensure_metadata_has_projected_fields(
projected_arrow_fields.as_ref(),
projected_arrow_schema.as_ref(),
&metadata,
)?;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use polars_core::prelude::{DataType, PlHashMap};
use polars_core::prelude::{ArrowSchema, DataType, PlHashMap};
use polars_error::{polars_bail, PolarsResult};
use polars_io::prelude::FileMetadata;
use polars_io::utils::byte_source::{ByteSource, DynByteSource};
Expand Down Expand Up @@ -124,7 +124,7 @@ pub(super) async fn read_parquet_metadata_bytes(
/// Ensures that a parquet file has all the necessary columns for a projection with the correct
/// dtype. There are no ordering requirements and extra columns are permitted.
pub(super) fn ensure_metadata_has_projected_fields(
projected_fields: &[polars_core::prelude::ArrowField],
projected_fields: &ArrowSchema,
metadata: &FileMetadata,
) -> PolarsResult<()> {
let schema = polars_parquet::arrow::read::infer_schema(metadata)?;
Expand All @@ -138,7 +138,7 @@ pub(super) fn ensure_metadata_has_projected_fields(
})
.collect::<PlHashMap<PlSmallStr, DataType>>();

for field in projected_fields {
for field in projected_fields.iter_values() {
let Some(dtype) = schema.remove(&field.name) else {
polars_bail!(SchemaMismatch: "did not find column: {}", field.name)
};
Expand Down
17 changes: 14 additions & 3 deletions crates/polars-stream/src/nodes/parquet_source/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use mem_prefetch_funcs::get_memory_prefetch_func;
use polars_core::config;
use polars_core::frame::DataFrame;
use polars_core::prelude::ArrowSchema;
use polars_error::PolarsResult;
use polars_expr::prelude::{phys_expr_to_io_expr, PhysicalExpr};
use polars_io::cloud::CloudOptions;
Expand Down Expand Up @@ -47,7 +48,7 @@ pub struct ParquetSourceNode {
config: Config,
verbose: bool,
physical_predicate: Option<Arc<dyn PhysicalIoExpr>>,
projected_arrow_fields: Arc<[polars_core::prelude::ArrowField]>,
projected_arrow_schema: Arc<ArrowSchema>,
byte_source_builder: DynByteSourceBuilder,
memory_prefetch_func: fn(&[u8]) -> (),
// This permit blocks execution until the first morsel is requested.
Expand All @@ -67,6 +68,9 @@ struct Config {
metadata_decode_ahead_size: usize,
/// Number of row groups to pre-fetch concurrently, this can be across files
row_group_prefetch_size: usize,
/// Minimum number of values for a parallel spawned task to process to amortize
/// parallelism overhead.
min_values_per_thread: usize,
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -106,10 +110,11 @@ impl ParquetSourceNode {
metadata_prefetch_size: 0,
metadata_decode_ahead_size: 0,
row_group_prefetch_size: 0,
min_values_per_thread: 0,
},
verbose,
physical_predicate: None,
projected_arrow_fields: Arc::new([]),
projected_arrow_schema: Arc::new(ArrowSchema::default()),
byte_source_builder,
memory_prefetch_func,

Expand All @@ -134,19 +139,25 @@ impl ComputeNode for ParquetSourceNode {
(metadata_prefetch_size / 2).min(1 + num_pipelines).max(1);
let row_group_prefetch_size = polars_core::config::get_rg_prefetch_size();

// This can be set to 1 to force column-per-thread parallelism, e.g. for bug reproduction.
let min_values_per_thread = std::env::var("POLARS_MIN_VALUES_PER_THREAD")
.map(|x| x.parse::<usize>().expect("integer").max(1))
.unwrap_or(16_777_216);

Config {
num_pipelines,
metadata_prefetch_size,
metadata_decode_ahead_size,
row_group_prefetch_size,
min_values_per_thread,
}
};

if self.verbose {
eprintln!("[ParquetSource]: {:?}", &self.config);
}

self.init_projected_arrow_fields();
self.init_projected_arrow_schema();
self.physical_predicate = self.predicate.clone().map(phys_expr_to_io_expr);

let (raw_morsel_receivers, morsel_stream_task_handle) = self.init_raw_morsel_stream();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::utils::task_handles_ext;

/// Represents byte-data that can be transformed into a DataFrame after some computation.
pub(super) struct RowGroupData {
pub(super) byte_source: FetchedBytes,
pub(super) fetched_bytes: FetchedBytes,
pub(super) path_index: usize,
pub(super) row_offset: usize,
pub(super) slice: Option<(usize, usize)>,
Expand Down Expand Up @@ -167,7 +167,7 @@ impl RowGroupDataFetcher {
// Push calculation of byte ranges to a task to run in parallel, as it can be
// expensive for very wide tables and projections.
let handle = async_executor::spawn(TaskPriority::Low, async move {
let byte_source = if let DynByteSource::MemSlice(mem_slice) =
let fetched_bytes = if let DynByteSource::MemSlice(mem_slice) =
current_byte_source.as_ref()
{
// Skip byte range calculation for `no_prefetch`.
Expand Down Expand Up @@ -251,7 +251,7 @@ impl RowGroupDataFetcher {
};

PolarsResult::Ok(RowGroupData {
byte_source,
fetched_bytes,
path_index: current_path_index,
row_offset: current_row_offset,
slice,
Expand Down
Loading

0 comments on commit 470d8c4

Please sign in to comment.