Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] make hnsw query orchestrator use BF operator #1927

Merged
merged 2 commits into from
Mar 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use super::super::operator::{wrap, TaskMessage};
use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput};
use crate::distance;
use crate::distance::DistanceFunction;
use crate::errors::ChromaError;
use crate::execution::operators::brute_force_knn::{
BruteForceKnnOperator, BruteForceKnnOperatorInput, BruteForceKnnOperatorOutput,
BruteForceKnnOperatorResult,
};
use crate::execution::operators::pull_log::PullLogsResult;
use crate::sysdb::sysdb::SysDb;
use crate::system::System;
Expand All @@ -12,7 +18,6 @@ use crate::{
use async_trait::async_trait;
use num_bigint::BigInt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;

Expand Down Expand Up @@ -118,7 +123,7 @@ impl HnswQueryOrchestrator {
let end_timestamp = SystemTime::now().duration_since(UNIX_EPOCH);
let end_timestamp = match end_timestamp {
// TODO: change protobuf definition to use u64 instead of i64
Ok(end_timestamp) => end_timestamp.as_nanos() as i64,
Ok(end_timestamp) => end_timestamp.as_secs() as i64,
Err(e) => {
// Log an error and reply + return
return;
Expand Down Expand Up @@ -173,8 +178,45 @@ impl Handler<PullLogsResult> for HnswQueryOrchestrator {
self.state = ExecutionState::Dedupe;

// TODO: implement the remaining state transitions and operators
// This is an example of the final state transition and result
// TODO: don't need all this cloning and data shuffling, once we land the chunk abstraction
let mut dataset = Vec::new();
match message {
Ok(logs) => {
for log in logs.logs().iter() {
// TODO: only adds have embeddings, unwrap is fine for now
dataset.push(log.embedding.clone().unwrap());
}
let bf_input = BruteForceKnnOperatorInput {
data: dataset,
query: self.query_vectors[0].clone(),
k: self.k as usize,
distance_metric: DistanceFunction::Euclidean,
};
let operator = Box::new(BruteForceKnnOperator {});
let task = wrap(operator, bf_input, ctx.sender.as_receiver());
match self.dispatcher.send(task).await {
Ok(_) => (),
Err(e) => {
// TODO: log an error and reply to caller
}
}
}
Err(e) => {
// Log an error
return;
}
}
}
}

#[async_trait]
impl Handler<BruteForceKnnOperatorResult> for HnswQueryOrchestrator {
async fn handle(
&mut self,
message: BruteForceKnnOperatorResult,
ctx: &crate::system::ComponentContext<HnswQueryOrchestrator>,
) {
// This is an example of the final state transition and result
let result_channel = match self.result_channel.take() {
Some(tx) => tx,
None => {
Expand All @@ -184,18 +226,29 @@ impl Handler<PullLogsResult> for HnswQueryOrchestrator {
};

match message {
Ok(logs) => {
// TODO: remove this after debugging
println!("Received logs: {:?}", logs);
let _ = result_channel.send(Ok(vec![vec![VectorQueryResult {
id: "abc".to_string(),
seq_id: BigInt::from(0),
distance: 0.0,
vector: Some(vec![0.0, 0.0, 0.0]),
}]]));
Ok(output) => {
let mut result = Vec::new();
let mut query_results = Vec::new();
for (index, distance) in output.indices.iter().zip(output.distances.iter()) {
let query_result = VectorQueryResult {
id: index.to_string(),
seq_id: BigInt::from(0),
distance: *distance,
vector: None,
};
query_results.push(query_result);
}
result.push(query_results);

match result_channel.send(Ok(result)) {
Ok(_) => (),
Err(e) => {
// Log an error
}
}
}
Err(e) => {
let _ = result_channel.send(Err(Box::new(e)));
Err(_) => {
// Log an error
}
}
}
Expand Down
Loading