Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] NAC rate limits requests #2632

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 145 additions & 6 deletions rust/storage/src/admissioncontrolleds3.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
config::StorageConfig,
config::{CountBasedPolicyConfig, RateLimitingConfig, StorageConfig},
s3::{S3GetError, S3PutError, S3Storage, StorageConfigError},
stream::ByteStreamItem,
};
Expand All @@ -10,6 +10,7 @@ use futures::{future::Shared, FutureExt, Stream};
use parking_lot::Mutex;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::sync::{Semaphore, SemaphorePermit};
use tracing::{Instrument, Span};

/// Wrapper over s3 storage that provides proxy features such as
Expand Down Expand Up @@ -37,6 +38,7 @@ pub struct AdmissionControlledS3Storage {
>,
>,
>,
rate_limiter: Arc<RateLimitPolicy>,
}

#[derive(Error, Debug, Clone)]
Expand All @@ -54,10 +56,19 @@ impl ChromaError for AdmissionControlledS3StorageError {
}

impl AdmissionControlledS3Storage {
pub fn new(storage: S3Storage) -> Self {
pub fn new_with_default_policy(storage: S3Storage) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(15))),
}
}

pub fn new(storage: S3Storage, policy: RateLimitPolicy) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(policy),
}
}

Expand Down Expand Up @@ -106,12 +117,17 @@ impl AdmissionControlledS3Storage {
&self,
key: String,
) -> Result<Arc<Vec<u8>>, AdmissionControlledS3StorageError> {
// If there is a duplicate request and the original request finishes
// before we look it up in the map below then we will end up with another
// request to S3. We rely on synchronization on the cache
// by the upstream consumer to make sure that this works correctly.
let future_to_await;
let is_dupe: bool;
{
let mut requests = self.outstanding_requests.lock();
let maybe_inflight = requests.get(&key).map(|fut| fut.clone());
future_to_await = match maybe_inflight {
Some(fut) => fut,
(future_to_await, is_dupe) = match maybe_inflight {
Some(fut) => (fut, true),
None => {
let get_storage_future = AdmissionControlledS3Storage::read_from_storage(
self.storage.clone(),
Expand All @@ -120,16 +136,25 @@ impl AdmissionControlledS3Storage {
.boxed()
.shared();
requests.insert(key.clone(), get_storage_future.clone());
get_storage_future
(get_storage_future, false)
}
};
}

// Acquire permit.
let permit: SemaphorePermit<'_>;
if is_dupe {
permit = self.rate_limiter.enter().await;
}

let res = future_to_await.await;
{
let mut requests = self.outstanding_requests.lock();
requests.remove(&key);
}

res
// Permit gets dropped here since it is RAII.
}

pub async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> {
Expand All @@ -149,11 +174,125 @@ impl Configurable<StorageConfig> for AdmissionControlledS3Storage {
let s3_storage =
S3Storage::try_from_config(&StorageConfig::S3(nacconfig.s3_config.clone()))
.await?;
return Ok(Self::new(s3_storage));
let policy =
RateLimitPolicy::try_from_config(&nacconfig.rate_limiting_policy).await?;
return Ok(Self::new(s3_storage, policy));
}
_ => {
return Err(Box::new(StorageConfigError::InvalidStorageConfig));
}
}
}
}

// Prefer enum dispatch over dyn since there could
sanketkedia marked this conversation as resolved.
Show resolved Hide resolved
// only be a handful of these policies.
#[derive(Debug)]
enum RateLimitPolicy {
sanketkedia marked this conversation as resolved.
Show resolved Hide resolved
CountBasedPolicy(CountBasedPolicy),
}

impl RateLimitPolicy {
async fn enter(&self) -> SemaphorePermit<'_> {
match self {
RateLimitPolicy::CountBasedPolicy(policy) => {
return policy.acquire().await;
}
}
}
}

#[derive(Debug)]
struct CountBasedPolicy {
max_allowed_outstanding: usize,
remaining_tokens: Semaphore,
}

