From a886630a52febb8fe811570aae37d2c0b9f0395e Mon Sep 17 00:00:00 2001 From: Sanket Kedia Date: Mon, 19 Aug 2024 17:41:47 -0700 Subject: [PATCH] [ENH] NAC rate limits requests (#2632) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Introduced count based rate limiting policy where a maximum of X number of outstanding are permitted. X is configurable. ## Test plan - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None --- rust/storage/src/admissioncontrolleds3.rs | 151 +++++++++++++++++++++- rust/storage/src/config.rs | 11 ++ rust/storage/src/s3.rs | 8 +- rust/worker/chroma_config.yaml | 6 + rust/worker/src/config.rs | 21 +++ 5 files changed, 187 insertions(+), 10 deletions(-) diff --git a/rust/storage/src/admissioncontrolleds3.rs b/rust/storage/src/admissioncontrolleds3.rs index 5a81fcee689..9716b0a446f 100644 --- a/rust/storage/src/admissioncontrolleds3.rs +++ b/rust/storage/src/admissioncontrolleds3.rs @@ -1,5 +1,5 @@ use crate::{ - config::StorageConfig, + config::{CountBasedPolicyConfig, RateLimitingConfig, StorageConfig}, s3::{S3GetError, S3PutError, S3Storage, StorageConfigError}, stream::ByteStreamItem, }; @@ -10,6 +10,7 @@ use futures::{future::Shared, FutureExt, Stream}; use parking_lot::Mutex; use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; use thiserror::Error; +use tokio::sync::{Semaphore, SemaphorePermit}; use tracing::{Instrument, Span}; /// Wrapper over s3 storage that provides proxy features such as @@ -37,6 +38,7 @@ pub struct AdmissionControlledS3Storage { >, >, >, + rate_limiter: Arc, } #[derive(Error, Debug, Clone)] @@ -54,10 +56,19 @@ impl ChromaError for AdmissionControlledS3StorageError { } impl AdmissionControlledS3Storage { - pub fn new(storage: S3Storage) -> Self { + pub fn new_with_default_policy(storage: S3Storage) -> Self { Self { storage, outstanding_requests: Arc::new(Mutex::new(HashMap::new())), + rate_limiter: Arc::new(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(15))), + } + } + + pub fn new(storage: S3Storage, policy: RateLimitPolicy) -> Self { + Self { + storage, + outstanding_requests: Arc::new(Mutex::new(HashMap::new())), + rate_limiter: Arc::new(policy), } } @@ -106,12 +117,17 @@ impl AdmissionControlledS3Storage { &self, key: String, ) -> Result>, AdmissionControlledS3StorageError> { + // If there is a duplicate request and the original request finishes + // before we look it up in the map below then we will end up with another + // request to S3. We rely on synchronization on the cache + // by the upstream consumer to make sure that this works correctly. let future_to_await; + let is_dupe: bool; { let mut requests = self.outstanding_requests.lock(); let maybe_inflight = requests.get(&key).map(|fut| fut.clone()); - future_to_await = match maybe_inflight { - Some(fut) => fut, + (future_to_await, is_dupe) = match maybe_inflight { + Some(fut) => (fut, true), None => { let get_storage_future = AdmissionControlledS3Storage::read_from_storage( self.storage.clone(), @@ -120,16 +136,25 @@ impl AdmissionControlledS3Storage { .boxed() .shared(); requests.insert(key.clone(), get_storage_future.clone()); - get_storage_future + (get_storage_future, false) } }; } + + // Acquire permit. + let permit: SemaphorePermit<'_>; + if is_dupe { + permit = self.rate_limiter.enter().await; + } + let res = future_to_await.await; { let mut requests = self.outstanding_requests.lock(); requests.remove(&key); } + res + // Permit gets dropped here since it is RAII. } pub async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> { @@ -149,7 +174,9 @@ impl Configurable for AdmissionControlledS3Storage { let s3_storage = S3Storage::try_from_config(&StorageConfig::S3(nacconfig.s3_config.clone())) .await?; - return Ok(Self::new(s3_storage)); + let policy = + RateLimitPolicy::try_from_config(&nacconfig.rate_limiting_policy).await?; + return Ok(Self::new(s3_storage, policy)); } _ => { return Err(Box::new(StorageConfigError::InvalidStorageConfig)); @@ -157,3 +184,115 @@ impl Configurable for AdmissionControlledS3Storage { } } } + +// Prefer enum dispatch over dyn since there could +// only be a handful of these policies. +#[derive(Debug)] +enum RateLimitPolicy { + CountBasedPolicy(CountBasedPolicy), +} + +impl RateLimitPolicy { + async fn enter(&self) -> SemaphorePermit<'_> { + match self { + RateLimitPolicy::CountBasedPolicy(policy) => { + return policy.acquire().await; + } + } + } +} + +#[derive(Debug)] +struct CountBasedPolicy { + max_allowed_outstanding: usize, + remaining_tokens: Semaphore, +} + +impl CountBasedPolicy { + fn new(max_allowed_outstanding: usize) -> Self { + Self { + max_allowed_outstanding, + remaining_tokens: Semaphore::new(max_allowed_outstanding), + } + } + async fn acquire(&self) -> SemaphorePermit<'_> { + let token_res = self.remaining_tokens.acquire().await; + match token_res { + Ok(token) => { + return token; + } + Err(e) => panic!("AcquireToken Failed {}", e), + } + } +} + +#[async_trait] +impl Configurable for RateLimitPolicy { + async fn try_from_config(config: &RateLimitingConfig) -> Result> { + match &config { + RateLimitingConfig::CountBasedPolicy(count_policy) => { + return Ok(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new( + count_policy.max_concurrent_requests, + ))); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{admissioncontrolleds3::AdmissionControlledS3Storage, s3::S3Storage}; + + fn get_s3_client() -> aws_sdk_s3::Client { + // Set up credentials assuming minio is running locally + let cred = aws_sdk_s3::config::Credentials::new( + "minio", + "minio123", + None, + None, + "loaded-from-env", + ); + + // Set up s3 client + let config = aws_sdk_s3::config::Builder::new() + .endpoint_url("http://127.0.0.1:9000".to_string()) + .credentials_provider(cred) + .behavior_version_latest() + .region(aws_sdk_s3::config::Region::new("us-east-1")) + .force_path_style(true) + .build(); + + aws_sdk_s3::Client::from_conf(config) + } + + #[tokio::test] + #[cfg(CHROMA_KUBERNETES_INTEGRATION)] + async fn test_put_get_key() { + let client = get_s3_client(); + + let storage = S3Storage { + bucket: "test".to_string(), + client, + upload_part_size_bytes: 1024 * 1024 * 8, + }; + storage.create_bucket().await.unwrap(); + let admission_controlled_storage = + AdmissionControlledS3Storage::new_with_default_policy(storage); + + let test_data = "test data"; + admission_controlled_storage + .put_bytes("test", test_data.as_bytes().to_vec()) + .await + .unwrap(); + + let buf = admission_controlled_storage + .get("test".to_string()) + .await + .unwrap(); + + let buf = String::from_utf8(Arc::unwrap_or_clone(buf)).unwrap(); + assert_eq!(buf, test_data); + } +} diff --git a/rust/storage/src/config.rs b/rust/storage/src/config.rs index 21c7250f7c0..d87d02397bd 100644 --- a/rust/storage/src/config.rs +++ b/rust/storage/src/config.rs @@ -49,4 +49,15 @@ pub struct LocalStorageConfig { #[derive(Deserialize, Debug, Clone)] pub struct AdmissionControlledS3StorageConfig { pub s3_config: S3StorageConfig, + pub rate_limiting_policy: RateLimitingConfig, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct CountBasedPolicyConfig { + pub max_concurrent_requests: usize, +} + +#[derive(Deserialize, Debug, Clone)] +pub enum RateLimitingConfig { + CountBasedPolicy(CountBasedPolicyConfig), } diff --git a/rust/storage/src/s3.rs b/rust/storage/src/s3.rs index 3a302ac47ce..3d7d7ec5fff 100644 --- a/rust/storage/src/s3.rs +++ b/rust/storage/src/s3.rs @@ -40,9 +40,9 @@ use tracing::Span; #[derive(Clone)] pub struct S3Storage { - bucket: String, - client: aws_sdk_s3::Client, - upload_part_size_bytes: usize, + pub(super) bucket: String, + pub(super) client: aws_sdk_s3::Client, + pub(super) upload_part_size_bytes: usize, } #[derive(Error, Debug)] @@ -84,7 +84,7 @@ impl S3Storage { }; } - async fn create_bucket(&self) -> Result<(), String> { + pub(super) async fn create_bucket(&self) -> Result<(), String> { // Creates a public bucket with default settings in the region. // This should only be used for testing and in production // the bucket should be provisioned ahead of time. diff --git a/rust/worker/chroma_config.yaml b/rust/worker/chroma_config.yaml index 8b1ef98138c..eeffc25b33c 100644 --- a/rust/worker/chroma_config.yaml +++ b/rust/worker/chroma_config.yaml @@ -30,6 +30,9 @@ query_service: connect_timeout_ms: 5000 request_timeout_ms: 30000 # 1 minute upload_part_size_bytes: 536870912 # 512MiB + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "logservice.chroma" @@ -84,6 +87,9 @@ compaction_service: connect_timeout_ms: 5000 request_timeout_ms: 60000 # 1 minute upload_part_size_bytes: 536870912 # 512MiB + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "logservice.chroma" diff --git a/rust/worker/src/config.rs b/rust/worker/src/config.rs index 2d064ec1c52..135c0bf1ad1 100644 --- a/rust/worker/src/config.rs +++ b/rust/worker/src/config.rs @@ -171,6 +171,9 @@ mod tests { connect_timeout_ms: 5000 request_timeout_ms: 1000 upload_part_size_bytes: 8388608 + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "localhost" @@ -225,6 +228,9 @@ mod tests { connect_timeout_ms: 5000 request_timeout_ms: 1000 upload_part_size_bytes: 8388608 + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "localhost" @@ -305,6 +311,9 @@ mod tests { connect_timeout_ms: 5000 request_timeout_ms: 1000 upload_part_size_bytes: 8388608 + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "localhost" @@ -359,6 +368,9 @@ mod tests { connect_timeout_ms: 5000 request_timeout_ms: 1000 upload_part_size_bytes: 8388608 + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "localhost" @@ -457,6 +469,9 @@ mod tests { connect_timeout_ms: 5000 request_timeout_ms: 1000 upload_part_size_bytes: 8388608 + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "localhost" @@ -511,6 +526,9 @@ mod tests { connect_timeout_ms: 5000 request_timeout_ms: 1000 upload_part_size_bytes: 8388608 + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "localhost" @@ -607,6 +625,9 @@ mod tests { connect_timeout_ms: 5000 request_timeout_ms: 1000 upload_part_size_bytes: 8388608 + rate_limiting_policy: + CountBasedPolicy: + max_concurrent_requests: 15 log: Grpc: host: "localhost"