Skip to content

Commit

Permalink
consume NAC
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Aug 3, 2024
1 parent ac57655 commit 648e0f8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 115 deletions.
90 changes: 33 additions & 57 deletions rust/blockstore/src/arrow/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ use async_trait::async_trait;
use chroma_cache::cache::Cache;
use chroma_config::Configurable;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_storage::{network_admission_control::NetworkAdmissionControl, Storage};
use chroma_storage::{
network_admission_control::{NetworkAdmissionControl, NetworkAdmissionControlError},
Storage,
};
use core::panic;
use futures::StreamExt;
use thiserror::Error;
Expand Down Expand Up @@ -233,66 +236,39 @@ impl BlockManager {
match block {
Some(block) => Some(block.clone()),
None => {
// TODO: NAC register/deregister/validation goes here.
async {
let key = format!("block/{}", id);
let stream = self.storage.get(&key).instrument(
tracing::trace_span!(parent: Span::current(), "BlockManager storage get"),
).await;
match stream {
Ok(mut bytes) => {
let read_block_span = tracing::trace_span!(parent: Span::current(), "BlockManager read bytes to end");
let buf = read_block_span.in_scope(|| async {
let mut buf: Vec<u8> = Vec::new();
while let Some(res) = bytes.next().await {
match res {
Ok(chunk) => {
buf.extend(chunk);
}
Err(e) => {
tracing::error!("Error reading block from storage: {}", e);
return None;
}
}
}
Some(buf)
}
).await;
let buf = match buf {
Some(buf) => {
buf
}
None => {
return None;
}
};
tracing::info!("Read {:?} bytes from s3", buf.len());
let deserialization_span = tracing::trace_span!(parent: Span::current(), "BlockManager deserialize block");
let block = deserialization_span.in_scope(|| Block::from_bytes(&buf, *id));
match block {
Ok(block) => {
self.block_cache.insert(*id, block.clone());
Some(block)
}
Err(e) => {
// TODO: Return an error to callsite instead of None.
tracing::error!(
"Error converting bytes to Block {:?}/{:?}",
key,
e
);
None
}
}
},
let key = format!("block/{}", id);
// Clone the cache is cheap since it is arc type.
let cache_clone = self.block_cache.clone();
let id_copy = id.clone();
let cb = move |buf: Vec<u8>| async move {
let deserialization_span = tracing::trace_span!(parent: Span::current(), "BlockManager deserialize block");
let block = deserialization_span.in_scope(|| Block::from_bytes(&buf, id_copy));
match block {
Ok(block) => {
cache_clone.insert(id_copy, block.clone());
}
Err(e) => {
tracing::error!("Error reading block from storage: {}", e);
None
tracing::error!("Error converting bytes to Block {:?}", e);
return Err(Box::new(
NetworkAdmissionControlError::DeserializationError,
));
}
}
Ok(())
};
match self.network_admission_control.get(key, cb).await {
Ok(()) => {}
Err(e) => {
// TODO: Return error here.
tracing::error!(
"Error getting block from the network admission control {}",
e
);
return None;
}
}
.instrument(tracing::trace_span!(parent: Span::current(), "BlockManager get cold"))
.await
// Cache must be populated now.
self.block_cache.get(id)
}
}
}
Expand Down
87 changes: 29 additions & 58 deletions rust/index/src/hnsw_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use chroma_cache::cache::Cache;
use chroma_config::Configurable;
use chroma_error::ChromaError;
use chroma_error::ErrorCodes;
use chroma_storage::network_admission_control::NetworkAdmissionControl;
use chroma_storage::network_admission_control::{
NetworkAdmissionControl, NetworkAdmissionControlError,
};
use chroma_storage::stream::ByteStreamItem;
use chroma_storage::Storage;
use chroma_types::Segment;
Expand Down Expand Up @@ -162,78 +164,49 @@ impl HnswIndexProvider {
source_id: &Uuid,
index_storage_path: &Path,
) -> Result<(), Box<HnswIndexProviderFileError>> {
// Fetch the files from storage and put them in the index storage path
// Fetch the files from storage and put them in the index storage path.
// TODO: Fetch multiple chunks in parallel from S3.
for file in FILES.iter() {
let key = self.format_key(source_id, file);
tracing::info!("Loading hnsw index file: {}", key);
let stream = self.storage.get(&key).await;
let reader = match stream {
Ok(reader) => reader,
Err(e) => {
tracing::error!("Failed to load hnsw index file from storage: {}", e);
return Err(Box::new(HnswIndexProviderFileError::StorageGetError(e)));
}
};

let file_path = index_storage_path.join(file);
// For now, we never evict from the cache, so if the index is being loaded, the file does not exist
let file_handle = tokio::fs::File::create(&file_path).await;
let file_handle = match file_handle {
let mut file_handle = match file_handle {
Ok(file) => file,
Err(e) => {
tracing::error!("Failed to create file: {}", e);
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
}
};
let total_bytes_written = self
.copy_stream_to_local_file(reader, file_handle)
.instrument(tracing::info_span!(parent: Span::current(), "hnsw provider file read", file = file))
.await?;
tracing::info!(
"Copied {} bytes from storage key: {} to file: {}",
total_bytes_written,
key,
file_path.to_str().unwrap()
);
// bytes is an AsyncBufRead, so we fil and consume it to a file
tracing::info!("Loaded hnsw index file: {}", file);
}
Ok(())
}

async fn copy_stream_to_local_file(
&self,
stream: Box<dyn stream::Stream<Item = ByteStreamItem> + Unpin + Send>,
file_handle: tokio::fs::File,
) -> Result<u64, Box<HnswIndexProviderFileError>> {
let mut total_bytes_written = 0;
let mut file_handle = file_handle;
let mut stream = stream;
while let Some(res) = stream.next().await {
let chunk = match res {
Ok(chunk) => chunk,
Err(e) => {
return Err(Box::new(HnswIndexProviderFileError::StorageGetError(e)));
let cb = move |buf: Vec<u8>| async move {
let res = file_handle.write_all(&buf).await;
match res {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to copy file: {}", e);
return Err(Box::new(NetworkAdmissionControlError::IOError));
}
}
};

let res = file_handle.write_all(&chunk).await;
match res {
Ok(_) => {
total_bytes_written += chunk.len() as u64;
match file_handle.flush().await {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to flush file: {}", e);
return Err(Box::new(NetworkAdmissionControlError::IOError));
}
}
Ok(())
};
match self.network_admission_control.get(key, cb).await {
Ok(_) => {}
Err(e) => {
tracing::error!("Failed to copy file: {}", e);
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
return Err(Box::new(HnswIndexProviderFileError::NACError(*e)));
}
}
tracing::info!("Loaded hnsw index file: {}", file);
}
match file_handle.flush().await {
Ok(_) => Ok(total_bytes_written),
Err(e) => {
return Err(Box::new(HnswIndexProviderFileError::IOError(e)));
}
}
Ok(())
}

pub async fn open(
Expand Down Expand Up @@ -492,10 +465,8 @@ impl ChromaError for HnswIndexProviderFlushError {
pub enum HnswIndexProviderFileError {
#[error("IO Error")]
IOError(#[from] std::io::Error),
#[error("Storage Get Error")]
StorageGetError(#[from] chroma_storage::GetError),
#[error("Storage Put Error")]
StoragePutError(#[from] chroma_storage::PutError),
#[error("NAC Error")]
NACError(#[from] NetworkAdmissionControlError),
}

#[cfg(test)]
Expand Down

0 comments on commit 648e0f8

Please sign in to comment.