Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Aug 1, 2024
1 parent bf84e8b commit e087224
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 32 deletions.
126 changes: 98 additions & 28 deletions rust/worker/src/execution/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ impl Dispatcher {
/// - task: The task to enqueue
async fn enqueue_task(&mut self, task: TaskMessage) {
match task.get_type() {
OperatorType::IoOperatorType => {
OperatorType::IoOperator => {
tokio::spawn(async move {
task.run().await;
});
}
OperatorType::OtherType => {
OperatorType::Other => {
// If a worker is waiting for a task, send it to the worker in FIFO order
// Otherwise, add it to the task queue
match self.waiters.pop() {
Expand All @@ -113,7 +113,7 @@ impl Dispatcher {
{
Ok(_) => {}
Err(e) => {
println!("Error sending task to worker: {:?}", e);
tracing::error!("Error sending task to worker: {:?}", e);
}
},
None => {
Expand Down Expand Up @@ -294,7 +294,70 @@ mod tests {
}

fn get_type(&self) -> OperatorType {
OperatorType::IoOperatorType
OperatorType::IoOperator
}
}

#[derive(Debug)]
struct MockIoDispatchUser {
pub dispatcher: ComponentHandle<Dispatcher>,
counter: Arc<AtomicUsize>, // We expect to recieve DISPATCH_COUNT messages
sent_tasks: Arc<Mutex<HashSet<Uuid>>>,
received_tasks: Arc<Mutex<HashSet<Uuid>>>,
}
#[async_trait]
impl Component for MockIoDispatchUser {
fn get_name() -> &'static str {
"Mock Io dispatcher"
}

fn queue_size(&self) -> usize {
1000
}

async fn on_start(&mut self, ctx: &ComponentContext<Self>) {
// dispatch a new task every DISPATCH_FREQUENCY_MS for DISPATCH_COUNT times
let duration = std::time::Duration::from_millis(DISPATCH_FREQUENCY_MS);
ctx.scheduler
.schedule_interval((), duration, Some(DISPATCH_COUNT), ctx);
}
}
#[async_trait]
impl Handler<TaskResult<String, ()>> for MockIoDispatchUser {
type Result = ();

async fn handle(
&mut self,
_message: TaskResult<String, ()>,
ctx: &ComponentContext<MockIoDispatchUser>,
) {
self.counter.fetch_add(1, Ordering::SeqCst);
let curr_count = self.counter.load(Ordering::SeqCst);
// Cancel self
if curr_count == DISPATCH_COUNT {
ctx.cancellation_token.cancel();
}
self.received_tasks.lock().insert(_message.id());
}
}

#[async_trait]
impl Handler<()> for MockIoDispatchUser {
type Result = ();

async fn handle(&mut self, _message: (), ctx: &ComponentContext<MockIoDispatchUser>) {
let rng = rand::thread_rng();
// Generate a random filename for writing and reading.
let filename = rng
.sample_iter(&Alphanumeric)
.take(5)
.map(char::from)
.collect();
println!("Scheduling mock io operator with filename {}", filename);
let task = wrap(Box::new(MockIoOperator {}), filename, ctx.receiver());
let task_id = task.id();
self.sent_tasks.lock().insert(task_id);
let res = self.dispatcher.send(task, None).await;
}
}

Expand Down Expand Up @@ -346,37 +409,44 @@ mod tests {
type Result = ();

async fn handle(&mut self, _message: (), ctx: &ComponentContext<MockDispatchUser>) {
// Randomly choose between IO task and other task.
let should_io;
let mut filename = String::from("dummy");
{
let mut rng = rand::thread_rng();
should_io = rng.gen_bool(1.0 / 2.0);
// Generate a random filename for writing and reading.
if should_io {
filename = rng
.sample_iter(&Alphanumeric)
.take(5)
.map(char::from)
.collect();
}
}
let task;
if should_io {
println!("Scheduling mock io operator with filename {}", filename);
task = wrap(Box::new(MockIoOperator {}), filename, ctx.receiver());
} else {
println!("Scheduling mock cpu operator with input {}", 42.0);
task = wrap(Box::new(MockOperator {}), 42.0, ctx.receiver());
}
println!("Scheduling mock cpu operator with input {}", 42.0);
let task = wrap(Box::new(MockOperator {}), 42.0, ctx.receiver());
let task_id = task.id();
self.sent_tasks.lock().insert(task_id);
let res = self.dispatcher.send(task, None).await;
}
}

#[tokio::test]
async fn test_dispatcher() {
async fn test_dispatcher_io_tasks() {
let system = System::new();
let dispatcher = Dispatcher::new(THREAD_COUNT, 1000, 1000);
let dispatcher_handle = system.start_component(dispatcher);
let counter = Arc::new(AtomicUsize::new(0));
let sent_tasks = Arc::new(Mutex::new(HashSet::new()));
let received_tasks = Arc::new(Mutex::new(HashSet::new()));
let dispatch_user = MockIoDispatchUser {
dispatcher: dispatcher_handle,
counter: counter.clone(),
sent_tasks: sent_tasks.clone(),
received_tasks: received_tasks.clone(),
};
let mut dispatch_user_handle = system.start_component(dispatch_user);
// yield to allow the component to process the messages
tokio::task::yield_now().await;
// Join on the dispatch user, since it will kill itself after DISPATCH_COUNT messages
dispatch_user_handle.join().await;
// We should have received DISPATCH_COUNT messages
assert_eq!(counter.load(Ordering::SeqCst), DISPATCH_COUNT);
// The sent tasks should be equal to the received tasks
assert_eq!(*sent_tasks.lock(), *received_tasks.lock());
// The length of the sent/recieved tasks should be equal to the number of dispatched tasks
assert_eq!(sent_tasks.lock().len(), DISPATCH_COUNT);
assert_eq!(received_tasks.lock().len(), DISPATCH_COUNT);
}

#[tokio::test]
async fn test_dispatcher_non_io_tasks() {
let system = System::new();
let dispatcher = Dispatcher::new(THREAD_COUNT, 1000, 1000);
let dispatcher_handle = system.start_component(dispatcher);
Expand Down
6 changes: 3 additions & 3 deletions rust/worker/src/execution/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use thiserror::Error;
use uuid::Uuid;

pub(crate) enum OperatorType {
IoOperatorType,
OtherType,
IoOperator,
Other,
}

/// An operator takes a generic input and returns a generic output.
Expand All @@ -28,7 +28,7 @@ where
async fn run(&self, input: &I) -> Result<O, Self::Error>;
fn get_name(&self) -> &'static str;
fn get_type(&self) -> OperatorType {
OperatorType::OtherType
OperatorType::Other
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,6 @@ impl Operator<RecordSegmentPrefetchIoInput, RecordSegmentPrefetchIoOutput>
}

fn get_type(&self) -> OperatorType {
OperatorType::IoOperatorType
OperatorType::IoOperator
}
}

0 comments on commit e087224

Please sign in to comment.