diff --git a/daft/daft.pyi b/daft/daft.pyi index a33dc7cf38..2c44699ecc 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -472,6 +472,11 @@ class S3Config: """Replaces values if provided, returning a new S3Config""" ... + @staticmethod + def from_env() -> S3Config: + """Creates an S3Config, retrieving credentials and configurations from the current environtment""" + ... + class AzureConfig: """ I/O configuration for accessing Azure Blob Storage. @@ -530,6 +535,8 @@ class IOConfig: """ Recreate an IOConfig from a JSON string. """ + ... + def replace( self, s3: S3Config | None = None, azure: AzureConfig | None = None, gcs: GCSConfig | None = None ) -> IOConfig: diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index ecfeb19c25..a24cf2d443 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -263,6 +263,21 @@ impl S3Config { } } + /// Creates an S3Config from the current environment, auto-discovering variables such as + /// credentials, regions and more. + #[staticmethod] + pub fn from_env(py: Python) -> PyResult { + let io_config_from_env_func = py + .import(pyo3::intern!(py, "daft"))? + .getattr(pyo3::intern!(py, "daft"))? + .getattr(pyo3::intern!(py, "s3_config_from_env"))?; + io_config_from_env_func.call0().map(|pyany| { + pyany + .extract() + .expect("s3_config_from_env function must return S3Config") + }) + } + pub fn __repr__(&self) -> PyResult { Ok(format!("{}", self.config)) } diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index b85140b4b4..3ce4b59ba3 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -2,7 +2,7 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig}; pub use py::register_modules; mod py { - use crate::{get_io_client, get_runtime, parse_url, stats::IOStatsContext}; + use crate::{get_io_client, get_runtime, parse_url, s3_like, stats::IOStatsContext}; use common_error::DaftResult; use futures::TryStreamExt; use pyo3::{ @@ -66,11 +66,24 @@ mod py { Ok(crate::set_io_pool_num_threads(num_threads as usize)) } + /// Creates an S3Config from the current environment, auto-discovering variables such as + /// credentials, regions and more. + #[pyfunction] + fn s3_config_from_env(py: Python) -> PyResult { + let s3_config: DaftResult = py.allow_threads(|| { + let runtime = get_runtime(false)?; + let runtime_handle = runtime.handle(); + let _rt_guard = runtime_handle.enter(); + runtime_handle.block_on(async { Ok(s3_like::s3_config_from_env().await?) }) + }); + Ok(common_io_config::python::S3Config { config: s3_config? }) + } + pub fn register_modules(py: Python, parent: &PyModule) -> PyResult<()> { common_io_config::python::register_modules(py, parent)?; parent.add_function(wrap_pyfunction!(io_glob, parent)?)?; parent.add_function(wrap_pyfunction!(set_io_pool_num_threads, parent)?)?; - + parent.add_function(wrap_pyfunction!(s3_config_from_env, parent)?)?; Ok(()) } } diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 1a143f9040..0321547e4b 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -196,6 +196,33 @@ impl From for super::Error { } } +/// Retrieves an S3Config from the environment by leveraging the AWS SDK's credentials chain +pub(crate) async fn s3_config_from_env() -> super::Result { + let default_s3_config = S3Config::default(); + let (anonymous, s3_conf) = build_s3_conf(&default_s3_config, None).await?; + let creds = s3_conf + .credentials_cache() + .provide_cached_credentials() + .await + .with_context(|_| UnableToLoadCredentialsSnafu {})?; + let key_id = Some(creds.access_key_id().to_string()); + let access_key = Some(creds.secret_access_key().to_string()); + let session_token = creds.session_token().map(|t| t.to_string()); + let region_name = s3_conf.region().map(|r| r.to_string()); + Ok(S3Config { + // Do not perform auto-discovery of endpoint_url. This is possible, but requires quite a bit + // of work that our current implementation of `build_s3_conf` does not yet do. See smithy-rs code: + // https://github.com/smithy-lang/smithy-rs/blob/94ecd38c2518583042796b2b45c37947237e31dd/aws/rust-runtime/aws-config/src/lib.rs#L824-L849 + endpoint_url: None, + region_name, + key_id, + session_token, + access_key, + anonymous, + ..default_s3_config + }) +} + /// Helper to parse S3 URLs, returning (scheme, bucket, key) fn parse_url(uri: &str) -> super::Result<(String, String, String)> { let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; @@ -247,10 +274,10 @@ fn handle_https_client_settings( Ok(builder) } -async fn build_s3_client( +async fn build_s3_conf( config: &S3Config, credentials_cache: Option, -) -> super::Result<(bool, s3::Client)> { +) -> super::Result<(bool, s3::Config)> { const DEFAULT_REGION: Region = Region::from_static("us-east-1"); let mut anonymous = config.anonymous; @@ -405,6 +432,15 @@ async fn build_s3_client( } else { s3_conf }; + + Ok((anonymous, s3_conf)) +} + +async fn build_s3_client( + config: &S3Config, + credentials_cache: Option, +) -> super::Result<(bool, s3::Client)> { + let (anonymous, s3_conf) = build_s3_conf(config, credentials_cache).await?; Ok((anonymous, s3::Client::from_conf(s3_conf))) }