From f72df5474a5250d3d79a5aa39f7673130405707e Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Wed, 11 Oct 2023 09:59:38 -0700 Subject: [PATCH 1/2] Pass-through multithreaded_io flag in read_parquet --- daft/daft.pyi | 5 +++++ daft/execution/execution_step.py | 1 + daft/io/_parquet.py | 17 ++++++++++++----- daft/table/table_io.py | 2 ++ src/common/io-config/src/python.rs | 2 +- src/daft-io/src/lib.rs | 15 +++++++++++++-- src/daft-io/src/s3_like.rs | 5 +++-- src/daft-plan/src/source_info/file_format.rs | 13 ++++++++++--- 8 files changed, 47 insertions(+), 13 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index d8c843080b..786d00e374 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -178,6 +178,11 @@ class ParquetSourceConfig: Configuration of a Parquet data source. """ + # Whether or not to use a multithreaded tokio runtime for processing I/O + multithreaded_io: bool + + def __init__(self, multithreaded_io: bool): ... + class CsvSourceConfig: """ Configuration of a CSV data source. diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 896c7f3e8b..37052b39c6 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -401,6 +401,7 @@ def _handle_tabular_files_scan( schema=self.schema, storage_config=self.storage_config, read_options=read_options, + multithreaded_io=format_config.multithreaded_io, ) for fp in filepaths ] diff --git a/daft/io/_parquet.py b/daft/io/_parquet.py index 1e6558dd9b..0228cc1d0a 100644 --- a/daft/io/_parquet.py +++ b/daft/io/_parquet.py @@ -1,12 +1,14 @@ # isort: dont-add-import: from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import fsspec +from daft import context from daft.api_annotations import PublicAPI from daft.daft import ( FileFormatConfig, + IOConfig, NativeStorageConfig, ParquetSourceConfig, PythonStorageConfig, @@ -16,9 +18,6 @@ from daft.datatype import DataType from daft.io.common import _get_tabular_files_scan -if TYPE_CHECKING: - from daft.io import IOConfig - @PublicAPI def read_parquet( @@ -53,7 +52,15 @@ def read_parquet( if isinstance(path, list) and len(path) == 0: raise ValueError(f"Cannot read DataFrame from from empty list of Parquet filepaths") - file_format_config = FileFormatConfig.from_parquet_config(ParquetSourceConfig()) + # If running on Ray, we want to limit the amount of concurrency and requests being made. + # This is because each Ray worker process receives its own pool of thread workers and connections + multithreaded_io = not context.get_context().is_ray_runner + + file_format_config = FileFormatConfig.from_parquet_config( + ParquetSourceConfig( + multithreaded_io=multithreaded_io, + ) + ) if use_native_downloader: storage_config = StorageConfig.native(NativeStorageConfig(io_config)) else: diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 4cb0276c34..b12ffb9feb 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -106,6 +106,7 @@ def read_parquet( storage_config: StorageConfig | None = None, read_options: TableReadOptions = TableReadOptions(), parquet_options: TableParseParquetOptions = TableParseParquetOptions(), + multithreaded_io: bool | None = None, ) -> Table: """Reads a Table from a Parquet file @@ -130,6 +131,7 @@ def read_parquet( num_rows=read_options.num_rows, io_config=config.io_config, coerce_int96_timestamp_unit=parquet_options.coerce_int96_timestamp_unit, + multithreaded_io=multithreaded_io, ) return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 943693e410..91aed62b28 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -211,7 +211,7 @@ impl S3Config { Ok(self.config.access_key.clone()) } - /// AWS max connections + /// AWS max connections per IO thread #[getter] pub fn max_connections(&self) -> PyResult { Ok(self.config.max_connections) diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 417313f047..09032eafab 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -17,6 +17,7 @@ pub use common_io_config::{AzureConfig, IOConfig, S3Config}; pub use object_io::GetResult; #[cfg(feature = "python")] pub use python::register_modules; +use tokio::runtime::RuntimeFlavor; use std::{borrow::Cow, collections::HashMap, hash::Hash, ops::Range, sync::Arc}; @@ -261,16 +262,17 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { type CacheKey = (bool, Arc); lazy_static! { static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); + static ref THREADED_RUNTIME_NUM_WORKER_THREADS: usize = 8.min(*NUM_CPUS); static ref THREADED_RUNTIME: tokio::sync::RwLock<(Arc, usize)> = tokio::sync::RwLock::new(( Arc::new( tokio::runtime::Builder::new_multi_thread() - .worker_threads(8.min(*NUM_CPUS)) + .worker_threads(*THREADED_RUNTIME_NUM_WORKER_THREADS) .enable_all() .build() .unwrap() ), - 8.min(*NUM_CPUS) + *THREADED_RUNTIME_NUM_WORKER_THREADS, )); static ref CLIENT_CACHE: tokio::sync::RwLock>> = tokio::sync::RwLock::new(HashMap::new()); @@ -332,6 +334,15 @@ pub fn set_io_pool_num_threads(num_threads: usize) -> bool { true } +pub fn get_io_pool_num_threads() -> Option { + tokio::runtime::Handle::try_current().map_or(None, |handle| match handle.runtime_flavor() { + RuntimeFlavor::CurrentThread => Some(1), + RuntimeFlavor::MultiThread => Some(THREADED_RUNTIME.blocking_read().1), + // RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative + _ => Some(1), + }) +} + pub fn _url_download( array: &Utf8Array, max_connections: usize, diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 3a549d2afb..490764cd98 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -9,7 +9,7 @@ use s3::operation::list_objects_v2::ListObjectsV2Error; use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit}; use crate::object_io::{FileMetadata, FileType, LSResult}; -use crate::{InvalidArgumentSnafu, SourceType}; +use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; use aws_credential_types::cache::ProvideCachedCredentials; use aws_credential_types::provider::error::CredentialsError; @@ -311,7 +311,8 @@ async fn build_client(config: &S3Config) -> super::Result { Ok(S3LikeSource { region_to_client_map: tokio::sync::RwLock::new(client_map), connection_pool_sema: Arc::new(tokio::sync::Semaphore::new( - config.max_connections as usize, + (config.max_connections as usize) + * get_io_pool_num_threads().expect("Should be running in tokio pool"), )), s3_config: config.clone(), default_region, diff --git a/src/daft-plan/src/source_info/file_format.rs b/src/daft-plan/src/source_info/file_format.rs index fdcb483e61..1e4988f4d5 100644 --- a/src/daft-plan/src/source_info/file_format.rs +++ b/src/daft-plan/src/source_info/file_format.rs @@ -52,15 +52,22 @@ impl FileFormatConfig { /// Configuration for a Parquet data source. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] #[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] -pub struct ParquetSourceConfig; +pub struct ParquetSourceConfig { + multithreaded_io: bool, +} #[cfg(feature = "python")] #[pymethods] impl ParquetSourceConfig { /// Create a config for a Parquet data source. #[new] - fn new() -> Self { - Self {} + fn new(multithreaded_io: bool) -> Self { + Self { multithreaded_io } + } + + #[getter] + fn multithreaded_io(&self) -> PyResult { + Ok(self.multithreaded_io) } } From fdb084964062266ee20666f8ebdac85ed97330c9 Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Wed, 11 Oct 2023 10:58:32 -0700 Subject: [PATCH 2/2] Make function async and use a non-blocking read --- src/daft-io/src/lib.rs | 19 ++++++++++++------- src/daft-io/src/s3_like.rs | 4 +++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 09032eafab..3f14834fb3 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -334,13 +334,18 @@ pub fn set_io_pool_num_threads(num_threads: usize) -> bool { true } -pub fn get_io_pool_num_threads() -> Option { - tokio::runtime::Handle::try_current().map_or(None, |handle| match handle.runtime_flavor() { - RuntimeFlavor::CurrentThread => Some(1), - RuntimeFlavor::MultiThread => Some(THREADED_RUNTIME.blocking_read().1), - // RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative - _ => Some(1), - }) +pub async fn get_io_pool_num_threads() -> Option { + match tokio::runtime::Handle::try_current() { + Ok(handle) => { + match handle.runtime_flavor() { + RuntimeFlavor::CurrentThread => Some(1), + RuntimeFlavor::MultiThread => Some(THREADED_RUNTIME.read().await.1), + // RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative + _ => Some(1), + } + } + Err(_) => None, + } } pub fn _url_download( diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 490764cd98..46b9831712 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -312,7 +312,9 @@ async fn build_client(config: &S3Config) -> super::Result { region_to_client_map: tokio::sync::RwLock::new(client_map), connection_pool_sema: Arc::new(tokio::sync::Semaphore::new( (config.max_connections as usize) - * get_io_pool_num_threads().expect("Should be running in tokio pool"), + * get_io_pool_num_threads() + .await + .expect("Should be running in tokio pool"), )), s3_config: config.clone(), default_region,