impl CountBasedPolicy {
fn new(max_allowed_outstanding: usize) -> Self {
Self {
max_allowed_outstanding,
remaining_tokens: Semaphore::new(max_allowed_outstanding),
}
}
async fn acquire(&self) -> SemaphorePermit<'_> {
let token_res = self.remaining_tokens.acquire().await;
match token_res {
Ok(token) => {
return token;
}
Err(e) => panic!("AcquireToken Failed {}", e),
}
}
}

#[async_trait]
impl Configurable<RateLimitingConfig> for RateLimitPolicy {
async fn try_from_config(config: &RateLimitingConfig) -> Result<Self, Box<dyn ChromaError>> {
match &config {
RateLimitingConfig::CountBasedPolicy(count_policy) => {
return Ok(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(
count_policy.max_concurrent_requests,
)));
}
}
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::{admissioncontrolleds3::AdmissionControlledS3Storage, s3::S3Storage};

fn get_s3_client() -> aws_sdk_s3::Client {
// Set up credentials assuming minio is running locally
let cred = aws_sdk_s3::config::Credentials::new(
"minio",
"minio123",
None,
None,
"loaded-from-env",
);

// Set up s3 client
let config = aws_sdk_s3::config::Builder::new()
.endpoint_url("http://127.0.0.1:9000".to_string())
.credentials_provider(cred)
.behavior_version_latest()
.region(aws_sdk_s3::config::Region::new("us-east-1"))
.force_path_style(true)
.build();

aws_sdk_s3::Client::from_conf(config)
}

#[tokio::test]
#[cfg(CHROMA_KUBERNETES_INTEGRATION)]
async fn test_put_get_key() {
let client = get_s3_client();

let storage = S3Storage {
bucket: "test".to_string(),
client,
upload_part_size_bytes: 1024 * 1024 * 8,
};
storage.create_bucket().await.unwrap();
let admission_controlled_storage =
AdmissionControlledS3Storage::new_with_default_policy(storage);

let test_data = "test data";
admission_controlled_storage
.put_bytes("test", test_data.as_bytes().to_vec())
.await
.unwrap();

let buf = admission_controlled_storage
.get("test".to_string())
.await
.unwrap();

let buf = String::from_utf8(Arc::unwrap_or_clone(buf)).unwrap();
assert_eq!(buf, test_data);
}
}
11 changes: 11 additions & 0 deletions rust/storage/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,15 @@ pub struct LocalStorageConfig {
#[derive(Deserialize, Debug, Clone)]
pub struct AdmissionControlledS3StorageConfig {
pub s3_config: S3StorageConfig,
pub rate_limiting_policy: RateLimitingConfig,
}

#[derive(Deserialize, Debug, Clone)]
pub struct CountBasedPolicyConfig {
pub max_concurrent_requests: usize,
}

#[derive(Deserialize, Debug, Clone)]
pub enum RateLimitingConfig {
CountBasedPolicy(CountBasedPolicyConfig),
}
8 changes: 4 additions & 4 deletions rust/storage/src/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ use tracing::Span;

#[derive(Clone)]
pub struct S3Storage {
bucket: String,
client: aws_sdk_s3::Client,
upload_part_size_bytes: usize,
pub(super) bucket: String,
pub(super) client: aws_sdk_s3::Client,
pub(super) upload_part_size_bytes: usize,
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -84,7 +84,7 @@ impl S3Storage {
};
}

async fn create_bucket(&self) -> Result<(), String> {
pub(super) async fn create_bucket(&self) -> Result<(), String> {
// Creates a public bucket with default settings in the region.
// This should only be used for testing and in production
// the bucket should be provisioned ahead of time.
Expand Down
6 changes: 6 additions & 0 deletions rust/worker/chroma_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ query_service:
connect_timeout_ms: 5000
request_timeout_ms: 30000 # 1 minute
upload_part_size_bytes: 536870912 # 512MiB
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "logservice.chroma"
Expand Down Expand Up @@ -84,6 +87,9 @@ compaction_service:
connect_timeout_ms: 5000
request_timeout_ms: 60000 # 1 minute
upload_part_size_bytes: 536870912 # 512MiB
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "logservice.chroma"
Expand Down
21 changes: 21 additions & 0 deletions rust/worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -225,6 +228,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -305,6 +311,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -359,6 +368,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -457,6 +469,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -511,6 +526,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -607,6 +625,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down
Loading