Skip to content

Commit

Permalink
Rate limit for NAC
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Aug 4, 2024
1 parent 648e0f8 commit 197103a
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 43 deletions.
18 changes: 12 additions & 6 deletions rust/blockstore/src/arrow/block/delta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ mod test {
let path = tmp_dir.path().to_str().unwrap();
let storage = Storage::Local(LocalStorage::new(path));
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let block_manager = BlockManager::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -274,7 +275,8 @@ mod test {
let path = tmp_dir.path().to_str().unwrap();
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let block_manager = BlockManager::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -337,7 +339,8 @@ mod test {
let path = tmp_dir.path().to_str().unwrap();
let storage = Storage::Local(LocalStorage::new(path));
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let block_manager = BlockManager::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -380,7 +383,8 @@ mod test {
let path = tmp_dir.path().to_str().unwrap();
let storage = Storage::Local(LocalStorage::new(path));
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let block_manager = BlockManager::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -421,7 +425,8 @@ mod test {
let path = tmp_dir.path().to_str().unwrap();
let storage = Storage::Local(LocalStorage::new(path));
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let block_manager = BlockManager::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -490,7 +495,8 @@ mod test {
let path = tmp_dir.path().to_str().unwrap();
let storage = Storage::Local(LocalStorage::new(path));
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let block_manager = BlockManager::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down
45 changes: 30 additions & 15 deletions rust/blockstore/src/arrow/blockfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -633,7 +634,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -700,7 +702,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -817,7 +820,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -859,7 +863,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -975,7 +980,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1019,7 +1025,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1058,7 +1065,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1094,7 +1102,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1142,7 +1151,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1178,7 +1188,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1232,7 +1243,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1266,7 +1278,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1331,7 +1344,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down Expand Up @@ -1372,7 +1386,8 @@ mod tests {
let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap()));
let block_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let sparse_index_cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down
3 changes: 2 additions & 1 deletion rust/blockstore/src/arrow/concurrency_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ mod tests {
let sparse_index_cache = Cache::new(&CacheConfig::Lru(LruConfig {
capacity: SPARSE_INDEX_CACHE_CAPACITY,
}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let blockfile_provider = ArrowBlockfileProvider::new(
storage,
TEST_MAX_BLOCK_SIZE_BYTES,
Expand Down
3 changes: 2 additions & 1 deletion rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,8 @@ mod tests {

let storage = Storage::Local(LocalStorage::new(storage_dir.to_str().unwrap()));
let cache = Cache::new(&CacheConfig::Unbounded(UnboundedCacheConfig {}));
let network_admission_control = NetworkAdmissionControl::new(storage.clone());
let network_admission_control =
NetworkAdmissionControl::new_with_default_policy(storage.clone());
let provider =
HnswIndexProvider::new(storage, hnsw_tmp_path, cache, network_admission_control);
let segment = Segment {
Expand Down
10 changes: 10 additions & 0 deletions rust/storage/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,13 @@ pub struct S3StorageConfig {
pub struct LocalStorageConfig {
pub root: String,
}

#[derive(Deserialize, Debug)]
pub struct CountBasedPolicyConfig {
pub max_concurrent_requests: usize,
}

#[derive(Deserialize, Debug)]
pub enum StorageAdmissionConfig {
CountBasedPolicy(CountBasedPolicyConfig),
}
95 changes: 92 additions & 3 deletions rust/storage/src/network_admission_control.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use crate::config::{CountBasedPolicyConfig, StorageAdmissionConfig};

use super::{GetError, Storage};
use async_trait::async_trait;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};
use futures::{future::Shared, FutureExt, StreamExt};
use futures::{future::Shared, FutureExt, StreamExt, TryFutureExt};
use parking_lot::Mutex;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use std::{collections::HashMap, future::Future, marker::PhantomData, pin::Pin, sync::Arc};
use thiserror::Error;
use tokio::sync::{Semaphore, SemaphorePermit};
use tracing::{Instrument, Span};

#[derive(Clone)]
Expand All @@ -25,6 +30,7 @@ pub struct NetworkAdmissionControl {
>,
>,
>,
rate_limiter: Arc<RateLimitPolicy>,
}

#[derive(Error, Debug, Clone)]
Expand All @@ -48,10 +54,18 @@ impl ChromaError for NetworkAdmissionControlError {
}

impl NetworkAdmissionControl {
pub fn new(storage: Storage) -> Self {
pub fn new_with_default_policy(storage: Storage) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::new(15))),
}
}
pub fn new(storage: Storage, policy: RateLimitPolicy) -> Self {
Self {
storage,
outstanding_requests: Arc::new(Mutex::new(HashMap::new())),
rate_limiter: Arc::new(policy),
}
}

Expand Down Expand Up @@ -108,6 +122,22 @@ impl NetworkAdmissionControl {
}
}

async fn enter(&self) -> SemaphorePermit<'_> {
match &*self.rate_limiter {
RateLimitPolicy::CountBasedPolicy(policy) => {
return policy.acquire().await;
}
}
}

async fn exit(&self, permit: SemaphorePermit<'_>) {
match &*self.rate_limiter {
RateLimitPolicy::CountBasedPolicy(policy) => {
policy.drop(permit).await;
}
}
}

pub async fn get<F, R>(
&self,
key: String,
Expand All @@ -117,6 +147,8 @@ impl NetworkAdmissionControl {
R: Future<Output = Result<(), Box<NetworkAdmissionControlError>>> + Send + 'static,
F: (FnOnce(Vec<u8>) -> R) + Send + 'static,
{
// Wait for permit.
let permit = self.enter().await;
let future_to_await;
{
let mut requests = self.outstanding_requests.lock();
Expand All @@ -141,6 +173,63 @@ impl NetworkAdmissionControl {
let mut requests = self.outstanding_requests.lock();
requests.remove(&key);
}
// Release permit.
self.exit(permit).await;
res
}
}

// Prefer enum dispatch over dyn since there could
// only be a handful of these policies.
#[derive(Debug)]
enum RateLimitPolicy {
CountBasedPolicy(CountBasedPolicy),
}

#[derive(Debug)]
struct CountBasedPolicy {
max_allowed_outstanding: usize,
remaining_tokens: Semaphore,
}

impl CountBasedPolicy {
fn new(max_allowed_outstanding: usize) -> Self {
Self {
max_allowed_outstanding,
remaining_tokens: Semaphore::new(max_allowed_outstanding),
}
}
async fn acquire(&self) -> SemaphorePermit<'_> {
let token_res = self.remaining_tokens.acquire().await;
match token_res {
Ok(token) => {
return token;
}
Err(e) => panic!("AcquireToken Failed {}", e),
}
}
async fn drop(&self, permit: SemaphorePermit<'_>) {
drop(permit);
}
}

pub async fn from_config(
config: &StorageAdmissionConfig,
storage: Storage,
) -> Result<NetworkAdmissionControl, Box<dyn ChromaError>> {
match &config {
StorageAdmissionConfig::CountBasedPolicy(policy) => Ok(NetworkAdmissionControl::new(
storage,
RateLimitPolicy::CountBasedPolicy(CountBasedPolicy::try_from_config(policy).await?),
)),
}
}

#[async_trait]
impl Configurable<CountBasedPolicyConfig> for CountBasedPolicy {
async fn try_from_config(
config: &CountBasedPolicyConfig,
) -> Result<Self, Box<dyn ChromaError>> {
Ok(Self::new(config.max_concurrent_requests))
}
}
Loading

0 comments on commit 197103a

Please sign in to comment.