Skip to content

Commit

Permalink
[ENH] Construct and pass NAC (#2630)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
- NAC is another enum type of the storage abstraction that gets constructed from the config on startup. Introduced config as well in this PR.

## Test plan
- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia committed Aug 20, 2024
1 parent 91e9111 commit d48f4fd
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 113 deletions.
4 changes: 2 additions & 2 deletions rust/blockstore/src/arrow/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ impl BlockManager {
// TODO: NAC register/deregister/validation goes here.
async {
let key = format!("block/{}", id);
let stream = self.storage.get(&key).instrument(
let stream = self.storage.get_stream(&key).instrument(
tracing::trace_span!(parent: Span::current(), "BlockManager storage get"),
).await;
match stream {
Expand Down Expand Up @@ -341,7 +341,7 @@ impl SparseIndexManager {
tracing::info!("Cache miss - fetching sparse index from storage");
let key = format!("sparse_index/{}", id);
tracing::debug!("Reading sparse index from storage with key: {}", key);
let stream = self.storage.get(&key).await;
let stream = self.storage.get_stream(&key).await;
let mut buf: Vec<u8> = Vec::new();
match stream {
Ok(mut bytes) => {
Expand Down
2 changes: 1 addition & 1 deletion rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl HnswIndexProvider {
for file in FILES.iter() {
let key = self.format_key(source_id, file);
tracing::info!("Loading hnsw index file: {}", key);
let stream = self.storage.get(&key).await;
let stream = self.storage.get_stream(&key).await;
let reader = match stream {
Ok(reader) => reader,
Err(e) => {
Expand Down
106 changes: 59 additions & 47 deletions rust/storage/src/admissioncontrolleds3.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use super::GetError;
use crate::s3::{S3GetError, S3Storage};
use crate::{
config::StorageConfig,
s3::{S3GetError, S3PutError, S3Storage, StorageConfigError},
stream::ByteStreamItem,
};
use async_trait::async_trait;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};
use futures::{future::Shared, FutureExt, StreamExt};
use futures::{future::Shared, FutureExt, Stream};
use parking_lot::Mutex;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
Expand Down Expand Up @@ -56,57 +61,39 @@ impl AdmissionControlledS3Storage {
}
}

// TODO: Remove this once the upstream consumers switch to non-streaming APIs.
pub async fn get_stream(
&self,
key: &str,
) -> Result<
Box<dyn Stream<Item = ByteStreamItem> + Unpin + Send>,
AdmissionControlledS3StorageError,
> {
match self
.storage
.get_stream(key)
.instrument(tracing::trace_span!(parent: Span::current(), "Storage get"))
.await
{
Ok(res) => Ok(res),
Err(e) => {
tracing::error!("Error reading from storage: {}", e);
return Err(AdmissionControlledS3StorageError::S3GetError(e));
}
}
}

async fn read_from_storage(
storage: S3Storage,
key: String,
) -> Result<Arc<Vec<u8>>, AdmissionControlledS3StorageError> {
let stream = storage
let bytes_res = storage
.get(&key)
.instrument(tracing::trace_span!(parent: Span::current(), "Storage get"))
.await;
match stream {
Ok(mut bytes) => {
let read_block_span =
tracing::trace_span!(parent: Span::current(), "Read bytes to end");
let buf = read_block_span
.in_scope(|| async {
let mut buf: Vec<u8> = Vec::new();
while let Some(res) = bytes.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(err) => {
tracing::error!("Error reading from storage: {}", err);
match err {
GetError::S3Error(e) => {
return Err(
AdmissionControlledS3StorageError::S3GetError(e),
);
}
GetError::NoSuchKey(e) => {
return Err(
AdmissionControlledS3StorageError::S3GetError(
S3GetError::NoSuchKey(e),
),
);
}
GetError::LocalError(_) => unreachable!(),
}
}
}
}
tracing::info!("Read {:?} bytes from s3", buf.len());
Ok(Some(buf))
})
.await?;
match buf {
Some(buf) => Ok(Arc::new(buf)),
None => {
// Buffer is empty. Nothing interesting to do.
Ok(Arc::new(vec![]))
}
}
match bytes_res {
Ok(bytes) => {
return Ok(bytes);
}
Err(e) => {
tracing::error!("Error reading from storage: {}", e);
Expand Down Expand Up @@ -144,4 +131,29 @@ impl AdmissionControlledS3Storage {
}
res
}

pub async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> {
self.storage.put_file(key, path).await
}

pub async fn put_bytes(&self, key: &str, bytes: Vec<u8>) -> Result<(), S3PutError> {
self.storage.put_bytes(key, bytes).await
}
}

#[async_trait]
impl Configurable<StorageConfig> for AdmissionControlledS3Storage {
async fn try_from_config(config: &StorageConfig) -> Result<Self, Box<dyn ChromaError>> {
match &config {
StorageConfig::AdmissionControlledS3(nacconfig) => {
let s3_storage =
S3Storage::try_from_config(&StorageConfig::S3(nacconfig.s3_config.clone()))
.await?;
return Ok(Self::new(s3_storage));
}
_ => {
return Err(Box::new(StorageConfigError::InvalidStorageConfig));
}
}
}
}
13 changes: 10 additions & 3 deletions rust/storage/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ pub enum StorageConfig {
S3(S3StorageConfig),
#[serde(alias = "local")]
Local(LocalStorageConfig),
#[serde(alias = "admissioncontrolleds3")]
AdmissionControlledS3(AdmissionControlledS3StorageConfig),
}

#[derive(Deserialize, PartialEq, Debug)]
#[derive(Deserialize, PartialEq, Debug, Clone)]
pub enum S3CredentialsConfig {
Minio,
AWS,
}

#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
/// The configuration for the s3 storage type
/// # Fields
/// - bucket: The name of the bucket to use.
Expand All @@ -33,7 +35,7 @@ pub struct S3StorageConfig {
pub upload_part_size_bytes: usize,
}

#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
/// The configuration for the local storage type
/// # Fields
/// - root: The root directory to use for storage.
Expand All @@ -43,3 +45,8 @@ pub struct S3StorageConfig {
pub struct LocalStorageConfig {
pub root: String,
}

#[derive(Deserialize, Debug, Clone)]
pub struct AdmissionControlledS3StorageConfig {
pub s3_config: S3StorageConfig,
}
68 changes: 65 additions & 3 deletions rust/storage/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::sync::Arc;

use self::config::StorageConfig;
use self::s3::S3GetError;
use self::stream::ByteStreamItem;
use admissioncontrolleds3::AdmissionControlledS3StorageError;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};

Expand All @@ -16,6 +19,7 @@ use thiserror::Error;
pub enum Storage {
S3(s3::S3Storage),
Local(local::LocalStorage),
AdmissionControlledS3(admissioncontrolleds3::AdmissionControlledS3Storage),
}

#[derive(Error, Debug, Clone)]
Expand Down Expand Up @@ -56,13 +60,48 @@ impl ChromaError for PutError {
}

impl Storage {
pub async fn get(
pub async fn get(&self, key: &str) -> Result<Arc<Vec<u8>>, GetError> {
match self {
Storage::S3(s3) => {
let res = s3.get(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => match e {
S3GetError::NoSuchKey(_) => Err(GetError::NoSuchKey(key.to_string())),
_ => Err(GetError::S3Error(e)),
},
}
}
Storage::Local(local) => {
let res = local.get(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => Err(GetError::LocalError(e)),
}
}
Storage::AdmissionControlledS3(admission_controlled_storage) => {
let res = admission_controlled_storage.get(key.to_string()).await;
match res {
Ok(res) => Ok(res),
Err(e) => match e {
AdmissionControlledS3StorageError::S3GetError(e) => match e {
S3GetError::NoSuchKey(_) => Err(GetError::NoSuchKey(key.to_string())),
_ => Err(GetError::S3Error(e)),
},
},
}
}
}
}

// TODO: Remove this once the upstream switches to consume non-streaming.
pub async fn get_stream(
&self,
key: &str,
) -> Result<Box<dyn Stream<Item = ByteStreamItem> + Unpin + Send>, GetError> {
match self {
Storage::S3(s3) => {
let res = s3.get(key).await;
let res = s3.get_stream(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => match e {
Expand All @@ -72,12 +111,24 @@ impl Storage {
}
}
Storage::Local(local) => {
let res = local.get(key).await;
let res = local.get_stream(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => Err(GetError::LocalError(e)),
}
}
Storage::AdmissionControlledS3(admission_controlled_storage) => {
let res = admission_controlled_storage.get_stream(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => match e {
AdmissionControlledS3StorageError::S3GetError(e) => match e {
S3GetError::NoSuchKey(_) => Err(GetError::NoSuchKey(key.to_string())),
_ => Err(GetError::S3Error(e)),
},
},
}
}
}
}

Expand All @@ -91,6 +142,10 @@ impl Storage {
.put_file(key, path)
.await
.map_err(|e| PutError::LocalError(e)),
Storage::AdmissionControlledS3(as3) => as3
.put_file(key, path)
.await
.map_err(|e| PutError::S3Error(e)),
}
}

Expand All @@ -104,6 +159,10 @@ impl Storage {
.put_bytes(key, &bytes)
.await
.map_err(|e| PutError::LocalError(e)),
Storage::AdmissionControlledS3(as3) => as3
.put_bytes(key, bytes)
.await
.map_err(|e| PutError::S3Error(e)),
}
}
}
Expand All @@ -114,5 +173,8 @@ pub async fn from_config(config: &StorageConfig) -> Result<Storage, Box<dyn Chro
StorageConfig::Local(_) => Ok(Storage::Local(
local::LocalStorage::try_from_config(config).await?,
)),
StorageConfig::AdmissionControlledS3(_) => Ok(Storage::AdmissionControlledS3(
admissioncontrolleds3::AdmissionControlledS3Storage::try_from_config(config).await?,
)),
}
}
47 changes: 46 additions & 1 deletion rust/storage/src/local.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use super::stream::ByteStream;
use super::stream::ByteStreamItem;
use super::{config::StorageConfig, s3::StorageConfigError};
use crate::GetError;
use async_trait::async_trait;
use chroma_config::Configurable;
use chroma_error::ChromaError;
use futures::Stream;
use futures::StreamExt;
use std::sync::Arc;
use tracing::Instrument;
use tracing::Span;

#[derive(Clone)]
pub struct LocalStorage {
Expand All @@ -19,7 +24,47 @@ impl LocalStorage {
};
}

pub async fn get(
pub async fn get(&self, key: &str) -> Result<Arc<Vec<u8>>, String> {
let mut stream = self
.get_stream(&key)
.instrument(tracing::trace_span!(parent: Span::current(), "Local Storage get"))
.await?;
let read_block_span =
tracing::trace_span!(parent: Span::current(), "Local storage read bytes to end");
let buf = read_block_span
.in_scope(|| async {
let mut buf: Vec<u8> = Vec::new();
while let Some(res) = stream.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(err) => {
tracing::error!("Error reading from storage: {}", err);
match err {
GetError::LocalError(e) => {
return Err(e);
}
_ => unreachable!(),
}
}
}
}
tracing::info!("Read {:?} bytes from local storage", buf.len());
Ok(Some(buf))
})
.await?;

match buf {
Some(buf) => Ok(Arc::new(buf)),
None => {
// Buffer is empty. Nothing interesting to do.
Ok(Arc::new(vec![]))
}
}
}

pub(super) async fn get_stream(
&self,
key: &str,
) -> Result<Box<dyn Stream<Item = ByteStreamItem> + Unpin + Send>, String> {
Expand Down
Loading

0 comments on commit d48f4fd

Please sign in to comment.