diff --git a/openraft/src/error.rs b/openraft/src/error.rs index 92857b279..196b3b91b 100644 --- a/openraft/src/error.rs +++ b/openraft/src/error.rs @@ -224,8 +224,8 @@ where #[error(transparent)] HigherVote(#[from] HigherVote), - #[error("Replication is closed")] - Closed, + #[error(transparent)] + Closed(#[from] ReplicationClosed), // TODO(xp): two sub type: StorageError / TransportError // TODO(xp): a sub error for just send_append_entries() @@ -236,6 +236,11 @@ where RPCError(#[from] RPCError>), } +/// Error occurs when replication is closed. +#[derive(Debug, thiserror::Error)] +#[error("Replication is closed by RaftCore")] +pub(crate) struct ReplicationClosed {} + /// Error occurs when invoking a remote raft API. #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] #[cfg_attr( @@ -247,6 +252,11 @@ pub enum RPCError { #[error(transparent)] Timeout(#[from] Timeout), + /// The node is temporarily unreachable and should backoff before retrying. + #[error(transparent)] + Unreachable(#[from] Unreachable), + + /// Failed to send the RPC request and should retry immediately. #[error(transparent)] Network(#[from] NetworkError), @@ -265,6 +275,7 @@ where where E: TryAsRef> { match self { RPCError::Timeout(_) => None, + RPCError::Unreachable(_) => None, RPCError::Network(_) => None, RPCError::RemoteError(remote_err) => remote_err.source.forward_to_leader(), } @@ -331,6 +342,27 @@ impl NetworkError { } } +/// Error that indicates a node is unreachable and should not retry sending anything to it +/// immediately. +/// +/// It is similar to [`NetworkError`] but indicating a backoff. +/// When a [`NetworkError`] is returned, Openraft will retry immediately. +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize), serde(bound = ""))] +#[error("Unreachable node: {source}")] +pub struct Unreachable { + #[from] + source: AnyError, +} + +impl Unreachable { + pub fn new(e: &E) -> Self { + Self { + source: AnyError::new(e), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize), serde(bound = ""))] #[error("timeout after {timeout:?} when {action} {id}->{target}")] diff --git a/openraft/src/membership/stored_membership.rs b/openraft/src/membership/stored_membership.rs index 7a3f4071c..868aad05f 100644 --- a/openraft/src/membership/stored_membership.rs +++ b/openraft/src/membership/stored_membership.rs @@ -60,6 +60,6 @@ where NID: NodeId, { fn summary(&self) -> String { - format!("{{log_id:{:?} membership:{}}}", self.log_id, self.membership.summary()) + format!("{{log_id:{}, {}}}", self.log_id.summary(), self.membership.summary()) } } diff --git a/openraft/src/metrics/raft_metrics.rs b/openraft/src/metrics/raft_metrics.rs index 2d28861dd..69e7bf527 100644 --- a/openraft/src/metrics/raft_metrics.rs +++ b/openraft/src/metrics/raft_metrics.rs @@ -62,13 +62,14 @@ where NID: NodeId, N: Node, { + // TODO: make this more readable fn summary(&self) -> String { format!("Metrics{{id:{},{:?}, term:{}, last_log:{:?}, last_applied:{:?}, leader:{:?}, membership:{}, snapshot:{:?}, replication:{{{}}}", self.id, self.state, self.current_term, self.last_log_index, - self.last_applied, + self.last_applied.summary(), self.current_leader, self.membership_config.summary(), self.snapshot, diff --git a/openraft/src/network.rs b/openraft/src/network.rs index b07f65145..b103baa63 100644 --- a/openraft/src/network.rs +++ b/openraft/src/network.rs @@ -1,6 +1,7 @@ //! The Raft network interface. use std::fmt::Formatter; +use std::time::Duration; use async_trait::async_trait; @@ -63,6 +64,41 @@ where C: RaftTypeConfig &mut self, rpc: VoteRequest, ) -> Result, RPCError>>; + + /// Build a backoff instance if the target node is temporarily(or permanently) unreachable. + /// + /// When a [`Unreachable`](`crate::error::Unreachable`) error is returned from the `Network` + /// methods, Openraft does not retry connecting to a node immediately. Instead, it sleeps + /// for a while and retries. The duration of the sleep is determined by the backoff + /// instance. + /// + /// The backoff is an infinite iterator that returns the ith sleep interval before the ith + /// retry. The returned instance will be dropped if a successful RPC is made. + /// + /// By default it returns a constant backoff of 500 ms. + fn backoff(&self) -> Backoff { + Backoff::new(std::iter::repeat(Duration::from_millis(500))) + } +} + +/// A backoff instance that is an infinite iterator of durations to sleep before next retry, when a +/// [`Unreachable`](`crate::error::Unreachable`) occurs. +pub struct Backoff { + inner: Box + Send + 'static>, +} + +impl Backoff { + pub fn new(iter: impl Iterator + Send + 'static) -> Self { + Self { inner: Box::new(iter) } + } +} + +impl Iterator for Backoff { + type Item = Duration; + + fn next(&mut self) -> Option { + self.inner.next() + } } /// A trait defining the interface for a Raft network factory to create connections between cluster diff --git a/openraft/src/replication/mod.rs b/openraft/src/replication/mod.rs index 17db270a8..a061d7c2e 100644 --- a/openraft/src/replication/mod.rs +++ b/openraft/src/replication/mod.rs @@ -14,21 +14,25 @@ use tokio::io::AsyncRead; use tokio::io::AsyncReadExt; use tokio::io::AsyncSeek; use tokio::io::AsyncSeekExt; +use tokio::select; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio::time::sleep; use tokio::time::timeout; use tokio::time::Duration; +use tokio::time::Instant; use tracing_futures::Instrument; use crate::config::Config; use crate::error::HigherVote; use crate::error::RPCError; +use crate::error::ReplicationClosed; use crate::error::ReplicationError; use crate::error::Timeout; use crate::log_id::LogIdOptionExt; use crate::log_id_range::LogIdRange; +use crate::network::Backoff; use crate::raft::AppendEntriesRequest; use crate::raft::AppendEntriesResponse; use crate::raft::InstallSnapshotRequest; @@ -59,7 +63,7 @@ where S: AsyncRead + AsyncSeek + Send + Unpin + 'static, { /// The spawn handle the `ReplicationCore` task. - pub(crate) join_handle: JoinHandle<()>, + pub(crate) join_handle: JoinHandle>, /// The channel used for communicating with the replication task. pub(crate) tx_repl: mpsc::UnboundedSender>, @@ -93,6 +97,10 @@ where /// The `RaftNetwork` interface. network: N::Network, + /// The backoff policy if an [`Unreachable`] error is returned. + /// It will be reset to `None` when an successful response is received. + backoff: Option, + /// The `RaftLogReader` of a `RaftStorage` interface. log_reader: LS::LogReader, @@ -146,6 +154,7 @@ where target, session_id, network, + backoff: None, log_reader, config, committed, @@ -161,7 +170,7 @@ where } #[tracing::instrument(level="debug", skip(self), fields(session=%self.session_id, target=display(self.target), cluster=%self.config.cluster_name))] - async fn main(mut self) { + async fn main(mut self) -> Result<(), ReplicationClosed> { loop { let action = std::mem::replace(&mut self.next_action, None); @@ -179,13 +188,16 @@ where }; match res { - Ok(_x) => {} + Ok(_) => { + // reset backoff + self.backoff = None; + } Err(err) => { tracing::warn!(error=%err, "error replication to target={}", self.target); match err { - ReplicationError::Closed => { - return; + ReplicationError::Closed(closed) => { + return Err(closed); } ReplicationError::HigherVote(h) => { let _ = self.tx_raft_core.send(RaftMsg::HigherVote { @@ -193,14 +205,14 @@ where higher: h.higher, vote: self.session_id.vote, }); - return; + return Ok(()); } ReplicationError::StorageError(err) => { tracing::error!(error=%err, "error replication to target={}", self.target); // TODO: report this error let _ = self.tx_raft_core.send(RaftMsg::ReplicationFatal); - return; + return Ok(()); } ReplicationError::RPCError(err) => { tracing::error!(err = display(&err), "RPCError"); @@ -210,24 +222,29 @@ where result: Err(err.to_string()), session_id: self.session_id, }); + + // If there is an [`Unreachable`] error, we will backoff for a period of time + // Backoff will be reset if there is a successful RPC is sent. + if let RPCError::Unreachable(_unreachable) = err { + if self.backoff.is_none() { + self.backoff = Some(self.network.backoff()); + } + } } }; } }; - let res = self.drain_events().await; - match res { - Ok(_x) => {} - Err(err) => match err { - ReplicationError::Closed => { - return; - } + if let Some(b) = &mut self.backoff { + let duration = b.next().unwrap_or_else(|| { + tracing::warn!("backoff exhausted, using default"); + Duration::from_millis(500) + }); - _ => { - unreachable!("no other error expected but: {:?}", err); - } - }, + self.backoff_drain_events(Instant::now() + duration).await?; } + + self.drain_events().await?; } } @@ -367,14 +384,49 @@ where } } + /// Drain all events in the channel in backoff mode, i.e., there was an un-retry-able error and + /// should not send out anything before backoff interval expired. + /// + /// In the backoff period, we should not send out any RPCs, but we should still receive events, + /// in case the channel is closed, it should quit at once. + #[tracing::instrument(level = "trace", skip(self))] + pub async fn backoff_drain_events(&mut self, until: Instant) -> Result<(), ReplicationClosed> { + let d = until - Instant::now(); + tracing::warn!( + interval = debug(d), + "{} backoff mode: drain events without processing them", + func_name!() + ); + + loop { + let sleep_duration = until - Instant::now(); + let sleep = sleep(sleep_duration); + + let recv = self.rx_repl.recv(); + + tracing::debug!("backoff timeout: {:?}", sleep_duration); + + select! { + _ = sleep => { + tracing::debug!("backoff timeout"); + return Ok(()); + } + recv_res = recv => { + let event = recv_res.ok_or(ReplicationClosed{})?; + self.process_event(event); + } + } + } + } + /// Receive and process events from RaftCore, until `next_action` is filled. /// /// It blocks until at least one event is received. #[tracing::instrument(level = "trace", skip_all)] - pub async fn drain_events(&mut self) -> Result<(), ReplicationError> { + pub async fn drain_events(&mut self) -> Result<(), ReplicationClosed> { tracing::debug!("drain_events"); - let event = self.rx_repl.recv().await.ok_or(ReplicationError::Closed)?; + let event = self.rx_repl.recv().await.ok_or(ReplicationClosed {})?; self.process_event(event); self.try_drain_events().await?; @@ -399,10 +451,13 @@ where } #[tracing::instrument(level = "trace", skip(self))] - pub async fn try_drain_events(&mut self) -> Result<(), ReplicationError> { + pub async fn try_drain_events(&mut self) -> Result<(), ReplicationClosed> { tracing::debug!("try_drain_raft_rx"); - while self.next_action.is_none() { + // Just drain all events in the channel. + // There should not be more than one `Replicate::Data` event in the channel. + // Looping it just collect all commit events and heartbeat events. + loop { let maybe_res = self.rx_repl.recv().now_or_never(); let recv_res = match maybe_res { @@ -413,12 +468,10 @@ where Some(x) => x, }; - let event = recv_res.ok_or(ReplicationError::Closed)?; + let event = recv_res.ok_or(ReplicationClosed {})?; self.process_event(event); } - - Ok(()) } #[tracing::instrument(level = "trace", skip_all)] @@ -445,7 +498,9 @@ where //- If self.next_action is not None, next_action will serve as a heartbeat. } Replicate::Data(d) => { - debug_assert!(self.next_action.is_none(),); + // TODO: Currently there is at most 1 in flight data. But in future RaftCore may send next data + // actions without waiting for the previous to finish. + debug_assert!(self.next_action.is_none(), "there can not be two data action in flight"); self.next_action = Some(d); } } diff --git a/tests/tests/fixtures/mod.rs b/tests/tests/fixtures/mod.rs index 847b4f859..23089e61c 100644 --- a/tests/tests/fixtures/mod.rs +++ b/tests/tests/fixtures/mod.rs @@ -5,6 +5,7 @@ #[cfg(feature = "bt")] use std::backtrace::Backtrace; use std::collections::BTreeMap; use std::collections::BTreeSet; +use std::collections::HashMap; use std::collections::HashSet; use std::env; use std::panic::PanicInfo; @@ -28,6 +29,7 @@ use openraft::error::NetworkError; use openraft::error::RPCError; use openraft::error::RaftError; use openraft::error::RemoteError; +use openraft::error::Unreachable; use openraft::metrics::Wait; use openraft::raft::AppendEntriesRequest; use openraft::raft::AppendEntriesResponse; @@ -144,12 +146,25 @@ pub struct TypedRaftRouter { #[allow(clippy::type_complexity)] routing_table: Arc>>, - /// Nodes which are isolated can neither send nor receive frames. + /// Nodes which are isolated can neither send nor receive frames, it returns an `NetworkError`. isolated_nodes: Arc>>, + /// Nodes to which an RPC is sent return an `Unreachable` error. + unreachable_nodes: Arc>>, + /// To emulate network delay for sending, in milliseconds. /// 0 means no delay. send_delay: Arc, + + /// Count of RPCs sent. + rpc_count: Arc>>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RPCType { + AppendEntries, + InstallSnapshot, + Vote, } /// Default `RaftRouter` for memstore. @@ -181,7 +196,9 @@ impl Builder { config: self.config, routing_table: Default::default(), isolated_nodes: Default::default(), + unreachable_nodes: Default::default(), send_delay: Arc::new(AtomicU64::new(send_delay)), + rpc_count: Default::default(), } } } @@ -192,7 +209,9 @@ impl Clone for TypedRaftRouter { config: self.config.clone(), routing_table: self.routing_table.clone(), isolated_nodes: self.isolated_nodes.clone(), + unreachable_nodes: self.unreachable_nodes.clone(), send_delay: self.send_delay.clone(), + rpc_count: self.rpc_count.clone(), } } } @@ -223,6 +242,16 @@ impl TypedRaftRouter { tokio::time::sleep(timeout).await; } + fn count_rpc(&self, rpc_type: RPCType) { + let mut rpc_count = self.rpc_count.lock().unwrap(); + let count = rpc_count.entry(rpc_type).or_insert(0); + *count += 1; + } + + pub fn get_rpc_count(&self) -> HashMap { + self.rpc_count.lock().unwrap().clone() + } + /// Create a cluster: 0 is the initial leader, others are voters and learners /// /// NOTE: it create a single node cluster first, then change it to a multi-voter cluster. @@ -315,7 +344,6 @@ impl TypedRaftRouter { } #[tracing::instrument(level = "debug", skip_all)] - pub async fn new_raft_node_with_sto(&mut self, id: MemNodeId, log_store: MemLogStore, sm: MemStateMachine) { let node = Raft::new(id, self.config.clone(), self.clone(), log_store.clone(), sm.clone()).await.unwrap(); let mut rt = self.routing_table.lock().unwrap(); @@ -356,6 +384,17 @@ impl TypedRaftRouter { self.isolated_nodes.lock().unwrap().insert(id); } + /// Set to `true` to return [`Unreachable`](`openraft::errors::Unreachable`) when sending RPC to + /// a node. + pub fn set_unreachable(&self, id: MemNodeId, unreachable: bool) { + let mut u = self.unreachable_nodes.lock().unwrap(); + if unreachable { + u.insert(id); + } else { + u.remove(&id); + } + } + /// Get a payload of the latest metrics from each node in the cluster. #[allow(clippy::significant_drop_in_scrutinee)] pub fn latest_metrics(&self) -> Vec> { @@ -882,7 +921,7 @@ impl TypedRaftRouter { } #[tracing::instrument(level = "debug", skip(self))] - pub fn check_reachable(&self, id: MemNodeId, target: MemNodeId) -> Result<(), NetworkError> { + pub fn check_network_error(&self, id: MemNodeId, target: MemNodeId) -> Result<(), NetworkError> { let isolated = self.isolated_nodes.lock().unwrap(); if isolated.contains(&target) || isolated.contains(&id) { @@ -892,6 +931,18 @@ impl TypedRaftRouter { Ok(()) } + + #[tracing::instrument(level = "debug", skip(self))] + pub fn check_unreachable(&self, id: MemNodeId, target: MemNodeId) -> Result<(), Unreachable> { + let unreachable = self.unreachable_nodes.lock().unwrap(); + + if unreachable.contains(&target) || unreachable.contains(&id) { + let err = Unreachable::new(&AnyError::error(format!("isolated:{} -> {}", id, target))); + return Err(err); + } + + Ok(()) + } } #[async_trait] @@ -919,7 +970,10 @@ impl RaftNetwork for RaftRouterNetwork { rpc: AppendEntriesRequest, ) -> Result, RPCError>> { tracing::debug!("append_entries to id={} {}", self.target, rpc.summary()); - self.owner.check_reachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; + self.owner.count_rpc(RPCType::AppendEntries); + + self.owner.check_network_error(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; + self.owner.check_unreachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; self.owner.rand_send_delay().await; let node = self.owner.get_raft_handle(&self.target)?; @@ -928,6 +982,7 @@ impl RaftNetwork for RaftRouterNetwork { tracing::debug!("append_entries: recv resp from id={} {:?}", self.target, resp); let resp = resp.map_err(|e| RemoteError::new(self.target, e))?; + Ok(resp) } @@ -937,13 +992,17 @@ impl RaftNetwork for RaftRouterNetwork { rpc: InstallSnapshotRequest, ) -> Result, RPCError>> { - self.owner.check_reachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; + self.owner.count_rpc(RPCType::InstallSnapshot); + + self.owner.check_network_error(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; + self.owner.check_unreachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; self.owner.rand_send_delay().await; let node = self.owner.get_raft_handle(&self.target)?; let resp = node.install_snapshot(rpc).await; let resp = resp.map_err(|e| RemoteError::new(self.target, e))?; + Ok(resp) } @@ -952,13 +1011,17 @@ impl RaftNetwork for RaftRouterNetwork { &mut self, rpc: VoteRequest, ) -> Result, RPCError>> { - self.owner.check_reachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; + self.owner.count_rpc(RPCType::Vote); + + self.owner.check_network_error(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; + self.owner.check_unreachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?; self.owner.rand_send_delay().await; let node = self.owner.get_raft_handle(&self.target)?; let resp = node.vote(rpc).await; let resp = resp.map_err(|e| RemoteError::new(self.target, e))?; + Ok(resp) } } diff --git a/tests/tests/replication/main.rs b/tests/tests/replication/main.rs new file mode 100644 index 000000000..c7b12bee5 --- /dev/null +++ b/tests/tests/replication/main.rs @@ -0,0 +1,8 @@ +#![cfg_attr(feature = "bt", feature(error_generic_member_access))] +#![cfg_attr(feature = "bt", feature(provide_any))] + +#[macro_use] +#[path = "../fixtures/mod.rs"] +mod fixtures; + +mod t50_append_entries_backoff; diff --git a/tests/tests/replication/t50_append_entries_backoff.rs b/tests/tests/replication/t50_append_entries_backoff.rs new file mode 100644 index 000000000..bdb66407b --- /dev/null +++ b/tests/tests/replication/t50_append_entries_backoff.rs @@ -0,0 +1,63 @@ +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use maplit::btreeset; +use openraft::Config; + +use crate::fixtures::init_default_ut_tracing; +use crate::fixtures::RPCType; +use crate::fixtures::RaftRouter; + +/// Append-entries should backoff when a `Unreachable` error is found. +#[async_entry::test(worker_threads = 4, init = "init_default_ut_tracing()", tracing_span = "debug")] +async fn append_entries_backoff() -> Result<()> { + let config = Arc::new( + Config { + heartbeat_interval: 5_000, + election_timeout_min: 10_000, + election_timeout_max: 10_001, + ..Default::default() + } + .validate()?, + ); + + let mut router = RaftRouter::new(config.clone()); + + tracing::info!("--- initializing cluster"); + let mut log_index = router.new_cluster(btreeset! {0,1,2}, btreeset! {}).await?; + + let counts0 = router.get_rpc_count(); + let n = 10u64; + + tracing::info!("--- set node 2 to unreachable, and write 10 entries"); + { + router.set_unreachable(2, true); + + router.client_request_many(0, "0", n as usize).await?; + log_index += n; + + router.wait(&0, timeout()).log(Some(log_index), format!("{} writes", n)).await?; + } + + let counts1 = router.get_rpc_count(); + + let c0 = *counts0.get(&RPCType::AppendEntries).unwrap_or(&0); + let c1 = *counts1.get(&RPCType::AppendEntries).unwrap_or(&0); + + // dbg!(counts0); + // dbg!(counts1); + + // Without backoff, the leader would send about 40 append-entries RPC. + // 20 for append log entries, 20 for updating committed. + assert!( + n < c1 - c0 && c1 - c0 < n * 4, + "append-entries should backoff when a `Unreachable` error is found" + ); + + Ok(()) +} + +fn timeout() -> Option { + Some(Duration::from_millis(1_000)) +}