Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Cancel tasks spawned on compute runtime #3128

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions src/common/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
use common_error::{DaftError, DaftResult};
use futures::FutureExt;
use lazy_static::lazy_static;
use tokio::{runtime::RuntimeFlavor, task::JoinHandle};
use tokio::{runtime::RuntimeFlavor, task::JoinError};

lazy_static! {
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
Expand Down Expand Up @@ -41,7 +41,6 @@ impl Runtime {
Arc::new(Self { runtime, pool_type })
}

// TODO: figure out a way to cancel the Future if this output is dropped.
async fn execute_task<F>(future: F, pool_type: PoolType) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
Expand Down Expand Up @@ -81,36 +80,25 @@ impl Runtime {
rx.recv().expect("Spawned task transmitter dropped")
}

/// Spawn a task on the runtime and await on it.
/// You should use this when you are spawning compute or IO tasks from the Executor.
pub async fn await_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");
}
});
rx.await.expect("Spawned task transmitter dropped")
}

/// Blocks current thread to compute future. Can not be called in tokio runtime context
///
pub fn block_on_current_thread<F: Future>(&self, future: F) -> F::Output {
self.runtime.block_on(future)
}

pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
// Spawn a task on the runtime
pub fn spawn<F>(
&self,
future: F,
) -> impl Future<Output = Result<F::Output, JoinError>> + Send + 'static
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.runtime.spawn(future)
// Spawn it on a joinset on the runtime, such that if the future gets dropped, the task is cancelled
let mut joinset = tokio::task::JoinSet::new();
joinset.spawn_on(future, self.runtime.handle());
async move { joinset.join_next().await.expect("just spawned task") }.boxed()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't have to box this right? We should avoid the heap allocation if we can here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think it would be cleaner to make a struct that holds the joinset and then impl the future trait on that struct.

Similarly to what influx does for the their Job struct
https://github.com/metrico/influxdb_iox/blob/ab17bbc9efbb8568ea5a95ccb9d4bbddd33fc9ea/executor/src/lib.rs#L88

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good catch no box needed since there's only 1 branch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Encapsulated the fut in a RuntimeTask struct

colin-ho marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -183,3 +171,34 @@ pub fn get_io_pool_num_threads() -> Option<usize> {
Err(_) => None,
}
}

mod tests {

#[tokio::test]
async fn test_spawned_task_cancelled_when_dropped() {
use super::*;

let runtime = get_compute_runtime();
let ptr = Arc::new(AtomicUsize::new(0));
let ptr_clone = ptr.clone();

// Spawn a task that just does work in a loop
// The task should own a reference to the Arc, so the strong count should be 2
let task = async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
ptr_clone.fetch_add(1, Ordering::SeqCst);
}
};
let fut = runtime.spawn(task);
assert!(Arc::strong_count(&ptr) == 2);

// Drop the future, which should cancel the task
drop(fut);

// Wait for a while so that the task can be aborted
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
// The strong count should be 1 now
assert!(Arc::strong_count(&ptr) == 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ use common_display::tree::TreeDisplay;
use common_error::DaftResult;
use common_runtime::get_compute_runtime;
use daft_micropartition::MicroPartition;
use snafu::ResultExt;
use tracing::{info_span, instrument};

use super::buffer::OperatorBuffer;
use crate::{
channel::{create_channel, PipelineChannel, Receiver, Sender},
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext},
ExecutionRuntimeHandle, NUM_CPUS,
ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS,
};

pub(crate) trait DynIntermediateOpState: Send + Sync {
Expand Down Expand Up @@ -119,7 +120,7 @@ impl IntermediateNode {
let fut = async move {
rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
match result {
IntermediateOperatorResult::NeedMoreInput(Some(mp)) => {
let _ = sender.send(mp.into()).await;
Expand Down
10 changes: 6 additions & 4 deletions src/daft-local-execution/src/sinks/blocking_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use common_display::tree::TreeDisplay;
use common_error::DaftResult;
use common_runtime::get_compute_runtime;
use daft_micropartition::MicroPartition;
use snafu::ResultExt;
use tracing::info_span;

use crate::{
channel::PipelineChannel,
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::RuntimeStatsContext,
ExecutionRuntimeHandle,
ExecutionRuntimeHandle, JoinSnafu,
};
pub enum BlockingSinkStatus {
NeedMoreInput,
Expand Down Expand Up @@ -102,19 +103,20 @@ impl PipelineNode for BlockingSinkNode {
let mut guard = op.lock().await;
rt_context.in_span(&span, || guard.sink(val.as_data()))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
if matches!(result, BlockingSinkStatus::Finished) {
break;
}
}
let finalized_result = compute_runtime
.await_on(async move {
.spawn(async move {
let mut guard = op.lock().await;
rt_context.in_span(&info_span!("BlockingSinkNode::finalize"), || {
guard.finalize()
})
})
.await??;
.await
.context(JoinSnafu)??;
if let Some(part) = finalized_result {
let _ = destination_sender.send(part).await;
}
Expand Down
7 changes: 4 additions & 3 deletions src/daft-local-execution/src/sinks/streaming_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl StreamingSinkNode {
let fut = async move {
rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref()))
};
let result = compute_runtime.await_on(fut).await??;
let result = compute_runtime.spawn(fut).await.context(JoinSnafu)??;
match result {
StreamingSinkOutput::NeedMoreInput(mp) => {
if let Some(mp) = mp {
Expand Down Expand Up @@ -281,12 +281,13 @@ impl PipelineNode for StreamingSinkNode {

let compute_runtime = get_compute_runtime();
let finalized_result = compute_runtime
.await_on(async move {
.spawn(async move {
runtime_stats.in_span(&info_span!("StreamingSinkNode::finalize"), || {
op.finalize(finished_states)
})
})
.await??;
.await
.context(JoinSnafu)??;
if let Some(res) = finalized_result {
let _ = destination_sender.send(res.into()).await;
}
Expand Down
5 changes: 3 additions & 2 deletions src/daft-local-execution/src/sources/scan_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
.get_io_client_and_runtime()?;
let scan_tasks = scan_tasks.to_vec();
runtime
.await_on(async move {
.spawn(async move {

Check warning on line 142 in src/daft-local-execution/src/sources/scan_task.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sources/scan_task.rs#L142

Added line #L142 was not covered by tests
let mut delete_map = scan_tasks
.iter()
.flat_map(|st| st.sources.iter().map(|s| s.get_path().to_string()))
Expand Down Expand Up @@ -183,7 +183,8 @@
}
Ok(Some(delete_map))
})
.await?
.await
.context(JoinSnafu)?

Check warning on line 187 in src/daft-local-execution/src/sources/scan_task.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sources/scan_task.rs#L186-L187

Added lines #L186 - L187 were not covered by tests
}

async fn stream_scan_task(
Expand Down
Loading