Skip to content

Commit

Permalink
[BUG] Fixes to S3 Native Lister with correct Error propagation (#1401)
Browse files Browse the repository at this point in the history
* Fixes s3 listing infinte loop
* correct error propagation in the recursive file lister
* return directories in the recursive file lister
* Fixes bug when you pass in an exact file
* Adds tests to cover the above cases
* Adds s3fs recursive lister utility
  • Loading branch information
samster25 authored Sep 21, 2023
1 parent eccd030 commit 2bb2aa4
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![feature(async_closure)]

#![feature(let_chains)]
mod azure_blob;
mod google_cloud;
mod http;
Expand Down
27 changes: 19 additions & 8 deletions src/daft-io/src/object_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,27 @@ pub(crate) async fn recursive_iter(
uri: &str,
) -> super::Result<BoxStream<super::Result<FileMetadata>>> {
let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024);
fn add_to_channel(source: Arc<dyn ObjectSource>, tx: Sender<FileMetadata>, dir: String) {
fn add_to_channel(
source: Arc<dyn ObjectSource>,
tx: Sender<super::Result<FileMetadata>>,
dir: String,
) {
tokio::spawn(async move {
let mut s = source.iter_dir(&dir, None, None).await.unwrap();
let s = source.iter_dir(&dir, None, None).await;
let mut s = match s {
Ok(s) => s,
Err(e) => {
tx.send(Err(e)).await.unwrap();
return;
}
};
let tx = &tx;
while let Some(tr) = s.next().await {
let tr = tr.unwrap();
match tr.filetype {
FileType::File => tx.send(tr).await.unwrap(),
FileType::Directory => add_to_channel(source.clone(), tx.clone(), tr.filepath),
};
let tr = tr;
if let Ok(ref tr) = tr && matches!(tr.filetype, FileType::Directory) {
add_to_channel(source.clone(), tx.clone(), tr.filepath.clone())
}
tx.send(tr).await.unwrap();
}
});
}
Expand All @@ -135,7 +146,7 @@ pub(crate) async fn recursive_iter(

let to_rtn_stream = stream! {
while let Some(v) = to_rtn_rx.recv().await {
yield Ok(v)
yield v
}
};

Expand Down
16 changes: 9 additions & 7 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ impl ObjectSource for S3LikeSource {
) -> super::Result<LSResult> {
let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?;
let delimiter = delimiter.unwrap_or("/");
log::warn!("{:?}", parsed);

let bucket = match parsed.host_str() {
Some(s) => Ok(s),
Expand All @@ -701,22 +700,25 @@ impl ObjectSource for S3LikeSource {
}?;
let key = parsed.path();

let key = key
.strip_prefix('/')
.map(|k| k.strip_suffix('/').unwrap_or(k));
let key = key.unwrap_or("");

let key = key.strip_prefix('/').unwrap_or("");
let key = if key.is_empty() {
"".to_string()
} else {
let key = key.strip_suffix('/').unwrap_or(key);
format!("{key}/")
};
// assume its a directory first
let lsr = {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;

self._list_impl(
permit,
bucket,
key,
&key,
delimiter.into(),
continuation_token.map(String::from),
&self.default_region,
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def minio_create_bucket(
password=minio_io_config.s3.access_key,
client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url},
)
if fs.exists(bucket_name):
fs.rm(bucket_name, recursive=True)
fs.mkdir(bucket_name)
try:
yield fs
Expand Down
62 changes: 61 additions & 1 deletion tests/integration/io/test_list_files_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,24 @@

def compare_s3_result(daft_ls_result: list, s3fs_result: list):
daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result]
s3fs_files = [(f"s3://{f['Key']}", f["type"]) for f in s3fs_result]
s3fs_files = [(f"s3://{f['name']}", f["type"]) for f in s3fs_result]
assert sorted(daft_files) == sorted(s3fs_files)


def s3fs_recursive_list(fs, path) -> list:
all_results = []
curr_level_result = fs.ls(path, detail=True)
for item in curr_level_result:
if item["type"] == "directory":
new_path = f's3://{item["name"]}'
all_results.extend(s3fs_recursive_list(fs, new_path))
item["name"] += "/"
all_results.append(item)
else:
all_results.append(item)
return all_results


@pytest.mark.integration()
def test_flat_directory_listing(minio_io_config):
bucket_name = "bucket"
Expand All @@ -23,3 +37,49 @@ def test_flat_directory_listing(minio_io_config):
daft_ls_result = io_list(f"s3://{bucket_name}", io_config=minio_io_config)
s3fs_result = fs.ls(f"s3://{bucket_name}", detail=True)
compare_s3_result(daft_ls_result, s3fs_result)


@pytest.mark.integration()
def test_recursive_directory_listing(minio_io_config):
bucket_name = "bucket"
with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs:
files = ["a", "b/bb", "c/cc/ccc"]
for name in files:
fs.write_bytes(f"s3://{bucket_name}/{name}", b"")
daft_ls_result = io_list(f"s3://{bucket_name}/", io_config=minio_io_config, recursive=True)
fs.invalidate_cache()
s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}")
compare_s3_result(daft_ls_result, s3fs_result)


@pytest.mark.integration()
@pytest.mark.parametrize(
"recursive",
[False, True],
)
def test_single_file_directory_listing(minio_io_config, recursive):
bucket_name = "bucket"
with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs:
files = ["a", "b/bb", "c/cc/ccc"]
for name in files:
fs.write_bytes(f"s3://{bucket_name}/{name}", b"")
daft_ls_result = io_list(f"s3://{bucket_name}/c/cc/ccc", io_config=minio_io_config, recursive=recursive)
fs.invalidate_cache()
s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}/c/cc/ccc")
assert len(daft_ls_result) == 1
compare_s3_result(daft_ls_result, s3fs_result)


@pytest.mark.integration()
@pytest.mark.parametrize(
"recursive",
[False, True],
)
def test_missing_file_path(minio_io_config, recursive):
bucket_name = "bucket"
with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs:
files = ["a", "b/bb", "c/cc/ccc"]
for name in files:
fs.write_bytes(f"s3://{bucket_name}/{name}", b"")
with pytest.raises(FileNotFoundError, match=f"s3://{bucket_name}/c/cc/ddd"):
daft_ls_result = io_list(f"s3://{bucket_name}/c/cc/ddd", io_config=minio_io_config, recursive=recursive)

0 comments on commit 2bb2aa4

Please sign in to comment.