Skip to content

Commit

Permalink
Refactor: Implement custom async Mutex using Oneshot channel
Browse files Browse the repository at this point in the history
This commit introduces a custom implementation of an asynchronous Mutex
using the `AsyncRuntime::Oneshot` functions, tailored specifically for
Openraft's limited use of asynchronous locks. This custom mutex replaces
the previously used Tokio mutex.

- **Refactor of `RaftInner::tx_shutdown`:** The `tx_shutdown` member of
  `RaftInner` has been changed to use a standard (synchronous) Mutex
  instead of an asynchronous one. This change is made because
  `tx_shutdown` does not span across `.await` points, making the
  asynchronous capabilities unnecessary.

- **OneshotSender `Drop` Implementation:** It is now documented that the
  `OneshotSender` should implement the `Drop` trait to ensure that when
  a sender is dropped, the receiver is notified and yields an error.
  This behavior is crucial for maintaining robust error handling in
  asynchronous communication patterns.
  • Loading branch information
drmingdrmer committed Jul 29, 2024
1 parent 6c7527f commit 9a5574f
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 11 deletions.
1 change: 1 addition & 0 deletions openraft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub mod metrics;
pub mod network;
pub mod raft;
pub mod storage;
pub mod sync;
pub mod testing;
pub mod type_config;

Expand Down
6 changes: 3 additions & 3 deletions openraft/src/raft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ pub use message::InstallSnapshotResponse;
pub use message::SnapshotResponse;
pub use message::VoteRequest;
pub use message::VoteResponse;
use tokio::sync::Mutex;
use tracing::trace_span;
use tracing::Instrument;
use tracing::Level;
Expand Down Expand Up @@ -78,6 +77,7 @@ pub use crate::raft::runtime_config_handle::RuntimeConfigHandle;
use crate::raft::trigger::Trigger;
use crate::storage::RaftLogStorage;
use crate::storage::RaftStateMachine;
use crate::sync::Mutex;
use crate::type_config::alias::JoinErrorOf;
use crate::type_config::alias::ResponderOf;
use crate::type_config::alias::ResponderReceiverOf;
Expand Down Expand Up @@ -318,7 +318,7 @@ where C: RaftTypeConfig
rx_metrics,
rx_data_metrics,
rx_server_metrics,
tx_shutdown: Mutex::new(Some(tx_shutdown)),
tx_shutdown: std::sync::Mutex::new(Some(tx_shutdown)),
core_state: Mutex::new(CoreState::Running(core_handle)),

snapshot: Mutex::new(None),
Expand Down Expand Up @@ -919,7 +919,7 @@ where C: RaftTypeConfig
///
/// It sends a shutdown signal and waits until `RaftCore` returns.
pub async fn shutdown(&self) -> Result<(), JoinErrorOf<C>> {
if let Some(tx) = self.inner.tx_shutdown.lock().await.take() {
if let Some(tx) = self.inner.tx_shutdown.lock().unwrap().take() {
// A failure to send means the RaftCore is already shutdown. Continue to check the task
// return value.
let send_res = tx.send(());
Expand Down
9 changes: 4 additions & 5 deletions openraft/src/raft/raft_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;

use tokio::sync::Mutex;
use tracing::Level;

use crate::async_runtime::MpscUnboundedSender;
Expand All @@ -16,6 +15,7 @@ use crate::error::RaftError;
use crate::metrics::RaftDataMetrics;
use crate::metrics::RaftServerMetrics;
use crate::raft::core_state::CoreState;
use crate::sync::Mutex;
use crate::type_config::alias::MpscUnboundedSenderOf;
use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::OneshotSenderOf;
Expand All @@ -40,13 +40,12 @@ where C: RaftTypeConfig
pub(in crate::raft) rx_data_metrics: WatchReceiverOf<C, RaftDataMetrics<C>>,
pub(in crate::raft) rx_server_metrics: WatchReceiverOf<C, RaftServerMetrics<C>>,

// TODO(xp): it does not need to be a async mutex.
#[allow(clippy::type_complexity)]
pub(in crate::raft) tx_shutdown: Mutex<Option<OneshotSenderOf<C, ()>>>,
pub(in crate::raft) core_state: Mutex<CoreState<C>>,
pub(in crate::raft) tx_shutdown: std::sync::Mutex<Option<OneshotSenderOf<C, ()>>>,
pub(in crate::raft) core_state: Mutex<C, CoreState<C>>,

/// The ongoing snapshot transmission.
pub(in crate::raft) snapshot: Mutex<Option<crate::network::snapshot_transport::Streaming<C>>>,
pub(in crate::raft) snapshot: Mutex<C, Option<crate::network::snapshot_transport::Streaming<C>>>,
}

impl<C> RaftInner<C>
Expand Down
6 changes: 3 additions & 3 deletions openraft/src/replication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use request::Replicate;
use response::ReplicationResult;
pub(crate) use response::Response;
use tokio::select;
use tokio::sync::Mutex;
use tracing_futures::Instrument;

use crate::async_runtime::MpscUnboundedReceiver;
Expand Down Expand Up @@ -50,6 +49,7 @@ use crate::replication::request_id::RequestId;
use crate::storage::RaftLogReader;
use crate::storage::RaftLogStorage;
use crate::storage::Snapshot;
use crate::sync::Mutex;
use crate::type_config::alias::InstantOf;
use crate::type_config::alias::JoinHandleOf;
use crate::type_config::alias::LogIdOf;
Expand Down Expand Up @@ -114,7 +114,7 @@ where
/// Another `RaftNetwork` specific for snapshot replication.
///
/// Snapshot transmitting is a long running task, and is processed in a separate task.
snapshot_network: Arc<Mutex<N::Network>>,
snapshot_network: Arc<Mutex<C, N::Network>>,

/// The current snapshot replication state.
///
Expand Down Expand Up @@ -754,7 +754,7 @@ where

async fn send_snapshot(
request_id: RequestId,
network: Arc<Mutex<N::Network>>,
network: Arc<Mutex<C, N::Network>>,
vote: Vote<C::NodeId>,
snapshot: Snapshot<C>,
option: RPCOption,
Expand Down
5 changes: 5 additions & 0 deletions openraft/src/sync/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub(crate) mod mutex;

pub(crate) use mutex::Mutex;
#[allow(unused_imports)]
pub(crate) use mutex::MutexGuard;
168 changes: 168 additions & 0 deletions openraft/src/sync/mutex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
use std::cell::UnsafeCell;
use std::ops::Deref;
use std::ops::DerefMut;

use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::TypeConfigExt;
use crate::RaftTypeConfig;

/// A simple async mutex implementation that uses oneshot channels to notify the next waiting task.
///
/// Openraft use async mutex in non-performance critical path,
/// so it's ok to use this simple implementation.
///
/// Since oneshot channel is already required by AsyncRuntime implementation,
/// there is no need for the application to implement Mutex.
pub(crate) struct Mutex<C, T>
where C: RaftTypeConfig
{
/// The current lock holder.
///
/// When the acquired `MutexGuard` is dropped, it will notify the next waiting task via this
/// oneshot channel.
lock_holder: std::sync::Mutex<Option<OneshotReceiverOf<C, ()>>>,

/// The value protected by the mutex.
value: UnsafeCell<T>,
}

impl<C, T> Mutex<C, T>
where C: RaftTypeConfig
{
pub(crate) fn new(value: T) -> Self {
Self {
lock_holder: std::sync::Mutex::new(None),
value: UnsafeCell::new(value),
}
}

pub(crate) async fn lock(&self) -> MutexGuard<'_, C, T> {
// Every lock() call puts a oneshot receiver into the holder
// and takes out the existing one.
// If the existing one is Some(rx),
// it means the lock is already held by another task.
// In this case, the current task should wait for the lock to be released.
//
// Such approach forms a queue in which every task waits for the previous one.

let (tx, rx) = C::oneshot();
let current_rx = {
let mut l = self.lock_holder.lock().unwrap();
l.replace(rx)
};

if let Some(rx) = current_rx {
let _ = rx.await;
}

MutexGuard { guard: tx, lock: self }
}

#[allow(dead_code)]
pub(crate) fn into_inner(self) -> T {
self.value.into_inner()
}
}

/// The guard of the mutex.
pub(crate) struct MutexGuard<'a, C, T>
where C: RaftTypeConfig
{
/// This is only used to trigger `Drop` to notify the next waiting task.
#[allow(dead_code)]
guard: OneshotSenderOf<C, ()>,
lock: &'a Mutex<C, T>,
}

impl<'a, C, T> Deref for MutexGuard<'a, C, T>
where C: RaftTypeConfig
{
type Target = T;

fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.value.get() }
}
}

