Skip to content

Commit

Permalink
[BUG] fix consumable join handle panicking after clone (#2448)
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb authored Jul 3, 2024
1 parent 28e6f5f commit 1d5999c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
7 changes: 6 additions & 1 deletion rust/worker/src/system/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::scheduler::Scheduler;
use super::ComponentContext;
use super::ComponentRuntime;
use super::ComponentSender;
use super::ConsumableJoinHandle;
use super::Message;
use super::{executor::ComponentExecutor, Component, ComponentHandle, Handler, StreamHandler};
use futures::Stream;
Expand Down Expand Up @@ -52,7 +53,11 @@ impl System {
trace_span!(parent: Span::current(), "component spawn", "name" = C::get_name());
let task_future = async move { executor.run(rx).await };
let join_handle = tokio::spawn(task_future.instrument(child_span));
return ComponentHandle::new(cancel_token, Some(join_handle), sender);
return ComponentHandle::new(
cancel_token,
Some(ConsumableJoinHandle::new(join_handle)),
sender,
);
}
ComponentRuntime::Dedicated => {
println!("Spawning on dedicated thread");
Expand Down
20 changes: 10 additions & 10 deletions rust/worker/src/system/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,20 @@ where

/// A thin wrapper over a join handle that will panic if it is consumed more than once.
#[derive(Debug, Clone)]
struct ConsumableJoinHandle {
handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub(super) struct ConsumableJoinHandle {
handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}

impl ConsumableJoinHandle {
fn new(handle: tokio::task::JoinHandle<()>) -> Self {
pub(super) fn new(handle: tokio::task::JoinHandle<()>) -> Self {
ConsumableJoinHandle {
handle: Some(Arc::new(handle)),
handle: Arc::new(Mutex::new(Some(handle))),
}
}

async fn consume(&mut self) -> Result<(), JoinError> {
match self.handle.take() {
match self.handle.lock().take() {
Some(handle) => {
let handle = Arc::into_inner(handle)
.expect("there should be no other strong references to the join handle");
handle.await?;
Ok(())
}
Expand Down Expand Up @@ -199,13 +197,13 @@ impl<C: Component> ComponentHandle<C> {
// Components with a dedicated runtime do not have a join handle
// and instead use a one shot channel to signal completion
// TODO: implement this
join_handle: Option<tokio::task::JoinHandle<()>>,
join_handle: Option<ConsumableJoinHandle>,
sender: ComponentSender<C>,
) -> Self {
ComponentHandle {
cancellation_token: cancellation_token,
state: Arc::new(Mutex::new(ComponentState::Running)),
join_handle: join_handle.map(|handle| ConsumableJoinHandle::new(handle)),
join_handle,
sender: sender,
}
}
Expand Down Expand Up @@ -372,8 +370,10 @@ mod tests {
async fn join_handle_panics_if_consumed_twice() {
let handle = tokio::spawn(async {});
let mut handle = ConsumableJoinHandle::new(handle);
// Should be able to clone the handle
let mut cloned = handle.clone();

handle.consume().await.unwrap();
cloned.consume().await.unwrap();
// Expected to panic
handle.consume().await.unwrap();
}
Expand Down

0 comments on commit 1d5999c

Please sign in to comment.