Skip to content

Commit

Permalink
[FEAT]: huggingface integration (#2701)
Browse files Browse the repository at this point in the history
added a few public sample files to my personal huggingface account

```py
df = daft.read_csv("hf://datasets/universalmind303/daft-docs/iris.csv")
```
  • Loading branch information
universalmind303 authored Aug 22, 2024
1 parent a72321c commit 7e9208e
Show file tree
Hide file tree
Showing 25 changed files with 879 additions and 70 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ class HTTPConfig:
I/O configuration for accessing HTTP systems
"""

user_agent: str | None
bearer_token: str | None

def __init__(self, bearer_token: str | None = None): ...

class S3Config:
"""
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Integrations
integrations/microsoft-azure
integrations/aws
integrations/sql
integrations/huggingface
64 changes: 64 additions & 0 deletions docs/source/user_guide/integrations/huggingface.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
Huggingface Datasets
===========

Daft is able to read datasets directly from Huggingface via the ``hf://`` protocol.

Since huggingface will `automatically convert <https://huggingface.co/docs/dataset-viewer/en/parquet>`_ all public datasets to parquet format,
we can read these datasets using the ``read_parquet`` method.

.. NOTE::
This is limited to either public datasets, or PRO/ENTERPRISE datasets.

For other file formats, you will need to manually specify the path or glob pattern to the files you want to read, similar to how you would read from a local file system.


Reading Public Datasets
-----------------------

.. code:: python
import daft
df = daft.read_parquet("hf://username/dataset_name")
This will read the entire dataset into a daft DataFrame.

Not only can you read entire datasets, but you can also read individual files from a dataset.

.. code:: python
import daft
df = daft.read_parquet("hf://username/dataset_name/file_name.parquet")
# or a csv file
df = daft.read_csv("hf://username/dataset_name/file_name.csv")
# or a glob pattern
df = daft.read_parquet("hf://username/dataset_name/**/*.parquet")
Authorization
-------------

For authenticated datasets:

.. code:: python
from daft.io import IOConfig, HTTPConfig
io_config = IoConfig(http=HTTPConfig(bearer_token="your_token"))
df = daft.read_parquet("hf://username/dataset_name", io_config=io_config)
It's important to note that this will not work with standard tier private datasets.
Huggingface does not auto convert private datasets to parquet format, so you will need to specify the path to the files you want to read.

.. code:: python
df = daft.read_parquet("hf://username/my_private_dataset", io_config=io_config) # Errors
to get around this, you can read all files using a glob pattern *(assuming they are in parquet format)*

.. code:: python
df = daft.read_parquet("hf://username/my_private_dataset/**/*.parquet", io_config=io_config) # Works
33 changes: 31 additions & 2 deletions src/common/io-config/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,40 @@ use std::fmt::Formatter;
use serde::Deserialize;
use serde::Serialize;

use crate::ObfuscatedString;

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct HTTPConfig {
pub user_agent: String,
pub bearer_token: Option<ObfuscatedString>,
}

impl Default for HTTPConfig {
fn default() -> Self {
HTTPConfig {
user_agent: "daft/0.0.1".to_string(), // NOTE: Ideally we grab the version of Daft, but that requires a dependency on daft-core
bearer_token: None,
}
}
}

impl HTTPConfig {
pub fn new<S: Into<ObfuscatedString>>(bearer_token: Option<S>) -> Self {
HTTPConfig {
bearer_token: bearer_token.map(|t| t.into()),
..Default::default()
}
}
}

impl HTTPConfig {
pub fn multiline_display(&self) -> Vec<String> {
vec![format!("user_agent = {}", self.user_agent)]
let mut v = vec![format!("user_agent = {}", self.user_agent)];
if let Some(bearer_token) = &self.bearer_token {
v.push(format!("bearer_token = {}", bearer_token));
}

v
}
}

Expand All @@ -30,6 +48,17 @@ impl Display for HTTPConfig {
"HTTPConfig
user_agent: {}",
self.user_agent,
)
)?;

if let Some(bearer_token) = &self.bearer_token {
write!(
f,
"
bearer_token: {}",
bearer_token
)
} else {
Ok(())
}
}
}
17 changes: 16 additions & 1 deletion src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ pub struct IOConfig {
///
/// Args:
/// user_agent (str, optional): The value for the user-agent header, defaults to "daft/{__version__}" if not provided
/// bearer_token (str, optional): Bearer token to use for authentication. This will be used as the value for the `Authorization` header. such as "Authorization: Bearer xxx"
///
/// Example:
/// >>> io_config = IOConfig(http=HTTPConfig(user_agent="my_application/0.0.1"))
/// >>> io_config = IOConfig(http=HTTPConfig(user_agent="my_application/0.0.1", bearer_token="xxx"))
/// >>> daft.read_parquet("http://some-path", io_config=io_config)
#[derive(Clone, Default)]
#[pyclass]
Expand Down Expand Up @@ -901,6 +902,20 @@ impl From<config::IOConfig> for IOConfig {
}
}

#[pymethods]
impl HTTPConfig {
#[new]
pub fn new(bearer_token: Option<String>) -> Self {
HTTPConfig {
config: crate::HTTPConfig::new(bearer_token),
}
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.config))
}
}

pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_class::<AzureConfig>()?;
parent.add_class::<GCSConfig>()?;
Expand Down
11 changes: 9 additions & 2 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ azure_storage_blobs = {version = "0.17.0", features = ["enable_reqwest"], defaul
bytes = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-io-config = {path = "../common/io-config", default-features = false}
common-py-serde = {path = "../common/py-serde", default-features = false}
futures = {workspace = true}
globset = "0.4"
google-cloud-storage = {version = "0.15.0", default-features = false, features = ["default-tls", "auth"]}
Expand All @@ -29,22 +30,28 @@ openssl-sys = {version = "0.9.102", features = ["vendored"]}
pyo3 = {workspace = true, optional = true}
rand = "0.8.5"
regex = {version = "1.10.4"}
serde = {workspace = true}
snafu = {workspace = true}
tokio = {workspace = true}
tokio-stream = {workspace = true}
url = {workspace = true}

[dependencies.reqwest]
default-features = false
features = ["stream", "native-tls"]
features = ["stream", "native-tls", "json"]
version = "0.11.18"

[dev-dependencies]
md5 = "0.7.0"
tempfile = "3.8.1"

[features]
python = ["dep:pyo3", "common-error/python", "common-io-config/python"]
python = [
"dep:pyo3",
"common-error/python",
"common-io-config/python",
"common-py-serde/python"
]

[package]
edition = {workspace = true}
Expand Down
3 changes: 2 additions & 1 deletion src/daft-io/src/azure_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
object_io::{FileMetadata, FileType, LSResult, ObjectSource},
stats::IOStatsRef,
stream_utils::io_stats_on_bytestream,
GetResult,
FileFormat, GetResult,
};
use common_io_config::AzureConfig;

Expand Down Expand Up @@ -577,6 +577,7 @@ impl ObjectSource for AzureBlobSource {
page_size: Option<i32>,
limit: Option<usize>,
io_stats: Option<IOStatsRef>,
_file_format: Option<FileFormat>,
) -> super::Result<BoxStream<'static, super::Result<FileMetadata>>> {
use crate::object_store_glob::glob;

Expand Down
58 changes: 58 additions & 0 deletions src/daft-io/src/file_format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use std::str::FromStr;

use common_error::{DaftError, DaftResult};
use common_py_serde::impl_bincode_py_state_serialization;
#[cfg(feature = "python")]
use pyo3::prelude::*;

use serde::{Deserialize, Serialize};

/// Format of a file, e.g. Parquet, CSV, JSON.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Copy)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum FileFormat {
Parquet,
Csv,
Json,
Database,
Python,
}

#[cfg(feature = "python")]
#[pymethods]
impl FileFormat {
fn ext(&self) -> &'static str {
match self {
Self::Parquet => "parquet",
Self::Csv => "csv",
Self::Json => "json",
Self::Database => "db",
Self::Python => "py",
}
}
}

impl FromStr for FileFormat {
type Err = DaftError;

fn from_str(file_format: &str) -> DaftResult<Self> {
use FileFormat::*;

if file_format.trim().eq_ignore_ascii_case("parquet") {
Ok(Parquet)
} else if file_format.trim().eq_ignore_ascii_case("csv") {
Ok(Csv)
} else if file_format.trim().eq_ignore_ascii_case("json") {
Ok(Json)
} else if file_format.trim().eq_ignore_ascii_case("database") {
Ok(Database)
} else {
Err(DaftError::TypeError(format!(
"FileFormat {} not supported!",
file_format
)))
}
}
}

impl_bincode_py_state_serialization!(FileFormat);
2 changes: 2 additions & 0 deletions src/daft-io/src/google_cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::object_io::LSResult;
use crate::object_io::ObjectSource;
use crate::stats::IOStatsRef;
use crate::stream_utils::io_stats_on_bytestream;
use crate::FileFormat;
use crate::GetResult;
use common_io_config::GCSConfig;

Expand Down Expand Up @@ -436,6 +437,7 @@ impl ObjectSource for GCSSource {
page_size: Option<i32>,
limit: Option<usize>,
io_stats: Option<IOStatsRef>,
_file_format: Option<FileFormat>,
) -> super::Result<BoxStream<'static, super::Result<FileMetadata>>> {
use crate::object_store_glob::glob;

Expand Down
4 changes: 3 additions & 1 deletion src/daft-io/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
object_io::{FileMetadata, FileType, LSResult},
stats::IOStatsRef,
stream_utils::io_stats_on_bytestream,
FileFormat,
};

use super::object_io::{GetResult, ObjectSource};
Expand Down Expand Up @@ -140,7 +141,7 @@ fn _get_file_metadata_from_html(path: &str, text: &str) -> super::Result<Vec<Fil
}

pub(crate) struct HttpSource {
client: reqwest::Client,
pub(crate) client: reqwest::Client,
}

impl From<Error> for super::Error {
Expand Down Expand Up @@ -276,6 +277,7 @@ impl ObjectSource for HttpSource {
_page_size: Option<i32>,
limit: Option<usize>,
io_stats: Option<IOStatsRef>,
_file_format: Option<FileFormat>,
) -> super::Result<BoxStream<'static, super::Result<FileMetadata>>> {
use crate::object_store_glob::glob;

Expand Down
Loading

0 comments on commit 7e9208e

Please sign in to comment.