Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PERF] Pass-through multithreaded_io flag in read_parquet #1484

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ class ParquetSourceConfig:
Configuration of a Parquet data source.
"""

# Whether or not to use a multithreaded tokio runtime for processing I/O
multithreaded_io: bool

def __init__(self, multithreaded_io: bool): ...

class CsvSourceConfig:
"""
Configuration of a CSV data source.
Expand Down
1 change: 1 addition & 0 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def _handle_tabular_files_scan(
schema=self.schema,
storage_config=self.storage_config,
read_options=read_options,
multithreaded_io=format_config.multithreaded_io,
)
for fp in filepaths
]
Expand Down
17 changes: 12 additions & 5 deletions daft/io/_parquet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# isort: dont-add-import: from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union

import fsspec

from daft import context
from daft.api_annotations import PublicAPI
from daft.daft import (
FileFormatConfig,
IOConfig,
NativeStorageConfig,
ParquetSourceConfig,
PythonStorageConfig,
Expand All @@ -16,9 +18,6 @@
from daft.datatype import DataType
from daft.io.common import _get_tabular_files_scan

if TYPE_CHECKING:
from daft.io import IOConfig


@PublicAPI
def read_parquet(
Expand Down Expand Up @@ -53,7 +52,15 @@ def read_parquet(
if isinstance(path, list) and len(path) == 0:
raise ValueError(f"Cannot read DataFrame from from empty list of Parquet filepaths")

file_format_config = FileFormatConfig.from_parquet_config(ParquetSourceConfig())
# If running on Ray, we want to limit the amount of concurrency and requests being made.
# This is because each Ray worker process receives its own pool of thread workers and connections
multithreaded_io = not context.get_context().is_ray_runner

file_format_config = FileFormatConfig.from_parquet_config(
ParquetSourceConfig(
multithreaded_io=multithreaded_io,
)
)
if use_native_downloader:
storage_config = StorageConfig.native(NativeStorageConfig(io_config))
else:
Expand Down
2 changes: 2 additions & 0 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def read_parquet(
storage_config: StorageConfig | None = None,
read_options: TableReadOptions = TableReadOptions(),
parquet_options: TableParseParquetOptions = TableParseParquetOptions(),
multithreaded_io: bool | None = None,
) -> Table:
"""Reads a Table from a Parquet file

Expand All @@ -130,6 +131,7 @@ def read_parquet(
num_rows=read_options.num_rows,
io_config=config.io_config,
coerce_int96_timestamp_unit=parquet_options.coerce_int96_timestamp_unit,
multithreaded_io=multithreaded_io,
)
return _cast_table_to_schema(tbl, read_options=read_options, schema=schema)

Expand Down
2 changes: 1 addition & 1 deletion src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl S3Config {
Ok(self.config.access_key.clone())
}

/// AWS max connections
/// AWS max connections per IO thread
#[getter]
pub fn max_connections(&self) -> PyResult<u32> {
Ok(self.config.max_connections)
Expand Down
20 changes: 18 additions & 2 deletions src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub use common_io_config::{AzureConfig, IOConfig, S3Config};
pub use object_io::GetResult;
#[cfg(feature = "python")]
pub use python::register_modules;
use tokio::runtime::RuntimeFlavor;

use std::{borrow::Cow, collections::HashMap, hash::Hash, ops::Range, sync::Arc};

Expand Down Expand Up @@ -261,16 +262,17 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> {
type CacheKey = (bool, Arc<IOConfig>);
lazy_static! {
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
static ref THREADED_RUNTIME_NUM_WORKER_THREADS: usize = 8.min(*NUM_CPUS);
static ref THREADED_RUNTIME: tokio::sync::RwLock<(Arc<tokio::runtime::Runtime>, usize)> =
tokio::sync::RwLock::new((
Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(8.min(*NUM_CPUS))
.worker_threads(*THREADED_RUNTIME_NUM_WORKER_THREADS)
.enable_all()
.build()
.unwrap()
),
8.min(*NUM_CPUS)
*THREADED_RUNTIME_NUM_WORKER_THREADS,
));
static ref CLIENT_CACHE: tokio::sync::RwLock<HashMap<CacheKey, Arc<IOClient>>> =
tokio::sync::RwLock::new(HashMap::new());
Expand Down Expand Up @@ -332,6 +334,20 @@ pub fn set_io_pool_num_threads(num_threads: usize) -> bool {
true
}

pub async fn get_io_pool_num_threads() -> Option<usize> {
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
match handle.runtime_flavor() {
RuntimeFlavor::CurrentThread => Some(1),
RuntimeFlavor::MultiThread => Some(THREADED_RUNTIME.read().await.1),
// RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative
_ => Some(1),
}
}
Err(_) => None,
}
}

pub fn _url_download(
array: &Utf8Array,
max_connections: usize,
Expand Down
7 changes: 5 additions & 2 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use s3::operation::list_objects_v2::ListObjectsV2Error;
use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit};

use crate::object_io::{FileMetadata, FileType, LSResult};
use crate::{InvalidArgumentSnafu, SourceType};
use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType};
use aws_config::SdkConfig;
use aws_credential_types::cache::ProvideCachedCredentials;
use aws_credential_types::provider::error::CredentialsError;
Expand Down Expand Up @@ -311,7 +311,10 @@ async fn build_client(config: &S3Config) -> super::Result<S3LikeSource> {
Ok(S3LikeSource {
region_to_client_map: tokio::sync::RwLock::new(client_map),
connection_pool_sema: Arc::new(tokio::sync::Semaphore::new(
config.max_connections as usize,
(config.max_connections as usize)
* get_io_pool_num_threads()
.await
.expect("Should be running in tokio pool"),
)),
s3_config: config.clone(),
default_region,
Expand Down
13 changes: 10 additions & 3 deletions src/daft-plan/src/source_info/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,22 @@ impl FileFormatConfig {
/// Configuration for a Parquet data source.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub struct ParquetSourceConfig;
pub struct ParquetSourceConfig {
multithreaded_io: bool,
}

#[cfg(feature = "python")]
#[pymethods]
impl ParquetSourceConfig {
/// Create a config for a Parquet data source.
#[new]
fn new() -> Self {
Self {}
fn new(multithreaded_io: bool) -> Self {
Self { multithreaded_io }
}

#[getter]
fn multithreaded_io(&self) -> PyResult<bool> {
Ok(self.multithreaded_io)
}
}

Expand Down
Loading