Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Implement Custom Async Mutex Using Oneshot Channel #1208

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions openraft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub mod metrics;
pub mod network;
pub mod raft;
pub mod storage;
pub mod sync;
pub mod testing;
pub mod type_config;

Expand Down
2 changes: 1 addition & 1 deletion openraft/src/raft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ pub use message::InstallSnapshotResponse;
pub use message::SnapshotResponse;
pub use message::VoteRequest;
pub use message::VoteResponse;
use tokio::sync::Mutex;
use tracing::trace_span;
use tracing::Instrument;
use tracing::Level;
Expand Down Expand Up @@ -78,6 +77,7 @@ pub use crate::raft::runtime_config_handle::RuntimeConfigHandle;
use crate::raft::trigger::Trigger;
use crate::storage::RaftLogStorage;
use crate::storage::RaftStateMachine;
use crate::sync::Mutex;
use crate::type_config::alias::JoinErrorOf;
use crate::type_config::alias::ResponderOf;
use crate::type_config::alias::ResponderReceiverOf;
Expand Down
4 changes: 2 additions & 2 deletions openraft/src/raft/raft_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;

use tokio::sync::Mutex;
use tracing::Level;

use crate::async_runtime::watch::WatchReceiver;
Expand All @@ -18,6 +17,7 @@ use crate::error::RaftError;
use crate::metrics::RaftDataMetrics;
use crate::metrics::RaftServerMetrics;
use crate::raft::core_state::CoreState;
use crate::sync::Mutex;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::MpscUnboundedSenderOf;
use crate::type_config::alias::OneshotReceiverOf;
Expand Down Expand Up @@ -48,7 +48,7 @@ where C: RaftTypeConfig
pub(in crate::raft) core_state: std::sync::Mutex<CoreState<C>>,

/// The ongoing snapshot transmission.
pub(in crate::raft) snapshot: Mutex<Option<crate::network::snapshot_transport::Streaming<C>>>,
pub(in crate::raft) snapshot: Mutex<C, Option<crate::network::snapshot_transport::Streaming<C>>>,
}

impl<C> RaftInner<C>
Expand Down
6 changes: 3 additions & 3 deletions openraft/src/replication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use request::Replicate;
use response::ReplicationResult;
pub(crate) use response::Response;
use tokio::select;
use tokio::sync::Mutex;
use tracing_futures::Instrument;

use crate::async_runtime::MpscUnboundedReceiver;
Expand Down Expand Up @@ -50,6 +49,7 @@ use crate::replication::request_id::RequestId;
use crate::storage::RaftLogReader;
use crate::storage::RaftLogStorage;
use crate::storage::Snapshot;
use crate::sync::Mutex;
use crate::type_config::alias::InstantOf;
use crate::type_config::alias::JoinHandleOf;
use crate::type_config::alias::LogIdOf;
Expand Down Expand Up @@ -114,7 +114,7 @@ where
/// Another `RaftNetwork` specific for snapshot replication.
///
/// Snapshot transmitting is a long running task, and is processed in a separate task.
snapshot_network: Arc<Mutex<N::Network>>,
snapshot_network: Arc<Mutex<C, N::Network>>,

/// The current snapshot replication state.
///
Expand Down Expand Up @@ -754,7 +754,7 @@ where

