Skip to content

Commit

Permalink
metadata() for global metadata access
Browse files Browse the repository at this point in the history
Access to metadata has never been easier. Access `metadata()` by importing `restate_core::metadata`;
  • Loading branch information
AhmedSoliman committed Feb 22, 2024
1 parent a9b2edf commit 6fc290a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 17 deletions.
2 changes: 1 addition & 1 deletion crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

pub mod metadata;

mod task_center_types;
mod task_center;
mod task_center_types;

pub use task_center::*;
pub use task_center_types::*;
52 changes: 41 additions & 11 deletions crates/core/src/task_center.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use restate_types::identifiers::PartitionId;
use std::collections::HashMap;
use std::panic::AssertUnwindSafe;
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};

use futures::Future;
Expand All @@ -22,6 +22,7 @@ use tokio::task_local;
use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
use tracing::{debug, error, info, instrument, warn};

use crate::metadata::Metadata;
use crate::{TaskId, TaskKind};

static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(0);
Expand All @@ -43,6 +44,7 @@ impl TaskCenterFactory {
shutdown_requested: AtomicBool::new(false),
current_exit_code: AtomicI32::new(0),
tasks: Mutex::new(HashMap::new()),
global_metadata: OnceLock::new(),
}),
}
}
Expand Down Expand Up @@ -100,6 +102,12 @@ impl TaskCenter {
info!("** Shutdown completed in {:?}", start.elapsed());
}

/// Attempt to set the global metadata handle. This should be called once
/// at the startup of the node.
pub fn try_set_global_metadata(&self, metadata: Metadata) -> bool {
self.inner.global_metadata.set(metadata).is_ok()
}

#[track_caller]
fn spawn_inner<F>(
&self,
Expand All @@ -124,14 +132,20 @@ impl TaskCenter {
});

inner.tasks.lock().unwrap().insert(id, Arc::clone(&task));
let metadata = inner.global_metadata.get().cloned();

let mut handle_mut = task.join_handle.lock().unwrap();

let task_cloned = Arc::clone(&task);
let join_handle =
inner
.runtime
.spawn(wrapper(self.clone(), id, kind, task_cloned, cancel, future));
let join_handle = inner.runtime.spawn(wrapper(
self.clone(),
id,
kind,
task_cloned,
cancel,
metadata,
future,
));
*handle_mut = Some(join_handle);
drop(handle_mut);

Expand Down Expand Up @@ -450,6 +464,7 @@ struct TaskCenterInner {
shutdown_requested: AtomicBool,
current_exit_code: AtomicI32,
tasks: Mutex<HashMap<TaskId, Arc<Task>>>,
global_metadata: OnceLock<Metadata>,
}

pub struct Task {
Expand All @@ -475,6 +490,9 @@ task_local! {

// Current task center
static CURRENT_TASK_CENTER: TaskCenter;

// Metadata handle
static METADATA: Option<Metadata>;
}

/// This wrapper function runs in a newly-spawned task. It initializes the
Expand All @@ -485,6 +503,7 @@ async fn wrapper<F>(
kind: TaskKind,
task: Arc<Task>,
cancel_token: CancellationToken,
metadata: Option<Metadata>,
future: F,
) where
F: Future<Output = anyhow::Result<()>> + Send + 'static,
Expand All @@ -496,12 +515,15 @@ async fn wrapper<F>(
task_center.clone(),
CANCEL_TOKEN.scope(
cancel_token,
CURRENT_TASK.scope(task, {
// We use AssertUnwindSafe here so that the wrapped function
// doesn't need to be UnwindSafe. We should not do anything after
// unwinding that'd risk us being in unwind-unsafe behavior.
AssertUnwindSafe(future).catch_unwind()
}),
CURRENT_TASK.scope(
task,
METADATA.scope(metadata, {
// We use AssertUnwindSafe here so that the wrapped function
// doesn't need to be UnwindSafe. We should not do anything after
// unwinding that'd risk us being in unwind-unsafe behavior.
AssertUnwindSafe(future).catch_unwind()
}),
),
),
)
.await;
Expand All @@ -520,6 +542,14 @@ pub fn current_task_id() -> Option<TaskId> {
CURRENT_TASK.try_with(|ct| ct.id).ok()
}

/// Access to global metadata handle. This available in task-center tasks only!
pub fn metadata() -> Metadata {
METADATA
.try_with(|m| m.clone())
.expect("metadata() called outside task-center scope")
.expect("metadata() called before global metadata was set")
}

/// The current partition Id associated to the running task-center task.
pub fn current_task_partition_id() -> Option<PartitionId> {
CURRENT_TASK.try_with(|ct| ct.partition_id).ok().flatten()
Expand Down
4 changes: 3 additions & 1 deletion crates/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ impl Node {
let tc = task_center();
let metadata_writer = self.metadata_manager.writer();
let metadata = self.metadata_manager.metadata();
let is_set = tc.try_set_global_metadata(metadata.clone());
debug_assert!(is_set);

// Start metadata manager
tc.spawn(
Expand Down Expand Up @@ -253,7 +255,7 @@ impl Node {
TaskKind::SystemBoot,
"worker-init",
None,
worker_role.start(metadata),
worker_role.start(),
)?;
}

Expand Down
7 changes: 3 additions & 4 deletions crates/node/src/roles/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ use tonic::transport::Channel;
use tracing::debug;
use tracing::subscriber::NoSubscriber;

use restate_core::metadata::Metadata;
use restate_core::task_center;
use restate_core::TaskKind;
use restate_core::{metadata, task_center};
use restate_network::utils::create_grpc_channel_from_network_address;
use restate_node_services::cluster_ctrl::cluster_ctrl_svc_client::ClusterCtrlSvcClient;
use restate_node_services::cluster_ctrl::AttachmentRequest;
Expand Down Expand Up @@ -118,11 +117,11 @@ impl WorkerRole {
Some(self.worker.subscription_controller_handle())
}

pub async fn start(self, metadata: Metadata) -> anyhow::Result<()> {
pub async fn start(self) -> anyhow::Result<()> {
// todo: only run subscriptions on node 0 once being distributed
let subscription_controller = Some(self.worker.subscription_controller_handle());

let admin_address = metadata
let admin_address = metadata()
.nodes_config()
.get_admin_node()
.expect("at least one admin node")
Expand Down

0 comments on commit 6fc290a

Please sign in to comment.