Skip to content

Commit

Permalink
[ENH] NAC rate limits requests (#2632)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - Introduced count based rate limiting policy where a maximum of X number of outstanding are permitted. X is configurable.

## Test plan
- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia committed Aug 20, 2024
1 parent 2b098f8 commit a886630
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 10 deletions.
151 changes: 145 additions & 6 deletions rust/storage/src/admissioncontrolleds3.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
config::StorageConfig,
config::{CountBasedPolicyConfig, RateLimitingConfig, StorageConfig},
s3::{S3GetError, S3PutError, S3Storage, StorageConfigError},
stream::ByteStreamItem,
};
Expand All @@ -10,6 +10,7 @@ use futures::{future::Shared, FutureExt, Stream};
use parking_lot::Mutex;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::sync::{Semaphore, SemaphorePermit};
use tracing::{Instrument, Span};

/// Wrapper over s3 storage that provides proxy features such as
Expand Down Expand Up @@ -37,6 +38,7 @@ pub struct AdmissionControlledS3Storage {
>,
>,
>,
rate_limiter: Arc<RateLimitPolicy>,
}

#[derive(Error, Debug, Clone)]
Expand All @@ -54,10 +56,19 @@ impl ChromaError for AdmissionControlledS3StorageError {
}

impl AdmissionControlledS3Storage {
pub fn new(storage: S3Storage) -> Self {
pub fn new_with_default_policy(storage: S3Storage) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(15))),
}
}

pub fn new(storage: S3Storage, policy: RateLimitPolicy) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(policy),
}
}

Expand Down Expand Up @@ -106,12 +117,17 @@ impl AdmissionControlledS3Storage {
&self,
key: String,
) -> Result<Arc<Vec<u8>>, AdmissionControlledS3StorageError> {
// If there is a duplicate request and the original request finishes
// before we look it up in the map below then we will end up with another
// request to S3. We rely on synchronization on the cache
// by the upstream consumer to make sure that this works correctly.
let future_to_await;
let is_dupe: bool;
{
let mut requests = self.outstanding_requests.lock();
let maybe_inflight = requests.get(&key).map(|fut| fut.clone());
future_to_await = match maybe_inflight {
Some(fut) => fut,
(future_to_await, is_dupe) = match maybe_inflight {
Some(fut) => (fut, true),
None => {
let get_storage_future = AdmissionControlledS3Storage::read_from_storage(
self.storage.clone(),
Expand All @@ -120,16 +136,25 @@ impl AdmissionControlledS3Storage {
.boxed()
.shared();
requests.insert(key.clone(), get_storage_future.clone());
get_storage_future
(get_storage_future, false)
}
};
}

// Acquire permit.
let permit: SemaphorePermit<'_>;
if is_dupe {
permit = self.rate_limiter.enter().await;
}

let res = future_to_await.await;
{
let mut requests = self.outstanding_requests.lock();
requests.remove(&key);
}

res
// Permit gets dropped here since it is RAII.
}

pub async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> {
Expand All @@ -149,11 +174,125 @@ impl Configurable<StorageConfig> for AdmissionControlledS3Storage {
let s3_storage =
S3Storage::try_from_config(&StorageConfig::S3(nacconfig.s3_config.clone()))
.await?;
return Ok(Self::new(s3_storage));
let policy =
RateLimitPolicy::try_from_config(&nacconfig.rate_limiting_policy).await?;
return Ok(Self::new(s3_storage, policy));
}
_ => {
return Err(Box::new(StorageConfigError::InvalidStorageConfig));
}
}
}
}

// Prefer enum dispatch over dyn since there could
// only be a handful of these policies.
#[derive(Debug)]
enum RateLimitPolicy {
CountBasedPolicy(CountBasedPolicy),
}

impl RateLimitPolicy {
async fn enter(&self) -> SemaphorePermit<'_> {
match self {
RateLimitPolicy::CountBasedPolicy(policy) => {
return policy.acquire().await;
}
}
}
}

#[derive(Debug)]
struct CountBasedPolicy {
max_allowed_outstanding: usize,
remaining_tokens: Semaphore,
}

impl CountBasedPolicy {
fn new(max_allowed_outstanding: usize) -> Self {
Self {
max_allowed_outstanding,
remaining_tokens: Semaphore::new(max_allowed_outstanding),
}
}
async fn acquire(&self) -> SemaphorePermit<'_> {
let token_res = self.remaining_tokens.acquire().await;
match token_res {
Ok(token) => {
return token;
}
Err(e) => panic!("AcquireToken Failed {}", e),
}
}
}