impl<'a, C, T> DerefMut for MutexGuard<'a, C, T>
where C: RaftTypeConfig
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.value.get() }
}
}

/// T must be `Send` to make Mutex `Send`
unsafe impl<C: RaftTypeConfig, T> Send for Mutex<C, T> where T: Send {}

/// To allow multiple threads to access T through a `&Mutex`, T must be `Send`,
/// because the caller acquires the ownership through `Mutex::lock()`.
unsafe impl<C: RaftTypeConfig, T> Sync for Mutex<C, T> where T: Send {}

/// MutexGuard needs to be Sync to across `.await` point.
unsafe impl<C: RaftTypeConfig, T> Sync for MutexGuard<'_, C, T> where T: Send + Sync {}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;
use crate::engine::testing::UTConfig;

#[test]
fn bounds() {
fn check_send<T: Send>() {}
fn check_unpin<T: Unpin>() {}
// This has to take a value, since the async fn's return type is unnameable.
fn check_send_sync_val<T: Send + Sync>(_t: T) {}
fn check_send_sync<T: Send + Sync>() {}

check_send::<MutexGuard<'_, UTConfig, u32>>();
check_unpin::<Mutex<UTConfig, u32>>();
check_send_sync::<Mutex<UTConfig, u32>>();

let mutex = Mutex::<UTConfig, _>::new(1);
check_send_sync_val(mutex.lock());
}

#[test]
fn test_mutex() {
let mutex = Arc::new(Mutex::<UTConfig, u64>::new(0));

let rt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(8)
.enable_all()
.build()
.expect("Failed building the Runtime");

let big_prime_num = 1_000_000_009;
let n = 100_000;
let n_task = 10;
let mut tasks = vec![];

for _i in 0..n_task {
let mutex = mutex.clone();
let h = rt.spawn(async move {
for k in 0..n {
{
let mut guard = mutex.lock().await;
*guard = (*guard + k) % big_prime_num;
}
}
});

tasks.push(h);
}

let got = rt.block_on(async {
for t in tasks {
let _ = t.await;
}
*mutex.lock().await
});

println!("got: {}", got);
assert_eq!(got, n_task * n * (n - 1) / 2 % big_prime_num);
}
}
2 changes: 2 additions & 0 deletions openraft/src/type_config/async_runtime/oneshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub trait Oneshot {
where T: OptionalSend;
}

/// This `Sender` must implement `Drop` to notify the [`Oneshot::Receiver`] that the sending end has
/// been dropped, causing the receiver to return a [`Oneshot::ReceiverError`].
pub trait OneshotSender<T>: OptionalSend + OptionalSync + Sized
where T: OptionalSend
{
Expand Down

0 comments on commit 9a5574f

Please sign in to comment.