diff --git a/src/daft-io/src/azure_blob.rs b/src/daft-io/src/azure_blob.rs index e63d77e6c1..cf4752fe97 100644 --- a/src/daft-io/src/azure_blob.rs +++ b/src/daft-io/src/azure_blob.rs @@ -14,6 +14,9 @@ use crate::{ }; use common_io_config::AzureConfig; +const AZURE_DELIMITER: &str = "/"; +const DEFAULT_GLOB_FANOUT_LIMIT: usize = 1024; + #[derive(Debug, Snafu)] enum Error { // Input errors. @@ -170,7 +173,7 @@ impl AzureBlobSource { protocol: &str, container_name: &str, prefix: &str, - delimiter: &str, + posix: bool, ) -> BoxStream> { let container_client = self.blob_client.container_client(container_name); @@ -178,7 +181,6 @@ impl AzureBlobSource { let protocol = protocol.to_string(); let container_name = container_name.to_string(); let prefix = prefix.to_string(); - let delimiter = delimiter.to_string(); // Blob stores expose listing by prefix and delimiter, // but this is not the exact same as a unix-like LS behaviour @@ -186,7 +188,10 @@ impl AzureBlobSource { // To use prefix listing as LS, we need to ensure the path given is exactly a directory or a file, not a prefix. // It turns out Azure list_blobs("path/") will match both a file at "path" and a folder at "path/", which is exactly what we need. - let prefix_with_delimiter = format!("{}{delimiter}", prefix.trim_end_matches(&delimiter)); + let prefix_with_delimiter = format!( + "{}{AZURE_DELIMITER}", + prefix.trim_end_matches(&AZURE_DELIMITER) + ); let full_path = format!("{}://{}{}", protocol, container_name, prefix); let full_path_with_trailing_delimiter = format!( "{}://{}{}", @@ -199,7 +204,7 @@ impl AzureBlobSource { &protocol, &container_name, &prefix_with_delimiter, - &delimiter, + &posix, ) .await; @@ -241,15 +246,15 @@ impl AzureBlobSource { // To check whether the prefix actually exists, check whether it exists as a result one directory above. // (Azure does not return marker files for empty directories.) let upper_dir = prefix // "/upper/blah/" - .trim_end_matches(&delimiter) // "/upper/blah" - .trim_end_matches(|c: char| c.to_string() != delimiter); // "/upper/" + .trim_end_matches(&AZURE_DELIMITER) // "/upper/blah" + .trim_end_matches(|c: char| c.to_string() != AZURE_DELIMITER); // "/upper/" let upper_results_stream = self._list_directory_delimiter_stream( &container_client, &protocol, &container_name, upper_dir, - &delimiter, + &posix, ).await; // At this point, we have a stream of Result. @@ -296,7 +301,7 @@ impl AzureBlobSource { protocol: &str, container_name: &str, prefix: &str, - delimiter: &str, + posix: &bool, ) -> BoxStream> { // Calls Azure list_blobs with the prefix // and returns the result flattened and standardized into FileMetadata. @@ -307,11 +312,14 @@ impl AzureBlobSource { let prefix = prefix.to_string(); // Paginated response stream from Azure API. - let responses_stream = container_client - .list_blobs() - .delimiter(delimiter.to_string()) - .prefix(prefix.clone()) - .into_stream(); + let mut responses_stream = container_client.list_blobs().prefix(prefix.clone()); + + // Setting delimiter will trigger "directory-mode" which is a posix-like ls for the current directory + if *posix { + responses_stream = responses_stream.delimiter(AZURE_DELIMITER.to_string()); + } + + let responses_stream = responses_stream.into_stream(); // Map each page of results to a page of standardized FileMetadata. responses_stream @@ -373,6 +381,10 @@ impl AzureBlobSource { #[async_trait] impl ObjectSource for AzureBlobSource { + fn delimiter(&self) -> &'static str { + AZURE_DELIMITER + } + async fn get(&self, uri: &str, range: Option>) -> super::Result { let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; let container = match parsed.host_str() { @@ -427,20 +439,28 @@ impl ObjectSource for AzureBlobSource { Ok(metadata.blob.properties.content_length as usize) } + async fn glob( + self: Arc, + glob_path: &str, + fanout_limit: Option, + page_size: Option, + ) -> super::Result>> { + use crate::object_store_glob::glob; + + // Ensure fanout_limit is not None to prevent runaway concurrency + let fanout_limit = fanout_limit.or(Some(DEFAULT_GLOB_FANOUT_LIMIT)); + + glob(self, glob_path, fanout_limit, page_size.or(Some(1000))).await + } + async fn iter_dir( &self, uri: &str, - delimiter: &str, posix: bool, _page_size: Option, - _limit: Option, ) -> super::Result>> { let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; - if !posix { - todo!("Prefix-listing is not yet implemented for Azure"); - } - // path can be root (buckets) or path prefix within a bucket. let container = { // "Container" is Azure's name for Bucket. @@ -469,7 +489,7 @@ impl ObjectSource for AzureBlobSource { Some(container_name) => { let prefix = uri.path(); Ok(self - .list_directory_stream(protocol, container_name, prefix, delimiter) + .list_directory_stream(protocol, container_name, prefix, posix) .await) } } @@ -478,10 +498,9 @@ impl ObjectSource for AzureBlobSource { async fn ls( &self, path: &str, - delimiter: &str, posix: bool, continuation_token: Option<&str>, - page_size: Option, + _page_size: Option, ) -> super::Result { // It looks like the azure rust library API // does not currently allow using the continuation token: @@ -496,7 +515,7 @@ impl ObjectSource for AzureBlobSource { }?; let files = self - .iter_dir(path, delimiter, posix, page_size, None) + .iter_dir(path, posix, None) .await? .try_collect::>() .await?; diff --git a/src/daft-io/src/glob.rs b/src/daft-io/src/glob.rs deleted file mode 100644 index 823af44d5d..0000000000 --- a/src/daft-io/src/glob.rs +++ /dev/null @@ -1,196 +0,0 @@ -use itertools::Itertools; -use std::{collections::HashSet, sync::Arc}; - -use globset::GlobMatcher; -use lazy_static::lazy_static; - -lazy_static! { - /// Check if a given char is considered a special glob character - /// NOTE: we use the `globset` crate which defines the following glob behavior: - /// https://docs.rs/globset/latest/globset/index.html#syntax - static ref GLOB_SPECIAL_CHARACTERS: HashSet = HashSet::from(['*', '?', '{', '}', '[', ']']); -} - -const SCHEME_SUFFIX_LEN: usize = "://".len(); - -#[derive(Clone)] -pub(crate) struct GlobState { - // Current path in dirtree and glob_fragments - pub current_path: String, - pub current_fragment_idx: usize, - - // How large of a fanout this level of iteration is currently experiencing - pub current_fanout: usize, - - // Whether we have encountered wildcards yet in the process of parsing - pub wildcard_mode: bool, - - // Carry along expensive data as Arcs to avoid recomputation - pub glob_fragments: Arc>, - pub full_glob_matcher: Arc, - pub fanout_limit: usize, - pub page_size: Option, -} - -impl GlobState { - pub fn current_glob_fragment(&self) -> &GlobFragment { - &self.glob_fragments[self.current_fragment_idx] - } - - pub fn advance(self, path: String, idx: usize, fanout_factor: usize) -> Self { - GlobState { - current_path: path, - current_fragment_idx: idx, - current_fanout: self.current_fanout * fanout_factor, - ..self.clone() - } - } - - pub fn with_wildcard_mode(self) -> Self { - GlobState { - wildcard_mode: true, - ..self - } - } -} - -#[derive(Debug, Clone)] -pub(crate) struct GlobFragment { - data: String, - escaped_data: String, - first_wildcard_idx: Option, -} - -impl GlobFragment { - pub fn new(data: &str) -> Self { - let first_wildcard_idx = match data { - "" => None, - data if GLOB_SPECIAL_CHARACTERS.contains(&data.chars().nth(0).unwrap()) => Some(0), - _ => { - // Detect any special characters that are not preceded by an escape \ - let mut idx = None; - for (i, window) in data - .chars() - .collect::>() - .as_slice() - .windows(2) - .enumerate() - { - let &[c1, c2] = window else { - unreachable!("Window contains 2 elements") - }; - if (c1 != '\\') && GLOB_SPECIAL_CHARACTERS.contains(&c2) { - idx = Some(i + 1); - break; - } - } - idx - } - }; - - // Sanitize `data`: removing '\' and converting '\\' to '\' - let mut escaped_data = String::new(); - let mut ptr = 0; - while ptr < data.len() { - let remaining = &data[ptr..]; - match remaining.find(r"\\") { - Some(backslash_idx) => { - escaped_data.push_str(&remaining[..backslash_idx].replace('\\', "")); - escaped_data.extend(std::iter::once('\\')); - ptr += backslash_idx + 2; - } - None => { - escaped_data.push_str(&remaining.replace('\\', "")); - break; - } - } - } - - GlobFragment { - data: data.to_string(), - first_wildcard_idx, - escaped_data, - } - } - - /// Checks if this GlobFragment has any special characters - pub fn has_special_character(&self) -> bool { - self.first_wildcard_idx.is_some() - } - - /// Joins a slice of GlobFragments together with a separator - pub fn join(fragments: &[GlobFragment], sep: &str) -> Self { - GlobFragment::new( - fragments - .iter() - .map(|frag: &GlobFragment| frag.data.as_str()) - .join(sep) - .as_str(), - ) - } - - /// Returns the fragment as a string with the backslash (\) escapes applied - /// 1. \\ is cleaned up to just \ - /// 2. \ followed by anything else is just ignored - pub fn escaped_str(&self) -> &str { - self.escaped_data.as_str() - } - - /// Returns the GlobFragment as a raw unescaped string, suitable for use by the globset crate - pub fn raw_str(&self) -> &str { - self.data.as_str() - } -} - -/// Parses a glob URL string into "fragments" -/// Fragments are the glob URL string but: -/// 1. Split by delimiter ("/") -/// 2. Non-wildcard fragments are joined and coalesced by delimiter -/// 3. The first fragment is prefixed by "{scheme}://" -pub(crate) fn to_glob_fragments(glob_str: &str) -> super::Result> { - let delimiter = "/"; - - // NOTE: We only use the URL parse library to get the scheme, because it will escape some of our glob special characters - // such as ? and {} - let glob_url = url::Url::parse(glob_str).map_err(|e| super::Error::InvalidUrl { - path: glob_str.to_string(), - source: e, - })?; - let url_scheme = glob_url.scheme(); - - // Parse glob fragments: split by delimiter and join any non-special fragments - let mut coalesced_fragments = vec![]; - let mut nonspecial_fragments_so_far = vec![]; - for fragment in glob_str[url_scheme.len() + SCHEME_SUFFIX_LEN..] - .split(delimiter) - .map(GlobFragment::new) - { - match fragment { - fragment if fragment.data.is_empty() => (), - fragment if fragment.has_special_character() => { - if !nonspecial_fragments_so_far.is_empty() { - coalesced_fragments.push(GlobFragment::join( - nonspecial_fragments_so_far.drain(..).as_slice(), - delimiter, - )); - } - coalesced_fragments.push(fragment); - } - _ => { - nonspecial_fragments_so_far.push(fragment); - } - } - } - if !nonspecial_fragments_so_far.is_empty() { - coalesced_fragments.push(GlobFragment::join( - nonspecial_fragments_so_far.drain(..).as_slice(), - delimiter, - )); - } - - // Ensure that the first fragment has the scheme prefixed - coalesced_fragments[0] = - GlobFragment::new((format!("{url_scheme}://") + coalesced_fragments[0].raw_str()).as_str()); - - Ok(coalesced_fragments) -} diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index 15a612caa2..bc190512dd 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -1,6 +1,7 @@ use std::ops::Range; use std::sync::Arc; +use futures::stream::BoxStream; use futures::StreamExt; use futures::TryStreamExt; use google_cloud_storage::client::ClientConfig; @@ -23,6 +24,10 @@ use crate::s3_like; use crate::GetResult; use common_io_config::GCSConfig; +const GCS_DELIMITER: &str = "/"; +const GCS_SCHEME: &str = "gs"; +const DEFAULT_GLOB_FANOUT_LIMIT: usize = 1024; + #[derive(Debug, Snafu)] enum Error { #[snafu(display("Unable to open {}: {}", path, source))] @@ -117,7 +122,7 @@ fn parse_uri(uri: &url::Url) -> super::Result<(&str, &str)> { }), }?; let key = uri.path(); - let key = key.strip_prefix('/').unwrap_or(key); + let key = key.strip_prefix(GCS_DELIMITER).unwrap_or(key); Ok((bucket, key)) } @@ -195,7 +200,7 @@ impl GCSClientWrapper { client: &Client, bucket: &str, key: &str, - delimiter: &str, + delimiter: Option<&str>, continuation_token: Option<&str>, page_size: Option, ) -> super::Result { @@ -205,7 +210,7 @@ impl GCSClientWrapper { end_offset: None, start_offset: None, page_token: continuation_token.map(|s| s.to_string()), - delimiter: Some(delimiter.to_string()), // returns results in "directory mode" + delimiter: delimiter.map(|d| d.to_string()), // returns results in "directory mode" if delimiter is provided max_results: page_size, include_trailing_delimiter: Some(false), // This will not populate "directories" in the response's .item[] projection: None, @@ -215,17 +220,17 @@ impl GCSClientWrapper { .list_objects(&req) .await .context(UnableToListObjectsSnafu { - path: format!("gs://{}/{}", bucket, key), + path: format!("{GCS_SCHEME}://{}/{}", bucket, key), })?; let response_items = ls_response.items.unwrap_or_default(); let response_prefixes = ls_response.prefixes.unwrap_or_default(); let files = response_items.iter().map(|obj| FileMetadata { - filepath: format!("gs://{}/{}", bucket, obj.name), + filepath: format!("{GCS_SCHEME}://{}/{}", bucket, obj.name), size: Some(obj.size as u64), filetype: FileType::File, }); let dirs = response_prefixes.iter().map(|pref| FileMetadata { - filepath: format!("gs://{}/{}", bucket, pref), + filepath: format!("{GCS_SCHEME}://{}/{}", bucket, pref), size: None, filetype: FileType::Directory, }); @@ -238,7 +243,6 @@ impl GCSClientWrapper { async fn ls( &self, path: &str, - delimiter: &str, posix: bool, continuation_token: Option<&str>, page_size: Option, @@ -247,55 +251,68 @@ impl GCSClientWrapper { let (bucket, key) = parse_uri(&uri)?; match self { GCSClientWrapper::Native(client) => { - if !posix { - todo!("Prefix-listing is not yet implemented for GCS"); - } - - // Attempt to forcefully ls the key as a directory (by ensuring a "/" suffix) - let forced_directory_key = - format!("{}{delimiter}", key.trim_end_matches(delimiter)); - let forced_directory_ls_result = self - ._ls_impl( - client, - bucket, - forced_directory_key.as_str(), - delimiter, - continuation_token, - page_size, - ) - .await?; - - // If no items were obtained, then this is actually a file and we perform a second ls to obtain just the file's - // details as the one-and-only-one entry - if forced_directory_ls_result.files.is_empty() { - let file_result = self + if posix { + // Attempt to forcefully ls the key as a directory (by ensuring a "/" suffix) + let forced_directory_key = if key.is_empty() { + "".to_string() + } else { + format!("{}{GCS_DELIMITER}", key.trim_end_matches(GCS_DELIMITER)) + }; + let forced_directory_ls_result = self ._ls_impl( client, bucket, - key, - delimiter, + forced_directory_key.as_str(), + Some(GCS_DELIMITER), continuation_token, page_size, ) .await?; - // Not dir and not file, so it is missing - if file_result.files.is_empty() { - return Err(Error::NotFound { - path: path.to_string(), + // If no items were obtained, then this is actually a file and we perform a second ls to obtain just the file's + // details as the one-and-only-one entry + if forced_directory_ls_result.files.is_empty() { + let mut file_result = self + ._ls_impl( + client, + bucket, + key, + Some(GCS_DELIMITER), + continuation_token, + page_size, + ) + .await?; + + // Only retain exact matches (since the API does prefix lists by default) + let target_path = format!("{GCS_SCHEME}://{bucket}/{key}"); + file_result.files.retain(|fm| fm.filepath == target_path); + + // Not dir and not file, so it is missing + if file_result.files.is_empty() { + return Err(Error::NotFound { + path: path.to_string(), + } + .into()); } - .into()); - } - Ok(file_result) + Ok(file_result) + } else { + Ok(forced_directory_ls_result) + } } else { - Ok(forced_directory_ls_result) + self._ls_impl( + client, + bucket, + key, + None, // Force a prefix-listing + continuation_token, + page_size, + ) + .await } } GCSClientWrapper::S3Compat(client) => { - client - .ls(path, delimiter, posix, continuation_token, page_size) - .await + client.ls(path, posix, continuation_token, page_size).await } } } @@ -345,6 +362,10 @@ impl GCSSource { #[async_trait] impl ObjectSource for GCSSource { + fn delimiter(&self) -> &'static str { + GCS_DELIMITER + } + async fn get(&self, uri: &str, range: Option>) -> super::Result { self.client.get(uri, range).await } @@ -353,16 +374,29 @@ impl ObjectSource for GCSSource { self.client.get_size(uri).await } + async fn glob( + self: Arc, + glob_path: &str, + fanout_limit: Option, + page_size: Option, + ) -> super::Result>> { + use crate::object_store_glob::glob; + + // Ensure fanout_limit is not None to prevent runaway concurrency + let fanout_limit = fanout_limit.or(Some(DEFAULT_GLOB_FANOUT_LIMIT)); + + glob(self, glob_path, fanout_limit, page_size.or(Some(1000))).await + } + async fn ls( &self, path: &str, - delimiter: &str, posix: bool, continuation_token: Option<&str>, page_size: Option, ) -> super::Result { self.client - .ls(path, delimiter, posix, continuation_token, page_size) + .ls(path, posix, continuation_token, page_size) .await } } diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index d47008d89e..9bf80f9f6a 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -1,7 +1,7 @@ use std::{num::ParseIntError, ops::Range, string::FromUtf8Error, sync::Arc}; use async_trait::async_trait; -use futures::{StreamExt, TryStreamExt}; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use lazy_static::lazy_static; use regex::Regex; @@ -13,6 +13,8 @@ use crate::object_io::{FileMetadata, FileType, LSResult}; use super::object_io::{GetResult, ObjectSource}; +const HTTP_DELIMITER: &str = "/"; + lazy_static! { // Taken from: https://stackoverflow.com/a/15926317/3821154 static ref HTML_A_TAG_HREF_RE: Regex = @@ -84,7 +86,7 @@ fn _get_file_metadata_from_html(path: &str, text: &str) -> super::Result super::Result (), }; - let filetype = if matched_url.ends_with('/') { + let filetype = if matched_url.ends_with(HTTP_DELIMITER) { FileType::Directory } else { FileType::File @@ -167,6 +169,10 @@ impl HttpSource { #[async_trait] impl ObjectSource for HttpSource { + fn delimiter(&self) -> &'static str { + HTTP_DELIMITER + } + async fn get(&self, uri: &str, range: Option>) -> super::Result { let request = self.client.get(uri); let request = match range { @@ -222,16 +228,30 @@ impl ObjectSource for HttpSource { } } + async fn glob( + self: Arc, + glob_path: &str, + _fanout_limit: Option, + _page_size: Option, + ) -> super::Result>> { + use crate::object_store_glob::glob; + + // Ensure fanout_limit is None because HTTP ObjectSource does not support prefix listing + let fanout_limit = None; + let page_size = None; + + glob(self, glob_path, fanout_limit, page_size).await + } + async fn ls( &self, path: &str, - _delimiter: &str, posix: bool, _continuation_token: Option<&str>, _page_size: Option, ) -> super::Result { if !posix { - todo!("Prefix-listing is not implemented for HTTP listing"); + unimplemented!("Prefix-listing is not implemented for HTTP listing"); } let request = self.client.get(path); @@ -245,8 +265,8 @@ impl ObjectSource for HttpSource { // Reconstruct the actual path of the request, which may have been redirected via a 301 // This is important because downstream URL joining logic relies on proper trailing-slashes/index.html let path = response.url().to_string(); - let path = if path.ends_with('/') { - format!("{}/", path.trim_end_matches('/')) + let path = if path.ends_with(HTTP_DELIMITER) { + format!("{}/", path.trim_end_matches(HTTP_DELIMITER)) } else { path }; diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 73cbed2f13..417313f047 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -1,11 +1,11 @@ #![feature(async_closure)] #![feature(let_chains)] mod azure_blob; -mod glob; mod google_cloud; mod http; mod local; mod object_io; +mod object_store_glob; mod s3_like; use azure_blob::AzureBlobSource; use google_cloud::GCSSource; diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index 3a039ed623..c3d617f5bb 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -16,6 +16,9 @@ use snafu::{ResultExt, Snafu}; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncSeekExt}; use url::ParseError; + +const PLATFORM_FS_DELIMITER: &str = std::path::MAIN_SEPARATOR_STR; + pub(crate) struct LocalSource {} #[derive(Debug, Snafu)] @@ -104,6 +107,10 @@ pub struct LocalFile { #[async_trait] impl ObjectSource for LocalSource { + fn delimiter(&self) -> &'static str { + PLATFORM_FS_DELIMITER + } + async fn get(&self, uri: &str, range: Option>) -> super::Result { const LOCAL_PROTOCOL: &str = "file://"; if let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) { @@ -129,17 +136,29 @@ impl ObjectSource for LocalSource { Ok(meta.len() as usize) } + async fn glob( + self: Arc, + glob_path: &str, + _fanout_limit: Option, + _page_size: Option, + ) -> super::Result>> { + use crate::object_store_glob::glob; + + // Ensure fanout_limit is None because Local ObjectSource does not support prefix listing + let fanout_limit = None; + let page_size = None; + + glob(self, glob_path, fanout_limit, page_size).await + } + async fn ls( &self, path: &str, - delimiter: &str, posix: bool, _continuation_token: Option<&str>, - page_size: Option, + _page_size: Option, ) -> super::Result { - let s = self - .iter_dir(path, delimiter, posix, page_size, None) - .await?; + let s = self.iter_dir(path, posix, None).await?; let files = s.try_collect::>().await?; Ok(LSResult { files, @@ -150,13 +169,11 @@ impl ObjectSource for LocalSource { async fn iter_dir( &self, uri: &str, - _delimiter: &str, posix: bool, _page_size: Option, - _limit: Option, ) -> super::Result>> { if !posix { - todo!("Prefix-listing is not implemented for local"); + unimplemented!("Prefix-listing is not implemented for local."); } const LOCAL_PROTOCOL: &str = "file://"; @@ -201,7 +218,7 @@ impl ObjectSource for LocalSource { "{}{}{}", LOCAL_PROTOCOL, entry.path().to_string_lossy(), - if meta.is_dir() { "/" } else { "" } + if meta.is_dir() { self.delimiter() } else { "" } ), size: Some(meta.len()), filetype: meta.file_type().try_into().with_context(|_| { @@ -334,7 +351,7 @@ mod tests { let dir_path = format!("file://{}", dir.path().to_string_lossy()); let client = LocalSource::get_client().await?; - let ls_result = client.ls(dir_path.as_ref(), "/", true, None, None).await?; + let ls_result = client.ls(dir_path.as_ref(), true, None, None).await?; let mut files = ls_result.files.clone(); // Ensure stable sort ordering of file paths before comparing with expected payload. files.sort_by(|a, b| a.filepath.cmp(&b.filepath)); diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index fa2776ebcb..d7504b27ab 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -6,18 +6,10 @@ use bytes::Bytes; use common_error::DaftError; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; -use globset::GlobBuilder; -use tokio::sync::mpsc::Sender; -use tokio::sync::OwnedSemaphorePermit; -use crate::glob::GlobState; -use crate::{ - glob::{to_glob_fragments, GlobFragment}, - local::{collect_file, LocalFile}, -}; +use tokio::sync::OwnedSemaphorePermit; -/// Default limit before we fallback onto parallel prefix list streams -static DEFAULT_FANOUT_LIMIT: usize = 1024; +use crate::local::{collect_file, LocalFile}; pub enum GetResult { File(LocalFile), @@ -102,16 +94,25 @@ use async_stream::stream; #[async_trait] pub(crate) trait ObjectSource: Sync + Send { + /// Returns the delimiter for the platform (S3 vs GCS vs Azure vs local-unix vs Windows etc) + fn delimiter(&self) -> &'static str; + async fn get(&self, uri: &str, range: Option>) -> super::Result; async fn get_range(&self, uri: &str, range: Range) -> super::Result { self.get(uri, Some(range)).await } async fn get_size(&self, uri: &str) -> super::Result; + async fn glob( + self: Arc, + glob_path: &str, + fanout_limit: Option, + page_size: Option, + ) -> super::Result>>; + async fn ls( &self, path: &str, - delimiter: &str, posix: bool, continuation_token: Option<&str>, page_size: Option, @@ -120,22 +121,19 @@ pub(crate) trait ObjectSource: Sync + Send { async fn iter_dir( &self, uri: &str, - delimiter: &str, posix: bool, page_size: Option, - _limit: Option, ) -> super::Result>> { let uri = uri.to_string(); - let delimiter = delimiter.to_string(); let s = stream! { - let lsr = self.ls(&uri, delimiter.as_str(), posix, None, page_size).await?; + let lsr = self.ls(&uri, posix, None, page_size).await?; for fm in lsr.files { yield Ok(fm); } let mut continuation_token = lsr.continuation_token.clone(); while continuation_token.is_some() { - let lsr = self.ls(&uri, delimiter.as_str(), posix, continuation_token.as_deref(), page_size).await?; + let lsr = self.ls(&uri, posix, continuation_token.as_deref(), page_size).await?; continuation_token = lsr.continuation_token.clone(); for fm in lsr.files { yield Ok(fm); @@ -145,383 +143,3 @@ pub(crate) trait ObjectSource: Sync + Send { Ok(s.boxed()) } } - -/// Helper method to iterate on a directory with the following behavior -/// -/// * First attempts to non-recursively list all Files and Directories under the current `uri` -/// * If during iteration we detect the number of Directories being returned exceeds `max_dirs`, we -/// fall back onto a prefix list of all Files with the current `uri` as the prefix -/// -/// Returns a tuple `(file_metadata_stream: BoxStream<...>, dir_count: usize)` where the second element -/// indicates the number of Directory entries contained within the stream -async fn ls_with_prefix_fallback( - source: Arc, - uri: &str, - delimiter: &str, - max_dirs: usize, - page_size: Option, -) -> (BoxStream<'static, super::Result>, usize) { - // Prefix list function that only returns Files - fn prefix_ls( - source: Arc, - path: String, - delimiter: String, - page_size: Option, - ) -> BoxStream<'static, super::Result> { - stream! { - match source.iter_dir(&path, delimiter.as_str(), false, page_size, None).await { - Ok(mut result_stream) => { - while let Some(fm) = result_stream.next().await { - match fm { - Ok(fm) => { - if matches!(fm.filetype, FileType::File) - { - yield Ok(fm) - } - } - Err(e) => yield Err(e), - } - } - }, - Err(e) => yield Err(e), - } - } - .boxed() - } - - // Buffer results in memory as we go along - let mut results_buffer = vec![]; - let mut fm_stream = source - .iter_dir(uri, delimiter, true, page_size, None) - .await - .unwrap_or_else(|e| futures::stream::iter([Err(e)]).boxed()); - - // Iterate and collect results into the `results_buffer`, but terminate early if too many directories are found - let mut dir_count_so_far = 0; - while let Some(fm) = fm_stream.next().await { - if let Ok(fm) = &fm { - if matches!(fm.filetype, FileType::Directory) { - dir_count_so_far += 1; - // STOP EARLY!! - // If the number of directory results are more than `max_dirs`, we terminate the function early, - // throw away our results buffer and return a stream of FileType::File files using `prefix_ls` instead - if dir_count_so_far > max_dirs { - return ( - prefix_ls( - source.clone(), - uri.to_string(), - delimiter.to_string(), - page_size, - ), - 0, - ); - } - } - } - results_buffer.push(fm); - } - - // No early termination: we unwrap the results in our results buffer and yield data as a stream - let s = futures::stream::iter(results_buffer); - (s.boxed(), dir_count_so_far) -} - -pub(crate) async fn glob( - source: Arc, - glob: &str, - fanout_limit: Option, - page_size: Option, -) -> super::Result>> { - // If no special characters, we fall back to ls behavior - let full_fragment = GlobFragment::new(glob); - if !full_fragment.has_special_character() { - let glob = full_fragment.escaped_str().to_string(); - return Ok(stream! { - let mut results = source.iter_dir(glob.as_str(), "/", true, page_size, None).await?; - while let Some(val) = results.next().await { - match &val { - // Ignore non-File results - Ok(fm) if !matches!(fm.filetype, FileType::File) => continue, - _ => yield val, - } - } - } - .boxed()); - } - - // If user specifies a trailing / then we understand it as an attempt to list the folder(s) matched - // and append a trailing * fragment - let glob = if glob.ends_with('/') { - glob.to_string() + "*" - } else { - glob.to_string() - }; - let glob = glob.as_str(); - - let fanout_limit = fanout_limit.unwrap_or(DEFAULT_FANOUT_LIMIT); - let glob_fragments = to_glob_fragments(glob)?; - let full_glob_matcher = GlobBuilder::new(glob) - .literal_separator(true) - .backslash_escape(true) - .build() - .map_err(|err| super::Error::InvalidArgument { - msg: format!("Cannot parse provided glob {glob}: {err}"), - })? - .compile_matcher(); - - // Channel to send results back to caller. Note that all results must have FileType::File. - let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024); - - /// Dispatches a task to visit the specified `path` (a concrete path on the filesystem to either a File or Directory). - /// Based on the current glob_fragment being processed (accessible via `glob_fragments[i]`) this task will: - /// 1. Perform work to retrieve Files/Directories at (`path` + `glob_fragments[i]`) - /// 2. Return results to the provided `result_tx` channel based on the provided glob, if appropriate - /// 3. Dispatch additional tasks via `.visit()` to continue visiting them, if appropriate - fn visit( - result_tx: Sender>, - source: Arc, - state: GlobState, - ) { - tokio::spawn(async move { - log::debug!(target: "glob", "Visiting '{}' with glob_fragments: {:?}", &state.current_path, &state.glob_fragments); - let current_fragment = state.current_glob_fragment(); - - // BASE CASE: current_fragment is a ** - // We perform a recursive ls and filter on the results for only FileType::File results that match the full glob - if current_fragment.escaped_str() == "**" { - let (mut results, stream_dir_count) = ls_with_prefix_fallback( - source.clone(), - &state.current_path, - "/", - state.fanout_limit / state.current_fanout, - state.page_size, - ) - .await; - - while let Some(val) = results.next().await { - match val { - Ok(fm) => { - match fm.filetype { - // Recursively visit each sub-directory - FileType::Directory => { - visit( - result_tx.clone(), - source.clone(), - // Do not increment `current_fragment_idx` so as to keep visiting the "**" fragmemt - state.clone().advance( - fm.filepath.clone(), - state.current_fragment_idx, - stream_dir_count, - ), - ); - } - // Return any Files that match - FileType::File - if state.full_glob_matcher.is_match(fm.filepath.as_str()) => - { - result_tx.send(Ok(fm)).await.expect("Internal multithreading channel is broken: results may be incorrect"); - } - _ => (), - } - } - // Silence NotFound errors when in wildcard "search" mode - Err(super::Error::NotFound { .. }) if state.wildcard_mode => (), - Err(e) => result_tx.send(Err(e)).await.expect( - "Internal multithreading channel is broken: results may be incorrect", - ), - } - } - // BASE CASE: current fragment is the last fragment in `glob_fragments` - } else if state.current_fragment_idx == state.glob_fragments.len() - 1 { - // Last fragment contains a wildcard: we list the last level and match against the full glob - if current_fragment.has_special_character() { - let mut results = source - .iter_dir(&state.current_path, "/", true, state.page_size, None) - .await - .unwrap_or_else(|e| futures::stream::iter([Err(e)]).boxed()); - - while let Some(val) = results.next().await { - match val { - Ok(fm) => { - if matches!(fm.filetype, FileType::File) - && state.full_glob_matcher.is_match(fm.filepath.as_str()) - { - result_tx.send(Ok(fm)).await.expect("Internal multithreading channel is broken: results may be incorrect"); - } - } - // Silence NotFound errors when in wildcard "search" mode - Err(super::Error::NotFound { .. }) if state.wildcard_mode => (), - Err(e) => result_tx.send(Err(e)).await.expect( - "Internal multithreading channel is broken: results may be incorrect", - ), - } - } - // Last fragment does not contain wildcard: we return it if the full path exists and is a FileType::File - } else { - let full_dir_path = state.current_path.clone() + current_fragment.escaped_str(); - let single_file_ls = source - .ls(full_dir_path.as_str(), "/", true, None, state.page_size) - .await; - match single_file_ls { - Ok(mut single_file_ls) => { - if single_file_ls.files.len() == 1 - && matches!(single_file_ls.files[0].filetype, FileType::File) - { - let fm = single_file_ls.files.drain(..).next().unwrap(); - result_tx.send(Ok(fm)).await.expect("Internal multithreading channel is broken: results may be incorrect"); - } - } - // Silence NotFound errors when in wildcard "search" mode - Err(super::Error::NotFound { .. }) if state.wildcard_mode => (), - Err(e) => result_tx.send(Err(e)).await.expect( - "Internal multithreading channel is broken: results may be incorrect", - ), - }; - } - - // RECURSIVE CASE: current_fragment contains a special character (e.g. *) - } else if current_fragment.has_special_character() { - let partial_glob_matcher = GlobBuilder::new( - GlobFragment::join( - &state.glob_fragments[..state.current_fragment_idx + 1], - "/", - ) - .raw_str(), - ) - .literal_separator(true) - .build() - .expect("Cannot parse glob") - .compile_matcher(); - - let (mut results, stream_dir_count) = ls_with_prefix_fallback( - source.clone(), - &state.current_path, - "/", - state.fanout_limit / state.current_fanout, - state.page_size, - ) - .await; - - while let Some(val) = results.next().await { - match val { - Ok(fm) => match fm.filetype { - FileType::Directory - if partial_glob_matcher - .is_match(fm.filepath.as_str().trim_end_matches('/')) => - { - visit( - result_tx.clone(), - source.clone(), - state - .clone() - .advance( - fm.filepath, - state.current_fragment_idx + 1, - stream_dir_count, - ) - .with_wildcard_mode(), - ); - } - FileType::File - if state.full_glob_matcher.is_match(fm.filepath.as_str()) => - { - result_tx.send(Ok(fm)).await.expect("Internal multithreading channel is broken: results may be incorrect"); - } - _ => (), - }, - // Always silence NotFound since we are in wildcard "search" mode here by definition - Err(super::Error::NotFound { .. }) => (), - Err(e) => result_tx.send(Err(e)).await.expect( - "Internal multithreading channel is broken: results may be incorrect", - ), - } - } - - // RECURSIVE CASE: current_fragment contains no special characters, and is a path to a specific File or Directory - } else { - let full_dir_path = state.current_path.clone() + current_fragment.escaped_str(); - visit( - result_tx.clone(), - source.clone(), - state - .clone() - .advance(full_dir_path, state.current_fragment_idx + 1, 1), - ); - } - }); - } - - visit( - to_rtn_tx, - source.clone(), - GlobState { - current_path: "".to_string(), - current_fragment_idx: 0, - glob_fragments: Arc::new(glob_fragments), - full_glob_matcher: Arc::new(full_glob_matcher), - wildcard_mode: false, - current_fanout: 1, - fanout_limit, - page_size, - }, - ); - - let to_rtn_stream = stream! { - while let Some(v) = to_rtn_rx.recv().await { - yield v - } - }; - - Ok(to_rtn_stream.boxed()) -} - -pub(crate) async fn recursive_iter( - source: Arc, - uri: &str, -) -> super::Result>> { - log::debug!(target: "recursive_iter", "starting recursive_iter: with top level of: {uri}"); - let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024); - fn add_to_channel( - source: Arc, - tx: Sender>, - dir: String, - ) { - log::debug!(target: "recursive_iter", "recursive_iter: spawning task to list: {dir}"); - let source = source.clone(); - tokio::spawn(async move { - let s = source.iter_dir(&dir, "/", true, Some(1000), None).await; - log::debug!(target: "recursive_iter", "started listing task for {dir}"); - let mut s = match s { - Ok(s) => s, - Err(e) => { - log::debug!(target: "recursive_iter", "Error occurred when listing {dir}\nerror:\n{e}"); - tx.send(Err(e)).await.map_err(|se| { - super::Error::UnableToSendDataOverChannel { source: se.into() } - })?; - return super::Result::<_, super::Error>::Ok(()); - } - }; - let tx = &tx; - while let Some(tr) = s.next().await { - if let Ok(ref tr) = tr && matches!(tr.filetype, FileType::Directory) { - add_to_channel(source.clone(), tx.clone(), tr.filepath.clone()) - } - tx.send(tr) - .await - .map_err(|e| super::Error::UnableToSendDataOverChannel { source: e.into() })?; - } - log::debug!(target: "recursive_iter", "completed listing task for {dir}"); - super::Result::Ok(()) - }); - } - - add_to_channel(source, to_rtn_tx, uri.to_string()); - - let to_rtn_stream = stream! { - while let Some(v) = to_rtn_rx.recv().await { - yield v - } - }; - - Ok(to_rtn_stream.boxed()) -} diff --git a/src/daft-io/src/object_store_glob.rs b/src/daft-io/src/object_store_glob.rs new file mode 100644 index 0000000000..5ffb6aec41 --- /dev/null +++ b/src/daft-io/src/object_store_glob.rs @@ -0,0 +1,562 @@ +use async_stream::stream; +use futures::stream::{BoxStream, StreamExt}; +use itertools::Itertools; +use std::{collections::HashSet, sync::Arc}; +use tokio::sync::mpsc::Sender; + +use globset::{GlobBuilder, GlobMatcher}; +use lazy_static::lazy_static; + +use crate::object_io::{FileMetadata, FileType, ObjectSource}; + +lazy_static! { + /// Check if a given char is considered a special glob character + /// NOTE: we use the `globset` crate which defines the following glob behavior: + /// https://docs.rs/globset/latest/globset/index.html#syntax + static ref GLOB_SPECIAL_CHARACTERS: HashSet = HashSet::from(['*', '?', '{', '}', '[', ']']); +} + +const SCHEME_SUFFIX_LEN: usize = "://".len(); + +#[derive(Clone)] +pub(crate) struct GlobState { + // Current path in dirtree and glob_fragments + pub current_path: String, + pub current_fragment_idx: usize, + + // How large of a fanout this level of iteration is currently experiencing + pub current_fanout: usize, + + // Whether we have encountered wildcards yet in the process of parsing + pub wildcard_mode: bool, + + // Carry along expensive data as Arcs to avoid recomputation + pub glob_fragments: Arc>, + pub full_glob_matcher: Arc, + pub fanout_limit: Option, + pub page_size: Option, +} + +impl GlobState { + pub fn current_glob_fragment(&self) -> &GlobFragment { + &self.glob_fragments[self.current_fragment_idx] + } + + pub fn advance(self, path: String, idx: usize, fanout_factor: usize) -> Self { + GlobState { + current_path: path, + current_fragment_idx: idx, + current_fanout: self.current_fanout * fanout_factor, + ..self.clone() + } + } + + pub fn with_wildcard_mode(self) -> Self { + GlobState { + wildcard_mode: true, + ..self + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct GlobFragment { + data: String, + escaped_data: String, + first_wildcard_idx: Option, +} + +impl GlobFragment { + pub fn new(data: &str) -> Self { + let first_wildcard_idx = match data { + "" => None, + data if GLOB_SPECIAL_CHARACTERS.contains(&data.chars().nth(0).unwrap()) => Some(0), + _ => { + // Detect any special characters that are not preceded by an escape \ + let mut idx = None; + for (i, window) in data + .chars() + .collect::>() + .as_slice() + .windows(2) + .enumerate() + { + let &[c1, c2] = window else { + unreachable!("Window contains 2 elements") + }; + if (c1 != '\\') && GLOB_SPECIAL_CHARACTERS.contains(&c2) { + idx = Some(i + 1); + break; + } + } + idx + } + }; + + // Sanitize `data`: removing '\' and converting '\\' to '\' + let mut escaped_data = String::new(); + let mut ptr = 0; + while ptr < data.len() { + let remaining = &data[ptr..]; + match remaining.find(r"\\") { + Some(backslash_idx) => { + escaped_data.push_str(&remaining[..backslash_idx].replace('\\', "")); + escaped_data.extend(std::iter::once('\\')); + ptr += backslash_idx + 2; + } + None => { + escaped_data.push_str(&remaining.replace('\\', "")); + break; + } + } + } + + GlobFragment { + data: data.to_string(), + first_wildcard_idx, + escaped_data, + } + } + + /// Checks if this GlobFragment has any special characters + pub fn has_special_character(&self) -> bool { + self.first_wildcard_idx.is_some() + } + + /// Joins a slice of GlobFragments together with a separator + pub fn join(fragments: &[GlobFragment], sep: &str) -> Self { + GlobFragment::new( + fragments + .iter() + .map(|frag: &GlobFragment| frag.data.as_str()) + .join(sep) + .as_str(), + ) + } + + /// Returns the fragment as a string with the backslash (\) escapes applied + /// 1. \\ is cleaned up to just \ + /// 2. \ followed by anything else is just ignored + pub fn escaped_str(&self) -> &str { + self.escaped_data.as_str() + } + + /// Returns the GlobFragment as a raw unescaped string, suitable for use by the globset crate + pub fn raw_str(&self) -> &str { + self.data.as_str() + } +} + +/// Parses a glob URL string into "fragments" +/// Fragments are the glob URL string but: +/// 1. Split by delimiter ("/") +/// 2. Non-wildcard fragments are joined and coalesced by delimiter +/// 3. The first fragment is prefixed by "{scheme}://" +/// 4. Preserves any leading delimiters +pub(crate) fn to_glob_fragments( + glob_str: &str, + delimiter: &str, +) -> super::Result> { + // NOTE: We only use the URL parse library to get the scheme, because it will escape some of our glob special characters + // such as ? and {} + let glob_url = url::Url::parse(glob_str).map_err(|e| super::Error::InvalidUrl { + path: glob_str.to_string(), + source: e, + })?; + let url_scheme = glob_url.scheme(); + + let glob_str_after_scheme = &glob_str[url_scheme.len() + SCHEME_SUFFIX_LEN..]; + + // NOTE: Leading delimiter may be important for absolute paths on local directory, and is considered + // part of the first fragment + let leading_delimiter = if glob_str_after_scheme.starts_with(delimiter) { + delimiter + } else { + "" + }; + + // Parse glob fragments: split by delimiter and join any non-special fragments + let mut coalesced_fragments = vec![]; + let mut nonspecial_fragments_so_far = vec![]; + for fragment in glob_str_after_scheme + .split(delimiter) + .map(GlobFragment::new) + { + match fragment { + fragment if fragment.data.is_empty() => (), + fragment if fragment.has_special_character() => { + if !nonspecial_fragments_so_far.is_empty() { + coalesced_fragments.push(GlobFragment::join( + nonspecial_fragments_so_far.drain(..).as_slice(), + delimiter, + )); + } + coalesced_fragments.push(fragment); + } + _ => { + nonspecial_fragments_so_far.push(fragment); + } + } + } + if !nonspecial_fragments_so_far.is_empty() { + coalesced_fragments.push(GlobFragment::join( + nonspecial_fragments_so_far.drain(..).as_slice(), + delimiter, + )); + } + + // Ensure that the first fragment has the scheme and leading delimiter (if requested) prefixed + coalesced_fragments[0] = GlobFragment::new( + (format!("{url_scheme}://") + leading_delimiter + coalesced_fragments[0].raw_str()) + .as_str(), + ); + + Ok(coalesced_fragments) +} + +/// Helper method to iterate on a directory with the following behavior +/// +/// * First attempts to non-recursively list all Files and Directories under the current `uri` +/// * If during iteration we detect the number of Directories being returned exceeds `max_dirs`, we +/// fall back onto a prefix list of all Files with the current `uri` as the prefix +/// +/// Returns a tuple `(file_metadata_stream: BoxStream<...>, dir_count: usize)` where the second element +/// indicates the number of Directory entries contained within the stream +async fn ls_with_prefix_fallback( + source: Arc, + uri: &str, + max_dirs: Option, + page_size: Option, +) -> (BoxStream<'static, super::Result>, usize) { + // Prefix list function that only returns Files + fn prefix_ls( + source: Arc, + path: String, + page_size: Option, + ) -> BoxStream<'static, super::Result> { + stream! { + match source.iter_dir(&path, false, page_size).await { + Ok(mut result_stream) => { + while let Some(result) = result_stream.next().await { + match result { + Ok(fm) => { + if matches!(fm.filetype, FileType::File) + { + yield Ok(fm) + } + } + Err(e) => yield Err(e), + } + } + }, + Err(e) => yield Err(e), + } + } + .boxed() + } + + // Buffer results in memory as we go along + let mut results_buffer = vec![]; + + let mut fm_stream = source + .iter_dir(uri, true, page_size) + .await + .unwrap_or_else(|e| futures::stream::iter([Err(e)]).boxed()); + + // Iterate and collect results into the `results_buffer`, but terminate early if too many directories are found + let mut dir_count_so_far = 0; + while let Some(fm) = fm_stream.next().await { + if let Ok(fm) = &fm { + if matches!(fm.filetype, FileType::Directory) { + dir_count_so_far += 1; + // STOP EARLY!! + // If the number of directory results are more than `max_dirs`, we terminate the function early, + // throw away our results buffer and return a stream of FileType::File files using `prefix_ls` instead + if max_dirs + .map(|max_dirs| dir_count_so_far > max_dirs) + .unwrap_or(false) + { + return (prefix_ls(source.clone(), uri.to_string(), page_size), 0); + } + } + } + results_buffer.push(fm); + } + + // No early termination: we unwrap the results in our results buffer and yield data as a stream + let s = futures::stream::iter(results_buffer); + (s.boxed(), dir_count_so_far) +} + +/// Globs an ObjectSource for Files +/// +/// Uses the `globset` crate for matching, and thus supports all the syntax enabled by that crate. +/// See: https://docs.rs/globset/latest/globset/#syntax +/// +/// Arguments: +/// * source: the ObjectSource to use for file listing +/// * glob: the string to glob +/// * fanout_limit: number of directories at which to fallback onto prefix listing, or None to never fall back. +/// A reasonable number here for a remote object store is something like 1024, which saturates the number of +/// parallel connections (usually defaulting to 64). +/// * page_size: control the returned results page size, or None to use the ObjectSource's defaults. Usually only used for testing +/// but may yield some performance improvements depending on the workload. +pub(crate) async fn glob( + source: Arc, + glob: &str, + fanout_limit: Option, + page_size: Option, +) -> super::Result>> { + let delimiter = source.delimiter(); + + // If no special characters, we fall back to ls behavior + let full_fragment = GlobFragment::new(glob); + if !full_fragment.has_special_character() { + let glob = full_fragment.escaped_str().to_string(); + return Ok(stream! { + let mut results = source.iter_dir(glob.as_str(), true, page_size).await?; + while let Some(result) = results.next().await { + match result { + Ok(fm) => { + if matches!(fm.filetype, FileType::File) { + yield Ok(fm) + } + }, + Err(e) => yield Err(e), + } + } + } + .boxed()); + } + + // If user specifies a trailing / then we understand it as an attempt to list the folder(s) matched + // and append a trailing * fragment + let glob = if glob.ends_with(source.delimiter()) { + glob.to_string() + "*" + } else { + glob.to_string() + }; + let glob = glob.as_str(); + + let glob_fragments = to_glob_fragments(glob, delimiter)?; + let full_glob_matcher = GlobBuilder::new(glob) + .literal_separator(true) + .backslash_escape(true) + .build() + .map_err(|err| super::Error::InvalidArgument { + msg: format!("Cannot parse provided glob {glob}: {err}"), + })? + .compile_matcher(); + + // Channel to send results back to caller. Note that all results must have FileType::File. + let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024); + + /// Dispatches a task to visit the specified `path` (a concrete path on the filesystem to either a File or Directory). + /// Based on the current glob_fragment being processed (accessible via `glob_fragments[i]`) this task will: + /// 1. Perform work to retrieve Files/Directories at (`path` + `glob_fragments[i]`) + /// 2. Return results to the provided `result_tx` channel based on the provided glob, if appropriate + /// 3. Dispatch additional tasks via `.visit()` to continue visiting them, if appropriate + fn visit( + result_tx: Sender>, + source: Arc, + state: GlobState, + ) { + tokio::spawn(async move { + log::debug!( + target: "glob", + "Visiting '{}' with glob_fragments: {:?}", + &state.current_path, &state.glob_fragments + ); + let current_fragment = state.current_glob_fragment(); + + // BASE CASE: current_fragment is a ** + // We perform a recursive ls and filter on the results for only FileType::File results that match the full glob + if current_fragment.escaped_str() == "**" { + let (mut results, stream_dir_count) = ls_with_prefix_fallback( + source.clone(), + &state.current_path, + state + .fanout_limit + .map(|fanout_limit| fanout_limit / state.current_fanout), + state.page_size, + ) + .await; + + while let Some(val) = results.next().await { + match val { + Ok(fm) => { + match fm.filetype { + // Recursively visit each sub-directory + FileType::Directory => { + visit( + result_tx.clone(), + source.clone(), + // Do not increment `current_fragment_idx` so as to keep visiting the "**" fragmemt + state.clone().advance( + fm.filepath.clone(), + state.current_fragment_idx, + stream_dir_count, + ), + ); + } + // Return any Files that match + FileType::File + if state.full_glob_matcher.is_match(fm.filepath.as_str()) => + { + result_tx.send(Ok(fm)).await.expect("Internal multithreading channel is broken: results may be incorrect"); + } + _ => (), + } + } + // Silence NotFound errors when in wildcard "search" mode + Err(super::Error::NotFound { .. }) if state.wildcard_mode => (), + Err(e) => result_tx.send(Err(e)).await.expect( + "Internal multithreading channel is broken: results may be incorrect", + ), + } + } + // BASE CASE: current fragment is the last fragment in `glob_fragments` + } else if state.current_fragment_idx == state.glob_fragments.len() - 1 { + // Last fragment contains a wildcard: we list the last level and match against the full glob + if current_fragment.has_special_character() { + let mut results = source + .iter_dir(&state.current_path, true, state.page_size) + .await + .unwrap_or_else(|e| futures::stream::iter([Err(e)]).boxed()); + + while let Some(result) = results.next().await { + match result { + Ok(fm) => { + if matches!(fm.filetype, FileType::File) + && state.full_glob_matcher.is_match(fm.filepath.as_str()) + { + result_tx.send(Ok(fm)).await.expect("Internal multithreading channel is broken: results may be incorrect"); + } + } + // Silence NotFound errors when in wildcard "search" mode + Err(super::Error::NotFound { .. }) if state.wildcard_mode => (), + Err(e) => result_tx.send(Err(e)).await.expect( + "Internal multithreading channel is broken: results may be incorrect", + ), + } + } + // Last fragment does not contain wildcard: we return it if the full path exists and is a FileType::File + } else { + let full_dir_path = state.current_path.clone() + current_fragment.escaped_str(); + let single_file_ls = source + .ls(full_dir_path.as_str(), true, None, state.page_size) + .await; + match single_file_ls { + Ok(mut single_file_ls) => { + if single_file_ls.files.len() == 1 + && matches!(single_file_ls.files[0].filetype, FileType::File) + { + let fm = single_file_ls.files.drain(..).next().unwrap(); + result_tx.send(Ok(fm)).await.expect("Internal multithreading channel is broken: results may be incorrect"); + } + } + // Silence NotFound errors when in wildcard "search" mode + Err(super::Error::NotFound { .. }) if state.wildcard_mode => (), + Err(e) => result_tx.send(Err(e)).await.expect( + "Internal multithreading channel is broken: results may be incorrect", + ), + }; + } + + // RECURSIVE CASE: current_fragment contains a special character (e.g. *) + } else if current_fragment.has_special_character() { + let partial_glob_matcher = GlobBuilder::new( + GlobFragment::join( + &state.glob_fragments[..state.current_fragment_idx + 1], + source.delimiter(), + ) + .raw_str(), + ) + .literal_separator(true) + .build() + .expect("Cannot parse glob") + .compile_matcher(); + + let (mut results, stream_dir_count) = ls_with_prefix_fallback( + source.clone(), + &state.current_path, + state + .fanout_limit + .map(|fanout_limit| fanout_limit / state.current_fanout), + state.page_size, + ) + .await; + + while let Some(val) = results.next().await { + match val { + Ok(fm) => match fm.filetype { + FileType::Directory + if partial_glob_matcher.is_match( + fm.filepath.as_str().trim_end_matches(source.delimiter()), + ) => + { + visit( + result_tx.clone(), + source.clone(), + state + .clone() + .advance( + fm.filepath, + state.current_fragment_idx + 1, + stream_dir_count, + ) + .with_wildcard_mode(), + ); + } + FileType::File + if state.full_glob_matcher.is_match(fm.filepath.as_str()) => + { + result_tx.send(Ok(fm)).await.expect("Internal multithreading channel is broken: results may be incorrect"); + } + _ => (), + }, + // Always silence NotFound since we are in wildcard "search" mode here by definition + Err(super::Error::NotFound { .. }) => (), + Err(e) => result_tx.send(Err(e)).await.expect( + "Internal multithreading channel is broken: results may be incorrect", + ), + } + } + + // RECURSIVE CASE: current_fragment contains no special characters, and is a path to a specific File or Directory + } else { + let full_dir_path = state.current_path.clone() + current_fragment.escaped_str(); + visit( + result_tx.clone(), + source.clone(), + state + .clone() + .advance(full_dir_path, state.current_fragment_idx + 1, 1), + ); + } + }); + } + + visit( + to_rtn_tx, + source.clone(), + GlobState { + current_path: "".to_string(), + current_fragment_idx: 0, + glob_fragments: Arc::new(glob_fragments), + full_glob_matcher: Arc::new(full_glob_matcher), + wildcard_mode: false, + current_fanout: 1, + fanout_limit, + page_size, + }, + ); + + let to_rtn_stream = stream! { + while let Some(v) = to_rtn_rx.recv().await { + yield v + } + }; + + Ok(to_rtn_stream.boxed()) +} diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index c4a666bd54..f22d182a19 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -2,11 +2,7 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig}; pub use py::register_modules; mod py { - use crate::{ - get_io_client, get_runtime, - object_io::{glob, recursive_iter}, - parse_url, - }; + use crate::{get_io_client, get_runtime, parse_url}; use common_error::DaftResult; use futures::TryStreamExt; use pyo3::{ @@ -35,7 +31,8 @@ mod py { runtime_handle.block_on(async move { let source = io_client.get_source(&scheme).await?; - let files = glob(source, path.as_ref(), fanout_limit, page_size) + let files = source + .glob(path.as_ref(), fanout_limit, page_size) .await? .try_collect() .await?; @@ -54,53 +51,6 @@ mod py { Ok(PyList::new(py, to_rtn)) } - #[pyfunction] - fn io_list( - py: Python, - path: String, - recursive: Option, - multithreaded_io: Option, - io_config: Option, - ) -> PyResult<&PyList> { - let lsr: DaftResult> = py.allow_threads(|| { - let io_client = get_io_client( - multithreaded_io.unwrap_or(true), - io_config.unwrap_or_default().config.into(), - )?; - let (scheme, path) = parse_url(&path)?; - let runtime_handle = get_runtime(true)?; - let _rt_guard = runtime_handle.enter(); - - runtime_handle.block_on(async move { - let source = io_client.get_source(&scheme).await?; - let files = if recursive.is_some_and(|r| r) { - recursive_iter(source, &path) - .await? - .try_collect::>() - .await? - } else { - source - .iter_dir(&path, "/", true, None, None) - .await? - .try_collect::>() - .await? - }; - - Ok(files) - }) - }); - let lsr = lsr?; - let mut to_rtn = vec![]; - for file in lsr { - let dict = PyDict::new(py); - dict.set_item("type", format!("{:?}", file.filetype))?; - dict.set_item("path", file.filepath)?; - dict.set_item("size", file.size)?; - to_rtn.push(dict); - } - Ok(PyList::new(py, to_rtn)) - } - #[pyfunction] fn set_io_pool_num_threads(num_threads: i64) -> PyResult { Ok(crate::set_io_pool_num_threads(num_threads as usize)) @@ -108,7 +58,6 @@ mod py { pub fn register_modules(py: Python, parent: &PyModule) -> PyResult<()> { common_io_config::python::register_modules(py, parent)?; - parent.add_function(wrap_pyfunction!(io_list, parent)?)?; parent.add_function(wrap_pyfunction!(io_glob, parent)?)?; parent.add_function(wrap_pyfunction!(set_io_pool_num_threads, parent)?)?; diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 5004abc87b..fea3e1629e 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -2,6 +2,7 @@ use async_trait::async_trait; use aws_config::retry::RetryMode; use aws_config::timeout::TimeoutConfig; use aws_smithy_async::rt::sleep::TokioSleep; +use futures::stream::BoxStream; use reqwest::StatusCode; use s3::operation::head_object::HeadObjectError; use s3::operation::list_objects_v2::ListObjectsV2Error; @@ -33,6 +34,9 @@ use std::ops::Range; use std::string::FromUtf8Error; use std::sync::Arc; use std::time::Duration; + +const S3_DELIMITER: &str = "/"; +const DEFAULT_GLOB_FANOUT_LIMIT: usize = 1024; pub(crate) struct S3LikeSource { region_to_client_map: tokio::sync::RwLock>>, connection_pool_sema: Arc, @@ -362,7 +366,7 @@ impl S3LikeSource { }), }?; let key = parsed.path(); - if let Some(key) = key.strip_prefix('/') { + if let Some(key) = key.strip_prefix(S3_DELIMITER) { log::debug!("S3 get parsed uri: {uri} into Bucket: {bucket}, Key: {key}"); let request = self .get_s3_client(region) @@ -474,7 +478,7 @@ impl S3LikeSource { }), }?; let key = parsed.path(); - if let Some(key) = key.strip_prefix('/') { + if let Some(key) = key.strip_prefix(S3_DELIMITER) { log::debug!("S3 head parsed uri: {uri} into Bucket: {bucket}, Key: {key}"); let request = self .get_s3_client(region) @@ -688,6 +692,10 @@ impl S3LikeSource { #[async_trait] impl ObjectSource for S3LikeSource { + fn delimiter(&self) -> &'static str { + S3_DELIMITER + } + async fn get(&self, uri: &str, range: Option>) -> super::Result { let permit = self .connection_pool_sema @@ -708,10 +716,23 @@ impl ObjectSource for S3LikeSource { self._head_impl(permit, uri, &self.default_region).await } + async fn glob( + self: Arc, + glob_path: &str, + fanout_limit: Option, + page_size: Option, + ) -> super::Result>> { + use crate::object_store_glob::glob; + + // Ensure fanout_limit is not None to prevent runaway concurrency + let fanout_limit = fanout_limit.or(Some(DEFAULT_GLOB_FANOUT_LIMIT)); + + glob(self, glob_path, fanout_limit, page_size.or(Some(1000))).await + } + async fn ls( &self, path: &str, - delimiter: &str, posix: bool, continuation_token: Option<&str>, page_size: Option, @@ -726,7 +747,7 @@ impl ObjectSource for S3LikeSource { source: ParseError::EmptyHost, }), }?; - let key = parsed.path().trim_start_matches(delimiter); + let key = parsed.path().trim_start_matches(S3_DELIMITER); if posix { // Perform a directory-based list of entries in the next level @@ -734,7 +755,7 @@ impl ObjectSource for S3LikeSource { let key = if key.is_empty() { "".to_string() } else { - format!("{}{delimiter}", key.trim_end_matches(delimiter)) + format!("{}{S3_DELIMITER}", key.trim_end_matches(S3_DELIMITER)) }; let lsr = { let permit = self @@ -748,28 +769,28 @@ impl ObjectSource for S3LikeSource { scheme, bucket, &key, - Some(delimiter.into()), + Some(S3_DELIMITER.into()), continuation_token.map(String::from), &self.default_region, page_size, ) .await? }; - if lsr.files.is_empty() && key.contains(delimiter) { + if lsr.files.is_empty() && key.contains(S3_DELIMITER) { let permit = self .connection_pool_sema .acquire() .await .context(UnableToGrabSemaphoreSnafu)?; // Might be a File - let key = key.trim_end_matches(delimiter); + let key = key.trim_end_matches(S3_DELIMITER); let mut lsr = self ._list_impl( permit, scheme, bucket, key, - Some(delimiter.into()), + Some(S3_DELIMITER.into()), continuation_token.map(String::from), &self.default_region, page_size, @@ -879,7 +900,7 @@ mod tests { }; let client = S3LikeSource::get_client(&config).await?; - client.ls(file_path, "/", true, None, None).await?; + client.ls(file_path, true, None, None).await?; Ok(()) } diff --git a/tests/integration/io/benchmarks/test_benchmark_glob.py b/tests/integration/io/benchmarks/test_benchmark_glob.py index 09a1d12fd4..19659aaf80 100644 --- a/tests/integration/io/benchmarks/test_benchmark_glob.py +++ b/tests/integration/io/benchmarks/test_benchmark_glob.py @@ -5,7 +5,7 @@ import pytest import s3fs -from daft.daft import io_glob, io_list +from daft.daft import io_glob from ..conftest import minio_create_bucket @@ -219,10 +219,3 @@ def test_benchmark_glob_daft(benchmark, setup_bucket, minio_io_config, fanout_li ) ) assert len(results) == setup_bucket - - -@pytest.mark.benchmark(group="glob") -@pytest.mark.integration() -def test_benchmark_io_list_recursive_daft(benchmark, setup_bucket, minio_io_config): - results = benchmark(lambda: io_list(f"s3://{BUCKET}/", io_config=minio_io_config, recursive=True)) - assert len([r for r in results if r["type"] == "File"]) == setup_bucket diff --git a/tests/integration/io/test_list_files_gcs.py b/tests/integration/io/test_list_files_gcs.py index 1053e25fa1..5640887e4d 100644 --- a/tests/integration/io/test_list_files_gcs.py +++ b/tests/integration/io/test_list_files_gcs.py @@ -3,7 +3,7 @@ import gcsfs import pytest -from daft.daft import GCSConfig, IOConfig, io_list +from daft.daft import GCSConfig, IOConfig, io_glob BUCKET = "daft-public-data-gs" DEFAULT_GCS_CONFIG = GCSConfig(project_id=None, anonymous=None) @@ -28,11 +28,8 @@ def compare_gcs_result(daft_ls_result: list, fsspec_result: list): daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] gcsfs_files = [(f"gs://{f['name']}", f["type"]) for f in fsspec_result] - # Perform necessary post-processing of fsspec results to match expected behavior from Daft: - # NOTE: gcsfs sometimes does not return the trailing / for directories, so we have to ensure it - gcsfs_files = [ - (f"{path.rstrip('/')}/", type_) if type_ == "directory" else (path, type_) for path, type_ in gcsfs_files - ] + # Remove all directories: our glob utilities don't return dirs + gcsfs_files = [(path, type_) for path, type_ in gcsfs_files if type_ == "file"] assert len(daft_files) == len(gcsfs_files) assert sorted(daft_files) == sorted(gcsfs_files) @@ -51,10 +48,12 @@ def compare_gcs_result(daft_ls_result: list, fsspec_result: list): ], ) @pytest.mark.parametrize("recursive", [False, True]) +@pytest.mark.parametrize("fanout_limit", [None, 1]) @pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) -def test_gs_flat_directory_listing(path, recursive, gcs_config): +def test_gs_flat_directory_listing(path, recursive, gcs_config, fanout_limit): fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) + glob_path = path.rstrip("/") + "/**" if recursive else path + daft_ls_result = io_glob(glob_path, io_config=IOConfig(gcs=gcs_config), fanout_limit=fanout_limit) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @@ -65,18 +64,14 @@ def test_gs_flat_directory_listing(path, recursive, gcs_config): def test_gs_single_file_listing(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/file.txt" fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) + daft_ls_result = io_glob(path, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() -@pytest.mark.parametrize("recursive", [False, True]) @pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) -def test_gs_notfound(recursive, gcs_config): - path = f"gs://{BUCKET}/test_ls/MISSING" - fs = gcsfs.GCSFileSystem() - with pytest.raises(FileNotFoundError): - fs.ls(path, detail=True) +def test_gs_notfound(gcs_config): + path = f"gs://{BUCKET}/test_" with pytest.raises(FileNotFoundError, match=path): - io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) + io_glob(path, io_config=IOConfig(gcs=gcs_config)) diff --git a/tests/integration/io/test_list_files_http.py b/tests/integration/io/test_list_files_http.py index 2471133c3f..e40a8a80e5 100644 --- a/tests/integration/io/test_list_files_http.py +++ b/tests/integration/io/test_list_files_http.py @@ -5,13 +5,17 @@ import pytest from fsspec.implementations.http import HTTPFileSystem -from daft.daft import io_list +from daft.daft import io_glob from tests.integration.io.conftest import mount_data_nginx def compare_http_result(daft_ls_result: list, fsspec_result: list): daft_files = [(f["path"], f["type"].lower(), f["size"]) for f in daft_ls_result] httpfs_files = [(f["name"], f["type"], f["size"]) for f in fsspec_result] + + # io_glob doesn't return directory entries + httpfs_files = [(p, t, s) for p, t, s in httpfs_files if t == "file"] + assert len(daft_files) == len(httpfs_files) assert sorted(daft_files) == sorted(httpfs_files) @@ -47,14 +51,14 @@ def test_http_flat_directory_listing(path, nginx_http_url): http_path = f"{nginx_http_url}{path}" fs = HTTPFileSystem() fsspec_result = fs.ls(http_path, detail=True) - daft_ls_result = io_list(http_path) + daft_ls_result = io_glob(http_path) compare_http_result(daft_ls_result, fsspec_result) @pytest.mark.integration() -def test_gs_single_file_listing(nginx_http_url): +def test_http_single_file_listing(nginx_http_url): path = f"{nginx_http_url}/test_ls/file.txt" - daft_ls_result = io_list(path) + daft_ls_result = io_glob(path) # NOTE: FSSpec will return size 0 list for this case, but we want to return 1 element to be # consistent with behavior of our other file listing utilities @@ -73,22 +77,15 @@ def test_http_notfound(nginx_http_url): fs.ls(path, detail=True) with pytest.raises(FileNotFoundError, match=path): - io_list(path) + io_glob(path) @pytest.mark.integration() -@pytest.mark.parametrize( - "path", - [ - f"", - f"/", - ], -) -def test_http_flat_directory_listing_recursive(path, nginx_http_url): - http_path = f"{nginx_http_url}/{path}" +def test_http_flat_directory_listing_recursive(nginx_http_url): + http_path = f"{nginx_http_url}/**" fs = HTTPFileSystem() - fsspec_result = list(fs.glob(http_path.rstrip("/") + "/**", detail=True).values()) - daft_ls_result = io_list(http_path, recursive=True) + fsspec_result = list(fs.glob(http_path, detail=True).values()) + daft_ls_result = io_glob(http_path) compare_http_result(daft_ls_result, fsspec_result) @@ -107,7 +104,7 @@ def test_http_listing_absolute_urls(nginx_config, tmpdir): with mount_data_nginx(nginx_config, tmpdir): http_path = f"{nginx_http_url}/index.html" - daft_ls_result = io_list(http_path, recursive=False) + daft_ls_result = io_glob(http_path) # NOTE: Cannot use fsspec here because they do not correctly find the links # fsspec_result = fs.ls(http_path, detail=True) @@ -115,7 +112,6 @@ def test_http_listing_absolute_urls(nginx_config, tmpdir): assert daft_ls_result == [ {"type": "File", "path": f"{nginx_http_url}/other.html", "size": None}, - {"type": "Directory", "path": f"{nginx_http_url}/dir/", "size": None}, ] @@ -134,7 +130,7 @@ def test_http_listing_absolute_base_urls(nginx_config, tmpdir): with mount_data_nginx(nginx_config, tmpdir): http_path = f"{nginx_http_url}/index.html" - daft_ls_result = io_list(http_path, recursive=False) + daft_ls_result = io_glob(http_path) # NOTE: Cannot use fsspec here because they do not correctly find the links # fsspec_result = fs.ls(http_path, detail=True) @@ -142,5 +138,4 @@ def test_http_listing_absolute_base_urls(nginx_config, tmpdir): assert daft_ls_result == [ {"type": "File", "path": f"{nginx_http_url}/other.html", "size": None}, - {"type": "Directory", "path": f"{nginx_http_url}/dir/", "size": None}, ] diff --git a/tests/integration/io/test_list_files_s3_minio.py b/tests/integration/io/test_list_files_s3_minio.py index 3f0fc9a715..1d67768998 100644 --- a/tests/integration/io/test_list_files_s3_minio.py +++ b/tests/integration/io/test_list_files_s3_minio.py @@ -2,7 +2,7 @@ import pytest -from daft.daft import io_glob, io_list +from daft.daft import io_glob from .conftest import minio_create_bucket @@ -10,6 +10,10 @@ def compare_s3_result(daft_ls_result: list, s3fs_result: list): daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] s3fs_files = [(f"s3://{f['name']}", f["type"]) for f in s3fs_result] + + # io_glob does not return directories + s3fs_files = [(p, t) for p, t in s3fs_files if t == "file"] + assert sorted(daft_files) == sorted(s3fs_files) @@ -298,7 +302,7 @@ def test_flat_directory_listing(minio_io_config): files = ["a", "b", "c"] for name in files: fs.touch(f"{bucket_name}/{name}") - daft_ls_result = io_list(f"s3://{bucket_name}", io_config=minio_io_config) + daft_ls_result = io_glob(f"s3://{bucket_name}", io_config=minio_io_config) s3fs_result = fs.ls(f"s3://{bucket_name}", detail=True) compare_s3_result(daft_ls_result, s3fs_result) @@ -310,24 +314,20 @@ def test_recursive_directory_listing(minio_io_config): files = ["a", "b/bb", "c/cc/ccc"] for name in files: fs.write_bytes(f"s3://{bucket_name}/{name}", b"") - daft_ls_result = io_list(f"s3://{bucket_name}/", io_config=minio_io_config, recursive=True) + daft_ls_result = io_glob(f"s3://{bucket_name}/**", io_config=minio_io_config) fs.invalidate_cache() s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}") compare_s3_result(daft_ls_result, s3fs_result) @pytest.mark.integration() -@pytest.mark.parametrize( - "recursive", - [False, True], -) -def test_single_file_directory_listing(minio_io_config, recursive): +def test_single_file_directory_listing(minio_io_config): bucket_name = "bucket" with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs: files = ["a", "b/bb", "c/cc/ccc"] for name in files: fs.write_bytes(f"s3://{bucket_name}/{name}", b"") - daft_ls_result = io_list(f"s3://{bucket_name}/c/cc/ccc", io_config=minio_io_config, recursive=recursive) + daft_ls_result = io_glob(f"s3://{bucket_name}/c/cc/ccc", io_config=minio_io_config) fs.invalidate_cache() s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}/c/cc/ccc") assert len(daft_ls_result) == 1 @@ -335,17 +335,13 @@ def test_single_file_directory_listing(minio_io_config, recursive): @pytest.mark.integration() -@pytest.mark.parametrize( - "recursive", - [False, True], -) -def test_single_file_directory_listing_trailing(minio_io_config, recursive): +def test_single_file_directory_listing_trailing(minio_io_config): bucket_name = "bucket" with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs: files = ["a", "b/bb", "c/cc/ccc"] for name in files: fs.write_bytes(f"s3://{bucket_name}/{name}", b"") - daft_ls_result = io_list(f"s3://{bucket_name}/c/cc///", io_config=minio_io_config, recursive=recursive) + daft_ls_result = io_glob(f"s3://{bucket_name}/c/cc///", io_config=minio_io_config) fs.invalidate_cache() s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}/c/cc///") assert len(daft_ls_result) == 1 @@ -361,7 +357,8 @@ def test_missing_file_path(minio_io_config, recursive): bucket_name = "bucket" with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs: files = ["a", "b/bb", "c/cc/ccc"] + path = f"s3://{bucket_name}/c/cc/ddd/**" if recursive else f"s3://{bucket_name}/c/cc/ddd" for name in files: fs.write_bytes(f"s3://{bucket_name}/{name}", b"") with pytest.raises(FileNotFoundError, match=f"s3://{bucket_name}/c/cc/ddd"): - daft_ls_result = io_list(f"s3://{bucket_name}/c/cc/ddd", io_config=minio_io_config, recursive=recursive) + daft_ls_result = io_glob(path, io_config=minio_io_config) diff --git a/tests/io/__init__.py b/tests/io/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/io/test_list_files_local.py b/tests/io/test_list_files_local.py similarity index 73% rename from tests/integration/io/test_list_files_local.py rename to tests/io/test_list_files_local.py index dfd016038b..fc7cd5259c 100644 --- a/tests/integration/io/test_list_files_local.py +++ b/tests/io/test_list_files_local.py @@ -3,7 +3,7 @@ import pytest from fsspec.implementations.local import LocalFileSystem -from daft.daft import io_list +from daft.daft import io_glob def local_recursive_list(fs, path) -> list: @@ -23,10 +23,13 @@ def local_recursive_list(fs, path) -> list: def compare_local_result(daft_ls_result: list, fs_result: list): daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] fs_files = [(f'file://{f["name"]}', f["type"]) for f in fs_result] + + # io_glob does not return directories + fs_files = [(p, t) for p, t in fs_files if t == "file"] + assert sorted(daft_files) == sorted(fs_files) -@pytest.mark.integration() @pytest.mark.parametrize("include_protocol", [False, True]) def test_flat_directory_listing(tmp_path, include_protocol): d = tmp_path / "dir" @@ -35,16 +38,15 @@ def test_flat_directory_listing(tmp_path, include_protocol): for name in files: p = d / name p.touch() - d = str(d) + d = str(d) + "/" if include_protocol: d = "file://" + d - daft_ls_result = io_list(d) + daft_ls_result = io_glob(d) fs = LocalFileSystem() fs_result = fs.ls(d, detail=True) compare_local_result(daft_ls_result, fs_result) -@pytest.mark.integration() @pytest.mark.parametrize("include_protocol", [False, True]) def test_recursive_directory_listing(tmp_path, include_protocol): d = tmp_path / "dir" @@ -58,22 +60,16 @@ def test_recursive_directory_listing(tmp_path, include_protocol): p.mkdir() p /= segments[-1] p.touch() - d = str(d) if include_protocol: - d = "file://" + d - daft_ls_result = io_list(d, recursive=True) + d = "file://" + str(d) + daft_ls_result = io_glob(str(d) + "/**") fs = LocalFileSystem() fs_result = local_recursive_list(fs, d) compare_local_result(daft_ls_result, fs_result) -@pytest.mark.integration() @pytest.mark.parametrize("include_protocol", [False, True]) -@pytest.mark.parametrize( - "recursive", - [False, True], -) -def test_single_file_directory_listing(tmp_path, include_protocol, recursive): +def test_single_file_directory_listing(tmp_path, include_protocol): d = tmp_path / "dir" d.mkdir() files = ["a", "b/bb", "c/cc/ccc"] @@ -88,13 +84,40 @@ def test_single_file_directory_listing(tmp_path, include_protocol, recursive): p = f"{d}/c/cc/ccc" if include_protocol: p = "file://" + p - daft_ls_result = io_list(p, recursive=recursive) + + daft_ls_result = io_glob(p) fs_result = [{"name": f"{d}/c/cc/ccc", "type": "file"}] assert len(daft_ls_result) == 1 compare_local_result(daft_ls_result, fs_result) -@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_wildcard_listing(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a/x.txt", "b/y.txt", "c/z.txt"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + p = f"{d}/*/*.txt" + if include_protocol: + p = "file://" + p + + daft_ls_result = io_glob(p) + fs_result = [ + {"name": f"{d}/a/x.txt", "type": "file"}, + {"name": f"{d}/b/y.txt", "type": "file"}, + {"name": f"{d}/c/z.txt", "type": "file"}, + ] + assert len(daft_ls_result) == 3 + compare_local_result(daft_ls_result, fs_result) + + @pytest.mark.parametrize("include_protocol", [False, True]) def test_missing_file_path(tmp_path, include_protocol): d = tmp_path / "dir" @@ -112,4 +135,4 @@ def test_missing_file_path(tmp_path, include_protocol): if include_protocol: p = "file://" + p with pytest.raises(FileNotFoundError, match=f"File: {d}/c/cc/ddd not found"): - daft_ls_result = io_list(p, recursive=True) + io_glob(p) diff --git a/tests/integration/io/test_url_download_local.py b/tests/io/test_url_download_local.py similarity index 100% rename from tests/integration/io/test_url_download_local.py rename to tests/io/test_url_download_local.py