diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 03ddac2f3b..d83a486e65 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -1,5 +1,5 @@ #![feature(async_closure)] - +#![feature(let_chains)] mod azure_blob; mod google_cloud; mod http; diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 8de8940d7c..d793c129c8 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -117,20 +117,27 @@ pub(crate) async fn recursive_iter( uri: &str, ) -> super::Result>> { let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024); - fn add_to_channel(source: Arc, tx: Sender, dir: String) { + fn add_to_channel( + source: Arc, + tx: Sender>, + dir: String, + ) { tokio::spawn(async move { - let mut s = source.iter_dir(&dir, None, None).await.unwrap(); + let s = source.iter_dir(&dir, None, None).await; + let mut s = match s { + Ok(s) => s, + Err(e) => { + tx.send(Err(e)).await.unwrap(); + return; + } + }; let tx = &tx; while let Some(tr) = s.next().await { - let tr = tr.unwrap(); - match tr.filetype { - FileType::File => tx.send(tr).await.unwrap(), - FileType::Directory => { - let dirpath = tr.filepath.clone(); - tx.send(tr).await.unwrap(); - add_to_channel(source.clone(), tx.clone(), dirpath) - } - }; + let tr = tr; + if let Ok(ref tr) = tr && matches!(tr.filetype, FileType::Directory) { + add_to_channel(source.clone(), tx.clone(), tr.filepath.clone()) + } + tx.send(tr).await.unwrap(); } }); } @@ -139,7 +146,7 @@ pub(crate) async fn recursive_iter( let to_rtn_stream = stream! { while let Some(v) = to_rtn_rx.recv().await { - yield Ok(v) + yield v } }; diff --git a/tests/integration/io/test_list_files_s3_minio.py b/tests/integration/io/test_list_files_s3_minio.py index c5a902d5b2..7a1744c307 100644 --- a/tests/integration/io/test_list_files_s3_minio.py +++ b/tests/integration/io/test_list_files_s3_minio.py @@ -68,3 +68,18 @@ def test_single_file_directory_listing(minio_io_config, recursive): s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}/c/cc/ccc") assert len(daft_ls_result) == 1 compare_s3_result(daft_ls_result, s3fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "recursive", + [False, True], +) +def test_missing_file_path(minio_io_config, recursive): + bucket_name = "bucket" + with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs: + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + fs.write_bytes(f"s3://{bucket_name}/{name}", b"") + with pytest.raises(FileNotFoundError, match=f"s3://{bucket_name}/c/cc/ddd"): + daft_ls_result = io_list(f"s3://{bucket_name}/c/cc/ddd", io_config=minio_io_config, recursive=recursive)