Skip to content

Commit

Permalink
[PERF] Use region from system and leverage cached credentials when ma…
Browse files Browse the repository at this point in the history
…king new clients (#1490)

* Fixes 2 issues: 
* We always set to default region even when the system can provide us
creds
* We reran the credential chain even though we can leverage a cache
  • Loading branch information
samster25 authored Oct 13, 2023
1 parent 84fcc7f commit fac11a4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(async_closure)]
#![feature(let_chains)]

mod azure_blob;
mod google_cloud;
mod http;
Expand Down
39 changes: 32 additions & 7 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use async_trait::async_trait;
use aws_config::meta::credentials::CredentialsProviderChain;
use aws_config::retry::RetryMode;
use aws_config::timeout::TimeoutConfig;
use aws_smithy_async::rt::sleep::TokioSleep;
Expand All @@ -11,7 +12,7 @@ use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit};
use crate::object_io::{FileMetadata, FileType, LSResult};
use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType};
use aws_config::SdkConfig;
use aws_credential_types::cache::ProvideCachedCredentials;
use aws_credential_types::cache::{ProvideCachedCredentials, SharedCredentialsCache};
use aws_credential_types::provider::error::CredentialsError;
use aws_sig_auth::signer::SigningRequirements;
use common_io_config::S3Config;
Expand Down Expand Up @@ -211,7 +212,10 @@ fn handle_https_client_settings(
Ok(builder)
}

async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> {
async fn build_s3_client(
config: &S3Config,
credentials_cache: Option<SharedCredentialsCache>,
) -> super::Result<(bool, s3::Client)> {
const DEFAULT_REGION: Region = Region::from_static("us-east-1");

let mut anonymous = config.anonymous;
Expand All @@ -228,8 +232,6 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>
};
let builder = if let Some(region) = &config.region_name {
builder.region(Region::new(region.to_owned()))
} else if conf.region().is_none() && config.region_name.is_none() {
builder.region(DEFAULT_REGION)
} else {
builder
};
Expand Down Expand Up @@ -268,7 +270,19 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>
.build();
let builder = builder.timeout_config(timeout_config);

let builder = if config.access_key.is_some() && config.key_id.is_some() {
let cached_creds = if let Some(credentials_cache) = credentials_cache {
let creds = credentials_cache.provide_cached_credentials().await;
creds.ok()
} else {
None
};

let builder = if let Some(cached_creds) = cached_creds {
let provider = CredentialsProviderChain::first_try("different_region_cache", cached_creds)
.or_default_provider()
.await;
builder.credentials_provider(provider)
} else if config.access_key.is_some() && config.key_id.is_some() {
let creds = Credentials::from_keys(
config.key_id.clone().unwrap(),
config.access_key.clone().unwrap(),
Expand All @@ -283,6 +297,7 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>
builder
};

let builder_copy = builder.clone();
let s3_conf = builder.build();
if !config.anonymous {
use CredentialsError::*;
Expand All @@ -300,11 +315,16 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)>
}.with_context(|_| UnableToLoadCredentialsSnafu {})?;
};

let s3_conf = if s3_conf.region().is_none() {
builder_copy.region(DEFAULT_REGION).build()
} else {
s3_conf
};
Ok((anonymous, s3::Client::from_conf(s3_conf)))
}

async fn build_client(config: &S3Config) -> super::Result<S3LikeSource> {
let (anonymous, client) = build_s3_client(config).await?;
let (anonymous, client) = build_s3_client(config, None).await?;
let mut client_map = HashMap::new();
let default_region = client.conf().region().unwrap().clone();
client_map.insert(default_region.clone(), client.into());
Expand Down Expand Up @@ -343,7 +363,12 @@ impl S3LikeSource {

let mut new_config = self.s3_config.clone();
new_config.region_name = Some(region.to_string());
let (_, new_client) = build_s3_client(&new_config).await?;

let creds_cache = w_handle
.get(&self.default_region)
.map(|current_client| current_client.conf().credentials_cache());

let (_, new_client) = build_s3_client(&new_config, creds_cache).await?;

if w_handle.get(region).is_none() {
w_handle.insert(region.clone(), new_client.into());
Expand Down

0 comments on commit fac11a4

Please sign in to comment.