Skip to content

Commit

Permalink
[ENH] Add query-service server (chroma-core#1899)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Add a new entrypoint for query-service. Delete the old worker
entrypoint that did not have read/write decoupling.
	 - Make dispatcher configurable
- Wrapped hnsw orchestrator logic into run() so server is unaware of it
	 - Make server struct {} have the resources it needs
 - New functionality
	  - Add dynamic creation for log
	 - Add dynamic creation for sysdb
	 - Make server respond to query by using orchestrator 


## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
None
  • Loading branch information
HammadB authored and atroyn committed Apr 3, 2024
1 parent fb48996 commit b113e0b
Show file tree
Hide file tree
Showing 20 changed files with 289 additions and 65 deletions.
6 changes: 3 additions & 3 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ version = "0.1.0"
edition = "2021"

[[bin]]
name = "worker"
path = "src/bin/worker.rs"
name = "query_service"
path = "src/bin/query_service.rs"

[dependencies]
tonic = "0.10"
Expand Down Expand Up @@ -46,4 +46,4 @@ proptest-state-machine = "0.1.0"

[build-dependencies]
tonic-build = "0.10"
cc = "1.0"
cc = "1.0"
4 changes: 4 additions & 0 deletions rust/worker/chroma_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ worker:
Grpc:
host: "logservice.chroma"
port: 50052
dispatcher:
num_worker_threads: 4
dispatcher_queue_size: 100
worker_queue_size: 100
6 changes: 6 additions & 0 deletions rust/worker/src/bin/query_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use worker::query_service_entrypoint;

#[tokio::main]
async fn main() {
query_service_entrypoint().await;
}
6 changes: 0 additions & 6 deletions rust/worker/src/bin/worker.rs

This file was deleted.

17 changes: 17 additions & 0 deletions rust/worker/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ pub(crate) struct WorkerConfig {
pub(crate) segment_manager: crate::segment::config::SegmentManagerConfig,
pub(crate) storage: crate::storage::config::StorageConfig,
pub(crate) log: crate::log::config::LogConfig,
pub(crate) dispatcher: crate::execution::config::DispatcherConfig,
}

/// # Description
Expand Down Expand Up @@ -165,6 +166,10 @@ mod tests {
Grpc:
host: "localhost"
port: 50052
dispatcher:
num_worker_threads: 4
dispatcher_queue_size: 100
worker_queue_size: 100
"#,
);
let config = RootConfig::load();
Expand Down Expand Up @@ -213,6 +218,10 @@ mod tests {
Grpc:
host: "localhost"
port: 50052
dispatcher:
num_worker_threads: 4
dispatcher_queue_size: 100
worker_queue_size: 100
"#,
);
Expand Down Expand Up @@ -277,6 +286,10 @@ mod tests {
Grpc:
host: "localhost"
port: 50052
dispatcher:
num_worker_threads: 4
dispatcher_queue_size: 100
worker_queue_size: 100
"#,
);
let config = RootConfig::load();
Expand Down Expand Up @@ -321,6 +334,10 @@ mod tests {
Grpc:
host: "localhost"
port: 50052
dispatcher:
num_worker_threads: 4
dispatcher_queue_size: 100
worker_queue_size: 100
"#,
);
let config = RootConfig::load();
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ pub(crate) enum ErrorCodes {
DataLoss = 15,
}

pub(crate) trait ChromaError: Error {
pub(crate) trait ChromaError: Error + Send {
fn code(&self) -> ErrorCodes;
}
8 changes: 8 additions & 0 deletions rust/worker/src/execution/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use serde::Deserialize;

#[derive(Deserialize)]
pub(crate) struct DispatcherConfig {
pub(crate) num_worker_threads: usize,
pub(crate) dispatcher_queue_size: usize,
pub(crate) worker_queue_size: usize,
}
45 changes: 31 additions & 14 deletions rust/worker/src/execution/dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use super::{operator::TaskMessage, worker_thread::WorkerThread};
use crate::system::{Component, ComponentContext, Handler, Receiver, System};
use crate::{
config::{Configurable, WorkerConfig},
errors::ChromaError,
system::{Component, ComponentContext, Handler, Receiver, System},
};
use async_trait::async_trait;
use std::fmt::Debug;

Expand Down Expand Up @@ -46,21 +50,27 @@ use std::fmt::Debug;
coarser work-stealing, and other optimizations.
*/
#[derive(Debug)]
struct Dispatcher {
pub(crate) struct Dispatcher {
task_queue: Vec<TaskMessage>,
waiters: Vec<TaskRequestMessage>,
n_worker_threads: usize,
queue_size: usize,
worker_queue_size: usize,
}

impl Dispatcher {
/// Create a new dispatcher
/// # Parameters
/// - n_worker_threads: The number of worker threads to use
pub fn new(n_worker_threads: usize) -> Self {
/// - queue_size: The size of the components message queue
/// - worker_queue_size: The size of the worker components queue
pub fn new(n_worker_threads: usize, queue_size: usize, worker_queue_size: usize) -> Self {
Dispatcher {
task_queue: Vec::new(),
waiters: Vec::new(),
n_worker_threads,
queue_size,
worker_queue_size,
}
}

Expand All @@ -74,7 +84,7 @@ impl Dispatcher {
self_receiver: Box<dyn Receiver<TaskRequestMessage>>,
) {
for _ in 0..self.n_worker_threads {
let worker = WorkerThread::new(self_receiver.clone());
let worker = WorkerThread::new(self_receiver.clone(), self.worker_queue_size);
system.start_component(worker);
}
}
Expand Down Expand Up @@ -118,6 +128,17 @@ impl Dispatcher {
}
}

#[async_trait]
impl Configurable for Dispatcher {
async fn try_from_config(worker_config: &WorkerConfig) -> Result<Self, Box<dyn ChromaError>> {
Ok(Dispatcher::new(
worker_config.dispatcher.num_worker_threads,
worker_config.dispatcher.dispatcher_queue_size,
worker_config.dispatcher.worker_queue_size,
))
}
}

/// A message that a worker thread sends to the dispatcher to request a task
/// # Members
/// - reply_to: The receiver to send the task to, this is the worker thread
Expand All @@ -141,7 +162,7 @@ impl TaskRequestMessage {
#[async_trait]
impl Component for Dispatcher {
fn queue_size(&self) -> usize {
1000 // TODO: make configurable
self.queue_size
}

async fn on_start(&mut self, ctx: &ComponentContext<Self>) {
Expand All @@ -166,19 +187,15 @@ impl Handler<TaskRequestMessage> for Dispatcher {

#[cfg(test)]
mod tests {
use std::{
env::current_dir,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};

use super::*;
use crate::{
execution::operator::{wrap, Operator},
system::System,
};
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};

// Create a component that will schedule DISPATCH_COUNT invocations of the MockOperator
// on an interval of DISPATCH_FREQUENCY_MS.
Expand Down Expand Up @@ -249,7 +266,7 @@ mod tests {
#[tokio::test]
async fn test_dispatcher() {
let mut system = System::new();
let dispatcher = Dispatcher::new(THREAD_COUNT);
let dispatcher = Dispatcher::new(THREAD_COUNT, 1000, 1000);
let dispatcher_handle = system.start_component(dispatcher);
let counter = Arc::new(AtomicUsize::new(0));
let dispatch_user = MockDispatchUser {
Expand Down
7 changes: 4 additions & 3 deletions rust/worker/src/execution/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod dispatcher;
mod operator;
pub(crate) mod config;
pub(crate) mod dispatcher;
pub(crate) mod operator;
mod operators;
mod orchestration;
pub(crate) mod orchestration;
mod worker_thread;
4 changes: 2 additions & 2 deletions rust/worker/src/execution/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ where
}

/// A message type used by the dispatcher to send tasks to worker threads.
pub(super) type TaskMessage = Box<dyn TaskWrapper>;
pub(crate) type TaskMessage = Box<dyn TaskWrapper>;

/// A task wrapper is a trait that can be used to run a task. We use it to
/// erase the I, O types from the Task struct so that tasks.
#[async_trait]
pub(super) trait TaskWrapper: Send + Debug {
pub(crate) trait TaskWrapper: Send + Debug {
async fn run(&self);
}

Expand Down
48 changes: 45 additions & 3 deletions rust/worker/src/execution/orchestration/hnsw.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use super::super::operator::{wrap, TaskMessage};
use super::super::operators::pull_log::{PullLogsInput, PullLogsOperator, PullLogsOutput};
use crate::errors::ChromaError;
use crate::sysdb::sysdb::SysDb;
use crate::system::System;
use crate::types::VectorQueryResult;
use crate::{
log::log::Log,
system::{Component, Handler, Receiver},
};
use async_trait::async_trait;
use std::fmt::{self, Debug, Formatter};
use num_bigint::BigInt;
use std::fmt::Debug;
use uuid::Uuid;

/** The state of the orchestrator.
Expand Down Expand Up @@ -35,8 +39,10 @@ enum ExecutionState {
}

#[derive(Debug)]
struct HnswQueryOrchestrator {
pub(crate) struct HnswQueryOrchestrator {
state: ExecutionState,
// Component Execution
system: System,
// Query state
query_vectors: Vec<Vec<f32>>,
k: i32,
Expand All @@ -46,10 +52,15 @@ struct HnswQueryOrchestrator {
log: Box<dyn Log>,
sysdb: Box<dyn SysDb>,
dispatcher: Box<dyn Receiver<TaskMessage>>,
// Result channel
result_channel: Option<
tokio::sync::oneshot::Sender<Result<Vec<Vec<VectorQueryResult>>, Box<dyn ChromaError>>>,
>,
}

impl HnswQueryOrchestrator {
pub fn new(
pub(crate) fn new(
system: System,
query_vectors: Vec<Vec<f32>>,
k: i32,
include_embeddings: bool,
Expand All @@ -60,13 +71,15 @@ impl HnswQueryOrchestrator {
) -> Self {
HnswQueryOrchestrator {
state: ExecutionState::Pending,
system,
query_vectors,
k,
include_embeddings,
segment_id,
log,
sysdb,
dispatcher,
result_channel: None,
}
}

Expand Down Expand Up @@ -108,6 +121,19 @@ impl HnswQueryOrchestrator {
}
}
}

/// Run the orchestrator and return the result.
/// # Note
/// Use this over spawning the component directly. This method will start the component and
/// wait for it to finish before returning the result.
pub(crate) async fn run(mut self) -> Result<Vec<Vec<VectorQueryResult>>, Box<dyn ChromaError>> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.result_channel = Some(tx);
let mut handle = self.system.clone().start_component(self);
let result = rx.await;
handle.stop();
result.unwrap()
}
}

// ============== Component Implementation ==============
Expand All @@ -133,6 +159,22 @@ impl Handler<PullLogsOutput> for HnswQueryOrchestrator {
ctx: &crate::system::ComponentContext<HnswQueryOrchestrator>,
) {
self.state = ExecutionState::Dedupe;

// TODO: implement the remaining state transitions and operators
// This is an example of the final state transition and result

match self.result_channel.take() {
Some(tx) => {
let _ = tx.send(Ok(vec![vec![VectorQueryResult {
id: "abc".to_string(),
seq_id: BigInt::from(0),
distance: 0.0,
vector: Some(vec![0.0, 0.0, 0.0]),
}]]));
}
None => {
// Log an error
}
}
}
}
2 changes: 2 additions & 0 deletions rust/worker/src/execution/orchestration/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
mod hnsw;

pub(crate) use hnsw::*;
13 changes: 10 additions & 3 deletions rust/worker/src/execution/worker_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@ use std::fmt::{Debug, Formatter, Result};
/// - The actor loop will block until work is available
pub(super) struct WorkerThread {
dispatcher: Box<dyn Receiver<TaskRequestMessage>>,
queue_size: usize,
}

impl WorkerThread {
pub(super) fn new(dispatcher: Box<dyn Receiver<TaskRequestMessage>>) -> Self {
WorkerThread { dispatcher }
pub(super) fn new(
dispatcher: Box<dyn Receiver<TaskRequestMessage>>,
queue_size: usize,
) -> WorkerThread {
WorkerThread {
dispatcher,
queue_size,
}
}
}

Expand All @@ -26,7 +33,7 @@ impl Debug for WorkerThread {
#[async_trait]
impl Component for WorkerThread {
fn queue_size(&self) -> usize {
1000 // TODO: make configurable
self.queue_size
}

fn runtime() -> ComponentRuntime {
Expand Down
Loading

0 comments on commit b113e0b

Please sign in to comment.