diff --git a/Cargo.lock b/Cargo.lock index dbefdef313..214c1790bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1773,6 +1773,7 @@ dependencies = [ "bytes", "common-error", "common-io-config", + "common-py-serde", "futures", "globset", "google-cloud-storage", @@ -1789,6 +1790,7 @@ dependencies = [ "rand 0.8.5", "regex", "reqwest", + "serde", "snafu", "tempfile", "tokio", @@ -1948,6 +1950,7 @@ dependencies = [ "daft-core", "daft-dsl", "daft-functions", + "daft-io", "daft-scan", "daft-table", "indexmap 2.3.0", @@ -1998,6 +2001,7 @@ dependencies = [ "common-py-serde", "daft-core", "daft-dsl", + "daft-io", "daft-plan", "daft-scan", "pyo3", diff --git a/daft/daft.pyi b/daft/daft.pyi index fe84b6599d..8a023e61cf 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -448,7 +448,9 @@ class HTTPConfig: I/O configuration for accessing HTTP systems """ - user_agent: str | None + bearer_token: str | None + + def __init__(self, bearer_token: str | None = None): ... class S3Config: """ diff --git a/docs/source/user_guide/integrations.rst b/docs/source/user_guide/integrations.rst index de296676f5..3390b68483 100644 --- a/docs/source/user_guide/integrations.rst +++ b/docs/source/user_guide/integrations.rst @@ -11,3 +11,4 @@ Integrations integrations/microsoft-azure integrations/aws integrations/sql + integrations/huggingface diff --git a/docs/source/user_guide/integrations/huggingface.rst b/docs/source/user_guide/integrations/huggingface.rst new file mode 100644 index 0000000000..547f5ed856 --- /dev/null +++ b/docs/source/user_guide/integrations/huggingface.rst @@ -0,0 +1,64 @@ +Huggingface Datasets +=========== + +Daft is able to read datasets directly from Huggingface via the ``hf://`` protocol. + +Since huggingface will `automatically convert `_ all public datasets to parquet format, +we can read these datasets using the ``read_parquet`` method. + +.. NOTE:: + This is limited to either public datasets, or PRO/ENTERPRISE datasets. + +For other file formats, you will need to manually specify the path or glob pattern to the files you want to read, similar to how you would read from a local file system. + + +Reading Public Datasets +----------------------- + +.. code:: python + + import daft + + df = daft.read_parquet("hf://username/dataset_name") + +This will read the entire dataset into a daft DataFrame. + +Not only can you read entire datasets, but you can also read individual files from a dataset. + +.. code:: python + + import daft + + df = daft.read_parquet("hf://username/dataset_name/file_name.parquet") + # or a csv file + df = daft.read_csv("hf://username/dataset_name/file_name.csv") + + # or a glob pattern + df = daft.read_parquet("hf://username/dataset_name/**/*.parquet") + + +Authorization +------------- + +For authenticated datasets: + +.. code:: python + + from daft.io import IOConfig, HTTPConfig + + io_config = IoConfig(http=HTTPConfig(bearer_token="your_token")) + df = daft.read_parquet("hf://username/dataset_name", io_config=io_config) + + +It's important to note that this will not work with standard tier private datasets. +Huggingface does not auto convert private datasets to parquet format, so you will need to specify the path to the files you want to read. + +.. code:: python + + df = daft.read_parquet("hf://username/my_private_dataset", io_config=io_config) # Errors + +to get around this, you can read all files using a glob pattern *(assuming they are in parquet format)* + +.. code:: python + + df = daft.read_parquet("hf://username/my_private_dataset/**/*.parquet", io_config=io_config) # Works diff --git a/src/common/io-config/src/http.rs b/src/common/io-config/src/http.rs index c619c0b7d8..7fb3f38eeb 100644 --- a/src/common/io-config/src/http.rs +++ b/src/common/io-config/src/http.rs @@ -4,22 +4,40 @@ use std::fmt::Formatter; use serde::Deserialize; use serde::Serialize; +use crate::ObfuscatedString; + #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct HTTPConfig { pub user_agent: String, + pub bearer_token: Option, } impl Default for HTTPConfig { fn default() -> Self { HTTPConfig { user_agent: "daft/0.0.1".to_string(), // NOTE: Ideally we grab the version of Daft, but that requires a dependency on daft-core + bearer_token: None, + } + } +} + +impl HTTPConfig { + pub fn new>(bearer_token: Option) -> Self { + HTTPConfig { + bearer_token: bearer_token.map(|t| t.into()), + ..Default::default() } } } impl HTTPConfig { pub fn multiline_display(&self) -> Vec { - vec![format!("user_agent = {}", self.user_agent)] + let mut v = vec![format!("user_agent = {}", self.user_agent)]; + if let Some(bearer_token) = &self.bearer_token { + v.push(format!("bearer_token = {}", bearer_token)); + } + + v } } @@ -30,6 +48,17 @@ impl Display for HTTPConfig { "HTTPConfig user_agent: {}", self.user_agent, - ) + )?; + + if let Some(bearer_token) = &self.bearer_token { + write!( + f, + " + bearer_token: {}", + bearer_token + ) + } else { + Ok(()) + } } } diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 8332a55b73..a6d32d3cb0 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -134,9 +134,10 @@ pub struct IOConfig { /// /// Args: /// user_agent (str, optional): The value for the user-agent header, defaults to "daft/{__version__}" if not provided +/// bearer_token (str, optional): Bearer token to use for authentication. This will be used as the value for the `Authorization` header. such as "Authorization: Bearer xxx" /// /// Example: -/// >>> io_config = IOConfig(http=HTTPConfig(user_agent="my_application/0.0.1")) +/// >>> io_config = IOConfig(http=HTTPConfig(user_agent="my_application/0.0.1", bearer_token="xxx")) /// >>> daft.read_parquet("http://some-path", io_config=io_config) #[derive(Clone, Default)] #[pyclass] @@ -901,6 +902,20 @@ impl From for IOConfig { } } +#[pymethods] +impl HTTPConfig { + #[new] + pub fn new(bearer_token: Option) -> Self { + HTTPConfig { + config: crate::HTTPConfig::new(bearer_token), + } + } + + pub fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.config)) + } +} + pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 062d1cc446..a749189eec 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -15,6 +15,7 @@ azure_storage_blobs = {version = "0.17.0", features = ["enable_reqwest"], defaul bytes = {workspace = true} common-error = {path = "../common/error", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-py-serde = {path = "../common/py-serde", default-features = false} futures = {workspace = true} globset = "0.4" google-cloud-storage = {version = "0.15.0", default-features = false, features = ["default-tls", "auth"]} @@ -29,6 +30,7 @@ openssl-sys = {version = "0.9.102", features = ["vendored"]} pyo3 = {workspace = true, optional = true} rand = "0.8.5" regex = {version = "1.10.4"} +serde = {workspace = true} snafu = {workspace = true} tokio = {workspace = true} tokio-stream = {workspace = true} @@ -36,7 +38,7 @@ url = {workspace = true} [dependencies.reqwest] default-features = false -features = ["stream", "native-tls"] +features = ["stream", "native-tls", "json"] version = "0.11.18" [dev-dependencies] @@ -44,7 +46,12 @@ md5 = "0.7.0" tempfile = "3.8.1" [features] -python = ["dep:pyo3", "common-error/python", "common-io-config/python"] +python = [ + "dep:pyo3", + "common-error/python", + "common-io-config/python", + "common-py-serde/python" +] [package] edition = {workspace = true} diff --git a/src/daft-io/src/azure_blob.rs b/src/daft-io/src/azure_blob.rs index 117e7e81c0..0a12dc704c 100644 --- a/src/daft-io/src/azure_blob.rs +++ b/src/daft-io/src/azure_blob.rs @@ -15,7 +15,7 @@ use crate::{ object_io::{FileMetadata, FileType, LSResult, ObjectSource}, stats::IOStatsRef, stream_utils::io_stats_on_bytestream, - GetResult, + FileFormat, GetResult, }; use common_io_config::AzureConfig; @@ -577,6 +577,7 @@ impl ObjectSource for AzureBlobSource { page_size: Option, limit: Option, io_stats: Option, + _file_format: Option, ) -> super::Result>> { use crate::object_store_glob::glob; diff --git a/src/daft-io/src/file_format.rs b/src/daft-io/src/file_format.rs new file mode 100644 index 0000000000..6aff086cc4 --- /dev/null +++ b/src/daft-io/src/file_format.rs @@ -0,0 +1,58 @@ +use std::str::FromStr; + +use common_error::{DaftError, DaftResult}; +use common_py_serde::impl_bincode_py_state_serialization; +#[cfg(feature = "python")] +use pyo3::prelude::*; + +use serde::{Deserialize, Serialize}; + +/// Format of a file, e.g. Parquet, CSV, JSON. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Copy)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] +pub enum FileFormat { + Parquet, + Csv, + Json, + Database, + Python, +} + +#[cfg(feature = "python")] +#[pymethods] +impl FileFormat { + fn ext(&self) -> &'static str { + match self { + Self::Parquet => "parquet", + Self::Csv => "csv", + Self::Json => "json", + Self::Database => "db", + Self::Python => "py", + } + } +} + +impl FromStr for FileFormat { + type Err = DaftError; + + fn from_str(file_format: &str) -> DaftResult { + use FileFormat::*; + + if file_format.trim().eq_ignore_ascii_case("parquet") { + Ok(Parquet) + } else if file_format.trim().eq_ignore_ascii_case("csv") { + Ok(Csv) + } else if file_format.trim().eq_ignore_ascii_case("json") { + Ok(Json) + } else if file_format.trim().eq_ignore_ascii_case("database") { + Ok(Database) + } else { + Err(DaftError::TypeError(format!( + "FileFormat {} not supported!", + file_format + ))) + } + } +} + +impl_bincode_py_state_serialization!(FileFormat); diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index ff533cbe41..e9d9b2f9d1 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -23,6 +23,7 @@ use crate::object_io::LSResult; use crate::object_io::ObjectSource; use crate::stats::IOStatsRef; use crate::stream_utils::io_stats_on_bytestream; +use crate::FileFormat; use crate::GetResult; use common_io_config::GCSConfig; @@ -436,6 +437,7 @@ impl ObjectSource for GCSSource { page_size: Option, limit: Option, io_stats: Option, + _file_format: Option, ) -> super::Result>> { use crate::object_store_glob::glob; diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index 1779a2ed97..aa6deb050f 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -15,6 +15,7 @@ use crate::{ object_io::{FileMetadata, FileType, LSResult}, stats::IOStatsRef, stream_utils::io_stats_on_bytestream, + FileFormat, }; use super::object_io::{GetResult, ObjectSource}; @@ -140,7 +141,7 @@ fn _get_file_metadata_from_html(path: &str, text: &str) -> super::Result for super::Error { @@ -276,6 +277,7 @@ impl ObjectSource for HttpSource { _page_size: Option, limit: Option, io_stats: Option, + _file_format: Option, ) -> super::Result>> { use crate::object_store_glob::glob; diff --git a/src/daft-io/src/huggingface.rs b/src/daft-io/src/huggingface.rs new file mode 100644 index 0000000000..9e89327c80 --- /dev/null +++ b/src/daft-io/src/huggingface.rs @@ -0,0 +1,633 @@ +use std::{ + collections::HashMap, num::ParseIntError, ops::Range, str::FromStr, string::FromUtf8Error, + sync::Arc, +}; + +use async_trait::async_trait; +use common_io_config::HTTPConfig; +use futures::{ + stream::{self, BoxStream}, + StreamExt, TryStreamExt, +}; + +use hyper::header; +use reqwest::{ + header::{CONTENT_LENGTH, RANGE}, + Client, +}; +use snafu::{IntoError, ResultExt, Snafu}; + +use crate::{ + http::HttpSource, + object_io::{FileMetadata, FileType, LSResult}, + stats::IOStatsRef, + stream_utils::io_stats_on_bytestream, + FileFormat, +}; +use serde::{Deserialize, Serialize}; + +use super::object_io::{GetResult, ObjectSource}; + +#[derive(Debug, Snafu)] +enum Error { + #[snafu(display("Unable to connect to {}: {}", path, source))] + UnableToConnect { + path: String, + source: reqwest::Error, + }, + + #[snafu(display("Unable to open {}: {}", path, source))] + UnableToOpenFile { + path: String, + source: reqwest::Error, + }, + + #[snafu(display("Unable to determine size of {}", path))] + UnableToDetermineSize { path: String }, + + #[snafu(display("Unable to read data from {}: {}", path, source))] + UnableToReadBytes { + path: String, + source: reqwest::Error, + }, + + #[snafu(display("Unable to create Http Client {}", source))] + UnableToCreateClient { source: reqwest::Error }, + + #[snafu(display( + "Unable to parse data as Utf8 while reading header for file: {path}. {source}" + ))] + UnableToParseUtf8Header { path: String, source: FromUtf8Error }, + + #[snafu(display( + "Unable to parse data as Integer while reading header for file: {path}. {source}" + ))] + UnableToParseInteger { path: String, source: ParseIntError }, + + #[snafu(display("Unable to create HTTP header: {source}"))] + UnableToCreateHeader { source: header::InvalidHeaderValue }, + #[snafu(display("Invalid path: {}", path))] + InvalidPath { path: String }, + + #[snafu(display(r#" +Implicit Parquet conversion not supported for private datasets. +You can use glob patterns, or request a specific file to access your dataset instead. +Example: + instead of `hf://datasets/username/dataset_name`, use `hf://datasets/username/dataset_name/file_name.parquet` + or `hf://datasets/username/dataset_name/*.parquet +"#))] + PrivateDataset, + #[snafu(display("Unauthorized access to dataset, please check your credentials."))] + Unauthorized, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "snake_case")] +enum ItemType { + File, + Directory, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "snake_case")] +struct Item { + r#type: ItemType, + oid: String, + size: u64, + path: String, +} + +#[derive(Debug, PartialEq)] +struct HFPathParts { + bucket: String, + repository: String, + revision: String, + path: String, +} +impl FromStr for HFPathParts { + type Err = Error; + /// Extracts path components from a hugging face path: + /// `hf:// [datasets | spaces] / {username} / {reponame} @ {revision} / {path from root}` + fn from_str(uri: &str) -> Result { + // hf:// [datasets] / {username} / {reponame} @ {revision} / {path from root} + // !> + if !uri.starts_with("hf://") { + return Err(Error::InvalidPath { + path: uri.to_string(), + }); + } + (|| { + let uri = &uri[5..]; + + // [datasets] / {username} / {reponame} @ {revision} / {path from root} + // ^--------^ !> + let (bucket, uri) = uri.split_once('/')?; + // {username} / {reponame} @ {revision} / {path from root} + // ^--------^ !> + let (username, uri) = uri.split_once('/')?; + // {reponame} @ {revision} / {path from root} + // ^--------^ !> + let (repository, uri) = if let Some((repo, uri)) = uri.split_once('/') { + (repo, uri) + } else { + return Some(HFPathParts { + bucket: bucket.to_string(), + repository: format!("{}/{}", username, uri), + revision: "main".to_string(), + path: "".to_string(), + }); + }; + + // {revision} / {path from root} + // ^--------^ !> + let (repository, revision) = if let Some((repo, rev)) = repository.split_once('@') { + (repo, rev.to_string()) + } else { + (repository, "main".to_string()) + }; + + // {username}/{reponame} + let repository = format!("{}/{}", username, repository); + // {path from root} + // ^--------------^ + let path = uri.to_string().trim_end_matches('/').to_string(); + + Some(HFPathParts { + bucket: bucket.to_string(), + repository, + revision, + path, + }) + })() + .ok_or_else(|| Error::InvalidPath { + path: uri.to_string(), + }) + } +} + +impl std::fmt::Display for HFPathParts { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "hf://{BUCKET}/{REPOSITORY}/{PATH}", + BUCKET = self.bucket, + REPOSITORY = self.repository, + PATH = self.path + ) + } +} + +impl HFPathParts { + fn get_file_uri(&self) -> String { + format!( + "https://huggingface.co/{BUCKET}/{REPOSITORY}/resolve/{REVISION}/{PATH}", + BUCKET = self.bucket, + REPOSITORY = self.repository, + REVISION = self.revision, + PATH = self.path + ) + } + + fn get_api_uri(&self) -> String { + // "https://huggingface.co/api/ [datasets] / {username} / {reponame} / tree / {revision} / {path from root}" + format!( + "https://huggingface.co/api/{BUCKET}/{REPOSITORY}/tree/{REVISION}/{PATH}", + BUCKET = self.bucket, + REPOSITORY = self.repository, + REVISION = self.revision, + PATH = self.path + ) + } + + fn get_parquet_api_uri(&self) -> String { + format!( + "https://huggingface.co/api/{BUCKET}/{REPOSITORY}/parquet", + BUCKET = self.bucket, + REPOSITORY = self.repository, + ) + } +} + +pub(crate) struct HFSource { + http_source: HttpSource, +} + +impl From for HFSource { + fn from(http_source: HttpSource) -> Self { + Self { http_source } + } +} + +impl From for super::Error { + fn from(error: Error) -> Self { + use Error::*; + match error { + UnableToOpenFile { path, source } => match source.status().map(|v| v.as_u16()) { + Some(404) | Some(410) => super::Error::NotFound { + path, + source: source.into(), + }, + None | Some(_) => super::Error::UnableToOpenFile { + path, + source: source.into(), + }, + }, + UnableToDetermineSize { path } => super::Error::UnableToDetermineSize { path }, + _ => super::Error::Generic { + store: super::SourceType::Http, + source: error.into(), + }, + } + } +} + +impl HFSource { + pub async fn get_client(config: &HTTPConfig) -> super::Result> { + let mut default_headers = header::HeaderMap::new(); + default_headers.append( + "user-agent", + header::HeaderValue::from_str(config.user_agent.as_str()) + .context(UnableToCreateHeaderSnafu)?, + ); + + if let Some(token) = &config.bearer_token { + default_headers.append( + "Authorization", + header::HeaderValue::from_str(&format!("Bearer {}", token.as_string())) + .context(UnableToCreateHeaderSnafu)?, + ); + } + + Ok(HFSource { + http_source: HttpSource { + client: reqwest::ClientBuilder::default() + .pool_max_idle_per_host(70) + .default_headers(default_headers) + .build() + .context(UnableToCreateClientSnafu)?, + }, + } + .into()) + } +} + +#[async_trait] +impl ObjectSource for HFSource { + async fn get( + &self, + uri: &str, + range: Option>, + io_stats: Option, + ) -> super::Result { + let path_parts = uri.parse::()?; + let uri = &path_parts.get_file_uri(); + let request = self.http_source.client.get(uri); + let request = match range { + None => request, + Some(range) => request.header( + RANGE, + format!("bytes={}-{}", range.start, range.end.saturating_sub(1)), + ), + }; + + let response = request + .send() + .await + .context(UnableToConnectSnafu:: { path: uri.into() })?; + + let response = response.error_for_status().map_err(|e| { + if let Some(401) = e.status().map(|s| s.as_u16()) { + Error::Unauthorized + } else { + Error::UnableToOpenFile { + path: uri.clone(), + source: e, + } + } + })?; + + if let Some(is) = io_stats.as_ref() { + is.mark_get_requests(1) + } + let size_bytes = response.content_length().map(|s| s as usize); + let stream = response.bytes_stream(); + let owned_string = uri.to_owned(); + let stream = stream.map_err(move |e| { + UnableToReadBytesSnafu:: { + path: owned_string.clone(), + } + .into_error(e) + .into() + }); + Ok(GetResult::Stream( + io_stats_on_bytestream(stream, io_stats), + size_bytes, + None, + None, + )) + } + + async fn put( + &self, + _uri: &str, + _data: bytes::Bytes, + _io_stats: Option, + ) -> super::Result<()> { + todo!("PUTs to HTTP URLs are not yet supported! Please file an issue."); + } + + async fn get_size(&self, uri: &str, io_stats: Option) -> super::Result { + let path_parts = uri.parse::()?; + let uri = &path_parts.get_file_uri(); + + let request = self.http_source.client.head(uri); + let response = request + .send() + .await + .context(UnableToConnectSnafu:: { path: uri.into() })?; + let response = response.error_for_status().map_err(|e| { + if let Some(401) = e.status().map(|s| s.as_u16()) { + Error::Unauthorized + } else { + Error::UnableToOpenFile { + path: uri.clone(), + source: e, + } + } + })?; + + if let Some(is) = io_stats.as_ref() { + is.mark_head_requests(1) + } + + let headers = response.headers(); + match headers.get(CONTENT_LENGTH) { + Some(v) => { + let size_bytes = String::from_utf8(v.as_bytes().to_vec()).with_context(|_| { + UnableToParseUtf8HeaderSnafu:: { path: uri.into() } + })?; + + Ok(size_bytes + .parse() + .with_context(|_| UnableToParseIntegerSnafu:: { path: uri.into() })?) + } + None => Err(Error::UnableToDetermineSize { path: uri.into() }.into()), + } + } + + async fn glob( + self: Arc, + glob_path: &str, + _fanout_limit: Option, + _page_size: Option, + limit: Option, + io_stats: Option, + file_format: Option, + ) -> super::Result>> { + use crate::object_store_glob::glob; + + // Huggingface has a special API for parquet files + // So we'll try to use that API to get the parquet files + // This allows us compatibility with datasets that are not natively uploaded as parquet, such as image datasets + + // We only want to use this api for datasets, not specific files + // such as + // hf://datasets/user/repo + // but not + // hf://datasets/user/repo/file.parquet + if let Some(FileFormat::Parquet) = file_format { + let res = + try_parquet_api(glob_path, limit, io_stats.clone(), &self.http_source.client).await; + match res { + Ok(Some(stream)) => return Ok(stream), + Err(e) => return Err(e.into()), + Ok(None) => { + // INTENTIONALLY EMPTY + // If we can't determine if the dataset is private, we'll fall back to the default globbing + } + } + } + + glob(self, glob_path, None, None, limit, io_stats).await + } + + async fn ls( + &self, + path: &str, + posix: bool, + _continuation_token: Option<&str>, + _page_size: Option, + io_stats: Option, + ) -> super::Result { + if !posix { + unimplemented!("Prefix-listing is not implemented for HTTP listing"); + } + let path_parts = path.parse::()?; + + let api_uri = path_parts.get_api_uri(); + + let request = self.http_source.client.get(api_uri.clone()); + let response = request + .send() + .await + .context(UnableToConnectSnafu:: { + path: api_uri.clone(), + })?; + + let response = response.error_for_status().map_err(|e| { + if let Some(401) = e.status().map(|s| s.as_u16()) { + Error::Unauthorized + } else { + Error::UnableToOpenFile { + path: api_uri.clone(), + source: e, + } + } + })?; + + if let Some(is) = io_stats.as_ref() { + is.mark_list_requests(1) + } + let response = response + .json::>() + .await + .context(UnableToReadBytesSnafu { + path: api_uri.clone(), + })?; + + let files = response + .into_iter() + .map(|item| { + let filepath = HFPathParts { + bucket: path_parts.bucket.clone(), + repository: path_parts.repository.clone(), + revision: path_parts.revision.clone(), + path: item.path, + }; + + let size = match item.size { + 0 => None, + size => Some(size), + }; + let filepath = filepath.to_string(); + + let filetype = match item.r#type { + ItemType::File => FileType::File, + ItemType::Directory => FileType::Directory, + }; + + FileMetadata { + filepath, + size, + filetype, + } + }) + .collect(); + Ok(LSResult { + files, + continuation_token: None, + }) + } +} + +async fn try_parquet_api( + glob_path: &str, + limit: Option, + io_stats: Option, + client: &Client, +) -> Result>>, Error> { + let hf_glob_path = glob_path.parse::()?; + if hf_glob_path.path.is_empty() { + let api_path = hf_glob_path.get_parquet_api_uri(); + + let response = client + .get(api_path.clone()) + .send() + .await + .with_context(|_| UnableToOpenFileSnafu { + path: api_path.to_string(), + })?; + if response.status() == 400 { + if let Some(error_message) = response + .headers() + .get("x-error-message") + .and_then(|v| v.to_str().ok()) + { + const PRIVATE_DATASET_ERROR: &str = + "Private datasets are only supported for PRO users and Enterprise Hub organizations."; + if error_message.ends_with(PRIVATE_DATASET_ERROR) { + return Err(Error::PrivateDataset); + } + } else { + return Err(Error::Unauthorized); + } + } + let response = response + .error_for_status() + .with_context(|_| UnableToOpenFileSnafu { + path: api_path.to_string(), + })?; + + if let Some(is) = io_stats.as_ref() { + is.mark_list_requests(1) + } + + // {: {: [, ...]}} + type DatasetResponse = HashMap>>; + let body = response + .json::() + .await + .context(UnableToReadBytesSnafu { + path: api_path.clone(), + })?; + + let files = body + .into_values() + .flat_map(|splits| splits.into_values()) + .flatten() + .map(|uri| { + Ok(FileMetadata { + filepath: uri, + size: None, + filetype: FileType::File, + }) + }); + + return Ok(Some( + stream::iter(files).take(limit.unwrap_or(16 * 1024)).boxed(), + )); + } else { + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use common_error::DaftResult; + + use crate::huggingface::HFPathParts; + + #[test] + fn test_parse_hf_parts() -> DaftResult<()> { + let uri = "hf://datasets/wikimedia/wikipedia/20231101.ab/*.parquet"; + let parts = uri.parse::().unwrap(); + let expected = HFPathParts { + bucket: "datasets".to_string(), + repository: "wikimedia/wikipedia".to_string(), + revision: "main".to_string(), + path: "20231101.ab/*.parquet".to_string(), + }; + + assert_eq!(parts, expected); + + Ok(()) + } + + #[test] + fn test_parse_hf_parts_with_revision() -> DaftResult<()> { + let uri = "hf://datasets/wikimedia/wikipedia@dev/20231101.ab/*.parquet"; + let parts = uri.parse::().unwrap(); + let expected = HFPathParts { + bucket: "datasets".to_string(), + repository: "wikimedia/wikipedia".to_string(), + revision: "dev".to_string(), + path: "20231101.ab/*.parquet".to_string(), + }; + + assert_eq!(parts, expected); + + Ok(()) + } + + #[test] + fn test_parse_hf_parts_with_exact_path() -> DaftResult<()> { + let uri = "hf://datasets/user/repo@dev/config/my_file.parquet"; + let parts = uri.parse::().unwrap(); + let expected = HFPathParts { + bucket: "datasets".to_string(), + repository: "user/repo".to_string(), + revision: "dev".to_string(), + path: "config/my_file.parquet".to_string(), + }; + + assert_eq!(parts, expected); + + Ok(()) + } + + #[test] + fn test_parse_hf_parts_with_wildcard() -> DaftResult<()> { + let uri = "hf://datasets/wikimedia/wikipedia/**/*.parquet"; + let parts = uri.parse::().unwrap(); + let expected = HFPathParts { + bucket: "datasets".to_string(), + repository: "wikimedia/wikipedia".to_string(), + revision: "main".to_string(), + path: "**/*.parquet".to_string(), + }; + + assert_eq!(parts, expected); + + Ok(()) + } +} diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 957086f72d..a52b5825fd 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -5,6 +5,7 @@ mod azure_blob; mod google_cloud; mod http; +mod huggingface; mod local; mod object_io; mod object_store_glob; @@ -13,9 +14,12 @@ mod stats; mod stream_utils; use azure_blob::AzureBlobSource; use google_cloud::GCSSource; +use huggingface::HFSource; use lazy_static::lazy_static; +mod file_format; #[cfg(feature = "python")] pub mod python; +pub use file_format::FileFormat; pub use common_io_config::{AzureConfig, IOConfig, S3Config}; pub use object_io::FileMetadata; @@ -210,6 +214,9 @@ impl IOClient { SourceType::GCS => { GCSSource::get_client(&self.config.gcs).await? as Arc } + SourceType::HF => { + HFSource::get_client(&self.config.http).await? as Arc + } }; if w_handle.get(source_type).is_none() { @@ -225,11 +232,19 @@ impl IOClient { page_size: Option, limit: Option, io_stats: Option>, + file_format: Option, ) -> Result>> { let (scheme, _) = parse_url(input.as_str())?; let source = self.get_source(&scheme).await?; let files = source - .glob(input.as_str(), fanout_limit, page_size, limit, io_stats) + .glob( + input.as_str(), + fanout_limit, + page_size, + limit, + io_stats, + file_format, + ) .await?; Ok(files) } @@ -338,6 +353,7 @@ pub enum SourceType { S3, AzureBlob, GCS, + HF, } impl std::fmt::Display for SourceType { @@ -348,6 +364,7 @@ impl std::fmt::Display for SourceType { SourceType::S3 => write!(f, "s3"), SourceType::AzureBlob => write!(f, "AzureBlob"), SourceType::GCS => write!(f, "gcs"), + SourceType::HF => write!(f, "hf"), } } } @@ -386,6 +403,7 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { "s3" | "s3a" => Ok((SourceType::S3, fixed_input)), "az" | "abfs" | "abfss" => Ok((SourceType::AzureBlob, fixed_input)), "gcs" | "gs" => Ok((SourceType::GCS, fixed_input)), + "hf" => Ok((SourceType::HF, fixed_input)), #[cfg(target_env = "msvc")] _ if scheme.len() == 1 && ("a" <= scheme.as_str() && (scheme.as_str() <= "z")) => { Ok((SourceType::File, Cow::Owned(format!("file://{input}")))) diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index d792ae7cc1..d468faa798 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -4,6 +4,7 @@ use std::path::PathBuf; use crate::object_io::{self, FileMetadata, LSResult}; use crate::stats::IOStatsRef; +use crate::FileFormat; use super::object_io::{GetResult, ObjectSource}; use super::Result; @@ -196,6 +197,7 @@ impl ObjectSource for LocalSource { _page_size: Option, limit: Option, io_stats: Option, + _file_format: Option, ) -> super::Result>> { use crate::object_store_glob::glob; diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 295c10cfab..7a51ceb4b7 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -12,6 +12,7 @@ use tokio::sync::OwnedSemaphorePermit; use crate::local::{collect_file, LocalFile}; use crate::stats::IOStatsRef; +use crate::FileFormat; pub struct StreamingRetryParams { source: Arc, @@ -195,6 +196,7 @@ pub(crate) trait ObjectSource: Sync + Send { page_size: Option, limit: Option, io_stats: Option, + file_format: Option, ) -> super::Result>>; async fn ls( diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 3ce4b59ba3..b893002bd2 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -42,6 +42,7 @@ mod py { page_size, limit, Some(io_stats_handle), + None, ) .await? .try_collect() diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index aea9283cc1..ef8758b0f8 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -13,7 +13,7 @@ use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit}; use crate::object_io::{FileMetadata, FileType, LSResult}; use crate::stats::IOStatsRef; use crate::stream_utils::io_stats_on_bytestream; -use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType}; +use crate::{get_io_pool_num_threads, FileFormat, InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; use aws_credential_types::cache::{ CredentialsCache, ProvideCachedCredentials, SharedCredentialsCache, @@ -1063,6 +1063,7 @@ impl ObjectSource for S3LikeSource { page_size: Option, limit: Option, io_stats: Option, + _file_format: Option, ) -> super::Result>> { use crate::object_store_glob::glob; diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 8ed445b5f1..6458ef4363 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -22,6 +22,7 @@ common-resource-request = {path = "../common/resource-request", default-features common-treenode = {path = "../common/treenode", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} +daft-io = {path = "../daft-io", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} daft-table = {path = "../daft-table", default-features = false} indexmap = {workspace = true} @@ -46,6 +47,7 @@ python = [ "common-resource-request/python", "daft-core/python", "daft-dsl/python", + "daft-io/python", "daft-functions/python", "daft-table/python", "daft-scan/python" diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index c0ea1d18ee..cd719e7a50 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -22,7 +22,8 @@ use daft_core::{ schema::{Schema, SchemaRef}, }; use daft_dsl::{col, ExprRef}; -use daft_scan::{file_format::FileFormat, PhysicalScanInfo, Pushdowns, ScanOperatorRef}; +use daft_io::FileFormat; +use daft_scan::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; #[cfg(feature = "python")] use { diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 34d270f431..633165bcaa 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -21,7 +21,7 @@ mod treenode; pub use builder::{LogicalPlanBuilder, PyLogicalPlanBuilder}; pub use daft_core::join::{JoinStrategy, JoinType}; -use daft_scan::file_format::FileFormat; +use daft_io::FileFormat; pub use logical_plan::{LogicalPlan, LogicalPlanRef}; pub use partitioning::ClusteringSpec; pub use physical_plan::{PhysicalPlan, PhysicalPlanRef}; diff --git a/src/daft-plan/src/logical_ops/sink.rs b/src/daft-plan/src/logical_ops/sink.rs index c5207df333..ab42f99ada 100644 --- a/src/daft-plan/src/logical_ops/sink.rs +++ b/src/daft-plan/src/logical_ops/sink.rs @@ -39,7 +39,7 @@ impl Sink { Arc::new(SinkInfo::OutputFileInfo(OutputFileInfo { root_dir: root_dir.clone(), - file_format: file_format.clone(), + file_format: *file_format, partition_cols: resolved_partition_cols, compression: compression.clone(), io_config: io_config.clone(), diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index 90ccb708b6..c0daa61ca7 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -1,8 +1,8 @@ -use common_error::{DaftError, DaftResult}; use daft_core::datatypes::{Field, TimeUnit}; +use daft_io::FileFormat; use serde::{Deserialize, Serialize}; use std::hash::Hash; -use std::{collections::BTreeMap, str::FromStr, sync::Arc}; +use std::{collections::BTreeMap, sync::Arc}; use common_py_serde::impl_bincode_py_state_serialization; @@ -13,56 +13,6 @@ use { pyo3::{pyclass, pyclass::CompareOp, pymethods, IntoPy, PyObject, PyResult, Python}, }; -/// Format of a file, e.g. Parquet, CSV, JSON. -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] -pub enum FileFormat { - Parquet, - Csv, - Json, - Database, - Python, -} - -#[cfg(feature = "python")] -#[pymethods] -impl FileFormat { - fn ext(&self) -> &'static str { - match self { - Self::Parquet => "parquet", - Self::Csv => "csv", - Self::Json => "json", - Self::Database => "db", - Self::Python => "py", - } - } -} - -impl FromStr for FileFormat { - type Err = DaftError; - - fn from_str(file_format: &str) -> DaftResult { - use FileFormat::*; - - if file_format.trim().eq_ignore_ascii_case("parquet") { - Ok(Parquet) - } else if file_format.trim().eq_ignore_ascii_case("csv") { - Ok(Csv) - } else if file_format.trim().eq_ignore_ascii_case("json") { - Ok(Json) - } else if file_format.trim().eq_ignore_ascii_case("database") { - Ok(Database) - } else { - Err(DaftError::TypeError(format!( - "FileFormat {} not supported!", - file_format - ))) - } - } -} - -impl_bincode_py_state_serialization!(FileFormat); - impl From<&FileFormatConfig> for FileFormat { fn from(file_format_config: &FileFormatConfig) -> Self { match file_format_config { @@ -90,6 +40,10 @@ pub enum FileFormatConfig { } impl FileFormatConfig { + pub fn file_format(&self) -> FileFormat { + self.into() + } + pub fn var_name(&self) -> &'static str { use FileFormatConfig::*; diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index e16d443fb8..3797fca360 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -3,7 +3,7 @@ use std::{sync::Arc, vec}; use common_error::{DaftError, DaftResult}; use daft_core::schema::SchemaRef; use daft_csv::CsvParseOptions; -use daft_io::{parse_url, FileMetadata, IOClient, IOStatsContext, IOStatsRef}; +use daft_io::{parse_url, FileFormat, FileMetadata, IOClient, IOStatsContext, IOStatsRef}; use daft_parquet::read::ParquetSchemaInferenceOptions; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use snafu::Snafu; @@ -65,6 +65,7 @@ fn run_glob( io_client: Arc, runtime: Arc, io_stats: Option, + file_format: FileFormat, ) -> DaftResult { let (_, parsed_glob_path) = parse_url(glob_path)?; // Construct a static-lifetime BoxStream returning the FileMetadata @@ -72,7 +73,7 @@ fn run_glob( let runtime_handle = runtime.handle(); let boxstream = runtime_handle.block_on(async move { io_client - .glob(glob_input, None, None, limit, io_stats) + .glob(glob_input, None, None, limit, io_stats, Some(file_format)) .await })?; @@ -90,6 +91,7 @@ fn run_glob_parallel( io_client: Arc, runtime: Arc, io_stats: Option, + file_format: FileFormat, ) -> DaftResult>> { let num_parallel_tasks = 64; @@ -102,7 +104,7 @@ fn run_glob_parallel( runtime.spawn(async move { let stream = io_client - .glob(glob_input, None, None, None, io_stats) + .glob(glob_input, None, None, None, io_stats, Some(file_format)) .await?; let results = stream.collect::>().await; Result::<_, daft_io::Error>::Ok(futures::stream::iter(results)) @@ -137,6 +139,8 @@ impl GlobScanOperator { Some(path) => Ok(path), }?; + let file_format = file_format_config.file_format(); + let (io_runtime, io_client) = storage_config.get_io_client_and_runtime()?; let io_stats = IOStatsContext::new(format!( "GlobScanOperator::try_new schema inference for {first_glob_path}" @@ -147,6 +151,7 @@ impl GlobScanOperator { io_client.clone(), io_runtime.clone(), Some(io_stats.clone()), + file_format, )?; let FileMetadata { filepath: first_filepath, @@ -285,12 +290,14 @@ impl ScanOperator for GlobScanOperator { "GlobScanOperator::to_scan_tasks for {:#?}", self.glob_paths )); + let file_format = self.file_format_config.file_format(); let files = run_glob_parallel( self.glob_paths.clone(), io_client.clone(), io_runtime.clone(), Some(io_stats.clone()), + file_format, )?; let file_format_config = self.file_format_config.clone(); diff --git a/src/daft-scheduler/Cargo.toml b/src/daft-scheduler/Cargo.toml index 8ba0fafe9f..df17f64cee 100644 --- a/src/daft-scheduler/Cargo.toml +++ b/src/daft-scheduler/Cargo.toml @@ -6,6 +6,7 @@ common-io-config = {path = "../common/io-config", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} +daft-io = {path = "../daft-io", default-features = false} daft-plan = {path = "../daft-plan", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} pyo3 = {workspace = true, optional = true} @@ -23,6 +24,7 @@ python = [ "common-daft-config/python", "common-py-serde/python", "daft-core/python", + "daft-io/python", "daft-plan/python", "daft-dsl/python" ] diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 4aaa842e6e..be442f4429 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -13,8 +13,9 @@ use { daft_core::schema::SchemaRef, daft_dsl::python::PyExpr, daft_dsl::Expr, + daft_io::FileFormat, daft_plan::{OutputFileInfo, PyLogicalPlanBuilder}, - daft_scan::{file_format::FileFormat, python::pylib::PyScanTask}, + daft_scan::python::pylib::PyScanTask, pyo3::{pyclass, pymethods, PyObject, PyRef, PyRefMut, PyResult, Python}, std::collections::HashMap, }; @@ -154,7 +155,7 @@ fn tabular_write( .getattr(pyo3::intern!(py, "write_file"))? .call1(( upstream_iter, - file_format.clone(), + *file_format, PySchema::from(schema.clone()), root_dir, compression.clone(),