diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index 11a75a1224..e18a78d45d 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -32,8 +32,8 @@ impl Default for S3Config { access_key: None, max_connections_per_io_thread: 8, retry_initial_backoff_ms: 1000, - connect_timeout_ms: 10_000, - read_timeout_ms: 10_000, + connect_timeout_ms: 30_000, + read_timeout_ms: 30_000, // AWS EMR actually does 100 tries by default for AIMD retries // (See [Advanced AIMD retry settings]: https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-spark-emrfs-retry.html) num_tries: 25, diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index c839d82e8b..70e3f7b1fd 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -99,6 +99,9 @@ pub enum Error { #[snafu(display("Error joining spawned task: {}", source), context(false))] JoinError { source: tokio::task::JoinError }, + + #[snafu(display("Cached error: {}", source))] + CachedError { source: Arc }, } impl From for DaftError { diff --git a/src/daft-parquet/src/lib.rs b/src/daft-parquet/src/lib.rs index 637ab2cf91..a23af268f8 100644 --- a/src/daft-parquet/src/lib.rs +++ b/src/daft-parquet/src/lib.rs @@ -1,5 +1,6 @@ #![feature(async_closure)] #![feature(let_chains)] +#![feature(result_flattening)] use common_error::DaftError; use snafu::Snafu; diff --git a/src/daft-parquet/src/read_planner.rs b/src/daft-parquet/src/read_planner.rs index 76e557fef6..391e7320d5 100644 --- a/src/daft-parquet/src/read_planner.rs +++ b/src/daft-parquet/src/read_planner.rs @@ -3,7 +3,7 @@ use std::{fmt::Display, ops::Range, sync::Arc}; use bytes::Bytes; use common_error::DaftResult; use daft_io::{IOClient, IOStatsRef}; -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use tokio::task::JoinHandle; type RangeList = Vec>; @@ -86,7 +86,8 @@ impl ReadPlanPass for SplitLargeRequestPass { enum RangeCacheState { InFlight(JoinHandle>), - Ready(Bytes), + // Ready-state stores either the fetched bytes, or a shared error if the fetch failed. + Ready(std::result::Result>), } struct RangeCacheEntry { @@ -99,16 +100,25 @@ impl RangeCacheEntry { async fn get_or_wait(&self, range: Range) -> std::result::Result { { let mut _guard = self.state.lock().await; - match &mut (*_guard) { + match &mut *_guard { RangeCacheState::InFlight(f) => { // TODO(sammy): thread in url for join error let v = f .await - .map_err(|e| daft_io::Error::JoinError { source: e })??; - *_guard = RangeCacheState::Ready(v.clone()); - Ok(v.slice(range)) + .map_err(|e| daft_io::Error::JoinError { source: e }) + .flatten() + .map_err(Arc::new); + let sliced = v + .as_ref() + .map(|b| b.slice(range)) + .map_err(|e| daft_io::Error::CachedError { source: e.clone() }); + *_guard = RangeCacheState::Ready(v); + sliced } - RangeCacheState::Ready(v) => Ok(v.slice(range)), + RangeCacheState::Ready(v) => v + .as_ref() + .map(|b| b.slice(range)) + .map_err(|e| daft_io::Error::CachedError { source: e.clone() }), } } } @@ -247,7 +257,8 @@ impl RangesContainer { assert_eq!(current_pos, range.end); let bytes_iter = tokio_stream::iter(needed_entries.into_iter().zip(ranges_to_slice)) - .then(|(e, r)| async move { e.get_or_wait(r).await }); + .then(|(e, r)| async move { e.get_or_wait(r).await }) + .inspect_err(|e| panic!("Reading a range of Parquet bytes failed: {}", e)); let stream_reader = tokio_util::io::StreamReader::new(bytes_iter); let convert = async_compat::Compat::new(stream_reader); diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index e05257112b..54fe4fa055 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -377,3 +377,45 @@ def test_row_groups_selection_into_pyarrow_bulk(public_storage_io_config, multit for i, t in enumerate(rest): assert len(t) == 10 assert first[i * 10 : (i + 1) * 10] == t + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "multithreaded_io", + [False, True], +) +def test_connect_timeout(multithreaded_io): + url = "s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet" + connect_timeout_config = daft.io.IOConfig( + s3=daft.io.S3Config( + # NOTE: no keys or endpoints specified for an AWS public s3 bucket + region_name="us-west-2", + anonymous=True, + connect_timeout_ms=1, + num_tries=3, + ) + ) + + with pytest.raises(ValueError, match="HTTP connect timeout"): + MicroPartition.read_parquet(url, io_config=connect_timeout_config, multithreaded_io=multithreaded_io).to_arrow() + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "multithreaded_io", + [False, True], +) +def test_read_timeout(multithreaded_io): + url = "s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet" + read_timeout_config = daft.io.IOConfig( + s3=daft.io.S3Config( + # NOTE: no keys or endpoints specified for an AWS public s3 bucket + region_name="us-west-2", + anonymous=True, + read_timeout_ms=1, + num_tries=3, + ) + ) + + with pytest.raises(ValueError, match="HTTP read timeout"): + MicroPartition.read_parquet(url, io_config=read_timeout_config, multithreaded_io=multithreaded_io).to_arrow()