diff --git a/daft/daft.pyi b/daft/daft.pyi index bc502e26e7..a9723d934a 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -428,6 +428,7 @@ class S3Config: check_hostname_ssl: bool requester_pays: bool | None force_virtual_addressing: bool | None + profile_name: str | None def __init__( self, @@ -448,6 +449,7 @@ class S3Config: check_hostname_ssl: bool | None = None, requester_pays: bool | None = None, force_virtual_addressing: bool | None = None, + profile_name: str | None = None, ): ... def replace( self, @@ -468,6 +470,7 @@ class S3Config: check_hostname_ssl: bool | None = None, requester_pays: bool | None = None, force_virtual_addressing: bool | None = None, + profile_name: str | None = None, ) -> S3Config: """Replaces values if provided, returning a new S3Config""" ... diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index a24cf2d443..ebc13a96b6 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -26,6 +26,7 @@ use crate::config; /// check_hostname_ssl: Whether or not to verify the hostname when verifying ssl certificates, this was the legacy behavior for openssl, defaults to True /// requester_pays: Whether or not the authenticated user will assume transfer costs, which is required by some providers of bulk data, defaults to False /// force_virtual_addressing: Force S3 client to use virtual addressing in all cases. If False, virtual addressing will only be used if `endpoint_url` is empty, defaults to False +/// profile_name: Name of AWS_PROFILE to load, defaults to None which will then check the Environment Variable `AWS_PROFILE` then fall back to `default` /// /// Example: /// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx")) @@ -188,6 +189,7 @@ impl S3Config { check_hostname_ssl: Option, requester_pays: Option, force_virtual_addressing: Option, + profile_name: Option, ) -> Self { let def = crate::S3Config::default(); S3Config { @@ -212,6 +214,7 @@ impl S3Config { requester_pays: requester_pays.unwrap_or(def.requester_pays), force_virtual_addressing: force_virtual_addressing .unwrap_or(def.force_virtual_addressing), + profile_name: profile_name.or(def.profile_name), }, } } @@ -236,6 +239,7 @@ impl S3Config { check_hostname_ssl: Option, requester_pays: Option, force_virtual_addressing: Option, + profile_name: Option, ) -> Self { S3Config { config: crate::S3Config { @@ -259,6 +263,7 @@ impl S3Config { requester_pays: requester_pays.unwrap_or(self.config.requester_pays), force_virtual_addressing: force_virtual_addressing .unwrap_or(self.config.force_virtual_addressing), + profile_name: profile_name.or_else(|| self.config.profile_name.clone()), }, } } @@ -383,6 +388,12 @@ impl S3Config { pub fn force_virtual_addressing(&self) -> PyResult> { Ok(Some(self.config.force_virtual_addressing)) } + + /// AWS profile name + #[getter] + pub fn profile_name(&self) -> PyResult> { + Ok(self.config.profile_name.clone()) + } } #[pymethods] diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index 0f66d0285b..f7efdf3607 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -23,6 +23,7 @@ pub struct S3Config { pub check_hostname_ssl: bool, pub requester_pays: bool, pub force_virtual_addressing: bool, + pub profile_name: Option, } impl S3Config { @@ -66,6 +67,9 @@ impl S3Config { "Force Virtual Addressing = {}", self.force_virtual_addressing )); + if let Some(name) = &self.profile_name { + res.push(format!("Profile Name = {}", name)); + } res } } @@ -92,6 +96,7 @@ impl Default for S3Config { check_hostname_ssl: true, requester_pays: false, force_virtual_addressing: false, + profile_name: None, } } } @@ -117,7 +122,8 @@ impl Display for S3Config { verify_ssl: {}, check_hostname_ssl: {} requester_pays: {} - force_virtual_addressing: {}", + force_virtual_addressing: {} + profile_name: {:?}", self.region_name, self.endpoint_url, self.key_id, @@ -134,7 +140,8 @@ impl Display for S3Config { self.verify_ssl, self.check_hostname_ssl, self.requester_pays, - self.force_virtual_addressing + self.force_virtual_addressing, + self.profile_name ) } } diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 7eec342526..e582ee6493 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -375,6 +375,12 @@ async fn build_s3_conf( builder.build() } else { let loader = aws_config::from_env(); + let loader = if let Some(profile_name) = &config.profile_name { + loader.profile_name(profile_name) + } else { + loader + }; + // Set region now to avoid imds let loader = if let Some(region) = &config.region_name { loader.region(Region::new(region.to_owned())) @@ -396,7 +402,6 @@ async fn build_s3_conf( None => builder, Some(endpoint) => builder.endpoint_url(endpoint), }; - let builder = if config.endpoint_url.is_some() && !config.force_virtual_addressing { builder.force_path_style(true) } else {