From fe9e0f3aaca9582352130b462f391a3047841a6c Mon Sep 17 00:00:00 2001 From: hammadb Date: Wed, 20 Mar 2024 08:09:59 -0700 Subject: [PATCH] Add error handling --- rust/worker/src/errors.rs | 1 - rust/worker/src/execution/dispatcher.rs | 13 +++++++--- rust/worker/src/execution/operator.rs | 21 +++++++++------ .../src/execution/operators/pull_log.rs | 16 ++++++++---- .../src/execution/orchestration/hnsw.rs | 26 +++++++++++++------ 5 files changed, 51 insertions(+), 26 deletions(-) diff --git a/rust/worker/src/errors.rs b/rust/worker/src/errors.rs index 968dbaeeb17..086b938f265 100644 --- a/rust/worker/src/errors.rs +++ b/rust/worker/src/errors.rs @@ -1,7 +1,6 @@ // Defines 17 standard error codes based on the error codes defined in the // gRPC spec. https://grpc.github.io/grpc/core/md_doc_statuscodes.html // Custom errors can use these codes in order to allow for generic handling - use std::error::Error; #[derive(PartialEq, Debug)] diff --git a/rust/worker/src/execution/dispatcher.rs b/rust/worker/src/execution/dispatcher.rs index 8a25e0b26fe..b1668b1c60a 100644 --- a/rust/worker/src/execution/dispatcher.rs +++ b/rust/worker/src/execution/dispatcher.rs @@ -210,13 +210,14 @@ mod tests { struct MockOperator {} #[async_trait] impl Operator for MockOperator { - async fn run(&self, input: &f32) -> String { + type Error = (); + async fn run(&self, input: &f32) -> Result { // sleep to simulate work tokio::time::sleep(tokio::time::Duration::from_millis( MOCK_OPERATOR_SLEEP_DURATION_MS, )) .await; - input.to_string() + Ok(input.to_string()) } } @@ -244,8 +245,12 @@ mod tests { } } #[async_trait] - impl Handler for MockDispatchUser { - async fn handle(&mut self, message: String, ctx: &ComponentContext) { + impl Handler> for MockDispatchUser { + async fn handle( + &mut self, + message: Result, + ctx: &ComponentContext, + ) { self.counter.fetch_add(1, Ordering::SeqCst); let curr_count = self.counter.load(Ordering::SeqCst); // Cancel self diff --git a/rust/worker/src/execution/operator.rs b/rust/worker/src/execution/operator.rs index 85baa7d8c7d..935c01eb16e 100644 --- a/rust/worker/src/execution/operator.rs +++ b/rust/worker/src/execution/operator.rs @@ -10,20 +10,23 @@ where I: Send + Sync, O: Send + Sync, { - async fn run(&self, input: &I) -> O; + type Error; + // It would have been nice to do this with a default trait for result + // but that's not stable in rust yet. + async fn run(&self, input: &I) -> Result; } /// A task is a wrapper around an operator and its input. /// It is a description of a function to be run. #[derive(Debug)] -struct Task +struct Task where Input: Send + Sync + Debug, Output: Send + Sync + Debug, { - operator: Box>, + operator: Box>, input: Input, - reply_channel: Box>, + reply_channel: Box>>, } /// A message type used by the dispatcher to send tasks to worker threads. @@ -40,8 +43,9 @@ pub(crate) trait TaskWrapper: Send + Debug { /// erase the I, O types from the Task struct so that tasks can be /// stored in a homogenous queue regardless of their input and output types. #[async_trait] -impl TaskWrapper for Task +impl TaskWrapper for Task where + Error: Debug, Input: Send + Sync + Debug, Output: Send + Sync + Debug, { @@ -53,12 +57,13 @@ where } /// Wrap an operator and its input into a task message. -pub(super) fn wrap( - operator: Box>, +pub(super) fn wrap( + operator: Box>, input: Input, - reply_channel: Box>, + reply_channel: Box>>, ) -> TaskMessage where + Error: Debug + 'static, Input: Send + Sync + Debug + 'static, Output: Send + Sync + Debug + 'static, { diff --git a/rust/worker/src/execution/operators/pull_log.rs b/rust/worker/src/execution/operators/pull_log.rs index 9e38f919505..7fb150fd34c 100644 --- a/rust/worker/src/execution/operators/pull_log.rs +++ b/rust/worker/src/execution/operators/pull_log.rs @@ -1,4 +1,8 @@ -use crate::{execution::operator::Operator, log::log::Log, types::EmbeddingRecord}; +use crate::{ + execution::operator::Operator, + log::log::{Log, PullLogsError}, + types::EmbeddingRecord, +}; use async_trait::async_trait; use uuid::Uuid; @@ -66,9 +70,12 @@ impl PullLogsOutput { } } +pub type PullLogsResult = Result; + #[async_trait] impl Operator for PullLogsOperator { - async fn run(&self, input: &PullLogsInput) -> PullLogsOutput { + type Error = PullLogsError; + async fn run(&self, input: &PullLogsInput) -> PullLogsResult { // We expect the log to be cheaply cloneable, we need to clone it since we need // a mutable reference to it. Not necessarily the best, but it works for our needs. let mut client_clone = self.client.clone(); @@ -79,8 +86,7 @@ impl Operator for PullLogsOperator { input.batch_size, None, ) - .await - .unwrap(); - PullLogsOutput::new(logs) + .await?; + Ok(PullLogsOutput::new(logs)) } } diff --git a/rust/worker/src/execution/orchestration/hnsw.rs b/rust/worker/src/execution/orchestration/hnsw.rs index d25862bc9e1..35c4134c940 100644 --- a/rust/worker/src/execution/orchestration/hnsw.rs +++ b/rust/worker/src/execution/orchestration/hnsw.rs @@ -1,6 +1,8 @@ use super::super::operator::{wrap, TaskMessage}; use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput}; use crate::errors::ChromaError; +use crate::execution::operators::pull_log::PullLogsResult; +use crate::log::log::PullLogsError; use crate::sysdb::sysdb::SysDb; use crate::system::System; use crate::types::VectorQueryResult; @@ -102,7 +104,7 @@ impl HnswQueryOrchestrator { } } - async fn pull_logs(&mut self, self_address: Box>) { + async fn pull_logs(&mut self, self_address: Box>) { self.state = ExecutionState::PullLogs; let operator = PullLogsOperator::new(self.log.clone()); let collection_id = match self.get_collection_id_for_segment_id(self.segment_id).await { @@ -152,10 +154,10 @@ impl Component for HnswQueryOrchestrator { // ============== Handlers ============== #[async_trait] -impl Handler for HnswQueryOrchestrator { +impl Handler for HnswQueryOrchestrator { async fn handle( &mut self, - message: PullLogsOutput, + message: PullLogsResult, ctx: &crate::system::ComponentContext, ) { self.state = ExecutionState::Dedupe; @@ -163,17 +165,25 @@ impl Handler for HnswQueryOrchestrator { // TODO: implement the remaining state transitions and operators // This is an example of the final state transition and result - match self.result_channel.take() { - Some(tx) => { - let _ = tx.send(Ok(vec![vec![VectorQueryResult { + let result_channel = match self.result_channel.take() { + Some(tx) => tx, + None => { + // Log an error + return; + } + }; + + match message { + Ok(logs) => { + let _ = result_channel.send(Ok(vec![vec![VectorQueryResult { id: "abc".to_string(), seq_id: BigInt::from(0), distance: 0.0, vector: Some(vec![0.0, 0.0, 0.0]), }]])); } - None => { - // Log an error + Err(e) => { + let _ = result_channel.send(Err(Box::new(e))); } } }