Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add S3Config.from_env functionality #2137

Merged
merged 12 commits into from
Apr 16, 2024
7 changes: 7 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,11 @@ class S3Config:
"""Replaces values if provided, returning a new S3Config"""
...

@staticmethod
def from_env() -> S3Config:
"""Creates an S3Config, retrieving credentials and configurations from the current environtment"""
...

class AzureConfig:
"""
I/O configuration for accessing Azure Blob Storage.
Expand Down Expand Up @@ -530,6 +535,8 @@ class IOConfig:
"""
Recreate an IOConfig from a JSON string.
"""
...

def replace(
self, s3: S3Config | None = None, azure: AzureConfig | None = None, gcs: GCSConfig | None = None
) -> IOConfig:
Expand Down
15 changes: 15 additions & 0 deletions src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,21 @@ impl S3Config {
}
}

/// Creates an S3Config from the current environment, auto-discovering variables such as
/// credentials, regions and more.
#[staticmethod]
pub fn from_env(py: Python) -> PyResult<Self> {
let io_config_from_env_func = py
.import(pyo3::intern!(py, "daft"))?
.getattr(pyo3::intern!(py, "daft"))?
.getattr(pyo3::intern!(py, "s3_config_from_env"))?;
io_config_from_env_func.call0().map(|pyany| {
pyany
.extract()
.expect("s3_config_from_env function must return S3Config")
})
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.config))
}
Expand Down
17 changes: 15 additions & 2 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig};
pub use py::register_modules;

mod py {
use crate::{get_io_client, get_runtime, parse_url, stats::IOStatsContext};
use crate::{get_io_client, get_runtime, parse_url, s3_like, stats::IOStatsContext};
use common_error::DaftResult;
use futures::TryStreamExt;
use pyo3::{
Expand Down Expand Up @@ -66,11 +66,24 @@ mod py {
Ok(crate::set_io_pool_num_threads(num_threads as usize))
}

/// Creates an S3Config from the current environment, auto-discovering variables such as
/// credentials, regions and more.
#[pyfunction]
fn s3_config_from_env(py: Python) -> PyResult<common_io_config::python::S3Config> {
let s3_config: DaftResult<common_io_config::S3Config> = py.allow_threads(|| {
let runtime = get_runtime(false)?;
let runtime_handle = runtime.handle();
let _rt_guard = runtime_handle.enter();
runtime_handle.block_on(async { Ok(s3_like::s3_config_from_env().await?) })
});
Ok(common_io_config::python::S3Config { config: s3_config? })
}

pub fn register_modules(py: Python, parent: &PyModule) -> PyResult<()> {
common_io_config::python::register_modules(py, parent)?;
parent.add_function(wrap_pyfunction!(io_glob, parent)?)?;
parent.add_function(wrap_pyfunction!(set_io_pool_num_threads, parent)?)?;

parent.add_function(wrap_pyfunction!(s3_config_from_env, parent)?)?;
Ok(())
}
}
40 changes: 38 additions & 2 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,33 @@ impl From<Error> for super::Error {
}
}

/// Retrieves an S3Config from the environment by leveraging the AWS SDK's credentials chain
pub(crate) async fn s3_config_from_env() -> super::Result<S3Config> {
let default_s3_config = S3Config::default();
let (anonymous, s3_conf) = build_s3_conf(&default_s3_config, None).await?;
let creds = s3_conf
.credentials_cache()
.provide_cached_credentials()
.await
.with_context(|_| UnableToLoadCredentialsSnafu {})?;
let key_id = Some(creds.access_key_id().to_string());
let access_key = Some(creds.secret_access_key().to_string());
let session_token = creds.session_token().map(|t| t.to_string());
let region_name = s3_conf.region().map(|r| r.to_string());
Ok(S3Config {
// Do not perform auto-discovery of endpoint_url. This is possible, but requires quite a bit
// of work that our current implementation of `build_s3_conf` does not yet do. See smithy-rs code:
// https://github.com/smithy-lang/smithy-rs/blob/94ecd38c2518583042796b2b45c37947237e31dd/aws/rust-runtime/aws-config/src/lib.rs#L824-L849
endpoint_url: None,
region_name,
key_id,
session_token,
access_key,
anonymous,
..default_s3_config
})
}

/// Helper to parse S3 URLs, returning (scheme, bucket, key)
fn parse_url(uri: &str) -> super::Result<(String, String, String)> {
let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;
Expand Down Expand Up @@ -247,10 +274,10 @@ fn handle_https_client_settings(
Ok(builder)
}

async fn build_s3_client(
async fn build_s3_conf(
config: &S3Config,
credentials_cache: Option<SharedCredentialsCache>,
) -> super::Result<(bool, s3::Client)> {
) -> super::Result<(bool, s3::Config)> {
const DEFAULT_REGION: Region = Region::from_static("us-east-1");

let mut anonymous = config.anonymous;
Expand Down Expand Up @@ -405,6 +432,15 @@ async fn build_s3_client(
} else {
s3_conf
};

Ok((anonymous, s3_conf))
}

async fn build_s3_client(
config: &S3Config,
credentials_cache: Option<SharedCredentialsCache>,
) -> super::Result<(bool, s3::Client)> {
let (anonymous, s3_conf) = build_s3_conf(config, credentials_cache).await?;
Ok((anonymous, s3::Client::from_conf(s3_conf)))
}

Expand Down
Loading