Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fixes to S3 Native Lister with correct Error propagation #1401

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![feature(async_closure)]

#![feature(let_chains)]
mod azure_blob;
mod google_cloud;
mod http;
Expand Down
27 changes: 19 additions & 8 deletions src/daft-io/src/object_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,27 @@ pub(crate) async fn recursive_iter(
uri: &str,
) -> super::Result<BoxStream<super::Result<FileMetadata>>> {
let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024);
fn add_to_channel(source: Arc<dyn ObjectSource>, tx: Sender<FileMetadata>, dir: String) {
fn add_to_channel(
source: Arc<dyn ObjectSource>,
tx: Sender<super::Result<FileMetadata>>,
dir: String,
) {
tokio::spawn(async move {
let mut s = source.iter_dir(&dir, None, None).await.unwrap();
let s = source.iter_dir(&dir, None, None).await;
let mut s = match s {
Ok(s) => s,
Err(e) => {
tx.send(Err(e)).await.unwrap();
return;
}
};
let tx = &tx;
while let Some(tr) = s.next().await {
let tr = tr.unwrap();
match tr.filetype {
FileType::File => tx.send(tr).await.unwrap(),
FileType::Directory => add_to_channel(source.clone(), tx.clone(), tr.filepath),
};
let tr = tr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line appears to be redundant now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if let Ok(ref tr) = tr && matches!(tr.filetype, FileType::Directory) {
add_to_channel(source.clone(), tx.clone(), tr.filepath.clone())
}
tx.send(tr).await.unwrap();
}
});
}
Expand All @@ -135,7 +146,7 @@ pub(crate) async fn recursive_iter(

let to_rtn_stream = stream! {
while let Some(v) = to_rtn_rx.recv().await {
yield Ok(v)
yield v
}
};

Expand Down
16 changes: 9 additions & 7 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ impl ObjectSource for S3LikeSource {
) -> super::Result<LSResult> {
let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?;
let delimiter = delimiter.unwrap_or("/");
log::warn!("{:?}", parsed);

let bucket = match parsed.host_str() {
Some(s) => Ok(s),
Expand All @@ -701,22 +700,25 @@ impl ObjectSource for S3LikeSource {
}?;
let key = parsed.path();

let key = key
.strip_prefix('/')
.map(|k| k.strip_suffix('/').unwrap_or(k));
let key = key.unwrap_or("");

let key = key.strip_prefix('/').unwrap_or("");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this seems like this will coerce the key to an empty string if it doesn't have a "/" prefix (i.e. if it isn't a base URL); for example, it appears that "foo/bar" would be coerced to "". Do we have a guarantee that all provided URLs will contain base URLs when URL-parsed?

Shouldn't this be .unwrap_or(key) just to be safe?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The UrlParse library leaves the prefix when your parse the url so "foo/bar" is hostname="foo", key="/bar". so if theres no prefix then it means it's an empty string.

let key = if key.is_empty() {
"".to_string()
} else {
let key = key.strip_suffix('/').unwrap_or(key);
format!("{key}/")
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this might be more simply expressed as:

let key = key.strip_prefix('/').unwrap_or(key);
let key = if !key.ends_with('/') { format!("{key}/") } else { key };

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wanted to take care of here is to drop trailing slashes but ensure that there's atleast one. looks like the right call was actually trim_** in https://github.com/Eventual-Inc/Daft/pull/1404/files

// assume its a directory first
let lsr = {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;

self._list_impl(
permit,
bucket,
key,
&key,
delimiter.into(),
continuation_token.map(String::from),
&self.default_region,
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def minio_create_bucket(
password=minio_io_config.s3.access_key,
client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url},
)
if fs.exists(bucket_name):
fs.rm(bucket_name, recursive=True)
fs.mkdir(bucket_name)
try:
yield fs
Expand Down
62 changes: 61 additions & 1 deletion tests/integration/io/test_list_files_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,24 @@

def compare_s3_result(daft_ls_result: list, s3fs_result: list):
daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result]
s3fs_files = [(f"s3://{f['Key']}", f["type"]) for f in s3fs_result]
s3fs_files = [(f"s3://{f['name']}", f["type"]) for f in s3fs_result]
assert sorted(daft_files) == sorted(s3fs_files)


def s3fs_recursive_list(fs, path) -> list:
all_results = []
curr_level_result = fs.ls(path, detail=True)
for item in curr_level_result:
if item["type"] == "directory":
new_path = f's3://{item["name"]}'
all_results.extend(s3fs_recursive_list(fs, new_path))
item["name"] += "/"
all_results.append(item)
else:
all_results.append(item)
return all_results


@pytest.mark.integration()
def test_flat_directory_listing(minio_io_config):
bucket_name = "bucket"
Expand All @@ -23,3 +37,49 @@ def test_flat_directory_listing(minio_io_config):
daft_ls_result = io_list(f"s3://{bucket_name}", io_config=minio_io_config)
s3fs_result = fs.ls(f"s3://{bucket_name}", detail=True)
compare_s3_result(daft_ls_result, s3fs_result)


@pytest.mark.integration()
def test_recursive_directory_listing(minio_io_config):
bucket_name = "bucket"
with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs:
files = ["a", "b/bb", "c/cc/ccc"]
for name in files:
fs.write_bytes(f"s3://{bucket_name}/{name}", b"")
daft_ls_result = io_list(f"s3://{bucket_name}/", io_config=minio_io_config, recursive=True)
fs.invalidate_cache()
s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}")
compare_s3_result(daft_ls_result, s3fs_result)


@pytest.mark.integration()
@pytest.mark.parametrize(
"recursive",
[False, True],
)
def test_single_file_directory_listing(minio_io_config, recursive):
bucket_name = "bucket"
with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs:
files = ["a", "b/bb", "c/cc/ccc"]
for name in files:
fs.write_bytes(f"s3://{bucket_name}/{name}", b"")
daft_ls_result = io_list(f"s3://{bucket_name}/c/cc/ccc", io_config=minio_io_config, recursive=recursive)
fs.invalidate_cache()
s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}/c/cc/ccc")
assert len(daft_ls_result) == 1
compare_s3_result(daft_ls_result, s3fs_result)


@pytest.mark.integration()
@pytest.mark.parametrize(
"recursive",
[False, True],
)
def test_missing_file_path(minio_io_config, recursive):
bucket_name = "bucket"
with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs:
files = ["a", "b/bb", "c/cc/ccc"]
for name in files:
fs.write_bytes(f"s3://{bucket_name}/{name}", b"")
with pytest.raises(FileNotFoundError, match=f"s3://{bucket_name}/c/cc/ddd"):
daft_ls_result = io_list(f"s3://{bucket_name}/c/cc/ddd", io_config=minio_io_config, recursive=recursive)
Loading