Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Native S3 Lister, support trailing slashes and fix panics when connection is dropped for tokio #1404

Merged
merged 2 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ pub enum Error {
#[snafu(display("Unhandled Error for path: {}\nDetails:\n{}", path, msg))]
Unhandled { path: String, msg: String },

#[snafu(
display("Error sending data over a tokio channel: {}", source),
context(false)
)]
UnableToSendDataOverChannel { source: DynError },

#[snafu(display("Error joining spawned task: {}", source), context(false))]
JoinError { source: tokio::task::JoinError },
}
Expand Down
12 changes: 8 additions & 4 deletions src/daft-io/src/object_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,22 @@ pub(crate) async fn recursive_iter(
let mut s = match s {
Ok(s) => s,
Err(e) => {
tx.send(Err(e)).await.unwrap();
return;
tx.send(Err(e)).await.map_err(|se| {
super::Error::UnableToSendDataOverChannel { source: se.into() }
})?;
return super::Result::<_, super::Error>::Ok(());
}
};
let tx = &tx;
while let Some(tr) = s.next().await {
let tr = tr;
if let Ok(ref tr) = tr && matches!(tr.filetype, FileType::Directory) {
add_to_channel(source.clone(), tx.clone(), tr.filepath.clone())
}
tx.send(tr).await.unwrap();
tx.send(tr)
.await
.map_err(|e| super::Error::UnableToSendDataOverChannel { source: e.into() })?;
}
super::Result::Ok(())
});
}

Expand Down
18 changes: 9 additions & 9 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -699,14 +699,15 @@ impl ObjectSource for S3LikeSource {
}),
}?;
let key = parsed.path();

let key = key.strip_prefix('/').unwrap_or("");
let key = key
.trim_start_matches(delimiter)
.trim_end_matches(delimiter);
let key = if key.is_empty() {
"".to_string()
} else {
let key = key.strip_suffix('/').unwrap_or(key);
format!("{key}/")
format!("{key}{delimiter}")
};

// assume its a directory first
let lsr = {
let permit = self
Expand All @@ -725,26 +726,25 @@ impl ObjectSource for S3LikeSource {
)
.await?
};
if lsr.files.is_empty() && key.contains('/') {
if lsr.files.is_empty() && key.contains(delimiter) {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;
// Might be a File
let split = key.rsplit_once('/');
let (new_key, _) = split.unwrap();
let key = key.trim_end_matches(delimiter);
let mut lsr = self
._list_impl(
permit,
bucket,
new_key,
key,
delimiter.into(),
continuation_token.map(String::from),
&self.default_region,
)
.await?;
let target_path = format!("s3://{bucket}/{new_key}");
let target_path = format!("s3://{bucket}/{key}");
lsr.files.retain(|f| f.filepath == target_path);

if lsr.files.is_empty() {
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/io/test_list_files_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ def test_single_file_directory_listing(minio_io_config, recursive):
compare_s3_result(daft_ls_result, s3fs_result)


@pytest.mark.integration()
@pytest.mark.parametrize(
"recursive",
[False, True],
)
def test_single_file_directory_listing_trailing(minio_io_config, recursive):
bucket_name = "bucket"
with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs:
files = ["a", "b/bb", "c/cc/ccc"]
for name in files:
fs.write_bytes(f"s3://{bucket_name}/{name}", b"")
daft_ls_result = io_list(f"s3://{bucket_name}/c/cc///", io_config=minio_io_config, recursive=recursive)
fs.invalidate_cache()
s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}/c/cc///")
assert len(daft_ls_result) == 1
compare_s3_result(daft_ls_result, s3fs_result)


@pytest.mark.integration()
@pytest.mark.parametrize(
"recursive",
Expand Down
Loading