#[async_trait]
impl Configurable<RateLimitingConfig> for RateLimitPolicy {
async fn try_from_config(config: &RateLimitingConfig) -> Result<Self, Box<dyn ChromaError>> {
match &config {
RateLimitingConfig::CountBasedPolicy(count_policy) => {
return Ok(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(
count_policy.max_concurrent_requests,
)));
}
}
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::{admissioncontrolleds3::AdmissionControlledS3Storage, s3::S3Storage};

fn get_s3_client() -> aws_sdk_s3::Client {
// Set up credentials assuming minio is running locally
let cred = aws_sdk_s3::config::Credentials::new(
"minio",
"minio123",
None,
None,
"loaded-from-env",
);

// Set up s3 client
let config = aws_sdk_s3::config::Builder::new()
.endpoint_url("http://127.0.0.1:9000".to_string())
.credentials_provider(cred)
.behavior_version_latest()
.region(aws_sdk_s3::config::Region::new("us-east-1"))
.force_path_style(true)
.build();

aws_sdk_s3::Client::from_conf(config)
}

#[tokio::test]
#[cfg(CHROMA_KUBERNETES_INTEGRATION)]
async fn test_put_get_key() {
let client = get_s3_client();

let storage = S3Storage {
bucket: "test".to_string(),
client,
upload_part_size_bytes: 1024 * 1024 * 8,
};
storage.create_bucket().await.unwrap();
let admission_controlled_storage =
AdmissionControlledS3Storage::new_with_default_policy(storage);

let test_data = "test data";
admission_controlled_storage
.put_bytes("test", test_data.as_bytes().to_vec())
.await
.unwrap();

let buf = admission_controlled_storage
.get("test".to_string())
.await
.unwrap();

let buf = String::from_utf8(Arc::unwrap_or_clone(buf)).unwrap();
assert_eq!(buf, test_data);
}
}
11 changes: 11 additions & 0 deletions rust/storage/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,15 @@ pub struct LocalStorageConfig {
#[derive(Deserialize, Debug, Clone)]
pub struct AdmissionControlledS3StorageConfig {
pub s3_config: S3StorageConfig,
pub rate_limiting_policy: RateLimitingConfig,
}

#[derive(Deserialize, Debug, Clone)]
pub struct CountBasedPolicyConfig {
pub max_concurrent_requests: usize,
}

#[derive(Deserialize, Debug, Clone)]
pub enum RateLimitingConfig {
CountBasedPolicy(CountBasedPolicyConfig),
}
8 changes: 4 additions & 4 deletions rust/storage/src/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ use tracing::Span;

#[derive(Clone)]
pub struct S3Storage {
bucket: String,
client: aws_sdk_s3::Client,
upload_part_size_bytes: usize,
pub(super) bucket: String,
pub(super) client: aws_sdk_s3::Client,
pub(super) upload_part_size_bytes: usize,
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -84,7 +84,7 @@ impl S3Storage {
};
}

async fn create_bucket(&self) -> Result<(), String> {
pub(super) async fn create_bucket(&self) -> Result<(), String> {
// Creates a public bucket with default settings in the region.
// This should only be used for testing and in production
// the bucket should be provisioned ahead of time.
Expand Down
6 changes: 6 additions & 0 deletions rust/worker/chroma_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ query_service:
connect_timeout_ms: 5000
request_timeout_ms: 30000 # 1 minute
upload_part_size_bytes: 536870912 # 512MiB
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "logservice.chroma"
Expand Down Expand Up @@ -84,6 +87,9 @@ compaction_service:
connect_timeout_ms: 5000
request_timeout_ms: 60000 # 1 minute
upload_part_size_bytes: 536870912 # 512MiB
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "logservice.chroma"
Expand Down
21 changes: 21 additions & 0 deletions rust/worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -225,6 +228,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -305,6 +311,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -359,6 +368,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -457,6 +469,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -511,6 +526,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down Expand Up @@ -607,6 +625,9 @@ mod tests {
connect_timeout_ms: 5000
request_timeout_ms: 1000
upload_part_size_bytes: 8388608
rate_limiting_policy:
CountBasedPolicy:
max_concurrent_requests: 15
log:
Grpc:
host: "localhost"
Expand Down

0 comments on commit a886630

Please sign in to comment.