Skip to content

Commit

Permalink
error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Mar 20, 2024
1 parent b28bcda commit 5c8f1c3
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 65 deletions.
2 changes: 1 addition & 1 deletion rust/worker/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ pub(crate) enum ErrorCodes {
DataLoss = 15,
}

pub(crate) trait ChromaError: Error {
pub(crate) trait ChromaError: Error + Send {
fn code(&self) -> ErrorCodes;
}
28 changes: 19 additions & 9 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use super::super::operator::{wrap, TaskMessage};
use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput};
use crate::errors::ChromaError;
use crate::sysdb::sysdb::SysDb;
use crate::system::System;
use crate::types::VectorQueryResult;
use crate::{
log::log::Log,
system::{Component, Handler, Receiver},
};
use async_trait::async_trait;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use num_bigint::BigInt;
use std::fmt::Debug;
use uuid::Uuid;

/** The state of the orchestrator.
Expand Down Expand Up @@ -50,8 +52,10 @@ pub(crate) struct HnswQueryOrchestrator {
log: Box<dyn Log>,
sysdb: Box<dyn SysDb>,
dispatcher: Box<dyn Receiver<TaskMessage>>,
// Result container. TODO: This should be VectorQueryResult
result_channel: Option<tokio::sync::oneshot::Sender<String>>,
// Result channel
result_channel: Option<
tokio::sync::oneshot::Sender<Result<Vec<Vec<VectorQueryResult>>, Box<dyn ChromaError>>>,
>,
}

impl HnswQueryOrchestrator {
Expand Down Expand Up @@ -122,8 +126,7 @@ impl HnswQueryOrchestrator {
/// # Note
/// Use this over spawning the component directly. This method will start the component and
/// wait for it to finish before returning the result.
/// RESUME POINT: RETURN THE CORRECT TYPE HERE
pub(crate) async fn run(mut self) -> String {
pub(crate) async fn run(mut self) -> Result<Vec<Vec<VectorQueryResult>>, Box<dyn ChromaError>> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.result_channel = Some(tx);
let mut handle = self.system.clone().start_component(self);
Expand Down Expand Up @@ -156,15 +159,22 @@ impl Handler<PullLogsOutput> for HnswQueryOrchestrator {
ctx: &crate::system::ComponentContext<HnswQueryOrchestrator>,
) {
self.state = ExecutionState::Dedupe;

// TODO: implement the remaining state transitions and operators
// This is an example of the final state transition and result

match self.result_channel.take() {
Some(tx) => {
let _ = tx.send("done".to_string());
let _ = tx.send(Ok(vec![vec![VectorQueryResult {
id: "abc".to_string(),
seq_id: BigInt::from(0),
distance: 0.0,
vector: Some(vec![0.0, 0.0, 0.0]),
}]]));
}
None => {
// Log an error
}
}
// TODO: implement the remaining state transitions and operators
// The query orchestrator kills itself in the last state
}
}
99 changes: 44 additions & 55 deletions rust/worker/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,77 +178,66 @@ impl chroma_proto::vector_reader_server::VectorReader for WorkerServer {
}
};

match self.system {
let result = match self.system {
Some(ref system) => {
let orchestrator = HnswQueryOrchestrator::new(
// TODO: Should not have to clone query vectors here
system.clone(),
query_vectors,
query_vectors.clone(),
request.k,
request.include_embeddings,
segment_uuid,
self.log.clone(),
self.sysdb.clone(),
dispatcher.clone(),
);
let result: String = orchestrator.run().await;
orchestrator.run().await
}
None => {
return Err(Status::internal("No system found"));
}
}

// for proto_query_vector in request.vectors {
// let (query_vector, encoding) = match proto_query_vector.try_into() {
// Ok((vector, encoding)) => (vector, encoding),
// Err(e) => {
// return Err(Status::internal(format!("Error converting vector: {}", e)));
// }
// };

// let results = match segment_manager
// .query_vector(
// &segment_uuid,
// &query_vector,
// request.k as usize,
// request.include_embeddings,
// )
// .await
// {
// Ok(results) => results,
// Err(e) => {
// return Err(Status::internal(format!("Error querying segment: {}", e)));
// }
// };
};

// let mut proto_results = Vec::new();
// for query_result in results {
// let proto_result = chroma_proto::VectorQueryResult {
// id: query_result.id,
// seq_id: query_result.seq_id.to_bytes_le().1,
// distance: query_result.distance,
// vector: match query_result.vector {
// Some(vector) => {
// match (vector, ScalarEncoding::FLOAT32, query_vector.len()).try_into() {
// Ok(proto_vector) => Some(proto_vector),
// Err(e) => {
// return Err(Status::internal(format!(
// "Error converting vector: {}",
// e
// )));
// }
// }
// }
// None => None,
// },
// };
// proto_results.push(proto_result);
// }
let result = match result {
Ok(result) => result,
Err(e) => {
return Err(Status::internal(format!(
"Error running orchestrator: {}",
e
)));
}
};

// let vector_query_results = chroma_proto::VectorQueryResults {
// results: proto_results,
// };
// proto_results_for_all.push(vector_query_results);
// }
for result_set in result {
let mut proto_results = Vec::new();
for query_result in result_set {
let proto_result = chroma_proto::VectorQueryResult {
id: query_result.id,
seq_id: query_result.seq_id.to_bytes_le().1,
distance: query_result.distance,
vector: match query_result.vector {
Some(vector) => {
match (vector, ScalarEncoding::FLOAT32, query_vectors[0].len())
.try_into()
{
Ok(proto_vector) => Some(proto_vector),
Err(e) => {
return Err(Status::internal(format!(
"Error converting vector: {}",
e
)));
}
}
}
None => None,
},
};
proto_results.push(proto_result);
}
proto_results_for_all.push(chroma_proto::VectorQueryResults {
results: proto_results,
});
}

let resp = chroma_proto::QueryVectorsResponse {
results: proto_results_for_all,
Expand Down

0 comments on commit 5c8f1c3

Please sign in to comment.