Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Construct and pass NAC #2630

Merged
merged 5 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/blockstore/src/arrow/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ impl BlockManager {
// TODO: NAC register/deregister/validation goes here.
async {
let key = format!("block/{}", id);
let stream = self.storage.get(&key).instrument(
let stream = self.storage.get_stream(&key).instrument(
tracing::trace_span!(parent: Span::current(), "BlockManager storage get"),
).await;
match stream {
Expand Down Expand Up @@ -341,7 +341,7 @@ impl SparseIndexManager {
tracing::info!("Cache miss - fetching sparse index from storage");
let key = format!("sparse_index/{}", id);
tracing::debug!("Reading sparse index from storage with key: {}", key);
let stream = self.storage.get(&key).await;
let stream = self.storage.get_stream(&key).await;
let mut buf: Vec<u8> = Vec::new();
match stream {
Ok(mut bytes) => {
Expand Down
2 changes: 1 addition & 1 deletion rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl HnswIndexProvider {
for file in FILES.iter() {
let key = self.format_key(source_id, file);
tracing::info!("Loading hnsw index file: {}", key);
let stream = self.storage.get(&key).await;
let stream = self.storage.get_stream(&key).await;
let reader = match stream {
Ok(reader) => reader,
Err(e) => {
Expand Down
106 changes: 59 additions & 47 deletions rust/storage/src/admissioncontrolleds3.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use super::GetError;
use crate::s3::{S3GetError, S3Storage};
use crate::{
config::StorageConfig,
s3::{S3GetError, S3PutError, S3Storage, StorageConfigError},
stream::ByteStreamItem,
};
use async_trait::async_trait;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};
use futures::{future::Shared, FutureExt, StreamExt};
use futures::{future::Shared, FutureExt, Stream};
use parking_lot::Mutex;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
Expand Down Expand Up @@ -56,57 +61,39 @@ impl AdmissionControlledS3Storage {
}
}

// TODO: Remove this once the upstream consumers switch to non-streaming APIs.
sanketkedia marked this conversation as resolved.
Show resolved Hide resolved
pub async fn get_stream(
&self,
key: &str,
) -> Result<
Box<dyn Stream<Item = ByteStreamItem> + Unpin + Send>,
AdmissionControlledS3StorageError,
> {
match self
.storage
.get_stream(key)
.instrument(tracing::trace_span!(parent: Span::current(), "Storage get"))
.await
{
Ok(res) => Ok(res),
Err(e) => {
tracing::error!("Error reading from storage: {}", e);
return Err(AdmissionControlledS3StorageError::S3GetError(e));
}
}
}

async fn read_from_storage(
storage: S3Storage,
key: String,
) -> Result<Arc<Vec<u8>>, AdmissionControlledS3StorageError> {
let stream = storage
let bytes_res = storage
.get(&key)
.instrument(tracing::trace_span!(parent: Span::current(), "Storage get"))
.await;
match stream {
Ok(mut bytes) => {
let read_block_span =
tracing::trace_span!(parent: Span::current(), "Read bytes to end");
let buf = read_block_span
.in_scope(|| async {
let mut buf: Vec<u8> = Vec::new();
while let Some(res) = bytes.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(err) => {
tracing::error!("Error reading from storage: {}", err);
match err {
GetError::S3Error(e) => {
return Err(
AdmissionControlledS3StorageError::S3GetError(e),
);
}
GetError::NoSuchKey(e) => {
return Err(
AdmissionControlledS3StorageError::S3GetError(
S3GetError::NoSuchKey(e),
),
);
}
GetError::LocalError(_) => unreachable!(),
}
}
}
}
tracing::info!("Read {:?} bytes from s3", buf.len());
Ok(Some(buf))
})
.await?;
match buf {
Some(buf) => Ok(Arc::new(buf)),
None => {
// Buffer is empty. Nothing interesting to do.
Ok(Arc::new(vec![]))
}
}
match bytes_res {
Ok(bytes) => {
return Ok(bytes);
}
Err(e) => {
tracing::error!("Error reading from storage: {}", e);
Expand Down Expand Up @@ -144,4 +131,29 @@ impl AdmissionControlledS3Storage {
}
res
}

pub async fn put_file(&self, key: &str, path: &str) -> Result<(), S3PutError> {
self.storage.put_file(key, path).await
}

pub async fn put_bytes(&self, key: &str, bytes: Vec<u8>) -> Result<(), S3PutError> {
self.storage.put_bytes(key, bytes).await
}
}

#[async_trait]
impl Configurable<StorageConfig> for AdmissionControlledS3Storage {
async fn try_from_config(config: &StorageConfig) -> Result<Self, Box<dyn ChromaError>> {
match &config {
StorageConfig::AdmissionControlledS3(nacconfig) => {
let s3_storage =
S3Storage::try_from_config(&StorageConfig::S3(nacconfig.s3_config.clone()))
.await?;
return Ok(Self::new(s3_storage));
}
_ => {
return Err(Box::new(StorageConfigError::InvalidStorageConfig));
}
}
}
}
13 changes: 10 additions & 3 deletions rust/storage/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@ pub enum StorageConfig {
S3(S3StorageConfig),
#[serde(alias = "local")]
Local(LocalStorageConfig),
#[serde(alias = "admissioncontrolleds3")]
AdmissionControlledS3(AdmissionControlledS3StorageConfig),
}

#[derive(Deserialize, PartialEq, Debug)]
#[derive(Deserialize, PartialEq, Debug, Clone)]
pub enum S3CredentialsConfig {
Minio,
AWS,
}

#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
/// The configuration for the s3 storage type
/// # Fields
/// - bucket: The name of the bucket to use.
Expand All @@ -33,7 +35,7 @@ pub struct S3StorageConfig {
pub upload_part_size_bytes: usize,
}

#[derive(Deserialize, Debug)]
#[derive(Deserialize, Debug, Clone)]
/// The configuration for the local storage type
/// # Fields
/// - root: The root directory to use for storage.
Expand All @@ -43,3 +45,8 @@ pub struct S3StorageConfig {
pub struct LocalStorageConfig {
pub root: String,
}

#[derive(Deserialize, Debug, Clone)]
pub struct AdmissionControlledS3StorageConfig {
pub s3_config: S3StorageConfig,
}
68 changes: 65 additions & 3 deletions rust/storage/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::sync::Arc;

use self::config::StorageConfig;
use self::s3::S3GetError;
use self::stream::ByteStreamItem;
use admissioncontrolleds3::AdmissionControlledS3StorageError;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};

Expand All @@ -16,6 +19,7 @@ use thiserror::Error;
pub enum Storage {
S3(s3::S3Storage),
Local(local::LocalStorage),
AdmissionControlledS3(admissioncontrolleds3::AdmissionControlledS3Storage),
}

#[derive(Error, Debug, Clone)]
Expand Down Expand Up @@ -56,13 +60,48 @@ impl ChromaError for PutError {
}

impl Storage {
pub async fn get(
pub async fn get(&self, key: &str) -> Result<Arc<Vec<u8>>, GetError> {
match self {
Storage::S3(s3) => {
let res = s3.get(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => match e {
S3GetError::NoSuchKey(_) => Err(GetError::NoSuchKey(key.to_string())),
_ => Err(GetError::S3Error(e)),
},
}
}
Storage::Local(local) => {
let res = local.get(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => Err(GetError::LocalError(e)),
}
}
Storage::AdmissionControlledS3(admission_controlled_storage) => {
let res = admission_controlled_storage.get(key.to_string()).await;
match res {
Ok(res) => Ok(res),
Err(e) => match e {
AdmissionControlledS3StorageError::S3GetError(e) => match e {
S3GetError::NoSuchKey(_) => Err(GetError::NoSuchKey(key.to_string())),
_ => Err(GetError::S3Error(e)),
},
},
}
}
}
}

// TODO: Remove this once the upstream switches to consume non-streaming.
pub async fn get_stream(
&self,
key: &str,
) -> Result<Box<dyn Stream<Item = ByteStreamItem> + Unpin + Send>, GetError> {
match self {
Storage::S3(s3) => {
let res = s3.get(key).await;
let res = s3.get_stream(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => match e {
Expand All @@ -72,12 +111,24 @@ impl Storage {
}
}
Storage::Local(local) => {
let res = local.get(key).await;
let res = local.get_stream(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => Err(GetError::LocalError(e)),
}
}
Storage::AdmissionControlledS3(admission_controlled_storage) => {
let res = admission_controlled_storage.get_stream(key).await;
match res {
Ok(res) => Ok(res),
Err(e) => match e {
AdmissionControlledS3StorageError::S3GetError(e) => match e {
S3GetError::NoSuchKey(_) => Err(GetError::NoSuchKey(key.to_string())),
_ => Err(GetError::S3Error(e)),
},
},
}
}
}
}

Expand All @@ -91,6 +142,10 @@ impl Storage {
.put_file(key, path)
.await
.map_err(|e| PutError::LocalError(e)),
Storage::AdmissionControlledS3(as3) => as3
.put_file(key, path)
.await
.map_err(|e| PutError::S3Error(e)),
}
}

Expand All @@ -104,6 +159,10 @@ impl Storage {
.put_bytes(key, &bytes)
.await
.map_err(|e| PutError::LocalError(e)),
Storage::AdmissionControlledS3(as3) => as3
.put_bytes(key, bytes)
.await
.map_err(|e| PutError::S3Error(e)),
}
}
}
Expand All @@ -114,5 +173,8 @@ pub async fn from_config(config: &StorageConfig) -> Result<Storage, Box<dyn Chro
StorageConfig::Local(_) => Ok(Storage::Local(
local::LocalStorage::try_from_config(config).await?,
)),
StorageConfig::AdmissionControlledS3(_) => Ok(Storage::AdmissionControlledS3(
admissioncontrolleds3::AdmissionControlledS3Storage::try_from_config(config).await?,
)),
}
}
47 changes: 46 additions & 1 deletion rust/storage/src/local.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use super::stream::ByteStream;
use super::stream::ByteStreamItem;
use super::{config::StorageConfig, s3::StorageConfigError};
use crate::GetError;
use async_trait::async_trait;
use chroma_config::Configurable;
use chroma_error::ChromaError;
use futures::Stream;
use futures::StreamExt;
use std::sync::Arc;
use tracing::Instrument;
use tracing::Span;

#[derive(Clone)]
pub struct LocalStorage {
Expand All @@ -19,7 +24,47 @@ impl LocalStorage {
};
}

pub async fn get(
pub async fn get(&self, key: &str) -> Result<Arc<Vec<u8>>, String> {
let mut stream = self
.get_stream(&key)
.instrument(tracing::trace_span!(parent: Span::current(), "Local Storage get"))
.await?;
let read_block_span =
tracing::trace_span!(parent: Span::current(), "Local storage read bytes to end");
let buf = read_block_span
.in_scope(|| async {
let mut buf: Vec<u8> = Vec::new();
sanketkedia marked this conversation as resolved.
Show resolved Hide resolved
while let Some(res) = stream.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(err) => {
tracing::error!("Error reading from storage: {}", err);
match err {
GetError::LocalError(e) => {
return Err(e);
}
_ => unreachable!(),
}
}
}
}
tracing::info!("Read {:?} bytes from local storage", buf.len());
Ok(Some(buf))
})
.await?;

match buf {
Some(buf) => Ok(Arc::new(buf)),
None => {
// Buffer is empty. Nothing interesting to do.
Ok(Arc::new(vec![]))
}
}
}

pub(super) async fn get_stream(
&self,
key: &str,
) -> Result<Box<dyn Stream<Item = ByteStreamItem> + Unpin + Send>, String> {
Expand Down
Loading