diff --git a/daft/daft.pyi b/daft/daft.pyi index a905afbae0..7e6f63d9d9 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -375,6 +375,7 @@ def read_parquet_bulk( num_rows: int | None = None, row_groups: list[list[int]] | None = None, io_config: IOConfig | None = None, + num_parallel_tasks: int | None = 128, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: PyTimeUnit | None = None, ): ... @@ -400,6 +401,7 @@ def read_parquet_into_pyarrow_bulk( num_rows: int | None = None, row_groups: list[list[int]] | None = None, io_config: IOConfig | None = None, + num_parallel_tasks: int | None = 128, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: PyTimeUnit | None = None, ): ... diff --git a/daft/table/table.py b/daft/table/table.py index 5893d63b39..3a464a3058 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -396,6 +396,7 @@ def read_parquet_bulk( num_rows: int | None = None, row_groups_per_path: list[list[int]] | None = None, io_config: IOConfig | None = None, + num_parallel_tasks: int | None = 128, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(), ) -> list[Table]: @@ -406,6 +407,7 @@ def read_parquet_bulk( num_rows=num_rows, row_groups=row_groups_per_path, io_config=io_config, + num_parallel_tasks=num_parallel_tasks, multithreaded_io=multithreaded_io, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit, ) @@ -486,6 +488,7 @@ def read_parquet_into_pyarrow_bulk( num_rows: int | None = None, row_groups_per_path: list[list[int]] | None = None, io_config: IOConfig | None = None, + num_parallel_tasks: int | None = 128, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(), ) -> list[pa.Table]: @@ -496,6 +499,7 @@ def read_parquet_into_pyarrow_bulk( num_rows=num_rows, row_groups=row_groups_per_path, io_config=io_config, + num_parallel_tasks=num_parallel_tasks, multithreaded_io=multithreaded_io, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit, ) diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index f0491b872a..e5630cdc65 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -118,6 +118,7 @@ pub mod pylib { num_rows: Option, row_groups: Option>>, io_config: Option, + num_parallel_tasks: Option, multithreaded_io: Option, coerce_int96_timestamp_unit: Option, ) -> PyResult> { @@ -137,6 +138,7 @@ pub mod pylib { num_rows, row_groups, io_client, + num_parallel_tasks.unwrap_or(128) as usize, multithreaded_io.unwrap_or(true), &schema_infer_options, )? @@ -156,6 +158,7 @@ pub mod pylib { num_rows: Option, row_groups: Option>>, io_config: Option, + num_parallel_tasks: Option, multithreaded_io: Option, coerce_int96_timestamp_unit: Option, ) -> PyResult> { @@ -175,6 +178,7 @@ pub mod pylib { num_rows, row_groups, io_client, + num_parallel_tasks.unwrap_or(128) as usize, multithreaded_io.unwrap_or(true), &schema_infer_options, ) diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 249abaff36..e2253983f8 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -9,7 +9,7 @@ use daft_core::{ }; use daft_io::{get_runtime, IOClient}; use daft_table::Table; -use futures::future::{join_all, try_join_all}; +use futures::{future::join_all, StreamExt, TryStreamExt}; use snafu::ResultExt; use crate::{file::ParquetReaderBuilder, JoinSnafu}; @@ -321,6 +321,7 @@ pub fn read_parquet_bulk( num_rows: Option, row_groups: Option>>, io_client: Arc, + num_parallel_tasks: usize, multithreaded_io: bool, schema_infer_options: &ParquetSchemaInferenceOptions, ) -> DaftResult> { @@ -338,7 +339,7 @@ pub fn read_parquet_bulk( } let tables = runtime_handle .block_on(async move { - try_join_all(uris.iter().enumerate().map(|(i, uri)| { + let task_stream = futures::stream::iter(uris.iter().enumerate().map(|(i, uri)| { let uri = uri.to_string(); let owned_columns = owned_columns.clone(); let owned_row_group = match &row_groups { @@ -352,22 +353,31 @@ pub fn read_parquet_bulk( let columns = owned_columns .as_ref() .map(|s| s.iter().map(AsRef::as_ref).collect::>()); - read_parquet_single( - &uri, - columns.as_deref(), - start_offset, - num_rows, - owned_row_group.as_deref(), - io_client, - &schema_infer_options, - ) - .await + Ok(( + i, + read_parquet_single( + &uri, + columns.as_deref(), + start_offset, + num_rows, + owned_row_group.as_deref(), + io_client, + &schema_infer_options, + ) + .await?, + )) }) - })) - .await + })); + task_stream + .buffer_unordered(num_parallel_tasks) + .try_collect::>() + .await }) .context(JoinSnafu { path: "UNKNOWN" })?; - tables.into_iter().collect::>>() + + let mut collected = tables.into_iter().collect::>>()?; + collected.sort_by_key(|(idx, _)| *idx); + Ok(collected.into_iter().map(|(_, v)| v).collect()) } #[allow(clippy::too_many_arguments)] @@ -378,6 +388,7 @@ pub fn read_parquet_into_pyarrow_bulk( num_rows: Option, row_groups: Option>>, io_client: Arc, + num_parallel_tasks: usize, multithreaded_io: bool, schema_infer_options: &ParquetSchemaInferenceOptions, ) -> DaftResult> { @@ -395,7 +406,7 @@ pub fn read_parquet_into_pyarrow_bulk( } let tables = runtime_handle .block_on(async move { - try_join_all(uris.iter().enumerate().map(|(i, uri)| { + futures::stream::iter(uris.iter().enumerate().map(|(i, uri)| { let uri = uri.to_string(); let owned_columns = owned_columns.clone(); let owned_row_group = match &row_groups { @@ -409,22 +420,29 @@ pub fn read_parquet_into_pyarrow_bulk( let columns = owned_columns .as_ref() .map(|s| s.iter().map(AsRef::as_ref).collect::>()); - read_parquet_single_into_arrow( - &uri, - columns.as_deref(), - start_offset, - num_rows, - owned_row_group.as_deref(), - io_client, - &schema_infer_options, - ) - .await + Ok(( + i, + read_parquet_single_into_arrow( + &uri, + columns.as_deref(), + start_offset, + num_rows, + owned_row_group.as_deref(), + io_client, + &schema_infer_options, + ) + .await?, + )) }) })) + .buffer_unordered(num_parallel_tasks) + .try_collect::>() .await }) .context(JoinSnafu { path: "UNKNOWN" })?; - tables.into_iter().collect::>>() + let mut collected = tables.into_iter().collect::>>()?; + collected.sort_by_key(|(idx, _)| *idx); + Ok(collected.into_iter().map(|(_, v)| v).collect()) } pub fn read_parquet_schema(