async fn send_snapshot(
request_id: RequestId,
network: Arc<Mutex<N::Network>>,
network: Arc<Mutex<C, N::Network>>,
vote: Vote<C::NodeId>,
snapshot: Snapshot<C>,
option: RPCOption,
Expand Down
5 changes: 5 additions & 0 deletions openraft/src/sync/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub(crate) mod mutex;

pub(crate) use mutex::Mutex;
#[allow(unused_imports)]
pub(crate) use mutex::MutexGuard;
168 changes: 168 additions & 0 deletions openraft/src/sync/mutex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
use std::cell::UnsafeCell;
use std::ops::Deref;
use std::ops::DerefMut;

use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::TypeConfigExt;
use crate::RaftTypeConfig;

/// A simple async mutex implementation that uses oneshot channels to notify the next waiting task.
///
/// Openraft use async mutex in non-performance critical path,
/// so it's ok to use this simple implementation.
///
/// Since oneshot channel is already required by AsyncRuntime implementation,
/// there is no need for the application to implement Mutex.
pub(crate) struct Mutex<C, T>
where C: RaftTypeConfig
{
/// The current lock holder.
///
/// When the acquired `MutexGuard` is dropped, it will notify the next waiting task via this
/// oneshot channel.
lock_holder: std::sync::Mutex<Option<OneshotReceiverOf<C, ()>>>,

/// The value protected by the mutex.
value: UnsafeCell<T>,
}

impl<C, T> Mutex<C, T>
where C: RaftTypeConfig
{
pub(crate) fn new(value: T) -> Self {
Self {
lock_holder: std::sync::Mutex::new(None),
value: UnsafeCell::new(value),
}
}

pub(crate) async fn lock(&self) -> MutexGuard<'_, C, T> {
// Every lock() call puts a oneshot receiver into the holder
// and takes out the existing one.
// If the existing one is Some(rx),
// it means the lock is already held by another task.
// In this case, the current task should wait for the lock to be released.
//
// Such approach forms a queue in which every task waits for the previous one.

let (tx, rx) = C::oneshot();
let current_rx = {
let mut l = self.lock_holder.lock().unwrap();
l.replace(rx)
};

if let Some(rx) = current_rx {
let _ = rx.await;
}

MutexGuard { guard: tx, lock: self }
}

#[allow(dead_code)]
pub(crate) fn into_inner(self) -> T {
self.value.into_inner()
}
}

/// The guard of the mutex.
pub(crate) struct MutexGuard<'a, C, T>
where C: RaftTypeConfig
{
/// This is only used to trigger `Drop` to notify the next waiting task.
#[allow(dead_code)]
guard: OneshotSenderOf<C, ()>,
lock: &'a Mutex<C, T>,
}

impl<'a, C, T> Deref for MutexGuard<'a, C, T>
where C: RaftTypeConfig
{
type Target = T;

fn deref(&self) -> &Self::Target {
unsafe { &*self.lock.value.get() }
}
}

impl<'a, C, T> DerefMut for MutexGuard<'a, C, T>
where C: RaftTypeConfig
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.lock.value.get() }
}
}

/// T must be `Send` to make Mutex `Send`
unsafe impl<C: RaftTypeConfig, T> Send for Mutex<C, T> where T: Send {}

/// To allow multiple threads to access T through a `&Mutex`, T must be `Send`,
/// because the caller acquires the ownership through `Mutex::lock()`.
unsafe impl<C: RaftTypeConfig, T> Sync for Mutex<C, T> where T: Send {}

/// MutexGuard needs to be Sync to across `.await` point.
unsafe impl<C: RaftTypeConfig, T> Sync for MutexGuard<'_, C, T> where T: Send + Sync {}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;
use crate::engine::testing::UTConfig;

#[test]
fn bounds() {
fn check_send<T: Send>() {}
fn check_unpin<T: Unpin>() {}
// This has to take a value, since the async fn's return type is unnameable.
fn check_send_sync_val<T: Send + Sync>(_t: T) {}
fn check_send_sync<T: Send + Sync>() {}

check_send::<MutexGuard<'_, UTConfig, u32>>();
check_unpin::<Mutex<UTConfig, u32>>();
check_send_sync::<Mutex<UTConfig, u32>>();

let mutex = Mutex::<UTConfig, _>::new(1);
check_send_sync_val(mutex.lock());
}

#[test]
fn test_mutex() {
let mutex = Arc::new(Mutex::<UTConfig, u64>::new(0));

let rt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(8)
.enable_all()
.build()
.expect("Failed building the Runtime");

let big_prime_num = 1_000_000_009;
let n = 100_000;
let n_task = 10;
let mut tasks = vec![];

for _i in 0..n_task {
let mutex = mutex.clone();
let h = rt.spawn(async move {
for k in 0..n {
{
let mut guard = mutex.lock().await;
*guard = (*guard + k) % big_prime_num;
}
}
});

tasks.push(h);
}

let got = rt.block_on(async {
for t in tasks {
let _ = t.await;
}
*mutex.lock().await
});

println!("got: {}", got);
assert_eq!(got, n_task * n * (n - 1) / 2 % big_prime_num);
}
}
2 changes: 2 additions & 0 deletions openraft/src/type_config/async_runtime/oneshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub trait Oneshot {
where T: OptionalSend;
}

/// This `Sender` must implement `Drop` to notify the [`Oneshot::Receiver`] that the sending end has
/// been dropped, causing the receiver to return a [`Oneshot::ReceiverError`].
pub trait OneshotSender<T>: OptionalSend + OptionalSync + Sized
where T: OptionalSend
{
Expand Down
Loading