Skip to content

Commit

Permalink
Testing helpers for networking
Browse files Browse the repository at this point in the history
Small refactoring to make it easier to test networking code and to run TestCenter in tests
  • Loading branch information
AhmedSoliman committed Sep 9, 2024
1 parent 7222cd4 commit fbf0413
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 44 deletions.
26 changes: 24 additions & 2 deletions crates/core/src/network/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::Metadata;
/// The primary owner of a connection is the running reactor, all other components
/// should hold a Weak<Connection> if caching access to a certain connection is
/// needed.
pub(crate) struct Connection {
pub struct Connection {
/// Connection identifier, randomly generated on this end of the connection.
pub(crate) cid: u64,
pub(crate) peer: GenerationalNodeId,
Expand All @@ -57,7 +57,7 @@ pub(crate) struct Connection {
}

impl Connection {
pub fn new(
pub(crate) fn new(
peer: GenerationalNodeId,
protocol_version: ProtocolVersion,
sender: mpsc::Sender<Message>,
Expand All @@ -71,6 +71,15 @@ impl Connection {
}
}

#[cfg(any(test, feature = "test-util"))]
pub fn new_fake(
peer: GenerationalNodeId,
protocol_version: ProtocolVersion,
sender: mpsc::Sender<Message>,
) -> Arc<Self> {
Arc::new(Self::new(peer, protocol_version, sender))
}

/// The current negotiated protocol version of the connection
pub fn protocol_version(&self) -> ProtocolVersion {
self.protocol_version
Expand Down Expand Up @@ -208,6 +217,19 @@ pub(crate) struct HeaderMetadataVersions {
versions: EnumMap<MetadataKind, Option<Version>>,
}

impl Default for HeaderMetadataVersions {
// Used primarily in tests
fn default() -> Self {
let versions = enum_map! {
MetadataKind::NodesConfiguration => Some(Version::MIN),
MetadataKind::Schema => None,
MetadataKind::Logs => None,
MetadataKind::PartitionTable => None,
};
Self { versions }
}
}

impl HeaderMetadataVersions {
pub fn from_metadata(metadata: &Metadata) -> Self {
let versions = enum_map! {
Expand Down
15 changes: 8 additions & 7 deletions crates/core/src/network/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use std::collections::{hash_map, HashMap};
use std::sync::{Arc, Mutex, Weak};
use std::time::Instant;

use enum_map::EnumMap;
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use rand::seq::SliceRandom;
use restate_types::config::NetworkingOptions;
use restate_types::net::codec::try_unwrap_binary_message;
use std::collections::{hash_map, HashMap};
use std::sync::{Arc, Mutex, Weak};
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tonic::transport::Channel;
use tracing::{debug, info, trace, warn, Instrument, Span};

use restate_types::config::NetworkingOptions;
use restate_types::net::codec::MessageBodyExt;
use restate_types::net::metadata::MetadataKind;
use restate_types::net::AdvertisedAddress;
use restate_types::nodes_config::NodesConfiguration;
Expand Down Expand Up @@ -577,7 +578,7 @@ where
break;
}

match try_unwrap_binary_message(body, connection.protocol_version) {
match body.try_as_binary_body(connection.protocol_version) {
Ok(msg) => {
if let Err(e) = router
.call(
Expand Down Expand Up @@ -626,7 +627,7 @@ where
};
if let Some(body) = msg.body {
// we ignore non-deserializable messages (serde errors, or control signals in drain)
if let Ok(msg) = try_unwrap_binary_message(body, protocol_version) {
if let Ok(msg) = body.try_as_binary_body(protocol_version) {
drain_counter += 1;
if let Err(e) = router
.call(
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub mod protobuf;
pub mod rpc_router;
mod types;

pub use connection::ConnectionSender;
pub use connection::{Connection, ConnectionSender};
pub use connection_manager::ConnectionManager;
pub use error::*;
pub use message_router::*;
Expand Down
10 changes: 9 additions & 1 deletion crates/core/src/network/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ impl<M> Incoming<M> {
in_response_to,
}
}

#[cfg(any(test, feature = "test-util"))]
pub fn for_testing(connection: &Arc<Connection>, body: M, in_response_to: Option<u64>) -> Self {
let peer = connection.peer;
let connection = Arc::downgrade(connection);
let msg_id = generate_msg_id();
Self::from_parts(peer, body, connection, msg_id, in_response_to)
}
}

impl<M> Incoming<M> {
Expand Down Expand Up @@ -300,7 +308,7 @@ impl<M: Targeted + WireEncode> Outgoing<M> {
return Err(NetworkSendError::new(self, NetworkError::ConnectionClosed));
}
};
let versions = with_metadata(HeaderMetadataVersions::from_metadata);
let versions = with_metadata(HeaderMetadataVersions::from_metadata).unwrap_or_default();
Ok((connection, versions, self))
}
}
18 changes: 10 additions & 8 deletions crates/core/src/task_center.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ impl TaskCenterBuilder {
self
}

#[cfg(any(test, feature = "test-util"))]
pub fn default_for_tests() -> Self {
Self::default()
.ingress_runtime_handle(tokio::runtime::Handle::current())
.default_runtime_handle(tokio::runtime::Handle::current())
.pause_time(true)
}

pub fn build(mut self) -> Result<TaskCenter, TaskCenterBuildError> {
let options = self.options.unwrap_or_default();
if self.default_runtime_handle.is_none() {
Expand Down Expand Up @@ -994,17 +1002,11 @@ pub fn metadata() -> Metadata {
}

#[track_caller]
pub fn with_metadata<F, R>(f: F) -> R
pub fn with_metadata<F, R>(f: F) -> Option<R>
where
F: FnOnce(&Metadata) -> R,
{
CONTEXT.with(|ctx| {
let metadata = ctx
.metadata
.as_ref()
.expect("metadata() called before global metadata was set");
f(metadata)
})
CONTEXT.with(|ctx| ctx.metadata.as_ref().map(f))
}

/// Access to this node id. This is available in task-center tasks only!
Expand Down
27 changes: 14 additions & 13 deletions crates/core/src/test_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,15 @@ use std::str::FromStr;
use std::sync::Arc;

use tokio::sync::{mpsc, RwLock};
use tracing::info;

use crate::metadata_store::{MetadataStoreClient, Precondition};
use crate::network::{
Handler, Incoming, MessageHandler, MessageRouter, MessageRouterBuilder, NetworkError,
NetworkSendError, NetworkSender, Outgoing,
};
use crate::{
cancellation_watcher, metadata, spawn_metadata_manager, MetadataBuilder, ShutdownError, TaskId,
};
use crate::{Metadata, MetadataManager, MetadataWriter};
use crate::{TaskCenter, TaskCenterBuilder};
use restate_types::cluster_controller::{ReplicationStrategy, SchedulingPlan};
use restate_types::logs::metadata::{bootstrap_logs_metadata, ProviderKind};
use restate_types::metadata_store::keys::{
BIFROST_CONFIG_KEY, NODES_CONFIG_KEY, PARTITION_TABLE_KEY, SCHEDULING_PLAN_KEY,
};
use restate_types::net::codec::{
serialize_message, try_unwrap_binary_message, Targeted, WireDecode, WireEncode,
serialize_message, MessageBodyExt, Targeted, WireDecode, WireEncode,
};
use restate_types::net::metadata::MetadataKind;
use restate_types::net::AdvertisedAddress;
Expand All @@ -39,7 +30,17 @@ use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfiguratio
use restate_types::partition_table::PartitionTable;
use restate_types::protobuf::node::{Header, Message};
use restate_types::{GenerationalNodeId, Version};
use tracing::info;

use crate::metadata_store::{MetadataStoreClient, Precondition};
use crate::network::{
Handler, Incoming, MessageHandler, MessageRouter, MessageRouterBuilder, NetworkError,
NetworkSendError, NetworkSender, Outgoing,
};
use crate::{
cancellation_watcher, metadata, spawn_metadata_manager, MetadataBuilder, ShutdownError, TaskId,
};
use crate::{Metadata, MetadataManager, MetadataWriter};
use crate::{TaskCenter, TaskCenterBuilder};

#[derive(Clone)]
pub struct MockNetworkSender {
Expand Down Expand Up @@ -145,7 +146,7 @@ impl NetworkReceiver {
let header = msg.header.expect("header must be set");
let msg = Incoming::from_parts(
peer,
try_unwrap_binary_message(body, CURRENT_PROTOCOL_VERSION)?,
body.try_as_binary_body(CURRENT_PROTOCOL_VERSION)?,
std::sync::Weak::new(),
header.msg_id,
header.in_response_to,
Expand Down
40 changes: 28 additions & 12 deletions crates/types/src/net/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,6 @@ pub fn serialize_message<M: WireEncode + Targeted>(
}))
}

pub fn try_unwrap_binary_message(
msg: message::Body,
_protocol_version: ProtocolVersion,
) -> Result<BinaryMessage, CodecError> {
let message::Body::Encoded(binary) = msg else {
return Err(CodecError::Decode(
"Cannot deserialize message, message is not of type BinaryMessage".into(),
));
};
Ok(binary)
}

/// Helper function for default encoding of values.
pub fn encode_default<T: Serialize, B: BufMut>(
value: T,
Expand Down Expand Up @@ -213,3 +201,31 @@ pub fn decode_default<T: DeserializeOwned, B: Buf>(
}
}
}

pub trait MessageBodyExt {
fn try_as_binary_body(
self,
protocol_version: ProtocolVersion,
) -> Result<BinaryMessage, CodecError>;

fn try_decode<T: WireDecode>(self, protocol_version: ProtocolVersion) -> Result<T, CodecError>;
}

impl MessageBodyExt for crate::protobuf::node::message::Body {
fn try_as_binary_body(
self,
_protocol_version: ProtocolVersion,
) -> Result<BinaryMessage, CodecError> {
let message::Body::Encoded(binary) = self else {
return Err(CodecError::Decode(
"Cannot deserialize message, message is not of type BinaryMessage".into(),
));
};
Ok(binary)
}

fn try_decode<T: WireDecode>(self, protocol_version: ProtocolVersion) -> Result<T, CodecError> {
let mut binary_message = self.try_as_binary_body(protocol_version)?;
<T as WireDecode>::decode(&mut binary_message.payload, protocol_version)
}
}

0 comments on commit fbf0413

Please sign in to comment.