Skip to content

Commit

Permalink
[FEAT] AWS Profile override in S3Config (#2243)
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 authored May 7, 2024
1 parent 0103397 commit adbdceb
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
3 changes: 3 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ class S3Config:
check_hostname_ssl: bool
requester_pays: bool | None
force_virtual_addressing: bool | None
profile_name: str | None

def __init__(
self,
Expand All @@ -448,6 +449,7 @@ class S3Config:
check_hostname_ssl: bool | None = None,
requester_pays: bool | None = None,
force_virtual_addressing: bool | None = None,
profile_name: str | None = None,
): ...
def replace(
self,
Expand All @@ -468,6 +470,7 @@ class S3Config:
check_hostname_ssl: bool | None = None,
requester_pays: bool | None = None,
force_virtual_addressing: bool | None = None,
profile_name: str | None = None,
) -> S3Config:
"""Replaces values if provided, returning a new S3Config"""
...
Expand Down
11 changes: 11 additions & 0 deletions src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::config;
/// check_hostname_ssl: Whether or not to verify the hostname when verifying ssl certificates, this was the legacy behavior for openssl, defaults to True
/// requester_pays: Whether or not the authenticated user will assume transfer costs, which is required by some providers of bulk data, defaults to False
/// force_virtual_addressing: Force S3 client to use virtual addressing in all cases. If False, virtual addressing will only be used if `endpoint_url` is empty, defaults to False
/// profile_name: Name of AWS_PROFILE to load, defaults to None which will then check the Environment Variable `AWS_PROFILE` then fall back to `default`
///
/// Example:
/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx"))
Expand Down Expand Up @@ -188,6 +189,7 @@ impl S3Config {
check_hostname_ssl: Option<bool>,
requester_pays: Option<bool>,
force_virtual_addressing: Option<bool>,
profile_name: Option<String>,
) -> Self {
let def = crate::S3Config::default();
S3Config {
Expand All @@ -212,6 +214,7 @@ impl S3Config {
requester_pays: requester_pays.unwrap_or(def.requester_pays),
force_virtual_addressing: force_virtual_addressing
.unwrap_or(def.force_virtual_addressing),
profile_name: profile_name.or(def.profile_name),
},
}
}
Expand All @@ -236,6 +239,7 @@ impl S3Config {
check_hostname_ssl: Option<bool>,
requester_pays: Option<bool>,
force_virtual_addressing: Option<bool>,
profile_name: Option<String>,
) -> Self {
S3Config {
config: crate::S3Config {
Expand All @@ -259,6 +263,7 @@ impl S3Config {
requester_pays: requester_pays.unwrap_or(self.config.requester_pays),
force_virtual_addressing: force_virtual_addressing
.unwrap_or(self.config.force_virtual_addressing),
profile_name: profile_name.or_else(|| self.config.profile_name.clone()),
},
}
}
Expand Down Expand Up @@ -383,6 +388,12 @@ impl S3Config {
pub fn force_virtual_addressing(&self) -> PyResult<Option<bool>> {
Ok(Some(self.config.force_virtual_addressing))
}

/// AWS profile name
#[getter]
pub fn profile_name(&self) -> PyResult<Option<String>> {
Ok(self.config.profile_name.clone())
}
}

#[pymethods]
Expand Down
11 changes: 9 additions & 2 deletions src/common/io-config/src/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct S3Config {
pub check_hostname_ssl: bool,
pub requester_pays: bool,
pub force_virtual_addressing: bool,
pub profile_name: Option<String>,
}

impl S3Config {
Expand Down Expand Up @@ -66,6 +67,9 @@ impl S3Config {
"Force Virtual Addressing = {}",
self.force_virtual_addressing
));
if let Some(name) = &self.profile_name {
res.push(format!("Profile Name = {}", name));
}
res
}
}
Expand All @@ -92,6 +96,7 @@ impl Default for S3Config {
check_hostname_ssl: true,
requester_pays: false,
force_virtual_addressing: false,
profile_name: None,
}
}
}
Expand All @@ -117,7 +122,8 @@ impl Display for S3Config {
verify_ssl: {},
check_hostname_ssl: {}
requester_pays: {}
force_virtual_addressing: {}",
force_virtual_addressing: {}
profile_name: {:?}",
self.region_name,
self.endpoint_url,
self.key_id,
Expand All @@ -134,7 +140,8 @@ impl Display for S3Config {
self.verify_ssl,
self.check_hostname_ssl,
self.requester_pays,
self.force_virtual_addressing
self.force_virtual_addressing,
self.profile_name
)
}
}
7 changes: 6 additions & 1 deletion src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,12 @@ async fn build_s3_conf(
builder.build()
} else {
let loader = aws_config::from_env();
let loader = if let Some(profile_name) = &config.profile_name {
loader.profile_name(profile_name)
} else {
loader
};

// Set region now to avoid imds
let loader = if let Some(region) = &config.region_name {
loader.region(Region::new(region.to_owned()))
Expand All @@ -396,7 +402,6 @@ async fn build_s3_conf(
None => builder,
Some(endpoint) => builder.endpoint_url(endpoint),
};

let builder = if config.endpoint_url.is_some() && !config.force_virtual_addressing {
builder.force_path_style(true)
} else {
Expand Down

0 comments on commit adbdceb

Please sign in to comment.