From 8ecd5f4905b13ce56bccde68a26b69c38b8a8a58 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Tue, 4 May 2021 22:37:35 +0200 Subject: [PATCH 01/23] misc/multistream-select: Implement simultaneous open extension From the multistream-select 1.0 simultaneous open protocol extension specification: > In order to support direct connections through NATs with hole punching, we need to account for simultaneous open. In such cases, there is no single initiator and responder, but instead both peers act as initiators. This breaks protocol negotiation in multistream-select, which assumes a single initator. > This draft proposes a simple extension to the multistream protocol negotiation in order to select a single initator when both peers are acting as such. See https://github.com/libp2p/specs/pull/196/ for details. This commit implements the above specification, available via `Version::V1SimOpen`. --- misc/multistream-select/Cargo.toml | 1 + misc/multistream-select/src/dialer_select.rs | 331 +++++++++++++++++- misc/multistream-select/src/lib.rs | 1 + .../multistream-select/src/listener_select.rs | 33 +- misc/multistream-select/src/protocol.rs | 58 ++- misc/multistream-select/src/tests.rs | 54 ++- 6 files changed, 444 insertions(+), 34 deletions(-) diff --git a/misc/multistream-select/Cargo.toml b/misc/multistream-select/Cargo.toml index 7c86be9e6ac..ef689d5e7ee 100644 --- a/misc/multistream-select/Cargo.toml +++ b/misc/multistream-select/Cargo.toml @@ -14,6 +14,7 @@ bytes = "1" futures = "0.3" log = "0.4" pin-project = "1.0.0" +rand = "0.7" smallvec = "1.6.1" unsigned-varint = "0.7" diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 482d31a56b3..6d80412e87b 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -21,10 +21,11 @@ //! Protocol negotiation strategies for the peer acting as the dialer. use crate::{Negotiated, NegotiationError, Version}; -use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine}; +use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine, SIM_OPEN_ID}; use futures::{future::Either, prelude::*}; -use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}}; +use std::{cmp::Ordering, convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}}; + /// Returns a `Future` that negotiates a protocol on the given I/O stream /// for a peer acting as the _dialer_ (or _initiator_). @@ -56,11 +57,18 @@ where I::Item: AsRef<[u8]> { let iter = protocols.into_iter(); - // We choose between the "serial" and "parallel" strategies based on the number of protocols. - if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) { - Either::Left(dialer_select_proto_serial(inner, iter, version)) - } else { - Either::Right(dialer_select_proto_parallel(inner, iter, version)) + match version { + Version::V1 | Version::V1Lazy => { + // We choose between the "serial" and "parallel" strategies based on the number of protocols. + if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) { + Either::Left(dialer_select_proto_serial(inner, iter, version)) + } else { + Either::Right(dialer_select_proto_parallel(inner, iter, version)) + } + }, + Version::V1SimOpen => { + Either::Left(dialer_select_proto_serial(inner, iter, version)) + } } } @@ -145,7 +153,16 @@ where R: AsyncRead + AsyncWrite, N: AsRef<[u8]> { - SendHeader { io: MessageIO, }, + SendHeader { io: MessageIO }, + + // Simultaneous open protocol extension + SendSimOpen { io: MessageIO, protocol: Option }, + FlushSimOpen { io: MessageIO, protocol: N }, + AwaitSimOpen { io: MessageIO, protocol: N }, + SimOpenPhase { selection: SimOpenPhase, protocol: N }, + Responder { responder: crate::ListenerSelectFuture }, + + // Standard multistream-select protocol SendProtocol { io: MessageIO, protocol: N }, FlushProtocol { io: MessageIO, protocol: N }, AwaitProtocol { io: MessageIO, protocol: N }, @@ -158,7 +175,8 @@ where // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, I: Iterator, - I::Item: AsRef<[u8]> + // TODO: Clone needed to embed ListenerSelectFuture. Still needed? + I::Item: AsRef<[u8]> + Clone { type Output = Result<(I::Item, Negotiated), NegotiationError>; @@ -181,11 +199,123 @@ where return Poll::Ready(Err(From::from(err))); } - let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; + match this.version { + Version::V1 | Version::V1Lazy => { + let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; + + // The dialer always sends the header and the first protocol + // proposal in one go for efficiency. + *this.state = SeqState::SendProtocol { io, protocol }; + } + Version::V1SimOpen => { + *this.state = SeqState::SendSimOpen { io, protocol: None }; + } + } + } + + SeqState::SendSimOpen { mut io, protocol } => { + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {}, + Poll::Pending => { + *this.state = SeqState::SendSimOpen { io, protocol }; + return Poll::Pending + }, + } + + match protocol { + None => { + let msg = Message::Protocol(SIM_OPEN_ID); + if let Err(err) = Pin::new(&mut io).start_send(msg) { + return Poll::Ready(Err(From::from(err))); + } + + let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; + *this.state = SeqState::SendSimOpen { io, protocol: Some(protocol) }; + } + Some(protocol) => { + let p = Protocol::try_from(protocol.as_ref())?; + if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) { + return Poll::Ready(Err(From::from(err))); + } + log::debug!("Dialer: Proposed protocol: {}", p); + + *this.state = SeqState::FlushSimOpen { io, protocol } + } + } + } + + SeqState::FlushSimOpen { mut io, protocol } => { + match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => { + *this.state = SeqState::AwaitSimOpen { io, protocol } + }, + Poll::Pending => { + *this.state = SeqState::FlushSimOpen { io, protocol }; + return Poll::Pending + }, + } + } + + SeqState::AwaitSimOpen { mut io, protocol } => { + let msg = match Pin::new(&mut io).poll_next(cx)? { + Poll::Ready(Some(msg)) => msg, + Poll::Pending => { + *this.state = SeqState::AwaitSimOpen { io, protocol }; + return Poll::Pending + } + // Treat EOF error as [`NegotiationError::Failed`], not as + // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O + // stream as a permissible way to "gracefully" fail a negotiation. + Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), + }; + + match msg { + Message::Header(v) if v == HeaderLine::from(*this.version) => { + *this.state = SeqState::AwaitSimOpen { io, protocol }; + } + Message::Protocol(p) if p == SIM_OPEN_ID => { + let selection = SimOpenPhase { + state: SimOpenState::SendNonce{ io }, + }; + *this.state = SeqState::SimOpenPhase { selection, protocol }; + } + Message::NotAvailable => { + *this.state = SeqState::AwaitProtocol { io, protocol } + } + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) + } + } + + SeqState::SimOpenPhase { mut selection, protocol } => { + let (io, selection_res) = match Pin::new(&mut selection).poll(cx)? { + Poll::Ready((io, res)) => (io, res), + Poll::Pending => { + *this.state = SeqState::SimOpenPhase { selection, protocol }; + return Poll::Pending + } + }; + + match selection_res { + SimOpenRole::Initiator => { + *this.state = SeqState::SendProtocol { io, protocol }; + } + SimOpenRole::Responder => { + let protocols: Vec<_> = this.protocols.collect(); + *this.state = SeqState::Responder { + responder: crate::listener_select::listener_select_proto_no_header(io, std::iter::once(protocol).chain(protocols.into_iter())), + } + }, + } + } - // The dialer always sends the header and the first protocol - // proposal in one go for efficiency. - *this.state = SeqState::SendProtocol { io, protocol }; + SeqState::Responder { mut responder } => { + match Pin::new(&mut responder ).poll(cx) { + Poll::Ready(res) => return Poll::Ready(res), + Poll::Pending => { + *this.state = SeqState::Responder { responder }; + return Poll::Pending + } + } } SeqState::SendProtocol { mut io, protocol } => { @@ -207,7 +337,7 @@ where *this.state = SeqState::FlushProtocol { io, protocol } } else { match this.version { - Version::V1 => *this.state = SeqState::FlushProtocol { io, protocol }, + Version::V1 | Version::V1SimOpen => *this.state = SeqState::FlushProtocol { io, protocol }, // This is the only effect that `V1Lazy` has compared to `V1`: // Optimistically settling on the only protocol that // the dialer supports for this negotiation. Notably, @@ -224,7 +354,9 @@ where SeqState::FlushProtocol { mut io, protocol } => { match Pin::new(&mut io).poll_flush(cx)? { - Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol }, + Poll::Ready(()) => { + *this.state = SeqState::AwaitProtocol { io, protocol } + } , Poll::Pending => { *this.state = SeqState::FlushProtocol { io, protocol }; return Poll::Pending @@ -245,10 +377,17 @@ where Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), }; + match msg { Message::Header(v) if v == HeaderLine::from(*this.version) => { *this.state = SeqState::AwaitProtocol { io, protocol }; } + Message::Protocol(p) if p == SIM_OPEN_ID => { + let selection = SimOpenPhase { + state: SimOpenState::SendNonce{ io }, + }; + *this.state = SeqState::SimOpenPhase { selection, protocol }; + } Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { log::debug!("Dialer: Received confirmation for protocol: {}", p); let io = Negotiated::completed(io.into_inner()); @@ -270,6 +409,168 @@ where } } +struct SimOpenPhase { + state: SimOpenState, +} + +enum SimOpenState { + SendNonce { io: MessageIO }, + FlushNonce { io: MessageIO, local_nonce: u64 }, + ReadNonce { io: MessageIO, local_nonce: u64 }, + SendRole { io: MessageIO, local_role: SimOpenRole }, + FlushRole { io: MessageIO, local_role: SimOpenRole }, + ReadRole { io: MessageIO, local_role: SimOpenRole }, + Done, +} + +enum SimOpenRole { + Initiator, + Responder, +} + +impl Future for SimOpenPhase +where + // The Unpin bound here is required because we produce a `Negotiated` as the output. + // It also makes the implementation considerably easier to write. + R: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(MessageIO, SimOpenRole), NegotiationError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + + loop { + match mem::replace(&mut self.state, SimOpenState::Done) { + SimOpenState::SendNonce { mut io } => { + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {}, + Poll::Pending => { + self.state = SimOpenState::SendNonce { io }; + return Poll::Pending + }, + } + + let local_nonce = rand::random(); + let msg = Message::Select(local_nonce); + if let Err(err) = Pin::new(&mut io).start_send(msg) { + return Poll::Ready(Err(From::from(err))); + } + + self.state = SimOpenState::FlushNonce { + io, + local_nonce, + }; + }, + SimOpenState::FlushNonce { mut io, local_nonce } => { + match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => self.state = SimOpenState::ReadNonce { + io, + local_nonce, + }, + Poll::Pending => { + self.state =SimOpenState::FlushNonce { io, local_nonce }; + return Poll::Pending + }, + } + }, + SimOpenState::ReadNonce { mut io, local_nonce } => { + let msg = match Pin::new(&mut io).poll_next(cx)? { + Poll::Ready(Some(msg)) => msg, + Poll::Pending => { + self.state = SimOpenState::ReadNonce { io, local_nonce }; + return Poll::Pending + } + // Treat EOF error as [`NegotiationError::Failed`], not as + // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O + // stream as a permissible way to "gracefully" fail a negotiation. + Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), + }; + + match msg { + // TODO: Document that this might still be the protocol send by the remote with + // the sim open ID. + Message::Protocol(_) => { + self.state = SimOpenState::ReadNonce { io, local_nonce }; + } + Message::Select(remote_nonce) => { + match local_nonce.cmp(&remote_nonce) { + Ordering::Equal => { + // Start over. + self.state = SimOpenState::SendNonce { io }; + }, + Ordering::Greater => { + self.state = SimOpenState::SendRole { + io, + local_role: SimOpenRole::Initiator, + }; + }, + Ordering::Less => { + self.state = SimOpenState::SendRole { + io, + local_role: SimOpenRole::Responder, + }; + } + } + } + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), + } + }, + SimOpenState::SendRole { mut io, local_role } => { + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {}, + Poll::Pending => { + self.state = SimOpenState::SendRole { io, local_role }; + return Poll::Pending + }, + } + + let msg = match local_role { + SimOpenRole::Initiator => Message::Initiator, + SimOpenRole::Responder => Message::Responder, + }; + + if let Err(err) = Pin::new(&mut io).start_send(msg) { + return Poll::Ready(Err(From::from(err))); + } + + self.state = SimOpenState::FlushRole { io, local_role }; + }, + SimOpenState::FlushRole { mut io, local_role } => { + match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => self.state = SimOpenState::ReadRole { io, local_role }, + Poll::Pending => { + self.state =SimOpenState::FlushRole { io, local_role }; + return Poll::Pending + }, + } + }, + SimOpenState::ReadRole { mut io, local_role } => { + let remote_msg = match Pin::new(&mut io).poll_next(cx)? { + Poll::Ready(Some(msg)) => msg, + Poll::Pending => { + self.state = SimOpenState::ReadRole { io, local_role }; + return Poll::Pending + } + // Treat EOF error as [`NegotiationError::Failed`], not as + // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O + // stream as a permissible way to "gracefully" fail a negotiation. + Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), + }; + + let result = match local_role { + SimOpenRole::Initiator if remote_msg == Message::Responder => Ok((io, local_role)), + SimOpenRole::Responder if remote_msg == Message::Initiator => Ok((io, local_role)), + + _ => Err(ProtocolError::InvalidMessage.into()) + }; + + return Poll::Ready(result) + }, + SimOpenState::Done => panic!("SimOpenPhase::poll called after completion") + } + } + } +} + /// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates /// a protocol selectively by considering all supported protocols of the remote /// "in parallel". diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 087b2a2cb21..29bc37914b0 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -137,6 +137,7 @@ pub enum Version { /// [1]: https://github.com/multiformats/go-multistream/issues/20 /// [2]: https://github.com/libp2p/rust-libp2p/pull/1212 V1Lazy, + V1SimOpen, // Draft: https://github.com/libp2p/specs/pull/95 // V2, } diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index 70463fa1cfe..0fd30f22417 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -39,6 +39,35 @@ pub fn listener_select_proto( inner: R, protocols: I, ) -> ListenerSelectFuture +where + R: AsyncRead + AsyncWrite, + I: IntoIterator, + I::Item: AsRef<[u8]> +{ + listener_select_proto_with_state(State::RecvHeader { + io: MessageIO::new(inner) + }, protocols) +} + +pub(crate) fn listener_select_proto_no_header( + io: MessageIO, + protocols: I, +) -> ListenerSelectFuture +where + R: AsyncRead + AsyncWrite, + I: IntoIterator, + I::Item: AsRef<[u8]> +{ + listener_select_proto_with_state( + State::RecvMessage { io }, + protocols, + ) +} + +fn listener_select_proto_with_state( + state: State, + protocols: I, +) -> ListenerSelectFuture where R: AsyncRead + AsyncWrite, I: IntoIterator, @@ -55,9 +84,7 @@ where }); ListenerSelectFuture { protocols: SmallVec::from_iter(protocols), - state: State::RecvHeader { - io: MessageIO::new(inner) - }, + state, last_sent_na: false, } } diff --git a/misc/multistream-select/src/protocol.rs b/misc/multistream-select/src/protocol.rs index 1d056de75ec..f9a9f4b49c1 100644 --- a/misc/multistream-select/src/protocol.rs +++ b/misc/multistream-select/src/protocol.rs @@ -30,7 +30,7 @@ use crate::length_delimited::{LengthDelimited, LengthDelimitedReader}; use bytes::{Bytes, BytesMut, BufMut}; use futures::{prelude::*, io::IoSlice, ready}; -use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}}; +use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, str::FromStr, task::{Context, Poll}}; use unsigned_varint as uvi; /// The maximum number of supported protocols that can be processed. @@ -42,6 +42,19 @@ const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n"; const MSG_PROTOCOL_NA: &[u8] = b"na\n"; /// The encoded form of a multistream-select 'ls' message. const MSG_LS: &[u8] = b"ls\n"; +/// The encoded form of a 'select:' message of the multistream-select +/// simultaneous open protocol extension. +const MSG_SELECT: &[u8] = b"select:"; +/// The encoded form of a 'initiator' message of the multistream-select +/// simultaneous open protocol extension. +const MSG_INITIATOR: &[u8] = b"initiator\n"; +/// The encoded form of a 'responder' message of the multistream-select +/// simultaneous open protocol extension. +const MSG_RESPONDER: &[u8] = b"responder\n"; + +/// The identifier of the multistream-select simultaneous open protocol +/// extension. +pub(crate) const SIM_OPEN_ID: Protocol = Protocol(Bytes::from_static(b"/libp2p/simultaneous-connect")); /// The multistream-select header lines preceeding negotiation. /// @@ -55,7 +68,7 @@ pub enum HeaderLine { impl From for HeaderLine { fn from(v: Version) -> HeaderLine { match v { - Version::V1 | Version::V1Lazy => HeaderLine::V1, + Version::V1 | Version::V1Lazy | Version::V1SimOpen => HeaderLine::V1, } } } @@ -113,6 +126,9 @@ pub enum Message { Protocols(Vec), /// A message signaling that a requested protocol is not available. NotAvailable, + Select(u64), + Initiator, + Responder, } impl Message { @@ -154,6 +170,22 @@ impl Message { dest.put(MSG_PROTOCOL_NA); Ok(()) } + Message::Select(nonce) => { + dest.put(MSG_SELECT); + dest.put(nonce.to_string().as_ref()); + dest.put_u8(b'\n'); + Ok(()) + } + Message::Initiator => { + dest.reserve(MSG_INITIATOR.len()); + dest.put(MSG_INITIATOR); + Ok(()) + } + Message::Responder => { + dest.reserve(MSG_RESPONDER.len()); + dest.put(MSG_RESPONDER); + Ok(()) + } } } @@ -171,6 +203,26 @@ impl Message { return Ok(Message::ListProtocols) } + if msg.len() > MSG_SELECT.len() + 1 /* \n */ + && msg[.. MSG_SELECT.len()] == *MSG_SELECT + && msg.last() == Some(&b'\n') + { + if let Some(nonce) = std::str::from_utf8(&msg[MSG_SELECT.len() .. msg.len() -1]) + .ok() + .and_then(|s| u64::from_str(s).ok()) + { + return Ok(Message::Select(nonce)) + } + } + + if msg == MSG_INITIATOR { + return Ok(Message::Initiator) + } + + if msg == MSG_RESPONDER { + return Ok(Message::Responder) + } + // If it starts with a `/`, ends with a line feed without any // other line feeds in-between, it must be a protocol name. if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') && @@ -238,7 +290,7 @@ impl MessageIO { MessageReader { inner: self.inner.into_reader() } } - /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream. + /// Draops the [`MessageIO`] resource, yielding the underlying I/O stream. /// /// # Panics /// diff --git a/misc/multistream-select/src/tests.rs b/misc/multistream-select/src/tests.rs index f03d1b1ff75..956c30aa5da 100644 --- a/misc/multistream-select/src/tests.rs +++ b/misc/multistream-select/src/tests.rs @@ -35,22 +35,21 @@ fn select_proto_basic() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener_addr = listener.local_addr().unwrap(); - let server = async_std::task::spawn(async move { + let server = async move { let connec = listener.accept().await.unwrap().0; let protos = vec![b"/proto1", b"/proto2"]; let (proto, mut io) = listener_select_proto(connec, protos).await.unwrap(); assert_eq!(proto, b"/proto2"); - let mut out = vec![0; 32]; - let n = io.read(&mut out).await.unwrap(); - out.truncate(n); + let mut out = vec![0; 4]; + io.read_exact(&mut out).await.unwrap(); assert_eq!(out, b"ping"); io.write_all(b"pong").await.unwrap(); io.flush().await.unwrap(); - }); + }; - let client = async_std::task::spawn(async move { + let client = async move { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version) @@ -60,18 +59,17 @@ fn select_proto_basic() { io.write_all(b"ping").await.unwrap(); io.flush().await.unwrap(); - let mut out = vec![0; 32]; - let n = io.read(&mut out).await.unwrap(); - out.truncate(n); + let mut out = vec![0; 4]; + io.read_exact(&mut out).await.unwrap(); assert_eq!(out, b"pong"); - }); + }; - server.await; - client.await; + futures::future::join(server, client).await; } async_std::task::block_on(run(Version::V1)); async_std::task::block_on(run(Version::V1Lazy)); + async_std::task::block_on(run(Version::V1SimOpen)); } /// Tests the expected behaviour of failed negotiations. @@ -165,7 +163,7 @@ fn negotiation_failed() { for (listen_protos, dial_protos) in protos { for dial_payload in payloads.clone() { - for &version in &[Version::V1, Version::V1Lazy] { + for &version in &[Version::V1, Version::V1Lazy, Version::V1SimOpen] { async_std::task::block_on(run(Test { version, listen_protos: listen_protos.clone(), @@ -237,4 +235,34 @@ fn select_proto_serial() { async_std::task::block_on(run(Version::V1)); async_std::task::block_on(run(Version::V1Lazy)); + async_std::task::block_on(run(Version::V1SimOpen)); +} + +#[test] +fn simultaneous_open() { + async fn run(version: Version) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let server = async move { + let connec = listener.accept().await.unwrap().0; + let protos = vec![b"/proto1", b"/proto2"]; + let (proto, io) = dialer_select_proto_serial(connec, protos, version).await.unwrap(); + assert_eq!(proto, b"/proto2"); + io.complete().await.unwrap(); + }; + + let client = async move { + let connec = TcpStream::connect(&listener_addr).await.unwrap(); + let protos = vec![b"/proto3", b"/proto2"]; + let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version) + .await.unwrap(); + assert_eq!(proto, b"/proto2"); + io.complete().await.unwrap(); + }; + + futures::future::join(server, client).await; + } + + futures::executor::block_on(run(Version::V1SimOpen)); } From 2cfaec991746f190587b8735a9839e1b98bbc203 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 30 Jun 2021 16:34:15 +0200 Subject: [PATCH 02/23] core/: Integrate Simultaneous Open extension --- core/src/transport.rs | 6 +- core/src/transport/upgrade.rs | 124 ++++++++--------- core/src/upgrade.rs | 5 +- core/src/upgrade/apply.rs | 138 ++++++++++++++++++- core/tests/transport_upgrade.rs | 5 +- core/tests/util.rs | 2 +- examples/chat-tokio.rs | 2 +- examples/ipfs-private.rs | 2 +- misc/multistream-select/src/dialer_select.rs | 16 ++- misc/multistream-select/src/lib.rs | 4 +- misc/multistream-select/src/tests.rs | 12 +- misc/multistream-select/tests/transport.rs | 5 +- muxers/mplex/benches/split_send_size.rs | 4 +- protocols/gossipsub/src/lib.rs | 2 +- protocols/gossipsub/tests/smoke.rs | 2 +- protocols/identify/src/identify.rs | 2 +- protocols/kad/src/behaviour/test.rs | 2 +- protocols/ping/tests/ping.rs | 2 +- protocols/relay/examples/relay.rs | 2 +- protocols/relay/src/lib.rs | 2 +- protocols/relay/tests/lib.rs | 6 +- protocols/request-response/tests/ping.rs | 2 +- src/lib.rs | 2 +- swarm/src/lib.rs | 2 +- transports/noise/src/lib.rs | 2 +- transports/noise/tests/smoke.rs | 2 +- 26 files changed, 243 insertions(+), 112 deletions(-) diff --git a/core/src/transport.rs b/core/src/transport.rs index f6e70c44628..94ca830e869 100644 --- a/core/src/transport.rs +++ b/core/src/transport.rs @@ -196,12 +196,14 @@ pub trait Transport { /// Begins a series of protocol upgrades via an /// [`upgrade::Builder`](upgrade::Builder). - fn upgrade(self, version: upgrade::Version) -> upgrade::Builder + // + // TODO: Method still needed now that `upgrade` takes `self` only? + fn upgrade(self) -> upgrade::Builder where Self: Sized, Self::Error: 'static { - upgrade::Builder::new(self, version) + upgrade::Builder::new(self) } } diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index b2cb7b46804..57ec27a5fdb 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -20,7 +20,7 @@ //! Configuration of transport protocol upgrades. -pub use crate::upgrade::Version; +pub use crate::upgrade::{Version, SimOpenRole}; use crate::{ ConnectedPoint, @@ -36,6 +36,7 @@ use crate::{ muxing::{StreamMuxer, StreamMuxerBox}, upgrade::{ self, + AuthenticationUpgradeApply, OutboundUpgrade, InboundUpgrade, apply_inbound, @@ -46,7 +47,7 @@ use crate::{ }, PeerId }; -use futures::{prelude::*, ready}; +use futures::{prelude::*, ready, future::Either}; use multiaddr::Multiaddr; use std::{ error::Error, @@ -80,7 +81,6 @@ use std::{ #[derive(Clone)] pub struct Builder { inner: T, - version: upgrade::Version, } impl Builder @@ -89,8 +89,8 @@ where T::Error: 'static, { /// Creates a `Builder` over the given (base) `Transport`. - pub fn new(inner: T, version: upgrade::Version) -> Builder { - Builder { inner, version } + pub fn new(inner: T) -> Builder { + Builder { inner } } /// Upgrades the transport to perform authentication of the remote. @@ -115,12 +115,26 @@ where U: OutboundUpgrade, Output = (PeerId, D), Error = E> + Clone, E: Error + 'static, { - let version = self.version; + self.authenticate_with_version(upgrade, upgrade::Version::default()) + } + + /// Same as [`Builder::authenticate`] with the option to choose the [`upgrade::Version`] used to + /// upgrade the connection. + pub fn authenticate_with_version(self, upgrade: U, version: upgrade::Version) -> Authenticated< + AndThen Authenticate + Clone> + > where + T: Transport, + C: AsyncRead + AsyncWrite + Unpin, + D: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade, Output = (PeerId, D), Error = E>, + U: OutboundUpgrade, Output = (PeerId, D), Error = E> + Clone, + E: Error + 'static, + { Authenticated(Builder::new(self.inner.and_then(move |conn, endpoint| { Authenticate { - inner: upgrade::apply(conn, upgrade, endpoint, version) + inner: upgrade::apply_authentication(conn, upgrade, endpoint, version) } - }), version)) + }))) } } @@ -135,7 +149,7 @@ where U: InboundUpgrade> + OutboundUpgrade> { #[pin] - inner: EitherUpgrade + inner: AuthenticationUpgradeApply } impl Future for Authenticate @@ -146,7 +160,7 @@ where Error = >>::Error > { - type Output = as Future>::Output; + type Output = as Future>::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -207,16 +221,18 @@ where /// /// * I/O upgrade: `C -> D`. /// * Transport output: `(PeerId, C) -> (PeerId, D)`. + // + // TODO: Do we need an `apply` with a version? pub fn apply(self, upgrade: U) -> Authenticated> where - T: Transport, + T: Transport, C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D, Error = E>, U: OutboundUpgrade, Output = D, Error = E> + Clone, E: Error + 'static, { - Authenticated(Builder::new(Upgrade::new(self.0.inner, upgrade), self.0.version)) + Authenticated(Builder::new(Upgrade::new(self.0.inner, upgrade))) } /// Upgrades the transport with a (sub)stream multiplexer. @@ -230,50 +246,32 @@ where /// * I/O upgrade: `C -> M`. /// * Transport output: `(PeerId, C) -> (PeerId, M)`. pub fn multiplex(self, upgrade: U) -> Multiplexed< - AndThen Multiplex + Clone> + AndThen Multiplex + Clone> > where - T: Transport, + T: Transport, C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, U: InboundUpgrade, Output = M, Error = E>, U: OutboundUpgrade, Output = M, Error = E> + Clone, E: Error + 'static, { - let version = self.0.version; - Multiplexed(self.0.inner.and_then(move |(i, c), endpoint| { - let upgrade = upgrade::apply(c, upgrade, endpoint, version); + Multiplexed(self.0.inner.and_then(move |((i, c), r), endpoint| { + let upgrade = match r { + SimOpenRole::Initiator => { + // TODO: Offer version that allows choosing the Version. + Either::Left(upgrade::apply_outbound(c, upgrade, upgrade::Version::default())) + }, + SimOpenRole::Responder => { + Either::Right(upgrade::apply_inbound(c, upgrade)) + + } + }; Multiplex { peer_id: Some(i), upgrade } })) } - /// Like [`Authenticated::multiplex`] but accepts a function which returns the upgrade. - /// - /// The supplied function is applied to [`PeerId`] and [`ConnectedPoint`] - /// and returns an upgrade which receives the I/O resource `C` and must - /// produce a [`StreamMuxer`] `M`. The transport must already be authenticated. - /// This ends the (regular) transport upgrade process. - /// - /// ## Transitions - /// - /// * I/O upgrade: `C -> M`. - /// * Transport output: `(PeerId, C) -> (PeerId, M)`. - pub fn multiplex_ext(self, up: F) -> Multiplexed< - AndThen Multiplex + Clone> - > where - T: Transport, - C: AsyncRead + AsyncWrite + Unpin, - M: StreamMuxer, - U: InboundUpgrade, Output = M, Error = E>, - U: OutboundUpgrade, Output = M, Error = E> + Clone, - E: Error + 'static, - F: for<'a> FnOnce(&'a PeerId, &'a ConnectedPoint) -> U + Clone - { - let version = self.0.version; - Multiplexed(self.0.inner.and_then(move |(peer_id, c), endpoint| { - let upgrade = upgrade::apply(c, up(&peer_id, &endpoint), endpoint, version); - Multiplex { peer_id: Some(peer_id), upgrade } - })) - } + + // TODO: Add changelog entry that multiplex_ext is removed. } /// A authenticated and multiplexed transport, obtained from @@ -341,7 +339,7 @@ where } /// An inbound or outbound upgrade. -type EitherUpgrade = future::Either, OutboundUpgradeApply>; +type EitherUpgrade = future::Either, InboundUpgradeApply>; /// A custom upgrade on an [`Authenticated`] transport. /// @@ -357,14 +355,14 @@ impl Upgrade { impl Transport for Upgrade where - T: Transport, + T: Transport, T::Error: 'static, C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D, Error = E>, U: OutboundUpgrade, Output = D, Error = E> + Clone, E: Error + 'static { - type Output = (PeerId, D); + type Output = ((PeerId, D), SimOpenRole); type Error = TransportUpgradeError; type Listener = ListenerStream; type ListenerUpgrade = ListenerUpgradeFuture; @@ -435,17 +433,17 @@ where C: AsyncRead + AsyncWrite + Unpin, { future: Pin>, - upgrade: future::Either, (Option, OutboundUpgradeApply)> + upgrade: future::Either, (Option<(PeerId, SimOpenRole)>, OutboundUpgradeApply)> } impl Future for DialUpgradeFuture where - F: TryFuture, + F: TryFuture, C: AsyncRead + AsyncWrite + Unpin, U: OutboundUpgrade, Output = D>, U::Error: Error { - type Output = Result<(PeerId, D), TransportUpgradeError>; + type Output = Result<((PeerId, D), SimOpenRole), TransportUpgradeError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // We use a `this` variable because the compiler can't mutably borrow multiple times @@ -455,20 +453,20 @@ where loop { this.upgrade = match this.upgrade { future::Either::Left(ref mut up) => { - let (i, c) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx).map_err(TransportUpgradeError::Transport)) { + let ((i, c), r) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx).map_err(TransportUpgradeError::Transport)) { Ok(v) => v, Err(err) => return Poll::Ready(Err(err)), }; let u = up.take().expect("DialUpgradeFuture is constructed with Either::Left(Some)."); - future::Either::Right((Some(i), apply_outbound(c, u, upgrade::Version::V1))) + future::Either::Right((Some((i, r)), apply_outbound(c, u, upgrade::Version::V1))) } future::Either::Right((ref mut i, ref mut up)) => { let d = match ready!(Future::poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { Ok(d) => d, Err(err) => return Poll::Ready(Err(err)), }; - let i = i.take().expect("DialUpgradeFuture polled after completion."); - return Poll::Ready(Ok((i, d))) + let (i, r) = i.take().expect("DialUpgradeFuture polled after completion."); + return Poll::Ready(Ok(((i, d), r))) } } } @@ -491,7 +489,7 @@ pub struct ListenerStream { impl Stream for ListenerStream where S: TryStream, Error = E>, - F: TryFuture, + F: TryFuture, C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D> + Clone { @@ -528,17 +526,17 @@ where U: InboundUpgrade> { future: Pin>, - upgrade: future::Either, (Option, InboundUpgradeApply)> + upgrade: future::Either, (Option<(PeerId, SimOpenRole)>, InboundUpgradeApply)> } impl Future for ListenerUpgradeFuture where - F: TryFuture, + F: TryFuture, C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D>, U::Error: Error { - type Output = Result<(PeerId, D), TransportUpgradeError>; + type Output = Result<((PeerId, D), SimOpenRole), TransportUpgradeError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // We use a `this` variable because the compiler can't mutably borrow multiple times @@ -548,20 +546,20 @@ where loop { this.upgrade = match this.upgrade { future::Either::Left(ref mut up) => { - let (i, c) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx).map_err(TransportUpgradeError::Transport)) { + let ((i, c), r) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx).map_err(TransportUpgradeError::Transport)) { Ok(v) => v, Err(err) => return Poll::Ready(Err(err)) }; let u = up.take().expect("ListenerUpgradeFuture is constructed with Either::Left(Some)."); - future::Either::Right((Some(i), apply_inbound(c, u))) + future::Either::Right((Some((i, r)), apply_inbound(c, u))) } future::Either::Right((ref mut i, ref mut up)) => { let d = match ready!(TryFuture::try_poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { Ok(v) => v, Err(err) => return Poll::Ready(Err(err)) }; - let i = i.take().expect("ListenerUpgradeFuture polled after completion."); - return Poll::Ready(Ok((i, d))) + let (i, r) = i.take().expect("ListenerUpgradeFuture polled after completion."); + return Poll::Ready(Ok(((i, d), r))) } } } diff --git a/core/src/upgrade.rs b/core/src/upgrade.rs index 9798ae6c27a..460f076a4b4 100644 --- a/core/src/upgrade.rs +++ b/core/src/upgrade.rs @@ -70,9 +70,9 @@ mod transfer; use futures::future::Future; pub use crate::Negotiated; -pub use multistream_select::{Version, NegotiatedComplete, NegotiationError, ProtocolError}; +pub use multistream_select::{Version, NegotiatedComplete, NegotiationError, ProtocolError, SimOpenRole}; pub use self::{ - apply::{apply, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply}, + apply::{apply, apply_authentication, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply, AuthenticationUpgradeApply}, denied::DeniedUpgrade, either::EitherUpgrade, error::UpgradeError, @@ -221,4 +221,3 @@ pub trait OutboundUpgradeExt: OutboundUpgrade { } impl> OutboundUpgradeExt for U {} - diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index eaf25e884b3..92712da9303 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -20,12 +20,12 @@ use crate::{ConnectedPoint, Negotiated}; use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName}; -use futures::{future::Either, prelude::*}; +use futures::{future::{Either, TryFutureExt, MapOk}, prelude::*}; use log::debug; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; use std::{iter, mem, pin::Pin, task::Context, task::Poll}; -pub use multistream_select::Version; +pub use multistream_select::{Version, SimOpenRole, NegotiationError}; /// Applies an upgrade to the inbound and outbound direction of a connection or substream. pub fn apply(conn: C, up: U, cp: ConnectedPoint, v: Version) @@ -41,6 +41,136 @@ where } } + +/// Applies an authentication upgrade to the inbound or outbound direction of a connection or substream. +// +// TODO: This is specific to authentication upgrades, given that it can handle simultaneous open. +// Should this be moved to transport.rs? +pub fn apply_authentication(conn: C, up: U, cp: ConnectedPoint, v: Version) + -> AuthenticationUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade> + OutboundUpgrade>, +{ + let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); + + AuthenticationUpgradeApply { + inner: AuthenticationUpgradeApplyState::Init{ + future: match cp { + ConnectedPoint::Dialer { .. } => Either::Left( + multistream_select::dialer_select_proto(conn, iter, v), + ), + ConnectedPoint::Listener { .. } => Either::Right( + multistream_select::listener_select_proto(conn, iter) + .map_ok(add_responder as fn (_) -> _), + ), + }, + upgrade: up, + }, + } +} + +// TODO: This is a hack to get a fn pointer. Can we do better? +fn add_responder(input: (P, C)) -> (P, C, SimOpenRole) { + (input.0, input.1, SimOpenRole::Responder) +} + +pub struct AuthenticationUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade> + OutboundUpgrade>, +{ + inner: AuthenticationUpgradeApplyState +} + +impl Unpin for AuthenticationUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade> + OutboundUpgrade>, +{ +} + +enum AuthenticationUpgradeApplyState +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade> + OutboundUpgrade>, +{ + Init { + future: Either< + multistream_select::DialerSelectFuture::IntoIter>>, + MapOk< + ListenerSelectFuture>, + fn((NameWrap, Negotiated)) -> (NameWrap, Negotiated, SimOpenRole) + >, + >, + upgrade: U, + }, + Upgrade { + role: SimOpenRole, + future: Either< + Pin>>::Future>>, + Pin>>::Future>>, + >, + }, + Undefined +} + +impl Future for AuthenticationUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade> + OutboundUpgrade, + Output = >>::Output, + Error = >>::Error + > +{ + type Output = Result< + (>>::Output, SimOpenRole), + UpgradeError<>>::Error>, + >; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match mem::replace(&mut self.inner, AuthenticationUpgradeApplyState::Undefined) { + AuthenticationUpgradeApplyState::Init { mut future, upgrade } => { + let (info, io, role) = match Future::poll(Pin::new(&mut future), cx)? { + Poll::Ready(x) => x, + Poll::Pending => { + self.inner = AuthenticationUpgradeApplyState::Init { future, upgrade }; + return Poll::Pending + } + }; + let fut = match role { + SimOpenRole::Initiator => Either::Left(Box::pin(upgrade.upgrade_outbound(io, info.0))), + SimOpenRole::Responder => Either::Right(Box::pin(upgrade.upgrade_inbound(io, info.0))), + }; + self.inner = AuthenticationUpgradeApplyState::Upgrade { + future: fut, + role, + }; + } + AuthenticationUpgradeApplyState::Upgrade { mut future, role } => { + match Future::poll(Pin::new(&mut future), cx) { + Poll::Pending => { + self.inner = AuthenticationUpgradeApplyState::Upgrade { future, role }; + return Poll::Pending + } + Poll::Ready(Ok(x)) => { + debug!("Successfully applied negotiated protocol"); + return Poll::Ready(Ok((x, role))) + } + Poll::Ready(Err(e)) => { + debug!("Failed to apply negotiated protocol"); + return Poll::Ready(Err(UpgradeError::Apply(e))) + } + } + } + AuthenticationUpgradeApplyState::Undefined => + panic!("AuthenticationUpgradeApplyState::poll called after completion") + } + } + } +} + /// Tries to perform an upgrade on an inbound connection or substream. pub fn apply_inbound(conn: C, up: U) -> InboundUpgradeApply where @@ -185,7 +315,8 @@ where loop { match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) { OutboundUpgradeApplyState::Init { mut future, upgrade } => { - let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? { + // TODO: Don't ignore the SimOpenRole here. Instead add assert!. + let (info, connection, _) = match Future::poll(Pin::new(&mut future), cx)? { Poll::Ready(x) => x, Poll::Pending => { self.inner = OutboundUpgradeApplyState::Init { future, upgrade }; @@ -230,4 +361,3 @@ impl AsRef<[u8]> for NameWrap { self.0.protocol_name() } } - diff --git a/core/tests/transport_upgrade.rs b/core/tests/transport_upgrade.rs index eecace3e46f..162120ac50e 100644 --- a/core/tests/transport_upgrade.rs +++ b/core/tests/transport_upgrade.rs @@ -83,7 +83,7 @@ fn upgrade_pipeline() { let listener_id = listener_keys.public().into_peer_id(); let listener_noise_keys = noise::Keypair::::new().into_authentic(&listener_keys).unwrap(); let listener_transport = MemoryTransport::default() - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(listener_noise_keys).into_authenticated()) .apply(HelloUpgrade {}) .apply(HelloUpgrade {}) @@ -99,7 +99,7 @@ fn upgrade_pipeline() { let dialer_id = dialer_keys.public().into_peer_id(); let dialer_noise_keys = noise::Keypair::::new().into_authentic(&dialer_keys).unwrap(); let dialer_transport = MemoryTransport::default() - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(dialer_noise_keys).into_authenticated()) .apply(HelloUpgrade {}) .apply(HelloUpgrade {}) @@ -136,4 +136,3 @@ fn upgrade_pipeline() { async_std::task::spawn(server); async_std::task::block_on(client); } - diff --git a/core/tests/util.rs b/core/tests/util.rs index c20a2c59305..0eff3e270bc 100644 --- a/core/tests/util.rs +++ b/core/tests/util.rs @@ -32,7 +32,7 @@ pub fn test_network(cfg: NetworkConfig) -> TestNetwork { let local_public_key = local_key.public(); let noise_keys = noise::Keypair::::new().into_authentic(&local_key).unwrap(); let transport: TestTransport = tcp::TcpConfig::new() - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(mplex::MplexConfig::new()) .boxed(); diff --git a/examples/chat-tokio.rs b/examples/chat-tokio.rs index 6fc28c198dd..27cc8214ff6 100644 --- a/examples/chat-tokio.rs +++ b/examples/chat-tokio.rs @@ -73,7 +73,7 @@ async fn main() -> Result<(), Box> { // Create a tokio-based TCP transport use noise for authenticated // encryption and Mplex for multiplexing of substreams on a TCP stream. let transport = TokioTcpConfig::new().nodelay(true) - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(mplex::MplexConfig::new()) .boxed(); diff --git a/examples/ipfs-private.rs b/examples/ipfs-private.rs index ce0d875f336..be9dd76efab 100644 --- a/examples/ipfs-private.rs +++ b/examples/ipfs-private.rs @@ -78,7 +78,7 @@ pub fn build_transport( None => EitherTransport::Right(base_transport), }; maybe_encrypted - .upgrade(Version::V1) + .upgrade() .authenticate(noise_config) .multiplex(yamux_config) .timeout(Duration::from_secs(20)) diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index b8b95ce5e02..b2ad163fab2 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -169,7 +169,7 @@ where // TODO: Clone needed to embed ListenerSelectFuture. Still needed? I::Item: AsRef<[u8]> + Clone { - type Output = Result<(I::Item, Negotiated), NegotiationError>; + type Output = Result<(I::Item, Negotiated, SimOpenRole), NegotiationError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -301,7 +301,7 @@ where SeqState::Responder { mut responder } => { match Pin::new(&mut responder ).poll(cx) { - Poll::Ready(res) => return Poll::Ready(res), + Poll::Ready(res) => return Poll::Ready(res.map(|(p, io)| (p, io, SimOpenRole::Responder))), Poll::Pending => { *this.state = SeqState::Responder { responder }; return Poll::Pending @@ -337,7 +337,7 @@ where log::debug!("Dialer: Expecting proposed protocol: {}", p); let hl = HeaderLine::from(Version::V1Lazy); let io = Negotiated::expecting(io.into_reader(), p, Some(hl)); - return Poll::Ready(Ok((protocol, io))) + return Poll::Ready(Ok((protocol, io, SimOpenRole::Initiator))) } } } @@ -382,7 +382,7 @@ where Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { log::debug!("Dialer: Received confirmation for protocol: {}", p); let io = Negotiated::completed(io.into_inner()); - return Poll::Ready(Ok((protocol, io))); + return Poll::Ready(Ok((protocol, io, SimOpenRole::Initiator))); } Message::NotAvailable => { log::debug!("Dialer: Received rejection of protocol: {}", @@ -414,7 +414,8 @@ enum SimOpenState { Done, } -enum SimOpenRole { +// TODO: Rename this to `Role` in general? +pub enum SimOpenRole { Initiator, Responder, } @@ -589,7 +590,8 @@ where I: Iterator, I::Item: AsRef<[u8]> { - type Output = Result<(I::Item, Negotiated), NegotiationError>; + // TODO: Is it a hack that DialerSelectPar returns the simopenrole? + type Output = Result<(I::Item, Negotiated, SimOpenRole), NegotiationError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -687,7 +689,7 @@ where log::debug!("Dialer: Expecting proposed protocol: {}", p); let io = Negotiated::expecting(io.into_reader(), p, None); - return Poll::Ready(Ok((protocol, io))) + return Poll::Ready(Ok((protocol, io, SimOpenRole::Initiator))) } ParState::Done => panic!("ParState::poll called after completion") diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 29bc37914b0..4f073f8bb43 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -79,7 +79,7 @@ //! let socket = TcpStream::connect("127.0.0.1:10333").await.unwrap(); //! //! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"]; -//! let (protocol, _io) = dialer_select_proto(socket, protos, Version::V1).await.unwrap(); +//! let (protocol, _io, _role) = dialer_select_proto(socket, protos, Version::V1).await.unwrap(); //! //! println!("Negotiated protocol: {:?}", protocol); //! // You can now use `_io` to communicate with the remote. @@ -96,7 +96,7 @@ mod tests; pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError}; pub use self::protocol::ProtocolError; -pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture}; +pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture, SimOpenRole}; pub use self::listener_select::{listener_select_proto, ListenerSelectFuture}; /// Supported multistream-select versions. diff --git a/misc/multistream-select/src/tests.rs b/misc/multistream-select/src/tests.rs index 956c30aa5da..7952af53333 100644 --- a/misc/multistream-select/src/tests.rs +++ b/misc/multistream-select/src/tests.rs @@ -52,7 +52,7 @@ fn select_proto_basic() { let client = async move { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; - let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version) + let (proto, mut io, _) = dialer_select_proto(connec, protos.into_iter(), version) .await.unwrap(); assert_eq!(proto, b"/proto2"); @@ -103,7 +103,7 @@ fn negotiation_failed() { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let mut io = match dialer_select_proto(connec, dial_protos.into_iter(), version).await { Err(NegotiationError::Failed) => return, - Ok((_, io)) => io, + Ok((_, io, _)) => io, Err(_) => panic!() }; // The dialer may write a payload that is even sent before it @@ -192,7 +192,7 @@ fn select_proto_parallel() { let client = async_std::task::spawn(async move { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; - let (proto, io) = dialer_select_proto_parallel(connec, protos.into_iter(), version) + let (proto, io, _) = dialer_select_proto_parallel(connec, protos.into_iter(), version) .await.unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); @@ -223,7 +223,7 @@ fn select_proto_serial() { let client = async_std::task::spawn(async move { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; - let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version) + let (proto, io, _) = dialer_select_proto_serial(connec, protos.into_iter(), version) .await.unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); @@ -247,7 +247,7 @@ fn simultaneous_open() { let server = async move { let connec = listener.accept().await.unwrap().0; let protos = vec![b"/proto1", b"/proto2"]; - let (proto, io) = dialer_select_proto_serial(connec, protos, version).await.unwrap(); + let (proto, io, _) = dialer_select_proto_serial(connec, protos, version).await.unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); }; @@ -255,7 +255,7 @@ fn simultaneous_open() { let client = async move { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; - let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version) + let (proto, io, _) = dialer_select_proto_serial(connec, protos.into_iter(), version) .await.unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); diff --git a/misc/multistream-select/tests/transport.rs b/misc/multistream-select/tests/transport.rs index 63c6ed90f93..7a3b8323679 100644 --- a/misc/multistream-select/tests/transport.rs +++ b/misc/multistream-select/tests/transport.rs @@ -40,11 +40,12 @@ use std::{io, task::{Context, Poll}}; type TestTransport = transport::Boxed<(PeerId, StreamMuxerBox)>; type TestNetwork = Network; -fn mk_transport(up: upgrade::Version) -> (PeerId, TestTransport) { +// TODO: Fix _up +fn mk_transport(_up: upgrade::Version) -> (PeerId, TestTransport) { let keys = identity::Keypair::generate_ed25519(); let id = keys.public().into_peer_id(); (id, MemoryTransport::default() - .upgrade(up) + .upgrade() .authenticate(PlainText2Config { local_public_key: keys.public() }) .multiplex(MplexConfig::default()) .boxed()) diff --git a/muxers/mplex/benches/split_send_size.rs b/muxers/mplex/benches/split_send_size.rs index c31703cb927..7613fa8e918 100644 --- a/muxers/mplex/benches/split_send_size.rs +++ b/muxers/mplex/benches/split_send_size.rs @@ -148,7 +148,7 @@ fn tcp_transport(split_send_size: usize) -> BenchTransport { mplex.set_split_send_size(split_send_size); libp2p_tcp::TcpConfig::new().nodelay(true) - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(PlainText2Config { local_public_key }) .multiplex(mplex) .timeout(Duration::from_secs(5)) @@ -163,7 +163,7 @@ fn mem_transport(split_send_size: usize) -> BenchTransport { mplex.set_split_send_size(split_send_size); transport::MemoryTransport::default() - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(PlainText2Config { local_public_key }) .multiplex(mplex) .timeout(Duration::from_secs(5)) diff --git a/protocols/gossipsub/src/lib.rs b/protocols/gossipsub/src/lib.rs index ddba0f69a1e..b757688db1d 100644 --- a/protocols/gossipsub/src/lib.rs +++ b/protocols/gossipsub/src/lib.rs @@ -87,7 +87,7 @@ //! // This is test transport (memory). //! let noise_keys = libp2p_noise::Keypair::::new().into_authentic(&local_key).unwrap(); //! let transport = MemoryTransport::default() -//! .upgrade(libp2p_core::upgrade::Version::V1) +//! .upgrade() //! .authenticate(libp2p_noise::NoiseConfig::xx(noise_keys).into_authenticated()) //! .multiplex(libp2p_mplex::MplexConfig::new()) //! .boxed(); diff --git a/protocols/gossipsub/tests/smoke.rs b/protocols/gossipsub/tests/smoke.rs index cb7c23d1747..aacaf5a5de7 100644 --- a/protocols/gossipsub/tests/smoke.rs +++ b/protocols/gossipsub/tests/smoke.rs @@ -143,7 +143,7 @@ fn build_node() -> (Multiaddr, Swarm) { let public_key = key.public(); let transport = MemoryTransport::default() - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(PlainText2Config { local_public_key: public_key.clone(), }) diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 8d8eb044486..95a77653c4c 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -464,7 +464,7 @@ mod tests { let pubkey = id_keys.public(); let transport = TcpConfig::new() .nodelay(true) - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(MplexConfig::new()) .boxed(); diff --git a/protocols/kad/src/behaviour/test.rs b/protocols/kad/src/behaviour/test.rs index 6b4516ad9ee..ed05d6ca0ad 100644 --- a/protocols/kad/src/behaviour/test.rs +++ b/protocols/kad/src/behaviour/test.rs @@ -59,7 +59,7 @@ fn build_node_with_config(cfg: KademliaConfig) -> (Multiaddr, TestSwarm) { let local_public_key = local_key.public(); let noise_keys = noise::Keypair::::new().into_authentic(&local_key).unwrap(); let transport = MemoryTransport::default() - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(yamux::YamuxConfig::default()) .boxed(); diff --git a/protocols/ping/tests/ping.rs b/protocols/ping/tests/ping.rs index ede6d49bfa7..83abb1f07d4 100644 --- a/protocols/ping/tests/ping.rs +++ b/protocols/ping/tests/ping.rs @@ -193,7 +193,7 @@ fn mk_transport(muxer: MuxerChoice) -> ( let noise_keys = noise::Keypair::::new().into_authentic(&id_keys).unwrap(); (peer_id, TcpConfig::new() .nodelay(true) - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(match muxer { MuxerChoice::Yamux => diff --git a/protocols/relay/examples/relay.rs b/protocols/relay/examples/relay.rs index 79a3762d691..16c536fdd10 100644 --- a/protocols/relay/examples/relay.rs +++ b/protocols/relay/examples/relay.rs @@ -95,7 +95,7 @@ fn main() -> Result<(), Box> { }; let transport = relay_wrapped_transport - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(plaintext) .multiplex(libp2p_yamux::YamuxConfig::default()) .boxed(); diff --git a/protocols/relay/src/lib.rs b/protocols/relay/src/lib.rs index 7f81860534d..5944798aac1 100644 --- a/protocols/relay/src/lib.rs +++ b/protocols/relay/src/lib.rs @@ -46,7 +46,7 @@ //! ); //! //! let transport = relay_transport -//! .upgrade(upgrade::Version::V1) +//! .upgrade() //! .authenticate(plaintext) //! .multiplex(YamuxConfig::default()) //! .boxed(); diff --git a/protocols/relay/tests/lib.rs b/protocols/relay/tests/lib.rs index 40890fdb024..ca62f4c3903 100644 --- a/protocols/relay/tests/lib.rs +++ b/protocols/relay/tests/lib.rs @@ -1228,7 +1228,7 @@ fn build_swarm(reachability: Reachability, relay_mode: RelayMode) -> Swarm Swarm { libp2p_relay::new_transport_and_behaviour(RelayConfig::default(), transport); let transport = transport - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(plaintext) .multiplex(libp2p_yamux::YamuxConfig::default()) .boxed(); @@ -1288,7 +1288,7 @@ fn build_keep_alive_only_swarm() -> Swarm { let transport = MemoryTransport::default(); let transport = transport - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(plaintext) .multiplex(libp2p_yamux::YamuxConfig::default()) .boxed(); diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index 734a6729f9e..5f6d80cb18f 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -366,7 +366,7 @@ fn mk_transport() -> (PeerId, transport::Boxed<(PeerId, StreamMuxerBox)>) { let noise_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); (peer_id, TcpConfig::new() .nodelay(true) - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(libp2p_yamux::YamuxConfig::default()) .boxed()) diff --git a/src/lib.rs b/src/lib.rs index e675b40e7f0..0795d8f1a80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -175,7 +175,7 @@ pub async fn development_transport(keypair: identity::Keypair) .expect("Signing libp2p-noise static DH keypair failed."); Ok(transport - .upgrade(core::upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(core::upgrade::SelectUpgrade::new(yamux::YamuxConfig::default(), mplex::MplexConfig::default())) .timeout(std::time::Duration::from_secs(20)) diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 28ea83ee589..c824a980d7a 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -1155,7 +1155,7 @@ mod tests { let pubkey = id_keys.public(); let noise_keys = noise::Keypair::::new().into_authentic(&id_keys).unwrap(); let transport = transport::MemoryTransport::default() - .upgrade(upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(libp2p_mplex::MplexConfig::new()) .boxed(); diff --git a/transports/noise/src/lib.rs b/transports/noise/src/lib.rs index 6b09f636cd1..5df42564149 100644 --- a/transports/noise/src/lib.rs +++ b/transports/noise/src/lib.rs @@ -47,7 +47,7 @@ //! let id_keys = identity::Keypair::generate_ed25519(); //! let dh_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); //! let noise = NoiseConfig::xx(dh_keys).into_authenticated(); -//! let builder = TcpConfig::new().upgrade(upgrade::Version::V1).authenticate(noise); +//! let builder = TcpConfig::new().upgrade().authenticate(noise); //! // let transport = builder.multiplex(...); //! # } //! ``` diff --git a/transports/noise/tests/smoke.rs b/transports/noise/tests/smoke.rs index 4a4c81b5eb8..829ebc6bf08 100644 --- a/transports/noise/tests/smoke.rs +++ b/transports/noise/tests/smoke.rs @@ -36,7 +36,7 @@ fn core_upgrade_compat() { let id_keys = identity::Keypair::generate_ed25519(); let dh_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); let noise = NoiseConfig::xx(dh_keys).into_authenticated(); - let _ = TcpConfig::new().upgrade(upgrade::Version::V1).authenticate(noise); + let _ = TcpConfig::new().upgrade().authenticate(noise); } #[test] From 11542d1902659420e92d691e45e2b7fd2c3d5687 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sun, 4 Jul 2021 19:39:33 +0200 Subject: [PATCH 03/23] core/src/transport/upgrade: Make DialFuture aware of SimOpenRole --- core/src/transport/upgrade.rs | 36 ++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index 57ec27a5fdb..ca45468c75d 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -429,21 +429,29 @@ where /// The [`Transport::Dial`] future of an [`Upgrade`]d transport. pub struct DialUpgradeFuture where - U: OutboundUpgrade>, + U: InboundUpgrade> + OutboundUpgrade, + Output = >>::Output, + Error = >>::Error + >, C: AsyncRead + AsyncWrite + Unpin, { future: Pin>, - upgrade: future::Either, (Option<(PeerId, SimOpenRole)>, OutboundUpgradeApply)> + upgrade: future::Either< + Option, + (Option<(PeerId, SimOpenRole)>, Either, InboundUpgradeApply>), + > } -impl Future for DialUpgradeFuture +impl Future for DialUpgradeFuture where F: TryFuture, C: AsyncRead + AsyncWrite + Unpin, - U: OutboundUpgrade, Output = D>, - U::Error: Error + U: InboundUpgrade> + OutboundUpgrade, + Output = >>::Output, + Error = >>::Error + >, { - type Output = Result<((PeerId, D), SimOpenRole), TransportUpgradeError>; + type Output = Result<((PeerId, >>::Output), SimOpenRole), TransportUpgradeError>>::Error>>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // We use a `this` variable because the compiler can't mutably borrow multiple times @@ -457,8 +465,15 @@ where Ok(v) => v, Err(err) => return Poll::Ready(Err(err)), }; - let u = up.take().expect("DialUpgradeFuture is constructed with Either::Left(Some)."); - future::Either::Right((Some((i, r)), apply_outbound(c, u, upgrade::Version::V1))) + let upgrade = up.take().map(|u| match r { + SimOpenRole::Initiator => { + Either::Left(apply_outbound(c, u, upgrade::Version::V1)) + }, + SimOpenRole::Responder => { + Either::Right(apply_inbound(c, u)) + } + }).take().expect("DialUpgradeFuture is constructed with Either::Left(Some)."); + future::Either::Right((Some((i, r)), upgrade)) } future::Either::Right((ref mut i, ref mut up)) => { let d = match ready!(Future::poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { @@ -475,7 +490,10 @@ where impl Unpin for DialUpgradeFuture where - U: OutboundUpgrade>, + U: InboundUpgrade> + OutboundUpgrade, + Output = >>::Output, + Error = >>::Error + >, C: AsyncRead + AsyncWrite + Unpin, { } From 618ccd593d804a6675638692c554fd0ea0552db1 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sun, 4 Jul 2021 20:01:12 +0200 Subject: [PATCH 04/23] core/src/transport: Use Transport::and_then for Authenticated::apply --- core/src/transport/upgrade.rs | 66 +++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index ca45468c75d..f01de0fa597 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -106,7 +106,7 @@ where /// * I/O upgrade: `C -> (PeerId, D)`. /// * Transport output: `C -> (PeerId, D)` pub fn authenticate(self, upgrade: U) -> Authenticated< - AndThen Authenticate + Clone> + AndThen AuthenticationUpgradeApply + Clone> > where T: Transport, C: AsyncRead + AsyncWrite + Unpin, @@ -121,7 +121,7 @@ where /// Same as [`Builder::authenticate`] with the option to choose the [`upgrade::Version`] used to /// upgrade the connection. pub fn authenticate_with_version(self, upgrade: U, version: upgrade::Version) -> Authenticated< - AndThen Authenticate + Clone> + AndThen AuthenticationUpgradeApply + Clone> > where T: Transport, C: AsyncRead + AsyncWrite + Unpin, @@ -131,40 +131,42 @@ where E: Error + 'static, { Authenticated(Builder::new(self.inner.and_then(move |conn, endpoint| { - Authenticate { - inner: upgrade::apply_authentication(conn, upgrade, endpoint, version) - } + upgrade::apply_authentication(conn, upgrade, endpoint, version) }))) } } -/// An upgrade that authenticates the remote peer, typically -/// in the context of negotiating a secure channel. +/// An upgrade that negotiates a (sub)stream multiplexer on +/// top of an authenticated transport. /// -/// Configured through [`Builder::authenticate`]. +/// Configured through [`Authenticated::multiplex`]. #[pin_project::pin_project] -pub struct Authenticate +pub struct Multiplex where C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade> + U: InboundUpgrade> + OutboundUpgrade>, { + peer_id: Option, #[pin] - inner: AuthenticationUpgradeApply + upgrade: EitherUpgrade, } -impl Future for Authenticate +impl Future for Multiplex where C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade, - Output = >>::Output, - Error = >>::Error - > + U: InboundUpgrade, Output = M, Error = E>, + U: OutboundUpgrade, Output = M, Error = E> { - type Output = as Future>::Output; + type Output = Result<(PeerId, M), UpgradeError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - Future::poll(this.inner, cx) + let m = match ready!(Future::poll(this.upgrade, cx)) { + Ok(m) => m, + Err(err) => return Poll::Ready(Err(err)), + }; + let i = this.peer_id.take().expect("Multiplex future polled after completion."); + Poll::Ready(Ok((i, m))) } } @@ -173,23 +175,23 @@ where /// /// Configured through [`Authenticated::multiplex`]. #[pin_project::pin_project] -pub struct Multiplex +pub struct AuthenticatedUpgrade where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade> + OutboundUpgrade>, { - peer_id: Option, + peer_id_and_role: Option<(PeerId, SimOpenRole)>, #[pin] upgrade: EitherUpgrade, } -impl Future for Multiplex +impl Future for AuthenticatedUpgrade where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = M, Error = E>, U: OutboundUpgrade, Output = M, Error = E> { - type Output = Result<(PeerId, M), UpgradeError>; + type Output = Result<((PeerId, M), SimOpenRole), UpgradeError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -197,8 +199,8 @@ where Ok(m) => m, Err(err) => return Poll::Ready(Err(err)), }; - let i = this.peer_id.take().expect("Multiplex future polled after completion."); - Poll::Ready(Ok((i, m))) + let (peer_id, role) = this.peer_id_and_role.take().expect("AuthenticatedUpgrade future polled after completion."); + Poll::Ready(Ok(((peer_id, m), role))) } } @@ -223,7 +225,7 @@ where /// * Transport output: `(PeerId, C) -> (PeerId, D)`. // // TODO: Do we need an `apply` with a version? - pub fn apply(self, upgrade: U) -> Authenticated> + pub fn apply(self, upgrade: U) -> Authenticated AuthenticatedUpgrade + Clone>> where T: Transport, C: AsyncRead + AsyncWrite + Unpin, @@ -232,7 +234,19 @@ where U: OutboundUpgrade, Output = D, Error = E> + Clone, E: Error + 'static, { - Authenticated(Builder::new(Upgrade::new(self.0.inner, upgrade))) + Authenticated(Builder::new(self.0.inner.and_then(move |((i, c), r), endpoint| { + let upgrade = match r { + SimOpenRole::Initiator => { + // TODO: Offer version that allows choosing the Version. + Either::Left(upgrade::apply_outbound(c, upgrade, upgrade::Version::default())) + }, + SimOpenRole::Responder => { + Either::Right(upgrade::apply_inbound(c, upgrade)) + + } + }; + AuthenticatedUpgrade { peer_id_and_role: Some((i, r)), upgrade } + }))) } /// Upgrades the transport with a (sub)stream multiplexer. From 11341738df5f2fcd5db0d889e0bde2309755b2e2 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sun, 4 Jul 2021 20:48:17 +0200 Subject: [PATCH 05/23] core/: Clean type structure --- core/src/transport.rs | 1 - core/src/transport/upgrade.rs | 378 ++++---------------------------- core/src/upgrade/apply.rs | 263 +++++++++++----------- core/tests/transport_upgrade.rs | 2 +- core/tests/util.rs | 1 - 5 files changed, 177 insertions(+), 468 deletions(-) diff --git a/core/src/transport.rs b/core/src/transport.rs index 94ca830e869..5711c06044e 100644 --- a/core/src/transport.rs +++ b/core/src/transport.rs @@ -46,7 +46,6 @@ pub use self::boxed::Boxed; pub use self::choice::OrTransport; pub use self::memory::MemoryTransport; pub use self::optional::OptionalTransport; -pub use self::upgrade::Upgrade; /// A transport provides connection-oriented communication between two peers /// through ordered streams of data (i.e. connections). diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index f01de0fa597..d3bfa73397d 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -28,7 +28,6 @@ use crate::{ transport::{ Transport, TransportError, - ListenerEvent, and_then::AndThen, boxed::boxed, timeout::TransportTimeout, @@ -36,14 +35,12 @@ use crate::{ muxing::{StreamMuxer, StreamMuxerBox}, upgrade::{ self, - AuthenticationUpgradeApply, OutboundUpgrade, InboundUpgrade, - apply_inbound, - apply_outbound, UpgradeError, OutboundUpgradeApply, - InboundUpgradeApply + InboundUpgradeApply, + AuthenticationUpgradeApply, }, PeerId }; @@ -51,7 +48,6 @@ use futures::{prelude::*, ready, future::Either}; use multiaddr::Multiaddr; use std::{ error::Error, - fmt, pin::Pin, task::{Context, Poll}, time::Duration @@ -136,74 +132,6 @@ where } } -/// An upgrade that negotiates a (sub)stream multiplexer on -/// top of an authenticated transport. -/// -/// Configured through [`Authenticated::multiplex`]. -#[pin_project::pin_project] -pub struct Multiplex -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade>, -{ - peer_id: Option, - #[pin] - upgrade: EitherUpgrade, -} - -impl Future for Multiplex -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade, Output = M, Error = E>, - U: OutboundUpgrade, Output = M, Error = E> -{ - type Output = Result<(PeerId, M), UpgradeError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let m = match ready!(Future::poll(this.upgrade, cx)) { - Ok(m) => m, - Err(err) => return Poll::Ready(Err(err)), - }; - let i = this.peer_id.take().expect("Multiplex future polled after completion."); - Poll::Ready(Ok((i, m))) - } -} - -/// An upgrade that negotiates a (sub)stream multiplexer on -/// top of an authenticated transport. -/// -/// Configured through [`Authenticated::multiplex`]. -#[pin_project::pin_project] -pub struct AuthenticatedUpgrade -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade>, -{ - peer_id_and_role: Option<(PeerId, SimOpenRole)>, - #[pin] - upgrade: EitherUpgrade, -} - -impl Future for AuthenticatedUpgrade -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade, Output = M, Error = E>, - U: OutboundUpgrade, Output = M, Error = E> -{ - type Output = Result<((PeerId, M), SimOpenRole), UpgradeError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let m = match ready!(Future::poll(this.upgrade, cx)) { - Ok(m) => m, - Err(err) => return Poll::Ready(Err(err)), - }; - let (peer_id, role) = this.peer_id_and_role.take().expect("AuthenticatedUpgrade future polled after completion."); - Poll::Ready(Ok(((peer_id, m), role))) - } -} - /// An transport with peer authentication, obtained from [`Builder::authenticate`]. #[derive(Clone)] pub struct Authenticated(Builder); @@ -225,16 +153,18 @@ where /// * Transport output: `(PeerId, C) -> (PeerId, D)`. // // TODO: Do we need an `apply` with a version? - pub fn apply(self, upgrade: U) -> Authenticated AuthenticatedUpgrade + Clone>> + // + // TODO: Rename to `and_then`. + pub fn apply(self, upgrade: U) -> Authenticated UpgradeAuthenticated + Clone>> where - T: Transport, + T: Transport, C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D, Error = E>, U: OutboundUpgrade, Output = D, Error = E> + Clone, E: Error + 'static, { - Authenticated(Builder::new(self.0.inner.and_then(move |((i, c), r), endpoint| { + Authenticated(Builder::new(self.0.inner.and_then(move |((i, r), c), _endpoint| { let upgrade = match r { SimOpenRole::Initiator => { // TODO: Offer version that allows choosing the Version. @@ -245,7 +175,7 @@ where } }; - AuthenticatedUpgrade { peer_id_and_role: Some((i, r)), upgrade } + UpgradeAuthenticated { user_data: Some((i, r)), upgrade } }))) } @@ -260,16 +190,16 @@ where /// * I/O upgrade: `C -> M`. /// * Transport output: `(PeerId, C) -> (PeerId, M)`. pub fn multiplex(self, upgrade: U) -> Multiplexed< - AndThen Multiplex + Clone> + AndThen UpgradeAuthenticated + Clone> > where - T: Transport, + T: Transport, C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, U: InboundUpgrade, Output = M, Error = E>, U: OutboundUpgrade, Output = M, Error = E> + Clone, E: Error + 'static, { - Multiplexed(self.0.inner.and_then(move |((i, c), r), endpoint| { + Multiplexed(self.0.inner.and_then(move |((i, r), c), _endpoint| { let upgrade = match r { SimOpenRole::Initiator => { // TODO: Offer version that allows choosing the Version. @@ -280,7 +210,7 @@ where } }; - Multiplex { peer_id: Some(i), upgrade } + UpgradeAuthenticated { user_data: Some(i), upgrade } })) } @@ -288,6 +218,40 @@ where // TODO: Add changelog entry that multiplex_ext is removed. } +/// An upgrade that negotiates a (sub)stream multiplexer on +/// top of an authenticated transport. +/// +/// Configured through [`Authenticated::multiplex`]. +#[pin_project::pin_project] +pub struct UpgradeAuthenticated +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade> + OutboundUpgrade>, +{ + user_data: Option, + #[pin] + upgrade: EitherUpgrade, +} + +impl Future for UpgradeAuthenticated +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade, Output = M, Error = E>, + U: OutboundUpgrade, Output = M, Error = E> +{ + type Output = Result<(D, M), UpgradeError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let m = match ready!(Future::poll(this.upgrade, cx)) { + Ok(m) => m, + Err(err) => return Poll::Ready(Err(err)), + }; + let user_data = this.user_data.take().expect("UpgradeAuthenticated future polled after completion."); + Poll::Ready(Ok((user_data, m))) + } +} + /// A authenticated and multiplexed transport, obtained from /// [`Authenticated::multiplex`]. #[derive(Clone)] @@ -354,253 +318,3 @@ where /// An inbound or outbound upgrade. type EitherUpgrade = future::Either, InboundUpgradeApply>; - -/// A custom upgrade on an [`Authenticated`] transport. -/// -/// See [`Transport::upgrade`] -#[derive(Debug, Copy, Clone)] -pub struct Upgrade { inner: T, upgrade: U } - -impl Upgrade { - pub fn new(inner: T, upgrade: U) -> Self { - Upgrade { inner, upgrade } - } -} - -impl Transport for Upgrade -where - T: Transport, - T::Error: 'static, - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade, Output = D, Error = E>, - U: OutboundUpgrade, Output = D, Error = E> + Clone, - E: Error + 'static -{ - type Output = ((PeerId, D), SimOpenRole); - type Error = TransportUpgradeError; - type Listener = ListenerStream; - type ListenerUpgrade = ListenerUpgradeFuture; - type Dial = DialUpgradeFuture; - - fn dial(self, addr: Multiaddr) -> Result> { - let future = self.inner.dial(addr) - .map_err(|err| err.map(TransportUpgradeError::Transport))?; - Ok(DialUpgradeFuture { - future: Box::pin(future), - upgrade: future::Either::Left(Some(self.upgrade)) - }) - } - - fn listen_on(self, addr: Multiaddr) -> Result> { - let stream = self.inner.listen_on(addr) - .map_err(|err| err.map(TransportUpgradeError::Transport))?; - Ok(ListenerStream { - stream: Box::pin(stream), - upgrade: self.upgrade - }) - } - - fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { - self.inner.address_translation(server, observed) - } -} - -/// Errors produced by a transport upgrade. -#[derive(Debug)] -pub enum TransportUpgradeError { - /// Error in the transport. - Transport(T), - /// Error while upgrading to a protocol. - Upgrade(UpgradeError), -} - -impl fmt::Display for TransportUpgradeError -where - T: fmt::Display, - U: fmt::Display, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TransportUpgradeError::Transport(e) => write!(f, "Transport error: {}", e), - TransportUpgradeError::Upgrade(e) => write!(f, "Upgrade error: {}", e), - } - } -} - -impl Error for TransportUpgradeError -where - T: Error + 'static, - U: Error + 'static, -{ - fn source(&self) -> Option<&(dyn Error + 'static)> { - match self { - TransportUpgradeError::Transport(e) => Some(e), - TransportUpgradeError::Upgrade(e) => Some(e), - } - } -} - -/// The [`Transport::Dial`] future of an [`Upgrade`]d transport. -pub struct DialUpgradeFuture -where - U: InboundUpgrade> + OutboundUpgrade, - Output = >>::Output, - Error = >>::Error - >, - C: AsyncRead + AsyncWrite + Unpin, -{ - future: Pin>, - upgrade: future::Either< - Option, - (Option<(PeerId, SimOpenRole)>, Either, InboundUpgradeApply>), - > -} - -impl Future for DialUpgradeFuture -where - F: TryFuture, - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade, - Output = >>::Output, - Error = >>::Error - >, -{ - type Output = Result<((PeerId, >>::Output), SimOpenRole), TransportUpgradeError>>::Error>>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // We use a `this` variable because the compiler can't mutably borrow multiple times - // accross a `Deref`. - let this = &mut *self; - - loop { - this.upgrade = match this.upgrade { - future::Either::Left(ref mut up) => { - let ((i, c), r) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx).map_err(TransportUpgradeError::Transport)) { - Ok(v) => v, - Err(err) => return Poll::Ready(Err(err)), - }; - let upgrade = up.take().map(|u| match r { - SimOpenRole::Initiator => { - Either::Left(apply_outbound(c, u, upgrade::Version::V1)) - }, - SimOpenRole::Responder => { - Either::Right(apply_inbound(c, u)) - } - }).take().expect("DialUpgradeFuture is constructed with Either::Left(Some)."); - future::Either::Right((Some((i, r)), upgrade)) - } - future::Either::Right((ref mut i, ref mut up)) => { - let d = match ready!(Future::poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { - Ok(d) => d, - Err(err) => return Poll::Ready(Err(err)), - }; - let (i, r) = i.take().expect("DialUpgradeFuture polled after completion."); - return Poll::Ready(Ok(((i, d), r))) - } - } - } - } -} - -impl Unpin for DialUpgradeFuture -where - U: InboundUpgrade> + OutboundUpgrade, - Output = >>::Output, - Error = >>::Error - >, - C: AsyncRead + AsyncWrite + Unpin, -{ -} - -/// The [`Transport::Listener`] stream of an [`Upgrade`]d transport. -pub struct ListenerStream { - stream: Pin>, - upgrade: U -} - -impl Stream for ListenerStream -where - S: TryStream, Error = E>, - F: TryFuture, - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade, Output = D> + Clone -{ - type Item = Result, TransportUpgradeError>, TransportUpgradeError>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match ready!(TryStream::try_poll_next(self.stream.as_mut(), cx)) { - Some(Ok(event)) => { - let event = event - .map(move |future| { - ListenerUpgradeFuture { - future: Box::pin(future), - upgrade: future::Either::Left(Some(self.upgrade.clone())) - } - }) - .map_err(TransportUpgradeError::Transport); - Poll::Ready(Some(Ok(event))) - } - Some(Err(err)) => { - Poll::Ready(Some(Err(TransportUpgradeError::Transport(err)))) - } - None => Poll::Ready(None) - } - } -} - -impl Unpin for ListenerStream { -} - -/// The [`Transport::ListenerUpgrade`] future of an [`Upgrade`]d transport. -pub struct ListenerUpgradeFuture -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> -{ - future: Pin>, - upgrade: future::Either, (Option<(PeerId, SimOpenRole)>, InboundUpgradeApply)> -} - -impl Future for ListenerUpgradeFuture -where - F: TryFuture, - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade, Output = D>, - U::Error: Error -{ - type Output = Result<((PeerId, D), SimOpenRole), TransportUpgradeError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // We use a `this` variable because the compiler can't mutably borrow multiple times - // accross a `Deref`. - let this = &mut *self; - - loop { - this.upgrade = match this.upgrade { - future::Either::Left(ref mut up) => { - let ((i, c), r) = match ready!(TryFuture::try_poll(this.future.as_mut(), cx).map_err(TransportUpgradeError::Transport)) { - Ok(v) => v, - Err(err) => return Poll::Ready(Err(err)) - }; - let u = up.take().expect("ListenerUpgradeFuture is constructed with Either::Left(Some)."); - future::Either::Right((Some((i, r)), apply_inbound(c, u))) - } - future::Either::Right((ref mut i, ref mut up)) => { - let d = match ready!(TryFuture::try_poll(Pin::new(up), cx).map_err(TransportUpgradeError::Upgrade)) { - Ok(v) => v, - Err(err) => return Poll::Ready(Err(err)) - }; - let (i, r) = i.take().expect("ListenerUpgradeFuture polled after completion."); - return Poll::Ready(Ok(((i, d), r))) - } - } - } - } -} - -impl Unpin for ListenerUpgradeFuture -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> -{ -} diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 92712da9303..76e726a554b 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{ConnectedPoint, Negotiated}; +use crate::{ConnectedPoint, Negotiated, PeerId}; use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName}; use futures::{future::{Either, TryFutureExt, MapOk}, prelude::*}; use log::debug; @@ -41,136 +41,6 @@ where } } - -/// Applies an authentication upgrade to the inbound or outbound direction of a connection or substream. -// -// TODO: This is specific to authentication upgrades, given that it can handle simultaneous open. -// Should this be moved to transport.rs? -pub fn apply_authentication(conn: C, up: U, cp: ConnectedPoint, v: Version) - -> AuthenticationUpgradeApply -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade>, -{ - let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); - - AuthenticationUpgradeApply { - inner: AuthenticationUpgradeApplyState::Init{ - future: match cp { - ConnectedPoint::Dialer { .. } => Either::Left( - multistream_select::dialer_select_proto(conn, iter, v), - ), - ConnectedPoint::Listener { .. } => Either::Right( - multistream_select::listener_select_proto(conn, iter) - .map_ok(add_responder as fn (_) -> _), - ), - }, - upgrade: up, - }, - } -} - -// TODO: This is a hack to get a fn pointer. Can we do better? -fn add_responder(input: (P, C)) -> (P, C, SimOpenRole) { - (input.0, input.1, SimOpenRole::Responder) -} - -pub struct AuthenticationUpgradeApply -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade>, -{ - inner: AuthenticationUpgradeApplyState -} - -impl Unpin for AuthenticationUpgradeApply -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade>, -{ -} - -enum AuthenticationUpgradeApplyState -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade>, -{ - Init { - future: Either< - multistream_select::DialerSelectFuture::IntoIter>>, - MapOk< - ListenerSelectFuture>, - fn((NameWrap, Negotiated)) -> (NameWrap, Negotiated, SimOpenRole) - >, - >, - upgrade: U, - }, - Upgrade { - role: SimOpenRole, - future: Either< - Pin>>::Future>>, - Pin>>::Future>>, - >, - }, - Undefined -} - -impl Future for AuthenticationUpgradeApply -where - C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + OutboundUpgrade, - Output = >>::Output, - Error = >>::Error - > -{ - type Output = Result< - (>>::Output, SimOpenRole), - UpgradeError<>>::Error>, - >; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - match mem::replace(&mut self.inner, AuthenticationUpgradeApplyState::Undefined) { - AuthenticationUpgradeApplyState::Init { mut future, upgrade } => { - let (info, io, role) = match Future::poll(Pin::new(&mut future), cx)? { - Poll::Ready(x) => x, - Poll::Pending => { - self.inner = AuthenticationUpgradeApplyState::Init { future, upgrade }; - return Poll::Pending - } - }; - let fut = match role { - SimOpenRole::Initiator => Either::Left(Box::pin(upgrade.upgrade_outbound(io, info.0))), - SimOpenRole::Responder => Either::Right(Box::pin(upgrade.upgrade_inbound(io, info.0))), - }; - self.inner = AuthenticationUpgradeApplyState::Upgrade { - future: fut, - role, - }; - } - AuthenticationUpgradeApplyState::Upgrade { mut future, role } => { - match Future::poll(Pin::new(&mut future), cx) { - Poll::Pending => { - self.inner = AuthenticationUpgradeApplyState::Upgrade { future, role }; - return Poll::Pending - } - Poll::Ready(Ok(x)) => { - debug!("Successfully applied negotiated protocol"); - return Poll::Ready(Ok((x, role))) - } - Poll::Ready(Err(e)) => { - debug!("Failed to apply negotiated protocol"); - return Poll::Ready(Err(UpgradeError::Apply(e))) - } - } - } - AuthenticationUpgradeApplyState::Undefined => - panic!("AuthenticationUpgradeApplyState::poll called after completion") - } - } - } -} - /// Tries to perform an upgrade on an inbound connection or substream. pub fn apply_inbound(conn: C, up: U) -> InboundUpgradeApply where @@ -350,11 +220,138 @@ where } } -type NameWrapIter = iter::Map::Item) -> NameWrap<::Item>>; +/// Applies an authentication upgrade to the inbound or outbound direction of a connection or substream. +/// +// TODO: Document that this is like `apply` but with simultaneous open. +pub fn apply_authentication(conn: C, up: U, cp: ConnectedPoint, v: Version) + -> AuthenticationUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + D: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade, Output = (PeerId, D)>, + U: OutboundUpgrade, Output = (PeerId, D), Error = >>::Error> + Clone, + +{ + fn add_responder(input: (P, C)) -> (P, C, SimOpenRole) { + (input.0, input.1, SimOpenRole::Responder) + } + + let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); + + AuthenticationUpgradeApply { + inner: AuthenticationUpgradeApplyState::Init{ + future: match cp { + ConnectedPoint::Dialer { .. } => Either::Left( + multistream_select::dialer_select_proto(conn, iter, v), + ), + ConnectedPoint::Listener { .. } => Either::Right( + multistream_select::listener_select_proto(conn, iter) + .map_ok(add_responder as fn (_) -> _), + ), + }, + upgrade: up, + }, + } +} + +pub struct AuthenticationUpgradeApply +where + U: InboundUpgrade> + OutboundUpgrade>, +{ + inner: AuthenticationUpgradeApplyState +} + +impl Unpin for AuthenticationUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade> + OutboundUpgrade>, +{ +} + +enum AuthenticationUpgradeApplyState +where + U: InboundUpgrade> + OutboundUpgrade>, +{ + Init { + future: Either< + multistream_select::DialerSelectFuture::IntoIter>>, + MapOk< + ListenerSelectFuture>, + fn((NameWrap, Negotiated)) -> (NameWrap, Negotiated, SimOpenRole) + >, + >, + upgrade: U, + }, + Upgrade { + role: SimOpenRole, + future: Either< + Pin>>::Future>>, + Pin>>::Future>>, + >, + }, + Undefined +} + +impl Future for AuthenticationUpgradeApply +where + C: AsyncRead + AsyncWrite + Unpin, + D: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade, Output = (PeerId, D)>, + U: OutboundUpgrade, Output = (PeerId, D), Error = >>::Error> + Clone, +{ + type Output = Result< + ((PeerId, SimOpenRole), D), + UpgradeError<>>::Error>, + >; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match mem::replace(&mut self.inner, AuthenticationUpgradeApplyState::Undefined) { + AuthenticationUpgradeApplyState::Init { mut future, upgrade } => { + let (info, io, role) = match Future::poll(Pin::new(&mut future), cx)? { + Poll::Ready(x) => x, + Poll::Pending => { + self.inner = AuthenticationUpgradeApplyState::Init { future, upgrade }; + return Poll::Pending + } + }; + let fut = match role { + SimOpenRole::Initiator => Either::Left(Box::pin(upgrade.upgrade_outbound(io, info.0))), + SimOpenRole::Responder => Either::Right(Box::pin(upgrade.upgrade_inbound(io, info.0))), + }; + self.inner = AuthenticationUpgradeApplyState::Upgrade { + future: fut, + role, + }; + } + AuthenticationUpgradeApplyState::Upgrade { mut future, role } => { + match Future::poll(Pin::new(&mut future), cx) { + Poll::Pending => { + self.inner = AuthenticationUpgradeApplyState::Upgrade { future, role }; + return Poll::Pending + } + Poll::Ready(Ok((peer_id, d))) => { + debug!("Successfully applied negotiated protocol"); + return Poll::Ready(Ok(((peer_id, role), d))) + } + Poll::Ready(Err(e)) => { + debug!("Failed to apply negotiated protocol"); + return Poll::Ready(Err(UpgradeError::Apply(e))) + } + } + } + AuthenticationUpgradeApplyState::Undefined => + panic!("AuthenticationUpgradeApplyState::poll called after completion") + } + } + } +} + +pub type NameWrapIter = iter::Map::Item) -> NameWrap<::Item>>; /// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`. #[derive(Clone)] -struct NameWrap(N); +pub struct NameWrap(N); impl AsRef<[u8]> for NameWrap { fn as_ref(&self) -> &[u8] { diff --git a/core/tests/transport_upgrade.rs b/core/tests/transport_upgrade.rs index 162120ac50e..c925d9b7958 100644 --- a/core/tests/transport_upgrade.rs +++ b/core/tests/transport_upgrade.rs @@ -23,7 +23,7 @@ mod util; use futures::prelude::*; use libp2p_core::identity; use libp2p_core::transport::{Transport, MemoryTransport}; -use libp2p_core::upgrade::{self, UpgradeInfo, InboundUpgrade, OutboundUpgrade}; +use libp2p_core::upgrade::{UpgradeInfo, InboundUpgrade, OutboundUpgrade}; use libp2p_mplex::MplexConfig; use libp2p_noise as noise; use multiaddr::{Multiaddr, Protocol}; diff --git a/core/tests/util.rs b/core/tests/util.rs index 0eff3e270bc..e85e29a9347 100644 --- a/core/tests/util.rs +++ b/core/tests/util.rs @@ -16,7 +16,6 @@ use libp2p_core::{ muxing::{StreamMuxer, StreamMuxerBox}, network::{Network, NetworkConfig}, transport, - upgrade, }; use libp2p_mplex as mplex; use libp2p_noise as noise; From dab4ceb4b57635a49b4155fc1fb8e4f8d624a5ed Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sun, 4 Jul 2021 21:55:40 +0200 Subject: [PATCH 06/23] core/: Enforce upgrade version at compile time --- core/src/transport/upgrade.rs | 52 +++++++++++----- core/src/upgrade.rs | 5 +- core/src/upgrade/apply.rs | 79 ++++++++++++++++++++++-- examples/ipfs-private.rs | 2 +- protocols/gossipsub/tests/smoke.rs | 4 +- protocols/identify/src/identify.rs | 9 ++- protocols/kad/src/behaviour/test.rs | 1 - protocols/relay/examples/relay.rs | 1 - protocols/relay/tests/lib.rs | 2 +- protocols/request-response/tests/ping.rs | 6 +- swarm/src/lib.rs | 3 +- 11 files changed, 126 insertions(+), 38 deletions(-) diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index d3bfa73397d..d92dd409cac 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -20,7 +20,6 @@ //! Configuration of transport protocol upgrades. -pub use crate::upgrade::{Version, SimOpenRole}; use crate::{ ConnectedPoint, @@ -35,6 +34,9 @@ use crate::{ muxing::{StreamMuxer, StreamMuxerBox}, upgrade::{ self, + SimOpenRole, + Version, + AuthenticationVersion, OutboundUpgrade, InboundUpgrade, UpgradeError, @@ -111,12 +113,12 @@ where U: OutboundUpgrade, Output = (PeerId, D), Error = E> + Clone, E: Error + 'static, { - self.authenticate_with_version(upgrade, upgrade::Version::default()) + self.authenticate_with_version(upgrade, AuthenticationVersion::default()) } - /// Same as [`Builder::authenticate`] with the option to choose the [`upgrade::Version`] used to - /// upgrade the connection. - pub fn authenticate_with_version(self, upgrade: U, version: upgrade::Version) -> Authenticated< + /// Same as [`Builder::authenticate`] with the option to choose the + /// [`AuthenticationVersion`] used to upgrade the connection. + pub fn authenticate_with_version(self, upgrade: U, version: AuthenticationVersion) -> Authenticated< AndThen AuthenticationUpgradeApply + Clone> > where T: Transport, @@ -151,11 +153,21 @@ where /// /// * I/O upgrade: `C -> D`. /// * Transport output: `(PeerId, C) -> (PeerId, D)`. - // - // TODO: Do we need an `apply` with a version? - // - // TODO: Rename to `and_then`. pub fn apply(self, upgrade: U) -> Authenticated UpgradeAuthenticated + Clone>> + where + T: Transport, + C: AsyncRead + AsyncWrite + Unpin, + D: AsyncRead + AsyncWrite + Unpin, + U: InboundUpgrade, Output = D, Error = E>, + U: OutboundUpgrade, Output = D, Error = E> + Clone, + E: Error + 'static, + { + self.apply_with_version(upgrade, Version::default()) + } + + /// Same as [`Authenticated::apply`] with the option to choose the + /// [`Version`] used to upgrade the connection. + pub fn apply_with_version(self, upgrade: U, version: Version) -> Authenticated UpgradeAuthenticated + Clone>> where T: Transport, C: AsyncRead + AsyncWrite + Unpin, @@ -167,8 +179,7 @@ where Authenticated(Builder::new(self.0.inner.and_then(move |((i, r), c), _endpoint| { let upgrade = match r { SimOpenRole::Initiator => { - // TODO: Offer version that allows choosing the Version. - Either::Left(upgrade::apply_outbound(c, upgrade, upgrade::Version::default())) + Either::Left(upgrade::apply_outbound(c, upgrade, version)) }, SimOpenRole::Responder => { Either::Right(upgrade::apply_inbound(c, upgrade)) @@ -198,12 +209,26 @@ where U: InboundUpgrade, Output = M, Error = E>, U: OutboundUpgrade, Output = M, Error = E> + Clone, E: Error + 'static, + { + self.multiplex_with_version(upgrade, Version::default()) + } + + /// Same as [`Authenticated::multiplex`] with the option to choose the + /// [`Version`] used to upgrade the connection. + pub fn multiplex_with_version(self, upgrade: U, version: Version) -> Multiplexed< + AndThen UpgradeAuthenticated + Clone> + > where + T: Transport, + C: AsyncRead + AsyncWrite + Unpin, + M: StreamMuxer, + U: InboundUpgrade, Output = M, Error = E>, + U: OutboundUpgrade, Output = M, Error = E> + Clone, + E: Error + 'static, { Multiplexed(self.0.inner.and_then(move |((i, r), c), _endpoint| { let upgrade = match r { SimOpenRole::Initiator => { - // TODO: Offer version that allows choosing the Version. - Either::Left(upgrade::apply_outbound(c, upgrade, upgrade::Version::default())) + Either::Left(upgrade::apply_outbound(c, upgrade, version)) }, SimOpenRole::Responder => { Either::Right(upgrade::apply_inbound(c, upgrade)) @@ -214,7 +239,6 @@ where })) } - // TODO: Add changelog entry that multiplex_ext is removed. } diff --git a/core/src/upgrade.rs b/core/src/upgrade.rs index 460f076a4b4..dde0fd8f16c 100644 --- a/core/src/upgrade.rs +++ b/core/src/upgrade.rs @@ -70,9 +70,10 @@ mod transfer; use futures::future::Future; pub use crate::Negotiated; -pub use multistream_select::{Version, NegotiatedComplete, NegotiationError, ProtocolError, SimOpenRole}; +pub use multistream_select::{NegotiatedComplete, NegotiationError, ProtocolError, SimOpenRole}; pub use self::{ - apply::{apply, apply_authentication, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply, AuthenticationUpgradeApply}, + // TODO: Break. + apply::{apply, apply_authentication, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply, AuthenticationUpgradeApply, Version, AuthenticationVersion}, denied::DeniedUpgrade, either::EitherUpgrade, error::UpgradeError, diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 76e726a554b..386a78ac36f 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -25,9 +25,36 @@ use log::debug; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; use std::{iter, mem, pin::Pin, task::Context, task::Poll}; -pub use multistream_select::{Version, SimOpenRole, NegotiationError}; +pub use multistream_select::{SimOpenRole, NegotiationError}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Version { + V1, + V1Lazy, +} + +impl From for multistream_select::Version { + fn from(v: Version) -> Self { + match v { + Version::V1 => multistream_select::Version::V1, + Version::V1Lazy => multistream_select::Version::V1Lazy, + } + } +} + +impl Default for Version { + fn default() -> Self { + match multistream_select::Version::default() { + multistream_select::Version::V1 => Version::V1, + multistream_select::Version::V1Lazy => Version::V1Lazy, + multistream_select::Version::V1SimOpen => unreachable!("see `v1_sim_open_is_not_default`"), + } + } +} /// Applies an upgrade to the inbound and outbound direction of a connection or substream. +// +// TODO: Link to apply_authentication for authentication protocols. pub fn apply(conn: C, up: U, cp: ConnectedPoint, v: Version) -> Either, OutboundUpgradeApply> where @@ -61,7 +88,7 @@ where U: OutboundUpgrade> { let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); - let future = multistream_select::dialer_select_proto(conn, iter, v); + let future = multistream_select::dialer_select_proto(conn, iter, v.into()); OutboundUpgradeApply { inner: OutboundUpgradeApplyState::Init { future, upgrade: up } } @@ -220,10 +247,37 @@ where } } -/// Applies an authentication upgrade to the inbound or outbound direction of a connection or substream. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AuthenticationVersion { + V1, + V1Lazy, + V1SimOpen +} + +impl Default for AuthenticationVersion { + fn default() -> Self { + match multistream_select::Version::default() { + multistream_select::Version::V1 => AuthenticationVersion::V1, + multistream_select::Version::V1Lazy => AuthenticationVersion::V1Lazy, + multistream_select::Version::V1SimOpen => AuthenticationVersion::V1SimOpen, + } + } +} + +impl From for multistream_select::Version { + fn from(v: AuthenticationVersion) -> Self { + match v { + AuthenticationVersion::V1 => multistream_select::Version::V1, + AuthenticationVersion::V1Lazy => multistream_select::Version::V1Lazy, + AuthenticationVersion::V1SimOpen => multistream_select::Version::V1SimOpen, + } + } +} + +/// Applies an authentication upgrade to the inbound or outbound direction of a connection. /// -// TODO: Document that this is like `apply` but with simultaneous open. -pub fn apply_authentication(conn: C, up: U, cp: ConnectedPoint, v: Version) +/// Note: This is like [`apply`] with additional support for +pub fn apply_authentication(conn: C, up: U, cp: ConnectedPoint, v: AuthenticationVersion) -> AuthenticationUpgradeApply where C: AsyncRead + AsyncWrite + Unpin, @@ -242,7 +296,7 @@ where inner: AuthenticationUpgradeApplyState::Init{ future: match cp { ConnectedPoint::Dialer { .. } => Either::Left( - multistream_select::dialer_select_proto(conn, iter, v), + multistream_select::dialer_select_proto(conn, iter, v.into()), ), ConnectedPoint::Listener { .. } => Either::Right( multistream_select::listener_select_proto(conn, iter) @@ -358,3 +412,16 @@ impl AsRef<[u8]> for NameWrap { self.0.protocol_name() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn v1_sim_open_is_not_default() { + assert_ne!( + multistream_select::Version::default(), + multistream_select::Version::V1SimOpen, + ); + } +} diff --git a/examples/ipfs-private.rs b/examples/ipfs-private.rs index 101ad372114..6dc5bba6cab 100644 --- a/examples/ipfs-private.rs +++ b/examples/ipfs-private.rs @@ -35,7 +35,7 @@ use async_std::{io, task}; use futures::{future, prelude::*}; use libp2p::{ core::{ - either::EitherTransport, muxing::StreamMuxerBox, transport, transport::upgrade::Version, + either::EitherTransport, muxing::StreamMuxerBox, transport, }, gossipsub::{self, Gossipsub, GossipsubConfigBuilder, GossipsubEvent, MessageAuthenticity}, identify::{Identify, IdentifyConfig, IdentifyEvent}, diff --git a/protocols/gossipsub/tests/smoke.rs b/protocols/gossipsub/tests/smoke.rs index 0dd649135fc..a4b9cc7c8fe 100644 --- a/protocols/gossipsub/tests/smoke.rs +++ b/protocols/gossipsub/tests/smoke.rs @@ -30,7 +30,7 @@ use std::{ use futures::StreamExt; use libp2p_core::{ - identity, multiaddr::Protocol, transport::MemoryTransport, upgrade, Multiaddr, Transport, + identity, multiaddr::Protocol, transport::MemoryTransport, Multiaddr, Transport, }; use libp2p_gossipsub::{ Gossipsub, GossipsubConfigBuilder, GossipsubEvent, IdentTopic as Topic, MessageAuthenticity, @@ -55,7 +55,7 @@ impl Future for Graph { Poll::Ready(Some(_)) => {} Poll::Ready(None) => panic!("unexpected None when polling nodes"), Poll::Pending => break, - } + } } } diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 8adfdc05a15..eb6c61798c8 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -451,7 +451,6 @@ mod tests { muxing::StreamMuxerBox, transport, Transport, - upgrade }; use libp2p_noise as noise; use libp2p_tcp::TcpConfig; @@ -517,8 +516,8 @@ mod tests { pin_mut!(swarm2_fut); match future::select(swarm1_fut, swarm2_fut).await.factor_second().0 { - future::Either::Left(SwarmEvent::Behaviour(IdentifyEvent::Received { - info, + future::Either::Left(SwarmEvent::Behaviour(IdentifyEvent::Received { + info, .. })) => { assert_eq!(info.public_key, pubkey2); @@ -528,8 +527,8 @@ mod tests { assert!(info.listen_addrs.is_empty()); return; } - future::Either::Right(SwarmEvent::Behaviour(IdentifyEvent::Received { - info, + future::Either::Right(SwarmEvent::Behaviour(IdentifyEvent::Received { + info, .. })) => { assert_eq!(info.public_key, pubkey1); diff --git a/protocols/kad/src/behaviour/test.rs b/protocols/kad/src/behaviour/test.rs index 19458c0602f..b410f90e9f8 100644 --- a/protocols/kad/src/behaviour/test.rs +++ b/protocols/kad/src/behaviour/test.rs @@ -38,7 +38,6 @@ use libp2p_core::{ identity, transport::MemoryTransport, multiaddr::{Protocol, Multiaddr, multiaddr}, - upgrade, multihash::{Code, Multihash, MultihashDigest}, }; use libp2p_noise as noise; diff --git a/protocols/relay/examples/relay.rs b/protocols/relay/examples/relay.rs index 1a9f6333d5a..53670177a27 100644 --- a/protocols/relay/examples/relay.rs +++ b/protocols/relay/examples/relay.rs @@ -53,7 +53,6 @@ use futures::executor::block_on; use futures::stream::StreamExt; -use libp2p::core::upgrade; use libp2p::ping::{Ping, PingConfig, PingEvent}; use libp2p::plaintext; use libp2p::relay::{Relay, RelayConfig}; diff --git a/protocols/relay/tests/lib.rs b/protocols/relay/tests/lib.rs index 1dc057eb7fe..aece1611eb9 100644 --- a/protocols/relay/tests/lib.rs +++ b/protocols/relay/tests/lib.rs @@ -29,7 +29,7 @@ use libp2p_core::either::EitherTransport; use libp2p_core::multiaddr::{Multiaddr, Protocol}; use libp2p_core::transport::{MemoryTransport, Transport, TransportError}; use libp2p_core::upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade}; -use libp2p_core::{identity, upgrade, PeerId}; +use libp2p_core::{identity, PeerId}; use libp2p_identify::{Identify, IdentifyConfig, IdentifyEvent, IdentifyInfo}; use libp2p_kad::{GetClosestPeersOk, Kademlia, KademliaEvent, QueryResult}; use libp2p_ping::{Ping, PingConfig, PingEvent}; diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index cc1f7897510..5f5dd2e41ff 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -27,7 +27,7 @@ use libp2p_core::{ identity, muxing::StreamMuxerBox, transport::{self, Transport}, - upgrade::{self, read_one, write_one} + upgrade::{read_one, write_one} }; use libp2p_noise::{NoiseConfig, X25519Spec, Keypair}; use libp2p_request_response::*; @@ -207,8 +207,8 @@ fn emits_inbound_connection_closed_failure() { loop { match swarm1.select_next_some().await { - SwarmEvent::Behaviour(RequestResponseEvent::InboundFailure { - error: InboundFailure::ConnectionClosed, + SwarmEvent::Behaviour(RequestResponseEvent::InboundFailure { + error: InboundFailure::ConnectionClosed, .. }) => break, SwarmEvent::Behaviour(e) => panic!("Peer1: Unexpected event: {:?}", e), diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 405c2776efd..0fb43c5c681 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -838,7 +838,7 @@ where TBehaviour: NetworkBehaviour, THandler: IntoProtocolsHandler + Send + 'static, TInEvent: Send + 'static, TOutEvent: Send + 'static, - THandler::Handler: + THandler::Handler: ProtocolsHandler, THandleErr: error::Error + Send + 'static, { @@ -1128,7 +1128,6 @@ mod tests { use futures::{future, executor}; use libp2p_core::{ identity, - upgrade, multiaddr, transport }; From 227363f7e584bf8d2e1022f2738c9ccefedd38f0 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 18:48:28 +0200 Subject: [PATCH 07/23] misc/multistream-select: Document V1SimOpen --- misc/multistream-select/src/lib.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 4f073f8bb43..5fa75c6476c 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -137,6 +137,17 @@ pub enum Version { /// [1]: https://github.com/multiformats/go-multistream/issues/20 /// [2]: https://github.com/libp2p/rust-libp2p/pull/1212 V1Lazy, + /// A variant of version 1 that selects a single initiator when both peers are acting as such, + /// in other words when both peers simultaneously open a connection. + /// + /// This multistream-select variant is specified in [1]. + /// + /// Note: [`V1SimOpen`] should only be used (a) on transports that allow simultaneously opened + /// connections, e.g. TCP with socket reuse and (2) during the first negotiation on the + /// connection, most likely the secure channel protocol negotiation. In all other cases one + /// should use [`V1`] or [`V1Lazy`]. + /// + /// [1]: https://github.com/libp2p/specs/blob/master/connections/simopen.md V1SimOpen, // Draft: https://github.com/libp2p/specs/pull/95 // V2, From 224141e6afe3930ac1d80a67866026cf32f2adf5 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 18:50:25 +0200 Subject: [PATCH 08/23] *: Rename V1SimOpen to V1SimultaneousOpen --- core/src/upgrade/apply.rs | 10 +++++----- misc/multistream-select/src/dialer_select.rs | 6 +++--- misc/multistream-select/src/lib.rs | 4 ++-- misc/multistream-select/src/protocol.rs | 2 +- misc/multistream-select/src/tests.rs | 8 ++++---- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 386a78ac36f..77c31e6bf5f 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -47,7 +47,7 @@ impl Default for Version { match multistream_select::Version::default() { multistream_select::Version::V1 => Version::V1, multistream_select::Version::V1Lazy => Version::V1Lazy, - multistream_select::Version::V1SimOpen => unreachable!("see `v1_sim_open_is_not_default`"), + multistream_select::Version::V1SimultaneousOpen => unreachable!("see `v1_sim_open_is_not_default`"), } } } @@ -251,7 +251,7 @@ where pub enum AuthenticationVersion { V1, V1Lazy, - V1SimOpen + V1SimultaneousOpen } impl Default for AuthenticationVersion { @@ -259,7 +259,7 @@ impl Default for AuthenticationVersion { match multistream_select::Version::default() { multistream_select::Version::V1 => AuthenticationVersion::V1, multistream_select::Version::V1Lazy => AuthenticationVersion::V1Lazy, - multistream_select::Version::V1SimOpen => AuthenticationVersion::V1SimOpen, + multistream_select::Version::V1SimultaneousOpen => AuthenticationVersion::V1SimultaneousOpen, } } } @@ -269,7 +269,7 @@ impl From for multistream_select::Version { match v { AuthenticationVersion::V1 => multistream_select::Version::V1, AuthenticationVersion::V1Lazy => multistream_select::Version::V1Lazy, - AuthenticationVersion::V1SimOpen => multistream_select::Version::V1SimOpen, + AuthenticationVersion::V1SimultaneousOpen => multistream_select::Version::V1SimultaneousOpen, } } } @@ -421,7 +421,7 @@ mod tests { fn v1_sim_open_is_not_default() { assert_ne!( multistream_select::Version::default(), - multistream_select::Version::V1SimOpen, + multistream_select::Version::V1SimultaneousOpen, ); } } diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index b2ad163fab2..f010fe5bf19 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -66,7 +66,7 @@ where Either::Right(dialer_select_proto_parallel(inner, iter, version)) } }, - Version::V1SimOpen => { + Version::V1SimultaneousOpen => { Either::Left(dialer_select_proto_serial(inner, iter, version)) } } @@ -198,7 +198,7 @@ where // proposal in one go for efficiency. *this.state = SeqState::SendProtocol { io, protocol }; } - Version::V1SimOpen => { + Version::V1SimultaneousOpen => { *this.state = SeqState::SendSimOpen { io, protocol: None }; } } @@ -328,7 +328,7 @@ where *this.state = SeqState::FlushProtocol { io, protocol } } else { match this.version { - Version::V1 | Version::V1SimOpen => *this.state = SeqState::FlushProtocol { io, protocol }, + Version::V1 | Version::V1SimultaneousOpen => *this.state = SeqState::FlushProtocol { io, protocol }, // This is the only effect that `V1Lazy` has compared to `V1`: // Optimistically settling on the only protocol that // the dialer supports for this negotiation. Notably, diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 5fa75c6476c..21f986959ad 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -142,13 +142,13 @@ pub enum Version { /// /// This multistream-select variant is specified in [1]. /// - /// Note: [`V1SimOpen`] should only be used (a) on transports that allow simultaneously opened + /// Note: [`V1SimultaneousOpen`] should only be used (a) on transports that allow simultaneously opened /// connections, e.g. TCP with socket reuse and (2) during the first negotiation on the /// connection, most likely the secure channel protocol negotiation. In all other cases one /// should use [`V1`] or [`V1Lazy`]. /// /// [1]: https://github.com/libp2p/specs/blob/master/connections/simopen.md - V1SimOpen, + V1SimultaneousOpen, // Draft: https://github.com/libp2p/specs/pull/95 // V2, } diff --git a/misc/multistream-select/src/protocol.rs b/misc/multistream-select/src/protocol.rs index f9a9f4b49c1..23305eff3b6 100644 --- a/misc/multistream-select/src/protocol.rs +++ b/misc/multistream-select/src/protocol.rs @@ -68,7 +68,7 @@ pub enum HeaderLine { impl From for HeaderLine { fn from(v: Version) -> HeaderLine { match v { - Version::V1 | Version::V1Lazy | Version::V1SimOpen => HeaderLine::V1, + Version::V1 | Version::V1Lazy | Version::V1SimultaneousOpen => HeaderLine::V1, } } } diff --git a/misc/multistream-select/src/tests.rs b/misc/multistream-select/src/tests.rs index 7952af53333..5bbbde1be0e 100644 --- a/misc/multistream-select/src/tests.rs +++ b/misc/multistream-select/src/tests.rs @@ -69,7 +69,7 @@ fn select_proto_basic() { async_std::task::block_on(run(Version::V1)); async_std::task::block_on(run(Version::V1Lazy)); - async_std::task::block_on(run(Version::V1SimOpen)); + async_std::task::block_on(run(Version::V1SimultaneousOpen)); } /// Tests the expected behaviour of failed negotiations. @@ -163,7 +163,7 @@ fn negotiation_failed() { for (listen_protos, dial_protos) in protos { for dial_payload in payloads.clone() { - for &version in &[Version::V1, Version::V1Lazy, Version::V1SimOpen] { + for &version in &[Version::V1, Version::V1Lazy, Version::V1SimultaneousOpen] { async_std::task::block_on(run(Test { version, listen_protos: listen_protos.clone(), @@ -235,7 +235,7 @@ fn select_proto_serial() { async_std::task::block_on(run(Version::V1)); async_std::task::block_on(run(Version::V1Lazy)); - async_std::task::block_on(run(Version::V1SimOpen)); + async_std::task::block_on(run(Version::V1SimultaneousOpen)); } #[test] @@ -264,5 +264,5 @@ fn simultaneous_open() { futures::future::join(server, client).await; } - futures::executor::block_on(run(Version::V1SimOpen)); + futures::executor::block_on(run(Version::V1SimultaneousOpen)); } From cadb0654cd632b8bdcf6e7c201fdd5cfcf9fb8c9 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:03:50 +0200 Subject: [PATCH 09/23] misc/multistream-select: Document SimOpenRole --- misc/multistream-select/src/dialer_select.rs | 29 +++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index f010fe5bf19..69427a7b083 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -32,8 +32,9 @@ use std::{cmp::Ordering, convert::TryFrom as _, iter, mem, pin::Pin, task::{Cont /// /// This function is given an I/O stream and a list of protocols and returns a /// computation that performs the protocol negotiation with the remote. The -/// returned `Future` resolves with the name of the negotiated protocol and -/// a [`Negotiated`] I/O stream. +/// returned `Future` resolves with the name of the negotiated protocol, a +/// [`Negotiated`] I/O stream and the [`Role`] of the peer on the connection +/// going forward. /// /// The chosen message flow for protocol negotiation depends on the numbers of /// supported protocols given. That is, this function delegates to serial or @@ -72,7 +73,7 @@ where } } -/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer +/// Future, returned by `dialer_select_proto`, which selects a protocol and /// either trying protocols in-order, or by requesting all protocols supported /// by the remote upfront, from which the first protocol found in the dialer's /// list of protocols is selected. @@ -414,7 +415,12 @@ enum SimOpenState { Done, } -// TODO: Rename this to `Role` in general? +/// Role of the local node after protocol negotiation. +/// +/// Always equals [`Initiator`] unless [`Version::V1SimultaneousOpen`] is used +/// in which case node may end up in either role after negotiation. +/// +/// See [`Version::V1SimultaneousOpen`] for details. pub enum SimOpenRole { Initiator, Responder, @@ -478,8 +484,18 @@ where }; match msg { - // TODO: Document that this might still be the protocol send by the remote with - // the sim open ID. + // As an optimization, the simultaneous open + // multistream-select variant sends both the + // simultaneous open ID (`/libp2p/simultaneous-connect`) + // and a protocol before flushing. In the case where the + // remote acts as a listener already, it can accept or + // decline the attached protocol within the same + // round-trip. + // + // In this particular situation, the remote acts as a + // dialer and uses the simultaneous open variant. Given + // that nonces need to be exchanged first, the attached + // protocol by the remote needs to be ignored. Message::Protocol(_) => { self.state = SimOpenState::ReadNonce { io, local_nonce }; } @@ -590,7 +606,6 @@ where I: Iterator, I::Item: AsRef<[u8]> { - // TODO: Is it a hack that DialerSelectPar returns the simopenrole? type Output = Result<(I::Item, Negotiated, SimOpenRole), NegotiationError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { From bd76946f886feee738121bae2457e69be97a349a Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:05:44 +0200 Subject: [PATCH 10/23] *: Rename SimOpenRole to Role --- core/src/transport/upgrade.rs | 26 +++++++------- core/src/upgrade.rs | 2 +- core/src/upgrade/apply.rs | 18 +++++----- misc/multistream-select/src/dialer_select.rs | 38 ++++++++++---------- misc/multistream-select/src/lib.rs | 2 +- 5 files changed, 43 insertions(+), 43 deletions(-) diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index d92dd409cac..0b215018452 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -34,7 +34,7 @@ use crate::{ muxing::{StreamMuxer, StreamMuxerBox}, upgrade::{ self, - SimOpenRole, + Role, Version, AuthenticationVersion, OutboundUpgrade, @@ -153,9 +153,9 @@ where /// /// * I/O upgrade: `C -> D`. /// * Transport output: `(PeerId, C) -> (PeerId, D)`. - pub fn apply(self, upgrade: U) -> Authenticated UpgradeAuthenticated + Clone>> + pub fn apply(self, upgrade: U) -> Authenticated UpgradeAuthenticated + Clone>> where - T: Transport, + T: Transport, C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D, Error = E>, @@ -167,9 +167,9 @@ where /// Same as [`Authenticated::apply`] with the option to choose the /// [`Version`] used to upgrade the connection. - pub fn apply_with_version(self, upgrade: U, version: Version) -> Authenticated UpgradeAuthenticated + Clone>> + pub fn apply_with_version(self, upgrade: U, version: Version) -> Authenticated UpgradeAuthenticated + Clone>> where - T: Transport, + T: Transport, C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = D, Error = E>, @@ -178,10 +178,10 @@ where { Authenticated(Builder::new(self.0.inner.and_then(move |((i, r), c), _endpoint| { let upgrade = match r { - SimOpenRole::Initiator => { + Role::Initiator => { Either::Left(upgrade::apply_outbound(c, upgrade, version)) }, - SimOpenRole::Responder => { + Role::Responder => { Either::Right(upgrade::apply_inbound(c, upgrade)) } @@ -201,9 +201,9 @@ where /// * I/O upgrade: `C -> M`. /// * Transport output: `(PeerId, C) -> (PeerId, M)`. pub fn multiplex(self, upgrade: U) -> Multiplexed< - AndThen UpgradeAuthenticated + Clone> + AndThen UpgradeAuthenticated + Clone> > where - T: Transport, + T: Transport, C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, U: InboundUpgrade, Output = M, Error = E>, @@ -216,9 +216,9 @@ where /// Same as [`Authenticated::multiplex`] with the option to choose the /// [`Version`] used to upgrade the connection. pub fn multiplex_with_version(self, upgrade: U, version: Version) -> Multiplexed< - AndThen UpgradeAuthenticated + Clone> + AndThen UpgradeAuthenticated + Clone> > where - T: Transport, + T: Transport, C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, U: InboundUpgrade, Output = M, Error = E>, @@ -227,10 +227,10 @@ where { Multiplexed(self.0.inner.and_then(move |((i, r), c), _endpoint| { let upgrade = match r { - SimOpenRole::Initiator => { + Role::Initiator => { Either::Left(upgrade::apply_outbound(c, upgrade, version)) }, - SimOpenRole::Responder => { + Role::Responder => { Either::Right(upgrade::apply_inbound(c, upgrade)) } diff --git a/core/src/upgrade.rs b/core/src/upgrade.rs index 4f87e007d72..927b54c71c0 100644 --- a/core/src/upgrade.rs +++ b/core/src/upgrade.rs @@ -70,7 +70,7 @@ mod transfer; use futures::future::Future; pub use crate::Negotiated; -pub use multistream_select::{NegotiatedComplete, NegotiationError, ProtocolError, SimOpenRole}; +pub use multistream_select::{NegotiatedComplete, NegotiationError, ProtocolError, Role}; pub use self::{ // TODO: Break. apply::{apply, apply_authentication, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply, AuthenticationUpgradeApply, Version, AuthenticationVersion}, diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 77c31e6bf5f..7d1fa536c19 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -25,7 +25,7 @@ use log::debug; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; use std::{iter, mem, pin::Pin, task::Context, task::Poll}; -pub use multistream_select::{SimOpenRole, NegotiationError}; +pub use multistream_select::{Role, NegotiationError}; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Version { @@ -212,7 +212,7 @@ where loop { match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) { OutboundUpgradeApplyState::Init { mut future, upgrade } => { - // TODO: Don't ignore the SimOpenRole here. Instead add assert!. + // TODO: Don't ignore the Role here. Instead add assert!. let (info, connection, _) = match Future::poll(Pin::new(&mut future), cx)? { Poll::Ready(x) => x, Poll::Pending => { @@ -286,8 +286,8 @@ where U: OutboundUpgrade, Output = (PeerId, D), Error = >>::Error> + Clone, { - fn add_responder(input: (P, C)) -> (P, C, SimOpenRole) { - (input.0, input.1, SimOpenRole::Responder) + fn add_responder(input: (P, C)) -> (P, C, Role) { + (input.0, input.1, Role::Responder) } let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); @@ -331,13 +331,13 @@ where multistream_select::DialerSelectFuture::IntoIter>>, MapOk< ListenerSelectFuture>, - fn((NameWrap, Negotiated)) -> (NameWrap, Negotiated, SimOpenRole) + fn((NameWrap, Negotiated)) -> (NameWrap, Negotiated, Role) >, >, upgrade: U, }, Upgrade { - role: SimOpenRole, + role: Role, future: Either< Pin>>::Future>>, Pin>>::Future>>, @@ -354,7 +354,7 @@ where U: OutboundUpgrade, Output = (PeerId, D), Error = >>::Error> + Clone, { type Output = Result< - ((PeerId, SimOpenRole), D), + ((PeerId, Role), D), UpgradeError<>>::Error>, >; @@ -370,8 +370,8 @@ where } }; let fut = match role { - SimOpenRole::Initiator => Either::Left(Box::pin(upgrade.upgrade_outbound(io, info.0))), - SimOpenRole::Responder => Either::Right(Box::pin(upgrade.upgrade_inbound(io, info.0))), + Role::Initiator => Either::Left(Box::pin(upgrade.upgrade_outbound(io, info.0))), + Role::Responder => Either::Right(Box::pin(upgrade.upgrade_inbound(io, info.0))), }; self.inner = AuthenticationUpgradeApplyState::Upgrade { future: fut, diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 69427a7b083..10e6115fb9c 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -170,7 +170,7 @@ where // TODO: Clone needed to embed ListenerSelectFuture. Still needed? I::Item: AsRef<[u8]> + Clone { - type Output = Result<(I::Item, Negotiated, SimOpenRole), NegotiationError>; + type Output = Result<(I::Item, Negotiated, Role), NegotiationError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -288,10 +288,10 @@ where }; match selection_res { - SimOpenRole::Initiator => { + Role::Initiator => { *this.state = SeqState::SendProtocol { io, protocol }; } - SimOpenRole::Responder => { + Role::Responder => { let protocols: Vec<_> = this.protocols.collect(); *this.state = SeqState::Responder { responder: crate::listener_select::listener_select_proto_no_header(io, std::iter::once(protocol).chain(protocols.into_iter())), @@ -302,7 +302,7 @@ where SeqState::Responder { mut responder } => { match Pin::new(&mut responder ).poll(cx) { - Poll::Ready(res) => return Poll::Ready(res.map(|(p, io)| (p, io, SimOpenRole::Responder))), + Poll::Ready(res) => return Poll::Ready(res.map(|(p, io)| (p, io, Role::Responder))), Poll::Pending => { *this.state = SeqState::Responder { responder }; return Poll::Pending @@ -338,7 +338,7 @@ where log::debug!("Dialer: Expecting proposed protocol: {}", p); let hl = HeaderLine::from(Version::V1Lazy); let io = Negotiated::expecting(io.into_reader(), p, Some(hl)); - return Poll::Ready(Ok((protocol, io, SimOpenRole::Initiator))) + return Poll::Ready(Ok((protocol, io, Role::Initiator))) } } } @@ -383,7 +383,7 @@ where Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { log::debug!("Dialer: Received confirmation for protocol: {}", p); let io = Negotiated::completed(io.into_inner()); - return Poll::Ready(Ok((protocol, io, SimOpenRole::Initiator))); + return Poll::Ready(Ok((protocol, io, Role::Initiator))); } Message::NotAvailable => { log::debug!("Dialer: Received rejection of protocol: {}", @@ -409,9 +409,9 @@ enum SimOpenState { SendNonce { io: MessageIO }, FlushNonce { io: MessageIO, local_nonce: u64 }, ReadNonce { io: MessageIO, local_nonce: u64 }, - SendRole { io: MessageIO, local_role: SimOpenRole }, - FlushRole { io: MessageIO, local_role: SimOpenRole }, - ReadRole { io: MessageIO, local_role: SimOpenRole }, + SendRole { io: MessageIO, local_role: Role }, + FlushRole { io: MessageIO, local_role: Role }, + ReadRole { io: MessageIO, local_role: Role }, Done, } @@ -421,7 +421,7 @@ enum SimOpenState { /// in which case node may end up in either role after negotiation. /// /// See [`Version::V1SimultaneousOpen`] for details. -pub enum SimOpenRole { +pub enum Role { Initiator, Responder, } @@ -432,7 +432,7 @@ where // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, { - type Output = Result<(MessageIO, SimOpenRole), NegotiationError>; + type Output = Result<(MessageIO, Role), NegotiationError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -508,13 +508,13 @@ where Ordering::Greater => { self.state = SimOpenState::SendRole { io, - local_role: SimOpenRole::Initiator, + local_role: Role::Initiator, }; }, Ordering::Less => { self.state = SimOpenState::SendRole { io, - local_role: SimOpenRole::Responder, + local_role: Role::Responder, }; } } @@ -532,8 +532,8 @@ where } let msg = match local_role { - SimOpenRole::Initiator => Message::Initiator, - SimOpenRole::Responder => Message::Responder, + Role::Initiator => Message::Initiator, + Role::Responder => Message::Responder, }; if let Err(err) = Pin::new(&mut io).start_send(msg) { @@ -565,8 +565,8 @@ where }; let result = match local_role { - SimOpenRole::Initiator if remote_msg == Message::Responder => Ok((io, local_role)), - SimOpenRole::Responder if remote_msg == Message::Initiator => Ok((io, local_role)), + Role::Initiator if remote_msg == Message::Responder => Ok((io, local_role)), + Role::Responder if remote_msg == Message::Initiator => Ok((io, local_role)), _ => Err(ProtocolError::InvalidMessage.into()) }; @@ -606,7 +606,7 @@ where I: Iterator, I::Item: AsRef<[u8]> { - type Output = Result<(I::Item, Negotiated, SimOpenRole), NegotiationError>; + type Output = Result<(I::Item, Negotiated, Role), NegotiationError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -704,7 +704,7 @@ where log::debug!("Dialer: Expecting proposed protocol: {}", p); let io = Negotiated::expecting(io.into_reader(), p, None); - return Poll::Ready(Ok((protocol, io, SimOpenRole::Initiator))) + return Poll::Ready(Ok((protocol, io, Role::Initiator))) } ParState::Done => panic!("ParState::poll called after completion") diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 21f986959ad..75942ea3b1a 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -96,7 +96,7 @@ mod tests; pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError}; pub use self::protocol::ProtocolError; -pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture, SimOpenRole}; +pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture, Role}; pub use self::listener_select::{listener_select_proto, ListenerSelectFuture}; /// Supported multistream-select versions. From ade2aff3a761e553c26c4add641e2a50471cb726 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:09:17 +0200 Subject: [PATCH 11/23] misc/multistream-select: Document reponder role process --- misc/multistream-select/src/dialer_select.rs | 1 - misc/multistream-select/src/listener_select.rs | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 10e6115fb9c..e0747b1fbbe 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -167,7 +167,6 @@ where // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, I: Iterator, - // TODO: Clone needed to embed ListenerSelectFuture. Still needed? I::Item: AsRef<[u8]> + Clone { type Output = Result<(I::Item, Negotiated, Role), NegotiationError>; diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index 01347b04bef..9bf7ef6726d 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -49,6 +49,8 @@ where }, protocols) } +/// Used when selected as a [`crate::Role::Responder`] during [`Version::V1SimultaneousOpen`] +/// negotiation. pub(crate) fn listener_select_proto_no_header( io: MessageIO, protocols: I, From b66b4ebc707947b2386aa24c25d8d7e16bb591c0 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:13:50 +0200 Subject: [PATCH 12/23] misc/multistream-select: Bump version and add changelog entry --- core/Cargo.toml | 2 +- misc/multistream-select/CHANGELOG.md | 14 ++++++++++++++ misc/multistream-select/Cargo.toml | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index dbc7584477e..ccd71fa1485 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -22,7 +22,7 @@ libsecp256k1 = { version = "0.5.0", optional = true } log = "0.4" multiaddr = { version = "0.12.0" } multihash = { version = "0.13", default-features = false, features = ["std", "multihash-impl", "identity", "sha2"] } -multistream-select = { version = "0.10", path = "../misc/multistream-select" } +multistream-select = { version = "0.11", path = "../misc/multistream-select" } parking_lot = "0.11.0" pin-project = "1.0.0" prost = "0.7" diff --git a/misc/multistream-select/CHANGELOG.md b/misc/multistream-select/CHANGELOG.md index f7cf58417c9..ebf509f7c82 100644 --- a/misc/multistream-select/CHANGELOG.md +++ b/misc/multistream-select/CHANGELOG.md @@ -1,3 +1,17 @@ +# 0.11.0 [unreleased] + +- Add support for [simultaneous open extension] via `Version::V1SimultaneousOpen`. + + [`Role`] struct returned by `dialer_select_proto` `Future` can be ignored unless + `Version::V1SimultaneousOpen` is used. + + This is one important component of the greater effort to support hole punching in rust-libp2p. + + See [PR 2066]. + +[simultaneous open extension]: https://github.com/libp2p/specs/blob/master/connections/simopen.md +[PR 2066]: https://github.com/libp2p/rust-libp2p/pull/2066 + # 0.10.3 [2021-03-17] - Update dependencies. diff --git a/misc/multistream-select/Cargo.toml b/misc/multistream-select/Cargo.toml index ef689d5e7ee..e66741af07b 100644 --- a/misc/multistream-select/Cargo.toml +++ b/misc/multistream-select/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "multistream-select" description = "Multistream-select negotiation protocol for libp2p" -version = "0.10.3" +version = "0.11.0" authors = ["Parity Technologies "] license = "MIT" repository = "https://github.com/libp2p/rust-libp2p" From 16da533c3b8e67b6dd682f0c4ccc5408f37e2319 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:37:51 +0200 Subject: [PATCH 13/23] core/CHANGELOG: Add entry --- core/CHANGELOG.md | 36 +++++++++++++++++++++++++++++++++++ core/src/transport/upgrade.rs | 2 -- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 46177efeeea..4556928881e 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -12,7 +12,43 @@ Introduce `upgrade::read_length_prefixed` and `upgrade::write_length_prefixed`. See [PR 2111](https://github.com/libp2p/rust-libp2p/pull/2111). +- Add support for multistream-select [simultaneous open extension] to assign _initiator_ and + _responder_ role during authentication protocol negotiation on simultaneously opened connection. + + This is one important component of the greater effort to support hole punching in rust-libp2p. + + - `Transport::upgrade` no longer takes a multistream-select `Version`. Instead the + multistream-select `Version`s `V1`, `V1Lazy` and `V1SimultaneousOpen` can be selected when + setting the authentication upgrade via `Builder::authenticate_with_version` and the + multistream-select `Version`s `V1` and `V1Lazy` can be selected when setting the multiplexing + upgrade via `Builder::multiplex_with_version`. + + Users merely wanting to maintain the status quo should use the following call chain depending + on which `Version` they previously used: + + - `Version::V1` + + ```rust + my_transport.upgrade() + .authenticate(my_authentication) + .multiplex(my_multiplexer) + ``` + - `Version::V1Lazy` + + ```rust + my_transport.upgrade() + .authenticate_with_version(my_authentication, Version::V1Lazy) + .multiplex_with_version(my_multiplexer, Version::V1Lazy) + ``` + + - `Builder::multiplex_ext` is removed in favor of the new simultaneous open workflow. Please reach + out in case you depend on `Builder::multiplex_ext`. + + See [PR 2066]. + [PR 2090]: https://github.com/libp2p/rust-libp2p/pull/2090 +[simultaneous open extension]: https://github.com/libp2p/specs/blob/master/connections/simopen.md +[PR 2066]: https://github.com/libp2p/rust-libp2p/pull/2066 # 0.28.3 [2021-04-26] diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index 0b215018452..50bd15f53e9 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -238,8 +238,6 @@ where UpgradeAuthenticated { user_data: Some(i), upgrade } })) } - - // TODO: Add changelog entry that multiplex_ext is removed. } /// An upgrade that negotiates a (sub)stream multiplexer on From 4a9fef37740305ebc98cc2930ee4db44b57cd4fc Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:46:53 +0200 Subject: [PATCH 14/23] core/src/upgrade: Assert Initiator when not using SimOpen --- core/src/transport.rs | 2 -- core/src/upgrade.rs | 6 ++++-- core/src/upgrade/apply.rs | 17 ++++++++++++----- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/core/src/transport.rs b/core/src/transport.rs index 5711c06044e..0d33b88a494 100644 --- a/core/src/transport.rs +++ b/core/src/transport.rs @@ -195,8 +195,6 @@ pub trait Transport { /// Begins a series of protocol upgrades via an /// [`upgrade::Builder`](upgrade::Builder). - // - // TODO: Method still needed now that `upgrade` takes `self` only? fn upgrade(self) -> upgrade::Builder where Self: Sized, diff --git a/core/src/upgrade.rs b/core/src/upgrade.rs index 927b54c71c0..bb637c63db6 100644 --- a/core/src/upgrade.rs +++ b/core/src/upgrade.rs @@ -72,8 +72,10 @@ use futures::future::Future; pub use crate::Negotiated; pub use multistream_select::{NegotiatedComplete, NegotiationError, ProtocolError, Role}; pub use self::{ - // TODO: Break. - apply::{apply, apply_authentication, apply_inbound, apply_outbound, InboundUpgradeApply, OutboundUpgradeApply, AuthenticationUpgradeApply, Version, AuthenticationVersion}, + apply::{ + apply, apply_authentication, apply_inbound, apply_outbound, InboundUpgradeApply, + OutboundUpgradeApply, AuthenticationUpgradeApply, Version, AuthenticationVersion, + }, denied::DeniedUpgrade, either::EitherUpgrade, error::UpgradeError, diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 7d1fa536c19..3698eeed25f 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -53,8 +53,9 @@ impl Default for Version { } /// Applies an upgrade to the inbound and outbound direction of a connection or substream. -// -// TODO: Link to apply_authentication for authentication protocols. +/// +/// Note: Use [`apply_authentication`] when negotiating an authentication protocol on top of a +/// transport allowing simultaneously opened connections. pub fn apply(conn: C, up: U, cp: ConnectedPoint, v: Version) -> Either, OutboundUpgradeApply> where @@ -212,14 +213,18 @@ where loop { match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) { OutboundUpgradeApplyState::Init { mut future, upgrade } => { - // TODO: Don't ignore the Role here. Instead add assert!. - let (info, connection, _) = match Future::poll(Pin::new(&mut future), cx)? { + let (info, connection, role) = match Future::poll(Pin::new(&mut future), cx)? { Poll::Ready(x) => x, Poll::Pending => { self.inner = OutboundUpgradeApplyState::Init { future, upgrade }; return Poll::Pending } }; + assert_eq!( + role, Role::Initiator, + "Expect negotiation not using `Version::V1SimultaneousOpen` to either return \ + as `Initiator` or fail.", + ); self.inner = OutboundUpgradeApplyState::Upgrade { future: Box::pin(upgrade.upgrade_outbound(connection, info.0)) }; @@ -276,7 +281,9 @@ impl From for multistream_select::Version { /// Applies an authentication upgrade to the inbound or outbound direction of a connection. /// -/// Note: This is like [`apply`] with additional support for +/// Note: This is like [`apply`] with additional support for transports allowing simultaneously +/// opened connections. Unless run on such transport and used to negotiate the authentication +/// protocol you likely want to use [`apply`] instead of [`apply_authentication`]. pub fn apply_authentication(conn: C, up: U, cp: ConnectedPoint, v: AuthenticationVersion) -> AuthenticationUpgradeApply where From 41c884cd50fadbbf8624ab2db8b9ffb8a705227a Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:49:57 +0200 Subject: [PATCH 15/23] core/upgrade/apply: Document different versions --- core/src/upgrade/apply.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 3698eeed25f..97dddc07663 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -27,9 +27,14 @@ use std::{iter, mem, pin::Pin, task::Context, task::Poll}; pub use multistream_select::{Role, NegotiationError}; +/// Wrapper around multistream-select `Version`. +/// +/// See [`multistream_select::Version`] for details. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Version { + /// See [`multistream_select::Version::V1`]. V1, + /// See [`multistream_select::Version::V1Lazy`]. V1Lazy, } @@ -252,10 +257,16 @@ where } } +/// Wrapper around multistream-select `Version`. +/// +/// See [`multistream_select::Version`] for details. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum AuthenticationVersion { + /// See [`multistream_select::Version::V1`]. V1, + /// See [`multistream_select::Version::V1Lazy`]. V1Lazy, + /// See [`multistream_select::Version::V1SimultaneousOpen`]. V1SimultaneousOpen } From cf2e3ef66c68c2955696826c517e7e36e3493ab1 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:54:37 +0200 Subject: [PATCH 16/23] misc/multistream-select: Derive Eq for Role --- misc/multistream-select/src/dialer_select.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index e0747b1fbbe..0ce178d6225 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -420,6 +420,7 @@ enum SimOpenState { /// in which case node may end up in either role after negotiation. /// /// See [`Version::V1SimultaneousOpen`] for details. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Role { Initiator, Responder, From a4264fcff3a12fb2cd71ea45e4f21b77817587d5 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 19:59:32 +0200 Subject: [PATCH 17/23] *: Fix documentation links --- misc/multistream-select/src/dialer_select.rs | 4 ++-- misc/multistream-select/src/lib.rs | 8 ++++---- transports/noise/src/lib.rs | 5 +++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 0ce178d6225..72b4126751c 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -416,8 +416,8 @@ enum SimOpenState { /// Role of the local node after protocol negotiation. /// -/// Always equals [`Initiator`] unless [`Version::V1SimultaneousOpen`] is used -/// in which case node may end up in either role after negotiation. +/// Always equals [`Role::Initiator`] unless [`Version::V1SimultaneousOpen`] is +/// used in which case node may end up in either role after negotiation. /// /// See [`Version::V1SimultaneousOpen`] for details. #[derive(Clone, Copy, Debug, PartialEq, Eq)] diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 75942ea3b1a..8fa88baf4ac 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -142,10 +142,10 @@ pub enum Version { /// /// This multistream-select variant is specified in [1]. /// - /// Note: [`V1SimultaneousOpen`] should only be used (a) on transports that allow simultaneously opened - /// connections, e.g. TCP with socket reuse and (2) during the first negotiation on the - /// connection, most likely the secure channel protocol negotiation. In all other cases one - /// should use [`V1`] or [`V1Lazy`]. + /// Note: [`Version::V1SimultaneousOpen`] should only be used (a) on transports that allow + /// simultaneously opened connections, e.g. TCP with socket reuse and (2) during the first + /// negotiation on the connection, most likely the secure channel protocol negotiation. In all + /// other cases one should use [`Version::V1`] or [`Version::V1Lazy`]. /// /// [1]: https://github.com/libp2p/specs/blob/master/connections/simopen.md V1SimultaneousOpen, diff --git a/transports/noise/src/lib.rs b/transports/noise/src/lib.rs index 5df42564149..b236e7258ab 100644 --- a/transports/noise/src/lib.rs +++ b/transports/noise/src/lib.rs @@ -317,8 +317,9 @@ where /// See [`NoiseConfig::into_authenticated`]. /// /// On success, the upgrade yields the [`PeerId`] obtained from the -/// `RemoteIdentity`. The output of this upgrade is thus directly suitable -/// for creating an [`authenticated`](libp2p_core::transport::upgrade::Authenticate) +/// `RemoteIdentity`. The output of this upgrade is thus directly suitable for +/// creating an +/// [`authenticated`](libp2p_core::transport::upgrade::Builder::authenticate) /// transport for use with a [`Network`](libp2p_core::Network). #[derive(Clone)] pub struct NoiseAuthenticated { From 94793bc98425a3574ab78120cf67fe7ad69cb116 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 20:06:24 +0200 Subject: [PATCH 18/23] misc/multistream-select: Fix doc link --- misc/multistream-select/src/listener_select.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index 9bf7ef6726d..28139235ed4 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -49,8 +49,8 @@ where }, protocols) } -/// Used when selected as a [`crate::Role::Responder`] during [`Version::V1SimultaneousOpen`] -/// negotiation. +/// Used when selected as a [`crate::Role::Responder`] during [`crate::dialer_select_proto`] +/// negotiation with [`crate::Version::V1SimultaneousOpen`] pub(crate) fn listener_select_proto_no_header( io: MessageIO, protocols: I, From deb5e10277104f746ac5aee7b92fb686cbe5703b Mon Sep 17 00:00:00 2001 From: Max Inden Date: Wed, 7 Jul 2021 20:49:07 +0200 Subject: [PATCH 19/23] src/lib: Call upgrade without Version --- examples/chat-tokio.rs | 1 - src/lib.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/chat-tokio.rs b/examples/chat-tokio.rs index 053d990eb98..d030fb0ef07 100644 --- a/examples/chat-tokio.rs +++ b/examples/chat-tokio.rs @@ -42,7 +42,6 @@ use libp2p::{ NetworkBehaviour, PeerId, Transport, - core::upgrade, identity, floodsub::{self, Floodsub, FloodsubEvent}, mdns::{Mdns, MdnsEvent}, diff --git a/src/lib.rs b/src/lib.rs index 0795d8f1a80..21e1334a28d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -210,7 +210,7 @@ pub fn tokio_development_transport(keypair: identity::Keypair) .expect("Signing libp2p-noise static DH keypair failed."); Ok(transport - .upgrade(core::upgrade::Version::V1) + .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(core::upgrade::SelectUpgrade::new(yamux::YamuxConfig::default(), mplex::MplexConfig::default())) .timeout(std::time::Duration::from_secs(20)) From d49cedd1ec2abb5f3a4b472e8722b78d54f2c4fd Mon Sep 17 00:00:00 2001 From: Max Inden Date: Fri, 13 Aug 2021 21:58:02 +0200 Subject: [PATCH 20/23] *: Format with rustfmt --- core/benches/peer_id.rs | 10 +- core/build.rs | 2 +- core/src/connection.rs | 40 +- core/src/connection/error.rs | 40 +- core/src/connection/handler.rs | 35 +- core/src/connection/listeners.rs | 155 ++- core/src/connection/manager.rs | 147 ++- core/src/connection/manager/task.rs | 148 ++- core/src/connection/pool.rs | 502 +++---- core/src/connection/substream.rs | 55 +- core/src/either.rs | 227 ++-- core/src/identity.rs | 68 +- core/src/identity/ed25519.rs | 47 +- core/src/identity/error.rs | 25 +- core/src/identity/rsa.rs | 46 +- core/src/identity/secp256k1.rs | 25 +- core/src/lib.rs | 10 +- core/src/muxing.rs | 216 ++- core/src/muxing/singleton.rs | 56 +- core/src/network.rs | 357 ++--- core/src/network/event.rs | 196 +-- core/src/network/peer.rs | 296 ++--- core/src/peer_id.rs | 17 +- core/src/transport.rs | 62 +- core/src/transport/and_then.rs | 58 +- core/src/transport/boxed.rs | 19 +- core/src/transport/choice.rs | 18 +- core/src/transport/dummy.rs | 30 +- core/src/transport/map.rs | 45 +- core/src/transport/map_err.rs | 20 +- core/src/transport/memory.rs | 152 ++- core/src/transport/timeout.rs | 38 +- core/src/transport/upgrade.rs | 157 ++- core/src/upgrade.rs | 24 +- core/src/upgrade/apply.rs | 186 ++- core/src/upgrade/either.rs | 47 +- core/src/upgrade/error.rs | 9 +- core/src/upgrade/from_fn.rs | 5 +- core/src/upgrade/map.rs | 47 +- core/src/upgrade/optional.rs | 6 +- core/src/upgrade/select.rs | 24 +- core/src/upgrade/transfer.rs | 65 +- core/tests/connection_limits.rs | 82 +- core/tests/network_dial_error.rs | 95 +- core/tests/transport_upgrade.rs | 16 +- core/tests/util.rs | 50 +- examples/chat-tokio.rs | 29 +- examples/chat.rs | 33 +- examples/distributed-key-value-store.rs | 56 +- examples/gossipsub-chat.rs | 2 +- examples/ipfs-kad.rs | 27 +- examples/ipfs-private.rs | 4 +- examples/mdns-passive-discovery.rs | 6 +- examples/ping.rs | 2 +- misc/multistream-select/src/dialer_select.rs | 321 +++-- .../src/length_delimited.rs | 104 +- misc/multistream-select/src/lib.rs | 6 +- .../multistream-select/src/listener_select.rs | 162 ++- misc/multistream-select/src/negotiated.rs | 140 +- misc/multistream-select/src/protocol.rs | 112 +- misc/multistream-select/src/tests.rs | 54 +- misc/multistream-select/tests/transport.rs | 94 +- misc/peer-id-generator/src/main.rs | 22 +- muxers/mplex/benches/split_send_size.rs | 53 +- muxers/mplex/src/codec.rs | 213 ++- muxers/mplex/src/io.rs | 437 +++--- muxers/mplex/src/lib.rs | 76 +- muxers/mplex/tests/async_write.rs | 34 +- muxers/mplex/tests/two_peers.rs | 74 +- muxers/yamux/src/lib.rs | 134 +- protocols/floodsub/build.rs | 3 +- protocols/floodsub/src/layer.rs | 242 ++-- protocols/floodsub/src/lib.rs | 2 +- protocols/floodsub/src/protocol.rs | 71 +- protocols/gossipsub/src/behaviour.rs | 10 +- protocols/gossipsub/src/protocol.rs | 2 +- protocols/gossipsub/tests/smoke.rs | 10 +- protocols/identify/build.rs | 3 +- protocols/identify/src/handler.rs | 103 +- protocols/identify/src/identify.rs | 171 +-- protocols/identify/src/lib.rs | 1 - protocols/identify/src/protocol.rs | 85 +- protocols/kad/build.rs | 3 +- protocols/kad/src/addresses.rs | 10 +- protocols/kad/src/behaviour.rs | 1127 +++++++++------- protocols/kad/src/behaviour/test.rs | 1173 +++++++++-------- protocols/kad/src/handler.rs | 401 +++--- protocols/kad/src/jobs.rs | 48 +- protocols/kad/src/kbucket.rs | 162 ++- protocols/kad/src/kbucket/bucket.rs | 204 +-- protocols/kad/src/kbucket/entry.rs | 62 +- protocols/kad/src/kbucket/key.rs | 44 +- protocols/kad/src/lib.rs | 55 +- protocols/kad/src/protocol.rs | 175 +-- protocols/kad/src/query.rs | 93 +- protocols/kad/src/query/peers.rs | 2 +- protocols/kad/src/query/peers/closest.rs | 219 +-- .../kad/src/query/peers/closest/disjoint.rs | 304 +++-- protocols/kad/src/query/peers/fixed.rs | 46 +- protocols/kad/src/record.rs | 16 +- protocols/kad/src/record/store.rs | 3 +- protocols/kad/src/record/store/memory.rs | 47 +- protocols/mdns/src/behaviour.rs | 6 +- protocols/ping/src/handler.rs | 92 +- protocols/ping/src/lib.rs | 14 +- protocols/ping/src/protocol.rs | 15 +- protocols/ping/tests/ping.rs | 106 +- protocols/relay/build.rs | 2 +- protocols/relay/examples/relay.rs | 2 +- protocols/relay/src/behaviour.rs | 2 +- .../relay/src/protocol/incoming_dst_req.rs | 16 +- .../relay/src/protocol/incoming_relay_req.rs | 5 +- .../relay/src/protocol/outgoing_dst_req.rs | 13 +- .../relay/src/protocol/outgoing_relay_req.rs | 14 +- protocols/relay/src/transport.rs | 2 +- protocols/relay/tests/lib.rs | 78 +- protocols/request-response/src/codec.rs | 31 +- protocols/request-response/src/handler.rs | 164 +-- .../request-response/src/handler/protocol.rs | 30 +- protocols/request-response/src/lib.rs | 342 +++-- protocols/request-response/src/throttled.rs | 436 +++--- .../request-response/src/throttled/codec.rs | 123 +- protocols/request-response/tests/ping.rs | 163 ++- src/bandwidth.rs | 82 +- src/lib.rs | 132 +- src/transport_ext.rs | 2 +- src/tutorial.rs | 4 +- swarm-derive/src/lib.rs | 325 +++-- swarm-derive/tests/test.rs | 68 +- swarm/src/behaviour.rs | 134 +- swarm/src/lib.rs | 692 +++++----- swarm/src/protocols_handler.rs | 73 +- swarm/src/protocols_handler/dummy.rs | 44 +- swarm/src/protocols_handler/map_in.rs | 31 +- swarm/src/protocols_handler/map_out.rs | 48 +- swarm/src/protocols_handler/multi.rs | 267 ++-- swarm/src/protocols_handler/node_handler.rs | 151 ++- swarm/src/protocols_handler/one_shot.rs | 64 +- swarm/src/protocols_handler/select.rs | 266 ++-- swarm/src/registry.rs | 123 +- swarm/src/test.rs | 56 +- swarm/src/toggle.rs | 137 +- transports/deflate/src/lib.rs | 73 +- transports/deflate/tests/test.rs | 28 +- transports/dns/src/lib.rs | 260 ++-- transports/noise/build.rs | 2 +- transports/noise/src/error.rs | 1 - transports/noise/src/io.rs | 35 +- transports/noise/src/io/framed.rs | 114 +- transports/noise/src/io/handshake.rs | 89 +- transports/noise/src/lib.rs | 148 ++- transports/noise/src/protocol.rs | 37 +- transports/noise/src/protocol/x25519.rs | 60 +- transports/noise/src/protocol/x25519_spec.rs | 44 +- transports/noise/tests/smoke.rs | 119 +- transports/plaintext/build.rs | 3 +- transports/plaintext/src/error.rs | 16 +- transports/plaintext/src/handshake.rs | 59 +- transports/plaintext/src/lib.rs | 52 +- transports/plaintext/tests/smoke.rs | 61 +- transports/pnet/src/crypt_writer.rs | 4 +- transports/pnet/src/lib.rs | 5 +- transports/tcp/src/lib.rs | 151 ++- transports/tcp/src/provider.rs | 9 +- transports/tcp/src/provider/async_io.rs | 24 +- transports/tcp/src/provider/tokio.rs | 94 +- transports/uds/src/lib.rs | 43 +- transports/wasm-ext/src/lib.rs | 66 +- transports/websocket/src/error.rs | 8 +- transports/websocket/src/framed.rs | 284 ++-- transports/websocket/src/lib.rs | 66 +- transports/websocket/src/tls.rs | 34 +- 172 files changed, 10171 insertions(+), 7240 deletions(-) diff --git a/core/benches/peer_id.rs b/core/benches/peer_id.rs index 5dfb0d7c132..9a6935113ec 100644 --- a/core/benches/peer_id.rs +++ b/core/benches/peer_id.rs @@ -35,9 +35,7 @@ fn from_bytes(c: &mut Criterion) { } fn clone(c: &mut Criterion) { - let peer_id = identity::Keypair::generate_ed25519() - .public() - .to_peer_id(); + let peer_id = identity::Keypair::generate_ed25519().public().to_peer_id(); c.bench_function("clone", |b| { b.iter(|| { @@ -48,11 +46,7 @@ fn clone(c: &mut Criterion) { fn sort_vec(c: &mut Criterion) { let peer_ids: Vec<_> = (0..100) - .map(|_| { - identity::Keypair::generate_ed25519() - .public() - .to_peer_id() - }) + .map(|_| identity::Keypair::generate_ed25519().public().to_peer_id()) .collect(); c.bench_function("sort_vec", |b| { diff --git a/core/build.rs b/core/build.rs index c08517dee58..9692abd9c81 100644 --- a/core/build.rs +++ b/core/build.rs @@ -19,5 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/keys.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/keys.proto"], &["src"]).unwrap(); } diff --git a/core/src/connection.rs b/core/src/connection.rs index 50b44b86ccd..335e2046c2d 100644 --- a/core/src/connection.rs +++ b/core/src/connection.rs @@ -28,16 +28,16 @@ pub(crate) mod pool; pub use error::{ConnectionError, PendingConnectionError}; pub use handler::{ConnectionHandler, ConnectionHandlerEvent, IntoConnectionHandler}; -pub use listeners::{ListenerId, ListenersStream, ListenersEvent}; +pub use listeners::{ListenerId, ListenersEvent, ListenersStream}; pub use manager::ConnectionId; -pub use substream::{Substream, SubstreamEndpoint, Close}; +pub use pool::{ConnectionCounters, ConnectionLimits}; pub use pool::{EstablishedConnection, EstablishedConnectionIter, PendingConnection}; -pub use pool::{ConnectionLimits, ConnectionCounters}; +pub use substream::{Close, Substream, SubstreamEndpoint}; use crate::muxing::StreamMuxer; use crate::{Multiaddr, PeerId}; -use std::{error::Error, fmt, pin::Pin, task::Context, task::Poll}; use std::hash::Hash; +use std::{error::Error, fmt, pin::Pin, task::Context, task::Poll}; use substream::{Muxing, SubstreamEvent}; /// The endpoint roles associated with a peer-to-peer communication channel. @@ -55,7 +55,7 @@ impl std::ops::Not for Endpoint { fn not(self) -> Self::Output { match self { Endpoint::Dialer => Endpoint::Listener, - Endpoint::Listener => Endpoint::Dialer + Endpoint::Listener => Endpoint::Dialer, } } } @@ -86,7 +86,7 @@ pub enum ConnectedPoint { local_addr: Multiaddr, /// Stack of protocols used to send back data to the remote. send_back_addr: Multiaddr, - } + }, } impl From<&'_ ConnectedPoint> for Endpoint { @@ -106,7 +106,7 @@ impl ConnectedPoint { pub fn to_endpoint(&self) -> Endpoint { match self { ConnectedPoint::Dialer { .. } => Endpoint::Dialer, - ConnectedPoint::Listener { .. } => Endpoint::Listener + ConnectedPoint::Listener { .. } => Endpoint::Listener, } } @@ -114,7 +114,7 @@ impl ConnectedPoint { pub fn is_dialer(&self) -> bool { match self { ConnectedPoint::Dialer { .. } => true, - ConnectedPoint::Listener { .. } => false + ConnectedPoint::Listener { .. } => false, } } @@ -122,7 +122,7 @@ impl ConnectedPoint { pub fn is_listener(&self) -> bool { match self { ConnectedPoint::Dialer { .. } => false, - ConnectedPoint::Listener { .. } => true + ConnectedPoint::Listener { .. } => true, } } @@ -237,9 +237,10 @@ where /// Polls the connection for events produced by the associated handler /// as a result of I/O activity on the substream multiplexer. - pub fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll, ConnectionError>> - { + pub fn poll( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, ConnectionError>> { loop { let mut io_pending = false; @@ -247,10 +248,13 @@ where // of new substreams. match self.muxing.poll(cx) { Poll::Pending => io_pending = true, - Poll::Ready(Ok(SubstreamEvent::InboundSubstream { substream })) => { - self.handler.inject_substream(substream, SubstreamEndpoint::Listener) - } - Poll::Ready(Ok(SubstreamEvent::OutboundSubstream { user_data, substream })) => { + Poll::Ready(Ok(SubstreamEvent::InboundSubstream { substream })) => self + .handler + .inject_substream(substream, SubstreamEndpoint::Listener), + Poll::Ready(Ok(SubstreamEvent::OutboundSubstream { + user_data, + substream, + })) => { let endpoint = SubstreamEndpoint::Dialer(user_data); self.handler.inject_substream(substream, endpoint) } @@ -265,7 +269,7 @@ where match self.handler.poll(cx) { Poll::Pending => { if io_pending { - return Poll::Pending // Nothing to do + return Poll::Pending; // Nothing to do } } Poll::Ready(Ok(ConnectionHandlerEvent::OutboundSubstreamRequest(user_data))) => { @@ -310,7 +314,7 @@ impl<'a> OutgoingInfo<'a> { /// Builds a `ConnectedPoint` corresponding to the outgoing connection. pub fn to_connected_point(&self) -> ConnectedPoint { ConnectedPoint::Dialer { - address: self.address.clone() + address: self.address.clone(), } } } diff --git a/core/src/connection/error.rs b/core/src/connection/error.rs index 1836965e43e..66da0670c98 100644 --- a/core/src/connection/error.rs +++ b/core/src/connection/error.rs @@ -20,7 +20,7 @@ use crate::connection::ConnectionLimit; use crate::transport::TransportError; -use std::{io, fmt}; +use std::{fmt, io}; /// Errors that can occur in the context of an established `Connection`. #[derive(Debug)] @@ -33,23 +33,19 @@ pub enum ConnectionError { Handler(THandlerErr), } -impl fmt::Display -for ConnectionError +impl fmt::Display for ConnectionError where THandlerErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ConnectionError::IO(err) => - write!(f, "Connection error: I/O error: {}", err), - ConnectionError::Handler(err) => - write!(f, "Connection error: Handler error: {}", err), + ConnectionError::IO(err) => write!(f, "Connection error: I/O error: {}", err), + ConnectionError::Handler(err) => write!(f, "Connection error: Handler error: {}", err), } } } -impl std::error::Error -for ConnectionError +impl std::error::Error for ConnectionError where THandlerErr: std::error::Error + 'static, { @@ -80,29 +76,29 @@ pub enum PendingConnectionError { IO(io::Error), } -impl fmt::Display -for PendingConnectionError +impl fmt::Display for PendingConnectionError where TTransErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - PendingConnectionError::IO(err) => - write!(f, "Pending connection: I/O error: {}", err), - PendingConnectionError::Transport(err) => - write!(f, "Pending connection: Transport error: {}", err), - PendingConnectionError::InvalidPeerId => - write!(f, "Pending connection: Invalid peer ID."), - PendingConnectionError::ConnectionLimit(l) => - write!(f, "Connection error: Connection limit: {}.", l), + PendingConnectionError::IO(err) => write!(f, "Pending connection: I/O error: {}", err), + PendingConnectionError::Transport(err) => { + write!(f, "Pending connection: Transport error: {}", err) + } + PendingConnectionError::InvalidPeerId => { + write!(f, "Pending connection: Invalid peer ID.") + } + PendingConnectionError::ConnectionLimit(l) => { + write!(f, "Connection error: Connection limit: {}.", l) + } } } } -impl std::error::Error -for PendingConnectionError +impl std::error::Error for PendingConnectionError where - TTransErr: std::error::Error + 'static + TTransErr: std::error::Error + 'static, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { diff --git a/core/src/connection/handler.rs b/core/src/connection/handler.rs index 0f1c2f6bcd8..011dcc2b61e 100644 --- a/core/src/connection/handler.rs +++ b/core/src/connection/handler.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use super::{Connected, SubstreamEndpoint}; use crate::Multiaddr; use std::{fmt::Debug, task::Context, task::Poll}; -use super::{Connected, SubstreamEndpoint}; /// The interface of a connection handler. /// @@ -53,7 +53,11 @@ pub trait ConnectionHandler { /// Implementations are allowed to panic in the case of dialing if the `user_data` in /// `endpoint` doesn't correspond to what was returned earlier when polling, or is used /// multiple times. - fn inject_substream(&mut self, substream: Self::Substream, endpoint: SubstreamEndpoint); + fn inject_substream( + &mut self, + substream: Self::Substream, + endpoint: SubstreamEndpoint, + ); /// Notifies the handler of an event. fn inject_event(&mut self, event: Self::InEvent); @@ -64,8 +68,10 @@ pub trait ConnectionHandler { /// Polls the handler for events. /// /// Returning an error will close the connection to the remote. - fn poll(&mut self, cx: &mut Context<'_>) - -> Poll, Self::Error>>; + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>; } /// Prototype for a `ConnectionHandler`. @@ -82,7 +88,7 @@ pub trait IntoConnectionHandler { impl IntoConnectionHandler for T where - T: ConnectionHandler + T: ConnectionHandler, { type Handler = Self; @@ -91,9 +97,12 @@ where } } -pub(crate) type THandlerInEvent = <::Handler as ConnectionHandler>::InEvent; -pub(crate) type THandlerOutEvent = <::Handler as ConnectionHandler>::OutEvent; -pub(crate) type THandlerError = <::Handler as ConnectionHandler>::Error; +pub(crate) type THandlerInEvent = + <::Handler as ConnectionHandler>::InEvent; +pub(crate) type THandlerOutEvent = + <::Handler as ConnectionHandler>::OutEvent; +pub(crate) type THandlerError = + <::Handler as ConnectionHandler>::Error; /// Event produced by a handler. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -109,24 +118,26 @@ pub enum ConnectionHandlerEvent { impl ConnectionHandlerEvent { /// If this is `OutboundSubstreamRequest`, maps the content to something else. pub fn map_outbound_open_info(self, map: F) -> ConnectionHandlerEvent - where F: FnOnce(TOutboundOpenInfo) -> I + where + F: FnOnce(TOutboundOpenInfo) -> I, { match self { ConnectionHandlerEvent::OutboundSubstreamRequest(val) => { ConnectionHandlerEvent::OutboundSubstreamRequest(map(val)) - }, + } ConnectionHandlerEvent::Custom(val) => ConnectionHandlerEvent::Custom(val), } } /// If this is `Custom`, maps the content to something else. pub fn map_custom(self, map: F) -> ConnectionHandlerEvent - where F: FnOnce(TCustom) -> I + where + F: FnOnce(TCustom) -> I, { match self { ConnectionHandlerEvent::OutboundSubstreamRequest(val) => { ConnectionHandlerEvent::OutboundSubstreamRequest(val) - }, + } ConnectionHandlerEvent::Custom(val) => ConnectionHandlerEvent::Custom(map(val)), } } diff --git a/core/src/connection/listeners.rs b/core/src/connection/listeners.rs index 02982d87393..8659c98cbe4 100644 --- a/core/src/connection/listeners.rs +++ b/core/src/connection/listeners.rs @@ -20,7 +20,10 @@ //! Manage listening on multiple multiaddresses at once. -use crate::{Multiaddr, Transport, transport::{TransportError, ListenerEvent}}; +use crate::{ + transport::{ListenerEvent, TransportError}, + Multiaddr, Transport, +}; use futures::{prelude::*, task::Context, task::Poll}; use log::debug; use smallvec::SmallVec; @@ -86,7 +89,7 @@ where /// can be resized, the only way is to use a `Pin>`. listeners: VecDeque>>>, /// The next listener ID to assign. - next_id: ListenerId + next_id: ListenerId, } /// The ID of a single listener. @@ -109,7 +112,7 @@ where #[pin] listener: TTrans::Listener, /// Addresses it is listening on. - addresses: SmallVec<[Multiaddr; 4]> + addresses: SmallVec<[Multiaddr; 4]>, } /// Event that can happen on the `ListenersStream`. @@ -122,14 +125,14 @@ where /// The listener that is listening on the new address. listener_id: ListenerId, /// The new address that is being listened on. - listen_addr: Multiaddr + listen_addr: Multiaddr, }, /// An address is no longer being listened on. AddressExpired { /// The listener that is no longer listening on the address. listener_id: ListenerId, /// The new address that is being listened on. - listen_addr: Multiaddr + listen_addr: Multiaddr, }, /// A connection is incoming on one of the listeners. Incoming { @@ -161,7 +164,7 @@ where listener_id: ListenerId, /// The error value. error: TTrans::Error, - } + }, } impl ListenersStream @@ -173,7 +176,7 @@ where ListenersStream { transport, listeners: VecDeque::new(), - next_id: ListenerId(1) + next_id: ListenerId(1), } } @@ -183,14 +186,17 @@ where ListenersStream { transport, listeners: VecDeque::with_capacity(capacity), - next_id: ListenerId(1) + next_id: ListenerId(1), } } /// Start listening on a multiaddress. /// /// Returns an error if the transport doesn't support the given multiaddress. - pub fn listen_on(&mut self, addr: Multiaddr) -> Result> + pub fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> where TTrans: Clone, { @@ -198,7 +204,7 @@ where self.listeners.push_back(Box::pin(Listener { id: self.next_id, listener, - addresses: SmallVec::new() + addresses: SmallVec::new(), })); let id = self.next_id; self.next_id = ListenerId(self.next_id.0 + 1); @@ -237,17 +243,23 @@ where Poll::Pending => { self.listeners.push_front(listener); remaining -= 1; - if remaining == 0 { break } + if remaining == 0 { + break; + } } - Poll::Ready(Some(Ok(ListenerEvent::Upgrade { upgrade, local_addr, remote_addr }))) => { + Poll::Ready(Some(Ok(ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + }))) => { let id = *listener_project.id; self.listeners.push_front(listener); return Poll::Ready(ListenersEvent::Incoming { listener_id: id, upgrade, local_addr, - send_back_addr: remote_addr - }) + send_back_addr: remote_addr, + }); } Poll::Ready(Some(Ok(ListenerEvent::NewAddress(a)))) => { if listener_project.addresses.contains(&a) { @@ -260,8 +272,8 @@ where self.listeners.push_front(listener); return Poll::Ready(ListenersEvent::NewAddress { listener_id: id, - listen_addr: a - }) + listen_addr: a, + }); } Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(a)))) => { listener_project.addresses.retain(|x| x != &a); @@ -269,8 +281,8 @@ where self.listeners.push_front(listener); return Poll::Ready(ListenersEvent::AddressExpired { listener_id: id, - listen_addr: a - }) + listen_addr: a, + }); } Poll::Ready(Some(Ok(ListenerEvent::Error(error)))) => { let id = *listener_project.id; @@ -278,7 +290,7 @@ where return Poll::Ready(ListenersEvent::Error { listener_id: id, error, - }) + }); } Poll::Ready(None) => { return Poll::Ready(ListenersEvent::Closed { @@ -313,11 +325,7 @@ where } } -impl Unpin for ListenersStream -where - TTrans: Transport, -{ -} +impl Unpin for ListenersStream where TTrans: Transport {} impl fmt::Debug for ListenersStream where @@ -338,22 +346,36 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - ListenersEvent::NewAddress { listener_id, listen_addr } => f + ListenersEvent::NewAddress { + listener_id, + listen_addr, + } => f .debug_struct("ListenersEvent::NewAddress") .field("listener_id", listener_id) .field("listen_addr", listen_addr) .finish(), - ListenersEvent::AddressExpired { listener_id, listen_addr } => f + ListenersEvent::AddressExpired { + listener_id, + listen_addr, + } => f .debug_struct("ListenersEvent::AddressExpired") .field("listener_id", listener_id) .field("listen_addr", listen_addr) .finish(), - ListenersEvent::Incoming { listener_id, local_addr, .. } => f + ListenersEvent::Incoming { + listener_id, + local_addr, + .. + } => f .debug_struct("ListenersEvent::Incoming") .field("listener_id", listener_id) .field("local_addr", local_addr) .finish(), - ListenersEvent::Closed { listener_id, addresses, reason } => f + ListenersEvent::Closed { + listener_id, + addresses, + reason, + } => f .debug_struct("ListenersEvent::Closed") .field("listener_id", listener_id) .field("addresses", addresses) @@ -363,7 +385,7 @@ where .debug_struct("ListenersEvent::Error") .field("listener_id", listener_id) .field("error", error) - .finish() + .finish(), } } } @@ -396,11 +418,15 @@ mod tests { }); match listeners.next().await.unwrap() { - ListenersEvent::Incoming { local_addr, send_back_addr, .. } => { + ListenersEvent::Incoming { + local_addr, + send_back_addr, + .. + } => { assert_eq!(local_addr, address); assert!(send_back_addr != address); - }, - _ => panic!() + } + _ => panic!(), } }); } @@ -415,21 +441,43 @@ mod tests { impl transport::Transport for DummyTrans { type Output = (); type Error = std::io::Error; - type Listener = Pin, std::io::Error>>>>; + type Listener = Pin< + Box< + dyn Stream< + Item = Result< + ListenerEvent, + std::io::Error, + >, + >, + >, + >; type ListenerUpgrade = Pin>>>; type Dial = Pin>>>; - fn listen_on(self, _: Multiaddr) -> Result> { + fn listen_on( + self, + _: Multiaddr, + ) -> Result> { Ok(Box::pin(stream::unfold((), |()| async move { - Some((Ok(ListenerEvent::Error(std::io::Error::from(std::io::ErrorKind::Other))), ())) + Some(( + Ok(ListenerEvent::Error(std::io::Error::from( + std::io::ErrorKind::Other, + ))), + (), + )) }))) } - fn dial(self, _: Multiaddr) -> Result> { + fn dial( + self, + _: Multiaddr, + ) -> Result> { panic!() } - fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { None } + fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { + None + } } async_std::task::block_on(async move { @@ -439,8 +487,8 @@ mod tests { for _ in 0..10 { match listeners.next().await.unwrap() { - ListenersEvent::Error { .. } => {}, - _ => panic!() + ListenersEvent::Error { .. } => {} + _ => panic!(), } } }); @@ -455,21 +503,38 @@ mod tests { impl transport::Transport for DummyTrans { type Output = (); type Error = std::io::Error; - type Listener = Pin, std::io::Error>>>>; + type Listener = Pin< + Box< + dyn Stream< + Item = Result< + ListenerEvent, + std::io::Error, + >, + >, + >, + >; type ListenerUpgrade = Pin>>>; type Dial = Pin>>>; - fn listen_on(self, _: Multiaddr) -> Result> { + fn listen_on( + self, + _: Multiaddr, + ) -> Result> { Ok(Box::pin(stream::unfold((), |()| async move { Some((Err(std::io::Error::from(std::io::ErrorKind::Other)), ())) }))) } - fn dial(self, _: Multiaddr) -> Result> { + fn dial( + self, + _: Multiaddr, + ) -> Result> { panic!() } - fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { None } + fn address_translation(&self, _: &Multiaddr, _: &Multiaddr) -> Option { + None + } } async_std::task::block_on(async move { @@ -478,8 +543,8 @@ mod tests { listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); match listeners.next().await.unwrap() { - ListenersEvent::Closed { .. } => {}, - _ => panic!() + ListenersEvent::Closed { .. } => {} + _ => panic!(), } }); } diff --git a/core/src/connection/manager.rs b/core/src/connection/manager.rs index b450f0d602f..1d7acb92e69 100644 --- a/core/src/connection/manager.rs +++ b/core/src/connection/manager.rs @@ -18,39 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{ - Executor, - muxing::StreamMuxer, +use super::{ + handler::{THandlerError, THandlerInEvent, THandlerOutEvent}, + Connected, ConnectedPoint, Connection, ConnectionError, ConnectionHandler, + IntoConnectionHandler, PendingConnectionError, Substream, }; +use crate::{muxing::StreamMuxer, Executor}; use fnv::FnvHashMap; -use futures::{ - prelude::*, - channel::mpsc, - stream::FuturesUnordered -}; +use futures::{channel::mpsc, prelude::*, stream::FuturesUnordered}; use std::{ collections::hash_map, - error, - fmt, - mem, + error, fmt, mem, pin::Pin, task::{Context, Poll}, }; -use super::{ - Connected, - ConnectedPoint, - Connection, - ConnectionError, - ConnectionHandler, - IntoConnectionHandler, - PendingConnectionError, - Substream, - handler::{ - THandlerInEvent, - THandlerOutEvent, - THandlerError, - }, -}; use task::{Task, TaskId}; mod task; @@ -123,11 +104,10 @@ pub struct Manager { events_tx: mpsc::Sender>, /// Receiver for events reported from managed tasks. - events_rx: mpsc::Receiver> + events_rx: mpsc::Receiver>, } -impl fmt::Debug for Manager -{ +impl fmt::Debug for Manager { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_map() .entries(self.tasks.iter().map(|(id, task)| (id, &task.state))) @@ -196,7 +176,7 @@ pub enum Event<'a, H: IntoConnectionHandler, TE> { /// What happened. error: PendingConnectionError, /// The handler that was supposed to handle the failed connection. - handler: H + handler: H, }, /// An established connection has been closed. @@ -225,7 +205,7 @@ pub enum Event<'a, H: IntoConnectionHandler, TE> { /// The entry associated with the connection that produced the event. entry: EstablishedEntry<'a, THandlerInEvent>, /// The produced event. - event: THandlerOutEvent + event: THandlerOutEvent, }, /// A connection to a node has changed its address. @@ -250,7 +230,7 @@ impl Manager { executor: config.executor, local_spawns: FuturesUnordered::new(), events_tx: tx, - events_rx: rx + events_rx: rx, } } @@ -265,18 +245,28 @@ impl Manager { M::OutboundSubstream: Send + 'static, F: Future> + Send + 'static, H: IntoConnectionHandler + Send + 'static, - H::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + H::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, { let task_id = self.next_task_id; self.next_task_id.0 += 1; let (tx, rx) = mpsc::channel(self.task_command_buffer_size); - self.tasks.insert(task_id, TaskInfo { sender: tx, state: TaskState::Pending }); - - let task = Box::pin(Task::pending(task_id, self.events_tx.clone(), rx, future, handler)); + self.tasks.insert( + task_id, + TaskInfo { + sender: tx, + state: TaskState::Pending, + }, + ); + + let task = Box::pin(Task::pending( + task_id, + self.events_tx.clone(), + rx, + future, + handler, + )); if let Some(executor) = &mut self.executor { executor.exec(task); } else { @@ -290,9 +280,7 @@ impl Manager { pub fn add(&mut self, conn: Connection, info: Connected) -> ConnectionId where H: IntoConnectionHandler + Send + 'static, - H::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + H::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TE: error::Error + Send + 'static, M: StreamMuxer + Send + Sync + 'static, @@ -302,9 +290,13 @@ impl Manager { self.next_task_id.0 += 1; let (tx, rx) = mpsc::channel(self.task_command_buffer_size); - self.tasks.insert(task_id, TaskInfo { - sender: tx, state: TaskState::Established(info) - }); + self.tasks.insert( + task_id, + TaskInfo { + sender: tx, + state: TaskState::Established(info), + }, + ); let task: Pin>>, _, _, _>>> = Box::pin(Task::established(task_id, self.events_tx.clone(), rx, conn)); @@ -329,7 +321,13 @@ impl Manager { /// Checks whether an established connection with the given ID is currently managed. pub fn is_established(&self, id: &ConnectionId) -> bool { - matches!(self.tasks.get(&id.0), Some(TaskInfo { state: TaskState::Established(..), .. })) + matches!( + self.tasks.get(&id.0), + Some(TaskInfo { + state: TaskState::Established(..), + .. + }) + ) } /// Polls the manager for events relating to the managed connections. @@ -341,8 +339,9 @@ impl Manager { let event = loop { match self.events_rx.poll_next_unpin(cx) { Poll::Ready(Some(event)) => { - if self.tasks.contains_key(event.id()) { // (1) - break event + if self.tasks.contains_key(event.id()) { + // (1) + break event; } } Poll::Pending => return Poll::Pending, @@ -352,12 +351,12 @@ impl Manager { if let hash_map::Entry::Occupied(mut task) = self.tasks.entry(*event.id()) { Poll::Ready(match event { - task::Event::Notify { id: _, event } => - Event::ConnectionEvent { - entry: EstablishedEntry { task }, - event - }, - task::Event::Established { id: _, info } => { // (2) + task::Event::Notify { id: _, event } => Event::ConnectionEvent { + entry: EstablishedEntry { task }, + event, + }, + task::Event::Established { id: _, info } => { + // (2) task.get_mut().state = TaskState::Established(info); // (3) Event::ConnectionEstablished { entry: EstablishedEntry { task }, @@ -389,11 +388,14 @@ impl Manager { let id = ConnectionId(id); let task = task.remove(); match task.state { - TaskState::Established(connected) => - Event::ConnectionClosed { id, connected, error }, + TaskState::Established(connected) => Event::ConnectionClosed { + id, + connected, + error, + }, TaskState::Pending => unreachable!( "`Event::Closed` implies (2) occurred on that task and thus (3)." - ), + ), } } }) @@ -407,14 +409,14 @@ impl Manager { #[derive(Debug)] pub enum Entry<'a, I> { Pending(PendingEntry<'a, I>), - Established(EstablishedEntry<'a, I>) + Established(EstablishedEntry<'a, I>), } impl<'a, I> Entry<'a, I> { fn new(task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo>) -> Self { match &task.get().state { TaskState::Pending => Entry::Pending(PendingEntry { task }), - TaskState::Established(_) => Entry::Established(EstablishedEntry { task }) + TaskState::Established(_) => Entry::Established(EstablishedEntry { task }), } } } @@ -442,10 +444,13 @@ impl<'a, I> EstablishedEntry<'a, I> { /// > the connection handler not being ready at this time. pub fn notify_handler(&mut self, event: I) -> Result<(), I> { let cmd = task::Command::NotifyHandler(event); // (*) - self.task.get_mut().sender.try_send(cmd) + self.task + .get_mut() + .sender + .try_send(cmd) .map_err(|e| match e.into_inner() { task::Command::NotifyHandler(event) => event, - _ => panic!("Unexpected command. Expected `NotifyHandler`") // see (*) + _ => panic!("Unexpected command. Expected `NotifyHandler`"), // see (*) }) } @@ -455,7 +460,7 @@ impl<'a, I> EstablishedEntry<'a, I> { /// /// Returns `Err(())` if the background task associated with the connection /// is terminating and the connection is about to close. - pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll> { + pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll> { self.task.get_mut().sender.poll_ready(cx).map_err(|_| ()) } @@ -469,9 +474,15 @@ impl<'a, I> EstablishedEntry<'a, I> { pub fn start_close(mut self) { // Clone the sender so that we are guaranteed to have // capacity for the close command (every sender gets a slot). - match self.task.get_mut().sender.clone().try_send(task::Command::Close) { - Ok(()) => {}, - Err(e) => assert!(e.is_disconnected(), "No capacity for close command.") + match self + .task + .get_mut() + .sender + .clone() + .try_send(task::Command::Close) + { + Ok(()) => {} + Err(e) => assert!(e.is_disconnected(), "No capacity for close command."), } } @@ -479,7 +490,7 @@ impl<'a, I> EstablishedEntry<'a, I> { pub fn connected(&self) -> &Connected { match &self.task.get().state { TaskState::Established(c) => c, - TaskState::Pending => unreachable!("By Entry::new()") + TaskState::Pending => unreachable!("By Entry::new()"), } } @@ -490,7 +501,7 @@ impl<'a, I> EstablishedEntry<'a, I> { pub fn remove(self) -> Connected { match self.task.remove().state { TaskState::Established(c) => c, - TaskState::Pending => unreachable!("By Entry::new()") + TaskState::Pending => unreachable!("By Entry::new()"), } } @@ -504,7 +515,7 @@ impl<'a, I> EstablishedEntry<'a, I> { /// (i.e. pending). #[derive(Debug)] pub struct PendingEntry<'a, I> { - task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo> + task: hash_map::OccupiedEntry<'a, TaskId, TaskInfo>, } impl<'a, I> PendingEntry<'a, I> { @@ -514,7 +525,7 @@ impl<'a, I> PendingEntry<'a, I> { } /// Aborts the pending connection attempt. - pub fn abort(self) { + pub fn abort(self) { self.task.remove(); } } diff --git a/core/src/connection/manager/task.rs b/core/src/connection/manager/task.rs index a7bdbd3cbbd..db8fb43adb6 100644 --- a/core/src/connection/manager/task.rs +++ b/core/src/connection/manager/task.rs @@ -18,29 +18,19 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use super::ConnectResult; use crate::{ - Multiaddr, - muxing::StreamMuxer, connection::{ self, - Close, - Connected, - Connection, - ConnectionError, - ConnectionHandler, - IntoConnectionHandler, - PendingConnectionError, - Substream, - handler::{ - THandlerInEvent, - THandlerOutEvent, - THandlerError, - }, + handler::{THandlerError, THandlerInEvent, THandlerOutEvent}, + Close, Connected, Connection, ConnectionError, ConnectionHandler, IntoConnectionHandler, + PendingConnectionError, Substream, }, + muxing::StreamMuxer, + Multiaddr, }; -use futures::{prelude::*, channel::mpsc, stream}; +use futures::{channel::mpsc, prelude::*, stream}; use std::{pin::Pin, task::Context, task::Poll}; -use super::ConnectResult; /// Identifier of a [`Task`] in a [`Manager`](super::Manager). #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] @@ -62,16 +52,26 @@ pub enum Event { /// A connection to a node has succeeded. Established { id: TaskId, info: Connected }, /// A pending connection failed. - Failed { id: TaskId, error: PendingConnectionError, handler: H }, + Failed { + id: TaskId, + error: PendingConnectionError, + handler: H, + }, /// A node we are connected to has changed its address. AddressChange { id: TaskId, new_address: Multiaddr }, /// Notify the manager of an event from the connection. - Notify { id: TaskId, event: THandlerOutEvent }, + Notify { + id: TaskId, + event: THandlerOutEvent, + }, /// A connection closed, possibly due to an error. /// /// If `error` is `None`, the connection has completed /// an active orderly close. - Closed { id: TaskId, error: Option>> } + Closed { + id: TaskId, + error: Option>>, + }, } impl Event { @@ -91,7 +91,7 @@ pub struct Task where M: StreamMuxer, H: IntoConnectionHandler, - H::Handler: ConnectionHandler> + H::Handler: ConnectionHandler>, { /// The ID of this task. id: TaskId, @@ -110,7 +110,7 @@ impl Task where M: StreamMuxer, H: IntoConnectionHandler, - H::Handler: ConnectionHandler> + H::Handler: ConnectionHandler>, { /// Create a new task to connect and handle some node. pub fn pending( @@ -118,7 +118,7 @@ where events: mpsc::Sender>, commands: mpsc::Receiver>>, future: F, - handler: H + handler: H, ) -> Self { Task { id, @@ -136,13 +136,16 @@ where id: TaskId, events: mpsc::Sender>, commands: mpsc::Receiver>>, - connection: Connection + connection: Connection, ) -> Self { Task { id, events, commands: commands.fuse(), - state: State::Established { connection, event: None }, + state: State::Established { + connection, + event: None, + }, } } } @@ -152,7 +155,7 @@ enum State where M: StreamMuxer, H: IntoConnectionHandler, - H::Handler: ConnectionHandler> + H::Handler: ConnectionHandler>, { /// The connection is being negotiated. Pending { @@ -180,14 +183,14 @@ where Terminating(Event), /// The task has finished. - Done + Done, } impl Unpin for Task where M: StreamMuxer, H: IntoConnectionHandler, - H::Handler: ConnectionHandler> + H::Handler: ConnectionHandler>, { } @@ -196,9 +199,7 @@ where M: StreamMuxer, F: Future>, H: IntoConnectionHandler, - H::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + H::Handler: ConnectionHandler> + Send + 'static, { type Output = (); @@ -211,33 +212,33 @@ where 'poll: loop { match std::mem::replace(&mut this.state, State::Done) { - State::Pending { mut future, handler } => { + State::Pending { + mut future, + handler, + } => { // Check whether the task is still registered with a `Manager` // by polling the commands channel. match this.commands.poll_next_unpin(cx) { - Poll::Pending => {}, + Poll::Pending => {} Poll::Ready(None) => { // The manager has dropped the task; abort. - return Poll::Ready(()) + return Poll::Ready(()); + } + Poll::Ready(Some(_)) => { + panic!("Task received command while the connection is pending.") } - Poll::Ready(Some(_)) => panic!( - "Task received command while the connection is pending." - ) } // Check if the connection succeeded. match future.poll_unpin(cx) { Poll::Ready(Ok((info, muxer))) => { this.state = State::Established { - connection: Connection::new( - muxer, - handler.into_handler(&info), - ), - event: Some(Event::Established { id, info }) + connection: Connection::new(muxer, handler.into_handler(&info)), + event: Some(Event::Established { id, info }), } } Poll::Pending => { this.state = State::Pending { future, handler }; - return Poll::Pending + return Poll::Pending; } Poll::Ready(Err(error)) => { // Don't accept any further commands and terminate the @@ -249,23 +250,27 @@ where } } - State::Established { mut connection, event } => { + State::Established { + mut connection, + event, + } => { // Check for commands from the `Manager`. loop { match this.commands.poll_next_unpin(cx) { Poll::Pending => break, - Poll::Ready(Some(Command::NotifyHandler(event))) => - connection.inject_event(event), + Poll::Ready(Some(Command::NotifyHandler(event))) => { + connection.inject_event(event) + } Poll::Ready(Some(Command::Close)) => { // Don't accept any further commands. this.commands.get_mut().close(); // Discard the event, if any, and start a graceful close. this.state = State::Closing(connection.close()); - continue 'poll + continue 'poll; } Poll::Ready(None) => { // The manager has dropped the task or disappeared; abort. - return Poll::Ready(()) + return Poll::Ready(()); } } } @@ -274,44 +279,56 @@ where // Send the event to the manager. match this.events.poll_ready(cx) { Poll::Pending => { - this.state = State::Established { connection, event: Some(event) }; - return Poll::Pending + this.state = State::Established { + connection, + event: Some(event), + }; + return Poll::Pending; } Poll::Ready(result) => { if result.is_ok() { if let Ok(()) = this.events.start_send(event) { - this.state = State::Established { connection, event: None }; - continue 'poll + this.state = State::Established { + connection, + event: None, + }; + continue 'poll; } } // The manager is no longer reachable; abort. - return Poll::Ready(()) + return Poll::Ready(()); } } } else { // Poll the connection for new events. match Connection::poll(Pin::new(&mut connection), cx) { Poll::Pending => { - this.state = State::Established { connection, event: None }; - return Poll::Pending + this.state = State::Established { + connection, + event: None, + }; + return Poll::Pending; } Poll::Ready(Ok(connection::Event::Handler(event))) => { this.state = State::Established { connection, - event: Some(Event::Notify { id, event }) + event: Some(Event::Notify { id, event }), }; } Poll::Ready(Ok(connection::Event::AddressChange(new_address))) => { this.state = State::Established { connection, - event: Some(Event::AddressChange { id, new_address }) + event: Some(Event::AddressChange { id, new_address }), }; } Poll::Ready(Err(error)) => { // Don't accept any further commands. this.commands.get_mut().close(); // Terminate the task with the error, dropping the connection. - let event = Event::Closed { id, error: Some(error) }; + let event = Event::Closed { + id, + error: Some(error), + }; this.state = State::Terminating(event); } } @@ -322,19 +339,22 @@ where // Try to gracefully close the connection. match closing.poll_unpin(cx) { Poll::Ready(Ok(())) => { - let event = Event::Closed { id: this.id, error: None }; + let event = Event::Closed { + id: this.id, + error: None, + }; this.state = State::Terminating(event); } Poll::Ready(Err(e)) => { let event = Event::Closed { id: this.id, - error: Some(ConnectionError::IO(e)) + error: Some(ConnectionError::IO(e)), }; this.state = State::Terminating(event); } Poll::Pending => { this.state = State::Closing(closing); - return Poll::Pending + return Poll::Pending; } } } @@ -344,18 +364,18 @@ where match this.events.poll_ready(cx) { Poll::Pending => { self.state = State::Terminating(event); - return Poll::Pending + return Poll::Pending; } Poll::Ready(result) => { if result.is_ok() { let _ = this.events.start_send(event); } - return Poll::Ready(()) + return Poll::Ready(()); } } } - State::Done => panic!("`Task::poll()` called after completion.") + State::Done => panic!("`Task::poll()` called after completion."), } } } diff --git a/core/src/connection/pool.rs b/core/src/connection/pool.rs index 263c36a88a8..9925dd526c0 100644 --- a/core/src/connection/pool.rs +++ b/core/src/connection/pool.rs @@ -19,29 +19,15 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - ConnectedPoint, - PeerId, connection::{ self, - Connected, - Connection, - ConnectionId, - ConnectionLimit, - ConnectionError, - ConnectionHandler, - IncomingInfo, - IntoConnectionHandler, - OutgoingInfo, - Substream, - PendingConnectionError, - handler::{ - THandlerInEvent, - THandlerOutEvent, - THandlerError, - }, + handler::{THandlerError, THandlerInEvent, THandlerOutEvent}, manager::{self, Manager, ManagerConfig}, + Connected, Connection, ConnectionError, ConnectionHandler, ConnectionId, ConnectionLimit, + IncomingInfo, IntoConnectionHandler, OutgoingInfo, PendingConnectionError, Substream, }, muxing::StreamMuxer, + ConnectedPoint, PeerId, }; use either::Either; use fnv::FnvHashMap; @@ -76,9 +62,7 @@ pub struct Pool { disconnected: Vec, } -impl fmt::Debug -for Pool -{ +impl fmt::Debug for Pool { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("Pool") .field("counters", &self.counters) @@ -86,8 +70,7 @@ for Pool } } -impl Unpin -for Pool {} +impl Unpin for Pool {} /// Event that can happen on the `Pool`. pub enum PoolEvent<'a, THandler: IntoConnectionHandler, TTransErr> { @@ -157,56 +140,60 @@ pub enum PoolEvent<'a, THandler: IntoConnectionHandler, TTransErr> { }, } -impl<'a, THandler: IntoConnectionHandler, TTransErr> fmt::Debug for PoolEvent<'a, THandler, TTransErr> +impl<'a, THandler: IntoConnectionHandler, TTransErr> fmt::Debug + for PoolEvent<'a, THandler, TTransErr> where TTransErr: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match *self { - PoolEvent::ConnectionEstablished { ref connection, .. } => { - f.debug_tuple("PoolEvent::ConnectionEstablished") - .field(connection) - .finish() - }, - PoolEvent::ConnectionClosed { ref id, ref connected, ref error, .. } => { - f.debug_struct("PoolEvent::ConnectionClosed") - .field("id", id) - .field("connected", connected) - .field("error", error) - .finish() - }, - PoolEvent::PendingConnectionError { ref id, ref error, .. } => { - f.debug_struct("PoolEvent::PendingConnectionError") - .field("id", id) - .field("error", error) - .finish() - }, - PoolEvent::ConnectionEvent { ref connection, ref event } => { - f.debug_struct("PoolEvent::ConnectionEvent") - .field("peer", &connection.peer_id()) - .field("event", event) - .finish() - }, - PoolEvent::AddressChange { ref connection, ref new_endpoint, ref old_endpoint } => { - f.debug_struct("PoolEvent::AddressChange") - .field("peer", &connection.peer_id()) - .field("new_endpoint", new_endpoint) - .field("old_endpoint", old_endpoint) - .finish() - }, + PoolEvent::ConnectionEstablished { ref connection, .. } => f + .debug_tuple("PoolEvent::ConnectionEstablished") + .field(connection) + .finish(), + PoolEvent::ConnectionClosed { + ref id, + ref connected, + ref error, + .. + } => f + .debug_struct("PoolEvent::ConnectionClosed") + .field("id", id) + .field("connected", connected) + .field("error", error) + .finish(), + PoolEvent::PendingConnectionError { + ref id, ref error, .. + } => f + .debug_struct("PoolEvent::PendingConnectionError") + .field("id", id) + .field("error", error) + .finish(), + PoolEvent::ConnectionEvent { + ref connection, + ref event, + } => f + .debug_struct("PoolEvent::ConnectionEvent") + .field("peer", &connection.peer_id()) + .field("event", event) + .finish(), + PoolEvent::AddressChange { + ref connection, + ref new_endpoint, + ref old_endpoint, + } => f + .debug_struct("PoolEvent::AddressChange") + .field("peer", &connection.peer_id()) + .field("new_endpoint", new_endpoint) + .field("old_endpoint", old_endpoint) + .finish(), } } } -impl - Pool -{ +impl Pool { /// Creates a new empty `Pool`. - pub fn new( - local_id: PeerId, - manager_config: ManagerConfig, - limits: ConnectionLimits - ) -> Self { + pub fn new(local_id: PeerId, manager_config: ManagerConfig, limits: ConnectionLimits) -> Self { Pool { local_id, counters: ConnectionCounters::new(limits), @@ -234,13 +221,11 @@ impl info: IncomingInfo<'_>, ) -> Result where - TFut: Future< - Output = Result<(PeerId, TMuxer), PendingConnectionError> - > + Send + 'static, + TFut: Future>> + + Send + + 'static, THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + THandler::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TTransErr: error::Error + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, @@ -263,13 +248,11 @@ impl info: OutgoingInfo<'_>, ) -> Result where - TFut: Future< - Output = Result<(PeerId, TMuxer), PendingConnectionError> - > + Send + 'static, + TFut: Future>> + + Send + + 'static, THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + THandler::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TTransErr: error::Error + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, @@ -290,13 +273,11 @@ impl peer: Option, ) -> ConnectionId where - TFut: Future< - Output = Result<(PeerId, TMuxer), PendingConnectionError> - > + Send + 'static, + TFut: Future>> + + Send + + 'static, THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + THandler::Handler: ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TTransErr: error::Error + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, @@ -313,12 +294,12 @@ impl move |(peer_id, muxer)| { if let Some(peer) = expected_peer { if peer != peer_id { - return future::err(PendingConnectionError::InvalidPeerId) + return future::err(PendingConnectionError::InvalidPeerId); } } if local_id == peer_id { - return future::err(PendingConnectionError::InvalidPeerId) + return future::err(PendingConnectionError::InvalidPeerId); } let connected = Connected { peer_id, endpoint }; @@ -337,73 +318,80 @@ impl /// Returns the assigned connection ID on success. An error is returned /// if the configured maximum number of established connections for the /// connected peer has been reached. - pub fn add(&mut self, c: Connection, i: Connected) - -> Result + pub fn add( + &mut self, + c: Connection, + i: Connected, + ) -> Result where THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = connection::Substream, - > + Send + 'static, + THandler::Handler: + ConnectionHandler> + Send + 'static, ::OutboundOpenInfo: Send + 'static, TTransErr: error::Error + Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, TMuxer::OutboundSubstream: Send + 'static, { self.counters.check_max_established(&i.endpoint)?; - self.counters.check_max_established_per_peer(self.num_peer_established(&i.peer_id))?; + self.counters + .check_max_established_per_peer(self.num_peer_established(&i.peer_id))?; let id = self.manager.add(c, i.clone()); self.counters.inc_established(&i.endpoint); - self.established.entry(i.peer_id).or_default().insert(id, i.endpoint); + self.established + .entry(i.peer_id) + .or_default() + .insert(id, i.endpoint); Ok(id) } /// Gets an entry representing a connection in the pool. /// /// Returns `None` if the pool has no connection with the given ID. - pub fn get(&mut self, id: ConnectionId) - -> Option>> - { + pub fn get( + &mut self, + id: ConnectionId, + ) -> Option>> { match self.manager.entry(id) { - Some(manager::Entry::Established(entry)) => - Some(PoolConnection::Established(EstablishedConnection { - entry - })), - Some(manager::Entry::Pending(entry)) => + Some(manager::Entry::Established(entry)) => { + Some(PoolConnection::Established(EstablishedConnection { entry })) + } + Some(manager::Entry::Pending(entry)) => { Some(PoolConnection::Pending(PendingConnection { entry, pending: &mut self.pending, counters: &mut self.counters, - })), - None => None + })) + } + None => None, } } /// Gets an established connection from the pool by ID. - pub fn get_established(&mut self, id: ConnectionId) - -> Option>> - { + pub fn get_established( + &mut self, + id: ConnectionId, + ) -> Option>> { match self.get(id) { Some(PoolConnection::Established(c)) => Some(c), - _ => None + _ => None, } } /// Gets a pending outgoing connection by ID. - pub fn get_outgoing(&mut self, id: ConnectionId) - -> Option>> - { + pub fn get_outgoing( + &mut self, + id: ConnectionId, + ) -> Option>> { match self.pending.get(&id) { - Some((ConnectedPoint::Dialer { .. }, _peer)) => - match self.manager.entry(id) { - Some(manager::Entry::Pending(entry)) => - Some(PendingConnection { - entry, - pending: &mut self.pending, - counters: &mut self.counters, - }), - _ => unreachable!("by consistency of `self.pending` with `self.manager`") - } - _ => None + Some((ConnectedPoint::Dialer { .. }, _peer)) => match self.manager.entry(id) { + Some(manager::Entry::Pending(entry)) => Some(PendingConnection { + entry, + pending: &mut self.pending, + counters: &mut self.counters, + }), + _ => unreachable!("by consistency of `self.pending` with `self.manager`"), + }, + _ => None, } } @@ -437,7 +425,9 @@ impl if let Some(manager::Entry::Established(e)) = self.manager.entry(id) { let connected = e.remove(); self.disconnected.push(Disconnected { - id, connected, num_established + id, + connected, + num_established, }); num_established += 1; } @@ -468,14 +458,13 @@ impl } /// Returns an iterator over all established connections of `peer`. - pub fn iter_peer_established<'a>(&'a mut self, peer: &PeerId) - -> EstablishedConnectionIter<'a, - impl Iterator, - THandler, - TTransErr, - > + pub fn iter_peer_established<'a>( + &'a mut self, + peer: &PeerId, + ) -> EstablishedConnectionIter<'a, impl Iterator, THandler, TTransErr> { - let ids = self.iter_peer_established_info(peer) + let ids = self + .iter_peer_established_info(peer) .map(|(id, _endpoint)| *id) .collect::>() .into_iter(); @@ -486,45 +475,50 @@ impl /// Returns an iterator for information on all pending incoming connections. pub fn iter_pending_incoming(&self) -> impl Iterator> { self.iter_pending_info() - .filter_map(|(_, ref endpoint, _)| { - match endpoint { - ConnectedPoint::Listener { local_addr, send_back_addr } => { - Some(IncomingInfo { local_addr, send_back_addr }) - }, - ConnectedPoint::Dialer { .. } => None, - } + .filter_map(|(_, ref endpoint, _)| match endpoint { + ConnectedPoint::Listener { + local_addr, + send_back_addr, + } => Some(IncomingInfo { + local_addr, + send_back_addr, + }), + ConnectedPoint::Dialer { .. } => None, }) } /// Returns an iterator for information on all pending outgoing connections. pub fn iter_pending_outgoing(&self) -> impl Iterator> { self.iter_pending_info() - .filter_map(|(_, ref endpoint, ref peer_id)| { - match endpoint { - ConnectedPoint::Listener { .. } => None, - ConnectedPoint::Dialer { address } => - Some(OutgoingInfo { address, peer_id: peer_id.as_ref() }), - } + .filter_map(|(_, ref endpoint, ref peer_id)| match endpoint { + ConnectedPoint::Listener { .. } => None, + ConnectedPoint::Dialer { address } => Some(OutgoingInfo { + address, + peer_id: peer_id.as_ref(), + }), }) } /// Returns an iterator over all connection IDs and associated endpoints /// of established connections to `peer` known to the pool. - pub fn iter_peer_established_info(&self, peer: &PeerId) - -> impl Iterator + fmt::Debug + '_ - { + pub fn iter_peer_established_info( + &self, + peer: &PeerId, + ) -> impl Iterator + fmt::Debug + '_ { match self.established.get(peer) { Some(conns) => Either::Left(conns.iter()), - None => Either::Right(std::iter::empty()) + None => Either::Right(std::iter::empty()), } } /// Returns an iterator over all pending connection IDs together /// with associated endpoints and expected peer IDs in the pool. - pub fn iter_pending_info(&self) - -> impl Iterator)> + '_ - { - self.pending.iter().map(|(id, (endpoint, info))| (id, endpoint, info)) + pub fn iter_pending_info( + &self, + ) -> impl Iterator)> + '_ { + self.pending + .iter() + .map(|(id, (endpoint, info))| (id, endpoint, info)) } /// Returns an iterator over all connected peers, i.e. those that have @@ -537,9 +531,10 @@ impl /// /// > **Note**: We use a regular `poll` method instead of implementing `Stream`, /// > because we want the `Pool` to stay borrowed if necessary. - pub fn poll<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll< - PoolEvent<'a, THandler, TTransErr> - > { + pub fn poll<'a>( + &'a mut self, + cx: &mut Context<'_>, + ) -> Poll> { // Drain events resulting from forced disconnections. // // Note: The `Disconnected` entries in `self.disconnected` @@ -548,15 +543,18 @@ impl // events in an order that properly counts down `num_established`. // See also `Pool::disconnect`. if let Some(Disconnected { - id, connected, num_established - }) = self.disconnected.pop() { + id, + connected, + num_established, + }) = self.disconnected.pop() + { return Poll::Ready(PoolEvent::ConnectionClosed { id, connected, num_established, error: None, pool: self, - }) + }); } // Poll the connection `Manager`. @@ -576,11 +574,15 @@ impl error, handler: Some(handler), peer, - pool: self - }) + pool: self, + }); } - }, - manager::Event::ConnectionClosed { id, connected, error } => { + } + manager::Event::ConnectionClosed { + id, + connected, + error, + } => { let num_established = if let Some(conns) = self.established.get_mut(&connected.peer_id) { if let Some(endpoint) = conns.remove(&id) { @@ -594,8 +596,12 @@ impl self.established.remove(&connected.peer_id); } return Poll::Ready(PoolEvent::ConnectionClosed { - id, connected, error, num_established, pool: self - }) + id, + connected, + error, + num_established, + pool: self, + }); } manager::Event::ConnectionEstablished { entry } => { let id = entry.id(); @@ -611,12 +617,13 @@ impl error: PendingConnectionError::ConnectionLimit(e), handler: None, peer, - pool: self - }) + pool: self, + }); } // Check per-peer established connection limit. - let current = num_peer_established(&self.established, &entry.connected().peer_id); + let current = + num_peer_established(&self.established, &entry.connected().peer_id); if let Err(e) = self.counters.check_max_established_per_peer(current) { let connected = entry.remove(); return Poll::Ready(PoolEvent::PendingConnectionError { @@ -625,8 +632,8 @@ impl error: PendingConnectionError::ConnectionLimit(e), handler: None, peer, - pool: self - }) + pool: self, + }); } // Peer ID checks must already have happened. See `add_pending`. @@ -644,54 +651,62 @@ impl // Add the connection to the pool. let peer = entry.connected().peer_id; let conns = self.established.entry(peer).or_default(); - let num_established = NonZeroU32::new(u32::try_from(conns.len() + 1).unwrap()) - .expect("n + 1 is always non-zero; qed"); + let num_established = + NonZeroU32::new(u32::try_from(conns.len() + 1).unwrap()) + .expect("n + 1 is always non-zero; qed"); self.counters.inc_established(&endpoint); conns.insert(id, endpoint); match self.get(id) { - Some(PoolConnection::Established(connection)) => + Some(PoolConnection::Established(connection)) => { return Poll::Ready(PoolEvent::ConnectionEstablished { - connection, num_established - }), - _ => unreachable!("since `entry` is an `EstablishedEntry`.") + connection, + num_established, + }) + } + _ => unreachable!("since `entry` is an `EstablishedEntry`."), } } - }, + } manager::Event::ConnectionEvent { entry, event } => { let id = entry.id(); match self.get(id) { - Some(PoolConnection::Established(connection)) => - return Poll::Ready(PoolEvent::ConnectionEvent { - connection, - event, - }), - _ => unreachable!("since `entry` is an `EstablishedEntry`.") + Some(PoolConnection::Established(connection)) => { + return Poll::Ready(PoolEvent::ConnectionEvent { connection, event }) + } + _ => unreachable!("since `entry` is an `EstablishedEntry`."), } - }, - manager::Event::AddressChange { entry, new_endpoint, old_endpoint } => { + } + manager::Event::AddressChange { + entry, + new_endpoint, + old_endpoint, + } => { let id = entry.id(); match self.established.get_mut(&entry.connected().peer_id) { - Some(list) => *list.get_mut(&id) - .expect("state inconsistency: entry is `EstablishedEntry` but absent \ - from `established`") = new_endpoint.clone(), - None => unreachable!("since `entry` is an `EstablishedEntry`.") + Some(list) => { + *list.get_mut(&id).expect( + "state inconsistency: entry is `EstablishedEntry` but absent \ + from `established`", + ) = new_endpoint.clone() + } + None => unreachable!("since `entry` is an `EstablishedEntry`."), }; match self.get(id) { - Some(PoolConnection::Established(connection)) => + Some(PoolConnection::Established(connection)) => { return Poll::Ready(PoolEvent::AddressChange { connection, new_endpoint, old_endpoint, - }), - _ => unreachable!("since `entry` is an `EstablishedEntry`.") + }) + } + _ => unreachable!("since `entry` is an `EstablishedEntry`."), } - }, + } } } } - } /// A connection in a [`Pool`]. @@ -707,9 +722,7 @@ pub struct PendingConnection<'a, TInEvent> { counters: &'a mut ConnectionCounters, } -impl - PendingConnection<'_, TInEvent> -{ +impl PendingConnection<'_, TInEvent> { /// Returns the local connection ID. pub fn id(&self) -> ConnectionId { self.entry.id() @@ -717,17 +730,29 @@ impl /// Returns the (expected) identity of the remote peer, if known. pub fn peer_id(&self) -> &Option { - &self.pending.get(&self.entry.id()).expect("`entry` is a pending entry").1 + &self + .pending + .get(&self.entry.id()) + .expect("`entry` is a pending entry") + .1 } /// Returns information about this endpoint of the connection. pub fn endpoint(&self) -> &ConnectedPoint { - &self.pending.get(&self.entry.id()).expect("`entry` is a pending entry").0 + &self + .pending + .get(&self.entry.id()) + .expect("`entry` is a pending entry") + .0 } /// Aborts the connection attempt, closing the connection. pub fn abort(self) { - let endpoint = self.pending.remove(&self.entry.id()).expect("`entry` is a pending entry").0; + let endpoint = self + .pending + .remove(&self.entry.id()) + .expect("`entry` is a pending entry") + .0; self.counters.dec_pending(&endpoint); self.entry.abort(); } @@ -738,8 +763,7 @@ pub struct EstablishedConnection<'a, TInEvent> { entry: manager::EstablishedEntry<'a, TInEvent>, } -impl fmt::Debug -for EstablishedConnection<'_, TInEvent> +impl fmt::Debug for EstablishedConnection<'_, TInEvent> where TInEvent: fmt::Debug, { @@ -790,7 +814,7 @@ impl EstablishedConnection<'_, TInEvent> { /// /// Returns `Err(())` if the background task associated with the connection /// is terminating and the connection is about to close. - pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll> { + pub fn poll_ready_notify_handler(&mut self, cx: &mut Context<'_>) -> Poll> { self.entry.poll_ready_notify_handler(cx) } @@ -811,21 +835,22 @@ pub struct EstablishedConnectionIter<'a, I, THandler: IntoConnectionHandler, TTr // Note: Ideally this would be an implementation of `Iterator`, but that // requires GATs (cf. https://github.com/rust-lang/rust/issues/44265) and // a different definition of `Iterator`. -impl<'a, I, THandler: IntoConnectionHandler, TTransErr> EstablishedConnectionIter<'a, I, THandler, TTransErr> +impl<'a, I, THandler: IntoConnectionHandler, TTransErr> + EstablishedConnectionIter<'a, I, THandler, TTransErr> where - I: Iterator + I: Iterator, { /// Obtains the next connection, if any. #[allow(clippy::should_implement_trait)] - pub fn next(&mut self) -> Option>> - { + pub fn next(&mut self) -> Option>> { while let Some(id) = self.ids.next() { - if self.pool.manager.is_established(&id) { // (*) + if self.pool.manager.is_established(&id) { + // (*) match self.pool.manager.entry(id) { Some(manager::Entry::Established(entry)) => { return Some(EstablishedConnection { entry }) } - _ => panic!("Established entry not found in manager.") // see (*) + _ => panic!("Established entry not found in manager."), // see (*) } } } @@ -838,17 +863,18 @@ where } /// Returns the first connection, if any, consuming the iterator. - pub fn into_first<'b>(mut self) - -> Option>> - where 'a: 'b + pub fn into_first<'b>(mut self) -> Option>> + where + 'a: 'b, { while let Some(id) = self.ids.next() { - if self.pool.manager.is_established(&id) { // (*) + if self.pool.manager.is_established(&id) { + // (*) match self.pool.manager.entry(id) { Some(manager::Entry::Established(entry)) => { return Some(EstablishedConnection { entry }) } - _ => panic!("Established entry not found in manager.") // see (*) + _ => panic!("Established entry not found in manager."), // see (*) } } } @@ -924,29 +950,45 @@ impl ConnectionCounters { fn inc_pending(&mut self, endpoint: &ConnectedPoint) { match endpoint { - ConnectedPoint::Dialer { .. } => { self.pending_outgoing += 1; } - ConnectedPoint::Listener { .. } => { self.pending_incoming += 1; } + ConnectedPoint::Dialer { .. } => { + self.pending_outgoing += 1; + } + ConnectedPoint::Listener { .. } => { + self.pending_incoming += 1; + } } } fn dec_pending(&mut self, endpoint: &ConnectedPoint) { match endpoint { - ConnectedPoint::Dialer { .. } => { self.pending_outgoing -= 1; } - ConnectedPoint::Listener { .. } => { self.pending_incoming -= 1; } + ConnectedPoint::Dialer { .. } => { + self.pending_outgoing -= 1; + } + ConnectedPoint::Listener { .. } => { + self.pending_incoming -= 1; + } } } fn inc_established(&mut self, endpoint: &ConnectedPoint) { match endpoint { - ConnectedPoint::Dialer { .. } => { self.established_outgoing += 1; } - ConnectedPoint::Listener { .. } => { self.established_incoming += 1; } + ConnectedPoint::Dialer { .. } => { + self.established_outgoing += 1; + } + ConnectedPoint::Listener { .. } => { + self.established_incoming += 1; + } } } fn dec_established(&mut self, endpoint: &ConnectedPoint) { match endpoint { - ConnectedPoint::Dialer { .. } => { self.established_outgoing -= 1; } - ConnectedPoint::Listener { .. } => { self.established_incoming -= 1; } + ConnectedPoint::Dialer { .. } => { + self.established_outgoing -= 1; + } + ConnectedPoint::Listener { .. } => { + self.established_incoming -= 1; + } } } @@ -958,18 +1000,19 @@ impl ConnectionCounters { Self::check(self.pending_incoming, self.limits.max_pending_incoming) } - fn check_max_established(&self, endpoint: &ConnectedPoint) - -> Result<(), ConnectionLimit> - { + fn check_max_established(&self, endpoint: &ConnectedPoint) -> Result<(), ConnectionLimit> { // Check total connection limit. Self::check(self.num_established(), self.limits.max_established_total)?; // Check incoming/outgoing connection limits match endpoint { - ConnectedPoint::Dialer { .. } => - Self::check(self.established_outgoing, self.limits.max_established_outgoing), - ConnectedPoint::Listener { .. } => { - Self::check(self.established_incoming, self.limits.max_established_incoming) - } + ConnectedPoint::Dialer { .. } => Self::check( + self.established_outgoing, + self.limits.max_established_outgoing, + ), + ConnectedPoint::Listener { .. } => Self::check( + self.established_incoming, + self.limits.max_established_incoming, + ), } } @@ -980,22 +1023,21 @@ impl ConnectionCounters { fn check(current: u32, limit: Option) -> Result<(), ConnectionLimit> { if let Some(limit) = limit { if current >= limit { - return Err(ConnectionLimit { limit, current }) + return Err(ConnectionLimit { limit, current }); } } Ok(()) } - } /// Counts the number of established connections to the given peer. fn num_peer_established( established: &FnvHashMap>, - peer: &PeerId + peer: &PeerId, ) -> u32 { - established.get(peer).map_or(0, |conns| - u32::try_from(conns.len()) - .expect("Unexpectedly large number of connections for a peer.")) + established.get(peer).map_or(0, |conns| { + u32::try_from(conns.len()).expect("Unexpectedly large number of connections for a peer.") + }) } /// The configurable connection limits. diff --git a/core/src/connection/substream.rs b/core/src/connection/substream.rs index ac537b488e9..399b09b9f0a 100644 --- a/core/src/connection/substream.rs +++ b/core/src/connection/substream.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::muxing::{StreamMuxer, StreamMuxerEvent, SubstreamRef, substream_from_ref}; +use crate::muxing::{substream_from_ref, StreamMuxer, StreamMuxerEvent, SubstreamRef}; use futures::prelude::*; use multiaddr::Multiaddr; use smallvec::SmallVec; @@ -135,7 +135,9 @@ where #[must_use] pub fn close(mut self) -> (Close, Vec) { let substreams = self.cancel_outgoing(); - let close = Close { muxer: self.inner.clone() }; + let close = Close { + muxer: self.inner.clone(), + }; (close, substreams) } @@ -150,17 +152,19 @@ where } /// Provides an API similar to `Future`. - pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll, IoError>> { + pub fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, IoError>> { // Polling inbound substream. match self.inner.poll_event(cx) { Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(substream))) => { let substream = substream_from_ref(self.inner.clone(), substream); - return Poll::Ready(Ok(SubstreamEvent::InboundSubstream { - substream, - })); + return Poll::Ready(Ok(SubstreamEvent::InboundSubstream { substream })); + } + Poll::Ready(Ok(StreamMuxerEvent::AddressChange(addr))) => { + return Poll::Ready(Ok(SubstreamEvent::AddressChange(addr))) } - Poll::Ready(Ok(StreamMuxerEvent::AddressChange(addr))) => - return Poll::Ready(Ok(SubstreamEvent::AddressChange(addr))), Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), Poll::Pending => {} } @@ -238,8 +242,7 @@ where TMuxer: StreamMuxer, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - f.debug_struct("Close") - .finish() + f.debug_struct("Close").finish() } } @@ -251,22 +254,22 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SubstreamEvent::InboundSubstream { substream } => { - f.debug_struct("SubstreamEvent::OutboundClosed") - .field("substream", substream) - .finish() - }, - SubstreamEvent::OutboundSubstream { user_data, substream } => { - f.debug_struct("SubstreamEvent::OutboundSubstream") - .field("user_data", user_data) - .field("substream", substream) - .finish() - }, - SubstreamEvent::AddressChange(address) => { - f.debug_struct("SubstreamEvent::AddressChange") - .field("address", address) - .finish() - }, + SubstreamEvent::InboundSubstream { substream } => f + .debug_struct("SubstreamEvent::OutboundClosed") + .field("substream", substream) + .finish(), + SubstreamEvent::OutboundSubstream { + user_data, + substream, + } => f + .debug_struct("SubstreamEvent::OutboundSubstream") + .field("user_data", user_data) + .field("substream", substream) + .finish(), + SubstreamEvent::AddressChange(address) => f + .debug_struct("SubstreamEvent::AddressChange") + .field("address", address) + .finish(), } } } diff --git a/core/src/either.rs b/core/src/either.rs index 4d991936121..66a11589f7a 100644 --- a/core/src/either.rs +++ b/core/src/either.rs @@ -20,29 +20,31 @@ use crate::{ muxing::{StreamMuxer, StreamMuxerEvent}, - ProtocolName, - transport::{Transport, ListenerEvent, TransportError}, - Multiaddr + transport::{ListenerEvent, Transport, TransportError}, + Multiaddr, ProtocolName, +}; +use futures::{ + io::{IoSlice, IoSliceMut}, + prelude::*, }; -use futures::{prelude::*, io::{IoSlice, IoSliceMut}}; use pin_project::pin_project; -use std::{fmt, io::{Error as IoError}, pin::Pin, task::Context, task::Poll}; +use std::{fmt, io::Error as IoError, pin::Pin, task::Context, task::Poll}; #[derive(Debug, Copy, Clone)] pub enum EitherError { A(A), - B(B) + B(B), } impl fmt::Display for EitherError where A: fmt::Display, - B: fmt::Display + B: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { EitherError::A(a) => a.fmt(f), - EitherError::B(b) => b.fmt(f) + EitherError::B(b) => b.fmt(f), } } } @@ -50,12 +52,12 @@ where impl std::error::Error for EitherError where A: std::error::Error, - B: std::error::Error + B: std::error::Error, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { EitherError::A(a) => a.source(), - EitherError::B(b) => b.source() + EitherError::B(b) => b.source(), } } } @@ -74,16 +76,22 @@ where A: AsyncRead, B: AsyncRead, { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { match self.project() { EitherOutputProj::First(a) => AsyncRead::poll_read(a, cx, buf), EitherOutputProj::Second(b) => AsyncRead::poll_read(b, cx, buf), } } - fn poll_read_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>]) - -> Poll> - { + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { match self.project() { EitherOutputProj::First(a) => AsyncRead::poll_read_vectored(a, cx, bufs), EitherOutputProj::Second(b) => AsyncRead::poll_read_vectored(b, cx, bufs), @@ -96,16 +104,22 @@ where A: AsyncWrite, B: AsyncWrite, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { match self.project() { EitherOutputProj::First(a) => AsyncWrite::poll_write(a, cx, buf), EitherOutputProj::Second(b) => AsyncWrite::poll_write(b, cx, buf), } } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) - -> Poll> - { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { match self.project() { EitherOutputProj::First(a) => AsyncWrite::poll_write_vectored(a, cx, bufs), EitherOutputProj::Second(b) => AsyncWrite::poll_write_vectored(b, cx, bufs), @@ -136,10 +150,12 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { - EitherOutputProj::First(a) => TryStream::try_poll_next(a, cx) - .map(|v| v.map(|r| r.map_err(EitherError::A))), - EitherOutputProj::Second(b) => TryStream::try_poll_next(b, cx) - .map(|v| v.map(|r| r.map_err(EitherError::B))), + EitherOutputProj::First(a) => { + TryStream::try_poll_next(a, cx).map(|v| v.map(|r| r.map_err(EitherError::A))) + } + EitherOutputProj::Second(b) => { + TryStream::try_poll_next(b, cx).map(|v| v.map(|r| r.map_err(EitherError::B))) + } } } } @@ -189,23 +205,24 @@ where type OutboundSubstream = EitherOutbound; type Error = IoError; - fn poll_event(&self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { match self { EitherOutput::First(inner) => inner.poll_event(cx).map(|result| { - result.map_err(|e| e.into()).map(|event| { - match event { - StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), - StreamMuxerEvent::InboundSubstream(substream) => - StreamMuxerEvent::InboundSubstream(EitherOutput::First(substream)) + result.map_err(|e| e.into()).map(|event| match event { + StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), + StreamMuxerEvent::InboundSubstream(substream) => { + StreamMuxerEvent::InboundSubstream(EitherOutput::First(substream)) } }) }), EitherOutput::Second(inner) => inner.poll_event(cx).map(|result| { - result.map_err(|e| e.into()).map(|event| { - match event { - StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), - StreamMuxerEvent::InboundSubstream(substream) => - StreamMuxerEvent::InboundSubstream(EitherOutput::Second(substream)) + result.map_err(|e| e.into()).map(|event| match event { + StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), + StreamMuxerEvent::InboundSubstream(substream) => { + StreamMuxerEvent::InboundSubstream(EitherOutput::Second(substream)) } }) }), @@ -219,96 +236,112 @@ where } } - fn poll_outbound(&self, cx: &mut Context<'_>, substream: &mut Self::OutboundSubstream) -> Poll> { + fn poll_outbound( + &self, + cx: &mut Context<'_>, + substream: &mut Self::OutboundSubstream, + ) -> Poll> { match (self, substream) { - (EitherOutput::First(ref inner), EitherOutbound::A(ref mut substream)) => { - inner.poll_outbound(cx, substream).map(|p| p.map(EitherOutput::First)).map_err(|e| e.into()) - }, - (EitherOutput::Second(ref inner), EitherOutbound::B(ref mut substream)) => { - inner.poll_outbound(cx, substream).map(|p| p.map(EitherOutput::Second)).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + (EitherOutput::First(ref inner), EitherOutbound::A(ref mut substream)) => inner + .poll_outbound(cx, substream) + .map(|p| p.map(EitherOutput::First)) + .map_err(|e| e.into()), + (EitherOutput::Second(ref inner), EitherOutbound::B(ref mut substream)) => inner + .poll_outbound(cx, substream) + .map(|p| p.map(EitherOutput::Second)) + .map_err(|e| e.into()), + _ => panic!("Wrong API usage"), } } fn destroy_outbound(&self, substream: Self::OutboundSubstream) { match self { - EitherOutput::First(inner) => { - match substream { - EitherOutbound::A(substream) => inner.destroy_outbound(substream), - _ => panic!("Wrong API usage") - } + EitherOutput::First(inner) => match substream { + EitherOutbound::A(substream) => inner.destroy_outbound(substream), + _ => panic!("Wrong API usage"), }, - EitherOutput::Second(inner) => { - match substream { - EitherOutbound::B(substream) => inner.destroy_outbound(substream), - _ => panic!("Wrong API usage") - } + EitherOutput::Second(inner) => match substream { + EitherOutbound::B(substream) => inner.destroy_outbound(substream), + _ => panic!("Wrong API usage"), }, } } - fn read_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + fn read_substream( + &self, + cx: &mut Context<'_>, + sub: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { inner.read_substream(cx, sub, buf).map_err(|e| e.into()) - }, + } (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { inner.read_substream(cx, sub, buf).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + } + _ => panic!("Wrong API usage"), } } - fn write_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream, buf: &[u8]) -> Poll> { + fn write_substream( + &self, + cx: &mut Context<'_>, + sub: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { inner.write_substream(cx, sub, buf).map_err(|e| e.into()) - }, + } (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { inner.write_substream(cx, sub, buf).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + } + _ => panic!("Wrong API usage"), } } - fn flush_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream) -> Poll> { + fn flush_substream( + &self, + cx: &mut Context<'_>, + sub: &mut Self::Substream, + ) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { inner.flush_substream(cx, sub).map_err(|e| e.into()) - }, + } (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { inner.flush_substream(cx, sub).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + } + _ => panic!("Wrong API usage"), } } - fn shutdown_substream(&self, cx: &mut Context<'_>, sub: &mut Self::Substream) -> Poll> { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + sub: &mut Self::Substream, + ) -> Poll> { match (self, sub) { (EitherOutput::First(ref inner), EitherOutput::First(ref mut sub)) => { inner.shutdown_substream(cx, sub).map_err(|e| e.into()) - }, + } (EitherOutput::Second(ref inner), EitherOutput::Second(ref mut sub)) => { inner.shutdown_substream(cx, sub).map_err(|e| e.into()) - }, - _ => panic!("Wrong API usage") + } + _ => panic!("Wrong API usage"), } } fn destroy_substream(&self, substream: Self::Substream) { match self { - EitherOutput::First(inner) => { - match substream { - EitherOutput::First(substream) => inner.destroy_substream(substream), - _ => panic!("Wrong API usage") - } + EitherOutput::First(inner) => match substream { + EitherOutput::First(substream) => inner.destroy_substream(substream), + _ => panic!("Wrong API usage"), }, - EitherOutput::Second(inner) => { - match substream { - EitherOutput::Second(substream) => inner.destroy_substream(substream), - _ => panic!("Wrong API usage") - } + EitherOutput::Second(inner) => match substream { + EitherOutput::Second(substream) => inner.destroy_substream(substream), + _ => panic!("Wrong API usage"), }, } } @@ -344,25 +377,33 @@ pub enum EitherListenStream { Second(#[pin] B), } -impl Stream for EitherListenStream +impl Stream + for EitherListenStream where AStream: TryStream, Error = AError>, BStream: TryStream, Error = BError>, { - type Item = Result, EitherError>, EitherError>; + type Item = Result< + ListenerEvent, EitherError>, + EitherError, + >; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { EitherListenStreamProj::First(a) => match TryStream::try_poll_next(a, cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le.map(EitherFuture::First).map_err(EitherError::A)))), + Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le + .map(EitherFuture::First) + .map_err(EitherError::A)))), Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), }, EitherListenStreamProj::Second(a) => match TryStream::try_poll_next(a, cx) { Poll::Pending => Poll::Pending, Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le.map(EitherFuture::Second).map_err(EitherError::B)))), + Poll::Ready(Some(Ok(le))) => Poll::Ready(Some(Ok(le + .map(EitherFuture::Second) + .map_err(EitherError::B)))), Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::B(err)))), }, } @@ -388,9 +429,11 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project() { EitherFutureProj::First(a) => TryFuture::try_poll(a, cx) - .map_ok(EitherOutput::First).map_err(EitherError::A), + .map_ok(EitherOutput::First) + .map_err(EitherError::A), EitherFutureProj::Second(a) => TryFuture::try_poll(a, cx) - .map_ok(EitherOutput::Second).map_err(EitherError::B), + .map_ok(EitherOutput::Second) + .map_err(EitherError::B), } } } @@ -398,7 +441,10 @@ where #[pin_project(project = EitherFuture2Proj)] #[derive(Debug, Copy, Clone)] #[must_use = "futures do nothing unless polled"] -pub enum EitherFuture2 { A(#[pin] A), B(#[pin] B) } +pub enum EitherFuture2 { + A(#[pin] A), + B(#[pin] B), +} impl Future for EitherFuture2 where @@ -410,21 +456,26 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project() { EitherFuture2Proj::A(a) => TryFuture::try_poll(a, cx) - .map_ok(EitherOutput::First).map_err(EitherError::A), + .map_ok(EitherOutput::First) + .map_err(EitherError::A), EitherFuture2Proj::B(a) => TryFuture::try_poll(a, cx) - .map_ok(EitherOutput::Second).map_err(EitherError::B), + .map_ok(EitherOutput::Second) + .map_err(EitherError::B), } } } #[derive(Debug, Clone)] -pub enum EitherName { A(A), B(B) } +pub enum EitherName { + A(A), + B(B), +} impl ProtocolName for EitherName { fn protocol_name(&self) -> &[u8] { match self { EitherName::A(a) => a.protocol_name(), - EitherName::B(b) => b.protocol_name() + EitherName::B(b) => b.protocol_name(), } } } diff --git a/core/src/identity.rs b/core/src/identity.rs index 6ed29424045..8c3c83db16e 100644 --- a/core/src/identity.rs +++ b/core/src/identity.rs @@ -41,7 +41,7 @@ pub mod secp256k1; pub mod error; use self::error::*; -use crate::{PeerId, keys_proto}; +use crate::{keys_proto, PeerId}; /// Identity keypair of a node. /// @@ -69,7 +69,7 @@ pub enum Keypair { Rsa(rsa::Keypair), /// A Secp256k1 keypair. #[cfg(feature = "secp256k1")] - Secp256k1(secp256k1::Keypair) + Secp256k1(secp256k1::Keypair), } impl Keypair { @@ -112,7 +112,7 @@ impl Keypair { #[cfg(not(target_arch = "wasm32"))] Rsa(ref pair) => pair.sign(msg), #[cfg(feature = "secp256k1")] - Secp256k1(ref pair) => pair.secret().sign(msg) + Secp256k1(ref pair) => pair.secret().sign(msg), } } @@ -154,7 +154,6 @@ impl Keypair { Ok(pk.encode_to_vec()) } - /// Decode a private key from a protobuf structure and parse it as a [`Keypair`]. pub fn from_protobuf_encoding(bytes: &[u8]) -> Result { use prost::Message; @@ -163,19 +162,20 @@ impl Keypair { .map_err(|e| DecodingError::new("Protobuf").source(e)) .map(zeroize::Zeroizing::new)?; - let key_type = keys_proto::KeyType::from_i32(private_key.r#type) - .ok_or_else(|| DecodingError::new(format!("unknown key type: {}", private_key.r#type)))?; + let key_type = keys_proto::KeyType::from_i32(private_key.r#type).ok_or_else(|| { + DecodingError::new(format!("unknown key type: {}", private_key.r#type)) + })?; match key_type { keys_proto::KeyType::Ed25519 => { ed25519::Keypair::decode(&mut private_key.data).map(Keypair::Ed25519) - }, - keys_proto::KeyType::Rsa => { - Err(DecodingError::new("Decoding RSA key from Protobuf is unsupported.")) - }, - keys_proto::KeyType::Secp256k1 => { - Err(DecodingError::new("Decoding Secp256k1 key from Protobuf is unsupported.")) - }, + } + keys_proto::KeyType::Rsa => Err(DecodingError::new( + "Decoding RSA key from Protobuf is unsupported.", + )), + keys_proto::KeyType::Secp256k1 => Err(DecodingError::new( + "Decoding Secp256k1 key from Protobuf is unsupported.", + )), } } } @@ -197,7 +197,7 @@ pub enum PublicKey { Rsa(rsa::PublicKey), #[cfg(feature = "secp256k1")] /// A public Secp256k1 key. - Secp256k1(secp256k1::PublicKey) + Secp256k1(secp256k1::PublicKey), } impl PublicKey { @@ -212,7 +212,7 @@ impl PublicKey { #[cfg(not(target_arch = "wasm32"))] Rsa(pk) => pk.verify(msg, sig), #[cfg(feature = "secp256k1")] - Secp256k1(pk) => pk.verify(msg, sig) + Secp256k1(pk) => pk.verify(msg, sig), } } @@ -222,27 +222,26 @@ impl PublicKey { use prost::Message; let public_key = match self { - PublicKey::Ed25519(key) => - keys_proto::PublicKey { - r#type: keys_proto::KeyType::Ed25519 as i32, - data: key.encode().to_vec() - }, + PublicKey::Ed25519(key) => keys_proto::PublicKey { + r#type: keys_proto::KeyType::Ed25519 as i32, + data: key.encode().to_vec(), + }, #[cfg(not(target_arch = "wasm32"))] - PublicKey::Rsa(key) => - keys_proto::PublicKey { - r#type: keys_proto::KeyType::Rsa as i32, - data: key.encode_x509() - }, + PublicKey::Rsa(key) => keys_proto::PublicKey { + r#type: keys_proto::KeyType::Rsa as i32, + data: key.encode_x509(), + }, #[cfg(feature = "secp256k1")] - PublicKey::Secp256k1(key) => - keys_proto::PublicKey { - r#type: keys_proto::KeyType::Secp256k1 as i32, - data: key.encode().to_vec() - } + PublicKey::Secp256k1(key) => keys_proto::PublicKey { + r#type: keys_proto::KeyType::Secp256k1 as i32, + data: key.encode().to_vec(), + }, }; let mut buf = Vec::with_capacity(public_key.encoded_len()); - public_key.encode(&mut buf).expect("Vec provides capacity as needed"); + public_key + .encode(&mut buf) + .expect("Vec provides capacity as needed"); buf } @@ -261,7 +260,7 @@ impl PublicKey { match key_type { keys_proto::KeyType::Ed25519 => { ed25519::PublicKey::decode(&pubkey.data).map(PublicKey::Ed25519) - }, + } #[cfg(not(target_arch = "wasm32"))] keys_proto::KeyType::Rsa => { rsa::PublicKey::decode_x509(&pubkey.data).map(PublicKey::Rsa) @@ -270,7 +269,7 @@ impl PublicKey { keys_proto::KeyType::Rsa => { log::debug!("support for RSA was disabled at compile-time"); Err(DecodingError::new("Unsupported")) - }, + } #[cfg(feature = "secp256k1")] keys_proto::KeyType::Secp256k1 => { secp256k1::PublicKey::decode(&pubkey.data).map(PublicKey::Secp256k1) @@ -311,7 +310,8 @@ mod tests { fn keypair_from_protobuf_encoding() { // E.g. retrieved from an IPFS config file. let base_64_encoded = "CAESQL6vdKQuznQosTrW7FWI9At+XX7EBf0BnZLhb6w+N+XSQSdfInl6c7U4NuxXJlhKcRBlBw9d0tj2dfBIVf6mcPA="; - let expected_peer_id = PeerId::from_str("12D3KooWEChVMMMzV8acJ53mJHrw1pQ27UAGkCxWXLJutbeUMvVu").unwrap(); + let expected_peer_id = + PeerId::from_str("12D3KooWEChVMMMzV8acJ53mJHrw1pQ27UAGkCxWXLJutbeUMvVu").unwrap(); let encoded = base64::decode(base_64_encoded).unwrap(); diff --git a/core/src/identity/ed25519.rs b/core/src/identity/ed25519.rs index f606a82b19b..5782ac788cb 100644 --- a/core/src/identity/ed25519.rs +++ b/core/src/identity/ed25519.rs @@ -20,12 +20,12 @@ //! Ed25519 keys. +use super::error::DecodingError; +use core::fmt; use ed25519_dalek::{self as ed25519, Signer as _, Verifier as _}; use rand::RngCore; use std::convert::TryFrom; -use super::error::DecodingError; use zeroize::Zeroize; -use core::fmt; /// An Ed25519 keypair. pub struct Keypair(ed25519::Keypair); @@ -49,7 +49,10 @@ impl Keypair { /// Note that this binary format is the same as `ed25519_dalek`'s and `ed25519_zebra`'s. pub fn decode(kp: &mut [u8]) -> Result { ed25519::Keypair::from_bytes(kp) - .map(|k| { kp.zeroize(); Keypair(k) }) + .map(|k| { + kp.zeroize(); + Keypair(k) + }) .map_err(|e| DecodingError::new("Ed25519 keypair").source(e)) } @@ -72,7 +75,9 @@ impl Keypair { impl fmt::Debug for Keypair { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Keypair").field("public", &self.0.public).finish() + f.debug_struct("Keypair") + .field("public", &self.0.public) + .finish() } } @@ -80,7 +85,8 @@ impl Clone for Keypair { fn clone(&self) -> Keypair { let mut sk_bytes = self.0.secret.to_bytes(); let secret = SecretKey::from_bytes(&mut sk_bytes) - .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k").0; + .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") + .0; let public = ed25519::PublicKey::from_bytes(&self.0.public.to_bytes()) .expect("ed25519::PublicKey::from_bytes(to_bytes(k)) != k"); Keypair(ed25519::Keypair { secret, public }) @@ -99,7 +105,10 @@ impl From for Keypair { fn from(sk: SecretKey) -> Keypair { let secret: ed25519::ExpandedSecretKey = (&sk.0).into(); let public = ed25519::PublicKey::from(&secret); - Keypair(ed25519::Keypair { secret: sk.0, public }) + Keypair(ed25519::Keypair { + secret: sk.0, + public, + }) } } @@ -120,7 +129,9 @@ impl fmt::Debug for PublicKey { impl PublicKey { /// Verify the Ed25519 signature on a message using the public key. pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool { - ed25519::Signature::try_from(sig).and_then(|s| self.0.verify(msg, &s)).is_ok() + ed25519::Signature::try_from(sig) + .and_then(|s| self.0.verify(msg, &s)) + .is_ok() } /// Encode the public key into a byte array in compressed form, i.e. @@ -150,8 +161,7 @@ impl AsRef<[u8]> for SecretKey { impl Clone for SecretKey { fn clone(&self) -> SecretKey { let mut sk_bytes = self.0.to_bytes(); - Self::from_bytes(&mut sk_bytes) - .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") + Self::from_bytes(&mut sk_bytes).expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") } } @@ -166,8 +176,11 @@ impl SecretKey { pub fn generate() -> SecretKey { let mut bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut bytes); - SecretKey(ed25519::SecretKey::from_bytes(&bytes) - .expect("this returns `Err` only if the length is wrong; the length is correct; qed")) + SecretKey( + ed25519::SecretKey::from_bytes(&bytes).expect( + "this returns `Err` only if the length is wrong; the length is correct; qed", + ), + ) } /// Create an Ed25519 secret key from a byte slice, zeroing the input on success. @@ -188,9 +201,7 @@ mod tests { use quickcheck::*; fn eq_keypairs(kp1: &Keypair, kp2: &Keypair) -> bool { - kp1.public() == kp2.public() - && - kp1.0.secret.as_bytes() == kp2.0.secret.as_bytes() + kp1.public() == kp2.public() && kp1.0.secret.as_bytes() == kp2.0.secret.as_bytes() } #[test] @@ -199,9 +210,7 @@ mod tests { let kp1 = Keypair::generate(); let mut kp1_enc = kp1.encode(); let kp2 = Keypair::decode(&mut kp1_enc).unwrap(); - eq_keypairs(&kp1, &kp2) - && - kp1_enc.iter().all(|b| *b == 0) + eq_keypairs(&kp1, &kp2) && kp1_enc.iter().all(|b| *b == 0) } QuickCheck::new().tests(10).quickcheck(prop as fn() -> _); } @@ -212,9 +221,7 @@ mod tests { let kp1 = Keypair::generate(); let mut sk = kp1.0.secret.to_bytes(); let kp2 = Keypair::from(SecretKey::from_bytes(&mut sk).unwrap()); - eq_keypairs(&kp1, &kp2) - && - sk == [0u8; 32] + eq_keypairs(&kp1, &kp2) && sk == [0u8; 32] } QuickCheck::new().tests(10).quickcheck(prop as fn() -> _); } diff --git a/core/src/identity/error.rs b/core/src/identity/error.rs index 8fd1b1b9be9..76f41278d5d 100644 --- a/core/src/identity/error.rs +++ b/core/src/identity/error.rs @@ -27,16 +27,22 @@ use std::fmt; #[derive(Debug)] pub struct DecodingError { msg: String, - source: Option> + source: Option>, } impl DecodingError { pub(crate) fn new(msg: S) -> Self { - Self { msg: msg.to_string(), source: None } + Self { + msg: msg.to_string(), + source: None, + } } pub(crate) fn source(self, source: impl Error + Send + Sync + 'static) -> Self { - Self { source: Some(Box::new(source)), .. self } + Self { + source: Some(Box::new(source)), + ..self + } } } @@ -56,17 +62,23 @@ impl Error for DecodingError { #[derive(Debug)] pub struct SigningError { msg: String, - source: Option> + source: Option>, } /// An error during encoding of key material. impl SigningError { pub(crate) fn new(msg: S) -> Self { - Self { msg: msg.to_string(), source: None } + Self { + msg: msg.to_string(), + source: None, + } } pub(crate) fn source(self, source: impl Error + Send + Sync + 'static) -> Self { - Self { source: Some(Box::new(source)), .. self } + Self { + source: Some(Box::new(source)), + ..self + } } } @@ -81,4 +93,3 @@ impl Error for SigningError { self.source.as_ref().map(|s| &**s as &dyn Error) } } - diff --git a/core/src/identity/rsa.rs b/core/src/identity/rsa.rs index ffbfb975ff0..752bb156764 100644 --- a/core/src/identity/rsa.rs +++ b/core/src/identity/rsa.rs @@ -20,12 +20,12 @@ //! RSA keys. -use asn1_der::typed::{DerEncodable, DerDecodable, DerTypeView, Sequence}; -use asn1_der::{DerObject, Asn1DerError, Asn1DerErrorVariant, Sink, VecBacking}; use super::error::*; +use asn1_der::typed::{DerDecodable, DerEncodable, DerTypeView, Sequence}; +use asn1_der::{Asn1DerError, Asn1DerErrorVariant, DerObject, Sink, VecBacking}; use ring::rand::SystemRandom; -use ring::signature::{self, RsaKeyPair, RSA_PKCS1_SHA256, RSA_PKCS1_2048_8192_SHA256}; use ring::signature::KeyPair; +use ring::signature::{self, RsaKeyPair, RSA_PKCS1_2048_8192_SHA256, RSA_PKCS1_SHA256}; use std::{fmt, sync::Arc}; use zeroize::Zeroize; @@ -56,7 +56,7 @@ impl Keypair { let rng = SystemRandom::new(); match self.0.sign(&RSA_PKCS1_SHA256, &rng, &data, &mut signature) { Ok(()) => Ok(signature), - Err(e) => Err(SigningError::new("RSA").source(e)) + Err(e) => Err(SigningError::new("RSA").source(e)), } } } @@ -89,12 +89,14 @@ impl PublicKey { let spki = Asn1SubjectPublicKeyInfo { algorithmIdentifier: Asn1RsaEncryption { algorithm: Asn1OidRsaEncryption, - parameters: () + parameters: (), }, - subjectPublicKey: Asn1SubjectPublicKey(self.clone()) + subjectPublicKey: Asn1SubjectPublicKey(self.clone()), }; let mut buf = Vec::new(); - let buf = spki.encode(&mut buf).map(|_| buf) + let buf = spki + .encode(&mut buf) + .map(|_| buf) .expect("RSA X.509 public key encoding failed."); buf } @@ -127,7 +129,7 @@ impl fmt::Debug for PublicKey { /// A raw ASN1 OID. #[derive(Copy, Clone)] struct Asn1RawOid<'a> { - object: DerObject<'a> + object: DerObject<'a>, } impl<'a> Asn1RawOid<'a> { @@ -179,7 +181,7 @@ impl Asn1OidRsaEncryption { /// /// [RFC-3279]: https://tools.ietf.org/html/rfc3279#section-2.3.1 /// [RFC-5280]: https://tools.ietf.org/html/rfc5280#section-4.1 - const OID: [u8;9] = [ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 ]; + const OID: [u8; 9] = [0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01]; } impl DerEncodable for Asn1OidRsaEncryption { @@ -194,7 +196,7 @@ impl DerDecodable<'_> for Asn1OidRsaEncryption { oid if oid == Self::OID => Ok(Self), _ => Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData( "DER object is not the 'rsaEncryption' identifier.", - ))) + ))), } } } @@ -202,7 +204,7 @@ impl DerDecodable<'_> for Asn1OidRsaEncryption { /// The ASN.1 AlgorithmIdentifier for "rsaEncryption". struct Asn1RsaEncryption { algorithm: Asn1OidRsaEncryption, - parameters: () + parameters: (), } impl DerEncodable for Asn1RsaEncryption { @@ -211,7 +213,9 @@ impl DerEncodable for Asn1RsaEncryption { let algorithm = self.algorithm.der_object(VecBacking(&mut algorithm_buf))?; let mut parameters_buf = Vec::new(); - let parameters = self.parameters.der_object(VecBacking(&mut parameters_buf))?; + let parameters = self + .parameters + .der_object(VecBacking(&mut parameters_buf))?; Sequence::write(&[algorithm, parameters], sink) } @@ -221,7 +225,7 @@ impl DerDecodable<'_> for Asn1RsaEncryption { fn load(object: DerObject<'_>) -> Result { let seq: Sequence = Sequence::load(object)?; - Ok(Self{ + Ok(Self { algorithm: seq.get_as(0)?, parameters: seq.get_as(1)?, }) @@ -248,9 +252,9 @@ impl DerEncodable for Asn1SubjectPublicKey { impl DerDecodable<'_> for Asn1SubjectPublicKey { fn load(object: DerObject<'_>) -> Result { if object.tag() != 3 { - return Err(Asn1DerError::new( - Asn1DerErrorVariant::InvalidData("DER object tag is not the bit string tag."), - )); + return Err(Asn1DerError::new(Asn1DerErrorVariant::InvalidData( + "DER object tag is not the bit string tag.", + ))); } let pk_der: Vec = object.value().into_iter().skip(1).cloned().collect(); @@ -264,13 +268,15 @@ impl DerDecodable<'_> for Asn1SubjectPublicKey { #[allow(non_snake_case)] struct Asn1SubjectPublicKeyInfo { algorithmIdentifier: Asn1RsaEncryption, - subjectPublicKey: Asn1SubjectPublicKey + subjectPublicKey: Asn1SubjectPublicKey, } impl DerEncodable for Asn1SubjectPublicKeyInfo { fn encode(&self, sink: &mut S) -> Result<(), Asn1DerError> { let mut identifier_buf = Vec::new(); - let identifier = self.algorithmIdentifier.der_object(VecBacking(&mut identifier_buf))?; + let identifier = self + .algorithmIdentifier + .der_object(VecBacking(&mut identifier_buf))?; let mut key_buf = Vec::new(); let key = self.subjectPublicKey.der_object(VecBacking(&mut key_buf))?; @@ -340,6 +346,8 @@ mod tests { fn prop(SomeKeypair(kp): SomeKeypair, msg: Vec) -> Result { kp.sign(&msg).map(|s| kp.public().verify(&msg, &s)) } - QuickCheck::new().tests(10).quickcheck(prop as fn(_,_) -> _); + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _) -> _); } } diff --git a/core/src/identity/secp256k1.rs b/core/src/identity/secp256k1.rs index be887064131..2c3aaf89a51 100644 --- a/core/src/identity/secp256k1.rs +++ b/core/src/identity/secp256k1.rs @@ -20,18 +20,18 @@ //! Secp256k1 keys. +use super::error::{DecodingError, SigningError}; use asn1_der::typed::{DerDecodable, Sequence}; -use sha2::{Digest as ShaDigestTrait, Sha256}; +use core::fmt; use libsecp256k1::{Message, Signature}; -use super::error::{DecodingError, SigningError}; +use sha2::{Digest as ShaDigestTrait, Sha256}; use zeroize::Zeroize; -use core::fmt; /// A Secp256k1 keypair. #[derive(Clone)] pub struct Keypair { secret: SecretKey, - public: PublicKey + public: PublicKey, } impl Keypair { @@ -53,7 +53,9 @@ impl Keypair { impl fmt::Debug for Keypair { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Keypair").field("public", &self.public).finish() + f.debug_struct("Keypair") + .field("public", &self.public) + .finish() } } @@ -110,10 +112,11 @@ impl SecretKey { let der_obj = der.as_mut(); let obj: Sequence = DerDecodable::decode(der_obj) .map_err(|e| DecodingError::new("Secp256k1 DER ECPrivateKey").source(e))?; - let sk_obj = obj.get(1) + let sk_obj = obj + .get(1) .map_err(|e| DecodingError::new("Not enough elements in DER").source(e))?; - let mut sk_bytes: Vec = asn1_der::typed::DerDecodable::load(sk_obj) - .map_err(DecodingError::new)?; + let mut sk_bytes: Vec = + asn1_der::typed::DerDecodable::load(sk_obj).map_err(DecodingError::new)?; let sk = SecretKey::from_bytes(&mut sk_bytes)?; sk_bytes.zeroize(); der_obj.zeroize(); @@ -138,7 +141,11 @@ impl SecretKey { pub fn sign_hash(&self, msg: &[u8]) -> Result, SigningError> { let m = Message::parse_slice(msg) .map_err(|_| SigningError::new("failed to parse secp256k1 digest"))?; - Ok(libsecp256k1::sign(&m, &self.0).0.serialize_der().as_ref().into()) + Ok(libsecp256k1::sign(&m, &self.0) + .0 + .serialize_der() + .as_ref() + .into()) } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 844fd2a23bc..60727c52062 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -54,16 +54,16 @@ pub mod network; pub mod transport; pub mod upgrade; +pub use connection::{Connected, ConnectedPoint, Endpoint}; +pub use identity::PublicKey; pub use multiaddr::Multiaddr; pub use multihash; pub use muxing::StreamMuxer; +pub use network::Network; pub use peer_id::PeerId; -pub use identity::PublicKey; -pub use transport::Transport; pub use translation::address_translation; -pub use upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, UpgradeError, ProtocolName}; -pub use connection::{Connected, Endpoint, ConnectedPoint}; -pub use network::Network; +pub use transport::Transport; +pub use upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeError, UpgradeInfo}; use std::{future::Future, pin::Pin}; diff --git a/core/src/muxing.rs b/core/src/muxing.rs index c8ae456fac2..12beb51d9dd 100644 --- a/core/src/muxing.rs +++ b/core/src/muxing.rs @@ -55,7 +55,12 @@ use fnv::FnvHashMap; use futures::{future, prelude::*, task::Context, task::Poll}; use multiaddr::Multiaddr; use parking_lot::Mutex; -use std::{io, ops::Deref, fmt, pin::Pin, sync::atomic::{AtomicUsize, Ordering}}; +use std::{ + fmt, io, + ops::Deref, + pin::Pin, + sync::atomic::{AtomicUsize, Ordering}, +}; pub use self::singleton::SingletonMuxer; @@ -95,7 +100,10 @@ pub trait StreamMuxer { /// work, such as processing incoming packets and polling timers. /// /// An error can be generated if the connection has been closed. - fn poll_event(&self, cx: &mut Context<'_>) -> Poll, Self::Error>>; + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>; /// Opens a new outgoing substream, and produces the equivalent to a future that will be /// resolved when it becomes available. @@ -113,8 +121,11 @@ pub trait StreamMuxer { /// /// May panic or produce an undefined result if an earlier polling of the same substream /// returned `Ready` or `Err`. - fn poll_outbound(&self, cx: &mut Context<'_>, s: &mut Self::OutboundSubstream) - -> Poll>; + fn poll_outbound( + &self, + cx: &mut Context<'_>, + s: &mut Self::OutboundSubstream, + ) -> Poll>; /// Destroys an outbound substream future. Use this after the outbound substream has finished, /// or if you want to interrupt it. @@ -131,8 +142,12 @@ pub trait StreamMuxer { /// /// An error can be generated if the connection has been closed, or if a protocol misbehaviour /// happened. - fn read_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &mut [u8]) - -> Poll>; + fn read_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll>; /// Write data to a substream. The behaviour is the same as `futures::AsyncWrite::poll_write`. /// @@ -145,8 +160,12 @@ pub trait StreamMuxer { /// /// It is incorrect to call this method on a substream if you called `shutdown_substream` on /// this substream earlier. - fn write_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &[u8]) - -> Poll>; + fn write_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &[u8], + ) -> Poll>; /// Flushes a substream. The behaviour is the same as `futures::AsyncWrite::poll_flush`. /// @@ -158,8 +177,11 @@ pub trait StreamMuxer { /// call this method may be notified. /// /// > **Note**: This method may be implemented as a call to `flush_all`. - fn flush_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) - -> Poll>; + fn flush_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll>; /// Attempts to shut down the writing side of a substream. The behaviour is similar to /// `AsyncWrite::poll_close`. @@ -172,8 +194,11 @@ pub trait StreamMuxer { /// /// An error can be generated if the connection has been closed, or if a protocol misbehaviour /// happened. - fn shutdown_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) - -> Poll>; + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll>; /// Destroys a substream. fn destroy_substream(&self, s: Self::Substream); @@ -246,14 +271,12 @@ where P::Target: StreamMuxer, { let muxer2 = muxer.clone(); - future::poll_fn(move |cx| muxer.poll_event(cx)) - .map_ok(|event| { - match event { - StreamMuxerEvent::InboundSubstream(substream) => - StreamMuxerEvent::InboundSubstream(substream_from_ref(muxer2, substream)), - StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), - } - }) + future::poll_fn(move |cx| muxer.poll_event(cx)).map_ok(|event| match event { + StreamMuxerEvent::InboundSubstream(substream) => { + StreamMuxerEvent::InboundSubstream(substream_from_ref(muxer2, substream)) + } + StreamMuxerEvent::AddressChange(addr) => StreamMuxerEvent::AddressChange(addr), + }) } /// Same as `outbound_from_ref`, but wraps the output in an object that @@ -336,7 +359,8 @@ where // We use a `this` because the compiler isn't smart enough to allow mutably borrowing // multiple different fields from the `Pin` at the same time. let this = &mut *self; - this.muxer.poll_outbound(cx, this.outbound.as_mut().expect("outbound was empty")) + this.muxer + .poll_outbound(cx, this.outbound.as_mut().expect("outbound was empty")) } } @@ -408,7 +432,11 @@ where P: Deref, P::Target: StreamMuxer, { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { // We use a `this` because the compiler isn't smart enough to allow mutably borrowing // multiple different fields from the `Pin` at the same time. let this = &mut *self; @@ -423,7 +451,11 @@ where P: Deref, P::Target: StreamMuxer, { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { // We use a `this` because the compiler isn't smart enough to allow mutably borrowing // multiple different fields from the `Pin` at the same time. let this = &mut *self; @@ -440,20 +472,16 @@ where let s = this.substream.as_mut().expect("substream was empty"); loop { match this.shutdown_state { - ShutdownState::Shutdown => { - match this.muxer.shutdown_substream(cx, s) { - Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Flush, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), - Poll::Pending => return Poll::Pending, - } - } - ShutdownState::Flush => { - match this.muxer.flush_substream(cx, s) { - Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Done, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), - Poll::Pending => return Poll::Pending, - } - } + ShutdownState::Shutdown => match this.muxer.shutdown_substream(cx, s) { + Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Flush, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => return Poll::Pending, + }, + ShutdownState::Flush => match this.muxer.flush_substream(cx, s) { + Poll::Ready(Ok(())) => this.shutdown_state = ShutdownState::Done, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), + Poll::Pending => return Poll::Pending, + }, ShutdownState::Done => { return Poll::Ready(Ok(())); } @@ -477,13 +505,18 @@ where P::Target: StreamMuxer, { fn drop(&mut self) { - self.muxer.destroy_substream(self.substream.take().expect("substream was empty")) + self.muxer + .destroy_substream(self.substream.take().expect("substream was empty")) } } /// Abstract `StreamMuxer`. pub struct StreamMuxerBox { - inner: Box + Send + Sync>, + inner: Box< + dyn StreamMuxer + + Send + + Sync, + >, } impl StreamMuxerBox { @@ -514,7 +547,10 @@ impl StreamMuxer for StreamMuxerBox { type Error = io::Error; #[inline] - fn poll_event(&self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { self.inner.poll_event(cx) } @@ -524,7 +560,11 @@ impl StreamMuxer for StreamMuxerBox { } #[inline] - fn poll_outbound(&self, cx: &mut Context<'_>, s: &mut Self::OutboundSubstream) -> Poll> { + fn poll_outbound( + &self, + cx: &mut Context<'_>, + s: &mut Self::OutboundSubstream, + ) -> Poll> { self.inner.poll_outbound(cx, s) } @@ -534,22 +574,40 @@ impl StreamMuxer for StreamMuxerBox { } #[inline] - fn read_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + fn read_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { self.inner.read_substream(cx, s, buf) } #[inline] - fn write_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &[u8]) -> Poll> { + fn write_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { self.inner.write_substream(cx, s, buf) } #[inline] - fn flush_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) -> Poll> { + fn flush_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { self.inner.flush_substream(cx, s) } #[inline] - fn shutdown_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) -> Poll> { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { self.inner.shutdown_substream(cx, s) } @@ -569,7 +627,10 @@ impl StreamMuxer for StreamMuxerBox { } } -struct Wrap where T: StreamMuxer { +struct Wrap +where + T: StreamMuxer, +{ inner: T, substreams: Mutex>, next_substream: AtomicUsize, @@ -586,11 +647,15 @@ where type Error = io::Error; #[inline] - fn poll_event(&self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { let substream = match self.inner.poll_event(cx) { Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(StreamMuxerEvent::AddressChange(a))) => - return Poll::Ready(Ok(StreamMuxerEvent::AddressChange(a))), + Poll::Ready(Ok(StreamMuxerEvent::AddressChange(a))) => { + return Poll::Ready(Ok(StreamMuxerEvent::AddressChange(a))) + } Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(s))) => s, Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), }; @@ -615,7 +680,10 @@ where substream: &mut Self::OutboundSubstream, ) -> Poll> { let mut list = self.outbound.lock(); - let substream = match self.inner.poll_outbound(cx, list.get_mut(substream).unwrap()) { + let substream = match self + .inner + .poll_outbound(cx, list.get_mut(substream).unwrap()) + { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(s)) => s, Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())), @@ -628,37 +696,65 @@ where #[inline] fn destroy_outbound(&self, substream: Self::OutboundSubstream) { let mut list = self.outbound.lock(); - self.inner.destroy_outbound(list.remove(&substream).unwrap()) + self.inner + .destroy_outbound(list.remove(&substream).unwrap()) } #[inline] - fn read_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + fn read_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { let mut list = self.substreams.lock(); - self.inner.read_substream(cx, list.get_mut(s).unwrap(), buf).map_err(|e| e.into()) + self.inner + .read_substream(cx, list.get_mut(s).unwrap(), buf) + .map_err(|e| e.into()) } #[inline] - fn write_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream, buf: &[u8]) -> Poll> { + fn write_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { let mut list = self.substreams.lock(); - self.inner.write_substream(cx, list.get_mut(s).unwrap(), buf).map_err(|e| e.into()) + self.inner + .write_substream(cx, list.get_mut(s).unwrap(), buf) + .map_err(|e| e.into()) } #[inline] - fn flush_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) -> Poll> { + fn flush_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { let mut list = self.substreams.lock(); - self.inner.flush_substream(cx, list.get_mut(s).unwrap()).map_err(|e| e.into()) + self.inner + .flush_substream(cx, list.get_mut(s).unwrap()) + .map_err(|e| e.into()) } #[inline] - fn shutdown_substream(&self, cx: &mut Context<'_>, s: &mut Self::Substream) -> Poll> { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { let mut list = self.substreams.lock(); - self.inner.shutdown_substream(cx, list.get_mut(s).unwrap()).map_err(|e| e.into()) + self.inner + .shutdown_substream(cx, list.get_mut(s).unwrap()) + .map_err(|e| e.into()) } #[inline] fn destroy_substream(&self, substream: Self::Substream) { let mut list = self.substreams.lock(); - self.inner.destroy_substream(list.remove(&substream).unwrap()) + self.inner + .destroy_substream(list.remove(&substream).unwrap()) } #[inline] diff --git a/core/src/muxing/singleton.rs b/core/src/muxing/singleton.rs index 47701f07139..749e9cd673e 100644 --- a/core/src/muxing/singleton.rs +++ b/core/src/muxing/singleton.rs @@ -18,11 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{connection::Endpoint, muxing::{StreamMuxer, StreamMuxerEvent}}; +use crate::{ + connection::Endpoint, + muxing::{StreamMuxer, StreamMuxerEvent}, +}; use futures::prelude::*; use parking_lot::Mutex; -use std::{io, pin::Pin, sync::atomic::{AtomicBool, Ordering}, task::Context, task::Poll}; +use std::{ + io, + pin::Pin, + sync::atomic::{AtomicBool, Ordering}, + task::Context, + task::Poll, +}; /// Implementation of `StreamMuxer` that allows only one substream on top of a connection, /// yielding the connection itself. @@ -65,7 +74,10 @@ where type OutboundSubstream = OutboundSubstream; type Error = io::Error; - fn poll_event(&self, _: &mut Context<'_>) -> Poll, io::Error>> { + fn poll_event( + &self, + _: &mut Context<'_>, + ) -> Poll, io::Error>> { match self.endpoint { Endpoint::Dialer => return Poll::Pending, Endpoint::Listener => {} @@ -82,7 +94,11 @@ where OutboundSubstream {} } - fn poll_outbound(&self, _: &mut Context<'_>, _: &mut Self::OutboundSubstream) -> Poll> { + fn poll_outbound( + &self, + _: &mut Context<'_>, + _: &mut Self::OutboundSubstream, + ) -> Poll> { match self.endpoint { Endpoint::Listener => return Poll::Pending, Endpoint::Dialer => {} @@ -95,27 +111,43 @@ where } } - fn destroy_outbound(&self, _: Self::OutboundSubstream) { - } + fn destroy_outbound(&self, _: Self::OutboundSubstream) {} - fn read_substream(&self, cx: &mut Context<'_>, _: &mut Self::Substream, buf: &mut [u8]) -> Poll> { + fn read_substream( + &self, + cx: &mut Context<'_>, + _: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { AsyncRead::poll_read(Pin::new(&mut *self.inner.lock()), cx, buf) } - fn write_substream(&self, cx: &mut Context<'_>, _: &mut Self::Substream, buf: &[u8]) -> Poll> { + fn write_substream( + &self, + cx: &mut Context<'_>, + _: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { AsyncWrite::poll_write(Pin::new(&mut *self.inner.lock()), cx, buf) } - fn flush_substream(&self, cx: &mut Context<'_>, _: &mut Self::Substream) -> Poll> { + fn flush_substream( + &self, + cx: &mut Context<'_>, + _: &mut Self::Substream, + ) -> Poll> { AsyncWrite::poll_flush(Pin::new(&mut *self.inner.lock()), cx) } - fn shutdown_substream(&self, cx: &mut Context<'_>, _: &mut Self::Substream) -> Poll> { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + _: &mut Self::Substream, + ) -> Poll> { AsyncWrite::poll_close(Pin::new(&mut *self.inner.lock()), cx) } - fn destroy_substream(&self, _: Self::Substream) { - } + fn destroy_substream(&self, _: Self::Substream) {} fn close(&self, cx: &mut Context<'_>) -> Poll> { // The `StreamMuxer` trait requires that `close()` implies `flush_all()`. diff --git a/core/src/network.rs b/core/src/network.rs index c069171c7f3..784c1e01ca7 100644 --- a/core/src/network.rs +++ b/core/src/network.rs @@ -21,45 +21,30 @@ mod event; pub mod peer; -pub use crate::connection::{ConnectionLimits, ConnectionCounters}; -pub use event::{NetworkEvent, IncomingConnection}; +pub use crate::connection::{ConnectionCounters, ConnectionLimits}; +pub use event::{IncomingConnection, NetworkEvent}; pub use peer::Peer; use crate::{ - ConnectedPoint, - Executor, - Multiaddr, - PeerId, connection::{ - ConnectionId, - ConnectionLimit, - ConnectionHandler, - IntoConnectionHandler, - IncomingInfo, - OutgoingInfo, - ListenersEvent, - ListenerId, - ListenersStream, - PendingConnectionError, - Substream, - handler::{ - THandlerInEvent, - THandlerOutEvent, - }, + handler::{THandlerInEvent, THandlerOutEvent}, manager::ManagerConfig, pool::{Pool, PoolEvent}, + ConnectionHandler, ConnectionId, ConnectionLimit, IncomingInfo, IntoConnectionHandler, + ListenerId, ListenersEvent, ListenersStream, OutgoingInfo, PendingConnectionError, + Substream, }, muxing::StreamMuxer, transport::{Transport, TransportError}, + ConnectedPoint, Executor, Multiaddr, PeerId, }; -use fnv::{FnvHashMap}; -use futures::{prelude::*, future}; +use fnv::FnvHashMap; +use futures::{future, prelude::*}; use smallvec::SmallVec; use std::{ collections::hash_map, convert::TryFrom as _, - error, - fmt, + error, fmt, num::NonZeroUsize, pin::Pin, task::{Context, Poll}, @@ -95,8 +80,7 @@ where dialing: FnvHashMap>, } -impl fmt::Debug for - Network +impl fmt::Debug for Network where TTrans: fmt::Debug + Transport, THandler: fmt::Debug + ConnectionHandler, @@ -111,16 +95,14 @@ where } } -impl Unpin for - Network +impl Unpin for Network where TTrans: Transport, THandler: IntoConnectionHandler, { } -impl - Network +impl Network where TTrans: Transport, THandler: IntoConnectionHandler, @@ -131,8 +113,7 @@ where } } -impl - Network +impl Network where TTrans: Transport + Clone, TMuxer: StreamMuxer, @@ -142,11 +123,7 @@ where THandler::Handler: ConnectionHandler> + Send, { /// Creates a new node events stream. - pub fn new( - transport: TTrans, - local_peer_id: PeerId, - config: NetworkConfig, - ) -> Self { + pub fn new(transport: TTrans, local_peer_id: PeerId, config: NetworkConfig) -> Self { Network { local_peer_id, listeners: ListenersStream::new(transport), @@ -161,7 +138,10 @@ where } /// Start listening on the given multiaddress. - pub fn listen_on(&mut self, addr: Multiaddr) -> Result> { + pub fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> { self.listeners.listen_on(addr) } @@ -189,14 +169,14 @@ where /// other than the peer who reported the `observed_addr`. /// /// The translation is transport-specific. See [`Transport::address_translation`]. - pub fn address_translation<'a>(&'a self, observed_addr: &'a Multiaddr) - -> Vec + pub fn address_translation<'a>(&'a self, observed_addr: &'a Multiaddr) -> Vec where TMuxer: 'a, THandler: 'a, { let transport = self.listeners.transport(); - let mut addrs: Vec<_> = self.listen_addrs() + let mut addrs: Vec<_> = self + .listen_addrs() .filter_map(move |server| transport.address_translation(server, observed_addr)) .collect(); @@ -218,8 +198,11 @@ where /// The given `handler` will be used to create the /// [`Connection`](crate::connection::Connection) upon success and the /// connection ID is returned. - pub fn dial(&mut self, address: &Multiaddr, handler: THandler) - -> Result + pub fn dial( + &mut self, + address: &Multiaddr, + handler: THandler, + ) -> Result where TTrans: Transport, TTrans::Error: Send + 'static, @@ -238,21 +221,29 @@ where address: address.clone(), handler, remaining: Vec::new(), - }) + }); } } // The address does not specify an expected peer, so just try to dial it as-is, // accepting any peer ID that the remote identifies as. - let info = OutgoingInfo { address, peer_id: None }; + let info = OutgoingInfo { + address, + peer_id: None, + }; match self.transport().clone().dial(address.clone()) { Ok(f) => { - let f = f.map_err(|err| PendingConnectionError::Transport(TransportError::Other(err))); - self.pool.add_outgoing(f, handler, info).map_err(DialError::ConnectionLimit) + let f = + f.map_err(|err| PendingConnectionError::Transport(TransportError::Other(err))); + self.pool + .add_outgoing(f, handler, info) + .map_err(DialError::ConnectionLimit) } Err(err) => { let f = future::err(PendingConnectionError::Transport(err)); - self.pool.add_outgoing(f, handler, info).map_err(DialError::ConnectionLimit) + self.pool + .add_outgoing(f, handler, info) + .map_err(DialError::ConnectionLimit) } } } @@ -274,14 +265,13 @@ where /// Returns the list of addresses we're currently dialing without knowing the `PeerId` of. pub fn unknown_dials(&self) -> impl Iterator { - self.pool.iter_pending_outgoing() - .filter_map(|info| { - if info.peer_id.is_none() { - Some(info.address) - } else { - None - } - }) + self.pool.iter_pending_outgoing().filter_map(|info| { + if info.peer_id.is_none() { + Some(info.address) + } else { + None + } + }) } /// Returns a list of all connected peers, i.e. peers to whom the `Network` @@ -313,9 +303,7 @@ where } /// Obtains a view of a [`Peer`] with the given ID in the network. - pub fn peer(&mut self, peer_id: PeerId) - -> Peer<'_, TTrans, THandler> - { + pub fn peer(&mut self, peer_id: PeerId) -> Peer<'_, TTrans, THandler> { Peer::new(self, peer_id) } @@ -336,8 +324,9 @@ where TTrans::Error: Send + 'static, TTrans::ListenerUpgrade: Send + 'static, { - let upgrade = connection.upgrade.map_err(|err| - PendingConnectionError::Transport(TransportError::Other(err))); + let upgrade = connection + .upgrade + .map_err(|err| PendingConnectionError::Transport(TransportError::Other(err))); let info = IncomingInfo { local_addr: &connection.local_addr, send_back_addr: &connection.send_back_addr, @@ -346,7 +335,12 @@ where } /// Provides an API similar to `Stream`, except that it cannot error. - pub fn poll<'a>(&'a mut self, cx: &mut Context<'_>) -> Poll, THandlerOutEvent, THandler>> + pub fn poll<'a>( + &'a mut self, + cx: &mut Context<'_>, + ) -> Poll< + NetworkEvent<'a, TTrans, THandlerInEvent, THandlerOutEvent, THandler>, + > where TTrans: Transport, TTrans::Error: Send + 'static, @@ -364,7 +358,7 @@ where listener_id, upgrade, local_addr, - send_back_addr + send_back_addr, }) => { return Poll::Ready(NetworkEvent::IncomingConnection { listener_id, @@ -372,17 +366,37 @@ where upgrade, local_addr, send_back_addr, - } + }, }) } - Poll::Ready(ListenersEvent::NewAddress { listener_id, listen_addr }) => { - return Poll::Ready(NetworkEvent::NewListenerAddress { listener_id, listen_addr }) + Poll::Ready(ListenersEvent::NewAddress { + listener_id, + listen_addr, + }) => { + return Poll::Ready(NetworkEvent::NewListenerAddress { + listener_id, + listen_addr, + }) } - Poll::Ready(ListenersEvent::AddressExpired { listener_id, listen_addr }) => { - return Poll::Ready(NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr }) + Poll::Ready(ListenersEvent::AddressExpired { + listener_id, + listen_addr, + }) => { + return Poll::Ready(NetworkEvent::ExpiredListenerAddress { + listener_id, + listen_addr, + }) } - Poll::Ready(ListenersEvent::Closed { listener_id, addresses, reason }) => { - return Poll::Ready(NetworkEvent::ListenerClosed { listener_id, addresses, reason }) + Poll::Ready(ListenersEvent::Closed { + listener_id, + addresses, + reason, + }) => { + return Poll::Ready(NetworkEvent::ListenerClosed { + listener_id, + addresses, + reason, + }) } Poll::Ready(ListenersEvent::Error { listener_id, error }) => { return Poll::Ready(NetworkEvent::ListenerError { listener_id, error }) @@ -392,7 +406,10 @@ where // Poll the known peers. let event = match self.pool.poll(cx) { Poll::Pending => return Poll::Pending, - Poll::Ready(PoolEvent::ConnectionEstablished { connection, num_established }) => { + Poll::Ready(PoolEvent::ConnectionEstablished { + connection, + num_established, + }) => { if let hash_map::Entry::Occupied(mut e) = self.dialing.entry(connection.peer_id()) { e.get_mut().retain(|s| s.current.0 != connection.id()); if e.get().is_empty() { @@ -405,7 +422,14 @@ where num_established, } } - Poll::Ready(PoolEvent::PendingConnectionError { id, endpoint, error, handler, pool, .. }) => { + Poll::Ready(PoolEvent::PendingConnectionError { + id, + endpoint, + error, + handler, + pool, + .. + }) => { let dialing = &mut self.dialing; let (next, event) = on_connection_failed(dialing, id, endpoint, error, handler); if let Some(dial) = next { @@ -416,35 +440,37 @@ where } event } - Poll::Ready(PoolEvent::ConnectionClosed { id, connected, error, num_established, .. }) => { - NetworkEvent::ConnectionClosed { - id, - connected, - num_established, - error, - } - } + Poll::Ready(PoolEvent::ConnectionClosed { + id, + connected, + error, + num_established, + .. + }) => NetworkEvent::ConnectionClosed { + id, + connected, + num_established, + error, + }, Poll::Ready(PoolEvent::ConnectionEvent { connection, event }) => { - NetworkEvent::ConnectionEvent { - connection, - event, - } - } - Poll::Ready(PoolEvent::AddressChange { connection, new_endpoint, old_endpoint }) => { - NetworkEvent::AddressChange { - connection, - new_endpoint, - old_endpoint, - } + NetworkEvent::ConnectionEvent { connection, event } } + Poll::Ready(PoolEvent::AddressChange { + connection, + new_endpoint, + old_endpoint, + }) => NetworkEvent::AddressChange { + connection, + new_endpoint, + old_endpoint, + }, }; Poll::Ready(event) } /// Initiates a connection attempt to a known peer. - fn dial_peer(&mut self, opts: DialingOpts) - -> Result + fn dial_peer(&mut self, opts: DialingOpts) -> Result where TTrans: Transport, TTrans::Dial: Send + 'static, @@ -452,7 +478,12 @@ where TMuxer: Send + Sync + 'static, TMuxer::OutboundSubstream: Send, { - dial_peer_impl(self.transport().clone(), &mut self.pool, &mut self.dialing, opts) + dial_peer_impl( + self.transport().clone(), + &mut self.pool, + &mut self.dialing, + opts, + ) } } @@ -470,15 +501,13 @@ fn dial_peer_impl( transport: TTrans, pool: &mut Pool, dialing: &mut FnvHashMap>, - opts: DialingOpts + opts: DialingOpts, ) -> Result where THandler: IntoConnectionHandler + Send + 'static, ::Error: error::Error + Send + 'static, ::OutboundOpenInfo: Send + 'static, - THandler::Handler: ConnectionHandler< - Substream = Substream, - > + Send + 'static, + THandler::Handler: ConnectionHandler> + Send + 'static, TTrans: Transport, TTrans::Dial: Send + 'static, TTrans::Error: error::Error + Send + 'static, @@ -493,23 +522,32 @@ where let result = match transport.dial(addr.clone()) { Ok(fut) => { let fut = fut.map_err(|e| PendingConnectionError::Transport(TransportError::Other(e))); - let info = OutgoingInfo { address: &addr, peer_id: Some(&opts.peer) }; - pool.add_outgoing(fut, opts.handler, info).map_err(DialError::ConnectionLimit) - }, + let info = OutgoingInfo { + address: &addr, + peer_id: Some(&opts.peer), + }; + pool.add_outgoing(fut, opts.handler, info) + .map_err(DialError::ConnectionLimit) + } Err(err) => { let fut = future::err(PendingConnectionError::Transport(err)); - let info = OutgoingInfo { address: &addr, peer_id: Some(&opts.peer) }; - pool.add_outgoing(fut, opts.handler, info).map_err(DialError::ConnectionLimit) - }, + let info = OutgoingInfo { + address: &addr, + peer_id: Some(&opts.peer), + }; + pool.add_outgoing(fut, opts.handler, info) + .map_err(DialError::ConnectionLimit) + } }; if let Ok(id) = &result { - dialing.entry(opts.peer).or_default().push( - peer::DialingState { + dialing + .entry(opts.peer) + .or_default() + .push(peer::DialingState { current: (*id, addr), remaining: opts.remaining, - }, - ); + }); } result @@ -526,22 +564,24 @@ fn on_connection_failed<'a, TTrans, THandler>( endpoint: ConnectedPoint, error: PendingConnectionError, handler: Option, -) -> (Option>, NetworkEvent<'a, TTrans, THandlerInEvent, THandlerOutEvent, THandler>) +) -> ( + Option>, + NetworkEvent<'a, TTrans, THandlerInEvent, THandlerOutEvent, THandler>, +) where TTrans: Transport, THandler: IntoConnectionHandler, { // Check if the failed connection is associated with a dialing attempt. - let dialing_failed = dialing.iter_mut() - .find_map(|(peer, attempts)| { - if let Some(pos) = attempts.iter().position(|s| s.current.0 == id) { - let attempt = attempts.remove(pos); - let last = attempts.is_empty(); - Some((*peer, attempt, last)) - } else { - None - } - }); + let dialing_failed = dialing.iter_mut().find_map(|(peer, attempts)| { + if let Some(pos) = attempts.iter().position(|s| s.current.0 == id) { + let attempt = attempts.remove(pos); + let last = attempts.is_empty(); + Some((*peer, attempt, last)) + } else { + None + } + }); if let Some((peer_id, mut attempt, last)) = dialing_failed { if last { @@ -551,47 +591,56 @@ where let num_remain = u32::try_from(attempt.remaining.len()).unwrap(); let failed_addr = attempt.current.1.clone(); - let (opts, attempts_remaining) = - if num_remain > 0 { - if let Some(handler) = handler { - let next_attempt = attempt.remaining.remove(0); - let opts = DialingOpts { - peer: peer_id, - handler, - address: next_attempt, - remaining: attempt.remaining - }; - (Some(opts), num_remain) - } else { - // The error is "fatal" for the dialing attempt, since - // the handler was already consumed. All potential - // remaining connection attempts are thus void. - (None, 0) - } + let (opts, attempts_remaining) = if num_remain > 0 { + if let Some(handler) = handler { + let next_attempt = attempt.remaining.remove(0); + let opts = DialingOpts { + peer: peer_id, + handler, + address: next_attempt, + remaining: attempt.remaining, + }; + (Some(opts), num_remain) } else { + // The error is "fatal" for the dialing attempt, since + // the handler was already consumed. All potential + // remaining connection attempts are thus void. (None, 0) - }; + } + } else { + (None, 0) + }; - (opts, NetworkEvent::DialError { - attempts_remaining, - peer_id, - multiaddr: failed_addr, - error, - }) + ( + opts, + NetworkEvent::DialError { + attempts_remaining, + peer_id, + multiaddr: failed_addr, + error, + }, + ) } else { // A pending incoming connection or outgoing connection to an unknown peer failed. match endpoint { - ConnectedPoint::Dialer { address } => - (None, NetworkEvent::UnknownPeerDialError { + ConnectedPoint::Dialer { address } => ( + None, + NetworkEvent::UnknownPeerDialError { multiaddr: address, error, - }), - ConnectedPoint::Listener { local_addr, send_back_addr } => - (None, NetworkEvent::IncomingConnectionError { + }, + ), + ConnectedPoint::Listener { + local_addr, + send_back_addr, + } => ( + None, + NetworkEvent::IncomingConnectionError { local_addr, send_back_addr, - error - }) + error, + }, + ), } } } @@ -644,7 +693,7 @@ impl NetworkConfig { /// only if no executor has already been configured. pub fn or_else_with_executor(mut self, f: F) -> Self where - F: FnOnce() -> Option> + F: FnOnce() -> Option>, { self.manager_config.executor = self.manager_config.executor.or_else(f); self @@ -693,7 +742,7 @@ impl NetworkConfig { fn p2p_addr(peer: PeerId, addr: Multiaddr) -> Result { if let Some(multiaddr::Protocol::P2p(hash)) = addr.iter().last() { if &hash != peer.as_ref() { - return Err(addr) + return Err(addr); } Ok(addr) } else { @@ -718,7 +767,7 @@ mod tests { struct Dummy; impl Executor for Dummy { - fn exec(&self, _: Pin + Send>>) { } + fn exec(&self, _: Pin + Send>>) {} } #[test] diff --git a/core/src/network/event.rs b/core/src/network/event.rs index 8154bd2087d..7b4158265d9 100644 --- a/core/src/network/event.rs +++ b/core/src/network/event.rs @@ -21,20 +21,12 @@ //! Network events and associated information. use crate::{ - Multiaddr, connection::{ - ConnectionId, - ConnectedPoint, - ConnectionError, - ConnectionHandler, - Connected, - EstablishedConnection, - IntoConnectionHandler, - ListenerId, - PendingConnectionError, + Connected, ConnectedPoint, ConnectionError, ConnectionHandler, ConnectionId, + EstablishedConnection, IntoConnectionHandler, ListenerId, PendingConnectionError, }, transport::Transport, - PeerId + Multiaddr, PeerId, }; use std::{fmt, num::NonZeroU32}; @@ -60,7 +52,7 @@ where /// The listener that errored. listener_id: ListenerId, /// The listener error. - error: TTrans::Error + error: TTrans::Error, }, /// One of the listeners is now listening on an additional address. @@ -68,7 +60,7 @@ where /// The listener that is listening on the new address. listener_id: ListenerId, /// The new address the listener is now also listening on. - listen_addr: Multiaddr + listen_addr: Multiaddr, }, /// One of the listeners is no longer listening on some address. @@ -76,7 +68,7 @@ where /// The listener that is no longer listening on some address. listener_id: ListenerId, /// The expired address. - listen_addr: Multiaddr + listen_addr: Multiaddr, }, /// A new connection arrived on a listener. @@ -177,8 +169,8 @@ where }, } -impl fmt::Debug for - NetworkEvent<'_, TTrans, TInEvent, TOutEvent, THandler> +impl fmt::Debug + for NetworkEvent<'_, TTrans, TInEvent, TOutEvent, THandler> where TInEvent: fmt::Debug, TOutEvent: fmt::Debug, @@ -189,83 +181,101 @@ where { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - NetworkEvent::NewListenerAddress { listener_id, listen_addr } => { - f.debug_struct("NewListenerAddress") - .field("listener_id", listener_id) - .field("listen_addr", listen_addr) - .finish() - } - NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr } => { - f.debug_struct("ExpiredListenerAddress") - .field("listener_id", listener_id) - .field("listen_addr", listen_addr) - .finish() - } - NetworkEvent::ListenerClosed { listener_id, addresses, reason } => { - f.debug_struct("ListenerClosed") - .field("listener_id", listener_id) - .field("addresses", addresses) - .field("reason", reason) - .finish() - } - NetworkEvent::ListenerError { listener_id, error } => { - f.debug_struct("ListenerError") - .field("listener_id", listener_id) - .field("error", error) - .finish() - } - NetworkEvent::IncomingConnection { connection, .. } => { - f.debug_struct("IncomingConnection") - .field("local_addr", &connection.local_addr) - .field("send_back_addr", &connection.send_back_addr) - .finish() - } - NetworkEvent::IncomingConnectionError { local_addr, send_back_addr, error } => { - f.debug_struct("IncomingConnectionError") - .field("local_addr", local_addr) - .field("send_back_addr", send_back_addr) - .field("error", error) - .finish() - } - NetworkEvent::ConnectionEstablished { connection, .. } => { - f.debug_struct("ConnectionEstablished") - .field("connection", connection) - .finish() - } - NetworkEvent::ConnectionClosed { id, connected, error, .. } => { - f.debug_struct("ConnectionClosed") - .field("id", id) - .field("connected", connected) - .field("error", error) - .finish() - } - NetworkEvent::DialError { attempts_remaining, peer_id, multiaddr, error } => { - f.debug_struct("DialError") - .field("attempts_remaining", attempts_remaining) - .field("peer_id", peer_id) - .field("multiaddr", multiaddr) - .field("error", error) - .finish() - } - NetworkEvent::UnknownPeerDialError { multiaddr, error, .. } => { - f.debug_struct("UnknownPeerDialError") - .field("multiaddr", multiaddr) - .field("error", error) - .finish() - } - NetworkEvent::ConnectionEvent { connection, event } => { - f.debug_struct("ConnectionEvent") - .field("connection", connection) - .field("event", event) - .finish() - } - NetworkEvent::AddressChange { connection, new_endpoint, old_endpoint } => { - f.debug_struct("AddressChange") - .field("connection", connection) - .field("new_endpoint", new_endpoint) - .field("old_endpoint", old_endpoint) - .finish() - } + NetworkEvent::NewListenerAddress { + listener_id, + listen_addr, + } => f + .debug_struct("NewListenerAddress") + .field("listener_id", listener_id) + .field("listen_addr", listen_addr) + .finish(), + NetworkEvent::ExpiredListenerAddress { + listener_id, + listen_addr, + } => f + .debug_struct("ExpiredListenerAddress") + .field("listener_id", listener_id) + .field("listen_addr", listen_addr) + .finish(), + NetworkEvent::ListenerClosed { + listener_id, + addresses, + reason, + } => f + .debug_struct("ListenerClosed") + .field("listener_id", listener_id) + .field("addresses", addresses) + .field("reason", reason) + .finish(), + NetworkEvent::ListenerError { listener_id, error } => f + .debug_struct("ListenerError") + .field("listener_id", listener_id) + .field("error", error) + .finish(), + NetworkEvent::IncomingConnection { connection, .. } => f + .debug_struct("IncomingConnection") + .field("local_addr", &connection.local_addr) + .field("send_back_addr", &connection.send_back_addr) + .finish(), + NetworkEvent::IncomingConnectionError { + local_addr, + send_back_addr, + error, + } => f + .debug_struct("IncomingConnectionError") + .field("local_addr", local_addr) + .field("send_back_addr", send_back_addr) + .field("error", error) + .finish(), + NetworkEvent::ConnectionEstablished { connection, .. } => f + .debug_struct("ConnectionEstablished") + .field("connection", connection) + .finish(), + NetworkEvent::ConnectionClosed { + id, + connected, + error, + .. + } => f + .debug_struct("ConnectionClosed") + .field("id", id) + .field("connected", connected) + .field("error", error) + .finish(), + NetworkEvent::DialError { + attempts_remaining, + peer_id, + multiaddr, + error, + } => f + .debug_struct("DialError") + .field("attempts_remaining", attempts_remaining) + .field("peer_id", peer_id) + .field("multiaddr", multiaddr) + .field("error", error) + .finish(), + NetworkEvent::UnknownPeerDialError { + multiaddr, error, .. + } => f + .debug_struct("UnknownPeerDialError") + .field("multiaddr", multiaddr) + .field("error", error) + .finish(), + NetworkEvent::ConnectionEvent { connection, event } => f + .debug_struct("ConnectionEvent") + .field("connection", connection) + .field("event", event) + .finish(), + NetworkEvent::AddressChange { + connection, + new_endpoint, + old_endpoint, + } => f + .debug_struct("AddressChange") + .field("connection", connection) + .field("new_endpoint", new_endpoint) + .field("old_endpoint", old_endpoint) + .finish(), } } } diff --git a/core/src/network/peer.rs b/core/src/network/peer.rs index 88c96aa0983..ca1b9be7502 100644 --- a/core/src/network/peer.rs +++ b/core/src/network/peer.rs @@ -18,35 +18,18 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use super::{DialError, DialingOpts, Network}; use crate::{ - Multiaddr, - Transport, - StreamMuxer, connection::{ - Connected, - ConnectedPoint, - ConnectionHandler, - Connection, - ConnectionId, - ConnectionLimit, - EstablishedConnection, - EstablishedConnectionIter, - IntoConnectionHandler, - PendingConnection, - Substream, - handler::THandlerInEvent, - pool::Pool, + handler::THandlerInEvent, pool::Pool, Connected, ConnectedPoint, Connection, + ConnectionHandler, ConnectionId, ConnectionLimit, EstablishedConnection, + EstablishedConnectionIter, IntoConnectionHandler, PendingConnection, Substream, }, - PeerId + Multiaddr, PeerId, StreamMuxer, Transport, }; use fnv::FnvHashMap; use smallvec::SmallVec; -use std::{ - collections::hash_map, - error, - fmt, -}; -use super::{Network, DialingOpts, DialError}; +use std::{collections::hash_map, error, fmt}; /// The possible representations of a peer in a [`Network`], as /// seen by the local node. @@ -57,7 +40,7 @@ use super::{Network, DialingOpts, DialError}; pub enum Peer<'a, TTrans, THandler> where TTrans: Transport, - THandler: IntoConnectionHandler + THandler: IntoConnectionHandler, { /// At least one established connection exists to the peer. Connected(ConnectedPeer<'a, TTrans, THandler>), @@ -76,53 +59,33 @@ where Local, } -impl<'a, TTrans, THandler> fmt::Debug for - Peer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> fmt::Debug for Peer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - Peer::Connected(p) => { - f.debug_struct("Connected") - .field("peer", &p) - .finish() - } - Peer::Dialing(p) => { - f.debug_struct("Dialing") - .field("peer", &p) - .finish() - } - Peer::Disconnected(p) => { - f.debug_struct("Disconnected") - .field("peer", &p) - .finish() - } - Peer::Local => { - f.debug_struct("Local") - .finish() - } + Peer::Connected(p) => f.debug_struct("Connected").field("peer", &p).finish(), + Peer::Dialing(p) => f.debug_struct("Dialing").field("peer", &p).finish(), + Peer::Disconnected(p) => f.debug_struct("Disconnected").field("peer", &p).finish(), + Peer::Local => f.debug_struct("Local").finish(), } } } -impl<'a, TTrans, THandler> - Peer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> Peer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, { - pub(super) fn new( - network: &'a mut Network, - peer_id: PeerId - ) -> Self { + pub(super) fn new(network: &'a mut Network, peer_id: PeerId) -> Self { if peer_id == network.local_peer_id { return Peer::Local; } if network.pool.is_connected(&peer_id) { - return Self::connected(network, peer_id) + return Self::connected(network, peer_id); } if network.dialing.get_mut(&peer_id).is_some() { @@ -132,31 +95,20 @@ where Self::disconnected(network, peer_id) } - - fn disconnected( - network: &'a mut Network, - peer_id: PeerId - ) -> Self { + fn disconnected(network: &'a mut Network, peer_id: PeerId) -> Self { Peer::Disconnected(DisconnectedPeer { network, peer_id }) } - fn connected( - network: &'a mut Network, - peer_id: PeerId - ) -> Self { + fn connected(network: &'a mut Network, peer_id: PeerId) -> Self { Peer::Connected(ConnectedPeer { network, peer_id }) } - fn dialing( - network: &'a mut Network, - peer_id: PeerId - ) -> Self { + fn dialing(network: &'a mut Network, peer_id: PeerId) -> Self { Peer::Dialing(DialingPeer { network, peer_id }) } } -impl<'a, TTrans, TMuxer, THandler> - Peer<'a, TTrans, THandler> +impl<'a, TTrans, TMuxer, THandler> Peer<'a, TTrans, THandler> where TTrans: Transport + Clone, TTrans::Error: Send + 'static, @@ -176,7 +128,7 @@ where Peer::Connected(..) => true, Peer::Dialing(peer) => peer.is_connected(), Peer::Disconnected(..) => false, - Peer::Local => false + Peer::Local => false, } } @@ -188,7 +140,7 @@ where Peer::Dialing(_) => true, Peer::Connected(peer) => peer.is_dialing(), Peer::Disconnected(..) => false, - Peer::Local => false + Peer::Local => false, } } @@ -206,11 +158,12 @@ where /// `remaining` addresses are tried in order in subsequent connection /// attempts in the context of the same dialing attempt, if the connection /// attempt to the first address fails. - pub fn dial(self, address: Multiaddr, remaining: I, handler: THandler) - -> Result< - (ConnectionId, DialingPeer<'a, TTrans, THandler>), - DialError - > + pub fn dial( + self, + address: Multiaddr, + remaining: I, + handler: THandler, + ) -> Result<(ConnectionId, DialingPeer<'a, TTrans, THandler>), DialError> where I: IntoIterator, { @@ -218,9 +171,12 @@ where Peer::Connected(p) => (p.peer_id, p.network), Peer::Dialing(p) => (p.peer_id, p.network), Peer::Disconnected(p) => (p.peer_id, p.network), - Peer::Local => return Err(DialError::ConnectionLimit(ConnectionLimit { - current: 0, limit: 0 - })) + Peer::Local => { + return Err(DialError::ConnectionLimit(ConnectionLimit { + current: 0, + limit: 0, + })) + } }; let id = network.dial_peer(DialingOpts { @@ -236,9 +192,7 @@ where /// Converts the peer into a `ConnectedPeer`, if an established connection exists. /// /// Succeeds if the there is at least one established connection to the peer. - pub fn into_connected(self) -> Option< - ConnectedPeer<'a, TTrans, THandler> - > { + pub fn into_connected(self) -> Option> { match self { Peer::Connected(peer) => Some(peer), Peer::Dialing(peer) => peer.into_connected(), @@ -250,22 +204,18 @@ where /// Converts the peer into a `DialingPeer`, if a dialing attempt exists. /// /// Succeeds if the there is at least one pending outgoing connection to the peer. - pub fn into_dialing(self) -> Option< - DialingPeer<'a, TTrans, THandler> - > { + pub fn into_dialing(self) -> Option> { match self { Peer::Dialing(peer) => Some(peer), Peer::Connected(peer) => peer.into_dialing(), Peer::Disconnected(..) => None, - Peer::Local => None + Peer::Local => None, } } /// Converts the peer into a `DisconnectedPeer`, if neither an established connection /// nor a dialing attempt exists. - pub fn into_disconnected(self) -> Option< - DisconnectedPeer<'a, TTrans, THandler> - > { + pub fn into_disconnected(self) -> Option> { match self { Peer::Disconnected(peer) => Some(peer), _ => None, @@ -285,8 +235,7 @@ where peer_id: PeerId, } -impl<'a, TTrans, THandler> - ConnectedPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> ConnectedPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -301,9 +250,10 @@ where } /// Obtains an established connection to the peer by ID. - pub fn connection(&mut self, id: ConnectionId) - -> Option>> - { + pub fn connection( + &mut self, + id: ConnectionId, + ) -> Option>> { self.network.pool.get_established(id) } @@ -321,47 +271,43 @@ where /// Converts this peer into a [`DialingPeer`], if there is an ongoing /// dialing attempt, `None` otherwise. - pub fn into_dialing(self) -> Option< - DialingPeer<'a, TTrans, THandler> - > { + pub fn into_dialing(self) -> Option> { if self.network.dialing.contains_key(&self.peer_id) { - Some(DialingPeer { network: self.network, peer_id: self.peer_id }) + Some(DialingPeer { + network: self.network, + peer_id: self.peer_id, + }) } else { None } } /// Gets an iterator over all established connections to the peer. - pub fn connections(&mut self) -> - EstablishedConnectionIter< - impl Iterator, - THandler, - TTrans::Error, - > + pub fn connections( + &mut self, + ) -> EstablishedConnectionIter, THandler, TTrans::Error> { self.network.pool.iter_peer_established(&self.peer_id) } /// Obtains some established connection to the peer. - pub fn some_connection(&mut self) - -> EstablishedConnection> - { + pub fn some_connection(&mut self) -> EstablishedConnection> { self.connections() .into_first() .expect("By `Peer::new` and the definition of `ConnectedPeer`.") } /// Disconnects from the peer, closing all connections. - pub fn disconnect(self) - -> DisconnectedPeer<'a, TTrans, THandler> - { + pub fn disconnect(self) -> DisconnectedPeer<'a, TTrans, THandler> { self.network.disconnect(&self.peer_id); - DisconnectedPeer { network: self.network, peer_id: self.peer_id } + DisconnectedPeer { + network: self.network, + peer_id: self.peer_id, + } } } -impl<'a, TTrans, THandler> fmt::Debug for - ConnectedPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> fmt::Debug for ConnectedPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -369,7 +315,10 @@ where fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("ConnectedPeer") .field("peer_id", &self.peer_id) - .field("established", &self.network.pool.iter_peer_established_info(&self.peer_id)) + .field( + "established", + &self.network.pool.iter_peer_established_info(&self.peer_id), + ) .field("attempts", &self.network.dialing.get(&self.peer_id)) .finish() } @@ -387,8 +336,7 @@ where peer_id: PeerId, } -impl<'a, TTrans, THandler> - DialingPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> DialingPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -404,11 +352,12 @@ where /// Disconnects from this peer, closing all established connections and /// aborting all dialing attempts. - pub fn disconnect(self) - -> DisconnectedPeer<'a, TTrans, THandler> - { + pub fn disconnect(self) -> DisconnectedPeer<'a, TTrans, THandler> { self.network.disconnect(&self.peer_id); - DisconnectedPeer { network: self.network, peer_id: self.peer_id } + DisconnectedPeer { + network: self.network, + peer_id: self.peer_id, + } } /// Checks whether there is an established connection to the peer. @@ -419,11 +368,12 @@ where } /// Converts the peer into a `ConnectedPeer`, if an established connection exists. - pub fn into_connected(self) - -> Option> - { + pub fn into_connected(self) -> Option> { if self.is_connected() { - Some(ConnectedPeer { peer_id: self.peer_id, network: self.network }) + Some(ConnectedPeer { + peer_id: self.peer_id, + network: self.network, + }) } else { None } @@ -431,13 +381,18 @@ where /// Obtains a dialing attempt to the peer by connection ID of /// the current connection attempt. - pub fn attempt(&mut self, id: ConnectionId) - -> Option>> - { + pub fn attempt( + &mut self, + id: ConnectionId, + ) -> Option>> { if let hash_map::Entry::Occupied(attempts) = self.network.dialing.entry(self.peer_id) { if let Some(pos) = attempts.get().iter().position(|s| s.current.0 == id) { if let Some(inner) = self.network.pool.get_outgoing(id) { - return Some(DialingAttempt { pos, inner, attempts }) + return Some(DialingAttempt { + pos, + inner, + attempts, + }); } } } @@ -445,25 +400,25 @@ where } /// Gets an iterator over all dialing (i.e. pending outgoing) connections to the peer. - pub fn attempts(&mut self) -> DialingAttemptIter<'_, THandler, TTrans::Error> - { - DialingAttemptIter::new(&self.peer_id, &mut self.network.pool, &mut self.network.dialing) + pub fn attempts(&mut self) -> DialingAttemptIter<'_, THandler, TTrans::Error> { + DialingAttemptIter::new( + &self.peer_id, + &mut self.network.pool, + &mut self.network.dialing, + ) } /// Obtains some dialing connection to the peer. /// /// At least one dialing connection is guaranteed to exist on a `DialingPeer`. - pub fn some_attempt(&mut self) - -> DialingAttempt<'_, THandlerInEvent> - { + pub fn some_attempt(&mut self) -> DialingAttempt<'_, THandlerInEvent> { self.attempts() .into_first() .expect("By `Peer::new` and the definition of `DialingPeer`.") } } -impl<'a, TTrans, THandler> fmt::Debug for - DialingPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> fmt::Debug for DialingPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -471,7 +426,10 @@ where fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("DialingPeer") .field("peer_id", &self.peer_id) - .field("established", &self.network.pool.iter_peer_established_info(&self.peer_id)) + .field( + "established", + &self.network.pool.iter_peer_established_info(&self.peer_id), + ) .field("attempts", &self.network.dialing.get(&self.peer_id)) .finish() } @@ -489,8 +447,7 @@ where network: &'a mut Network, } -impl<'a, TTrans, THandler> fmt::Debug for - DisconnectedPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> fmt::Debug for DisconnectedPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -502,8 +459,7 @@ where } } -impl<'a, TTrans, THandler> - DisconnectedPeer<'a, TTrans, THandler> +impl<'a, TTrans, THandler> DisconnectedPeer<'a, TTrans, THandler> where TTrans: Transport, THandler: IntoConnectionHandler, @@ -529,10 +485,8 @@ where self, connected: Connected, connection: Connection, - ) -> Result< - ConnectedPeer<'a, TTrans, THandler>, - ConnectionLimit - > where + ) -> Result, ConnectionLimit> + where THandler: Send + 'static, TTrans::Error: Send + 'static, THandler::Handler: ConnectionHandler> + Send, @@ -542,10 +496,15 @@ where TMuxer::OutboundSubstream: Send, { if connected.peer_id != self.peer_id { - panic!("Invalid peer ID given: {:?}. Expected: {:?}", connected.peer_id, self.peer_id) + panic!( + "Invalid peer ID given: {:?}. Expected: {:?}", + connected.peer_id, self.peer_id + ) } - self.network.pool.add(connection, connected) + self.network + .pool + .add(connection, connected) .map(move |_id| ConnectedPeer { network: self.network, peer_id: self.peer_id, @@ -575,9 +534,7 @@ pub struct DialingAttempt<'a, TInEvent> { pos: usize, } -impl<'a, TInEvent> - DialingAttempt<'a, TInEvent> -{ +impl<'a, TInEvent> DialingAttempt<'a, TInEvent> { /// Returns the ID of the current connection attempt. pub fn id(&self) -> ConnectionId { self.inner.id() @@ -592,7 +549,7 @@ impl<'a, TInEvent> pub fn address(&self) -> &Multiaddr { match self.inner.endpoint() { ConnectedPoint::Dialer { address } => address, - ConnectedPoint::Listener { .. } => unreachable!("by definition of a `DialingAttempt`.") + ConnectedPoint::Listener { .. } => unreachable!("by definition of a `DialingAttempt`."), } } @@ -640,16 +597,20 @@ pub struct DialingAttemptIter<'a, THandler: IntoConnectionHandler, TTransErr> { // Note: Ideally this would be an implementation of `Iterator`, but that // requires GATs (cf. https://github.com/rust-lang/rust/issues/44265) and // a different definition of `Iterator`. -impl<'a, THandler: IntoConnectionHandler, TTransErr> - DialingAttemptIter<'a, THandler, TTransErr> -{ +impl<'a, THandler: IntoConnectionHandler, TTransErr> DialingAttemptIter<'a, THandler, TTransErr> { fn new( peer_id: &'a PeerId, pool: &'a mut Pool, dialing: &'a mut FnvHashMap>, ) -> Self { let end = dialing.get(peer_id).map_or(0, |conns| conns.len()); - Self { pos: 0, end, pool, dialing, peer_id } + Self { + pos: 0, + end, + pool, + dialing, + peer_id, + } } /// Obtains the next dialing connection, if any. @@ -658,22 +619,29 @@ impl<'a, THandler: IntoConnectionHandler, TTransErr> // If the number of elements reduced, the current `DialingAttempt` has been // aborted and iteration needs to continue from the previous position to // account for the removed element. - let end = self.dialing.get(self.peer_id).map_or(0, |conns| conns.len()); + let end = self + .dialing + .get(self.peer_id) + .map_or(0, |conns| conns.len()); if self.end > end { self.end = end; self.pos -= 1; } if self.pos == self.end { - return None + return None; } if let hash_map::Entry::Occupied(attempts) = self.dialing.entry(*self.peer_id) { let id = attempts.get()[self.pos].current.0; if let Some(inner) = self.pool.get_outgoing(id) { - let conn = DialingAttempt { pos: self.pos, inner, attempts }; + let conn = DialingAttempt { + pos: self.pos, + inner, + attempts, + }; self.pos += 1; - return Some(conn) + return Some(conn); } } @@ -681,18 +649,22 @@ impl<'a, THandler: IntoConnectionHandler, TTransErr> } /// Returns the first connection, if any, consuming the iterator. - pub fn into_first<'b>(self) - -> Option>> - where 'a: 'b + pub fn into_first<'b>(self) -> Option>> + where + 'a: 'b, { if self.pos == self.end { - return None + return None; } if let hash_map::Entry::Occupied(attempts) = self.dialing.entry(*self.peer_id) { let id = attempts.get()[self.pos].current.0; if let Some(inner) = self.pool.get_outgoing(id) { - return Some(DialingAttempt { pos: self.pos, inner, attempts }) + return Some(DialingAttempt { + pos: self.pos, + inner, + attempts, + }); } } diff --git a/core/src/peer_id.rs b/core/src/peer_id.rs index 37d46038243..5a9ae8b0341 100644 --- a/core/src/peer_id.rs +++ b/core/src/peer_id.rs @@ -38,9 +38,7 @@ pub struct PeerId { impl fmt::Debug for PeerId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("PeerId") - .field(&self.to_base58()) - .finish() + f.debug_tuple("PeerId").field(&self.to_base58()).finish() } } @@ -80,9 +78,10 @@ impl PeerId { pub fn from_multihash(multihash: Multihash) -> Result { match Code::try_from(multihash.code()) { Ok(Code::Sha2_256) => Ok(PeerId { multihash }), - Ok(Code::Identity) if multihash.digest().len() <= MAX_INLINE_KEY_LENGTH - => Ok(PeerId { multihash }), - _ => Err(multihash) + Ok(Code::Identity) if multihash.digest().len() <= MAX_INLINE_KEY_LENGTH => { + Ok(PeerId { multihash }) + } + _ => Err(multihash), } } @@ -93,7 +92,7 @@ impl PeerId { let peer_id = rand::thread_rng().gen::<[u8; 32]>(); PeerId { multihash: Multihash::wrap(Code::Identity.into(), &peer_id) - .expect("The digest size is never too large") + .expect("The digest size is never too large"), } } @@ -185,7 +184,7 @@ impl FromStr for PeerId { #[cfg(test)] mod tests { - use crate::{PeerId, identity}; + use crate::{identity, PeerId}; #[test] fn peer_id_is_public_key() { @@ -210,7 +209,7 @@ mod tests { #[test] fn random_peer_id_is_valid() { - for _ in 0 .. 5000 { + for _ in 0..5000 { let peer_id = PeerId::random(); assert_eq!(peer_id, PeerId::from_bytes(&peer_id.to_bytes()).unwrap()); } diff --git a/core/src/transport.rs b/core/src/transport.rs index 0d33b88a494..ade8daf6a6c 100644 --- a/core/src/transport.rs +++ b/core/src/transport.rs @@ -93,7 +93,9 @@ pub trait Transport { /// /// If this stream produces an error, it is considered fatal and the listener is killed. It /// is possible to report non-fatal errors by producing a [`ListenerEvent::Error`]. - type Listener: Stream, Self::Error>>; + type Listener: Stream< + Item = Result, Self::Error>, + >; /// A pending [`Output`](Transport::Output) for an inbound connection, /// obtained from the [`Listener`](Transport::Listener) stream. @@ -148,7 +150,7 @@ pub trait Transport { fn map(self, f: F) -> map::Map where Self: Sized, - F: FnOnce(Self::Output, ConnectedPoint) -> O + Clone + F: FnOnce(Self::Output, ConnectedPoint) -> O + Clone, { map::Map::new(self, f) } @@ -157,7 +159,7 @@ pub trait Transport { fn map_err(self, f: F) -> map_err::MapErr where Self: Sized, - F: FnOnce(Self::Error) -> E + Clone + F: FnOnce(Self::Error) -> E + Clone, { map_err::MapErr::new(self, f) } @@ -171,7 +173,7 @@ pub trait Transport { where Self: Sized, U: Transport, - ::Error: 'static + ::Error: 'static, { OrTransport::new(self, other) } @@ -188,7 +190,7 @@ pub trait Transport { Self: Sized, C: FnOnce(Self::Output, ConnectedPoint) -> F + Clone, F: TryFuture, - ::Error: Error + 'static + ::Error: Error + 'static, { and_then::AndThen::new(self, f) } @@ -198,7 +200,7 @@ pub trait Transport { fn upgrade(self) -> upgrade::Builder where Self: Sized, - Self::Error: 'static + Self::Error: 'static, { upgrade::Builder::new(self) } @@ -221,7 +223,7 @@ pub enum ListenerEvent { /// The local address which produced this upgrade. local_addr: Multiaddr, /// The remote address which produced this upgrade. - remote_addr: Multiaddr + remote_addr: Multiaddr, }, /// A [`Multiaddr`] is no longer used for listening. AddressExpired(Multiaddr), @@ -238,9 +240,15 @@ impl ListenerEvent { /// based the the function's result. pub fn map(self, f: impl FnOnce(TUpgr) -> U) -> ListenerEvent { match self { - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => { - ListenerEvent::Upgrade { upgrade: f(upgrade), local_addr, remote_addr } - } + ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + } => ListenerEvent::Upgrade { + upgrade: f(upgrade), + local_addr, + remote_addr, + }, ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), ListenerEvent::Error(e) => ListenerEvent::Error(e), @@ -252,8 +260,15 @@ impl ListenerEvent { /// function's result. pub fn map_err(self, f: impl FnOnce(TErr) -> U) -> ListenerEvent { match self { - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr }, + ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + } => ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + }, ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), ListenerEvent::AddressExpired(a) => ListenerEvent::AddressExpired(a), ListenerEvent::Error(e) => ListenerEvent::Error(f(e)), @@ -262,7 +277,7 @@ impl ListenerEvent { /// Returns `true` if this is an `Upgrade` listener event. pub fn is_upgrade(&self) -> bool { - matches!(self, ListenerEvent::Upgrade {..}) + matches!(self, ListenerEvent::Upgrade { .. }) } /// Try to turn this listener event into upgrade parts. @@ -270,7 +285,12 @@ impl ListenerEvent { /// Returns `None` if the event is not actually an upgrade, /// otherwise the upgrade and the remote address. pub fn into_upgrade(self) -> Option<(TUpgr, Multiaddr)> { - if let ListenerEvent::Upgrade { upgrade, remote_addr, .. } = self { + if let ListenerEvent::Upgrade { + upgrade, + remote_addr, + .. + } = self + { Some((upgrade, remote_addr)) } else { None @@ -346,25 +366,31 @@ impl TransportError { /// Applies a function to the the error in [`TransportError::Other`]. pub fn map(self, map: impl FnOnce(TErr) -> TNewErr) -> TransportError { match self { - TransportError::MultiaddrNotSupported(addr) => TransportError::MultiaddrNotSupported(addr), + TransportError::MultiaddrNotSupported(addr) => { + TransportError::MultiaddrNotSupported(addr) + } TransportError::Other(err) => TransportError::Other(map(err)), } } } impl fmt::Display for TransportError -where TErr: fmt::Display, +where + TErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - TransportError::MultiaddrNotSupported(addr) => write!(f, "Multiaddr is not supported: {}", addr), + TransportError::MultiaddrNotSupported(addr) => { + write!(f, "Multiaddr is not supported: {}", addr) + } TransportError::Other(err) => write!(f, "{}", err), } } } impl Error for TransportError -where TErr: Error + 'static, +where + TErr: Error + 'static, { fn source(&self) -> Option<&(dyn Error + 'static)> { match self { diff --git a/core/src/transport/and_then.rs b/core/src/transport/and_then.rs index 22018729a07..51f5d88c2b6 100644 --- a/core/src/transport/and_then.rs +++ b/core/src/transport/and_then.rs @@ -19,9 +19,9 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - ConnectedPoint, either::EitherError, - transport::{Transport, TransportError, ListenerEvent} + transport::{ListenerEvent, Transport, TransportError}, + ConnectedPoint, }; use futures::{future::Either, prelude::*}; use multiaddr::Multiaddr; @@ -29,7 +29,10 @@ use std::{error, marker::PhantomPinned, pin::Pin, task::Context, task::Poll}; /// See the `Transport::and_then` method. #[derive(Debug, Clone)] -pub struct AndThen { transport: T, fun: C } +pub struct AndThen { + transport: T, + fun: C, +} impl AndThen { pub(crate) fn new(transport: T, fun: C) -> Self { @@ -51,17 +54,26 @@ where type Dial = AndThenFuture; fn listen_on(self, addr: Multiaddr) -> Result> { - let listener = self.transport.listen_on(addr).map_err(|err| err.map(EitherError::A))?; + let listener = self + .transport + .listen_on(addr) + .map_err(|err| err.map(EitherError::A))?; // Try to negotiate the protocol. // Note that failing to negotiate a protocol will never produce a future with an error. // Instead the `stream` will produce `Ok(Err(...))`. // `stream` can only produce an `Err` if `listening_stream` produces an `Err`. - let stream = AndThenStream { stream: listener, fun: self.fun }; + let stream = AndThenStream { + stream: listener, + fun: self.fun, + }; Ok(stream) } fn dial(self, addr: Multiaddr) -> Result> { - let dialed_fut = self.transport.dial(addr.clone()).map_err(|err| err.map(EitherError::A))?; + let dialed_fut = self + .transport + .dial(addr.clone()) + .map_err(|err| err.map(EitherError::A))?; let future = AndThenFuture { inner: Either::Left(Box::pin(dialed_fut)), args: Some((self.fun, ConnectedPoint::Dialer { address: addr })), @@ -83,19 +95,23 @@ where pub struct AndThenStream { #[pin] stream: TListener, - fun: TMap + fun: TMap, } -impl Stream for AndThenStream +impl Stream + for AndThenStream where TListener: TryStream, Error = TTransErr>, TListUpgr: TryFuture, TMap: FnOnce(TTransOut, ConnectedPoint) -> TMapOut + Clone, - TMapOut: TryFuture + TMapOut: TryFuture, { type Item = Result< - ListenerEvent, EitherError>, - EitherError + ListenerEvent< + AndThenFuture, + EitherError, + >, + EitherError, >; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -103,10 +119,14 @@ where match TryStream::try_poll_next(this.stream, cx) { Poll::Ready(Some(Ok(event))) => { let event = match event { - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => { + ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + } => { let point = ConnectedPoint::Listener { local_addr: local_addr.clone(), - send_back_addr: remote_addr.clone() + send_back_addr: remote_addr.clone(), }; ListenerEvent::Upgrade { upgrade: AndThenFuture { @@ -115,7 +135,7 @@ where marker: PhantomPinned, }, local_addr, - remote_addr + remote_addr, } } ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), @@ -127,7 +147,7 @@ where } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(EitherError::A(err)))), Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending + Poll::Pending => Poll::Pending, } } } @@ -159,7 +179,10 @@ where Poll::Ready(Err(err)) => return Poll::Ready(Err(EitherError::A(err))), Poll::Pending => return Poll::Pending, }; - let (f, a) = self.args.take().expect("AndThenFuture has already finished."); + let (f, a) = self + .args + .take() + .expect("AndThenFuture has already finished."); f(item, a) } Either::Right(future) => { @@ -176,5 +199,4 @@ where } } -impl Unpin for AndThenFuture { -} +impl Unpin for AndThenFuture {} diff --git a/core/src/transport/boxed.rs b/core/src/transport/boxed.rs index 5322b517dbe..001a0c9fdf3 100644 --- a/core/src/transport/boxed.rs +++ b/core/src/transport/boxed.rs @@ -45,7 +45,8 @@ pub struct Boxed { } type Dial = Pin> + Send>>; -type Listener = Pin, io::Error>>> + Send>>; +type Listener = + Pin, io::Error>>> + Send>>; type ListenerUpgrade = Pin> + Send>>; trait Abstract { @@ -64,12 +65,16 @@ where { fn listen_on(&self, addr: Multiaddr) -> Result, TransportError> { let listener = Transport::listen_on(self.clone(), addr).map_err(|e| e.map(box_err))?; - let fut = listener.map_ok(|event| - event.map(|upgrade| { - let up = upgrade.map_err(box_err); - Box::pin(up) as ListenerUpgrade - }).map_err(box_err) - ).map_err(box_err); + let fut = listener + .map_ok(|event| { + event + .map(|upgrade| { + let up = upgrade.map_err(box_err); + Box::pin(up) as ListenerUpgrade + }) + .map_err(box_err) + }) + .map_err(box_err); Ok(Box::pin(fut)) } diff --git a/core/src/transport/choice.rs b/core/src/transport/choice.rs index 3488b06884d..e9545617f09 100644 --- a/core/src/transport/choice.rs +++ b/core/src/transport/choice.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::either::{EitherListenStream, EitherOutput, EitherError, EitherFuture}; +use crate::either::{EitherError, EitherFuture, EitherListenStream, EitherOutput}; use crate::transport::{Transport, TransportError}; use multiaddr::Multiaddr; @@ -47,13 +47,17 @@ where let addr = match self.0.listen_on(addr) { Ok(listener) => return Ok(EitherListenStream::First(listener)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => return Err(TransportError::Other(EitherError::A(err))), + Err(TransportError::Other(err)) => { + return Err(TransportError::Other(EitherError::A(err))) + } }; let addr = match self.1.listen_on(addr) { Ok(listener) => return Ok(EitherListenStream::Second(listener)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => return Err(TransportError::Other(EitherError::B(err))), + Err(TransportError::Other(err)) => { + return Err(TransportError::Other(EitherError::B(err))) + } }; Err(TransportError::MultiaddrNotSupported(addr)) @@ -63,13 +67,17 @@ where let addr = match self.0.dial(addr) { Ok(connec) => return Ok(EitherFuture::First(connec)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => return Err(TransportError::Other(EitherError::A(err))), + Err(TransportError::Other(err)) => { + return Err(TransportError::Other(EitherError::A(err))) + } }; let addr = match self.1.dial(addr) { Ok(connec) => return Ok(EitherFuture::Second(connec)), Err(TransportError::MultiaddrNotSupported(addr)) => addr, - Err(TransportError::Other(err)) => return Err(TransportError::Other(EitherError::B(err))), + Err(TransportError::Other(err)) => { + return Err(TransportError::Other(EitherError::B(err))) + } }; Err(TransportError::MultiaddrNotSupported(addr)) diff --git a/core/src/transport/dummy.rs b/core/src/transport/dummy.rs index 5839a6a5928..a4eaa14901d 100644 --- a/core/src/transport/dummy.rs +++ b/core/src/transport/dummy.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{Transport, TransportError, ListenerEvent}; +use crate::transport::{ListenerEvent, Transport, TransportError}; use crate::Multiaddr; use futures::{prelude::*, task::Context, task::Poll}; use std::{fmt, io, marker::PhantomData, pin::Pin}; @@ -56,7 +56,9 @@ impl Clone for DummyTransport { impl Transport for DummyTransport { type Output = TOut; type Error = io::Error; - type Listener = futures::stream::Pending, Self::Error>>; + type Listener = futures::stream::Pending< + Result, Self::Error>, + >; type ListenerUpgrade = futures::future::Pending>; type Dial = futures::future::Pending>; @@ -83,29 +85,29 @@ impl fmt::Debug for DummyStream { } impl AsyncRead for DummyStream { - fn poll_read(self: Pin<&mut Self>, _: &mut Context<'_>, _: &mut [u8]) - -> Poll> - { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut [u8], + ) -> Poll> { Poll::Ready(Err(io::ErrorKind::Other.into())) } } impl AsyncWrite for DummyStream { - fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) - -> Poll> - { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &[u8], + ) -> Poll> { Poll::Ready(Err(io::ErrorKind::Other.into())) } - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) - -> Poll> - { + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Err(io::ErrorKind::Other.into())) } - fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) - -> Poll> - { + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Err(io::ErrorKind::Other.into())) } } diff --git a/core/src/transport/map.rs b/core/src/transport/map.rs index 0305af6626d..4493507c1d9 100644 --- a/core/src/transport/map.rs +++ b/core/src/transport/map.rs @@ -19,8 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{ + transport::{ListenerEvent, Transport, TransportError}, ConnectedPoint, - transport::{Transport, TransportError, ListenerEvent} }; use futures::prelude::*; use multiaddr::Multiaddr; @@ -28,7 +28,10 @@ use std::{pin::Pin, task::Context, task::Poll}; /// See `Transport::map`. #[derive(Debug, Copy, Clone)] -pub struct Map { transport: T, fun: F } +pub struct Map { + transport: T, + fun: F, +} impl Map { pub(crate) fn new(transport: T, fun: F) -> Self { @@ -39,7 +42,7 @@ impl Map { impl Transport for Map where T: Transport, - F: FnOnce(T::Output, ConnectedPoint) -> D + Clone + F: FnOnce(T::Output, ConnectedPoint) -> D + Clone, { type Output = D; type Error = T::Error; @@ -49,13 +52,19 @@ where fn listen_on(self, addr: Multiaddr) -> Result> { let stream = self.transport.listen_on(addr)?; - Ok(MapStream { stream, fun: self.fun }) + Ok(MapStream { + stream, + fun: self.fun, + }) } fn dial(self, addr: Multiaddr) -> Result> { let future = self.transport.dial(addr.clone())?; let p = ConnectedPoint::Dialer { address: addr }; - Ok(MapFuture { inner: future, args: Some((self.fun, p)) }) + Ok(MapFuture { + inner: future, + args: Some((self.fun, p)), + }) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { @@ -68,13 +77,17 @@ where /// Maps a function over every stream item. #[pin_project::pin_project] #[derive(Clone, Debug)] -pub struct MapStream { #[pin] stream: T, fun: F } +pub struct MapStream { + #[pin] + stream: T, + fun: F, +} impl Stream for MapStream where T: TryStream, Error = E>, X: TryFuture, - F: FnOnce(A, ConnectedPoint) -> B + Clone + F: FnOnce(A, ConnectedPoint) -> B + Clone, { type Item = Result, E>, E>; @@ -83,18 +96,22 @@ where match TryStream::try_poll_next(this.stream, cx) { Poll::Ready(Some(Ok(event))) => { let event = match event { - ListenerEvent::Upgrade { upgrade, local_addr, remote_addr } => { + ListenerEvent::Upgrade { + upgrade, + local_addr, + remote_addr, + } => { let point = ConnectedPoint::Listener { local_addr: local_addr.clone(), - send_back_addr: remote_addr.clone() + send_back_addr: remote_addr.clone(), }; ListenerEvent::Upgrade { upgrade: MapFuture { inner: upgrade, - args: Some((this.fun.clone(), point)) + args: Some((this.fun.clone(), point)), }, local_addr, - remote_addr + remote_addr, } } ListenerEvent::NewAddress(a) => ListenerEvent::NewAddress(a), @@ -105,7 +122,7 @@ where } Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending + Poll::Pending => Poll::Pending, } } } @@ -118,13 +135,13 @@ where pub struct MapFuture { #[pin] inner: T, - args: Option<(F, ConnectedPoint)> + args: Option<(F, ConnectedPoint)>, } impl Future for MapFuture where T: TryFuture, - F: FnOnce(A, ConnectedPoint) -> B + F: FnOnce(A, ConnectedPoint) -> B, { type Output = Result; diff --git a/core/src/transport/map_err.rs b/core/src/transport/map_err.rs index c0be6485204..df26214435a 100644 --- a/core/src/transport/map_err.rs +++ b/core/src/transport/map_err.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::transport::{Transport, TransportError, ListenerEvent}; +use crate::transport::{ListenerEvent, Transport, TransportError}; use futures::prelude::*; use multiaddr::Multiaddr; use std::{error, pin::Pin, task::Context, task::Poll}; @@ -53,14 +53,17 @@ where let map = self.map; match self.transport.listen_on(addr) { Ok(stream) => Ok(MapErrListener { inner: stream, map }), - Err(err) => Err(err.map(map)) + Err(err) => Err(err.map(map)), } } fn dial(self, addr: Multiaddr) -> Result> { let map = self.map; match self.transport.dial(addr) { - Ok(future) => Ok(MapErrDial { inner: future, map: Some(map) }), + Ok(future) => Ok(MapErrDial { + inner: future, + map: Some(map), + }), Err(err) => Err(err.map(map)), } } @@ -92,11 +95,9 @@ where Poll::Ready(Some(Ok(event))) => { let map = &*this.map; let event = event - .map(move |value| { - MapErrListenerUpgrade { - inner: value, - map: Some(map.clone()) - } + .map(move |value| MapErrListenerUpgrade { + inner: value, + map: Some(map.clone()), }) .map_err(|err| (map.clone())(err)); Poll::Ready(Some(Ok(event))) @@ -117,7 +118,8 @@ pub struct MapErrListenerUpgrade { } impl Future for MapErrListenerUpgrade -where T: Transport, +where + T: Transport, F: FnOnce(T::Error) -> TErr, { type Output = Result; diff --git a/core/src/transport/memory.rs b/core/src/transport/memory.rs index 043dcee06b7..3b4706c9adb 100644 --- a/core/src/transport/memory.rs +++ b/core/src/transport/memory.rs @@ -18,11 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{Transport, transport::{TransportError, ListenerEvent}}; +use crate::{ + transport::{ListenerEvent, TransportError}, + Transport, +}; use fnv::FnvHashMap; -use futures::{future::{self, Ready}, prelude::*, channel::mpsc, task::Context, task::Poll}; +use futures::{ + channel::mpsc, + future::{self, Ready}, + prelude::*, + task::Context, + task::Poll, +}; use lazy_static::lazy_static; -use multiaddr::{Protocol, Multiaddr}; +use multiaddr::{Multiaddr, Protocol}; use parking_lot::Mutex; use rw_stream_sink::RwStreamSink; use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin}; @@ -66,7 +75,7 @@ impl Hub { let (tx, rx) = mpsc::channel(2); match hub.entry(port) { Entry::Occupied(_) => return None, - Entry::Vacant(e) => e.insert(tx) + Entry::Vacant(e) => e.insert(tx), }; Some((rx, port)) @@ -103,7 +112,8 @@ impl DialFuture { fn new(port: NonZeroU64) -> Option { let sender = HUB.get(&port)?; - let (_dial_port_channel, dial_port) = HUB.register_port(0) + let (_dial_port_channel, dial_port) = HUB + .register_port(0) .expect("there to be some random unoccupied port."); let (a_tx, a_rx) = mpsc::channel(4096); @@ -129,14 +139,15 @@ impl Future for DialFuture { type Output = Result>, MemoryTransportError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.sender.poll_ready(cx) { Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)), } - let channel_to_send = self.channel_to_send.take() + let channel_to_send = self + .channel_to_send + .take() .expect("Future should not be polled again once complete"); let dial_port = self.dial_port; match self.sender.start_send((channel_to_send, dial_port)) { @@ -144,8 +155,10 @@ impl Future for DialFuture { Ok(()) => {} } - Poll::Ready(Ok(self.channel_to_return.take() - .expect("Future should not be polled again once complete"))) + Poll::Ready(Ok(self + .channel_to_return + .take() + .expect("Future should not be polled again once complete"))) } } @@ -172,7 +185,7 @@ impl Transport for MemoryTransport { port, addr: Protocol::Memory(port.get()).into(), receiver: rx, - tell_listen_addr: true + tell_listen_addr: true, }; Ok(listener) @@ -226,16 +239,19 @@ pub struct Listener { /// Receives incoming connections. receiver: ChannelReceiver, /// Generate `ListenerEvent::NewAddress` to inform about our listen address. - tell_listen_addr: bool + tell_listen_addr: bool, } impl Stream for Listener { - type Item = Result>, MemoryTransportError>>, MemoryTransportError>, MemoryTransportError>; + type Item = Result< + ListenerEvent>, MemoryTransportError>>, MemoryTransportError>, + MemoryTransportError, + >; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.tell_listen_addr { self.tell_listen_addr = false; - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))) + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))); } let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) { @@ -247,7 +263,7 @@ impl Stream for Listener { let event = ListenerEvent::Upgrade { upgrade: future::ready(Ok(channel)), local_addr: self.addr.clone(), - remote_addr: Protocol::Memory(dial_port.get()).into() + remote_addr: Protocol::Memory(dial_port.get()).into(), }; Poll::Ready(Some(Ok(event))) @@ -267,9 +283,9 @@ fn parse_memory_addr(a: &Multiaddr) -> Result { match protocols.next() { Some(Protocol::Memory(port)) => match protocols.next() { None | Some(Protocol::P2p(_)) => Ok(port), - _ => Err(()) - } - _ => Err(()) + _ => Err(()), + }, + _ => Err(()), } } @@ -294,8 +310,7 @@ pub struct Chan> { dial_port: Option, } -impl Unpin for Chan { -} +impl Unpin for Chan {} impl Stream for Chan { type Item = Result; @@ -313,12 +328,15 @@ impl Sink for Chan { type Error = io::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.outgoing.poll_ready(cx) + self.outgoing + .poll_ready(cx) .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into())) } fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { - self.outgoing.start_send(item).map_err(|_| io::ErrorKind::BrokenPipe.into()) + self.outgoing + .start_send(item) + .map_err(|_| io::ErrorKind::BrokenPipe.into()) } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { @@ -354,30 +372,59 @@ mod tests { assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5)); assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(())); assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0)); - assert_eq!(parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()), Err(())); - assert_eq!(parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()), Err(())); - assert_eq!(parse_memory_addr(&"/memory/1234567890".parse().unwrap()), Ok(1_234_567_890)); + assert_eq!( + parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()), + Err(()) + ); + assert_eq!( + parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()), + Err(()) + ); + assert_eq!( + parse_memory_addr(&"/memory/1234567890".parse().unwrap()), + Ok(1_234_567_890) + ); } #[test] fn listening_twice() { let transport = MemoryTransport::default(); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok()); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok()); - let _listener = transport.listen_on("/memory/1639174018481".parse().unwrap()).unwrap(); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_err()); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_err()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_ok()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_ok()); + let _listener = transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .unwrap(); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_err()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_err()); drop(_listener); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok()); - assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_ok()); + assert!(transport + .listen_on("/memory/1639174018481".parse().unwrap()) + .is_ok()); } #[test] fn port_not_in_use() { let transport = MemoryTransport::default(); - assert!(transport.dial("/memory/810172461024613".parse().unwrap()).is_err()); - let _listener = transport.listen_on("/memory/810172461024613".parse().unwrap()).unwrap(); - assert!(transport.dial("/memory/810172461024613".parse().unwrap()).is_ok()); + assert!(transport + .dial("/memory/810172461024613".parse().unwrap()) + .is_err()); + let _listener = transport + .listen_on("/memory/810172461024613".parse().unwrap()) + .unwrap(); + assert!(transport + .dial("/memory/810172461024613".parse().unwrap()) + .is_ok()); } #[test] @@ -395,9 +442,11 @@ mod tests { let listener = async move { let listener = t1.listen_on(t1_addr.clone()).unwrap(); - let upgrade = listener.filter_map(|ev| futures::future::ready( - ListenerEvent::into_upgrade(ev.unwrap()) - )).next().await.unwrap(); + let upgrade = listener + .filter_map(|ev| futures::future::ready(ListenerEvent::into_upgrade(ev.unwrap()))) + .next() + .await + .unwrap(); let mut socket = upgrade.0.await.unwrap(); @@ -422,16 +471,14 @@ mod tests { #[test] fn dialer_address_unequal_to_listener_address() { - let listener_addr: Multiaddr = Protocol::Memory( - rand::random::().saturating_add(1), - ).into(); + let listener_addr: Multiaddr = + Protocol::Memory(rand::random::().saturating_add(1)).into(); let listener_addr_cloned = listener_addr.clone(); let listener_transport = MemoryTransport::default(); let listener = async move { - let mut listener = listener_transport.listen_on(listener_addr.clone()) - .unwrap(); + let mut listener = listener_transport.listen_on(listener_addr.clone()).unwrap(); while let Some(ev) = listener.next().await { if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { assert!( @@ -444,7 +491,8 @@ mod tests { }; let dialer = async move { - MemoryTransport::default().dial(listener_addr_cloned) + MemoryTransport::default() + .dial(listener_addr_cloned) .unwrap() .await .unwrap(); @@ -458,21 +506,18 @@ mod tests { let (terminate, should_terminate) = futures::channel::oneshot::channel(); let (terminated, is_terminated) = futures::channel::oneshot::channel(); - let listener_addr: Multiaddr = Protocol::Memory( - rand::random::().saturating_add(1), - ).into(); + let listener_addr: Multiaddr = + Protocol::Memory(rand::random::().saturating_add(1)).into(); let listener_addr_cloned = listener_addr.clone(); let listener_transport = MemoryTransport::default(); let listener = async move { - let mut listener = listener_transport.listen_on(listener_addr.clone()) - .unwrap(); + let mut listener = listener_transport.listen_on(listener_addr.clone()).unwrap(); while let Some(ev) = listener.next().await { if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { - let dialer_port = NonZeroU64::new( - parse_memory_addr(&remote_addr).unwrap(), - ).unwrap(); + let dialer_port = + NonZeroU64::new(parse_memory_addr(&remote_addr).unwrap()).unwrap(); assert!( HUB.get(&dialer_port).is_some(), @@ -493,7 +538,8 @@ mod tests { }; let dialer = async move { - let _chan = MemoryTransport::default().dial(listener_addr_cloned) + let _chan = MemoryTransport::default() + .dial(listener_addr_cloned) .unwrap() .await .unwrap(); diff --git a/core/src/transport/timeout.rs b/core/src/transport/timeout.rs index d55d007df08..8084dcb7521 100644 --- a/core/src/transport/timeout.rs +++ b/core/src/transport/timeout.rs @@ -24,7 +24,10 @@ //! underlying `Transport`. // TODO: add example -use crate::{Multiaddr, Transport, transport::{TransportError, ListenerEvent}}; +use crate::{ + transport::{ListenerEvent, TransportError}, + Multiaddr, Transport, +}; use futures::prelude::*; use futures_timer::Delay; use std::{error, fmt, io, pin::Pin, task::Context, task::Poll, time::Duration}; @@ -82,7 +85,9 @@ where type Dial = Timeout; fn listen_on(self, addr: Multiaddr) -> Result> { - let listener = self.inner.listen_on(addr) + let listener = self + .inner + .listen_on(addr) .map_err(|err| err.map(TransportTimeoutError::Other))?; let listener = TimeoutListener { @@ -94,7 +99,9 @@ where } fn dial(self, addr: Multiaddr) -> Result> { - let dial = self.inner.dial(addr) + let dial = self + .inner + .dial(addr) .map_err(|err| err.map(TransportTimeoutError::Other))?; Ok(Timeout { inner: dial, @@ -120,13 +127,16 @@ impl Stream for TimeoutListener where InnerStream: TryStream, Error = E>, { - type Item = Result, TransportTimeoutError>, TransportTimeoutError>; + type Item = + Result, TransportTimeoutError>, TransportTimeoutError>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); let poll_out = match TryStream::try_poll_next(this.inner, cx) { - Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(TransportTimeoutError::Other(err)))), + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(Some(Err(TransportTimeoutError::Other(err)))) + } Poll::Ready(Some(Ok(v))) => v, Poll::Ready(None) => return Poll::Ready(None), Poll::Pending => return Poll::Pending, @@ -134,11 +144,9 @@ where let timeout = *this.timeout; let event = poll_out - .map(move |inner_fut| { - Timeout { - inner: inner_fut, - timer: Delay::new(timeout), - } + .map(move |inner_fut| Timeout { + inner: inner_fut, + timer: Delay::new(timeout), }) .map_err(TransportTimeoutError::Other); @@ -173,14 +181,14 @@ where let mut this = self.project(); match TryFuture::try_poll(this.inner, cx) { - Poll::Pending => {}, + Poll::Pending => {} Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)), Poll::Ready(Err(err)) => return Poll::Ready(Err(TransportTimeoutError::Other(err))), } match Pin::new(&mut this.timer).poll(cx) { Poll::Pending => Poll::Pending, - Poll::Ready(()) => Poll::Ready(Err(TransportTimeoutError::Timeout)) + Poll::Ready(()) => Poll::Ready(Err(TransportTimeoutError::Timeout)), } } } @@ -197,7 +205,8 @@ pub enum TransportTimeoutError { } impl fmt::Display for TransportTimeoutError -where TErr: fmt::Display, +where + TErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -209,7 +218,8 @@ where TErr: fmt::Display, } impl error::Error for TransportTimeoutError -where TErr: error::Error + 'static, +where + TErr: error::Error + 'static, { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { diff --git a/core/src/transport/upgrade.rs b/core/src/transport/upgrade.rs index 50bd15f53e9..4a0d50392ff 100644 --- a/core/src/transport/upgrade.rs +++ b/core/src/transport/upgrade.rs @@ -20,39 +20,24 @@ //! Configuration of transport protocol upgrades. - use crate::{ - ConnectedPoint, - Negotiated, + muxing::{StreamMuxer, StreamMuxerBox}, transport::{ - Transport, - TransportError, - and_then::AndThen, - boxed::boxed, - timeout::TransportTimeout, + and_then::AndThen, boxed::boxed, timeout::TransportTimeout, Transport, TransportError, }, - muxing::{StreamMuxer, StreamMuxerBox}, upgrade::{ - self, - Role, - Version, - AuthenticationVersion, - OutboundUpgrade, - InboundUpgrade, - UpgradeError, - OutboundUpgradeApply, - InboundUpgradeApply, - AuthenticationUpgradeApply, + self, AuthenticationUpgradeApply, AuthenticationVersion, InboundUpgrade, + InboundUpgradeApply, OutboundUpgrade, OutboundUpgradeApply, Role, UpgradeError, Version, }, - PeerId + ConnectedPoint, Negotiated, PeerId, }; -use futures::{prelude::*, ready, future::Either}; +use futures::{future::Either, prelude::*, ready}; use multiaddr::Multiaddr; use std::{ error::Error, pin::Pin, task::{Context, Poll}, - time::Duration + time::Duration, }; /// A `Builder` facilitates upgrading of a [`Transport`] for use with @@ -103,9 +88,13 @@ where /// /// * I/O upgrade: `C -> (PeerId, D)`. /// * Transport output: `C -> (PeerId, D)` - pub fn authenticate(self, upgrade: U) -> Authenticated< - AndThen AuthenticationUpgradeApply + Clone> - > where + pub fn authenticate( + self, + upgrade: U, + ) -> Authenticated< + AndThen AuthenticationUpgradeApply + Clone>, + > + where T: Transport, C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, @@ -118,9 +107,14 @@ where /// Same as [`Builder::authenticate`] with the option to choose the /// [`AuthenticationVersion`] used to upgrade the connection. - pub fn authenticate_with_version(self, upgrade: U, version: AuthenticationVersion) -> Authenticated< - AndThen AuthenticationUpgradeApply + Clone> - > where + pub fn authenticate_with_version( + self, + upgrade: U, + version: AuthenticationVersion, + ) -> Authenticated< + AndThen AuthenticationUpgradeApply + Clone>, + > + where T: Transport, C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, @@ -129,7 +123,7 @@ where E: Error + 'static, { Authenticated(Builder::new(self.inner.and_then(move |conn, endpoint| { - upgrade::apply_authentication(conn, upgrade, endpoint, version) + upgrade::apply_authentication(conn, upgrade, endpoint, version) }))) } } @@ -141,7 +135,7 @@ pub struct Authenticated(Builder); impl Authenticated where T: Transport, - T::Error: 'static + T::Error: 'static, { /// Applies an arbitrary upgrade. /// @@ -153,7 +147,19 @@ where /// /// * I/O upgrade: `C -> D`. /// * Transport output: `(PeerId, C) -> (PeerId, D)`. - pub fn apply(self, upgrade: U) -> Authenticated UpgradeAuthenticated + Clone>> + pub fn apply( + self, + upgrade: U, + ) -> Authenticated< + AndThen< + T, + impl FnOnce( + ((PeerId, Role), C), + ConnectedPoint, + ) -> UpgradeAuthenticated + + Clone, + >, + > where T: Transport, C: AsyncRead + AsyncWrite + Unpin, @@ -167,7 +173,20 @@ where /// Same as [`Authenticated::apply`] with the option to choose the /// [`Version`] used to upgrade the connection. - pub fn apply_with_version(self, upgrade: U, version: Version) -> Authenticated UpgradeAuthenticated + Clone>> + pub fn apply_with_version( + self, + upgrade: U, + version: Version, + ) -> Authenticated< + AndThen< + T, + impl FnOnce( + ((PeerId, Role), C), + ConnectedPoint, + ) -> UpgradeAuthenticated + + Clone, + >, + > where T: Transport, C: AsyncRead + AsyncWrite + Unpin, @@ -176,18 +195,18 @@ where U: OutboundUpgrade, Output = D, Error = E> + Clone, E: Error + 'static, { - Authenticated(Builder::new(self.0.inner.and_then(move |((i, r), c), _endpoint| { - let upgrade = match r { - Role::Initiator => { - Either::Left(upgrade::apply_outbound(c, upgrade, version)) - }, - Role::Responder => { - Either::Right(upgrade::apply_inbound(c, upgrade)) - + Authenticated(Builder::new(self.0.inner.and_then( + move |((i, r), c), _endpoint| { + let upgrade = match r { + Role::Initiator => Either::Left(upgrade::apply_outbound(c, upgrade, version)), + Role::Responder => Either::Right(upgrade::apply_inbound(c, upgrade)), + }; + UpgradeAuthenticated { + user_data: Some((i, r)), + upgrade, } - }; - UpgradeAuthenticated { user_data: Some((i, r)), upgrade } - }))) + }, + ))) } /// Upgrades the transport with a (sub)stream multiplexer. @@ -200,9 +219,17 @@ where /// /// * I/O upgrade: `C -> M`. /// * Transport output: `(PeerId, C) -> (PeerId, M)`. - pub fn multiplex(self, upgrade: U) -> Multiplexed< - AndThen UpgradeAuthenticated + Clone> - > where + pub fn multiplex( + self, + upgrade: U, + ) -> Multiplexed< + AndThen< + T, + impl FnOnce(((PeerId, Role), C), ConnectedPoint) -> UpgradeAuthenticated + + Clone, + >, + > + where T: Transport, C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, @@ -215,9 +242,18 @@ where /// Same as [`Authenticated::multiplex`] with the option to choose the /// [`Version`] used to upgrade the connection. - pub fn multiplex_with_version(self, upgrade: U, version: Version) -> Multiplexed< - AndThen UpgradeAuthenticated + Clone> - > where + pub fn multiplex_with_version( + self, + upgrade: U, + version: Version, + ) -> Multiplexed< + AndThen< + T, + impl FnOnce(((PeerId, Role), C), ConnectedPoint) -> UpgradeAuthenticated + + Clone, + >, + > + where T: Transport, C: AsyncRead + AsyncWrite + Unpin, M: StreamMuxer, @@ -227,15 +263,13 @@ where { Multiplexed(self.0.inner.and_then(move |((i, r), c), _endpoint| { let upgrade = match r { - Role::Initiator => { - Either::Left(upgrade::apply_outbound(c, upgrade, version)) - }, - Role::Responder => { - Either::Right(upgrade::apply_inbound(c, upgrade)) - - } + Role::Initiator => Either::Left(upgrade::apply_outbound(c, upgrade, version)), + Role::Responder => Either::Right(upgrade::apply_inbound(c, upgrade)), }; - UpgradeAuthenticated { user_data: Some(i), upgrade } + UpgradeAuthenticated { + user_data: Some(i), + upgrade, + } })) } } @@ -259,7 +293,7 @@ impl Future for UpgradeAuthenticated where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = M, Error = E>, - U: OutboundUpgrade, Output = M, Error = E> + U: OutboundUpgrade, Output = M, Error = E>, { type Output = Result<(D, M), UpgradeError>; @@ -269,7 +303,10 @@ where Ok(m) => m, Err(err) => return Poll::Ready(Err(err)), }; - let user_data = this.user_data.take().expect("UpgradeAuthenticated future polled after completion."); + let user_data = this + .user_data + .take() + .expect("UpgradeAuthenticated future polled after completion."); Poll::Ready(Ok((user_data, m))) } } @@ -291,7 +328,7 @@ impl Multiplexed { T::Error: Send + Sync, M: StreamMuxer + Send + Sync + 'static, M::Substream: Send + 'static, - M::OutboundSubstream: Send + 'static + M::OutboundSubstream: Send + 'static, { boxed(self.map(|(i, m), _| (i, StreamMuxerBox::new(m)))) } diff --git a/core/src/upgrade.rs b/core/src/upgrade.rs index bb637c63db6..bdc53153074 100644 --- a/core/src/upgrade.rs +++ b/core/src/upgrade.rs @@ -69,24 +69,24 @@ mod transfer; use futures::future::Future; -pub use crate::Negotiated; -pub use multistream_select::{NegotiatedComplete, NegotiationError, ProtocolError, Role}; +#[allow(deprecated)] +pub use self::transfer::ReadOneError; pub use self::{ apply::{ - apply, apply_authentication, apply_inbound, apply_outbound, InboundUpgradeApply, - OutboundUpgradeApply, AuthenticationUpgradeApply, Version, AuthenticationVersion, + apply, apply_authentication, apply_inbound, apply_outbound, AuthenticationUpgradeApply, + AuthenticationVersion, InboundUpgradeApply, OutboundUpgradeApply, Version, }, denied::DeniedUpgrade, either::EitherUpgrade, error::UpgradeError, from_fn::{from_fn, FromFnUpgrade}, - map::{MapInboundUpgrade, MapOutboundUpgrade, MapInboundUpgradeErr, MapOutboundUpgradeErr}, + map::{MapInboundUpgrade, MapInboundUpgradeErr, MapOutboundUpgrade, MapOutboundUpgradeErr}, optional::OptionalUpgrade, select::SelectUpgrade, - transfer::{write_length_prefixed, write_varint, read_length_prefixed, read_varint}, + transfer::{read_length_prefixed, read_varint, write_length_prefixed, write_varint}, }; -#[allow(deprecated)] -pub use self::transfer::ReadOneError; +pub use crate::Negotiated; +pub use multistream_select::{NegotiatedComplete, NegotiationError, ProtocolError, Role}; /// Types serving as protocol names. /// @@ -170,7 +170,7 @@ pub trait InboundUpgradeExt: InboundUpgrade { fn map_inbound(self, f: F) -> MapInboundUpgrade where Self: Sized, - F: FnOnce(Self::Output) -> T + F: FnOnce(Self::Output) -> T, { MapInboundUpgrade::new(self, f) } @@ -179,7 +179,7 @@ pub trait InboundUpgradeExt: InboundUpgrade { fn map_inbound_err(self, f: F) -> MapInboundUpgradeErr where Self: Sized, - F: FnOnce(Self::Error) -> T + F: FnOnce(Self::Error) -> T, { MapInboundUpgradeErr::new(self, f) } @@ -210,7 +210,7 @@ pub trait OutboundUpgradeExt: OutboundUpgrade { fn map_outbound(self, f: F) -> MapOutboundUpgrade where Self: Sized, - F: FnOnce(Self::Output) -> T + F: FnOnce(Self::Output) -> T, { MapOutboundUpgrade::new(self, f) } @@ -219,7 +219,7 @@ pub trait OutboundUpgradeExt: OutboundUpgrade { fn map_outbound_err(self, f: F) -> MapOutboundUpgradeErr where Self: Sized, - F: FnOnce(Self::Error) -> T + F: FnOnce(Self::Error) -> T, { MapOutboundUpgradeErr::new(self, f) } diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index 97dddc07663..b63450468ae 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -18,14 +18,17 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use crate::upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeError}; use crate::{ConnectedPoint, Negotiated, PeerId}; -use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName}; -use futures::{future::{Either, TryFutureExt, MapOk}, prelude::*}; +use futures::{ + future::{Either, MapOk, TryFutureExt}, + prelude::*, +}; use log::debug; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; use std::{iter, mem, pin::Pin, task::Context, task::Poll}; -pub use multistream_select::{Role, NegotiationError}; +pub use multistream_select::{NegotiationError, Role}; /// Wrapper around multistream-select `Version`. /// @@ -52,7 +55,9 @@ impl Default for Version { match multistream_select::Version::default() { multistream_select::Version::V1 => Version::V1, multistream_select::Version::V1Lazy => Version::V1Lazy, - multistream_select::Version::V1SimultaneousOpen => unreachable!("see `v1_sim_open_is_not_default`"), + multistream_select::Version::V1SimultaneousOpen => { + unreachable!("see `v1_sim_open_is_not_default`") + } } } } @@ -61,8 +66,12 @@ impl Default for Version { /// /// Note: Use [`apply_authentication`] when negotiating an authentication protocol on top of a /// transport allowing simultaneously opened connections. -pub fn apply(conn: C, up: U, cp: ConnectedPoint, v: Version) - -> Either, OutboundUpgradeApply> +pub fn apply( + conn: C, + up: U, + cp: ConnectedPoint, + v: Version, +) -> Either, OutboundUpgradeApply> where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade> + OutboundUpgrade>, @@ -80,10 +89,16 @@ where C: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade>, { - let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); + let iter = up + .protocol_info() + .into_iter() + .map(NameWrap as fn(_) -> NameWrap<_>); let future = multistream_select::listener_select_proto(conn, iter); InboundUpgradeApply { - inner: InboundUpgradeApplyState::Init { future, upgrade: up } + inner: InboundUpgradeApplyState::Init { + future, + upgrade: up, + }, } } @@ -91,12 +106,18 @@ where pub fn apply_outbound(conn: C, up: U, v: Version) -> OutboundUpgradeApply where C: AsyncRead + AsyncWrite + Unpin, - U: OutboundUpgrade> + U: OutboundUpgrade>, { - let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); + let iter = up + .protocol_info() + .into_iter() + .map(NameWrap as fn(_) -> NameWrap<_>); let future = multistream_select::dialer_select_proto(conn, iter, v.into()); OutboundUpgradeApply { - inner: OutboundUpgradeApplyState::Init { future, upgrade: up } + inner: OutboundUpgradeApplyState::Init { + future, + upgrade: up, + }, } } @@ -104,9 +125,9 @@ where pub struct InboundUpgradeApply where C: AsyncRead + AsyncWrite + Unpin, - U: InboundUpgrade> + U: InboundUpgrade>, { - inner: InboundUpgradeApplyState + inner: InboundUpgradeApplyState, } enum InboundUpgradeApplyState @@ -119,9 +140,9 @@ where upgrade: U, }, Upgrade { - future: Pin> + future: Pin>, }, - Undefined + Undefined, } impl Unpin for InboundUpgradeApply @@ -141,36 +162,40 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) { - InboundUpgradeApplyState::Init { mut future, upgrade } => { + InboundUpgradeApplyState::Init { + mut future, + upgrade, + } => { let (info, io) = match Future::poll(Pin::new(&mut future), cx)? { Poll::Ready(x) => x, Poll::Pending => { self.inner = InboundUpgradeApplyState::Init { future, upgrade }; - return Poll::Pending + return Poll::Pending; } }; self.inner = InboundUpgradeApplyState::Upgrade { - future: Box::pin(upgrade.upgrade_inbound(io, info.0)) + future: Box::pin(upgrade.upgrade_inbound(io, info.0)), }; } InboundUpgradeApplyState::Upgrade { mut future } => { match Future::poll(Pin::new(&mut future), cx) { Poll::Pending => { self.inner = InboundUpgradeApplyState::Upgrade { future }; - return Poll::Pending + return Poll::Pending; } Poll::Ready(Ok(x)) => { debug!("Successfully applied negotiated protocol"); - return Poll::Ready(Ok(x)) + return Poll::Ready(Ok(x)); } Poll::Ready(Err(e)) => { debug!("Failed to apply negotiated protocol"); - return Poll::Ready(Err(UpgradeError::Apply(e))) + return Poll::Ready(Err(UpgradeError::Apply(e))); } } } - InboundUpgradeApplyState::Undefined => + InboundUpgradeApplyState::Undefined => { panic!("InboundUpgradeApplyState::poll called after completion") + } } } } @@ -180,24 +205,24 @@ where pub struct OutboundUpgradeApply where C: AsyncRead + AsyncWrite + Unpin, - U: OutboundUpgrade> + U: OutboundUpgrade>, { - inner: OutboundUpgradeApplyState + inner: OutboundUpgradeApplyState, } enum OutboundUpgradeApplyState where C: AsyncRead + AsyncWrite + Unpin, - U: OutboundUpgrade> + U: OutboundUpgrade>, { Init { future: DialerSelectFuture::IntoIter>>, - upgrade: U + upgrade: U, }, Upgrade { - future: Pin> + future: Pin>, }, - Undefined + Undefined, } impl Unpin for OutboundUpgradeApply @@ -217,12 +242,15 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) { - OutboundUpgradeApplyState::Init { mut future, upgrade } => { + OutboundUpgradeApplyState::Init { + mut future, + upgrade, + } => { let (info, connection, role) = match Future::poll(Pin::new(&mut future), cx)? { Poll::Ready(x) => x, Poll::Pending => { self.inner = OutboundUpgradeApplyState::Init { future, upgrade }; - return Poll::Pending + return Poll::Pending; } }; assert_eq!( @@ -231,18 +259,18 @@ where as `Initiator` or fail.", ); self.inner = OutboundUpgradeApplyState::Upgrade { - future: Box::pin(upgrade.upgrade_outbound(connection, info.0)) + future: Box::pin(upgrade.upgrade_outbound(connection, info.0)), }; } OutboundUpgradeApplyState::Upgrade { mut future } => { match Future::poll(Pin::new(&mut future), cx) { Poll::Pending => { self.inner = OutboundUpgradeApplyState::Upgrade { future }; - return Poll::Pending + return Poll::Pending; } Poll::Ready(Ok(x)) => { debug!("Successfully applied negotiated protocol"); - return Poll::Ready(Ok(x)) + return Poll::Ready(Ok(x)); } Poll::Ready(Err(e)) => { debug!("Failed to apply negotiated protocol"); @@ -250,8 +278,9 @@ where } } } - OutboundUpgradeApplyState::Undefined => + OutboundUpgradeApplyState::Undefined => { panic!("OutboundUpgradeApplyState::poll called after completion") + } } } } @@ -267,7 +296,7 @@ pub enum AuthenticationVersion { /// See [`multistream_select::Version::V1Lazy`]. V1Lazy, /// See [`multistream_select::Version::V1SimultaneousOpen`]. - V1SimultaneousOpen + V1SimultaneousOpen, } impl Default for AuthenticationVersion { @@ -275,7 +304,9 @@ impl Default for AuthenticationVersion { match multistream_select::Version::default() { multistream_select::Version::V1 => AuthenticationVersion::V1, multistream_select::Version::V1Lazy => AuthenticationVersion::V1Lazy, - multistream_select::Version::V1SimultaneousOpen => AuthenticationVersion::V1SimultaneousOpen, + multistream_select::Version::V1SimultaneousOpen => { + AuthenticationVersion::V1SimultaneousOpen + } } } } @@ -285,7 +316,9 @@ impl From for multistream_select::Version { match v { AuthenticationVersion::V1 => multistream_select::Version::V1, AuthenticationVersion::V1Lazy => multistream_select::Version::V1Lazy, - AuthenticationVersion::V1SimultaneousOpen => multistream_select::Version::V1SimultaneousOpen, + AuthenticationVersion::V1SimultaneousOpen => { + multistream_select::Version::V1SimultaneousOpen + } } } } @@ -295,30 +328,40 @@ impl From for multistream_select::Version { /// Note: This is like [`apply`] with additional support for transports allowing simultaneously /// opened connections. Unless run on such transport and used to negotiate the authentication /// protocol you likely want to use [`apply`] instead of [`apply_authentication`]. -pub fn apply_authentication(conn: C, up: U, cp: ConnectedPoint, v: AuthenticationVersion) - -> AuthenticationUpgradeApply +pub fn apply_authentication( + conn: C, + up: U, + cp: ConnectedPoint, + v: AuthenticationVersion, +) -> AuthenticationUpgradeApply where C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = (PeerId, D)>, - U: OutboundUpgrade, Output = (PeerId, D), Error = >>::Error> + Clone, - + U: OutboundUpgrade< + Negotiated, + Output = (PeerId, D), + Error = >>::Error, + > + Clone, { fn add_responder(input: (P, C)) -> (P, C, Role) { (input.0, input.1, Role::Responder) } - let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); + let iter = up + .protocol_info() + .into_iter() + .map(NameWrap as fn(_) -> NameWrap<_>); AuthenticationUpgradeApply { - inner: AuthenticationUpgradeApplyState::Init{ + inner: AuthenticationUpgradeApplyState::Init { future: match cp { ConnectedPoint::Dialer { .. } => Either::Left( multistream_select::dialer_select_proto(conn, iter, v.into()), ), ConnectedPoint::Listener { .. } => Either::Right( multistream_select::listener_select_proto(conn, iter) - .map_ok(add_responder as fn (_) -> _), + .map_ok(add_responder as fn(_) -> _), ), }, upgrade: up, @@ -330,7 +373,7 @@ pub struct AuthenticationUpgradeApply where U: InboundUpgrade> + OutboundUpgrade>, { - inner: AuthenticationUpgradeApplyState + inner: AuthenticationUpgradeApplyState, } impl Unpin for AuthenticationUpgradeApply @@ -346,10 +389,13 @@ where { Init { future: Either< - multistream_select::DialerSelectFuture::IntoIter>>, + multistream_select::DialerSelectFuture< + C, + NameWrapIter<::IntoIter>, + >, MapOk< ListenerSelectFuture>, - fn((NameWrap, Negotiated)) -> (NameWrap, Negotiated, Role) + fn((NameWrap, Negotiated)) -> (NameWrap, Negotiated, Role), >, >, upgrade: U, @@ -361,7 +407,7 @@ where Pin>>::Future>>, >, }, - Undefined + Undefined, } impl Future for AuthenticationUpgradeApply @@ -369,57 +415,65 @@ where C: AsyncRead + AsyncWrite + Unpin, D: AsyncRead + AsyncWrite + Unpin, U: InboundUpgrade, Output = (PeerId, D)>, - U: OutboundUpgrade, Output = (PeerId, D), Error = >>::Error> + Clone, + U: OutboundUpgrade< + Negotiated, + Output = (PeerId, D), + Error = >>::Error, + > + Clone, { - type Output = Result< - ((PeerId, Role), D), - UpgradeError<>>::Error>, - >; + type Output = + Result<((PeerId, Role), D), UpgradeError<>>::Error>>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match mem::replace(&mut self.inner, AuthenticationUpgradeApplyState::Undefined) { - AuthenticationUpgradeApplyState::Init { mut future, upgrade } => { + AuthenticationUpgradeApplyState::Init { + mut future, + upgrade, + } => { let (info, io, role) = match Future::poll(Pin::new(&mut future), cx)? { Poll::Ready(x) => x, Poll::Pending => { self.inner = AuthenticationUpgradeApplyState::Init { future, upgrade }; - return Poll::Pending + return Poll::Pending; } }; let fut = match role { - Role::Initiator => Either::Left(Box::pin(upgrade.upgrade_outbound(io, info.0))), - Role::Responder => Either::Right(Box::pin(upgrade.upgrade_inbound(io, info.0))), - }; - self.inner = AuthenticationUpgradeApplyState::Upgrade { - future: fut, - role, + Role::Initiator => { + Either::Left(Box::pin(upgrade.upgrade_outbound(io, info.0))) + } + Role::Responder => { + Either::Right(Box::pin(upgrade.upgrade_inbound(io, info.0))) + } }; + self.inner = AuthenticationUpgradeApplyState::Upgrade { future: fut, role }; } AuthenticationUpgradeApplyState::Upgrade { mut future, role } => { match Future::poll(Pin::new(&mut future), cx) { Poll::Pending => { self.inner = AuthenticationUpgradeApplyState::Upgrade { future, role }; - return Poll::Pending + return Poll::Pending; } Poll::Ready(Ok((peer_id, d))) => { debug!("Successfully applied negotiated protocol"); - return Poll::Ready(Ok(((peer_id, role), d))) + return Poll::Ready(Ok(((peer_id, role), d))); } Poll::Ready(Err(e)) => { debug!("Failed to apply negotiated protocol"); - return Poll::Ready(Err(UpgradeError::Apply(e))) + return Poll::Ready(Err(UpgradeError::Apply(e))); } } } - AuthenticationUpgradeApplyState::Undefined => + AuthenticationUpgradeApplyState::Undefined => { panic!("AuthenticationUpgradeApplyState::poll called after completion") + } } } } } -pub type NameWrapIter = iter::Map::Item) -> NameWrap<::Item>>; +pub type NameWrapIter = + iter::Map::Item) -> NameWrap<::Item>>; /// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`. #[derive(Clone)] diff --git a/core/src/upgrade/either.rs b/core/src/upgrade/either.rs index 28db987ccd7..8b5c7f71422 100644 --- a/core/src/upgrade/either.rs +++ b/core/src/upgrade/either.rs @@ -19,29 +19,32 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - either::{EitherOutput, EitherError, EitherFuture2, EitherName}, - upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo} + either::{EitherError, EitherFuture2, EitherName, EitherOutput}, + upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}, }; /// A type to represent two possible upgrade types (inbound or outbound). #[derive(Debug, Clone)] -pub enum EitherUpgrade { A(A), B(B) } +pub enum EitherUpgrade { + A(A), + B(B), +} impl UpgradeInfo for EitherUpgrade where A: UpgradeInfo, - B: UpgradeInfo + B: UpgradeInfo, { type Info = EitherName; type InfoIter = EitherIter< ::IntoIter, - ::IntoIter + ::IntoIter, >; fn protocol_info(&self) -> Self::InfoIter { match self { EitherUpgrade::A(a) => EitherIter::A(a.protocol_info().into_iter()), - EitherUpgrade::B(b) => EitherIter::B(b.protocol_info().into_iter()) + EitherUpgrade::B(b) => EitherIter::B(b.protocol_info().into_iter()), } } } @@ -57,9 +60,13 @@ where fn upgrade_inbound(self, sock: C, info: Self::Info) -> Self::Future { match (self, info) { - (EitherUpgrade::A(a), EitherName::A(info)) => EitherFuture2::A(a.upgrade_inbound(sock, info)), - (EitherUpgrade::B(b), EitherName::B(info)) => EitherFuture2::B(b.upgrade_inbound(sock, info)), - _ => panic!("Invalid invocation of EitherUpgrade::upgrade_inbound") + (EitherUpgrade::A(a), EitherName::A(info)) => { + EitherFuture2::A(a.upgrade_inbound(sock, info)) + } + (EitherUpgrade::B(b), EitherName::B(info)) => { + EitherFuture2::B(b.upgrade_inbound(sock, info)) + } + _ => panic!("Invalid invocation of EitherUpgrade::upgrade_inbound"), } } } @@ -75,36 +82,42 @@ where fn upgrade_outbound(self, sock: C, info: Self::Info) -> Self::Future { match (self, info) { - (EitherUpgrade::A(a), EitherName::A(info)) => EitherFuture2::A(a.upgrade_outbound(sock, info)), - (EitherUpgrade::B(b), EitherName::B(info)) => EitherFuture2::B(b.upgrade_outbound(sock, info)), - _ => panic!("Invalid invocation of EitherUpgrade::upgrade_outbound") + (EitherUpgrade::A(a), EitherName::A(info)) => { + EitherFuture2::A(a.upgrade_outbound(sock, info)) + } + (EitherUpgrade::B(b), EitherName::B(info)) => { + EitherFuture2::B(b.upgrade_outbound(sock, info)) + } + _ => panic!("Invalid invocation of EitherUpgrade::upgrade_outbound"), } } } /// A type to represent two possible `Iterator` types. #[derive(Debug, Clone)] -pub enum EitherIter { A(A), B(B) } +pub enum EitherIter { + A(A), + B(B), +} impl Iterator for EitherIter where A: Iterator, - B: Iterator + B: Iterator, { type Item = EitherName; fn next(&mut self) -> Option { match self { EitherIter::A(a) => a.next().map(EitherName::A), - EitherIter::B(b) => b.next().map(EitherName::B) + EitherIter::B(b) => b.next().map(EitherName::B), } } fn size_hint(&self) -> (usize, Option) { match self { EitherIter::A(a) => a.size_hint(), - EitherIter::B(b) => b.size_hint() + EitherIter::B(b) => b.size_hint(), } } } - diff --git a/core/src/upgrade/error.rs b/core/src/upgrade/error.rs index de0ecadbd51..2bbe95ecf2a 100644 --- a/core/src/upgrade/error.rs +++ b/core/src/upgrade/error.rs @@ -33,7 +33,7 @@ pub enum UpgradeError { impl UpgradeError { pub fn map_err(self, f: F) -> UpgradeError where - F: FnOnce(E) -> T + F: FnOnce(E) -> T, { match self { UpgradeError::Select(e) => UpgradeError::Select(e), @@ -43,7 +43,7 @@ impl UpgradeError { pub fn into_err(self) -> UpgradeError where - T: From + T: From, { self.map_err(Into::into) } @@ -51,7 +51,7 @@ impl UpgradeError { impl fmt::Display for UpgradeError where - E: fmt::Display + E: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -63,7 +63,7 @@ where impl std::error::Error for UpgradeError where - E: std::error::Error + 'static + E: std::error::Error + 'static, { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { @@ -78,4 +78,3 @@ impl From for UpgradeError { UpgradeError::Select(e) } } - diff --git a/core/src/upgrade/from_fn.rs b/core/src/upgrade/from_fn.rs index 0c8947e5b30..97bbc2eb292 100644 --- a/core/src/upgrade/from_fn.rs +++ b/core/src/upgrade/from_fn.rs @@ -18,7 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{Endpoint, upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeInfo}}; +use crate::{ + upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeInfo}, + Endpoint, +}; use futures::prelude::*; use std::iter; diff --git a/core/src/upgrade/map.rs b/core/src/upgrade/map.rs index 2f5ca31e207..c5fe34f44b5 100644 --- a/core/src/upgrade/map.rs +++ b/core/src/upgrade/map.rs @@ -24,7 +24,10 @@ use std::{pin::Pin, task::Context, task::Poll}; /// Wraps around an upgrade and applies a closure to the output. #[derive(Debug, Clone)] -pub struct MapInboundUpgrade { upgrade: U, fun: F } +pub struct MapInboundUpgrade { + upgrade: U, + fun: F, +} impl MapInboundUpgrade { pub fn new(upgrade: U, fun: F) -> Self { @@ -34,7 +37,7 @@ impl MapInboundUpgrade { impl UpgradeInfo for MapInboundUpgrade where - U: UpgradeInfo + U: UpgradeInfo, { type Info = U::Info; type InfoIter = U::InfoIter; @@ -47,7 +50,7 @@ where impl InboundUpgrade for MapInboundUpgrade where U: InboundUpgrade, - F: FnOnce(U::Output) -> T + F: FnOnce(U::Output) -> T, { type Output = T; type Error = U::Error; @@ -56,7 +59,7 @@ where fn upgrade_inbound(self, sock: C, info: Self::Info) -> Self::Future { MapFuture { inner: self.upgrade.upgrade_inbound(sock, info), - map: Some(self.fun) + map: Some(self.fun), } } } @@ -76,7 +79,10 @@ where /// Wraps around an upgrade and applies a closure to the output. #[derive(Debug, Clone)] -pub struct MapOutboundUpgrade { upgrade: U, fun: F } +pub struct MapOutboundUpgrade { + upgrade: U, + fun: F, +} impl MapOutboundUpgrade { pub fn new(upgrade: U, fun: F) -> Self { @@ -86,7 +92,7 @@ impl MapOutboundUpgrade { impl UpgradeInfo for MapOutboundUpgrade where - U: UpgradeInfo + U: UpgradeInfo, { type Info = U::Info; type InfoIter = U::InfoIter; @@ -112,7 +118,7 @@ where impl OutboundUpgrade for MapOutboundUpgrade where U: OutboundUpgrade, - F: FnOnce(U::Output) -> T + F: FnOnce(U::Output) -> T, { type Output = T; type Error = U::Error; @@ -121,14 +127,17 @@ where fn upgrade_outbound(self, sock: C, info: Self::Info) -> Self::Future { MapFuture { inner: self.upgrade.upgrade_outbound(sock, info), - map: Some(self.fun) + map: Some(self.fun), } } } /// Wraps around an upgrade and applies a closure to the error. #[derive(Debug, Clone)] -pub struct MapInboundUpgradeErr { upgrade: U, fun: F } +pub struct MapInboundUpgradeErr { + upgrade: U, + fun: F, +} impl MapInboundUpgradeErr { pub fn new(upgrade: U, fun: F) -> Self { @@ -138,7 +147,7 @@ impl MapInboundUpgradeErr { impl UpgradeInfo for MapInboundUpgradeErr where - U: UpgradeInfo + U: UpgradeInfo, { type Info = U::Info; type InfoIter = U::InfoIter; @@ -151,7 +160,7 @@ where impl InboundUpgrade for MapInboundUpgradeErr where U: InboundUpgrade, - F: FnOnce(U::Error) -> T + F: FnOnce(U::Error) -> T, { type Output = U::Output; type Error = T; @@ -160,7 +169,7 @@ where fn upgrade_inbound(self, sock: C, info: Self::Info) -> Self::Future { MapErrFuture { fut: self.upgrade.upgrade_inbound(sock, info), - fun: Some(self.fun) + fun: Some(self.fun), } } } @@ -180,7 +189,10 @@ where /// Wraps around an upgrade and applies a closure to the error. #[derive(Debug, Clone)] -pub struct MapOutboundUpgradeErr { upgrade: U, fun: F } +pub struct MapOutboundUpgradeErr { + upgrade: U, + fun: F, +} impl MapOutboundUpgradeErr { pub fn new(upgrade: U, fun: F) -> Self { @@ -190,7 +202,7 @@ impl MapOutboundUpgradeErr { impl UpgradeInfo for MapOutboundUpgradeErr where - U: UpgradeInfo + U: UpgradeInfo, { type Info = U::Info; type InfoIter = U::InfoIter; @@ -203,7 +215,7 @@ where impl OutboundUpgrade for MapOutboundUpgradeErr where U: OutboundUpgrade, - F: FnOnce(U::Error) -> T + F: FnOnce(U::Error) -> T, { type Output = U::Output; type Error = T; @@ -212,14 +224,14 @@ where fn upgrade_outbound(self, sock: C, info: Self::Info) -> Self::Future { MapErrFuture { fut: self.upgrade.upgrade_outbound(sock, info), - fun: Some(self.fun) + fun: Some(self.fun), } } } impl InboundUpgrade for MapOutboundUpgradeErr where - U: InboundUpgrade + U: InboundUpgrade, { type Output = U::Output; type Error = U::Error; @@ -283,4 +295,3 @@ where } } } - diff --git a/core/src/upgrade/optional.rs b/core/src/upgrade/optional.rs index 02dc3c48f78..c661a4f0170 100644 --- a/core/src/upgrade/optional.rs +++ b/core/src/upgrade/optional.rs @@ -112,8 +112,4 @@ where } } -impl ExactSizeIterator for Iter -where - T: ExactSizeIterator -{ -} +impl ExactSizeIterator for Iter where T: ExactSizeIterator {} diff --git a/core/src/upgrade/select.rs b/core/src/upgrade/select.rs index 8fa4c5b8a7a..d1a8cabca2f 100644 --- a/core/src/upgrade/select.rs +++ b/core/src/upgrade/select.rs @@ -19,8 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - either::{EitherOutput, EitherError, EitherFuture2, EitherName}, - upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo} + either::{EitherError, EitherFuture2, EitherName, EitherOutput}, + upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}, }; /// Upgrade that combines two upgrades into one. Supports all the protocols supported by either @@ -42,16 +42,19 @@ impl SelectUpgrade { impl UpgradeInfo for SelectUpgrade where A: UpgradeInfo, - B: UpgradeInfo + B: UpgradeInfo, { type Info = EitherName; type InfoIter = InfoIterChain< ::IntoIter, - ::IntoIter + ::IntoIter, >; fn protocol_info(&self) -> Self::InfoIter { - InfoIterChain(self.0.protocol_info().into_iter(), self.1.protocol_info().into_iter()) + InfoIterChain( + self.0.protocol_info().into_iter(), + self.1.protocol_info().into_iter(), + ) } } @@ -67,7 +70,7 @@ where fn upgrade_inbound(self, sock: C, info: Self::Info) -> Self::Future { match info { EitherName::A(info) => EitherFuture2::A(self.0.upgrade_inbound(sock, info)), - EitherName::B(info) => EitherFuture2::B(self.1.upgrade_inbound(sock, info)) + EitherName::B(info) => EitherFuture2::B(self.1.upgrade_inbound(sock, info)), } } } @@ -84,7 +87,7 @@ where fn upgrade_outbound(self, sock: C, info: Self::Info) -> Self::Future { match info { EitherName::A(info) => EitherFuture2::A(self.0.upgrade_outbound(sock, info)), - EitherName::B(info) => EitherFuture2::B(self.1.upgrade_outbound(sock, info)) + EitherName::B(info) => EitherFuture2::B(self.1.upgrade_outbound(sock, info)), } } } @@ -96,16 +99,16 @@ pub struct InfoIterChain(A, B); impl Iterator for InfoIterChain where A: Iterator, - B: Iterator + B: Iterator, { type Item = EitherName; fn next(&mut self) -> Option { if let Some(info) = self.0.next() { - return Some(EitherName::A(info)) + return Some(EitherName::A(info)); } if let Some(info) = self.1.next() { - return Some(EitherName::B(info)) + return Some(EitherName::B(info)); } None } @@ -117,4 +120,3 @@ where (min1.saturating_add(min2), max) } } - diff --git a/core/src/upgrade/transfer.rs b/core/src/upgrade/transfer.rs index 500ece523c5..fd8127758f1 100644 --- a/core/src/upgrade/transfer.rs +++ b/core/src/upgrade/transfer.rs @@ -29,9 +29,10 @@ use std::{error, fmt, io}; /// /// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is /// > compatible with what [`read_length_prefixed`] expects. -pub async fn write_length_prefixed(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>) - -> Result<(), io::Error> -{ +pub async fn write_length_prefixed( + socket: &mut (impl AsyncWrite + Unpin), + data: impl AsRef<[u8]>, +) -> Result<(), io::Error> { write_varint(socket, data.as_ref().len()).await?; socket.write_all(data.as_ref()).await?; socket.flush().await?; @@ -44,11 +45,15 @@ pub async fn write_length_prefixed(socket: &mut (impl AsyncWrite + Unpin), data: /// > **Note**: Prepends a variable-length prefix indicate the length of the message. This is /// > compatible with what `read_one` expects. /// -#[deprecated(since = "0.29.0", note = "Use `write_length_prefixed` instead. You will need to manually close the stream using `socket.close().await`.")] +#[deprecated( + since = "0.29.0", + note = "Use `write_length_prefixed` instead. You will need to manually close the stream using `socket.close().await`." +)] #[allow(dead_code)] -pub async fn write_one(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>) - -> Result<(), io::Error> -{ +pub async fn write_one( + socket: &mut (impl AsyncWrite + Unpin), + data: impl AsRef<[u8]>, +) -> Result<(), io::Error> { write_varint(socket, data.as_ref().len()).await?; socket.write_all(data.as_ref()).await?; socket.close().await?; @@ -61,9 +66,10 @@ pub async fn write_one(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef< /// > compatible with what `read_one` expects. #[deprecated(since = "0.29.0", note = "Use `write_length_prefixed` instead.")] #[allow(dead_code)] -pub async fn write_with_len_prefix(socket: &mut (impl AsyncWrite + Unpin), data: impl AsRef<[u8]>) - -> Result<(), io::Error> -{ +pub async fn write_with_len_prefix( + socket: &mut (impl AsyncWrite + Unpin), + data: impl AsRef<[u8]>, +) -> Result<(), io::Error> { write_varint(socket, data.as_ref().len()).await?; socket.write_all(data.as_ref()).await?; socket.flush().await?; @@ -73,9 +79,10 @@ pub async fn write_with_len_prefix(socket: &mut (impl AsyncWrite + Unpin), data: /// Writes a variable-length integer to the `socket`. /// /// > **Note**: Does **NOT** flush the socket. -pub async fn write_varint(socket: &mut (impl AsyncWrite + Unpin), len: usize) - -> Result<(), io::Error> -{ +pub async fn write_varint( + socket: &mut (impl AsyncWrite + Unpin), + len: usize, +) -> Result<(), io::Error> { let mut len_data = unsigned_varint::encode::usize_buffer(); let encoded_len = unsigned_varint::encode::usize(len, &mut len_data).len(); socket.write_all(&len_data[..encoded_len]).await?; @@ -95,7 +102,7 @@ pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result { // Reaching EOF before finishing to read the length is an error, unless the EOF is // at the very beginning of the substream, in which case we assume that the data is @@ -116,7 +123,7 @@ pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result { return Err(io::Error::new( io::ErrorKind::InvalidData, - "overflow in variable-length integer" + "overflow in variable-length integer", )); } // TODO: why do we have a `__Nonexhaustive` variant in the error? I don't know how to process it @@ -134,11 +141,19 @@ pub async fn read_varint(socket: &mut (impl AsyncRead + Unpin)) -> Result **Note**: Assumes that a variable-length prefix indicates the length of the message. This is /// > compatible with what [`write_length_prefixed`] does. -pub async fn read_length_prefixed(socket: &mut (impl AsyncRead + Unpin), max_size: usize) -> io::Result> -{ +pub async fn read_length_prefixed( + socket: &mut (impl AsyncRead + Unpin), + max_size: usize, +) -> io::Result> { let len = read_varint(socket).await?; if len > max_size { - return Err(io::Error::new(io::ErrorKind::InvalidData, format!("Received data size ({} bytes) exceeds maximum ({} bytes)", len, max_size))) + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "Received data size ({} bytes) exceeds maximum ({} bytes)", + len, max_size + ), + )); } let mut buf = vec![0; len]; @@ -157,9 +172,10 @@ pub async fn read_length_prefixed(socket: &mut (impl AsyncRead + Unpin), max_siz /// > compatible with what `write_one` does. #[deprecated(since = "0.29.0", note = "Use `read_length_prefixed` instead.")] #[allow(dead_code, deprecated)] -pub async fn read_one(socket: &mut (impl AsyncRead + Unpin), max_size: usize) - -> Result, ReadOneError> -{ +pub async fn read_one( + socket: &mut (impl AsyncRead + Unpin), + max_size: usize, +) -> Result, ReadOneError> { let len = read_varint(socket).await?; if len > max_size { return Err(ReadOneError::TooLarge { @@ -175,7 +191,10 @@ pub async fn read_one(socket: &mut (impl AsyncRead + Unpin), max_size: usize) /// Error while reading one message. #[derive(Debug)] -#[deprecated(since = "0.29.0", note = "Use `read_length_prefixed` instead of `read_one` to avoid depending on this type.")] +#[deprecated( + since = "0.29.0", + note = "Use `read_length_prefixed` instead of `read_one` to avoid depending on this type." +)] pub enum ReadOneError { /// Error on the socket. Io(std::io::Error), @@ -239,7 +258,7 @@ mod tests { } // TODO: rewrite these tests -/* + /* #[test] fn read_one_works() { let original_data = (0..rand::random::() % 10_000) diff --git a/core/tests/connection_limits.rs b/core/tests/connection_limits.rs index 178eacbd192..65e61c4b3c4 100644 --- a/core/tests/connection_limits.rs +++ b/core/tests/connection_limits.rs @@ -20,16 +20,16 @@ mod util; -use futures::{ready, future::poll_fn}; +use futures::{future::poll_fn, ready}; use libp2p_core::multiaddr::{multiaddr, Multiaddr}; use libp2p_core::{ - PeerId, connection::PendingConnectionError, - network::{NetworkEvent, NetworkConfig, ConnectionLimits, DialError}, + network::{ConnectionLimits, DialError, NetworkConfig, NetworkEvent}, + PeerId, }; use rand::Rng; use std::task::Poll; -use util::{TestHandler, test_network}; +use util::{test_network, TestHandler}; #[test] fn max_outgoing() { @@ -40,14 +40,16 @@ fn max_outgoing() { let mut network = test_network(cfg); let target = PeerId::random(); - for _ in 0 .. outgoing_limit { - network.peer(target.clone()) + for _ in 0..outgoing_limit { + network + .peer(target.clone()) .dial(Multiaddr::empty(), Vec::new(), TestHandler()) .ok() .expect("Unexpected connection limit."); } - match network.peer(target.clone()) + match network + .peer(target.clone()) .dial(Multiaddr::empty(), Vec::new(), TestHandler()) .expect_err("Unexpected dialing success.") { @@ -60,10 +62,14 @@ fn max_outgoing() { let info = network.info(); assert_eq!(info.num_peers(), 0); - assert_eq!(info.connection_counters().num_pending_outgoing(), outgoing_limit); + assert_eq!( + info.connection_counters().num_pending_outgoing(), + outgoing_limit + ); // Abort all dialing attempts. - let mut peer = network.peer(target.clone()) + let mut peer = network + .peer(target.clone()) .into_dialing() .expect("Unexpected peer state"); @@ -72,7 +78,10 @@ fn max_outgoing() { attempt.abort(); } - assert_eq!(network.info().connection_counters().num_pending_outgoing(), 0); + assert_eq!( + network.info().connection_counters().num_pending_outgoing(), + 0 + ); } #[test] @@ -87,35 +96,34 @@ fn max_established_incoming() { let mut network1 = test_network(config(limit)); let mut network2 = test_network(config(limit)); - let listen_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127,0,0,1)), Tcp(0u16)]; + let listen_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1)), Tcp(0u16)]; let _ = network1.listen_on(listen_addr.clone()).unwrap(); let (addr_sender, addr_receiver) = futures::channel::oneshot::channel(); let mut addr_sender = Some(addr_sender); // Spawn the listener. - let listener = async_std::task::spawn(poll_fn(move |cx| { - loop { - match ready!(network1.poll(cx)) { - NetworkEvent::NewListenerAddress { listen_addr, .. } => { - addr_sender.take().unwrap().send(listen_addr).unwrap(); - } - NetworkEvent::IncomingConnection { connection, .. } => { - network1.accept(connection, TestHandler()).unwrap(); - } - NetworkEvent::ConnectionEstablished { .. } => {} - NetworkEvent::IncomingConnectionError { - error: PendingConnectionError::ConnectionLimit(err), .. - } => { - assert_eq!(err.limit, limit); - assert_eq!(err.limit, err.current); - let info = network1.info(); - let counters = info.connection_counters(); - assert_eq!(counters.num_established_incoming(), limit); - assert_eq!(counters.num_established(), limit); - return Poll::Ready(()) - } - e => panic!("Unexpected network event: {:?}", e) + let listener = async_std::task::spawn(poll_fn(move |cx| loop { + match ready!(network1.poll(cx)) { + NetworkEvent::NewListenerAddress { listen_addr, .. } => { + addr_sender.take().unwrap().send(listen_addr).unwrap(); + } + NetworkEvent::IncomingConnection { connection, .. } => { + network1.accept(connection, TestHandler()).unwrap(); } + NetworkEvent::ConnectionEstablished { .. } => {} + NetworkEvent::IncomingConnectionError { + error: PendingConnectionError::ConnectionLimit(err), + .. + } => { + assert_eq!(err.limit, limit); + assert_eq!(err.limit, err.current); + let info = network1.info(); + let counters = info.connection_counters(); + assert_eq!(counters.num_established_incoming(), limit); + assert_eq!(counters.num_established(), limit); + return Poll::Ready(()); + } + e => panic!("Unexpected network event: {:?}", e), } })); @@ -152,15 +160,15 @@ fn max_established_incoming() { let counters = info.connection_counters(); assert_eq!(counters.num_established_outgoing(), limit); assert_eq!(counters.num_established(), limit); - return Poll::Ready(()) + return Poll::Ready(()); } - e => panic!("Unexpected network event: {:?}", e) + e => panic!("Unexpected network event: {:?}", e), } } - }).await + }) + .await }); // Wait for the listener to complete. async_std::task::block_on(listener); } - diff --git a/core/tests/network_dial_error.rs b/core/tests/network_dial_error.rs index 2edb133ffc9..224d7950eac 100644 --- a/core/tests/network_dial_error.rs +++ b/core/tests/network_dial_error.rs @@ -23,14 +23,14 @@ mod util; use futures::prelude::*; use libp2p_core::multiaddr::multiaddr; use libp2p_core::{ - PeerId, connection::PendingConnectionError, multiaddr::Protocol, - network::{NetworkEvent, NetworkConfig}, + network::{NetworkConfig, NetworkEvent}, + PeerId, }; use rand::seq::SliceRandom; use std::{io, task::Poll}; -use util::{TestHandler, test_network}; +use util::{test_network, TestHandler}; #[test] fn deny_incoming_connec() { @@ -39,16 +39,16 @@ fn deny_incoming_connec() { let mut swarm1 = test_network(NetworkConfig::default()); let mut swarm2 = test_network(NetworkConfig::default()); - swarm1.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); + swarm1 + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); - let address = async_std::task::block_on(future::poll_fn(|cx| { - match swarm1.poll(cx) { - Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { - Poll::Ready(listen_addr) - } - Poll::Pending => Poll::Pending, - _ => panic!("Was expecting the listen address to be reported"), + let address = async_std::task::block_on(future::poll_fn(|cx| match swarm1.poll(cx) { + Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { + Poll::Ready(listen_addr) } + Poll::Pending => Poll::Pending, + _ => panic!("Was expecting the listen address to be reported"), })); swarm2 @@ -68,23 +68,26 @@ fn deny_incoming_connec() { attempts_remaining: 0, peer_id, multiaddr, - error: PendingConnectionError::Transport(_) + error: PendingConnectionError::Transport(_), }) => { assert_eq!(&peer_id, swarm1.local_peer_id()); - assert_eq!(multiaddr, address.clone().with(Protocol::P2p(peer_id.into()))); + assert_eq!( + multiaddr, + address.clone().with(Protocol::P2p(peer_id.into())) + ); return Poll::Ready(Ok(())); - }, + } Poll::Ready(_) => unreachable!(), Poll::Pending => (), } Poll::Pending - })).unwrap(); + })) + .unwrap(); } #[test] fn dial_self() { - // Check whether dialing ourselves correctly fails. // // Dialing the same address we're listening should result in three events: @@ -96,16 +99,16 @@ fn dial_self() { // The last two can happen in any order. let mut swarm = test_network(NetworkConfig::default()); - swarm.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); + swarm + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); - let local_address = async_std::task::block_on(future::poll_fn(|cx| { - match swarm.poll(cx) { - Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { - Poll::Ready(listen_addr) - } - Poll::Pending => Poll::Pending, - _ => panic!("Was expecting the listen address to be reported"), + let local_address = async_std::task::block_on(future::poll_fn(|cx| match swarm.poll(cx) { + Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) => { + Poll::Ready(listen_addr) } + Poll::Pending => Poll::Pending, + _ => panic!("Was expecting the listen address to be reported"), })); swarm.dial(&local_address, TestHandler()).unwrap(); @@ -124,30 +127,29 @@ fn dial_self() { assert_eq!(multiaddr, local_address); got_dial_err = true; if got_inc_err { - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } - }, - Poll::Ready(NetworkEvent::IncomingConnectionError { - local_addr, .. - }) => { + } + Poll::Ready(NetworkEvent::IncomingConnectionError { local_addr, .. }) => { assert!(!got_inc_err); assert_eq!(local_addr, local_address); got_inc_err = true; if got_dial_err { - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } - }, + } Poll::Ready(NetworkEvent::IncomingConnection { connection, .. }) => { assert_eq!(&connection.local_addr, &local_address); swarm.accept(connection, TestHandler()).unwrap(); - }, + } Poll::Ready(ev) => { panic!("Unexpected event: {:?}", ev) } Poll::Pending => break Poll::Pending, } } - })).unwrap(); + })) + .unwrap(); } #[test] @@ -168,23 +170,19 @@ fn multiple_addresses_err() { let mut swarm = test_network(NetworkConfig::default()); let mut addresses = Vec::new(); - for _ in 0 .. 3 { - addresses.push(multiaddr![ - Ip4([0, 0, 0, 0]), - Tcp(rand::random::()) - ]); + for _ in 0..3 { + addresses.push(multiaddr![Ip4([0, 0, 0, 0]), Tcp(rand::random::())]); } - for _ in 0 .. 5 { - addresses.push(multiaddr![ - Udp(rand::random::()) - ]); + for _ in 0..5 { + addresses.push(multiaddr![Udp(rand::random::())]); } addresses.shuffle(&mut rand::thread_rng()); let first = addresses[0].clone(); let rest = (&addresses[1..]).iter().cloned(); - swarm.peer(target.clone()) + swarm + .peer(target.clone()) .dial(first, rest, TestHandler()) .unwrap(); @@ -195,10 +193,12 @@ fn multiple_addresses_err() { attempts_remaining, peer_id, multiaddr, - error: PendingConnectionError::Transport(_) + error: PendingConnectionError::Transport(_), }) => { assert_eq!(peer_id, target); - let expected = addresses.remove(0).with(Protocol::P2p(target.clone().into())); + let expected = addresses + .remove(0) + .with(Protocol::P2p(target.clone().into())); assert_eq!(multiaddr, expected); if addresses.is_empty() { assert_eq!(attempts_remaining, 0); @@ -206,10 +206,11 @@ fn multiple_addresses_err() { } else { assert_eq!(attempts_remaining, addresses.len() as u32); } - }, + } Poll::Ready(_) => unreachable!(), Poll::Pending => break Poll::Pending, } } - })).unwrap(); + })) + .unwrap(); } diff --git a/core/tests/transport_upgrade.rs b/core/tests/transport_upgrade.rs index b6286832303..42712463e89 100644 --- a/core/tests/transport_upgrade.rs +++ b/core/tests/transport_upgrade.rs @@ -22,8 +22,8 @@ mod util; use futures::prelude::*; use libp2p_core::identity; -use libp2p_core::transport::{Transport, MemoryTransport}; -use libp2p_core::upgrade::{UpgradeInfo, InboundUpgrade, OutboundUpgrade}; +use libp2p_core::transport::{MemoryTransport, Transport}; +use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use libp2p_mplex::MplexConfig; use libp2p_noise as noise; use multiaddr::{Multiaddr, Protocol}; @@ -44,7 +44,7 @@ impl UpgradeInfo for HelloUpgrade { impl InboundUpgrade for HelloUpgrade where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = C; type Error = io::Error; @@ -81,7 +81,9 @@ where fn upgrade_pipeline() { let listener_keys = identity::Keypair::generate_ed25519(); let listener_id = listener_keys.public().to_peer_id(); - let listener_noise_keys = noise::Keypair::::new().into_authentic(&listener_keys).unwrap(); + let listener_noise_keys = noise::Keypair::::new() + .into_authentic(&listener_keys) + .unwrap(); let listener_transport = MemoryTransport::default() .upgrade() .authenticate(noise::NoiseConfig::xx(listener_noise_keys).into_authenticated()) @@ -97,7 +99,9 @@ fn upgrade_pipeline() { let dialer_keys = identity::Keypair::generate_ed25519(); let dialer_id = dialer_keys.public().to_peer_id(); - let dialer_noise_keys = noise::Keypair::::new().into_authentic(&dialer_keys).unwrap(); + let dialer_noise_keys = noise::Keypair::::new() + .into_authentic(&dialer_keys) + .unwrap(); let dialer_transport = MemoryTransport::default() .upgrade() .authenticate(noise::NoiseConfig::xx(dialer_noise_keys).into_authenticated()) @@ -121,7 +125,7 @@ fn upgrade_pipeline() { let (upgrade, _remote_addr) = match listener.next().await.unwrap().unwrap().into_upgrade() { Some(u) => u, - None => continue + None => continue, }; let (peer, _mplex) = upgrade.await.unwrap(); assert_eq!(peer, dialer_id); diff --git a/core/tests/util.rs b/core/tests/util.rs index 42c18c6a060..62bfccc20de 100644 --- a/core/tests/util.rs +++ b/core/tests/util.rs @@ -1,21 +1,12 @@ - #![allow(dead_code)] use futures::prelude::*; use libp2p_core::{ - Multiaddr, - PeerId, - Transport, - connection::{ - ConnectionHandler, - ConnectionHandlerEvent, - Substream, - SubstreamEndpoint, - }, + connection::{ConnectionHandler, ConnectionHandlerEvent, Substream, SubstreamEndpoint}, identity, muxing::{StreamMuxer, StreamMuxerBox}, network::{Network, NetworkConfig}, - transport, + transport, Multiaddr, PeerId, Transport, }; use libp2p_mplex as mplex; use libp2p_noise as noise; @@ -29,7 +20,9 @@ type TestTransport = transport::Boxed<(PeerId, StreamMuxerBox)>; pub fn test_network(cfg: NetworkConfig) -> TestNetwork { let local_key = identity::Keypair::generate_ed25519(); let local_public_key = local_key.public(); - let noise_keys = noise::Keypair::::new().into_authentic(&local_key).unwrap(); + let noise_keys = noise::Keypair::::new() + .into_authentic(&local_key) + .unwrap(); let transport: TestTransport = tcp::TcpConfig::new() .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) @@ -48,17 +41,21 @@ impl ConnectionHandler for TestHandler { type Substream = Substream; type OutboundOpenInfo = (); - fn inject_substream(&mut self, _: Self::Substream, _: SubstreamEndpoint) - {} + fn inject_substream( + &mut self, + _: Self::Substream, + _: SubstreamEndpoint, + ) { + } - fn inject_event(&mut self, _: Self::InEvent) - {} + fn inject_event(&mut self, _: Self::InEvent) {} - fn inject_address_change(&mut self, _: &Multiaddr) - {} + fn inject_address_change(&mut self, _: &Multiaddr) {} - fn poll(&mut self, _: &mut Context<'_>) - -> Poll, Self::Error>> + fn poll( + &mut self, + _: &mut Context<'_>, + ) -> Poll, Self::Error>> { Poll::Pending } @@ -71,7 +68,7 @@ pub struct CloseMuxer { impl CloseMuxer { pub fn new(m: M) -> CloseMuxer { CloseMuxer { - state: CloseMuxerState::Close(m) + state: CloseMuxerState::Close(m), } } } @@ -84,7 +81,7 @@ pub enum CloseMuxerState { impl Future for CloseMuxer where M: StreamMuxer, - M::Error: From + M::Error: From, { type Output = Result; @@ -94,15 +91,14 @@ where CloseMuxerState::Close(muxer) => { if !muxer.close(cx)?.is_ready() { self.state = CloseMuxerState::Close(muxer); - return Poll::Pending + return Poll::Pending; } - return Poll::Ready(Ok(muxer)) + return Poll::Ready(Ok(muxer)); } - CloseMuxerState::Done => panic!() + CloseMuxerState::Done => panic!(), } } } } -impl Unpin for CloseMuxer { -} +impl Unpin for CloseMuxer {} diff --git a/examples/chat-tokio.rs b/examples/chat-tokio.rs index df5492cf461..9cb8070b6d7 100644 --- a/examples/chat-tokio.rs +++ b/examples/chat-tokio.rs @@ -38,18 +38,18 @@ use futures::StreamExt; use libp2p::{ - Multiaddr, - NetworkBehaviour, - PeerId, - Transport, - identity, floodsub::{self, Floodsub, FloodsubEvent}, + identity, mdns::{Mdns, MdnsEvent}, mplex, noise, swarm::{NetworkBehaviourEventProcess, SwarmBuilder, SwarmEvent}, // `TokioTcpConfig` is available through the `tcp-tokio` feature. tcp::TokioTcpConfig, + Multiaddr, + NetworkBehaviour, + PeerId, + Transport, }; use std::error::Error; use tokio::io::{self, AsyncBufReadExt}; @@ -71,7 +71,8 @@ async fn main() -> Result<(), Box> { // Create a tokio-based TCP transport use noise for authenticated // encryption and Mplex for multiplexing of substreams on a TCP stream. - let transport = TokioTcpConfig::new().nodelay(true) + let transport = TokioTcpConfig::new() + .nodelay(true) .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) .multiplex(mplex::MplexConfig::new()) @@ -94,7 +95,11 @@ async fn main() -> Result<(), Box> { // Called when `floodsub` produces an event. fn inject_event(&mut self, message: FloodsubEvent) { if let FloodsubEvent::Message(message) = message { - println!("Received: '{:?}' from {:?}", String::from_utf8_lossy(&message.data), message.source); + println!( + "Received: '{:?}' from {:?}", + String::from_utf8_lossy(&message.data), + message.source + ); } } } @@ -103,16 +108,18 @@ async fn main() -> Result<(), Box> { // Called when `mdns` produces an event. fn inject_event(&mut self, event: MdnsEvent) { match event { - MdnsEvent::Discovered(list) => + MdnsEvent::Discovered(list) => { for (peer, _) in list { self.floodsub.add_node_to_partial_view(peer); } - MdnsEvent::Expired(list) => + } + MdnsEvent::Expired(list) => { for (peer, _) in list { if !self.mdns.has_node(&peer) { self.floodsub.remove_node_from_partial_view(&peer); } } + } } } } @@ -130,7 +137,9 @@ async fn main() -> Result<(), Box> { SwarmBuilder::new(transport, behaviour, peer_id) // We want the connection background tasks to be spawned // onto the tokio runtime. - .executor(Box::new(|fut| { tokio::spawn(fut); })) + .executor(Box::new(|fut| { + tokio::spawn(fut); + })) .build() }; diff --git a/examples/chat.rs b/examples/chat.rs index fddec5e5e9b..18ef72b96b8 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -31,7 +31,7 @@ //! # If they don't automatically connect //! //! If the nodes don't automatically connect, take note of the listening addresses of the first -//! instance and start the second with one of the addresses as the first argument. In the first +//! instance and start the second with one of the addresses as the first argument. In the first //! terminal window, run: //! //! ```sh @@ -52,16 +52,16 @@ use async_std::{io, task}; use futures::{future, prelude::*}; use libp2p::{ - Multiaddr, - PeerId, - Swarm, - NetworkBehaviour, - identity, floodsub::{self, Floodsub, FloodsubEvent}, + identity, mdns::{Mdns, MdnsConfig, MdnsEvent}, - swarm::{NetworkBehaviourEventProcess, SwarmEvent} + swarm::{NetworkBehaviourEventProcess, SwarmEvent}, + Multiaddr, NetworkBehaviour, PeerId, Swarm, +}; +use std::{ + error::Error, + task::{Context, Poll}, }; -use std::{error::Error, task::{Context, Poll}}; #[async_std::main] async fn main() -> Result<(), Box> { @@ -97,7 +97,11 @@ async fn main() -> Result<(), Box> { // Called when `floodsub` produces an event. fn inject_event(&mut self, message: FloodsubEvent) { if let FloodsubEvent::Message(message) = message { - println!("Received: '{:?}' from {:?}", String::from_utf8_lossy(&message.data), message.source); + println!( + "Received: '{:?}' from {:?}", + String::from_utf8_lossy(&message.data), + message.source + ); } } } @@ -106,16 +110,18 @@ async fn main() -> Result<(), Box> { // Called when `mdns` produces an event. fn inject_event(&mut self, event: MdnsEvent) { match event { - MdnsEvent::Discovered(list) => + MdnsEvent::Discovered(list) => { for (peer, _) in list { self.floodsub.add_node_to_partial_view(peer); } - MdnsEvent::Expired(list) => + } + MdnsEvent::Expired(list) => { for (peer, _) in list { if !self.mdns.has_node(&peer) { self.floodsub.remove_node_from_partial_view(&peer); } } + } } } } @@ -150,11 +156,12 @@ async fn main() -> Result<(), Box> { task::block_on(future::poll_fn(move |cx: &mut Context<'_>| { loop { match stdin.try_poll_next_unpin(cx)? { - Poll::Ready(Some(line)) => swarm.behaviour_mut() + Poll::Ready(Some(line)) => swarm + .behaviour_mut() .floodsub .publish(floodsub_topic.clone(), line.as_bytes()), Poll::Ready(None) => panic!("Stdin closed"), - Poll::Pending => break + Poll::Pending => break, } } loop { diff --git a/examples/distributed-key-value-store.rs b/examples/distributed-key-value-store.rs index 9ab5b7206d7..2e5fa5a8531 100644 --- a/examples/distributed-key-value-store.rs +++ b/examples/distributed-key-value-store.rs @@ -44,26 +44,19 @@ use async_std::{io, task}; use futures::prelude::*; use libp2p::kad::record::store::MemoryStore; use libp2p::kad::{ - AddProviderOk, - Kademlia, - KademliaEvent, - PeerRecord, - PutRecordOk, - QueryResult, - Quorum, - Record, - record::Key, + record::Key, AddProviderOk, Kademlia, KademliaEvent, PeerRecord, PutRecordOk, QueryResult, + Quorum, Record, }; use libp2p::{ - NetworkBehaviour, - PeerId, - Swarm, - development_transport, - identity, + development_transport, identity, mdns::{Mdns, MdnsConfig, MdnsEvent}, - swarm::{NetworkBehaviourEventProcess, SwarmEvent} + swarm::{NetworkBehaviourEventProcess, SwarmEvent}, + NetworkBehaviour, PeerId, Swarm, +}; +use std::{ + error::Error, + task::{Context, Poll}, }; -use std::{error::Error, task::{Context, Poll}}; #[async_std::main] async fn main() -> Result<(), Box> { @@ -80,7 +73,7 @@ async fn main() -> Result<(), Box> { #[derive(NetworkBehaviour)] struct MyBehaviour { kademlia: Kademlia, - mdns: Mdns + mdns: Mdns, } impl NetworkBehaviourEventProcess for MyBehaviour { @@ -112,7 +105,11 @@ async fn main() -> Result<(), Box> { eprintln!("Failed to get providers: {:?}", err); } QueryResult::GetRecord(Ok(ok)) => { - for PeerRecord { record: Record { key, value, .. }, ..} in ok.records { + for PeerRecord { + record: Record { key, value, .. }, + .. + } in ok.records + { println!( "Got record {:?} {:?}", std::str::from_utf8(key.as_ref()).unwrap(), @@ -133,7 +130,8 @@ async fn main() -> Result<(), Box> { eprintln!("Failed to put record: {:?}", err); } QueryResult::StartProviding(Ok(AddProviderOk { key })) => { - println!("Successfully put provider record {:?}", + println!( + "Successfully put provider record {:?}", std::str::from_utf8(key.as_ref()).unwrap() ); } @@ -141,7 +139,7 @@ async fn main() -> Result<(), Box> { eprintln!("Failed to put provider record: {:?}", err); } _ => {} - } + }, _ => {} } } @@ -167,9 +165,11 @@ async fn main() -> Result<(), Box> { task::block_on(future::poll_fn(move |cx: &mut Context<'_>| { loop { match stdin.try_poll_next_unpin(cx)? { - Poll::Ready(Some(line)) => handle_input_line(&mut swarm.behaviour_mut().kademlia, line), + Poll::Ready(Some(line)) => { + handle_input_line(&mut swarm.behaviour_mut().kademlia, line) + } Poll::Ready(None) => panic!("Stdin closed"), - Poll::Pending => break + Poll::Pending => break, } } loop { @@ -209,7 +209,7 @@ fn handle_input_line(kademlia: &mut Kademlia, line: String) { Some(key) => Key::new(&key), None => { eprintln!("Expected key"); - return + return; } } }; @@ -240,8 +240,10 @@ fn handle_input_line(kademlia: &mut Kademlia, line: String) { publisher: None, expires: None, }; - kademlia.put_record(record, Quorum::One).expect("Failed to store record locally."); - }, + kademlia + .put_record(record, Quorum::One) + .expect("Failed to store record locally."); + } Some("PUT_PROVIDER") => { let key = { match args.next() { @@ -253,7 +255,9 @@ fn handle_input_line(kademlia: &mut Kademlia, line: String) { } }; - kademlia.start_providing(key).expect("Failed to start providing key"); + kademlia + .start_providing(key) + .expect("Failed to start providing key"); } _ => { eprintln!("expected GET, GET_PROVIDERS, PUT or PUT_PROVIDER"); diff --git a/examples/gossipsub-chat.rs b/examples/gossipsub-chat.rs index f56fe708075..bbf1190f8c3 100644 --- a/examples/gossipsub-chat.rs +++ b/examples/gossipsub-chat.rs @@ -28,7 +28,7 @@ //! chat members and everyone will receive all messages. //! //! In order to get the nodes to connect, take note of the listening addresses of the first -//! instance and start the second with one of the addresses as the first argument. In the first +//! instance and start the second with one of the addresses as the first argument. In the first //! terminal window, run: //! //! ```sh diff --git a/examples/ipfs-kad.rs b/examples/ipfs-kad.rs index c1e7e5c66d2..b3e6b211b46 100644 --- a/examples/ipfs-kad.rs +++ b/examples/ipfs-kad.rs @@ -25,28 +25,20 @@ use async_std::task; use futures::StreamExt; +use libp2p::kad::record::store::MemoryStore; +use libp2p::kad::{GetClosestPeersError, Kademlia, KademliaConfig, KademliaEvent, QueryResult}; use libp2p::{ - Multiaddr, + development_transport, identity, swarm::{Swarm, SwarmEvent}, - PeerId, - identity, - development_transport -}; -use libp2p::kad::{ - Kademlia, - KademliaConfig, - KademliaEvent, - GetClosestPeersError, - QueryResult, + Multiaddr, PeerId, }; -use libp2p::kad::record::store::MemoryStore; use std::{env, error::Error, str::FromStr, time::Duration}; const BOOTNODES: [&'static str; 4] = [ "QmNnooDu7bfjPFoTZYxMNLWUQJyrVwtbZg5gBMjTezGAJN", "QmQCU2EcMqAqQPR2i9bChDtGNJchTbq5TbXJJ16u19uLTa", "QmbLHAnMoJPWSCR5Zhtx6BHJX9KiKNN6tpvbUcqanj75Nb", - "QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt" + "QmcZf59bWwK5XFi76CZX8cbJ4BhTzzA3gU1ZjYZcYW3dwt", ]; #[async_std::main] @@ -96,9 +88,10 @@ async fn main() -> Result<(), Box> { if let SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { result: QueryResult::GetClosestPeers(result), .. - }) = event { + }) = event + { match result { - Ok(ok) => + Ok(ok) => { if !ok.peers.is_empty() { println!("Query finished with closest peers: {:#?}", ok.peers) } else { @@ -106,7 +99,8 @@ async fn main() -> Result<(), Box> { // should always be at least 1 reachable peer. println!("Query finished with no closest peers.") } - Err(GetClosestPeersError::Timeout { peers, .. }) => + } + Err(GetClosestPeersError::Timeout { peers, .. }) => { if !peers.is_empty() { println!("Query timed out with closest peers: {:#?}", peers) } else { @@ -114,6 +108,7 @@ async fn main() -> Result<(), Box> { // should always be at least 1 reachable peer. println!("Query timed out with no closest peers."); } + } }; break; diff --git a/examples/ipfs-private.rs b/examples/ipfs-private.rs index 529e0482b64..2d7422392a2 100644 --- a/examples/ipfs-private.rs +++ b/examples/ipfs-private.rs @@ -34,9 +34,7 @@ use async_std::{io, task}; use futures::{future, prelude::*}; use libp2p::{ - core::{ - either::EitherTransport, muxing::StreamMuxerBox, transport, - }, + core::{either::EitherTransport, muxing::StreamMuxerBox, transport}, gossipsub::{self, Gossipsub, GossipsubConfigBuilder, GossipsubEvent, MessageAuthenticity}, identify::{Identify, IdentifyConfig, IdentifyEvent}, identity, diff --git a/examples/mdns-passive-discovery.rs b/examples/mdns-passive-discovery.rs index bce18dea1ee..a63ec7d5afe 100644 --- a/examples/mdns-passive-discovery.rs +++ b/examples/mdns-passive-discovery.rs @@ -20,10 +20,10 @@ use futures::StreamExt; use libp2p::{ - identity, - mdns::{Mdns, MdnsConfig, MdnsEvent}, + identity, + mdns::{Mdns, MdnsConfig, MdnsEvent}, swarm::{Swarm, SwarmEvent}, - PeerId + PeerId, }; use std::error::Error; diff --git a/examples/ping.rs b/examples/ping.rs index 151e9a5b5dd..f38b4fc4011 100644 --- a/examples/ping.rs +++ b/examples/ping.rs @@ -79,7 +79,7 @@ fn main() -> Result<(), Box> { block_on(future::poll_fn(move |cx| loop { match swarm.poll_next_unpin(cx) { Poll::Ready(Some(event)) => match event { - SwarmEvent::NewListenAddr{ address, .. } => println!("Listening on {:?}", address), + SwarmEvent::NewListenAddr { address, .. } => println!("Listening on {:?}", address), SwarmEvent::Behaviour(event) => println!("{:?}", event), _ => {} }, diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 72b4126751c..5e8b14de63a 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -20,12 +20,17 @@ //! Protocol negotiation strategies for the peer acting as the dialer. +use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError, SIM_OPEN_ID}; use crate::{Negotiated, NegotiationError, Version}; -use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine, SIM_OPEN_ID}; use futures::{future::Either, prelude::*}; -use std::{cmp::Ordering, convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}}; - +use std::{ + cmp::Ordering, + convert::TryFrom as _, + iter, mem, + pin::Pin, + task::{Context, Poll}, +}; /// Returns a `Future` that negotiates a protocol on the given I/O stream /// for a peer acting as the _dialer_ (or _initiator_). @@ -50,12 +55,12 @@ use std::{cmp::Ordering, convert::TryFrom as _, iter, mem, pin::Pin, task::{Cont pub fn dialer_select_proto( inner: R, protocols: I, - version: Version + version: Version, ) -> DialerSelectFuture where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { let iter = protocols.into_iter(); match version { @@ -66,7 +71,7 @@ where } else { Either::Right(dialer_select_proto_parallel(inner, iter, version)) } - }, + } Version::V1SimultaneousOpen => { Either::Left(dialer_select_proto_serial(inner, iter, version)) } @@ -88,12 +93,12 @@ pub type DialerSelectFuture = Either, DialerSelectPa pub(crate) fn dialer_select_proto_serial( inner: R, protocols: I, - version: Version + version: Version, ) -> DialerSelectSeq where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { let protocols = protocols.into_iter().peekable(); DialerSelectSeq { @@ -101,7 +106,7 @@ where protocols, state: SeqState::SendHeader { io: MessageIO::new(inner), - } + }, } } @@ -117,20 +122,20 @@ where pub(crate) fn dialer_select_proto_parallel( inner: R, protocols: I, - version: Version + version: Version, ) -> DialerSelectPar where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { let protocols = protocols.into_iter(); DialerSelectPar { version, protocols, state: ParState::SendHeader { - io: MessageIO::new(inner) - } + io: MessageIO::new(inner), + }, } } @@ -145,20 +150,45 @@ pub struct DialerSelectSeq { } enum SeqState { - SendHeader { io: MessageIO }, + SendHeader { + io: MessageIO, + }, // Simultaneous open protocol extension - SendSimOpen { io: MessageIO, protocol: Option }, - FlushSimOpen { io: MessageIO, protocol: N }, - AwaitSimOpen { io: MessageIO, protocol: N }, - SimOpenPhase { selection: SimOpenPhase, protocol: N }, - Responder { responder: crate::ListenerSelectFuture }, + SendSimOpen { + io: MessageIO, + protocol: Option, + }, + FlushSimOpen { + io: MessageIO, + protocol: N, + }, + AwaitSimOpen { + io: MessageIO, + protocol: N, + }, + SimOpenPhase { + selection: SimOpenPhase, + protocol: N, + }, + Responder { + responder: crate::ListenerSelectFuture, + }, // Standard multistream-select protocol - SendProtocol { io: MessageIO, protocol: N }, - FlushProtocol { io: MessageIO, protocol: N }, - AwaitProtocol { io: MessageIO, protocol: N }, - Done + SendProtocol { + io: MessageIO, + protocol: N, + }, + FlushProtocol { + io: MessageIO, + protocol: N, + }, + AwaitProtocol { + io: MessageIO, + protocol: N, + }, + Done, } impl Future for DialerSelectSeq @@ -167,7 +197,7 @@ where // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, I: Iterator, - I::Item: AsRef<[u8]> + Clone + I::Item: AsRef<[u8]> + Clone, { type Output = Result<(I::Item, Negotiated, Role), NegotiationError>; @@ -178,11 +208,11 @@ where match mem::replace(this.state, SeqState::Done) { SeqState::SendHeader { mut io } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = SeqState::SendHeader { io }; - return Poll::Pending - }, + return Poll::Pending; + } } let h = HeaderLine::from(*this.version); @@ -206,11 +236,11 @@ where SeqState::SendSimOpen { mut io, protocol } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = SeqState::SendSimOpen { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } } match protocol { @@ -221,11 +251,16 @@ where } let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; - *this.state = SeqState::SendSimOpen { io, protocol: Some(protocol) }; + *this.state = SeqState::SendSimOpen { + io, + protocol: Some(protocol), + }; } Some(protocol) => { let p = Protocol::try_from(protocol.as_ref())?; - if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) { + if let Err(err) = + Pin::new(&mut io).start_send(Message::Protocol(p.clone())) + { return Poll::Ready(Err(From::from(err))); } log::debug!("Dialer: Proposed protocol: {}", p); @@ -237,13 +272,11 @@ where SeqState::FlushSimOpen { mut io, protocol } => { match Pin::new(&mut io).poll_flush(cx)? { - Poll::Ready(()) => { - *this.state = SeqState::AwaitSimOpen { io, protocol } - }, + Poll::Ready(()) => *this.state = SeqState::AwaitSimOpen { io, protocol }, Poll::Pending => { *this.state = SeqState::FlushSimOpen { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } } } @@ -252,7 +285,7 @@ where Poll::Ready(Some(msg)) => msg, Poll::Pending => { *this.state = SeqState::AwaitSimOpen { io, protocol }; - return Poll::Pending + return Poll::Pending; } // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -266,23 +299,32 @@ where } Message::Protocol(p) if p == SIM_OPEN_ID => { let selection = SimOpenPhase { - state: SimOpenState::SendNonce{ io }, + state: SimOpenState::SendNonce { io }, + }; + *this.state = SeqState::SimOpenPhase { + selection, + protocol, }; - *this.state = SeqState::SimOpenPhase { selection, protocol }; } Message::NotAvailable => { *this.state = SeqState::AwaitProtocol { io, protocol } } - _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } } - SeqState::SimOpenPhase { mut selection, protocol } => { + SeqState::SimOpenPhase { + mut selection, + protocol, + } => { let (io, selection_res) = match Pin::new(&mut selection).poll(cx)? { Poll::Ready((io, res)) => (io, res), Poll::Pending => { - *this.state = SeqState::SimOpenPhase { selection, protocol }; - return Poll::Pending + *this.state = SeqState::SimOpenPhase { + selection, + protocol, + }; + return Poll::Pending; } }; @@ -293,29 +335,32 @@ where Role::Responder => { let protocols: Vec<_> = this.protocols.collect(); *this.state = SeqState::Responder { - responder: crate::listener_select::listener_select_proto_no_header(io, std::iter::once(protocol).chain(protocols.into_iter())), + responder: crate::listener_select::listener_select_proto_no_header( + io, + std::iter::once(protocol).chain(protocols.into_iter()), + ), } - }, + } } } - SeqState::Responder { mut responder } => { - match Pin::new(&mut responder ).poll(cx) { - Poll::Ready(res) => return Poll::Ready(res.map(|(p, io)| (p, io, Role::Responder))), - Poll::Pending => { - *this.state = SeqState::Responder { responder }; - return Poll::Pending - } + SeqState::Responder { mut responder } => match Pin::new(&mut responder).poll(cx) { + Poll::Ready(res) => { + return Poll::Ready(res.map(|(p, io)| (p, io, Role::Responder))) } - } + Poll::Pending => { + *this.state = SeqState::Responder { responder }; + return Poll::Pending; + } + }, SeqState::SendProtocol { mut io, protocol } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = SeqState::SendProtocol { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } } let p = Protocol::try_from(protocol.as_ref())?; @@ -328,7 +373,9 @@ where *this.state = SeqState::FlushProtocol { io, protocol } } else { match this.version { - Version::V1 | Version::V1SimultaneousOpen => *this.state = SeqState::FlushProtocol { io, protocol }, + Version::V1 | Version::V1SimultaneousOpen => { + *this.state = SeqState::FlushProtocol { io, protocol } + } // This is the only effect that `V1Lazy` has compared to `V1`: // Optimistically settling on the only protocol that // the dialer supports for this negotiation. Notably, @@ -337,7 +384,7 @@ where log::debug!("Dialer: Expecting proposed protocol: {}", p); let hl = HeaderLine::from(Version::V1Lazy); let io = Negotiated::expecting(io.into_reader(), p, Some(hl)); - return Poll::Ready(Ok((protocol, io, Role::Initiator))) + return Poll::Ready(Ok((protocol, io, Role::Initiator))); } } } @@ -345,13 +392,11 @@ where SeqState::FlushProtocol { mut io, protocol } => { match Pin::new(&mut io).poll_flush(cx)? { - Poll::Ready(()) => { - *this.state = SeqState::AwaitProtocol { io, protocol } - } , + Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol }, Poll::Pending => { *this.state = SeqState::FlushProtocol { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } } } @@ -360,7 +405,7 @@ where Poll::Ready(Some(msg)) => msg, Poll::Pending => { *this.state = SeqState::AwaitProtocol { io, protocol }; - return Poll::Pending + return Poll::Pending; } // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -368,16 +413,18 @@ where Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), }; - match msg { Message::Header(v) if v == HeaderLine::from(*this.version) => { *this.state = SeqState::AwaitProtocol { io, protocol }; } Message::Protocol(p) if p == SIM_OPEN_ID => { let selection = SimOpenPhase { - state: SimOpenState::SendNonce{ io }, + state: SimOpenState::SendNonce { io }, + }; + *this.state = SeqState::SimOpenPhase { + selection, + protocol, }; - *this.state = SeqState::SimOpenPhase { selection, protocol }; } Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { log::debug!("Dialer: Received confirmation for protocol: {}", p); @@ -385,16 +432,18 @@ where return Poll::Ready(Ok((protocol, io, Role::Initiator))); } Message::NotAvailable => { - log::debug!("Dialer: Received rejection of protocol: {}", - String::from_utf8_lossy(protocol.as_ref())); + log::debug!( + "Dialer: Received rejection of protocol: {}", + String::from_utf8_lossy(protocol.as_ref()) + ); let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; *this.state = SeqState::SendProtocol { io, protocol } } - _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } } - SeqState::Done => panic!("SeqState::poll called after completion") + SeqState::Done => panic!("SeqState::poll called after completion"), } } } @@ -435,16 +484,15 @@ where type Output = Result<(MessageIO, Role), NegotiationError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { match mem::replace(&mut self.state, SimOpenState::Done) { SimOpenState::SendNonce { mut io } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { self.state = SimOpenState::SendNonce { io }; - return Poll::Pending - }, + return Poll::Pending; + } } let local_nonce = rand::random(); @@ -453,29 +501,27 @@ where return Poll::Ready(Err(From::from(err))); } - self.state = SimOpenState::FlushNonce { - io, - local_nonce, - }; - }, - SimOpenState::FlushNonce { mut io, local_nonce } => { - match Pin::new(&mut io).poll_flush(cx)? { - Poll::Ready(()) => self.state = SimOpenState::ReadNonce { - io, - local_nonce, - }, - Poll::Pending => { - self.state =SimOpenState::FlushNonce { io, local_nonce }; - return Poll::Pending - }, + self.state = SimOpenState::FlushNonce { io, local_nonce }; + } + SimOpenState::FlushNonce { + mut io, + local_nonce, + } => match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => self.state = SimOpenState::ReadNonce { io, local_nonce }, + Poll::Pending => { + self.state = SimOpenState::FlushNonce { io, local_nonce }; + return Poll::Pending; } }, - SimOpenState::ReadNonce { mut io, local_nonce } => { + SimOpenState::ReadNonce { + mut io, + local_nonce, + } => { let msg = match Pin::new(&mut io).poll_next(cx)? { Poll::Ready(Some(msg)) => msg, Poll::Pending => { self.state = SimOpenState::ReadNonce { io, local_nonce }; - return Poll::Pending + return Poll::Pending; } // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -499,18 +545,18 @@ where Message::Protocol(_) => { self.state = SimOpenState::ReadNonce { io, local_nonce }; } - Message::Select(remote_nonce) => { + Message::Select(remote_nonce) => { match local_nonce.cmp(&remote_nonce) { Ordering::Equal => { // Start over. self.state = SimOpenState::SendNonce { io }; - }, + } Ordering::Greater => { self.state = SimOpenState::SendRole { io, local_role: Role::Initiator, }; - }, + } Ordering::Less => { self.state = SimOpenState::SendRole { io, @@ -521,14 +567,14 @@ where } _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } - }, + } SimOpenState::SendRole { mut io, local_role } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { self.state = SimOpenState::SendRole { io, local_role }; - return Poll::Pending - }, + return Poll::Pending; + } } let msg = match local_role { @@ -541,22 +587,22 @@ where } self.state = SimOpenState::FlushRole { io, local_role }; - }, + } SimOpenState::FlushRole { mut io, local_role } => { match Pin::new(&mut io).poll_flush(cx)? { Poll::Ready(()) => self.state = SimOpenState::ReadRole { io, local_role }, Poll::Pending => { - self.state =SimOpenState::FlushRole { io, local_role }; - return Poll::Pending - }, + self.state = SimOpenState::FlushRole { io, local_role }; + return Poll::Pending; + } } - }, + } SimOpenState::ReadRole { mut io, local_role } => { let remote_msg = match Pin::new(&mut io).poll_next(cx)? { Poll::Ready(Some(msg)) => msg, Poll::Pending => { self.state = SimOpenState::ReadRole { io, local_role }; - return Poll::Pending + return Poll::Pending; } // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -568,12 +614,12 @@ where Role::Initiator if remote_msg == Message::Responder => Ok((io, local_role)), Role::Responder if remote_msg == Message::Initiator => Ok((io, local_role)), - _ => Err(ProtocolError::InvalidMessage.into()) + _ => Err(ProtocolError::InvalidMessage.into()), }; - return Poll::Ready(result) - }, - SimOpenState::Done => panic!("SimOpenPhase::poll called after completion") + return Poll::Ready(result); + } + SimOpenState::Done => panic!("SimOpenPhase::poll called after completion"), } } } @@ -595,7 +641,7 @@ enum ParState { Flush { io: MessageIO }, RecvProtocols { io: MessageIO }, SendProtocol { io: MessageIO, protocol: N }, - Done + Done, } impl Future for DialerSelectPar @@ -604,7 +650,7 @@ where // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, I: Iterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { type Output = Result<(I::Item, Negotiated, Role), NegotiationError>; @@ -615,11 +661,11 @@ where match mem::replace(this.state, ParState::Done) { ParState::SendHeader { mut io } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = ParState::SendHeader { io }; - return Poll::Pending - }, + return Poll::Pending; + } } let msg = Message::Header(HeaderLine::from(*this.version)); @@ -632,11 +678,11 @@ where ParState::SendProtocolsRequest { mut io } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = ParState::SendProtocolsRequest { io }; - return Poll::Pending - }, + return Poll::Pending; + } } if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) { @@ -647,22 +693,20 @@ where *this.state = ParState::Flush { io } } - ParState::Flush { mut io } => { - match Pin::new(&mut io).poll_flush(cx)? { - Poll::Ready(()) => *this.state = ParState::RecvProtocols { io }, - Poll::Pending => { - *this.state = ParState::Flush { io }; - return Poll::Pending - }, + ParState::Flush { mut io } => match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => *this.state = ParState::RecvProtocols { io }, + Poll::Pending => { + *this.state = ParState::Flush { io }; + return Poll::Pending; } - } + }, ParState::RecvProtocols { mut io } => { let msg = match Pin::new(&mut io).poll_next(cx)? { Poll::Ready(Some(msg)) => msg, Poll::Pending => { *this.state = ParState::RecvProtocols { io }; - return Poll::Pending + return Poll::Pending; } // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -675,12 +719,15 @@ where *this.state = ParState::RecvProtocols { io } } Message::Protocols(supported) => { - let protocol = this.protocols.by_ref() - .find(|p| supported.iter().any(|s| - s.as_ref() == p.as_ref())) + let protocol = this + .protocols + .by_ref() + .find(|p| supported.iter().any(|s| s.as_ref() == p.as_ref())) .ok_or(NegotiationError::Failed)?; - log::debug!("Dialer: Found supported protocol: {}", - String::from_utf8_lossy(protocol.as_ref())); + log::debug!( + "Dialer: Found supported protocol: {}", + String::from_utf8_lossy(protocol.as_ref()) + ); *this.state = ParState::SendProtocol { io, protocol }; } _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), @@ -689,11 +736,11 @@ where ParState::SendProtocol { mut io, protocol } => { match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, + Poll::Ready(()) => {} Poll::Pending => { *this.state = ParState::SendProtocol { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } } let p = Protocol::try_from(protocol.as_ref())?; @@ -704,10 +751,10 @@ where log::debug!("Dialer: Expecting proposed protocol: {}", p); let io = Negotiated::expecting(io.into_reader(), p, None); - return Poll::Ready(Ok((protocol, io, Role::Initiator))) + return Poll::Ready(Ok((protocol, io, Role::Initiator))); } - ParState::Done => panic!("ParState::poll called after completion") + ParState::Done => panic!("ParState::poll called after completion"), } } } diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 593c915ac2b..abb622eed30 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -18,9 +18,15 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::{Bytes, BytesMut, Buf as _, BufMut as _}; -use futures::{prelude::*, io::IoSlice}; -use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16}; +use bytes::{Buf as _, BufMut as _, Bytes, BytesMut}; +use futures::{io::IoSlice, prelude::*}; +use std::{ + convert::TryFrom as _, + io, + pin::Pin, + task::{Context, Poll}, + u16, +}; const MAX_LEN_BYTES: u16 = 2; const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; @@ -50,7 +56,10 @@ pub struct LengthDelimited { #[derive(Debug, Copy, Clone, PartialEq, Eq)] enum ReadState { /// We are currently reading the length of the next frame of data. - ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize }, + ReadLength { + buf: [u8; MAX_LEN_BYTES as usize], + pos: usize, + }, /// We are currently reading the frame of data itself. ReadData { len: u16, pos: usize }, } @@ -59,7 +68,7 @@ impl Default for ReadState { fn default() -> Self { ReadState::ReadLength { buf: [0; MAX_LEN_BYTES as usize], - pos: 0 + pos: 0, } } } @@ -106,10 +115,12 @@ impl LengthDelimited { /// /// After this method returns `Poll::Ready`, the write buffer of frames /// submitted to the `Sink` is guaranteed to be empty. - pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll> + pub fn poll_write_buffer( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> where - R: AsyncWrite + R: AsyncWrite, { let mut this = self.project(); @@ -119,7 +130,8 @@ impl LengthDelimited { Poll::Ready(Ok(0)) => { return Poll::Ready(Err(io::Error::new( io::ErrorKind::WriteZero, - "Failed to write buffered frame."))) + "Failed to write buffered frame.", + ))) } Poll::Ready(Ok(n)) => this.write_buffer.advance(n), Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), @@ -132,7 +144,7 @@ impl LengthDelimited { impl Stream for LengthDelimited where - R: AsyncRead + R: AsyncRead, { type Item = Result; @@ -142,7 +154,7 @@ where loop { match this.read_state { ReadState::ReadLength { buf, pos } => { - match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) { + match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) { Poll::Ready(Ok(0)) => { if *pos == 0 { return Poll::Ready(None); @@ -160,11 +172,10 @@ where if (buf[*pos - 1] & 0x80) == 0 { // MSB is not set, indicating the end of the length prefix. - let (len, _) = unsigned_varint::decode::u16(buf) - .map_err(|e| { - log::debug!("invalid length prefix: {}", e); - io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") - })?; + let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| { + log::debug!("invalid length prefix: {}", e); + io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") + })?; if len >= 1 { *this.read_state = ReadState::ReadData { len, pos: 0 }; @@ -179,12 +190,19 @@ where // See the module documentation about the max frame len. return Poll::Ready(Some(Err(io::Error::new( io::ErrorKind::InvalidData, - "Maximum frame length exceeded")))); + "Maximum frame length exceeded", + )))); } } ReadState::ReadData { len, pos } => { - match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) { - Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))), + match this + .inner + .as_mut() + .poll_read(cx, &mut this.read_buffer[*pos..]) + { + Poll::Ready(Ok(0)) => { + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))) + } Poll::Ready(Ok(n)) => *pos += n, Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), @@ -214,7 +232,7 @@ where // implied to be roughly 2 * MAX_FRAME_SIZE. if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize { match self.as_mut().poll_write_buffer(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -233,7 +251,8 @@ where _ => { return Err(io::Error::new( io::ErrorKind::InvalidData, - "Maximum frame size exceeded.")) + "Maximum frame size exceeded.", + )) } }; @@ -249,7 +268,7 @@ where fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Write all buffered frame data to the underlying I/O stream. match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -264,7 +283,7 @@ where fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Write all buffered frame data to the underlying I/O stream. match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -283,7 +302,7 @@ where #[derive(Debug)] pub struct LengthDelimitedReader { #[pin] - inner: LengthDelimited + inner: LengthDelimited, } impl LengthDelimitedReader { @@ -306,7 +325,7 @@ impl LengthDelimitedReader { impl Stream for LengthDelimitedReader where - R: AsyncRead + R: AsyncRead, { type Item = Result; @@ -317,17 +336,19 @@ where impl AsyncWrite for LengthDelimitedReader where - R: AsyncWrite + R: AsyncWrite, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) - -> Poll> - { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { // `this` here designates the `LengthDelimited`. let mut this = self.project().inner; // We need to flush any data previously written with the `LengthDelimited`. match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -344,15 +365,17 @@ where self.project().inner.poll_close(cx) } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) - -> Poll> - { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { // `this` here designates the `LengthDelimited`. let mut this = self.project().inner; // We need to flush any data previously written with the `LengthDelimited`. match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } @@ -366,7 +389,7 @@ where mod tests { use crate::length_delimited::LengthDelimited; use async_std::net::{TcpListener, TcpStream}; - use futures::{prelude::*, io::Cursor}; + use futures::{io::Cursor, prelude::*}; use quickcheck::*; use std::io::ErrorKind; @@ -394,9 +417,7 @@ mod tests { let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8]; data.extend(frame.clone().into_iter()); let mut framed = LengthDelimited::new(Cursor::new(data)); - let recved = futures::executor::block_on(async move { - framed.next().await - }).unwrap(); + let recved = futures::executor::block_on(async move { framed.next().await }).unwrap(); assert_eq!(recved.unwrap(), frame); } @@ -405,9 +426,7 @@ mod tests { let mut data = vec![0x81, 0x81, 0x1]; data.extend((0..16513).map(|_| 0)); let mut framed = LengthDelimited::new(Cursor::new(data)); - let recved = futures::executor::block_on(async move { - framed.next().await.unwrap() - }); + let recved = futures::executor::block_on(async move { framed.next().await.unwrap() }); if let Err(io_err) = recved { assert_eq!(io_err.kind(), ErrorKind::InvalidData) @@ -479,7 +498,8 @@ mod tests { let expected_frames = frames.clone(); let server = async_std::task::spawn(async move { let socket = listener.accept().await.unwrap().0; - let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket)); + let mut connec = + rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket)); let mut buf = vec![0u8; 0]; for expected in expected_frames { diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index 8fa88baf4ac..1239b016939 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -94,10 +94,10 @@ mod negotiated; mod protocol; mod tests; -pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError}; -pub use self::protocol::ProtocolError; pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture, Role}; pub use self::listener_select::{listener_select_proto, ListenerSelectFuture}; +pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError}; +pub use self::protocol::ProtocolError; /// Supported multistream-select versions. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -157,4 +157,4 @@ impl Default for Version { fn default() -> Self { Version::V1 } -} \ No newline at end of file +} diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index 28139235ed4..15c06fb323a 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -21,12 +21,18 @@ //! Protocol negotiation strategies for the peer acting as the listener //! in a multistream-select protocol negotiation. +use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}; use crate::{Negotiated, NegotiationError}; -use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine}; use futures::prelude::*; use smallvec::SmallVec; -use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}}; +use std::{ + convert::TryFrom as _, + iter::FromIterator, + mem, + pin::Pin, + task::{Context, Poll}, +}; /// Returns a `Future` that negotiates a protocol on the given I/O stream /// for a peer acting as the _listener_ (or _responder_). @@ -35,18 +41,18 @@ use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Conte /// computation that performs the protocol negotiation with the remote. The /// returned `Future` resolves with the name of the negotiated protocol and /// a [`Negotiated`] I/O stream. -pub fn listener_select_proto( - inner: R, - protocols: I, -) -> ListenerSelectFuture +pub fn listener_select_proto(inner: R, protocols: I) -> ListenerSelectFuture where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { - listener_select_proto_with_state(State::RecvHeader { - io: MessageIO::new(inner) - }, protocols) + listener_select_proto_with_state( + State::RecvHeader { + io: MessageIO::new(inner), + }, + protocols, + ) } /// Used when selected as a [`crate::Role::Responder`] during [`crate::dialer_select_proto`] @@ -58,12 +64,9 @@ pub(crate) fn listener_select_proto_no_header( where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { - listener_select_proto_with_state( - State::RecvMessage { io }, - protocols, - ) + listener_select_proto_with_state(State::RecvMessage { io }, protocols) } fn listener_select_proto_with_state( @@ -73,14 +76,18 @@ fn listener_select_proto_with_state( where R: AsyncRead + AsyncWrite, I: IntoIterator, - I::Item: AsRef<[u8]> + I::Item: AsRef<[u8]>, { - let protocols = protocols.into_iter().filter_map(|n| - match Protocol::try_from(n.as_ref()) { + let protocols = protocols + .into_iter() + .filter_map(|n| match Protocol::try_from(n.as_ref()) { Ok(p) => Some((n, p)), Err(e) => { - log::warn!("Listener: Ignoring invalid protocol: {} due to {}", - String::from_utf8_lossy(n.as_ref()), e); + log::warn!( + "Listener: Ignoring invalid protocol: {} due to {}", + String::from_utf8_lossy(n.as_ref()), + e + ); None } }); @@ -109,19 +116,25 @@ pub struct ListenerSelectFuture { } enum State { - RecvHeader { io: MessageIO }, - SendHeader { io: MessageIO }, - RecvMessage { io: MessageIO }, + RecvHeader { + io: MessageIO, + }, + SendHeader { + io: MessageIO, + }, + RecvMessage { + io: MessageIO, + }, SendMessage { io: MessageIO, message: Message, - protocol: Option + protocol: Option, }, Flush { io: MessageIO, - protocol: Option + protocol: Option, }, - Done + Done, } impl Future for ListenerSelectFuture @@ -129,7 +142,7 @@ where // The Unpin bound here is required because we produce a `Negotiated` as the output. // It also makes the implementation considerably easier to write. R: AsyncRead + AsyncWrite + Unpin, - N: AsRef<[u8]> + Clone + N: AsRef<[u8]> + Clone, { type Output = Result<(N, Negotiated), NegotiationError>; @@ -140,14 +153,12 @@ where match mem::replace(this.state, State::Done) { State::RecvHeader { mut io } => { match io.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(Message::Header(h)))) => { - match h { - HeaderLine::V1 => *this.state = State::SendHeader { io } - } - } + Poll::Ready(Some(Ok(Message::Header(h)))) => match h { + HeaderLine::V1 => *this.state = State::SendHeader { io }, + }, Poll::Ready(Some(Ok(_))) => { return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) - }, + } Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))), // Treat EOF error as [`NegotiationError::Failed`], not as // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O @@ -155,7 +166,7 @@ where Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), Poll::Pending => { *this.state = State::RecvHeader { io }; - return Poll::Pending + return Poll::Pending; } } } @@ -164,9 +175,9 @@ where match Pin::new(&mut io).poll_ready(cx) { Poll::Pending => { *this.state = State::SendHeader { io }; - return Poll::Pending - }, - Poll::Ready(Ok(())) => {}, + return Poll::Pending; + } + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } @@ -204,28 +215,37 @@ where // the dialer also raises `NegotiationError::Failed` when finally // reading the `N/A` response. if let ProtocolError::InvalidMessage = &err { - log::trace!("Listener: Negotiation failed with invalid \ - message after protocol rejection."); - return Poll::Ready(Err(NegotiationError::Failed)) + log::trace!( + "Listener: Negotiation failed with invalid \ + message after protocol rejection." + ); + return Poll::Ready(Err(NegotiationError::Failed)); } if let ProtocolError::IoError(e) = &err { if e.kind() == std::io::ErrorKind::UnexpectedEof { - log::trace!("Listener: Negotiation failed with EOF \ - after protocol rejection."); - return Poll::Ready(Err(NegotiationError::Failed)) + log::trace!( + "Listener: Negotiation failed with EOF \ + after protocol rejection." + ); + return Poll::Ready(Err(NegotiationError::Failed)); } } } - return Poll::Ready(Err(From::from(err))) + return Poll::Ready(Err(From::from(err))); } }; match msg { Message::ListProtocols => { - let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect(); + let supported = + this.protocols.iter().map(|(_, p)| p).cloned().collect(); let message = Message::Protocols(supported); - *this.state = State::SendMessage { io, message, protocol: None } + *this.state = State::SendMessage { + io, + message, + protocol: None, + } } Message::Protocol(p) => { let protocol = this.protocols.iter().find_map(|(name, proto)| { @@ -240,28 +260,42 @@ where log::debug!("Listener: confirming protocol: {}", p); Message::Protocol(p.clone()) } else { - log::debug!("Listener: rejecting protocol: {}", - String::from_utf8_lossy(p.as_ref())); + log::debug!( + "Listener: rejecting protocol: {}", + String::from_utf8_lossy(p.as_ref()) + ); Message::NotAvailable }; - *this.state = State::SendMessage { io, message, protocol }; + *this.state = State::SendMessage { + io, + message, + protocol, + }; } - _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())) + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), } } - State::SendMessage { mut io, message, protocol } => { + State::SendMessage { + mut io, + message, + protocol, + } => { match Pin::new(&mut io).poll_ready(cx) { Poll::Pending => { - *this.state = State::SendMessage { io, message, protocol }; - return Poll::Pending - }, - Poll::Ready(Ok(())) => {}, + *this.state = State::SendMessage { + io, + message, + protocol, + }; + return Poll::Pending; + } + Poll::Ready(Ok(())) => {} Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } - if let Message::NotAvailable = &message { + if let Message::NotAvailable = &message { *this.last_sent_na = true; } else { *this.last_sent_na = false; @@ -278,26 +312,28 @@ where match Pin::new(&mut io).poll_flush(cx) { Poll::Pending => { *this.state = State::Flush { io, protocol }; - return Poll::Pending - }, + return Poll::Pending; + } Poll::Ready(Ok(())) => { // If a protocol has been selected, finish negotiation. // Otherwise expect to receive another message. match protocol { Some(protocol) => { - log::debug!("Listener: sent confirmed protocol: {}", - String::from_utf8_lossy(protocol.as_ref())); + log::debug!( + "Listener: sent confirmed protocol: {}", + String::from_utf8_lossy(protocol.as_ref()) + ); let io = Negotiated::completed(io.into_inner()); - return Poll::Ready(Ok((protocol, io))) + return Poll::Ready(Ok((protocol, io))); } - None => *this.state = State::RecvMessage { io } + None => *this.state = State::RecvMessage { io }, } } Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } } - State::Done => panic!("State::poll called after completion") + State::Done => panic!("State::poll called after completion"), } } } diff --git a/misc/multistream-select/src/negotiated.rs b/misc/multistream-select/src/negotiated.rs index e80d579f2b4..2f78daf0376 100644 --- a/misc/multistream-select/src/negotiated.rs +++ b/misc/multistream-select/src/negotiated.rs @@ -18,11 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::protocol::{Protocol, MessageReader, Message, ProtocolError, HeaderLine}; +use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError}; -use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready}; +use futures::{ + io::{IoSlice, IoSliceMut}, + prelude::*, + ready, +}; use pin_project::pin_project; -use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}}; +use std::{ + error::Error, + fmt, io, mem, + pin::Pin, + task::{Context, Poll}, +}; /// An I/O stream that has settled on an (application-layer) protocol to use. /// @@ -39,7 +48,7 @@ use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}}; #[derive(Debug)] pub struct Negotiated { #[pin] - state: State + state: State, } /// A `Future` that waits on the completion of protocol negotiation. @@ -57,12 +66,15 @@ where type Output = Result, NegotiationError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut io = self.inner.take().expect("NegotiatedFuture called after completion."); + let mut io = self + .inner + .take() + .expect("NegotiatedFuture called after completion."); match Negotiated::poll(Pin::new(&mut io), cx) { Poll::Pending => { self.inner = Some(io); Poll::Pending - }, + } Poll::Ready(Ok(())) => Poll::Ready(Ok(io)), Poll::Ready(Err(err)) => { self.inner = Some(io); @@ -75,7 +87,9 @@ where impl Negotiated { /// Creates a `Negotiated` in state [`State::Completed`]. pub(crate) fn completed(io: TInner) -> Self { - Negotiated { state: State::Completed { io } } + Negotiated { + state: State::Completed { io }, + } } /// Creates a `Negotiated` in state [`State::Expecting`] that is still @@ -83,25 +97,31 @@ impl Negotiated { pub(crate) fn expecting( io: MessageReader, protocol: Protocol, - header: Option + header: Option, ) -> Self { - Negotiated { state: State::Expecting { io, protocol, header } } + Negotiated { + state: State::Expecting { + io, + protocol, + header, + }, + } } /// Polls the `Negotiated` for completion. fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> where - TInner: AsyncRead + AsyncWrite + Unpin + TInner: AsyncRead + AsyncWrite + Unpin, { // Flush any pending negotiation data. match self.as_mut().poll_flush(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => { // If the remote closed the stream, it is important to still // continue reading the data that was sent, if any. if e.kind() != io::ErrorKind::WriteZero { - return Poll::Ready(Err(e.into())) + return Poll::Ready(Err(e.into())); } } } @@ -109,36 +129,52 @@ impl Negotiated { let mut this = self.project(); if let StateProj::Completed { .. } = this.state.as_mut().project() { - return Poll::Ready(Ok(())); + return Poll::Ready(Ok(())); } // Read outstanding protocol negotiation messages. loop { match mem::replace(&mut *this.state, State::Invalid) { - State::Expecting { mut io, header, protocol } => { + State::Expecting { + mut io, + header, + protocol, + } => { let msg = match Pin::new(&mut io).poll_next(cx)? { Poll::Ready(Some(msg)) => msg, Poll::Pending => { - *this.state = State::Expecting { io, header, protocol }; - return Poll::Pending - }, + *this.state = State::Expecting { + io, + header, + protocol, + }; + return Poll::Pending; + } Poll::Ready(None) => { return Poll::Ready(Err(ProtocolError::IoError( - io::ErrorKind::UnexpectedEof.into()).into())); + io::ErrorKind::UnexpectedEof.into(), + ) + .into())); } }; if let Message::Header(h) = &msg { if Some(h) == header.as_ref() { - *this.state = State::Expecting { io, protocol, header: None }; - continue + *this.state = State::Expecting { + io, + protocol, + header: None, + }; + continue; } } if let Message::Protocol(p) = &msg { if p.as_ref() == protocol.as_ref() { log::debug!("Negotiated: Received confirmation for protocol: {}", p); - *this.state = State::Completed { io: io.into_inner() }; + *this.state = State::Completed { + io: io.into_inner(), + }; return Poll::Ready(Ok(())); } } @@ -146,7 +182,7 @@ impl Negotiated { return Poll::Ready(Err(NegotiationError::Failed)); } - _ => panic!("Negotiated: Invalid state") + _ => panic!("Negotiated: Invalid state"), } } } @@ -178,7 +214,10 @@ enum State { /// In this state, a protocol has been agreed upon and I/O /// on the underlying stream can commence. - Completed { #[pin] io: R }, + Completed { + #[pin] + io: R, + }, /// Temporary state while moving the `io` resource from /// `Expecting` to `Completed`. @@ -187,11 +226,13 @@ enum State { impl AsyncRead for Negotiated where - TInner: AsyncRead + AsyncWrite + Unpin + TInner: AsyncRead + AsyncWrite + Unpin, { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) - -> Poll> - { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { loop { if let StateProj::Completed { io } = self.as_mut().project().state.project() { // If protocol negotiation is complete, commence with reading. @@ -201,7 +242,7 @@ where // Poll the `Negotiated`, driving protocol negotiation to completion, // including flushing of any remaining data. match self.as_mut().poll(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } @@ -217,19 +258,21 @@ where } }*/ - fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>]) - -> Poll> - { + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { loop { if let StateProj::Completed { io } = self.as_mut().project().state.project() { // If protocol negotiation is complete, commence with reading. - return io.poll_read_vectored(cx, bufs) + return io.poll_read_vectored(cx, bufs); } // Poll the `Negotiated`, driving protocol negotiation to completion, // including flushing of any remaining data. match self.as_mut().poll(cx) { - Poll::Ready(Ok(())) => {}, + Poll::Ready(Ok(())) => {} Poll::Pending => return Poll::Pending, Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), } @@ -239,9 +282,13 @@ where impl AsyncWrite for Negotiated where - TInner: AsyncWrite + AsyncRead + Unpin + TInner: AsyncWrite + AsyncRead + Unpin, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { match self.project().state.project() { StateProj::Completed { io } => io.poll_write(cx, buf), StateProj::Expecting { io, .. } => io.poll_write(cx, buf), @@ -261,7 +308,10 @@ where // Ensure all data has been flushed and expected negotiation messages // have been received. ready!(self.as_mut().poll(cx).map_err(Into::::into)?); - ready!(self.as_mut().poll_flush(cx).map_err(Into::::into)?); + ready!(self + .as_mut() + .poll_flush(cx) + .map_err(Into::::into)?); // Continue with the shutdown of the underlying I/O stream. match self.project().state.project() { @@ -271,9 +321,11 @@ where } } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) - -> Poll> - { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { match self.project().state.project() { StateProj::Completed { io } => io.poll_write_vectored(cx, bufs), StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs), @@ -307,7 +359,7 @@ impl From for NegotiationError { impl From for io::Error { fn from(err: NegotiationError) -> io::Error { if let NegotiationError::ProtocolError(e) = err { - return e.into() + return e.into(); } io::Error::new(io::ErrorKind::Other, err) } @@ -325,10 +377,10 @@ impl Error for NegotiationError { impl fmt::Display for NegotiationError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - NegotiationError::ProtocolError(p) => - fmt.write_fmt(format_args!("Protocol error: {}", p)), - NegotiationError::Failed => - fmt.write_str("Protocol negotiation failed.") + NegotiationError::ProtocolError(p) => { + fmt.write_fmt(format_args!("Protocol error: {}", p)) + } + NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."), } } } diff --git a/misc/multistream-select/src/protocol.rs b/misc/multistream-select/src/protocol.rs index 4ea7114541a..920d78919f2 100644 --- a/misc/multistream-select/src/protocol.rs +++ b/misc/multistream-select/src/protocol.rs @@ -25,12 +25,19 @@ //! `Stream` and `Sink` implementations of `MessageIO` and //! `MessageReader`. -use crate::Version; use crate::length_delimited::{LengthDelimited, LengthDelimitedReader}; +use crate::Version; -use bytes::{Bytes, BytesMut, BufMut}; -use futures::{prelude::*, io::IoSlice, ready}; -use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, str::FromStr, task::{Context, Poll}}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{io::IoSlice, prelude::*, ready}; +use std::{ + convert::TryFrom, + error::Error, + fmt, io, + pin::Pin, + str::FromStr, + task::{Context, Poll}, +}; use unsigned_varint as uvi; /// The maximum number of supported protocols that can be processed. @@ -54,7 +61,8 @@ const MSG_RESPONDER: &[u8] = b"responder\n"; /// The identifier of the multistream-select simultaneous open protocol /// extension. -pub(crate) const SIM_OPEN_ID: Protocol = Protocol(Bytes::from_static(b"/libp2p/simultaneous-connect")); +pub(crate) const SIM_OPEN_ID: Protocol = + Protocol(Bytes::from_static(b"/libp2p/simultaneous-connect")); /// The multistream-select header lines preceeding negotiation. /// @@ -88,7 +96,7 @@ impl TryFrom for Protocol { fn try_from(value: Bytes) -> Result { if !value.as_ref().starts_with(b"/") { - return Err(ProtocolError::InvalidProtocol) + return Err(ProtocolError::InvalidProtocol); } Ok(Protocol(value)) } @@ -192,7 +200,7 @@ impl Message { /// Decodes a `Message` from its byte representation. pub fn decode(mut msg: Bytes) -> Result { if msg == MSG_MULTISTREAM_1_0 { - return Ok(Message::Header(HeaderLine::V1)) + return Ok(Message::Header(HeaderLine::V1)); } if msg == MSG_PROTOCOL_NA { @@ -200,33 +208,34 @@ impl Message { } if msg == MSG_LS { - return Ok(Message::ListProtocols) + return Ok(Message::ListProtocols); } if msg.len() > MSG_SELECT.len() + 1 /* \n */ && msg[.. MSG_SELECT.len()] == *MSG_SELECT && msg.last() == Some(&b'\n') { - if let Some(nonce) = std::str::from_utf8(&msg[MSG_SELECT.len() .. msg.len() -1]) + if let Some(nonce) = std::str::from_utf8(&msg[MSG_SELECT.len()..msg.len() - 1]) .ok() .and_then(|s| u64::from_str(s).ok()) { - return Ok(Message::Select(nonce)) + return Ok(Message::Select(nonce)); } } if msg == MSG_INITIATOR { - return Ok(Message::Initiator) + return Ok(Message::Initiator); } if msg == MSG_RESPONDER { - return Ok(Message::Responder) + return Ok(Message::Responder); } // If it starts with a `/`, ends with a line feed without any // other line feeds in-between, it must be a protocol name. - if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') && - !msg[.. msg.len() - 1].contains(&b'\n') + if msg.get(0) == Some(&b'/') + && msg.last() == Some(&b'\n') + && !msg[..msg.len() - 1].contains(&b'\n') { let p = Protocol::try_from(msg.split_to(msg.len() - 1))?; return Ok(Message::Protocol(p)); @@ -239,24 +248,24 @@ impl Message { loop { // A well-formed message must be terminated with a newline. if remaining == [b'\n'] { - break + break; } else if protocols.len() == MAX_PROTOCOLS { - return Err(ProtocolError::TooManyProtocols) + return Err(ProtocolError::TooManyProtocols); } // Decode the length of the next protocol name and check that // it ends with a line feed. let (len, tail) = uvi::decode::usize(remaining)?; if len == 0 || len > tail.len() || tail[len - 1] != b'\n' { - return Err(ProtocolError::InvalidMessage) + return Err(ProtocolError::InvalidMessage); } // Parse the protocol name. - let p = Protocol::try_from(Bytes::copy_from_slice(&tail[.. len - 1]))?; + let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?; protocols.push(p); // Skip ahead to the next protocol. - remaining = &tail[len ..]; + remaining = &tail[len..]; } Ok(Message::Protocols(protocols)) @@ -274,9 +283,11 @@ impl MessageIO { /// Constructs a new `MessageIO` resource wrapping the given I/O stream. pub fn new(inner: R) -> MessageIO where - R: AsyncRead + AsyncWrite + R: AsyncRead + AsyncWrite, { - Self { inner: LengthDelimited::new(inner) } + Self { + inner: LengthDelimited::new(inner), + } } /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the @@ -287,7 +298,9 @@ impl MessageIO { /// received but no more messages are written, allowing the writing of /// follow-up protocol data to commence. pub fn into_reader(self) -> MessageReader { - MessageReader { inner: self.inner.into_reader() } + MessageReader { + inner: self.inner.into_reader(), + } } /// Draops the [`MessageIO`] resource, yielding the underlying I/O stream. @@ -317,7 +330,10 @@ where fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { let mut buf = BytesMut::new(); item.encode(&mut buf)?; - self.project().inner.start_send(buf.freeze()).map_err(From::from) + self.project() + .inner + .start_send(buf.freeze()) + .map_err(From::from) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -331,7 +347,7 @@ where impl Stream for MessageIO where - R: AsyncRead + R: AsyncRead, { type Item = Result; @@ -351,7 +367,7 @@ where #[derive(Debug)] pub struct MessageReader { #[pin] - inner: LengthDelimitedReader + inner: LengthDelimitedReader, } impl MessageReader { @@ -373,7 +389,7 @@ impl MessageReader { impl Stream for MessageReader where - R: AsyncRead + R: AsyncRead, { type Item = Result; @@ -384,9 +400,13 @@ where impl AsyncWrite for MessageReader where - TInner: AsyncWrite + TInner: AsyncWrite, { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { self.project().inner.poll_write(cx, buf) } @@ -398,12 +418,19 @@ where self.project().inner.poll_close(cx) } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll> { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { self.project().inner.poll_write_vectored(cx, bufs) } } -fn poll_stream(stream: Pin<&mut S>, cx: &mut Context<'_>) -> Poll>> +fn poll_stream( + stream: Pin<&mut S>, + cx: &mut Context<'_>, +) -> Poll>> where S: Stream>, { @@ -413,7 +440,7 @@ where Err(err) => return Poll::Ready(Some(Err(err))), } } else { - return Poll::Ready(None) + return Poll::Ready(None); }; log::trace!("Received message: {:?}", msg); @@ -446,7 +473,7 @@ impl From for ProtocolError { impl From for io::Error { fn from(err: ProtocolError) -> Self { if let ProtocolError::IoError(e) = err { - return e + return e; } io::ErrorKind::InvalidData.into() } @@ -470,14 +497,10 @@ impl Error for ProtocolError { impl fmt::Display for ProtocolError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - ProtocolError::IoError(e) => - write!(fmt, "I/O error: {}", e), - ProtocolError::InvalidMessage => - write!(fmt, "Received an invalid message."), - ProtocolError::InvalidProtocol => - write!(fmt, "A protocol (name) is invalid."), - ProtocolError::TooManyProtocols => - write!(fmt, "Too many protocols received.") + ProtocolError::IoError(e) => write!(fmt, "I/O error: {}", e), + ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."), + ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."), + ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."), } } } @@ -486,8 +509,8 @@ impl fmt::Display for ProtocolError { mod tests { use super::*; use quickcheck::*; - use rand::Rng; use rand::distributions::Alphanumeric; + use rand::Rng; use std::iter; impl Arbitrary for Protocol { @@ -509,7 +532,7 @@ mod tests { 2 => Message::ListProtocols, 3 => Message::Protocol(Protocol::arbitrary(g)), 4 => Message::Protocols(Vec::arbitrary(g)), - _ => panic!() + _ => panic!(), } } } @@ -518,10 +541,11 @@ mod tests { fn encode_decode_message() { fn prop(msg: Message) { let mut buf = BytesMut::new(); - msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg)); + msg.encode(&mut buf) + .expect(&format!("Encoding message failed: {:?}", msg)); match Message::decode(buf.freeze()) { Ok(m) => assert_eq!(m, msg), - Err(e) => panic!("Decoding failed: {:?}", e) + Err(e) => panic!("Decoding failed: {:?}", e), } } quickcheck(prop as fn(_)) diff --git a/misc/multistream-select/src/tests.rs b/misc/multistream-select/src/tests.rs index 5bbbde1be0e..763301943d7 100644 --- a/misc/multistream-select/src/tests.rs +++ b/misc/multistream-select/src/tests.rs @@ -22,9 +22,9 @@ #![cfg(test)] -use crate::{Version, NegotiationError}; use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial}; use crate::{dialer_select_proto, listener_select_proto}; +use crate::{NegotiationError, Version}; use async_std::net::{TcpListener, TcpStream}; use futures::prelude::*; @@ -53,7 +53,8 @@ fn select_proto_basic() { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; let (proto, mut io, _) = dialer_select_proto(connec, protos.into_iter(), version) - .await.unwrap(); + .await + .unwrap(); assert_eq!(proto, b"/proto2"); io.write_all(b"ping").await.unwrap(); @@ -77,12 +78,14 @@ fn select_proto_basic() { fn negotiation_failed() { let _ = env_logger::try_init(); - async fn run(Test { - version, - listen_protos, - dial_protos, - dial_payload - }: Test) { + async fn run( + Test { + version, + listen_protos, + dial_protos, + dial_payload, + }: Test, + ) { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener_addr = listener.local_addr().unwrap(); @@ -91,10 +94,12 @@ fn negotiation_failed() { let io = match listener_select_proto(connec, listen_protos).await { Ok((_, io)) => io, Err(NegotiationError::Failed) => return, - Err(NegotiationError::ProtocolError(e)) => panic!("Unexpected protocol error {}", e), + Err(NegotiationError::ProtocolError(e)) => { + panic!("Unexpected protocol error {}", e) + } }; match io.complete().await { - Err(NegotiationError::Failed) => {}, + Err(NegotiationError::Failed) => {} _ => panic!(), } }); @@ -104,14 +109,14 @@ fn negotiation_failed() { let mut io = match dialer_select_proto(connec, dial_protos.into_iter(), version).await { Err(NegotiationError::Failed) => return, Ok((_, io, _)) => io, - Err(_) => panic!() + Err(_) => panic!(), }; // The dialer may write a payload that is even sent before it // got confirmation of the last proposed protocol, when `V1Lazy` // is used. io.write_all(&dial_payload).await.unwrap(); match io.complete().await { - Err(NegotiationError::Failed) => {}, + Err(NegotiationError::Failed) => {} _ => panic!(), } }); @@ -133,10 +138,10 @@ fn negotiation_failed() { // // The choices here cover the main distinction between a single // and multiple protocols. - let protos = vec!{ + let protos = vec![ (vec!["/proto1"], vec!["/proto2"]), (vec!["/proto1", "/proto2"], vec!["/proto3", "/proto4"]), - }; + ]; // The payloads that the dialer sends after "successful" negotiation, // which may be sent even before the dialer got protocol confirmation @@ -145,7 +150,7 @@ fn negotiation_failed() { // The choices here cover the specific situations that can arise with // `V1Lazy` and which must nevertheless behave identically to `V1` w.r.t. // the outcome of the negotiation. - let payloads = vec!{ + let payloads = vec![ // No payload, in which case all versions should behave identically // in any case, i.e. the baseline test. vec![], @@ -153,13 +158,13 @@ fn negotiation_failed() { // `1` as a message length and encounters an invalid message (the // second `1`). The listener is nevertheless expected to fail // negotiation normally, just like with `V1`. - vec![1,1], + vec![1, 1], // With this payload and `V1Lazy`, the listener interprets the first // `42` as a message length and encounters unexpected EOF trying to // read a message of that length. The listener is nevertheless expected // to fail negotiation normally, just like with `V1` - vec![42,1], - }; + vec![42, 1], + ]; for (listen_protos, dial_protos) in protos { for dial_payload in payloads.clone() { @@ -193,7 +198,8 @@ fn select_proto_parallel() { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; let (proto, io, _) = dialer_select_proto_parallel(connec, protos.into_iter(), version) - .await.unwrap(); + .await + .unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); }); @@ -224,7 +230,8 @@ fn select_proto_serial() { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; let (proto, io, _) = dialer_select_proto_serial(connec, protos.into_iter(), version) - .await.unwrap(); + .await + .unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); }); @@ -247,7 +254,9 @@ fn simultaneous_open() { let server = async move { let connec = listener.accept().await.unwrap().0; let protos = vec![b"/proto1", b"/proto2"]; - let (proto, io, _) = dialer_select_proto_serial(connec, protos, version).await.unwrap(); + let (proto, io, _) = dialer_select_proto_serial(connec, protos, version) + .await + .unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); }; @@ -256,7 +265,8 @@ fn simultaneous_open() { let connec = TcpStream::connect(&listener_addr).await.unwrap(); let protos = vec![b"/proto3", b"/proto2"]; let (proto, io, _) = dialer_select_proto_serial(connec, protos.into_iter(), version) - .await.unwrap(); + .await + .unwrap(); assert_eq!(proto, b"/proto2"); io.complete().await.unwrap(); }; diff --git a/misc/multistream-select/tests/transport.rs b/misc/multistream-select/tests/transport.rs index fd632b6cc9f..38661a2f33c 100644 --- a/misc/multistream-select/tests/transport.rs +++ b/misc/multistream-select/tests/transport.rs @@ -18,24 +18,23 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use futures::{channel::oneshot, prelude::*, ready}; use libp2p_core::{ connection::{ConnectionHandler, ConnectionHandlerEvent, Substream, SubstreamEndpoint}, identity, - muxing::StreamMuxerBox, - upgrade, multiaddr::Protocol, - Multiaddr, - Network, - network::{NetworkEvent, NetworkConfig}, - PeerId, - Transport, - transport::{self, MemoryTransport} + muxing::StreamMuxerBox, + network::{NetworkConfig, NetworkEvent}, + transport::{self, MemoryTransport}, + upgrade, Multiaddr, Network, PeerId, Transport, }; use libp2p_mplex::MplexConfig; use libp2p_plaintext::PlainText2Config; -use futures::{channel::oneshot, ready, prelude::*}; use rand::random; -use std::{io, task::{Context, Poll}}; +use std::{ + io, + task::{Context, Poll}, +}; type TestTransport = transport::Boxed<(PeerId, StreamMuxerBox)>; type TestNetwork = Network; @@ -44,11 +43,16 @@ type TestNetwork = Network; fn mk_transport(_up: upgrade::Version) -> (PeerId, TestTransport) { let keys = identity::Keypair::generate_ed25519(); let id = keys.public().to_peer_id(); - (id, MemoryTransport::default() - .upgrade() - .authenticate(PlainText2Config { local_public_key: keys.public() }) - .multiplex(MplexConfig::default()) - .boxed()) + ( + id, + MemoryTransport::default() + .upgrade() + .authenticate(PlainText2Config { + local_public_key: keys.public(), + }) + .multiplex(MplexConfig::default()) + .boxed(), + ) } /// Tests the transport upgrade process with all supported @@ -64,7 +68,8 @@ fn transport_upgrade() { let listen_addr = Multiaddr::from(Protocol::Memory(random::())); let mut dialer = TestNetwork::new(dialer_transport, dialer_id, NetworkConfig::default()); - let mut listener = TestNetwork::new(listener_transport, listener_id, NetworkConfig::default()); + let mut listener = + TestNetwork::new(listener_transport, listener_id, NetworkConfig::default()); listener.listen_on(listen_addr).unwrap(); let (addr_sender, addr_receiver) = oneshot::channel(); @@ -72,33 +77,26 @@ fn transport_upgrade() { let client = async move { let addr = addr_receiver.await.unwrap(); dialer.dial(&addr, TestHandler()).unwrap(); - futures::future::poll_fn(move |cx| { - loop { - match ready!(dialer.poll(cx)) { - NetworkEvent::ConnectionEstablished { .. } => { - return Poll::Ready(()) - } - _ => {} - } + futures::future::poll_fn(move |cx| loop { + match ready!(dialer.poll(cx)) { + NetworkEvent::ConnectionEstablished { .. } => return Poll::Ready(()), + _ => {} } - }).await + }) + .await }; let mut addr_sender = Some(addr_sender); - let server = futures::future::poll_fn(move |cx| { - loop { - match ready!(listener.poll(cx)) { - NetworkEvent::NewListenerAddress { listen_addr, .. } => { - addr_sender.take().unwrap().send(listen_addr).unwrap(); - } - NetworkEvent::IncomingConnection { connection, .. } => { - listener.accept(connection, TestHandler()).unwrap(); - } - NetworkEvent::ConnectionEstablished { .. } => { - return Poll::Ready(()) - } - _ => {} + let server = futures::future::poll_fn(move |cx| loop { + match ready!(listener.poll(cx)) { + NetworkEvent::NewListenerAddress { listen_addr, .. } => { + addr_sender.take().unwrap().send(listen_addr).unwrap(); + } + NetworkEvent::IncomingConnection { connection, .. } => { + listener.accept(connection, TestHandler()).unwrap(); } + NetworkEvent::ConnectionEstablished { .. } => return Poll::Ready(()), + _ => {} } }); @@ -118,17 +116,21 @@ impl ConnectionHandler for TestHandler { type Substream = Substream; type OutboundOpenInfo = (); - fn inject_substream(&mut self, _: Self::Substream, _: SubstreamEndpoint) - {} + fn inject_substream( + &mut self, + _: Self::Substream, + _: SubstreamEndpoint, + ) { + } - fn inject_event(&mut self, _: Self::InEvent) - {} + fn inject_event(&mut self, _: Self::InEvent) {} - fn inject_address_change(&mut self, _: &Multiaddr) - {} + fn inject_address_change(&mut self, _: &Multiaddr) {} - fn poll(&mut self, _: &mut Context<'_>) - -> Poll, Self::Error>> + fn poll( + &mut self, + _: &mut Context<'_>, + ) -> Poll, Self::Error>> { Poll::Pending } diff --git a/misc/peer-id-generator/src/main.rs b/misc/peer-id-generator/src/main.rs index 6ac7af7e358..45239317396 100644 --- a/misc/peer-id-generator/src/main.rs +++ b/misc/peer-id-generator/src/main.rs @@ -26,22 +26,26 @@ fn main() { // bytes 0x1220, meaning that only some characters are valid. const ALLOWED_FIRST_BYTE: &'static [u8] = b"NPQRSTUVWXYZ"; - let prefix = - match env::args().nth(1) { - Some(prefix) => prefix, - None => { - println!( + let prefix = match env::args().nth(1) { + Some(prefix) => prefix, + None => { + println!( "Usage: {} \n\n\ Generates a peer id that starts with the chosen prefix using a secp256k1 public \ key.\n\n\ Prefix must be a sequence of characters in the base58 \ alphabet, and must start with one of the following: {}", - env::current_exe().unwrap().file_name().unwrap().to_str().unwrap(), + env::current_exe() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap(), str::from_utf8(ALLOWED_FIRST_BYTE).unwrap() ); - return; - } - }; + return; + } + }; // The base58 alphabet is not necessarily obvious. const ALPHABET: &'static [u8] = b"123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"; diff --git a/muxers/mplex/benches/split_send_size.rs b/muxers/mplex/benches/split_send_size.rs index 7613fa8e918..66592c4febf 100644 --- a/muxers/mplex/benches/split_send_size.rs +++ b/muxers/mplex/benches/split_send_size.rs @@ -24,9 +24,12 @@ use async_std::task; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use futures::channel::oneshot; -use futures::prelude::*; use futures::future::poll_fn; -use libp2p_core::{PeerId, Transport, StreamMuxer, identity, upgrade, transport, muxing, multiaddr::multiaddr, Multiaddr}; +use futures::prelude::*; +use libp2p_core::{ + identity, multiaddr::multiaddr, muxing, transport, upgrade, Multiaddr, PeerId, StreamMuxer, + Transport, +}; use libp2p_mplex as mplex; use libp2p_plaintext::PlainText2Config; use std::time::Duration; @@ -51,14 +54,13 @@ fn prepare(c: &mut Criterion) { let payload: Vec = vec![1; 1024 * 1024 * 1]; let mut tcp = c.benchmark_group("tcp"); - let tcp_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127,0,0,1)), Tcp(0u16)]; + let tcp_addr = multiaddr![Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1)), Tcp(0u16)]; for &size in BENCH_SIZES.iter() { tcp.throughput(Throughput::Bytes(payload.len() as u64)); let trans = tcp_transport(size); - tcp.bench_function( - format!("{}", size), - |b| b.iter(|| run(black_box(&trans), black_box(&payload), black_box(&tcp_addr))) - ); + tcp.bench_function(format!("{}", size), |b| { + b.iter(|| run(black_box(&trans), black_box(&payload), black_box(&tcp_addr))) + }); } tcp.finish(); @@ -67,15 +69,13 @@ fn prepare(c: &mut Criterion) { for &size in BENCH_SIZES.iter() { mem.throughput(Throughput::Bytes(payload.len() as u64)); let trans = mem_transport(size); - mem.bench_function( - format!("{}", size), - |b| b.iter(|| run(black_box(&trans), black_box(&payload), black_box(&mem_addr))) - ); + mem.bench_function(format!("{}", size), |b| { + b.iter(|| run(black_box(&trans), black_box(&payload), black_box(&mem_addr))) + }); } mem.finish(); } - /// Transfers the given payload between two nodes using the given transport. fn run(transport: &BenchTransport, payload: &Vec, listen_addr: &Multiaddr) { let mut listener = transport.clone().listen_on(listen_addr.clone()).unwrap(); @@ -101,18 +101,20 @@ fn run(transport: &BenchTransport, payload: &Vec, listen_addr: &Multiaddr) { let end = off + std::cmp::min(buf.len() - off, 8 * 1024); let n = poll_fn(|cx| { conn.read_substream(cx, &mut s, &mut buf[off..end]) - }).await.unwrap(); + }) + .await + .unwrap(); off += n; if off == buf.len() { - return + return; } } } Ok(_) => panic!("Unexpected muxer event"), - Err(e) => panic!("Unexpected error: {:?}", e) + Err(e) => panic!("Unexpected error: {:?}", e), } } - _ => panic!("Unexpected listener event") + _ => panic!("Unexpected listener event"), } } }); @@ -122,16 +124,20 @@ fn run(transport: &BenchTransport, payload: &Vec, listen_addr: &Multiaddr) { let addr = addr_receiver.await.unwrap(); let (_peer, conn) = transport.clone().dial(addr).unwrap().await.unwrap(); let mut handle = conn.open_outbound(); - let mut stream = poll_fn(|cx| conn.poll_outbound(cx, &mut handle)).await.unwrap(); + let mut stream = poll_fn(|cx| conn.poll_outbound(cx, &mut handle)) + .await + .unwrap(); let mut off = 0; loop { - let n = poll_fn(|cx| { - conn.write_substream(cx, &mut stream, &payload[off..]) - }).await.unwrap(); + let n = poll_fn(|cx| conn.write_substream(cx, &mut stream, &payload[off..])) + .await + .unwrap(); off += n; if off == payload.len() { - poll_fn(|cx| conn.flush_substream(cx, &mut stream)).await.unwrap(); - return + poll_fn(|cx| conn.flush_substream(cx, &mut stream)) + .await + .unwrap(); + return; } } }); @@ -147,7 +153,8 @@ fn tcp_transport(split_send_size: usize) -> BenchTransport { let mut mplex = mplex::MplexConfig::default(); mplex.set_split_send_size(split_send_size); - libp2p_tcp::TcpConfig::new().nodelay(true) + libp2p_tcp::TcpConfig::new() + .nodelay(true) .upgrade() .authenticate(PlainText2Config { local_public_key }) .multiplex(mplex) diff --git a/muxers/mplex/src/codec.rs b/muxers/mplex/src/codec.rs index f56bb146ad0..3867cd27d8d 100644 --- a/muxers/mplex/src/codec.rs +++ b/muxers/mplex/src/codec.rs @@ -18,10 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::{BufMut, Bytes, BytesMut}; use asynchronous_codec::{Decoder, Encoder}; +use bytes::{BufMut, Bytes, BytesMut}; use libp2p_core::Endpoint; -use std::{fmt, hash::{Hash, Hasher}, io, mem}; +use std::{ + fmt, + hash::{Hash, Hasher}, + io, mem, +}; use unsigned_varint::{codec, encode}; // Maximum size for a packet: 1MB as per the spec. @@ -82,18 +86,27 @@ pub struct RemoteStreamId { impl LocalStreamId { pub fn dialer(num: u64) -> Self { - Self { num, role: Endpoint::Dialer } + Self { + num, + role: Endpoint::Dialer, + } } #[cfg(test)] pub fn listener(num: u64) -> Self { - Self { num, role: Endpoint::Listener } + Self { + num, + role: Endpoint::Listener, + } } pub fn next(self) -> Self { Self { - num: self.num.checked_add(1).expect("Mplex substream ID overflowed"), - .. self + num: self + .num + .checked_add(1) + .expect("Mplex substream ID overflowed"), + ..self } } @@ -108,11 +121,17 @@ impl LocalStreamId { impl RemoteStreamId { fn dialer(num: u64) -> Self { - Self { num, role: Endpoint::Dialer } + Self { + num, + role: Endpoint::Dialer, + } } fn listener(num: u64) -> Self { - Self { num, role: Endpoint::Listener } + Self { + num, + role: Endpoint::Listener, + } } /// Converts this `RemoteStreamId` into the corresponding `LocalStreamId` @@ -174,31 +193,28 @@ impl Decoder for Codec { fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { loop { match mem::replace(&mut self.decoder_state, CodecDecodeState::Poisoned) { - CodecDecodeState::Begin => { - match self.varint_decoder.decode(src)? { - Some(header) => { - self.decoder_state = CodecDecodeState::HasHeader(header); - }, - None => { - self.decoder_state = CodecDecodeState::Begin; - return Ok(None); - }, + CodecDecodeState::Begin => match self.varint_decoder.decode(src)? { + Some(header) => { + self.decoder_state = CodecDecodeState::HasHeader(header); + } + None => { + self.decoder_state = CodecDecodeState::Begin; + return Ok(None); } }, - CodecDecodeState::HasHeader(header) => { - match self.varint_decoder.decode(src)? { - Some(len) => { - if len as usize > MAX_FRAME_SIZE { - let msg = format!("Mplex frame length {} exceeds maximum", len); - return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); - } - - self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len as usize); - }, - None => { - self.decoder_state = CodecDecodeState::HasHeader(header); - return Ok(None); - }, + CodecDecodeState::HasHeader(header) => match self.varint_decoder.decode(src)? { + Some(len) => { + if len as usize > MAX_FRAME_SIZE { + let msg = format!("Mplex frame length {} exceeds maximum", len); + return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); + } + + self.decoder_state = + CodecDecodeState::HasHeaderAndLen(header, len as usize); + } + None => { + self.decoder_state = CodecDecodeState::HasHeader(header); + return Ok(None); } }, CodecDecodeState::HasHeaderAndLen(header, len) => { @@ -212,25 +228,44 @@ impl Decoder for Codec { let buf = src.split_to(len); let num = (header >> 3) as u64; let out = match header & 7 { - 0 => Frame::Open { stream_id: RemoteStreamId::dialer(num) }, - 1 => Frame::Data { stream_id: RemoteStreamId::listener(num), data: buf.freeze() }, - 2 => Frame::Data { stream_id: RemoteStreamId::dialer(num), data: buf.freeze() }, - 3 => Frame::Close { stream_id: RemoteStreamId::listener(num) }, - 4 => Frame::Close { stream_id: RemoteStreamId::dialer(num) }, - 5 => Frame::Reset { stream_id: RemoteStreamId::listener(num) }, - 6 => Frame::Reset { stream_id: RemoteStreamId::dialer(num) }, + 0 => Frame::Open { + stream_id: RemoteStreamId::dialer(num), + }, + 1 => Frame::Data { + stream_id: RemoteStreamId::listener(num), + data: buf.freeze(), + }, + 2 => Frame::Data { + stream_id: RemoteStreamId::dialer(num), + data: buf.freeze(), + }, + 3 => Frame::Close { + stream_id: RemoteStreamId::listener(num), + }, + 4 => Frame::Close { + stream_id: RemoteStreamId::dialer(num), + }, + 5 => Frame::Reset { + stream_id: RemoteStreamId::listener(num), + }, + 6 => Frame::Reset { + stream_id: RemoteStreamId::dialer(num), + }, _ => { let msg = format!("Invalid mplex header value 0x{:x}", header); return Err(io::Error::new(io::ErrorKind::InvalidData, msg)); - }, + } }; self.decoder_state = CodecDecodeState::Begin; return Ok(Some(out)); - }, + } CodecDecodeState::Poisoned => { - return Err(io::Error::new(io::ErrorKind::InvalidData, "Mplex codec poisoned")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Mplex codec poisoned", + )); } } } @@ -243,27 +278,51 @@ impl Encoder for Codec { fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { let (header, data) = match item { - Frame::Open { stream_id } => { - (stream_id.num << 3, Bytes::new()) - }, - Frame::Data { stream_id: LocalStreamId { num, role: Endpoint::Listener }, data } => { - (num << 3 | 1, data) - }, - Frame::Data { stream_id: LocalStreamId { num, role: Endpoint::Dialer }, data } => { - (num << 3 | 2, data) - }, - Frame::Close { stream_id: LocalStreamId { num, role: Endpoint::Listener } } => { - (num << 3 | 3, Bytes::new()) - }, - Frame::Close { stream_id: LocalStreamId { num, role: Endpoint::Dialer } } => { - (num << 3 | 4, Bytes::new()) - }, - Frame::Reset { stream_id: LocalStreamId { num, role: Endpoint::Listener } } => { - (num << 3 | 5, Bytes::new()) - }, - Frame::Reset { stream_id: LocalStreamId { num, role: Endpoint::Dialer } } => { - (num << 3 | 6, Bytes::new()) - }, + Frame::Open { stream_id } => (stream_id.num << 3, Bytes::new()), + Frame::Data { + stream_id: + LocalStreamId { + num, + role: Endpoint::Listener, + }, + data, + } => (num << 3 | 1, data), + Frame::Data { + stream_id: + LocalStreamId { + num, + role: Endpoint::Dialer, + }, + data, + } => (num << 3 | 2, data), + Frame::Close { + stream_id: + LocalStreamId { + num, + role: Endpoint::Listener, + }, + } => (num << 3 | 3, Bytes::new()), + Frame::Close { + stream_id: + LocalStreamId { + num, + role: Endpoint::Dialer, + }, + } => (num << 3 | 4, Bytes::new()), + Frame::Reset { + stream_id: + LocalStreamId { + num, + role: Endpoint::Listener, + }, + } => (num << 3 | 5, Bytes::new()), + Frame::Reset { + stream_id: + LocalStreamId { + num, + role: Endpoint::Dialer, + }, + } => (num << 3 | 6, Bytes::new()), }; let mut header_buf = encode::u64_buffer(); @@ -274,7 +333,10 @@ impl Encoder for Codec { let data_len_bytes = encode::usize(data_len, &mut data_buf); if data_len > MAX_FRAME_SIZE { - return Err(io::Error::new(io::ErrorKind::InvalidData, "data size exceed maximum")); + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "data size exceed maximum", + )); } dst.reserve(header_bytes.len() + data_len_bytes.len() + data_len); @@ -294,15 +356,21 @@ mod tests { let mut enc = Codec::new(); let role = Endpoint::Dialer; let data = Bytes::from(&[123u8; MAX_FRAME_SIZE + 1][..]); - let bad_msg = Frame::Data { stream_id: LocalStreamId { num: 123, role }, data }; + let bad_msg = Frame::Data { + stream_id: LocalStreamId { num: 123, role }, + data, + }; let mut out = BytesMut::new(); match enc.encode(bad_msg, &mut out) { Err(e) => assert_eq!(e.to_string(), "data size exceed maximum"), - _ => panic!("Can't send a message bigger than MAX_FRAME_SIZE") + _ => panic!("Can't send a message bigger than MAX_FRAME_SIZE"), } let data = Bytes::from(&[123u8; MAX_FRAME_SIZE][..]); - let ok_msg = Frame::Data { stream_id: LocalStreamId { num: 123, role }, data }; + let ok_msg = Frame::Data { + stream_id: LocalStreamId { num: 123, role }, + data, + }; assert!(enc.encode(ok_msg, &mut out).is_ok()); } @@ -311,19 +379,24 @@ mod tests { // Create new codec object for encoding and decoding our frame. let mut codec = Codec::new(); // Create a u64 stream ID. - let id: u64 = u32::MAX as u64 + 1 ; - let stream_id = LocalStreamId { num: id, role: Endpoint::Dialer }; + let id: u64 = u32::MAX as u64 + 1; + let stream_id = LocalStreamId { + num: id, + role: Endpoint::Dialer, + }; // Open a new frame with that stream ID. let original_frame = Frame::Open { stream_id }; // Encode that frame. let mut enc_frame = BytesMut::new(); - codec.encode(original_frame, &mut enc_frame) + codec + .encode(original_frame, &mut enc_frame) .expect("Encoding to succeed."); // Decode encoded frame and extract stream ID. - let dec_string_id = codec.decode(&mut enc_frame) + let dec_string_id = codec + .decode(&mut enc_frame) .expect("Decoding to succeed.") .map(|f| f.remote_id()) .unwrap(); diff --git a/muxers/mplex/src/io.rs b/muxers/mplex/src/io.rs index e4e49935b8d..80da197a965 100644 --- a/muxers/mplex/src/io.rs +++ b/muxers/mplex/src/io.rs @@ -18,20 +18,24 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use bytes::Bytes; -use crate::{MplexConfig, MaxBufferBehaviour}; use crate::codec::{Codec, Frame, LocalStreamId, RemoteStreamId}; -use log::{debug, trace}; -use futures::{prelude::*, ready, stream::Fuse}; -use futures::task::{AtomicWaker, ArcWake, waker_ref, WakerRef}; +use crate::{MaxBufferBehaviour, MplexConfig}; use asynchronous_codec::Framed; +use bytes::Bytes; +use futures::task::{waker_ref, ArcWake, AtomicWaker, WakerRef}; +use futures::{prelude::*, ready, stream::Fuse}; +use log::{debug, trace}; use nohash_hasher::{IntMap, IntSet}; use parking_lot::Mutex; use smallvec::SmallVec; use std::collections::VecDeque; -use std::{cmp, fmt, io, mem, sync::Arc, task::{Context, Poll, Waker}}; +use std::{ + cmp, fmt, io, mem, + sync::Arc, + task::{Context, Poll, Waker}, +}; -pub use std::io::{Result, Error, ErrorKind}; +pub use std::io::{Error, ErrorKind, Result}; /// A connection identifier. /// @@ -109,7 +113,7 @@ enum Status { impl Multiplexed where - C: AsyncRead + AsyncWrite + Unpin + C: AsyncRead + AsyncWrite + Unpin, { /// Creates a new multiplexed I/O stream. pub fn new(io: C, config: MplexConfig) -> Self { @@ -134,8 +138,8 @@ where pending: Mutex::new(Default::default()), }), notifier_open: NotifierOpen { - pending: Default::default() - } + pending: Default::default(), + }, } } @@ -223,14 +227,14 @@ where // from the respective substreams. if num_buffered == self.config.max_buffer_len { cx.waker().clone().wake(); - return Poll::Pending + return Poll::Pending; } // Wait for the next inbound `Open` frame. match ready!(self.poll_read_frame(cx, None))? { Frame::Open { stream_id } => { if let Some(id) = self.on_open(stream_id)? { - return Poll::Ready(Ok(id)) + return Poll::Ready(Ok(id)); } } Frame::Data { stream_id, data } => { @@ -240,9 +244,7 @@ where Frame::Close { stream_id } => { self.on_close(stream_id.into_local()); } - Frame::Reset { stream_id } => { - self.on_reset(stream_id.into_local()) - } + Frame::Reset { stream_id } => self.on_reset(stream_id.into_local()), } } } @@ -253,10 +255,12 @@ where // Check the stream limits. if self.substreams.len() >= self.config.max_substreams { - debug!("{}: Maximum number of substreams reached ({})", - self.id, self.config.max_substreams); + debug!( + "{}: Maximum number of substreams reached ({})", + self.id, self.config.max_substreams + ); self.notifier_open.register(cx.waker()); - return Poll::Pending + return Poll::Pending; } // Send the `Open` frame. @@ -267,11 +271,18 @@ where let frame = Frame::Open { stream_id }; match self.io.start_send_unpin(frame) { Ok(()) => { - self.substreams.insert(stream_id, SubstreamState::Open { - buf: Default::default() - }); - debug!("{}: New outbound substream: {} (total {})", - self.id, stream_id, self.substreams.len()); + self.substreams.insert( + stream_id, + SubstreamState::Open { + buf: Default::default(), + }, + ); + debug!( + "{}: New outbound substream: {} (total {})", + self.id, + stream_id, + self.substreams.len() + ); // The flush is delayed and the `Open` frame may be sent // together with other frames in the same transport packet. self.pending_flush_open.insert(stream_id); @@ -279,8 +290,8 @@ where } Err(e) => Poll::Ready(self.on_error(e)), } - }, - Err(e) => Poll::Ready(self.on_error(e)) + } + Err(e) => Poll::Ready(self.on_error(e)), } } @@ -310,7 +321,7 @@ where // Check if the underlying stream is ok. match self.status { Status::Closed | Status::Err(_) => return, - Status::Open => {}, + Status::Open => {} } // If there is still a task waker interested in reading from that @@ -321,7 +332,7 @@ where // Remove the substream, scheduling pending frames as necessary. match self.substreams.remove(&id) { - None => {}, + None => {} Some(state) => { // If we fell below the substream limit, notify tasks that had // interest in opening an outbound substream earlier. @@ -336,17 +347,19 @@ where SubstreamState::Reset { .. } => {} SubstreamState::RecvClosed { .. } => { if self.check_max_pending_frames().is_err() { - return + return; } trace!("{}: Pending close for stream {}", self.id, id); - self.pending_frames.push_front(Frame::Close { stream_id: id }); + self.pending_frames + .push_front(Frame::Close { stream_id: id }); } SubstreamState::Open { .. } => { if self.check_max_pending_frames().is_err() { - return + return; } trace!("{}: Pending reset for stream {}", self.id, id); - self.pending_frames.push_front(Frame::Reset { stream_id: id }); + self.pending_frames + .push_front(Frame::Reset { stream_id: id }); } } } @@ -354,17 +367,22 @@ where } /// Writes data to a substream. - pub fn poll_write_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId, buf: &[u8]) - -> Poll> - { + pub fn poll_write_stream( + &mut self, + cx: &mut Context<'_>, + id: LocalStreamId, + buf: &[u8], + ) -> Poll> { self.guard_open()?; // Check if the stream is open for writing. match self.substreams.get(&id) { - None | Some(SubstreamState::Reset { .. }) => - return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())), - Some(SubstreamState::SendClosed { .. }) | Some(SubstreamState::Closed { .. }) => - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + None | Some(SubstreamState::Reset { .. }) => { + return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } + Some(SubstreamState::SendClosed { .. }) | Some(SubstreamState::Closed { .. }) => { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) + } Some(SubstreamState::Open { .. }) | Some(SubstreamState::RecvClosed { .. }) => { // Substream is writeable. Continue. } @@ -375,8 +393,11 @@ where // Send the data frame. ready!(self.poll_send_frame(cx, || { - let data = Bytes::copy_from_slice(&buf[.. frame_len]); - Frame::Data { stream_id: id, data } + let data = Bytes::copy_from_slice(&buf[..frame_len]); + Frame::Data { + stream_id: id, + data, + } }))?; Poll::Ready(Ok(frame_len)) @@ -396,9 +417,11 @@ where /// and under consideration of the number of already used substreams, /// thereby waking the task that last called `poll_next_stream`, if any. /// Inbound substreams received in excess of that limit are immediately reset. - pub fn poll_read_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId) - -> Poll>> - { + pub fn poll_read_stream( + &mut self, + cx: &mut Context<'_>, + id: LocalStreamId, + ) -> Poll>> { self.guard_open()?; // Try to read from the buffer first. @@ -411,7 +434,7 @@ where ArcWake::wake_by_ref(&self.notifier_read); } let data = buf.remove(0); - return Poll::Ready(Ok(Some(data))) + return Poll::Ready(Ok(Some(data))); } // If the stream buffer "spilled" onto the heap, free that memory. buf.shrink_to_fit(); @@ -426,7 +449,7 @@ where // a chance to read from the other substream(s). if num_buffered == self.config.max_buffer_len { cx.waker().clone().wake(); - return Poll::Pending + return Poll::Pending; } // Check if the targeted substream (if any) reached EOF. @@ -436,14 +459,14 @@ where // remote, as the `StreamMuxer::read_substream` contract only // permits errors on "terminal" conditions, e.g. if the connection // has been closed or on protocol misbehaviour. - return Poll::Ready(Ok(None)) + return Poll::Ready(Ok(None)); } // Read the next frame. match ready!(self.poll_read_frame(cx, Some(id)))? { Frame::Data { data, stream_id } if stream_id.into_local() == id => { return Poll::Ready(Ok(Some(data))) - }, + } Frame::Data { stream_id, data } => { // The data frame is for a different stream than the one // currently being polled, so it needs to be buffered and @@ -454,7 +477,12 @@ where frame @ Frame::Open { .. } => { if let Some(id) = self.on_open(frame.remote_id())? { self.open_buffer.push_front(id); - trace!("{}: Buffered new inbound stream {} (total: {})", self.id, id, self.open_buffer.len()); + trace!( + "{}: Buffered new inbound stream {} (total: {})", + self.id, + id, + self.open_buffer.len() + ); self.notifier_read.wake_next_stream(); } } @@ -462,14 +490,14 @@ where let stream_id = stream_id.into_local(); self.on_close(stream_id); if id == stream_id { - return Poll::Ready(Ok(None)) + return Poll::Ready(Ok(None)); } } Frame::Reset { stream_id } => { let stream_id = stream_id.into_local(); self.on_reset(stream_id); if id == stream_id { - return Poll::Ready(Ok(None)) + return Poll::Ready(Ok(None)); } } } @@ -481,9 +509,11 @@ where /// > **Note**: This is equivalent to `poll_flush()`, i.e. to flushing /// > all substreams, except that this operation returns an error if /// > the underlying I/O stream is already closed. - pub fn poll_flush_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId) - -> Poll> - { + pub fn poll_flush_stream( + &mut self, + cx: &mut Context<'_>, + id: LocalStreamId, + ) -> Poll> { self.guard_open()?; ready!(self.poll_flush(cx))?; @@ -495,15 +525,18 @@ where /// Closes a stream for writing. /// /// > **Note**: As opposed to `poll_close()`, a flush it not implied. - pub fn poll_close_stream(&mut self, cx: &mut Context<'_>, id: LocalStreamId) - -> Poll> - { + pub fn poll_close_stream( + &mut self, + cx: &mut Context<'_>, + id: LocalStreamId, + ) -> Poll> { self.guard_open()?; match self.substreams.remove(&id) { None => Poll::Ready(Ok(())), Some(SubstreamState::SendClosed { buf }) => { - self.substreams.insert(id, SubstreamState::SendClosed { buf }); + self.substreams + .insert(id, SubstreamState::SendClosed { buf }); Poll::Ready(Ok(())) } Some(SubstreamState::Closed { buf }) => { @@ -515,18 +548,26 @@ where Poll::Ready(Ok(())) } Some(SubstreamState::Open { buf }) => { - if self.poll_send_frame(cx, || Frame::Close { stream_id: id })?.is_pending() { + if self + .poll_send_frame(cx, || Frame::Close { stream_id: id })? + .is_pending() + { self.substreams.insert(id, SubstreamState::Open { buf }); Poll::Pending } else { debug!("{}: Closed substream {} (half-close)", self.id, id); - self.substreams.insert(id, SubstreamState::SendClosed { buf }); + self.substreams + .insert(id, SubstreamState::SendClosed { buf }); Poll::Ready(Ok(())) } } Some(SubstreamState::RecvClosed { buf }) => { - if self.poll_send_frame(cx, || Frame::Close { stream_id: id })?.is_pending() { - self.substreams.insert(id, SubstreamState::RecvClosed { buf }); + if self + .poll_send_frame(cx, || Frame::Close { stream_id: id })? + .is_pending() + { + self.substreams + .insert(id, SubstreamState::RecvClosed { buf }); Poll::Pending } else { debug!("{}: Closed substream {}", self.id, id); @@ -541,10 +582,9 @@ where /// /// The frame is only constructed if the underlying sink is ready to /// send another frame. - fn poll_send_frame(&mut self, cx: &mut Context<'_>, frame: F) - -> Poll> + fn poll_send_frame(&mut self, cx: &mut Context<'_>, frame: F) -> Poll> where - F: FnOnce() -> Frame + F: FnOnce() -> Frame, { let waker = NotifierWrite::register(&self.notifier_write, cx.waker()); match ready!(self.io.poll_ready_unpin(&mut Context::from_waker(&waker))) { @@ -553,10 +593,10 @@ where trace!("{}: Sending {:?}", self.id, frame); match self.io.start_send_unpin(frame) { Ok(()) => Poll::Ready(Ok(())), - Err(e) => Poll::Ready(self.on_error(e)) + Err(e) => Poll::Ready(self.on_error(e)), } - }, - Err(e) => Poll::Ready(self.on_error(e)) + } + Err(e) => Poll::Ready(self.on_error(e)), } } @@ -566,12 +606,14 @@ where /// the current task is interested and wants to be woken up for, /// in case new frames can be read. `None` means interest in /// frames for any substream. - fn poll_read_frame(&mut self, cx: &mut Context<'_>, stream_id: Option) - -> Poll>> - { + fn poll_read_frame( + &mut self, + cx: &mut Context<'_>, + stream_id: Option, + ) -> Poll>> { // Try to send pending frames, if there are any, without blocking, if let Poll::Ready(Err(e)) = self.send_pending_frames(cx) { - return Poll::Ready(Err(e)) + return Poll::Ready(Err(e)); } // Perform any pending flush before reading. @@ -593,13 +635,19 @@ where if !self.notifier_read.wake_read_stream(*blocked_id) { // No task dedicated to the blocked stream woken, so schedule // this task again to have a chance at progress. - trace!("{}: No task to read from blocked stream. Waking current task.", self.id); + trace!( + "{}: No task to read from blocked stream. Waking current task.", + self.id + ); cx.waker().clone().wake(); } else if let Some(id) = stream_id { // We woke some other task, but are still interested in // reading `Data` frames from the current stream when unblocked. - debug_assert!(blocked_id != &id, "Unexpected attempt at reading a new \ - frame from a substream with a full buffer."); + debug_assert!( + blocked_id != &id, + "Unexpected attempt at reading a new \ + frame from a substream with a full buffer." + ); let _ = NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id); } else { // We woke some other task but are still interested in @@ -607,13 +655,13 @@ where let _ = NotifierRead::register_next_stream(&self.notifier_read, cx.waker()); } - return Poll::Pending + return Poll::Pending; } // Try to read another frame from the underlying I/O stream. let waker = match stream_id { Some(id) => NotifierRead::register_read_stream(&self.notifier_read, cx.waker(), id), - None => NotifierRead::register_next_stream(&self.notifier_read, cx.waker()) + None => NotifierRead::register_next_stream(&self.notifier_read, cx.waker()), }; match ready!(self.io.poll_next_unpin(&mut Context::from_waker(&waker))) { Some(Ok(frame)) => { @@ -621,7 +669,7 @@ where Poll::Ready(Ok(frame)) } Some(Err(e)) => Poll::Ready(self.on_error(e)), - None => Poll::Ready(self.on_error(io::ErrorKind::UnexpectedEof.into())) + None => Poll::Ready(self.on_error(io::ErrorKind::UnexpectedEof.into())), } } @@ -630,27 +678,41 @@ where let id = id.into_local(); if self.substreams.contains_key(&id) { - debug!("{}: Received unexpected `Open` frame for open substream {}", self.id, id); - return self.on_error(io::Error::new(io::ErrorKind::Other, - "Protocol error: Received `Open` frame for open substream.")) + debug!( + "{}: Received unexpected `Open` frame for open substream {}", + self.id, id + ); + return self.on_error(io::Error::new( + io::ErrorKind::Other, + "Protocol error: Received `Open` frame for open substream.", + )); } if self.substreams.len() >= self.config.max_substreams { - debug!("{}: Maximum number of substreams exceeded: {}", - self.id, self.config.max_substreams); + debug!( + "{}: Maximum number of substreams exceeded: {}", + self.id, self.config.max_substreams + ); self.check_max_pending_frames()?; debug!("{}: Pending reset for new stream {}", self.id, id); - self.pending_frames.push_front(Frame::Reset { - stream_id: id - }); - return Ok(None) + self.pending_frames + .push_front(Frame::Reset { stream_id: id }); + return Ok(None); } - self.substreams.insert(id, SubstreamState::Open { - buf: Default::default() - }); + self.substreams.insert( + id, + SubstreamState::Open { + buf: Default::default(), + }, + ); - debug!("{}: New inbound substream: {} (total {})", self.id, id, self.substreams.len()); + debug!( + "{}: New inbound substream: {} (total {})", + self.id, + id, + self.substreams.len() + ); Ok(Some(id)) } @@ -660,15 +722,22 @@ where if let Some(state) = self.substreams.remove(&id) { match state { SubstreamState::Closed { .. } => { - trace!("{}: Ignoring reset for mutually closed substream {}.", self.id, id); + trace!( + "{}: Ignoring reset for mutually closed substream {}.", + self.id, + id + ); } SubstreamState::Reset { .. } => { - trace!("{}: Ignoring redundant reset for already reset substream {}", - self.id, id); + trace!( + "{}: Ignoring redundant reset for already reset substream {}", + self.id, + id + ); } - SubstreamState::RecvClosed { buf } | - SubstreamState::SendClosed { buf } | - SubstreamState::Open { buf } => { + SubstreamState::RecvClosed { buf } + | SubstreamState::SendClosed { buf } + | SubstreamState::Open { buf } => { debug!("{}: Substream {} reset by remote.", self.id, id); self.substreams.insert(id, SubstreamState::Reset { buf }); // Notify tasks interested in reading from that stream, @@ -677,8 +746,11 @@ where } } } else { - trace!("{}: Ignoring `Reset` for unknown substream {}. Possibly dropped earlier.", - self.id, id); + trace!( + "{}: Ignoring `Reset` for unknown substream {}. Possibly dropped earlier.", + self.id, + id + ); } } @@ -687,33 +759,45 @@ where if let Some(state) = self.substreams.remove(&id) { match state { SubstreamState::RecvClosed { .. } | SubstreamState::Closed { .. } => { - debug!("{}: Ignoring `Close` frame for closed substream {}", - self.id, id); + debug!( + "{}: Ignoring `Close` frame for closed substream {}", + self.id, id + ); self.substreams.insert(id, state); - }, + } SubstreamState::Reset { buf } => { - debug!("{}: Ignoring `Close` frame for already reset substream {}", - self.id, id); + debug!( + "{}: Ignoring `Close` frame for already reset substream {}", + self.id, id + ); self.substreams.insert(id, SubstreamState::Reset { buf }); } SubstreamState::SendClosed { buf } => { - debug!("{}: Substream {} closed by remote (SendClosed -> Closed).", - self.id, id); + debug!( + "{}: Substream {} closed by remote (SendClosed -> Closed).", + self.id, id + ); self.substreams.insert(id, SubstreamState::Closed { buf }); // Notify tasks interested in reading, so they may read the EOF. self.notifier_read.wake_read_stream(id); - }, + } SubstreamState::Open { buf } => { - debug!("{}: Substream {} closed by remote (Open -> RecvClosed)", - self.id, id); - self.substreams.insert(id, SubstreamState::RecvClosed { buf }); + debug!( + "{}: Substream {} closed by remote (Open -> RecvClosed)", + self.id, id + ); + self.substreams + .insert(id, SubstreamState::RecvClosed { buf }); // Notify tasks interested in reading, so they may read the EOF. self.notifier_read.wake_read_stream(id); - }, + } } } else { - trace!("{}: Ignoring `Close` for unknown substream {}. Possibly dropped earlier.", - self.id, id); + trace!( + "{}: Ignoring `Close` for unknown substream {}. Possibly dropped earlier.", + self.id, + id + ); } } @@ -735,11 +819,9 @@ where /// Sends pending frames, without flushing. fn send_pending_frames(&mut self, cx: &mut Context<'_>) -> Poll> { while let Some(frame) = self.pending_frames.pop_back() { - if self.poll_send_frame(cx, || { - frame.clone() - })?.is_pending() { + if self.poll_send_frame(cx, || frame.clone())?.is_pending() { self.pending_frames.push_back(frame); - return Poll::Pending + return Poll::Pending; } } @@ -750,7 +832,7 @@ where fn on_error(&mut self, e: io::Error) -> io::Result { debug!("{}: Multiplexed connection failed: {:?}", self.id, e); self.status = Status::Err(io::Error::new(e.kind(), e.to_string())); - self.pending_frames = Default::default(); + self.pending_frames = Default::default(); self.substreams = Default::default(); self.open_buffer = Default::default(); Err(e) @@ -762,7 +844,7 @@ where match &self.status { Status::Closed => Err(io::Error::new(io::ErrorKind::Other, "Connection is closed")), Status::Err(e) => Err(io::Error::new(e.kind(), e.to_string())), - Status::Open => Ok(()) + Status::Open => Ok(()), } } @@ -770,8 +852,10 @@ where /// has not been reached. fn check_max_pending_frames(&mut self) -> io::Result<()> { if self.pending_frames.len() >= self.config.max_substreams + EXTRA_PENDING_FRAMES { - return self.on_error(io::Error::new(io::ErrorKind::Other, - "Too many pending frames.")); + return self.on_error(io::Error::new( + io::ErrorKind::Other, + "Too many pending frames.", + )); } Ok(()) } @@ -789,19 +873,35 @@ where let state = if let Some(state) = self.substreams.get_mut(&id) { state } else { - trace!("{}: Dropping data {:?} for unknown substream {}", self.id, data, id); - return Ok(()) + trace!( + "{}: Dropping data {:?} for unknown substream {}", + self.id, + data, + id + ); + return Ok(()); }; let buf = if let Some(buf) = state.recv_buf_open() { buf } else { - trace!("{}: Dropping data {:?} for closed or reset substream {}", self.id, data, id); - return Ok(()) + trace!( + "{}: Dropping data {:?} for closed or reset substream {}", + self.id, + data, + id + ); + return Ok(()); }; debug_assert!(buf.len() <= self.config.max_buffer_len); - trace!("{}: Buffering {:?} for stream {} (total: {})", self.id, data, id, buf.len() + 1); + trace!( + "{}: Buffering {:?} for stream {} (total: {})", + self.id, + data, + id, + buf.len() + 1 + ); buf.push(data); self.notifier_read.wake_read_stream(id); if buf.len() > self.config.max_buffer_len { @@ -812,9 +912,8 @@ where self.check_max_pending_frames()?; self.substreams.insert(id, SubstreamState::Reset { buf }); debug!("{}: Pending reset for stream {}", self.id, id); - self.pending_frames.push_front(Frame::Reset { - stream_id: id - }); + self.pending_frames + .push_front(Frame::Reset { stream_id: id }); } MaxBufferBehaviour::Block => { self.blocking_stream = Some(id); @@ -845,7 +944,7 @@ enum SubstreamState { Closed { buf: RecvBuf }, /// The stream has been reset by the local or remote peer but has /// not yet been dropped and may still have buffered frames to read. - Reset { buf: RecvBuf } + Reset { buf: RecvBuf }, } impl SubstreamState { @@ -889,9 +988,11 @@ impl NotifierRead { /// The returned waker should be passed to an I/O read operation /// that schedules a wakeup, if the operation is pending. #[must_use] - fn register_read_stream<'a>(self: &'a Arc, waker: &Waker, id: LocalStreamId) - -> WakerRef<'a> - { + fn register_read_stream<'a>( + self: &'a Arc, + waker: &Waker, + id: LocalStreamId, + ) -> WakerRef<'a> { let mut pending = self.read_stream.lock(); pending.insert(id, waker.clone()); waker_ref(self) @@ -914,7 +1015,7 @@ impl NotifierRead { if let Some(waker) = pending.remove(&id) { waker.wake(); - return true + return true; } false @@ -999,21 +1100,23 @@ const EXTRA_PENDING_FRAMES: usize = 1000; #[cfg(test)] mod tests { + use super::*; use async_std::task; + use asynchronous_codec::{Decoder, Encoder}; use bytes::BytesMut; use futures::prelude::*; - use asynchronous_codec::{Decoder, Encoder}; use quickcheck::*; use rand::prelude::*; use std::collections::HashSet; use std::num::NonZeroU8; use std::ops::DerefMut; use std::pin::Pin; - use super::*; impl Arbitrary for MaxBufferBehaviour { fn arbitrary(g: &mut G) -> MaxBufferBehaviour { - *[MaxBufferBehaviour::Block, MaxBufferBehaviour::ResetStream].choose(g).unwrap() + *[MaxBufferBehaviour::Block, MaxBufferBehaviour::ResetStream] + .choose(g) + .unwrap() } } @@ -1042,10 +1145,10 @@ mod tests { fn poll_read( mut self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &mut [u8] + buf: &mut [u8], ) -> Poll> { if self.eof { - return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())) + return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())); } let n = std::cmp::min(buf.len(), self.r_buf.len()); let data = self.r_buf.split_to(n); @@ -1062,23 +1165,17 @@ mod tests { fn poll_write( mut self: Pin<&mut Self>, _: &mut Context<'_>, - buf: &[u8] + buf: &[u8], ) -> Poll> { self.w_buf.extend_from_slice(buf); Poll::Ready(Ok(buf.len())) } - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_> - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn poll_close( - self: Pin<&mut Self>, - _: &mut Context<'_> - ) -> Poll> { + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } @@ -1092,25 +1189,37 @@ mod tests { let mut codec = Codec::new(); // Open the maximum number of inbound streams. - for i in 0 .. cfg.max_substreams { + for i in 0..cfg.max_substreams { let stream_id = LocalStreamId::dialer(i as u64); - codec.encode(Frame::Open { stream_id }, &mut r_buf).unwrap(); + codec.encode(Frame::Open { stream_id }, &mut r_buf).unwrap(); } // Send more data on stream 0 than the buffer permits. let stream_id = LocalStreamId::dialer(0); let data = Bytes::from("Hello world"); - for _ in 0 .. cfg.max_buffer_len + overflow.get() as usize { - codec.encode(Frame::Data { stream_id, data: data.clone() }, &mut r_buf).unwrap(); + for _ in 0..cfg.max_buffer_len + overflow.get() as usize { + codec + .encode( + Frame::Data { + stream_id, + data: data.clone(), + }, + &mut r_buf, + ) + .unwrap(); } // Setup the multiplexed connection. - let conn = Connection { r_buf, w_buf: BytesMut::new(), eof: false }; + let conn = Connection { + r_buf, + w_buf: BytesMut::new(), + eof: false, + }; let mut m = Multiplexed::new(conn, cfg.clone()); task::block_on(future::poll_fn(move |cx| { // Receive all inbound streams. - for i in 0 .. cfg.max_substreams { + for i in 0..cfg.max_substreams { match m.poll_next_stream(cx) { Poll::Pending => panic!("Expected new inbound stream."), Poll::Ready(Err(e)) => panic!("{:?}", e), @@ -1161,7 +1270,7 @@ mod tests { } MaxBufferBehaviour::Block => { assert!(m.poll_next_stream(cx).is_pending()); - for i in 1 .. cfg.max_substreams { + for i in 1..cfg.max_substreams { let id = LocalStreamId::listener(i as u64); assert!(m.poll_read_stream(cx, id).is_pending()); } @@ -1169,12 +1278,12 @@ mod tests { } // Drain the buffer by reading from the stream. - for _ in 0 .. cfg.max_buffer_len + 1 { + for _ in 0..cfg.max_buffer_len + 1 { match m.poll_read_stream(cx, id) { Poll::Ready(Ok(Some(bytes))) => { assert_eq!(bytes, data); } - x => panic!("Unexpected: {:?}", x) + x => panic!("Unexpected: {:?}", x), } } @@ -1185,8 +1294,8 @@ mod tests { MaxBufferBehaviour::ResetStream => { // Expect to read EOF match m.poll_read_stream(cx, id) { - Poll::Ready(Ok(None)) => {}, - poll => panic!("Unexpected: {:?}", poll) + Poll::Ready(Ok(None)) => {} + poll => panic!("Unexpected: {:?}", poll), } } MaxBufferBehaviour::Block => { @@ -1194,7 +1303,7 @@ mod tests { match m.poll_read_stream(cx, id) { Poll::Ready(Ok(Some(bytes))) => assert_eq!(bytes, data), Poll::Pending => assert_eq!(overflow.get(), 1), - poll => panic!("Unexpected: {:?}", poll) + poll => panic!("Unexpected: {:?}", poll), } } } @@ -1203,7 +1312,7 @@ mod tests { })); } - quickcheck(prop as fn(_,_)) + quickcheck(prop as fn(_, _)) } #[test] @@ -1217,7 +1326,7 @@ mod tests { let conn = Connection { r_buf: BytesMut::new(), w_buf: BytesMut::new(), - eof: false + eof: false, }; let mut m = Multiplexed::new(conn, cfg.clone()); @@ -1225,7 +1334,7 @@ mod tests { let mut opened = HashSet::new(); task::block_on(future::poll_fn(move |cx| { // Open a number of streams. - for _ in 0 .. num_streams { + for _ in 0..num_streams { let id = ready!(m.poll_open_stream(cx)).unwrap(); assert!(opened.insert(id)); assert!(m.poll_read_stream(cx, id).is_pending()); @@ -1238,7 +1347,7 @@ mod tests { // should be closed due to the failed connection. assert!(opened.iter().all(|id| match m.poll_read_stream(cx, *id) { Poll::Ready(Err(e)) => e.kind() == io::ErrorKind::UnexpectedEof, - _ => false + _ => false, })); assert!(m.substreams.is_empty()); @@ -1247,6 +1356,6 @@ mod tests { })) } - quickcheck(prop as fn(_,_)) + quickcheck(prop as fn(_, _)) } } diff --git a/muxers/mplex/src/lib.rs b/muxers/mplex/src/lib.rs index 653c3310ab4..05e7571cf87 100644 --- a/muxers/mplex/src/lib.rs +++ b/muxers/mplex/src/lib.rs @@ -22,18 +22,18 @@ mod codec; mod config; mod io; -pub use config::{MplexConfig, MaxBufferBehaviour}; +pub use config::{MaxBufferBehaviour, MplexConfig}; -use codec::LocalStreamId; -use std::{cmp, iter, task::Context, task::Poll}; use bytes::Bytes; +use codec::LocalStreamId; +use futures::{future, prelude::*, ready}; use libp2p_core::{ - StreamMuxer, muxing::StreamMuxerEvent, upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}, + StreamMuxer, }; use parking_lot::Mutex; -use futures::{prelude::*, future, ready}; +use std::{cmp, iter, task::Context, task::Poll}; impl UpgradeInfo for MplexConfig { type Info = &'static [u8]; @@ -69,7 +69,7 @@ where fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future { future::ready(Ok(Multiplex { - io: Mutex::new(io::Multiplexed::new(socket, self)) + io: Mutex::new(io::Multiplexed::new(socket, self)), })) } } @@ -79,20 +79,21 @@ where /// This implementation isn't capable of detecting when the underlying socket changes its address, /// and no [`StreamMuxerEvent::AddressChange`] event is ever emitted. pub struct Multiplex { - io: Mutex> + io: Mutex>, } impl StreamMuxer for Multiplex where - C: AsyncRead + AsyncWrite + Unpin + C: AsyncRead + AsyncWrite + Unpin, { type Substream = Substream; type OutboundSubstream = OutboundSubstream; type Error = io::Error; - fn poll_event(&self, cx: &mut Context<'_>) - -> Poll>> - { + fn poll_event( + &self, + cx: &mut Context<'_>, + ) -> Poll>> { let stream_id = ready!(self.io.lock().poll_next_stream(cx))?; let stream = Substream::new(stream_id); Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(stream))) @@ -102,9 +103,11 @@ where OutboundSubstream {} } - fn poll_outbound(&self, cx: &mut Context<'_>, _: &mut Self::OutboundSubstream) - -> Poll> - { + fn poll_outbound( + &self, + cx: &mut Context<'_>, + _: &mut Self::OutboundSubstream, + ) -> Poll> { let stream_id = ready!(self.io.lock().poll_open_stream(cx))?; Poll::Ready(Ok(Substream::new(stream_id))) } @@ -113,9 +116,12 @@ where // Nothing to do, since `open_outbound` creates no new local state. } - fn read_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream, buf: &mut [u8]) - -> Poll> - { + fn read_substream( + &self, + cx: &mut Context<'_>, + substream: &mut Self::Substream, + buf: &mut [u8], + ) -> Poll> { loop { // Try to read from the current (i.e. last received) frame. if !substream.current_data.is_empty() { @@ -126,27 +132,36 @@ where // Read the next data frame from the multiplexed stream. match ready!(self.io.lock().poll_read_stream(cx, substream.id))? { - Some(data) => { substream.current_data = data; } - None => { return Poll::Ready(Ok(0)) } + Some(data) => { + substream.current_data = data; + } + None => return Poll::Ready(Ok(0)), } } } - fn write_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream, buf: &[u8]) - -> Poll> - { + fn write_substream( + &self, + cx: &mut Context<'_>, + substream: &mut Self::Substream, + buf: &[u8], + ) -> Poll> { self.io.lock().poll_write_stream(cx, substream.id, buf) } - fn flush_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream) - -> Poll> - { + fn flush_substream( + &self, + cx: &mut Context<'_>, + substream: &mut Self::Substream, + ) -> Poll> { self.io.lock().poll_flush_stream(cx, substream.id) } - fn shutdown_substream(&self, cx: &mut Context<'_>, substream: &mut Self::Substream) - -> Poll> - { + fn shutdown_substream( + &self, + cx: &mut Context<'_>, + substream: &mut Self::Substream, + ) -> Poll> { self.io.lock().poll_close_stream(cx, substream.id) } @@ -176,6 +191,9 @@ pub struct Substream { impl Substream { fn new(id: LocalStreamId) -> Self { - Self { id, current_data: Bytes::new() } + Self { + id, + current_data: Bytes::new(), + } } } diff --git a/muxers/mplex/tests/async_write.rs b/muxers/mplex/tests/async_write.rs index 1414db14847..d4a1df7c4c5 100644 --- a/muxers/mplex/tests/async_write.rs +++ b/muxers/mplex/tests/async_write.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use futures::{channel::oneshot, prelude::*}; use libp2p_core::{muxing, upgrade, Transport}; use libp2p_tcp::TcpConfig; -use futures::{prelude::*, channel::oneshot}; use std::sync::Arc; #[test] @@ -32,14 +32,16 @@ fn async_write() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let mut listener = transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener.next().await + let addr = listener + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -48,12 +50,19 @@ fn async_write() { tx.send(addr).unwrap(); let client = listener - .next().await + .next() + .await .unwrap() .unwrap() - .into_upgrade().unwrap().0.await.unwrap(); + .into_upgrade() + .unwrap() + .0 + .await + .unwrap(); - let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)).await.unwrap(); + let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)) + .await + .unwrap(); let mut buf = Vec::new(); outbound.read_to_end(&mut buf).await.unwrap(); @@ -62,13 +71,16 @@ fn async_write() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let client = Arc::new(transport.dial(rx.await.unwrap()).unwrap().await.unwrap()); let mut inbound = loop { - if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()).await.unwrap() - .into_inbound_substream() { + if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()) + .await + .unwrap() + .into_inbound_substream() + { break s; } }; diff --git a/muxers/mplex/tests/two_peers.rs b/muxers/mplex/tests/two_peers.rs index 54b939a548a..eb0526f4044 100644 --- a/muxers/mplex/tests/two_peers.rs +++ b/muxers/mplex/tests/two_peers.rs @@ -18,9 +18,9 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use futures::{channel::oneshot, prelude::*}; use libp2p_core::{muxing, upgrade, Transport}; use libp2p_tcp::TcpConfig; -use futures::{channel::oneshot, prelude::*}; use std::sync::Arc; #[test] @@ -32,14 +32,16 @@ fn client_to_server_outbound() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let mut listener = transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener.next().await + let addr = listener + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -48,12 +50,19 @@ fn client_to_server_outbound() { tx.send(addr).unwrap(); let client = listener - .next().await + .next() + .await .unwrap() .unwrap() - .into_upgrade().unwrap().0.await.unwrap(); + .into_upgrade() + .unwrap() + .0 + .await + .unwrap(); - let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)).await.unwrap(); + let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)) + .await + .unwrap(); let mut buf = Vec::new(); outbound.read_to_end(&mut buf).await.unwrap(); @@ -62,13 +71,16 @@ fn client_to_server_outbound() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let client = Arc::new(transport.dial(rx.await.unwrap()).unwrap().await.unwrap()); let mut inbound = loop { - if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()).await.unwrap() - .into_inbound_substream() { + if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()) + .await + .unwrap() + .into_inbound_substream() + { break s; } }; @@ -88,14 +100,16 @@ fn client_to_server_inbound() { let bg_thread = async_std::task::spawn(async move { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let mut listener = transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener.next().await + let addr = listener + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -103,15 +117,25 @@ fn client_to_server_inbound() { tx.send(addr).unwrap(); - let client = Arc::new(listener - .next().await - .unwrap() - .unwrap() - .into_upgrade().unwrap().0.await.unwrap()); + let client = Arc::new( + listener + .next() + .await + .unwrap() + .unwrap() + .into_upgrade() + .unwrap() + .0 + .await + .unwrap(), + ); let mut inbound = loop { - if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()).await.unwrap() - .into_inbound_substream() { + if let Some(s) = muxing::event_from_ref_and_wrap(client.clone()) + .await + .unwrap() + .into_inbound_substream() + { break s; } }; @@ -123,11 +147,13 @@ fn client_to_server_inbound() { async_std::task::block_on(async { let mplex = libp2p_mplex::MplexConfig::new(); - let transport = TcpConfig::new().and_then(move |c, e| - upgrade::apply(c, mplex, e, upgrade::Version::V1)); + let transport = TcpConfig::new() + .and_then(move |c, e| upgrade::apply(c, mplex, e, upgrade::Version::V1)); let client = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); - let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)).await.unwrap(); + let mut outbound = muxing::outbound_from_ref_and_wrap(Arc::new(client)) + .await + .unwrap(); outbound.write_all(b"hello world").await.unwrap(); outbound.close().await.unwrap(); diff --git a/muxers/yamux/src/lib.rs b/muxers/yamux/src/lib.rs index eb47ad8d0ce..941e0fefd8e 100644 --- a/muxers/yamux/src/lib.rs +++ b/muxers/yamux/src/lib.rs @@ -21,11 +21,20 @@ //! Implements the Yamux multiplexing protocol for libp2p, see also the //! [specification](https://github.com/hashicorp/yamux/blob/master/spec.md). -use futures::{future, prelude::*, ready, stream::{BoxStream, LocalBoxStream}}; +use futures::{ + future, + prelude::*, + ready, + stream::{BoxStream, LocalBoxStream}, +}; use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent}; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; use parking_lot::Mutex; -use std::{fmt, io, iter, pin::Pin, task::{Context, Poll}}; +use std::{ + fmt, io, iter, + pin::Pin, + task::{Context, Poll}, +}; use thiserror::Error; /// A Yamux connection. @@ -50,7 +59,7 @@ pub struct OpenSubstreamToken(()); impl Yamux> where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { /// Create a new Yamux connection. fn new(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self { @@ -59,7 +68,7 @@ where let inner = Inner { incoming: Incoming { stream: yamux::into_stream(conn).err_into().boxed(), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, }, control: ctrl, }; @@ -69,7 +78,7 @@ where impl Yamux> where - C: AsyncRead + AsyncWrite + Unpin + 'static + C: AsyncRead + AsyncWrite + Unpin + 'static, { /// Create a new Yamux connection (which is ![`Send`]). fn local(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self { @@ -78,7 +87,7 @@ where let inner = Inner { incoming: LocalIncoming { stream: yamux::into_stream(conn).err_into().boxed_local(), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, }, control: ctrl, }; @@ -91,20 +100,21 @@ pub type YamuxResult = Result; /// > **Note**: This implementation never emits [`StreamMuxerEvent::AddressChange`] events. impl StreamMuxer for Yamux where - S: Stream> + Unpin + S: Stream> + Unpin, { type Substream = yamux::Stream; type OutboundSubstream = OpenSubstreamToken; type Error = YamuxError; - fn poll_event(&self, c: &mut Context<'_>) - -> Poll>> - { + fn poll_event( + &self, + c: &mut Context<'_>, + ) -> Poll>> { let mut inner = self.0.lock(); match ready!(inner.incoming.poll_next_unpin(c)) { Some(Ok(s)) => Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(s))), Some(Err(e)) => Poll::Ready(Err(e)), - None => Poll::Ready(Err(yamux::ConnectionError::Closed.into())) + None => Poll::Ready(Err(yamux::ConnectionError::Closed.into())), } } @@ -112,53 +122,71 @@ where OpenSubstreamToken(()) } - fn poll_outbound(&self, c: &mut Context<'_>, _: &mut OpenSubstreamToken) - -> Poll> - { + fn poll_outbound( + &self, + c: &mut Context<'_>, + _: &mut OpenSubstreamToken, + ) -> Poll> { let mut inner = self.0.lock(); - Pin::new(&mut inner.control).poll_open_stream(c).map_err(YamuxError) + Pin::new(&mut inner.control) + .poll_open_stream(c) + .map_err(YamuxError) } fn destroy_outbound(&self, _: Self::OutboundSubstream) { self.0.lock().control.abort_open_stream() } - fn read_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream, b: &mut [u8]) - -> Poll> - { - Pin::new(s).poll_read(c, b).map_err(|e| YamuxError(e.into())) - } - - fn write_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream, b: &[u8]) - -> Poll> - { - Pin::new(s).poll_write(c, b).map_err(|e| YamuxError(e.into())) - } - - fn flush_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream) - -> Poll> - { + fn read_substream( + &self, + c: &mut Context<'_>, + s: &mut Self::Substream, + b: &mut [u8], + ) -> Poll> { + Pin::new(s) + .poll_read(c, b) + .map_err(|e| YamuxError(e.into())) + } + + fn write_substream( + &self, + c: &mut Context<'_>, + s: &mut Self::Substream, + b: &[u8], + ) -> Poll> { + Pin::new(s) + .poll_write(c, b) + .map_err(|e| YamuxError(e.into())) + } + + fn flush_substream( + &self, + c: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { Pin::new(s).poll_flush(c).map_err(|e| YamuxError(e.into())) } - fn shutdown_substream(&self, c: &mut Context<'_>, s: &mut Self::Substream) - -> Poll> - { + fn shutdown_substream( + &self, + c: &mut Context<'_>, + s: &mut Self::Substream, + ) -> Poll> { Pin::new(s).poll_close(c).map_err(|e| YamuxError(e.into())) } - fn destroy_substream(&self, _: Self::Substream) { } + fn destroy_substream(&self, _: Self::Substream) {} fn close(&self, c: &mut Context<'_>) -> Poll> { let mut inner = self.0.lock(); if let std::task::Poll::Ready(x) = Pin::new(&mut inner.control).poll_close(c) { - return Poll::Ready(x.map_err(YamuxError)) + return Poll::Ready(x.map_err(YamuxError)); } while let std::task::Poll::Ready(x) = inner.incoming.poll_next_unpin(c) { match x { - Some(Ok(_)) => {} // drop inbound stream + Some(Ok(_)) => {} // drop inbound stream Some(Err(e)) => return Poll::Ready(Err(e)), - None => return Poll::Ready(Ok(())) + None => return Poll::Ready(Ok(())), } } Poll::Pending @@ -173,7 +201,7 @@ where #[derive(Clone)] pub struct YamuxConfig { inner: yamux::Config, - mode: Option + mode: Option, } /// The window update mode determines when window updates are @@ -299,7 +327,7 @@ impl UpgradeInfo for YamuxLocalConfig { impl InboundUpgrade for YamuxConfig where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = Yamux>; type Error = io::Error; @@ -313,7 +341,7 @@ where impl InboundUpgrade for YamuxLocalConfig where - C: AsyncRead + AsyncWrite + Unpin + 'static + C: AsyncRead + AsyncWrite + Unpin + 'static, { type Output = Yamux>; type Error = io::Error; @@ -328,7 +356,7 @@ where impl OutboundUpgrade for YamuxConfig where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = Yamux>; type Error = io::Error; @@ -342,7 +370,7 @@ where impl OutboundUpgrade for YamuxLocalConfig where - C: AsyncRead + AsyncWrite + Unpin + 'static + C: AsyncRead + AsyncWrite + Unpin + 'static, { type Output = Yamux>; type Error = io::Error; @@ -364,7 +392,7 @@ impl From for io::Error { fn from(err: YamuxError) -> Self { match err.0 { yamux::ConnectionError::Io(e) => e, - e => io::Error::new(io::ErrorKind::Other, e) + e => io::Error::new(io::ErrorKind::Other, e), } } } @@ -372,7 +400,7 @@ impl From for io::Error { /// The [`futures::stream::Stream`] of incoming substreams. pub struct Incoming { stream: BoxStream<'static, Result>, - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } impl fmt::Debug for Incoming { @@ -384,7 +412,7 @@ impl fmt::Debug for Incoming { /// The [`futures::stream::Stream`] of incoming substreams (`!Send`). pub struct LocalIncoming { stream: LocalBoxStream<'static, Result>, - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } impl fmt::Debug for LocalIncoming { @@ -396,7 +424,10 @@ impl fmt::Debug for LocalIncoming { impl Stream for Incoming { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> std::task::Poll> { self.stream.as_mut().poll_next_unpin(cx) } @@ -405,13 +436,15 @@ impl Stream for Incoming { } } -impl Unpin for Incoming { -} +impl Unpin for Incoming {} impl Stream for LocalIncoming { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll> { + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> std::task::Poll> { self.stream.as_mut().poll_next_unpin(cx) } @@ -420,5 +453,4 @@ impl Stream for LocalIncoming { } } -impl Unpin for LocalIncoming { -} +impl Unpin for LocalIncoming {} diff --git a/protocols/floodsub/build.rs b/protocols/floodsub/build.rs index 3de5b750ca2..a3de99880dc 100644 --- a/protocols/floodsub/build.rs +++ b/protocols/floodsub/build.rs @@ -19,6 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/rpc.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/rpc.proto"], &["src"]).unwrap(); } - diff --git a/protocols/floodsub/src/layer.rs b/protocols/floodsub/src/layer.rs index b2916fa0605..eb5a7cb30b2 100644 --- a/protocols/floodsub/src/layer.rs +++ b/protocols/floodsub/src/layer.rs @@ -18,26 +18,24 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::protocol::{FloodsubProtocol, FloodsubMessage, FloodsubRpc, FloodsubSubscription, FloodsubSubscriptionAction}; +use crate::protocol::{ + FloodsubMessage, FloodsubProtocol, FloodsubRpc, FloodsubSubscription, + FloodsubSubscriptionAction, +}; use crate::topic::Topic; use crate::FloodsubConfig; use cuckoofilter::{CuckooError, CuckooFilter}; use fnv::FnvHashSet; -use libp2p_core::{PeerId, connection::ConnectionId}; +use libp2p_core::{connection::ConnectionId, PeerId}; use libp2p_swarm::{ - NetworkBehaviour, - NetworkBehaviourAction, - PollParameters, - ProtocolsHandler, - OneShotHandler, - NotifyHandler, - DialPeerCondition, + DialPeerCondition, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, OneShotHandler, + PollParameters, ProtocolsHandler, }; use log::warn; use smallvec::SmallVec; -use std::{collections::VecDeque, iter}; use std::collections::hash_map::{DefaultHasher, HashMap}; use std::task::{Context, Poll}; +use std::{collections::VecDeque, iter}; /// Network behaviour that handles the floodsub protocol. pub struct Floodsub { @@ -87,23 +85,25 @@ impl Floodsub { // Send our topics to this node if we're already connected to it. if self.connected_peers.contains_key(&peer_id) { for topic in self.subscribed_topics.iter().cloned() { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::Any, - event: FloodsubRpc { - messages: Vec::new(), - subscriptions: vec![FloodsubSubscription { - topic, - action: FloodsubSubscriptionAction::Subscribe, - }], - }, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id, + handler: NotifyHandler::Any, + event: FloodsubRpc { + messages: Vec::new(), + subscriptions: vec![FloodsubSubscription { + topic, + action: FloodsubSubscriptionAction::Subscribe, + }], + }, + }); } } if self.target_peers.insert(peer_id) { self.events.push_back(NetworkBehaviourAction::DialPeer { - peer_id, condition: DialPeerCondition::Disconnected + peer_id, + condition: DialPeerCondition::Disconnected, }); } } @@ -123,17 +123,18 @@ impl Floodsub { } for peer in self.connected_peers.keys() { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *peer, - handler: NotifyHandler::Any, - event: FloodsubRpc { - messages: Vec::new(), - subscriptions: vec![FloodsubSubscription { - topic: topic.clone(), - action: FloodsubSubscriptionAction::Subscribe, - }], - }, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *peer, + handler: NotifyHandler::Any, + event: FloodsubRpc { + messages: Vec::new(), + subscriptions: vec![FloodsubSubscription { + topic: topic.clone(), + action: FloodsubSubscriptionAction::Subscribe, + }], + }, + }); } self.subscribed_topics.push(topic); @@ -148,23 +149,24 @@ impl Floodsub { pub fn unsubscribe(&mut self, topic: Topic) -> bool { let pos = match self.subscribed_topics.iter().position(|t| *t == topic) { Some(pos) => pos, - None => return false + None => return false, }; self.subscribed_topics.remove(pos); for peer in self.connected_peers.keys() { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *peer, - handler: NotifyHandler::Any, - event: FloodsubRpc { - messages: Vec::new(), - subscriptions: vec![FloodsubSubscription { - topic: topic.clone(), - action: FloodsubSubscriptionAction::Unsubscribe, - }], - }, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *peer, + handler: NotifyHandler::Any, + event: FloodsubRpc { + messages: Vec::new(), + subscriptions: vec![FloodsubSubscription { + topic: topic.clone(), + action: FloodsubSubscriptionAction::Unsubscribe, + }], + }, + }); } true @@ -184,16 +186,29 @@ impl Floodsub { /// /// /// > **Note**: Doesn't do anything if we're not subscribed to any of the topics. - pub fn publish_many(&mut self, topic: impl IntoIterator>, data: impl Into>) { + pub fn publish_many( + &mut self, + topic: impl IntoIterator>, + data: impl Into>, + ) { self.publish_many_inner(topic, data, true) } /// Publishes a message with multiple topics to the network, even if we're not subscribed to any of the topics. - pub fn publish_many_any(&mut self, topic: impl IntoIterator>, data: impl Into>) { + pub fn publish_many_any( + &mut self, + topic: impl IntoIterator>, + data: impl Into>, + ) { self.publish_many_inner(topic, data, false) } - fn publish_many_inner(&mut self, topic: impl IntoIterator>, data: impl Into>, check_self_subscriptions: bool) { + fn publish_many_inner( + &mut self, + topic: impl IntoIterator>, + data: impl Into>, + check_self_subscriptions: bool, + ) { let message = FloodsubMessage { source: self.config.local_peer_id, data: data.into(), @@ -204,39 +219,48 @@ impl Floodsub { topics: topic.into_iter().map(Into::into).collect(), }; - let self_subscribed = self.subscribed_topics.iter().any(|t| message.topics.iter().any(|u| t == u)); + let self_subscribed = self + .subscribed_topics + .iter() + .any(|t| message.topics.iter().any(|u| t == u)); if self_subscribed { if let Err(e @ CuckooError::NotEnoughSpace) = self.received.add(&message) { warn!( "Message was added to 'received' Cuckoofilter but some \ - other message was removed as a consequence: {}", e, + other message was removed as a consequence: {}", + e, ); } if self.config.subscribe_local_messages { - self.events.push_back( - NetworkBehaviourAction::GenerateEvent(FloodsubEvent::Message(message.clone()))); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + FloodsubEvent::Message(message.clone()), + )); } } // Don't publish the message if we have to check subscriptions // and we're not subscribed ourselves to any of the topics. if check_self_subscriptions && !self_subscribed { - return + return; } // Send to peers we know are subscribed to the topic. for (peer_id, sub_topic) in self.connected_peers.iter() { - if !sub_topic.iter().any(|t| message.topics.iter().any(|u| t == u)) { + if !sub_topic + .iter() + .any(|t| message.topics.iter().any(|u| t == u)) + { continue; } - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *peer_id, - handler: NotifyHandler::Any, - event: FloodsubRpc { - subscriptions: Vec::new(), - messages: vec![message.clone()], - } - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *peer_id, + handler: NotifyHandler::Any, + event: FloodsubRpc { + subscriptions: Vec::new(), + messages: vec![message.clone()], + }, + }); } } } @@ -253,17 +277,18 @@ impl NetworkBehaviour for Floodsub { // We need to send our subscriptions to the newly-connected node. if self.target_peers.contains(id) { for topic in self.subscribed_topics.iter().cloned() { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *id, - handler: NotifyHandler::Any, - event: FloodsubRpc { - messages: Vec::new(), - subscriptions: vec![FloodsubSubscription { - topic, - action: FloodsubSubscriptionAction::Subscribe, - }], - }, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *id, + handler: NotifyHandler::Any, + event: FloodsubRpc { + messages: Vec::new(), + subscriptions: vec![FloodsubSubscription { + topic, + action: FloodsubSubscriptionAction::Subscribe, + }], + }, + }); } } @@ -279,7 +304,7 @@ impl NetworkBehaviour for Floodsub { if self.target_peers.contains(id) { self.events.push_back(NetworkBehaviourAction::DialPeer { peer_id: *id, - condition: DialPeerCondition::Disconnected + condition: DialPeerCondition::Disconnected, }); } } @@ -306,19 +331,26 @@ impl NetworkBehaviour for Floodsub { if !remote_peer_topics.contains(&subscription.topic) { remote_peer_topics.push(subscription.topic.clone()); } - self.events.push_back(NetworkBehaviourAction::GenerateEvent(FloodsubEvent::Subscribed { - peer_id: propagation_source, - topic: subscription.topic, - })); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + FloodsubEvent::Subscribed { + peer_id: propagation_source, + topic: subscription.topic, + }, + )); } FloodsubSubscriptionAction::Unsubscribe => { - if let Some(pos) = remote_peer_topics.iter().position(|t| t == &subscription.topic ) { + if let Some(pos) = remote_peer_topics + .iter() + .position(|t| t == &subscription.topic) + { remote_peer_topics.remove(pos); } - self.events.push_back(NetworkBehaviourAction::GenerateEvent(FloodsubEvent::Unsubscribed { - peer_id: propagation_source, - topic: subscription.topic, - })); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + FloodsubEvent::Unsubscribed { + peer_id: propagation_source, + topic: subscription.topic, + }, + )); } } } @@ -330,20 +362,27 @@ impl NetworkBehaviour for Floodsub { // Use `self.received` to skip the messages that we have already received in the past. // Note that this can result in false positives. match self.received.test_and_add(&message) { - Ok(true) => {}, // Message was added. + Ok(true) => {} // Message was added. Ok(false) => continue, // Message already existed. - Err(e @ CuckooError::NotEnoughSpace) => { // Message added, but some other removed. + Err(e @ CuckooError::NotEnoughSpace) => { + // Message added, but some other removed. warn!( "Message was added to 'received' Cuckoofilter but some \ - other message was removed as a consequence: {}", e, + other message was removed as a consequence: {}", + e, ); } } // Add the message to be dispatched to the user. - if self.subscribed_topics.iter().any(|t| message.topics.iter().any(|u| t == u)) { + if self + .subscribed_topics + .iter() + .any(|t| message.topics.iter().any(|u| t == u)) + { let event = FloodsubEvent::Message(message.clone()); - self.events.push_back(NetworkBehaviourAction::GenerateEvent(event)); + self.events + .push_back(NetworkBehaviourAction::GenerateEvent(event)); } // Propagate the message to everyone else who is subscribed to any of the topics. @@ -352,27 +391,34 @@ impl NetworkBehaviour for Floodsub { continue; } - if !subscr_topics.iter().any(|t| message.topics.iter().any(|u| t == u)) { + if !subscr_topics + .iter() + .any(|t| message.topics.iter().any(|u| t == u)) + { continue; } if let Some(pos) = rpcs_to_dispatch.iter().position(|(p, _)| p == peer_id) { rpcs_to_dispatch[pos].1.messages.push(message.clone()); } else { - rpcs_to_dispatch.push((*peer_id, FloodsubRpc { - subscriptions: Vec::new(), - messages: vec![message.clone()], - })); + rpcs_to_dispatch.push(( + *peer_id, + FloodsubRpc { + subscriptions: Vec::new(), + messages: vec![message.clone()], + }, + )); } } } for (peer_id, rpc) in rpcs_to_dispatch { - self.events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::Any, - event: rpc, - }); + self.events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id, + handler: NotifyHandler::Any, + event: rpc, + }); } } diff --git a/protocols/floodsub/src/lib.rs b/protocols/floodsub/src/lib.rs index 8e7014bedaa..5481ddbf43d 100644 --- a/protocols/floodsub/src/lib.rs +++ b/protocols/floodsub/src/lib.rs @@ -50,7 +50,7 @@ impl FloodsubConfig { pub fn new(local_peer_id: PeerId) -> Self { Self { local_peer_id, - subscribe_local_messages: false + subscribe_local_messages: false, } } } diff --git a/protocols/floodsub/src/protocol.rs b/protocols/floodsub/src/protocol.rs index 1b942549b22..df694b2e06d 100644 --- a/protocols/floodsub/src/protocol.rs +++ b/protocols/floodsub/src/protocol.rs @@ -20,10 +20,13 @@ use crate::rpc_proto; use crate::topic::Topic; -use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo, PeerId, upgrade}; +use futures::{ + io::{AsyncRead, AsyncWrite}, + AsyncWriteExt, Future, +}; +use libp2p_core::{upgrade, InboundUpgrade, OutboundUpgrade, PeerId, UpgradeInfo}; use prost::Message; use std::{error, fmt, io, iter, pin::Pin}; -use futures::{Future, io::{AsyncRead, AsyncWrite}, AsyncWriteExt}; /// Implementation of `ConnectionUpgrade` for the floodsub protocol. #[derive(Debug, Clone, Default)] @@ -61,21 +64,18 @@ where let mut messages = Vec::with_capacity(rpc.publish.len()); for publish in rpc.publish.into_iter() { messages.push(FloodsubMessage { - source: PeerId::from_bytes(&publish.from.unwrap_or_default()).map_err(|_| { - FloodsubDecodeError::InvalidPeerId - })?, + source: PeerId::from_bytes(&publish.from.unwrap_or_default()) + .map_err(|_| FloodsubDecodeError::InvalidPeerId)?, data: publish.data.unwrap_or_default(), sequence_number: publish.seqno.unwrap_or_default(), - topics: publish.topic_ids - .into_iter() - .map(Topic::new) - .collect(), + topics: publish.topic_ids.into_iter().map(Topic::new).collect(), }); } Ok(FloodsubRpc { messages, - subscriptions: rpc.subscriptions + subscriptions: rpc + .subscriptions .into_iter() .map(|sub| FloodsubSubscription { action: if Some(true) == sub.subscribe { @@ -117,12 +117,15 @@ impl From for FloodsubDecodeError { impl fmt::Display for FloodsubDecodeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - FloodsubDecodeError::ReadError(ref err) => - write!(f, "Error while reading from socket: {}", err), - FloodsubDecodeError::ProtobufError(ref err) => - write!(f, "Error while decoding protobuf: {}", err), - FloodsubDecodeError::InvalidPeerId => - write!(f, "Error while decoding PeerId from message"), + FloodsubDecodeError::ReadError(ref err) => { + write!(f, "Error while reading from socket: {}", err) + } + FloodsubDecodeError::ProtobufError(ref err) => { + write!(f, "Error while decoding protobuf: {}", err) + } + FloodsubDecodeError::InvalidPeerId => { + write!(f, "Error while decoding PeerId from message") + } } } } @@ -179,32 +182,30 @@ impl FloodsubRpc { /// Turns this `FloodsubRpc` into a message that can be sent to a substream. fn into_bytes(self) -> Vec { let rpc = rpc_proto::Rpc { - publish: self.messages.into_iter() - .map(|msg| { - rpc_proto::Message { - from: Some(msg.source.to_bytes()), - data: Some(msg.data), - seqno: Some(msg.sequence_number), - topic_ids: msg.topics - .into_iter() - .map(|topic| topic.into()) - .collect() - } + publish: self + .messages + .into_iter() + .map(|msg| rpc_proto::Message { + from: Some(msg.source.to_bytes()), + data: Some(msg.data), + seqno: Some(msg.sequence_number), + topic_ids: msg.topics.into_iter().map(|topic| topic.into()).collect(), }) .collect(), - subscriptions: self.subscriptions.into_iter() - .map(|topic| { - rpc_proto::rpc::SubOpts { - subscribe: Some(topic.action == FloodsubSubscriptionAction::Subscribe), - topic_id: Some(topic.topic.into()) - } + subscriptions: self + .subscriptions + .into_iter() + .map(|topic| rpc_proto::rpc::SubOpts { + subscribe: Some(topic.action == FloodsubSubscriptionAction::Subscribe), + topic_id: Some(topic.topic.into()), }) - .collect() + .collect(), }; let mut buf = Vec::with_capacity(rpc.encoded_len()); - rpc.encode(&mut buf).expect("Vec provides capacity as needed"); + rpc.encode(&mut buf) + .expect("Vec provides capacity as needed"); buf } } diff --git a/protocols/gossipsub/src/behaviour.rs b/protocols/gossipsub/src/behaviour.rs index 8a9c1b9efe0..a4f4ec24bfc 100644 --- a/protocols/gossipsub/src/behaviour.rs +++ b/protocols/gossipsub/src/behaviour.rs @@ -3198,9 +3198,13 @@ where NetworkBehaviourAction::ReportObservedAddr { address, score } => { NetworkBehaviourAction::ReportObservedAddr { address, score } } - NetworkBehaviourAction::CloseConnection { peer_id, connection } => { - NetworkBehaviourAction::CloseConnection { peer_id, connection } - } + NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + } => NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }, }); } diff --git a/protocols/gossipsub/src/protocol.rs b/protocols/gossipsub/src/protocol.rs index 19293f58d7e..199d210452a 100644 --- a/protocols/gossipsub/src/protocol.rs +++ b/protocols/gossipsub/src/protocol.rs @@ -27,12 +27,12 @@ use crate::types::{ GossipsubControlAction, GossipsubRpc, GossipsubSubscription, GossipsubSubscriptionAction, MessageId, PeerInfo, PeerKind, RawGossipsubMessage, }; +use asynchronous_codec::{Decoder, Encoder, Framed}; use byteorder::{BigEndian, ByteOrder}; use bytes::Bytes; use bytes::BytesMut; use futures::future; use futures::prelude::*; -use asynchronous_codec::{Decoder, Encoder, Framed}; use libp2p_core::{ identity::PublicKey, InboundUpgrade, OutboundUpgrade, PeerId, ProtocolName, UpgradeInfo, }; diff --git a/protocols/gossipsub/tests/smoke.rs b/protocols/gossipsub/tests/smoke.rs index b8df7ccddc0..f1bd056c6f9 100644 --- a/protocols/gossipsub/tests/smoke.rs +++ b/protocols/gossipsub/tests/smoke.rs @@ -51,7 +51,9 @@ impl Future for Graph { for (addr, node) in &mut self.nodes { loop { match node.poll_next_unpin(cx) { - Poll::Ready(Some(SwarmEvent::Behaviour(event))) => return Poll::Ready((addr.clone(), event)), + Poll::Ready(Some(SwarmEvent::Behaviour(event))) => { + return Poll::Ready((addr.clone(), event)) + } Poll::Ready(Some(_)) => {} Poll::Ready(None) => panic!("unexpected None when polling nodes"), Poll::Pending => break, @@ -226,7 +228,11 @@ fn multi_hop_propagation() { graph = graph.drain_poll(); // Publish a single message. - graph.nodes[0].1.behaviour_mut().publish(topic, vec![1, 2, 3]).unwrap(); + graph.nodes[0] + .1 + .behaviour_mut() + .publish(topic, vec![1, 2, 3]) + .unwrap(); // Wait for all nodes to receive the published message. let mut received_msgs = 0; diff --git a/protocols/identify/build.rs b/protocols/identify/build.rs index 1b0feff6a40..56c7b20121a 100644 --- a/protocols/identify/build.rs +++ b/protocols/identify/build.rs @@ -19,6 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/structs.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/structs.proto"], &["src"]).unwrap(); } - diff --git a/protocols/identify/src/handler.rs b/protocols/identify/src/handler.rs index 11c239cdfab..f0d05f79dc9 100644 --- a/protocols/identify/src/handler.rs +++ b/protocols/identify/src/handler.rs @@ -19,32 +19,16 @@ // DEALINGS IN THE SOFTWARE. use crate::protocol::{ - IdentifyProtocol, - IdentifyPushProtocol, - IdentifyInfo, - InboundPush, - OutboundPush, - ReplySubstream + IdentifyInfo, IdentifyProtocol, IdentifyPushProtocol, InboundPush, OutboundPush, ReplySubstream, }; use futures::prelude::*; -use libp2p_core::either::{ - EitherError, - EitherOutput, -}; +use libp2p_core::either::{EitherError, EitherOutput}; use libp2p_core::upgrade::{ - EitherUpgrade, - InboundUpgrade, - OutboundUpgrade, - SelectUpgrade, - UpgradeError, + EitherUpgrade, InboundUpgrade, OutboundUpgrade, SelectUpgrade, UpgradeError, }; use libp2p_swarm::{ - NegotiatedSubstream, - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + KeepAlive, NegotiatedSubstream, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; use smallvec::SmallVec; use std::{io, pin::Pin, task::Context, task::Poll, time::Duration}; @@ -57,12 +41,14 @@ use wasm_timer::Delay; /// permitting the underlying connection to be closed. pub struct IdentifyHandler { /// Pending events to yield. - events: SmallVec<[ProtocolsHandlerEvent< + events: SmallVec< + [ProtocolsHandlerEvent< EitherUpgrade>, (), IdentifyHandlerEvent, io::Error, - >; 4]>, + >; 4], + >, /// Future that fires when we need to identify the node again. next_id: Delay, @@ -114,28 +100,23 @@ impl ProtocolsHandler for IdentifyHandler { fn listen_protocol(&self) -> SubstreamProtocol { SubstreamProtocol::new( - SelectUpgrade::new( - IdentifyProtocol, - IdentifyPushProtocol::inbound(), - ), ()) + SelectUpgrade::new(IdentifyProtocol, IdentifyPushProtocol::inbound()), + (), + ) } fn inject_fully_negotiated_inbound( &mut self, output: >::Output, - _: Self::InboundOpenInfo + _: Self::InboundOpenInfo, ) { match output { - EitherOutput::First(substream) => { - self.events.push( - ProtocolsHandlerEvent::Custom( - IdentifyHandlerEvent::Identify(substream))) - } - EitherOutput::Second(info) => { - self.events.push( - ProtocolsHandlerEvent::Custom( - IdentifyHandlerEvent::Identified(info))) - } + EitherOutput::First(substream) => self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::Identify(substream), + )), + EitherOutput::Second(info) => self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::Identified(info), + )), } } @@ -146,39 +127,42 @@ impl ProtocolsHandler for IdentifyHandler { ) { match output { EitherOutput::First(remote_info) => { - self.events.push( - ProtocolsHandlerEvent::Custom( - IdentifyHandlerEvent::Identified(remote_info))); + self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::Identified(remote_info), + )); self.keep_alive = KeepAlive::No; } - EitherOutput::Second(()) => self.events.push( - ProtocolsHandlerEvent::Custom(IdentifyHandlerEvent::IdentificationPushed)) + EitherOutput::Second(()) => self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::IdentificationPushed, + )), } } fn inject_event(&mut self, IdentifyPush(push): Self::InEvent) { - self.events.push(ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new( - EitherUpgrade::B( - IdentifyPushProtocol::outbound(push)), ()) - }); + self.events + .push(ProtocolsHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new( + EitherUpgrade::B(IdentifyPushProtocol::outbound(push)), + (), + ), + }); } fn inject_dial_upgrade_error( &mut self, _info: Self::OutboundOpenInfo, err: ProtocolsHandlerUpgrErr< - >::Error - > + >::Error, + >, ) { let err = err.map_upgrade_err(|e| match e { UpgradeError::Select(e) => UpgradeError::Select(e), UpgradeError::Apply(EitherError::A(ioe)) => UpgradeError::Apply(ioe), UpgradeError::Apply(EitherError::B(ioe)) => UpgradeError::Apply(ioe), }); - self.events.push( - ProtocolsHandlerEvent::Custom( - IdentifyHandlerEvent::IdentificationError(err))); + self.events.push(ProtocolsHandlerEvent::Custom( + IdentifyHandlerEvent::IdentificationError(err), + )); self.keep_alive = KeepAlive::No; self.next_id.reset(self.interval); } @@ -187,7 +171,10 @@ impl ProtocolsHandler for IdentifyHandler { self.keep_alive } - fn poll(&mut self, cx: &mut Context<'_>) -> Poll< + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< ProtocolsHandlerEvent< Self::OutboundProtocol, Self::OutboundOpenInfo, @@ -196,9 +183,7 @@ impl ProtocolsHandler for IdentifyHandler { >, > { if !self.events.is_empty() { - return Poll::Ready( - self.events.remove(0), - ); + return Poll::Ready(self.events.remove(0)); } // Poll the future that fires when we need to identify the node again. @@ -207,11 +192,11 @@ impl ProtocolsHandler for IdentifyHandler { Poll::Ready(Ok(())) => { self.next_id.reset(self.interval); let ev = ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(EitherUpgrade::A(IdentifyProtocol), ()) + protocol: SubstreamProtocol::new(EitherUpgrade::A(IdentifyProtocol), ()), }; Poll::Ready(ev) } - Poll::Ready(Err(err)) => Poll::Ready(ProtocolsHandlerEvent::Close(err)) + Poll::Ready(Err(err)) => Poll::Ready(ProtocolsHandlerEvent::Close(err)), } } } diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 4e557f37f4f..0a339457333 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -22,26 +22,16 @@ use crate::handler::{IdentifyHandler, IdentifyHandlerEvent, IdentifyPush}; use crate::protocol::{IdentifyInfo, ReplySubstream}; use futures::prelude::*; use libp2p_core::{ - ConnectedPoint, - Multiaddr, - PeerId, - PublicKey, connection::{ConnectionId, ListenerId}, - upgrade::UpgradeError + upgrade::UpgradeError, + ConnectedPoint, Multiaddr, PeerId, PublicKey, }; use libp2p_swarm::{ - AddressScore, - DialPeerCondition, - NegotiatedSubstream, - NetworkBehaviour, - NetworkBehaviourAction, - NotifyHandler, - PollParameters, - ProtocolsHandler, - ProtocolsHandlerUpgrErr + AddressScore, DialPeerCondition, NegotiatedSubstream, NetworkBehaviour, NetworkBehaviourAction, + NotifyHandler, PollParameters, ProtocolsHandler, ProtocolsHandlerUpgrErr, }; use std::{ - collections::{HashSet, HashMap, VecDeque}, + collections::{HashMap, HashSet, VecDeque}, io, pin::Pin, task::Context, @@ -74,13 +64,13 @@ enum Reply { Queued { peer: PeerId, io: ReplySubstream, - observed: Multiaddr + observed: Multiaddr, }, /// The reply is being sent. Sending { peer: PeerId, io: Pin> + Send>>, - } + }, } /// Configuration for the [`Identify`] [`NetworkBehaviour`]. @@ -178,14 +168,14 @@ impl Identify { /// Initiates an active push of the local peer information to the given peers. pub fn push(&mut self, peers: I) where - I: IntoIterator + I: IntoIterator, { for p in peers { if self.pending_push.insert(p) { if !self.connected.contains_key(&p) { self.events.push_back(NetworkBehaviourAction::DialPeer { peer_id: p, - condition: DialPeerCondition::Disconnected + condition: DialPeerCondition::Disconnected, }); } } @@ -201,16 +191,29 @@ impl NetworkBehaviour for Identify { IdentifyHandler::new(self.config.initial_delay, self.config.interval) } - fn inject_connection_established(&mut self, peer_id: &PeerId, conn: &ConnectionId, endpoint: &ConnectedPoint) { + fn inject_connection_established( + &mut self, + peer_id: &PeerId, + conn: &ConnectionId, + endpoint: &ConnectedPoint, + ) { let addr = match endpoint { ConnectedPoint::Dialer { address } => address.clone(), ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr.clone(), }; - self.connected.entry(*peer_id).or_default().insert(*conn, addr); + self.connected + .entry(*peer_id) + .or_default() + .insert(*conn, addr); } - fn inject_connection_closed(&mut self, peer_id: &PeerId, conn: &ConnectionId, _: &ConnectedPoint) { + fn inject_connection_closed( + &mut self, + peer_id: &PeerId, + conn: &ConnectionId, + _: &ConnectedPoint, + ) { if let Some(addrs) = self.connected.get_mut(peer_id) { addrs.remove(conn); } @@ -248,41 +251,39 @@ impl NetworkBehaviour for Identify { match event { IdentifyHandlerEvent::Identified(info) => { let observed = info.observed_addr.clone(); - self.events.push_back( - NetworkBehaviourAction::GenerateEvent( - IdentifyEvent::Received { - peer_id, - info, - })); - self.events.push_back( - NetworkBehaviourAction::ReportObservedAddr { + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + IdentifyEvent::Received { peer_id, info }, + )); + self.events + .push_back(NetworkBehaviourAction::ReportObservedAddr { address: observed, score: AddressScore::Finite(1), }); } IdentifyHandlerEvent::IdentificationPushed => { - self.events.push_back( - NetworkBehaviourAction::GenerateEvent( - IdentifyEvent::Pushed { - peer_id, - })); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + IdentifyEvent::Pushed { peer_id }, + )); } IdentifyHandlerEvent::Identify(sender) => { - let observed = self.connected.get(&peer_id) + let observed = self + .connected + .get(&peer_id) .and_then(|addrs| addrs.get(&connection)) - .expect("`inject_event` is only called with an established connection \ - and `inject_connection_established` ensures there is an entry; qed"); - self.pending_replies.push_back( - Reply::Queued { - peer: peer_id, - io: sender, - observed: observed.clone() - }); + .expect( + "`inject_event` is only called with an established connection \ + and `inject_connection_established` ensures there is an entry; qed", + ); + self.pending_replies.push_back(Reply::Queued { + peer: peer_id, + io: sender, + observed: observed.clone(), + }); } IdentifyHandlerEvent::IdentificationError(error) => { - self.events.push_back( - NetworkBehaviourAction::GenerateEvent( - IdentifyEvent::Error { peer_id, error })); + self.events.push_back(NetworkBehaviourAction::GenerateEvent( + IdentifyEvent::Error { peer_id, error }, + )); } } } @@ -332,7 +333,7 @@ impl NetworkBehaviour for Identify { peer_id, event: push, handler: NotifyHandler::Any, - }) + }); } // Check for pending replies to send. @@ -360,12 +361,12 @@ impl NetworkBehaviour for Identify { Poll::Ready(Ok(())) => { let event = IdentifyEvent::Sent { peer_id: peer }; return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); - }, + } Poll::Pending => { self.pending_replies.push_back(Reply::Sending { peer, io }); if sending == to_send { // All remaining futures are NotReady - break + break; } else { reply = self.pending_replies.pop_front(); } @@ -373,13 +374,15 @@ impl NetworkBehaviour for Identify { Poll::Ready(Err(err)) => { let event = IdentifyEvent::Error { peer_id: peer, - error: ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)) + error: ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply( + err, + )), }; return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); - }, + } } } - None => unreachable!() + None => unreachable!(), } } } @@ -438,21 +441,20 @@ fn listen_addrs(params: &impl PollParameters) -> Vec { mod tests { use super::*; use futures::pin_mut; - use libp2p_core::{ - identity, - PeerId, - muxing::StreamMuxerBox, - transport, - Transport, - }; + use libp2p_core::{identity, muxing::StreamMuxerBox, transport, PeerId, Transport}; + use libp2p_mplex::MplexConfig; use libp2p_noise as noise; - use libp2p_tcp::TcpConfig; use libp2p_swarm::{Swarm, SwarmEvent}; - use libp2p_mplex::MplexConfig; + use libp2p_tcp::TcpConfig; - fn transport() -> (identity::PublicKey, transport::Boxed<(PeerId, StreamMuxerBox)>) { + fn transport() -> ( + identity::PublicKey, + transport::Boxed<(PeerId, StreamMuxerBox)>, + ) { let id_keys = identity::Keypair::generate_ed25519(); - let noise_keys = noise::Keypair::::new().into_authentic(&id_keys).unwrap(); + let noise_keys = noise::Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); let pubkey = id_keys.public(); let transport = TcpConfig::new() .nodelay(true) @@ -469,7 +471,8 @@ mod tests { let (pubkey, transport) = transport(); let protocol = Identify::new( IdentifyConfig::new("a".to_string(), pubkey.clone()) - .with_agent_version("b".to_string())); + .with_agent_version("b".to_string()), + ); let swarm = Swarm::new(transport, protocol, pubkey.to_peer_id()); (swarm, pubkey) }; @@ -478,12 +481,15 @@ mod tests { let (pubkey, transport) = transport(); let protocol = Identify::new( IdentifyConfig::new("c".to_string(), pubkey.clone()) - .with_agent_version("d".to_string())); + .with_agent_version("d".to_string()), + ); let swarm = Swarm::new(transport, protocol, pubkey.to_peer_id()); (swarm, pubkey) }; - swarm1.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); + swarm1 + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); let listen_addr = async_std::task::block_on(async { loop { @@ -508,7 +514,11 @@ mod tests { let swarm2_fut = swarm2.select_next_some(); pin_mut!(swarm2_fut); - match future::select(swarm1_fut, swarm2_fut).await.factor_second().0 { + match future::select(swarm1_fut, swarm2_fut) + .await + .factor_second() + .0 + { future::Either::Left(SwarmEvent::Behaviour(IdentifyEvent::Received { info, .. @@ -546,7 +556,8 @@ mod tests { let protocol = Identify::new( IdentifyConfig::new("a".to_string(), pubkey.clone()) // Delay identification requests so we can test the push protocol. - .with_initial_delay(Duration::from_secs(u32::MAX as u64))); + .with_initial_delay(Duration::from_secs(u32::MAX as u64)), + ); let swarm = Swarm::new(transport, protocol, pubkey.to_peer_id()); (swarm, pubkey) }; @@ -557,7 +568,8 @@ mod tests { IdentifyConfig::new("a".to_string(), pubkey.clone()) .with_agent_version("b".to_string()) // Delay identification requests so we can test the push protocol. - .with_initial_delay(Duration::from_secs(u32::MAX as u64))); + .with_initial_delay(Duration::from_secs(u32::MAX as u64)), + ); let swarm = Swarm::new(transport, protocol, pubkey.to_peer_id()); (swarm, pubkey) }; @@ -585,10 +597,15 @@ mod tests { { pin_mut!(swarm1_fut); pin_mut!(swarm2_fut); - match future::select(swarm1_fut, swarm2_fut).await.factor_second().0 { - future::Either::Left(SwarmEvent::Behaviour( - IdentifyEvent::Received { info, .. } - )) => { + match future::select(swarm1_fut, swarm2_fut) + .await + .factor_second() + .0 + { + future::Either::Left(SwarmEvent::Behaviour(IdentifyEvent::Received { + info, + .. + })) => { assert_eq!(info.public_key, pubkey2); assert_eq!(info.protocol_version, "a"); assert_eq!(info.agent_version, "b"); @@ -600,11 +617,13 @@ mod tests { // Once a connection is established, we can initiate an // active push below. } - _ => { continue } + _ => continue, } } - swarm2.behaviour_mut().push(std::iter::once(pubkey1.to_peer_id())); + swarm2 + .behaviour_mut() + .push(std::iter::once(pubkey1.to_peer_id())); } }) } diff --git a/protocols/identify/src/lib.rs b/protocols/identify/src/lib.rs index 48c0c651428..99456ed7001 100644 --- a/protocols/identify/src/lib.rs +++ b/protocols/identify/src/lib.rs @@ -47,4 +47,3 @@ mod protocol; mod structs_proto { include!(concat!(env!("OUT_DIR"), "/structs.rs")); } - diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index fafa5a37855..9604e660e9f 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -21,9 +21,8 @@ use crate::structs_proto; use futures::prelude::*; use libp2p_core::{ - Multiaddr, - PublicKey, - upgrade::{self, InboundUpgrade, OutboundUpgrade, UpgradeInfo} + upgrade::{self, InboundUpgrade, OutboundUpgrade, UpgradeInfo}, + Multiaddr, PublicKey, }; use log::{debug, trace}; use prost::Message; @@ -84,7 +83,7 @@ impl fmt::Debug for ReplySubstream { impl ReplySubstream where - T: AsyncWrite + Unpin + T: AsyncWrite + Unpin, { /// Sends back the requested information on the substream. /// @@ -158,17 +157,18 @@ where type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future { - send(socket, self.0.0).boxed() + send(socket, self.0 .0).boxed() } } async fn send(mut io: T, info: IdentifyInfo) -> io::Result<()> where - T: AsyncWrite + Unpin + T: AsyncWrite + Unpin, { trace!("Sending: {:?}", info); - let listen_addrs = info.listen_addrs + let listen_addrs = info + .listen_addrs .into_iter() .map(|addr| addr.to_vec()) .collect(); @@ -181,11 +181,13 @@ where public_key: Some(pubkey_bytes), listen_addrs, observed_addr: Some(info.observed_addr.to_vec()), - protocols: info.protocols + protocols: info.protocols, }; let mut bytes = Vec::with_capacity(message.encoded_len()); - message.encode(&mut bytes).expect("Vec provides capacity as needed"); + message + .encode(&mut bytes) + .expect("Vec provides capacity as needed"); upgrade::write_length_prefixed(&mut io, bytes).await?; io.close().await?; @@ -195,7 +197,7 @@ where async fn recv(mut socket: T) -> io::Result where - T: AsyncRead + AsyncWrite + Unpin + T: AsyncRead + AsyncWrite + Unpin, { socket.close().await?; @@ -207,7 +209,7 @@ where Ok(v) => v, Err(err) => { debug!("Invalid message: {:?}", err); - return Err(err) + return Err(err); } }; @@ -255,14 +257,14 @@ fn parse_proto_msg(msg: impl AsRef<[u8]>) -> Result { #[cfg(test)] mod tests { - use libp2p_tcp::TcpConfig; - use futures::{prelude::*, channel::oneshot}; + use super::*; + use futures::{channel::oneshot, prelude::*}; use libp2p_core::{ identity, + upgrade::{self, apply_inbound, apply_outbound}, Transport, - upgrade::{self, apply_outbound, apply_inbound} }; - use super::*; + use libp2p_tcp::TcpConfig; #[test] fn correct_transfer() { @@ -280,7 +282,9 @@ mod tests { .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let addr = listener.next().await + let addr = listener + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -288,14 +292,20 @@ mod tests { tx.send(addr).unwrap(); let socket = listener - .next().await.unwrap().unwrap() - .into_upgrade().unwrap() - .0.await.unwrap(); + .next() + .await + .unwrap() + .unwrap() + .into_upgrade() + .unwrap() + .0 + .await + .unwrap(); let sender = apply_inbound(socket, IdentifyProtocol).await.unwrap(); - sender.send( - IdentifyInfo { + sender + .send(IdentifyInfo { public_key: send_pubkey, protocol_version: "proto_version".to_owned(), agent_version: "agent_version".to_owned(), @@ -305,27 +315,36 @@ mod tests { ], protocols: vec!["proto1".to_string(), "proto2".to_string()], observed_addr: "/ip4/100.101.102.103/tcp/5000".parse().unwrap(), - }, - ).await.unwrap(); + }) + .await + .unwrap(); }); async_std::task::block_on(async move { let transport = TcpConfig::new(); let socket = transport.dial(rx.await.unwrap()).unwrap().await.unwrap(); - let info = apply_outbound( - socket, - IdentifyProtocol, - upgrade::Version::V1 - ).await.unwrap(); - assert_eq!(info.observed_addr, "/ip4/100.101.102.103/tcp/5000".parse().unwrap()); + let info = apply_outbound(socket, IdentifyProtocol, upgrade::Version::V1) + .await + .unwrap(); + assert_eq!( + info.observed_addr, + "/ip4/100.101.102.103/tcp/5000".parse().unwrap() + ); assert_eq!(info.public_key, recv_pubkey); assert_eq!(info.protocol_version, "proto_version"); assert_eq!(info.agent_version, "agent_version"); - assert_eq!(info.listen_addrs, - &["/ip4/80.81.82.83/tcp/500".parse().unwrap(), - "/ip6/::1/udp/1000".parse().unwrap()]); - assert_eq!(info.protocols, &["proto1".to_string(), "proto2".to_string()]); + assert_eq!( + info.listen_addrs, + &[ + "/ip4/80.81.82.83/tcp/500".parse().unwrap(), + "/ip6/::1/udp/1000".parse().unwrap() + ] + ); + assert_eq!( + info.protocols, + &["proto1".to_string(), "proto2".to_string()] + ); bg_task.await; }); diff --git a/protocols/kad/build.rs b/protocols/kad/build.rs index abae8bdd169..f05e9e03190 100644 --- a/protocols/kad/build.rs +++ b/protocols/kad/build.rs @@ -19,6 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/dht.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/dht.proto"], &["src"]).unwrap(); } - diff --git a/protocols/kad/src/addresses.rs b/protocols/kad/src/addresses.rs index b0106a6f83d..f5bdd4d0fbc 100644 --- a/protocols/kad/src/addresses.rs +++ b/protocols/kad/src/addresses.rs @@ -65,9 +65,9 @@ impl Addresses { /// /// An address should only be removed if is determined to be invalid or /// otherwise unreachable. - pub fn remove(&mut self, addr: &Multiaddr) -> Result<(),()> { + pub fn remove(&mut self, addr: &Multiaddr) -> Result<(), ()> { if self.addrs.len() == 1 { - return Err(()) + return Err(()); } if let Some(pos) = self.addrs.iter().position(|a| a == addr) { @@ -100,7 +100,7 @@ impl Addresses { pub fn replace(&mut self, old: &Multiaddr, new: &Multiaddr) -> bool { if let Some(a) = self.addrs.iter_mut().find(|a| *a == old) { *a = new.clone(); - return true + return true; } false @@ -109,8 +109,6 @@ impl Addresses { impl fmt::Debug for Addresses { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list() - .entries(self.addrs.iter()) - .finish() + f.debug_list().entries(self.addrs.iter()).finish() } } diff --git a/protocols/kad/src/behaviour.rs b/protocols/kad/src/behaviour.rs index b0a2f3ee8db..4d7e3207334 100644 --- a/protocols/kad/src/behaviour.rs +++ b/protocols/kad/src/behaviour.rs @@ -22,37 +22,37 @@ mod test; -use crate::K_VALUE; use crate::addresses::Addresses; use crate::handler::{ - KademliaHandlerProto, - KademliaHandlerConfig, + KademliaHandlerConfig, KademliaHandlerEvent, KademliaHandlerIn, KademliaHandlerProto, KademliaRequestId, - KademliaHandlerEvent, - KademliaHandlerIn }; use crate::jobs::*; use crate::kbucket::{self, Distance, KBucketsTable, NodeStatus}; -use crate::protocol::{KademliaProtocolConfig, KadConnectionType, KadPeer}; -use crate::query::{Query, QueryId, QueryPool, QueryConfig, QueryPoolState}; -use crate::record::{self, store::{self, RecordStore}, Record, ProviderRecord}; +use crate::protocol::{KadConnectionType, KadPeer, KademliaProtocolConfig}; +use crate::query::{Query, QueryConfig, QueryId, QueryPool, QueryPoolState}; +use crate::record::{ + self, + store::{self, RecordStore}, + ProviderRecord, Record, +}; +use crate::K_VALUE; use fnv::{FnvHashMap, FnvHashSet}; -use libp2p_core::{ConnectedPoint, Multiaddr, PeerId, connection::{ConnectionId, ListenerId}}; +use libp2p_core::{ + connection::{ConnectionId, ListenerId}, + ConnectedPoint, Multiaddr, PeerId, +}; use libp2p_swarm::{ - DialPeerCondition, - NetworkBehaviour, - NetworkBehaviourAction, - NotifyHandler, - PollParameters, + DialPeerCondition, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters, }; -use log::{info, debug, warn}; +use log::{debug, info, warn}; use smallvec::SmallVec; -use std::{borrow::Cow, error, time::Duration}; -use std::collections::{HashSet, VecDeque, BTreeMap}; +use std::collections::{BTreeMap, HashSet, VecDeque}; use std::fmt; use std::num::NonZeroUsize; use std::task::{Context, Poll}; use std::vec; +use std::{borrow::Cow, error, time::Duration}; use wasm_timer::Instant; pub use crate::query::QueryStats; @@ -356,7 +356,7 @@ impl KademliaConfig { impl Kademlia where - for<'a> TStore: RecordStore<'a> + for<'a> TStore: RecordStore<'a>, { /// Creates a new `Kademlia` network behaviour with a default configuration. pub fn new(id: PeerId, store: TStore) -> Self { @@ -375,12 +375,14 @@ where let put_record_job = config .record_replication_interval .or(config.record_publication_interval) - .map(|interval| PutRecordJob::new( - id, - interval, - config.record_publication_interval, - config.record_ttl, - )); + .map(|interval| { + PutRecordJob::new( + id, + interval, + config.record_publication_interval, + config.record_ttl, + ) + }); let add_provider_job = config .provider_publication_interval @@ -406,42 +408,46 @@ where /// Gets an iterator over immutable references to all running queries. pub fn iter_queries(&self) -> impl Iterator> { - self.queries.iter().filter_map(|query| + self.queries.iter().filter_map(|query| { if !query.is_finished() { Some(QueryRef { query }) } else { None - }) + } + }) } /// Gets an iterator over mutable references to all running queries. pub fn iter_queries_mut(&mut self) -> impl Iterator> { - self.queries.iter_mut().filter_map(|query| + self.queries.iter_mut().filter_map(|query| { if !query.is_finished() { Some(QueryMut { query }) } else { None - }) + } + }) } /// Gets an immutable reference to a running query, if it exists. pub fn query(&self, id: &QueryId) -> Option> { - self.queries.get(id).and_then(|query| + self.queries.get(id).and_then(|query| { if !query.is_finished() { Some(QueryRef { query }) } else { None - }) + } + }) } /// Gets a mutable reference to a running query, if it exists. pub fn query_mut<'a>(&'a mut self, id: &QueryId) -> Option> { - self.queries.get_mut(id).and_then(|query| + self.queries.get_mut(id).and_then(|query| { if !query.is_finished() { Some(QueryMut { query }) } else { None - }) + } + }) } /// Adds a known listen address of a peer participating in the DHT to the @@ -466,18 +472,20 @@ where match self.kbuckets.entry(&key) { kbucket::Entry::Present(mut entry, _) => { if entry.value().insert(address) { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutingUpdated { - peer: *peer, - is_new_peer: false, - addresses: entry.value().clone(), - old_peer: None, - bucket_range: self.kbuckets - .bucket(&key) - .map(|b| b.range()) - .expect("Not kbucket::Entry::SelfEntry."), - } - )) + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutingUpdated { + peer: *peer, + is_new_peer: false, + addresses: entry.value().clone(), + old_peer: None, + bucket_range: self + .kbuckets + .bucket(&key) + .map(|b| b.range()) + .expect("Not kbucket::Entry::SelfEntry."), + }, + )) } RoutingUpdate::Success } @@ -487,41 +495,43 @@ where } kbucket::Entry::Absent(entry) => { let addresses = Addresses::new(address); - let status = - if self.connected_peers.contains(peer) { - NodeStatus::Connected - } else { - NodeStatus::Disconnected - }; + let status = if self.connected_peers.contains(peer) { + NodeStatus::Connected + } else { + NodeStatus::Disconnected + }; match entry.insert(addresses.clone(), status) { kbucket::InsertResult::Inserted => { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutingUpdated { - peer: *peer, - is_new_peer: true, - addresses, - old_peer: None, - bucket_range: self.kbuckets - .bucket(&key) - .map(|b| b.range()) - .expect("Not kbucket::Entry::SelfEntry."), - } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutingUpdated { + peer: *peer, + is_new_peer: true, + addresses, + old_peer: None, + bucket_range: self + .kbuckets + .bucket(&key) + .map(|b| b.range()) + .expect("Not kbucket::Entry::SelfEntry."), + }, + )); RoutingUpdate::Success - }, + } kbucket::InsertResult::Full => { debug!("Bucket full. Peer not added to routing table: {}", peer); RoutingUpdate::Failed - }, + } kbucket::InsertResult::Pending { disconnected } => { - self.queued_events.push_back(NetworkBehaviourAction::DialPeer { - peer_id: disconnected.into_preimage(), - condition: DialPeerCondition::Disconnected - }); + self.queued_events + .push_back(NetworkBehaviourAction::DialPeer { + peer_id: disconnected.into_preimage(), + condition: DialPeerCondition::Disconnected, + }); RoutingUpdate::Pending - }, + } } - }, + } kbucket::Entry::SelfEntry => RoutingUpdate::Failed, } } @@ -536,9 +546,11 @@ where /// /// If the given peer or address is not in the routing table, /// this is a no-op. - pub fn remove_address(&mut self, peer: &PeerId, address: &Multiaddr) - -> Option, Addresses>> - { + pub fn remove_address( + &mut self, + peer: &PeerId, + address: &Multiaddr, + ) -> Option, Addresses>> { let key = kbucket::Key::from(*peer); match self.kbuckets.entry(&key) { kbucket::Entry::Present(mut entry, _) => { @@ -555,9 +567,7 @@ where None } } - kbucket::Entry::Absent(..) | kbucket::Entry::SelfEntry => { - None - } + kbucket::Entry::Absent(..) | kbucket::Entry::SelfEntry => None, } } @@ -565,37 +575,34 @@ where /// /// Returns `None` if the peer was not in the routing table, /// not even pending insertion. - pub fn remove_peer(&mut self, peer: &PeerId) - -> Option, Addresses>> - { + pub fn remove_peer( + &mut self, + peer: &PeerId, + ) -> Option, Addresses>> { let key = kbucket::Key::from(*peer); match self.kbuckets.entry(&key) { - kbucket::Entry::Present(entry, _) => { - Some(entry.remove()) - } - kbucket::Entry::Pending(entry, _) => { - Some(entry.remove()) - } - kbucket::Entry::Absent(..) | kbucket::Entry::SelfEntry => { - None - } + kbucket::Entry::Present(entry, _) => Some(entry.remove()), + kbucket::Entry::Pending(entry, _) => Some(entry.remove()), + kbucket::Entry::Absent(..) | kbucket::Entry::SelfEntry => None, } } /// Returns an iterator over all non-empty buckets in the routing table. - pub fn kbuckets(&mut self) - -> impl Iterator, Addresses>> - { + pub fn kbuckets( + &mut self, + ) -> impl Iterator, Addresses>> { self.kbuckets.iter().filter(|b| !b.is_empty()) } /// Returns the k-bucket for the distance to the given key. /// /// Returns `None` if the given key refers to the local key. - pub fn kbucket(&mut self, key: K) - -> Option, Addresses>> + pub fn kbucket( + &mut self, + key: K, + ) -> Option, Addresses>> where - K: Into> + Clone + K: Into> + Clone, { self.kbuckets.bucket(&key.into()) } @@ -606,9 +613,11 @@ where /// [`KademliaEvent::OutboundQueryCompleted{QueryResult::GetClosestPeers}`]. pub fn get_closest_peers(&mut self, key: K) -> QueryId where - K: Into> + Into> + Clone + K: Into> + Into> + Clone, { - let info = QueryInfo::GetClosestPeers { key: key.clone().into() }; + let info = QueryInfo::GetClosestPeers { + key: key.clone().into(), + }; let target: kbucket::Key = key.into(); let peers = self.kbuckets.closest_keys(&target); let inner = QueryInner::new(info); @@ -627,7 +636,10 @@ where if record.is_expired(Instant::now()) { self.store.remove(key) } else { - records.push(PeerRecord{ peer: None, record: record.into_owned()}); + records.push(PeerRecord { + peer: None, + record: record.into_owned(), + }); } } @@ -669,11 +681,16 @@ where /// does not update the record's expiration in local storage, thus a given record /// with an explicit expiration will always expire at that instant and until then /// is subject to regular (re-)replication and (re-)publication. - pub fn put_record(&mut self, mut record: Record, quorum: Quorum) -> Result { + pub fn put_record( + &mut self, + mut record: Record, + quorum: Quorum, + ) -> Result { record.publisher = Some(*self.kbuckets.local_key().preimage()); self.store.put(record.clone())?; - record.expires = record.expires.or_else(|| - self.record_ttl.map(|ttl| Instant::now() + ttl)); + record.expires = record + .expires + .or_else(|| self.record_ttl.map(|ttl| Instant::now() + ttl)); let quorum = quorum.eval(self.queries.config().replication_factor); let target = kbucket::Key::new(record.key.clone()); let peers = self.kbuckets.closest_keys(&target); @@ -682,7 +699,7 @@ where context, record, quorum, - phase: PutRecordPhase::GetClosestPeers + phase: PutRecordPhase::GetClosestPeers, }; let inner = QueryInner::new(info); Ok(self.queries.add_iter_closest(target.clone(), peers, inner)) @@ -710,7 +727,7 @@ where /// > caching or for other reasons. pub fn put_record_to(&mut self, mut record: Record, peers: I, quorum: Quorum) -> QueryId where - I: ExactSizeIterator + I: ExactSizeIterator, { let quorum = if peers.len() > 0 { quorum.eval(NonZeroUsize::new(peers.len()).expect("> 0")) @@ -720,8 +737,9 @@ where // introducing a new kind of error. NonZeroUsize::new(1).expect("1 > 0") }; - record.expires = record.expires.or_else(|| - self.record_ttl.map(|ttl| Instant::now() + ttl)); + record.expires = record + .expires + .or_else(|| self.record_ttl.map(|ttl| Instant::now() + ttl)); let context = PutRecordContext::Custom; let info = QueryInfo::PutRecord { context, @@ -729,8 +747,8 @@ where quorum, phase: PutRecordPhase::PutRecord { success: Vec::new(), - get_closest_peers_stats: QueryStats::empty() - } + get_closest_peers_stats: QueryStats::empty(), + }, }; let inner = QueryInner::new(info); self.queries.add_fixed(peers, inner) @@ -781,7 +799,7 @@ where let local_key = self.kbuckets.local_key().clone(); let info = QueryInfo::Bootstrap { peer: *local_key.preimage(), - remaining: None + remaining: None, }; let peers = self.kbuckets.closest_keys(&local_key).collect::>(); if peers.is_empty() { @@ -822,7 +840,8 @@ where let record = ProviderRecord::new( key.clone(), *self.kbuckets.local_key().preimage(), - local_addrs); + local_addrs, + ); self.store.add_provider(record)?; let target = kbucket::Key::new(key.clone()); let peers = self.kbuckets.closest_keys(&target); @@ -830,7 +849,7 @@ where let info = QueryInfo::AddProvider { context, key, - phase: AddProviderPhase::GetClosestPeers + phase: AddProviderPhase::GetClosestPeers, }; let inner = QueryInner::new(info); let id = self.queries.add_iter_closest(target.clone(), peers, inner); @@ -842,7 +861,8 @@ where /// This is a local operation. The local node will still be considered as a /// provider for the key by other nodes until these provider records expire. pub fn stop_providing(&mut self, key: &record::Key) { - self.store.remove_provider(key, self.kbuckets.local_key().preimage()); + self.store + .remove_provider(key, self.kbuckets.local_key().preimage()); } /// Performs a lookup for providers of a value to the given key. @@ -863,15 +883,19 @@ where /// Processes discovered peers from a successful request in an iterative `Query`. fn discovered<'a, I>(&'a mut self, query_id: &QueryId, source: &PeerId, peers: I) where - I: Iterator + Clone + I: Iterator + Clone, { let local_id = self.kbuckets.local_key().preimage(); let others_iter = peers.filter(|p| &p.node_id != local_id); if let Some(query) = self.queries.get_mut(query_id) { log::trace!("Request to {:?} in query {:?} succeeded.", source, query_id); for peer in others_iter.clone() { - log::trace!("Peer {:?} reported by {:?} in query {:?}.", - peer, source, query_id); + log::trace!( + "Peer {:?} reported by {:?} in query {:?}.", + peer, + source, + query_id + ); let addrs = peer.multiaddrs.iter().cloned().collect(); query.inner.addresses.insert(peer.node_id, addrs); } @@ -882,7 +906,11 @@ where /// Finds the closest peers to a `target` in the context of a request by /// the `source` peer, such that the `source` peer is never included in the /// result. - fn find_closest(&mut self, target: &kbucket::Key, source: &PeerId) -> Vec { + fn find_closest( + &mut self, + target: &kbucket::Key, + source: &PeerId, + ) -> Vec { if target == self.kbuckets.local_key() { Vec::new() } else { @@ -900,9 +928,10 @@ where let kbuckets = &mut self.kbuckets; let connected = &mut self.connected_peers; let local_addrs = &self.local_addrs; - self.store.providers(key) + self.store + .providers(key) .into_iter() - .filter_map(move |p| + .filter_map(move |p| { if &p.provider != source { let node_id = p.provider; let multiaddrs = p.addresses; @@ -922,21 +951,23 @@ where Some(local_addrs.iter().cloned().collect::>()) } else { let key = kbucket::Key::from(node_id); - kbuckets.entry(&key).view().map(|e| e.node.value.clone().into_vec()) + kbuckets + .entry(&key) + .view() + .map(|e| e.node.value.clone().into_vec()) } } else { Some(multiaddrs) } - .map(|multiaddrs| { - KadPeer { - node_id, - multiaddrs, - connection_ty, - } + .map(|multiaddrs| KadPeer { + node_id, + multiaddrs, + connection_ty, }) } else { None - }) + } + }) .take(self.queries.config().replication_factor.get()) .collect() } @@ -946,7 +977,7 @@ where let info = QueryInfo::AddProvider { context, key: key.clone(), - phase: AddProviderPhase::GetClosestPeers + phase: AddProviderPhase::GetClosestPeers, }; let target = kbucket::Key::new(key); let peers = self.kbuckets.closest_keys(&target); @@ -960,14 +991,22 @@ where let target = kbucket::Key::new(record.key.clone()); let peers = self.kbuckets.closest_keys(&target); let info = QueryInfo::PutRecord { - record, quorum, context, phase: PutRecordPhase::GetClosestPeers + record, + quorum, + context, + phase: PutRecordPhase::GetClosestPeers, }; let inner = QueryInner::new(info); self.queries.add_iter_closest(target.clone(), peers, inner); } /// Updates the routing table with a new connection status and address of a peer. - fn connection_updated(&mut self, peer: PeerId, address: Option, new_status: NodeStatus) { + fn connection_updated( + &mut self, + peer: PeerId, + address: Option, + new_status: NodeStatus, + ) { let key = kbucket::Key::from(peer); match self.kbuckets.entry(&key) { kbucket::Entry::Present(mut entry, old_status) => { @@ -976,21 +1015,23 @@ where } if let Some(address) = address { if entry.value().insert(address) { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutingUpdated { - peer, - is_new_peer: false, - addresses: entry.value().clone(), - old_peer: None, - bucket_range: self.kbuckets - .bucket(&key) - .map(|b| b.range()) - .expect("Not kbucket::Entry::SelfEntry."), - } - )) + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutingUpdated { + peer, + is_new_peer: false, + addresses: entry.value().clone(), + old_peer: None, + bucket_range: self + .kbuckets + .bucket(&key) + .map(|b| b.range()) + .expect("Not kbucket::Entry::SelfEntry."), + }, + )) } } - }, + } kbucket::Entry::Pending(mut entry, old_status) => { if let Some(address) = address { @@ -999,23 +1040,25 @@ where if old_status != new_status { entry.update(new_status); } - }, + } kbucket::Entry::Absent(entry) => { // Only connected nodes with a known address are newly inserted. if new_status != NodeStatus::Connected { - return + return; } match (address, self.kbucket_inserts) { (None, _) => { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::UnroutablePeer { peer } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::UnroutablePeer { peer }, + )); } (Some(a), KademliaBucketInserts::Manual) => { - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutablePeer { peer, address: a } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutablePeer { peer, address: a }, + )); } (Some(a), KademliaBucketInserts::OnConnected) => { let addresses = Addresses::new(a); @@ -1026,26 +1069,31 @@ where is_new_peer: true, addresses, old_peer: None, - bucket_range: self.kbuckets + bucket_range: self + .kbuckets .bucket(&key) .map(|b| b.range()) .expect("Not kbucket::Entry::SelfEntry."), }; - self.queued_events.push_back( - NetworkBehaviourAction::GenerateEvent(event)); - }, + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent(event)); + } kbucket::InsertResult::Full => { debug!("Bucket full. Peer not added to routing table: {}", peer); let address = addresses.first().clone(); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::RoutablePeer { peer, address } - )); - }, + self.queued_events.push_back( + NetworkBehaviourAction::GenerateEvent( + KademliaEvent::RoutablePeer { peer, address }, + ), + ); + } kbucket::InsertResult::Pending { disconnected } => { let address = addresses.first().clone(); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::PendingRoutablePeer { peer, address } - )); + self.queued_events.push_back( + NetworkBehaviourAction::GenerateEvent( + KademliaEvent::PendingRoutablePeer { peer, address }, + ), + ); // `disconnected` might already be in the process of re-connecting. // In other words `disconnected` might have already re-connected but @@ -1054,24 +1102,27 @@ where // // Only try dialing peer if not currently connected. if !self.connected_peers.contains(disconnected.preimage()) { - self.queued_events.push_back(NetworkBehaviourAction::DialPeer { - peer_id: disconnected.into_preimage(), - condition: DialPeerCondition::Disconnected - }) + self.queued_events + .push_back(NetworkBehaviourAction::DialPeer { + peer_id: disconnected.into_preimage(), + condition: DialPeerCondition::Disconnected, + }) } - }, + } } } } - }, + } _ => {} } } /// Handles a finished (i.e. successful) query. - fn query_finished(&mut self, q: Query, params: &mut impl PollParameters) - -> Option - { + fn query_finished( + &mut self, + q: Query, + params: &mut impl PollParameters, + ) -> Option { let query_id = q.id(); log::trace!("Query {:?} finished.", query_id); let result = q.into_result(); @@ -1084,7 +1135,8 @@ where // a bucket refresh should be performed for every bucket farther away than // the first non-empty bucket (which are most likely no more than the last // few, i.e. farthest, buckets). - self.kbuckets.iter() + self.kbuckets + .iter() .skip_while(|b| b.is_empty()) .skip(1) // Skip the bucket with the closest neighbour. .map(|b| { @@ -1102,7 +1154,7 @@ where // Pr(bucket-252) = 1 - (15/16)^16 ~= 0.64 // ... let mut target = kbucket::Key::from(PeerId::random()); - for _ in 0 .. 16 { + for _ in 0..16 { let d = local_key.distance(&target); if b.contains(&d) { break; @@ -1110,7 +1162,9 @@ where target = kbucket::Key::from(PeerId::random()); } target - }).collect::>().into_iter() + }) + .collect::>() + .into_iter() }); let num_remaining = remaining.len() as u32; @@ -1118,48 +1172,49 @@ where if let Some(target) = remaining.next() { let info = QueryInfo::Bootstrap { peer: target.clone().into_preimage(), - remaining: Some(remaining) + remaining: Some(remaining), }; let peers = self.kbuckets.closest_keys(&target); let inner = QueryInner::new(info); - self.queries.continue_iter_closest(query_id, target.clone(), peers, inner); + self.queries + .continue_iter_closest(query_id, target.clone(), peers, inner); } Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::Bootstrap(Ok(BootstrapOk { peer, num_remaining })) + result: QueryResult::Bootstrap(Ok(BootstrapOk { + peer, + num_remaining, + })), }) } - QueryInfo::GetClosestPeers { key, .. } => { - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::GetClosestPeers(Ok( - GetClosestPeersOk { key, peers: result.peers.collect() } - )) - }) - } + QueryInfo::GetClosestPeers { key, .. } => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::GetClosestPeers(Ok(GetClosestPeersOk { + key, + peers: result.peers.collect(), + })), + }), QueryInfo::GetProviders { key, providers } => { Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::GetProviders(Ok( - GetProvidersOk { - key, - providers, - closest_peers: result.peers.collect() - } - )) + result: QueryResult::GetProviders(Ok(GetProvidersOk { + key, + providers, + closest_peers: result.peers.collect(), + })), }) } QueryInfo::AddProvider { context, key, - phase: AddProviderPhase::GetClosestPeers + phase: AddProviderPhase::GetClosestPeers, } => { let provider_id = *params.local_peer_id(); let external_addresses = params.external_addresses().map(|r| r.addr).collect(); @@ -1169,8 +1224,8 @@ where phase: AddProviderPhase::AddProvider { provider_id, external_addresses, - get_closest_peers_stats: result.stats - } + get_closest_peers_stats: result.stats, + }, }); self.queries.continue_fixed(query_id, result.peers, inner); None @@ -1179,28 +1234,32 @@ where QueryInfo::AddProvider { context, key, - phase: AddProviderPhase::AddProvider { get_closest_peers_stats, .. } - } => { - match context { - AddProviderContext::Publish => { - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: get_closest_peers_stats.merge(result.stats), - result: QueryResult::StartProviding(Ok(AddProviderOk { key })) - }) - } - AddProviderContext::Republish => { - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: get_closest_peers_stats.merge(result.stats), - result: QueryResult::RepublishProvider(Ok(AddProviderOk { key })) - }) - } - } - } + phase: + AddProviderPhase::AddProvider { + get_closest_peers_stats, + .. + }, + } => match context { + AddProviderContext::Publish => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: get_closest_peers_stats.merge(result.stats), + result: QueryResult::StartProviding(Ok(AddProviderOk { key })), + }), + AddProviderContext::Republish => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: get_closest_peers_stats.merge(result.stats), + result: QueryResult::RepublishProvider(Ok(AddProviderOk { key })), + }), + }, - QueryInfo::GetRecord { key, records, quorum, cache_candidates } => { - let results = if records.len() >= quorum.get() { // [not empty] + QueryInfo::GetRecord { + key, + records, + quorum, + cache_candidates, + } => { + let results = if records.len() >= quorum.get() { + // [not empty] if quorum.get() == 1 && !cache_candidates.is_empty() { // Cache the record at the closest node(s) to the key that // did not return the record. @@ -1213,25 +1272,33 @@ where quorum, phase: PutRecordPhase::PutRecord { success: vec![], - get_closest_peers_stats: QueryStats::empty() - } + get_closest_peers_stats: QueryStats::empty(), + }, }; let inner = QueryInner::new(info); - self.queries.add_fixed(cache_candidates.values().copied(), inner); + self.queries + .add_fixed(cache_candidates.values().copied(), inner); } - Ok(GetRecordOk { records, cache_candidates }) + Ok(GetRecordOk { + records, + cache_candidates, + }) } else if records.is_empty() { Err(GetRecordError::NotFound { key, - closest_peers: result.peers.collect() + closest_peers: result.peers.collect(), }) } else { - Err(GetRecordError::QuorumFailed { key, records, quorum }) + Err(GetRecordError::QuorumFailed { + key, + records, + quorum, + }) }; Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::GetRecord(results) + result: QueryResult::GetRecord(results), }) } @@ -1239,7 +1306,7 @@ where context, record, quorum, - phase: PutRecordPhase::GetClosestPeers + phase: PutRecordPhase::GetClosestPeers, } => { let info = QueryInfo::PutRecord { context, @@ -1247,8 +1314,8 @@ where quorum, phase: PutRecordPhase::PutRecord { success: vec![], - get_closest_peers_stats: result.stats - } + get_closest_peers_stats: result.stats, + }, }; let inner = QueryInner::new(info); self.queries.continue_fixed(query_id, result.peers, inner); @@ -1259,28 +1326,36 @@ where context, record, quorum, - phase: PutRecordPhase::PutRecord { success, get_closest_peers_stats } + phase: + PutRecordPhase::PutRecord { + success, + get_closest_peers_stats, + }, } => { let mk_result = |key: record::Key| { if success.len() >= quorum.get() { Ok(PutRecordOk { key }) } else { - Err(PutRecordError::QuorumFailed { key, quorum, success }) + Err(PutRecordError::QuorumFailed { + key, + quorum, + success, + }) } }; match context { - PutRecordContext::Publish | PutRecordContext::Custom => - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: get_closest_peers_stats.merge(result.stats), - result: QueryResult::PutRecord(mk_result(record.key)) - }), - PutRecordContext::Republish => + PutRecordContext::Publish | PutRecordContext::Custom => { Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: get_closest_peers_stats.merge(result.stats), - result: QueryResult::RepublishRecord(mk_result(record.key)) - }), + result: QueryResult::PutRecord(mk_result(record.key)), + }) + } + PutRecordContext::Republish => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: get_closest_peers_stats.merge(result.stats), + result: QueryResult::RepublishRecord(mk_result(record.key)), + }), PutRecordContext::Replicate => { debug!("Record replicated: {:?}", record.key); None @@ -1300,7 +1375,10 @@ where log::trace!("Query {:?} timed out.", query_id); let result = query.into_result(); match result.inner.info { - QueryInfo::Bootstrap { peer, mut remaining } => { + QueryInfo::Bootstrap { + peer, + mut remaining, + } => { let num_remaining = remaining.as_ref().map(|r| r.len().saturating_sub(1) as u32); if let Some(mut remaining) = remaining.take() { @@ -1308,78 +1386,74 @@ where if let Some(target) = remaining.next() { let info = QueryInfo::Bootstrap { peer: target.clone().into_preimage(), - remaining: Some(remaining) + remaining: Some(remaining), }; let peers = self.kbuckets.closest_keys(&target); let inner = QueryInner::new(info); - self.queries.continue_iter_closest(query_id, target.clone(), peers, inner); + self.queries + .continue_iter_closest(query_id, target.clone(), peers, inner); } } Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::Bootstrap(Err( - BootstrapError::Timeout { peer, num_remaining } - )) + result: QueryResult::Bootstrap(Err(BootstrapError::Timeout { + peer, + num_remaining, + })), }) } - QueryInfo::AddProvider { context, key, .. } => - Some(match context { - AddProviderContext::Publish => - KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::StartProviding(Err( - AddProviderError::Timeout { key } - )) - }, - AddProviderContext::Republish => - KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::RepublishProvider(Err( - AddProviderError::Timeout { key } - )) - } - }), - - QueryInfo::GetClosestPeers { key } => { - Some(KademliaEvent::OutboundQueryCompleted { + QueryInfo::AddProvider { context, key, .. } => Some(match context { + AddProviderContext::Publish => KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::GetClosestPeers(Err( - GetClosestPeersError::Timeout { - key, - peers: result.peers.collect() - } - )) - }) - }, + result: QueryResult::StartProviding(Err(AddProviderError::Timeout { key })), + }, + AddProviderContext::Republish => KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::RepublishProvider(Err(AddProviderError::Timeout { key })), + }, + }), + + QueryInfo::GetClosestPeers { key } => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::GetClosestPeers(Err(GetClosestPeersError::Timeout { + key, + peers: result.peers.collect(), + })), + }), - QueryInfo::PutRecord { record, quorum, context, phase } => { + QueryInfo::PutRecord { + record, + quorum, + context, + phase, + } => { let err = Err(PutRecordError::Timeout { key: record.key, quorum, success: match phase { PutRecordPhase::GetClosestPeers => vec![], PutRecordPhase::PutRecord { ref success, .. } => success.clone(), - } + }, }); match context { - PutRecordContext::Publish | PutRecordContext::Custom => + PutRecordContext::Publish | PutRecordContext::Custom => { Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::PutRecord(err) - }), - PutRecordContext::Republish => - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::RepublishRecord(err) - }), + result: QueryResult::PutRecord(err), + }) + } + PutRecordContext::Republish => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::RepublishRecord(err), + }), PutRecordContext::Replicate => match phase { PutRecordPhase::GetClosestPeers => { warn!("Locating closest peers for replication failed: {:?}", err); @@ -1389,7 +1463,7 @@ where debug!("Replicating record failed: {:?}", err); None } - } + }, PutRecordContext::Cache => match phase { PutRecordPhase::GetClosestPeers => { // Caching a record at the closest peer to a key that did not return @@ -1401,32 +1475,37 @@ where debug!("Caching record failed: {:?}", err); None } - } + }, } } - QueryInfo::GetRecord { key, records, quorum, .. } => - Some(KademliaEvent::OutboundQueryCompleted { - id: query_id, - stats: result.stats, - result: QueryResult::GetRecord(Err( - GetRecordError::Timeout { key, records, quorum }, - )) - }), + QueryInfo::GetRecord { + key, + records, + quorum, + .. + } => Some(KademliaEvent::OutboundQueryCompleted { + id: query_id, + stats: result.stats, + result: QueryResult::GetRecord(Err(GetRecordError::Timeout { + key, + records, + quorum, + })), + }), - QueryInfo::GetProviders { key, providers } => + QueryInfo::GetProviders { key, providers } => { Some(KademliaEvent::OutboundQueryCompleted { id: query_id, stats: result.stats, - result: QueryResult::GetProviders(Err( - GetProvidersError::Timeout { - key, - providers, - closest_peers: result.peers.collect() - } - )) + result: QueryResult::GetProviders(Err(GetProvidersError::Timeout { + key, + providers, + closest_peers: result.peers.collect(), + })), }) } + } } /// Processes a record received from a peer. @@ -1435,22 +1514,23 @@ where source: PeerId, connection: ConnectionId, request_id: KademliaRequestId, - mut record: Record + mut record: Record, ) { if record.publisher.as_ref() == Some(self.kbuckets.local_key().preimage()) { // If the (alleged) publisher is the local node, do nothing. The record of // the original publisher should never change as a result of replication // and the publisher is always assumed to have the "right" value. - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::PutRecordRes { - key: record.key, - value: record.value, - request_id, - }, - }); - return + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::PutRecordRes { + key: record.key, + value: record.value, + request_id, + }, + }); + return; } let now = Instant::now(); @@ -1463,7 +1543,9 @@ where let num_between = self.kbuckets.count_nodes_between(&target); let k = self.queries.config().replication_factor.get(); let num_beyond_k = (usize::max(k, num_between) - k) as u32; - let expiration = self.record_ttl.map(|ttl| now + exp_decrease(ttl, num_beyond_k)); + let expiration = self + .record_ttl + .map(|ttl| now + exp_decrease(ttl, num_beyond_k)); // The smaller TTL prevails. Only if neither TTL is set is the record // stored "forever". record.expires = record.expires.or(expiration).min(expiration); @@ -1491,16 +1573,21 @@ where // requirement to send back the value in the response, although this // is a waste of resources. match self.store.put(record.clone()) { - Ok(()) => debug!("Record stored: {:?}; {} bytes", record.key, record.value.len()), + Ok(()) => debug!( + "Record stored: {:?}; {} bytes", + record.key, + record.value.len() + ), Err(e) => { info!("Record not stored: {:?}", e); - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::Reset(request_id) - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::Reset(request_id), + }); - return + return; } } } @@ -1512,15 +1599,16 @@ where // closest nodes to the target. In addition returning // [`KademliaHandlerIn::PutRecordRes`] does not reveal any internal // information to a possibly malicious remote node. - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::PutRecordRes { - key: record.key, - value: record.value, - request_id, - }, - }) + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::PutRecordRes { + key: record.key, + value: record.value, + request_id, + }, + }) } /// Processes a provider record received from a peer. @@ -1593,14 +1681,19 @@ where fn inject_connected(&mut self, peer: &PeerId) { // Queue events for sending pending RPCs to the connected peer. // There can be only one pending RPC for a particular peer and query per definition. - for (peer_id, event) in self.queries.iter_mut().filter_map(|q| - q.inner.pending_rpcs.iter() + for (peer_id, event) in self.queries.iter_mut().filter_map(|q| { + q.inner + .pending_rpcs + .iter() .position(|(p, _)| p == peer) - .map(|p| q.inner.pending_rpcs.remove(p))) - { - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, event, handler: NotifyHandler::Any - }); + .map(|p| q.inner.pending_rpcs.remove(p)) + }) { + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id, + event, + handler: NotifyHandler::Any, + }); } self.connected_peers.insert(*peer); @@ -1611,14 +1704,17 @@ where peer: &PeerId, _: &ConnectionId, old: &ConnectedPoint, - new: &ConnectedPoint + new: &ConnectedPoint, ) { let (old, new) = (old.get_remote_address(), new.get_remote_address()); // Update routing table. if let Some(addrs) = self.kbuckets.entry(&kbucket::Key::from(*peer)).value() { if addrs.replace(old, new) { - debug!("Address '{}' replaced with '{}' for peer '{}'.", old, new, peer); + debug!( + "Address '{}' replaced with '{}' for peer '{}'.", + old, new, peer + ); } else { debug!( "Address '{}' not replaced with '{}' for peer '{}' as old address wasn't \ @@ -1663,7 +1759,7 @@ where &mut self, peer_id: Option<&PeerId>, addr: &Multiaddr, - err: &dyn error::Error + err: &dyn error::Error, ) { if let Some(peer_id) = peer_id { let key = kbucket::Key::from(*peer_id); @@ -1675,8 +1771,10 @@ where // of the error is not possible (and also not truly desirable or ergonomic). // The error passed in should rather be a dedicated enum. if addrs.remove(addr).is_ok() { - debug!("Address '{}' removed from peer '{}' due to error: {}.", - addr, peer_id, err); + debug!( + "Address '{}' removed from peer '{}' due to error: {}.", + addr, peer_id, err + ); } else { // Despite apparently having no reachable address (any longer), // the peer is kept in the routing table with the last address to avoid @@ -1687,8 +1785,10 @@ where // into the same bucket. This is handled transparently by the // `KBucketsTable` and takes effect through `KBucketsTable::take_applied_pending` // within `Kademlia::poll`. - debug!("Last remaining address '{}' of peer '{}' is unreachable: {}.", - addr, peer_id, err) + debug!( + "Last remaining address '{}' of peer '{}' is unreachable: {}.", + addr, peer_id, err + ) } } @@ -1718,7 +1818,7 @@ where &mut self, source: PeerId, connection: ConnectionId, - event: KademliaHandlerEvent + event: KademliaHandlerEvent, ) { match event { KademliaHandlerEvent::ProtocolConfirmed { endpoint } => { @@ -1737,20 +1837,24 @@ where KademliaHandlerEvent::FindNodeReq { key, request_id } => { let closer_peers = self.find_closest(&kbucket::Key::new(key), &source); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::FindNode { - num_closer_peers: closer_peers.len(), - }} - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::FindNode { + num_closer_peers: closer_peers.len(), + }, + }, + )); - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::FindNodeRes { - closer_peers, - request_id, - }, - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::FindNodeRes { + closer_peers, + request_id, + }, + }); } KademliaHandlerEvent::FindNodeRes { @@ -1764,22 +1868,26 @@ where let provider_peers = self.provider_peers(&key, &source); let closer_peers = self.find_closest(&kbucket::Key::new(key), &source); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::GetProvider { - num_closer_peers: closer_peers.len(), - num_provider_peers: provider_peers.len(), - }} - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::GetProvider { + num_closer_peers: closer_peers.len(), + num_provider_peers: provider_peers.len(), + }, + }, + )); - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::GetProvidersRes { - closer_peers, - provider_peers, - request_id, - }, - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::GetProvidersRes { + closer_peers, + provider_peers, + request_id, + }, + }); } KademliaHandlerEvent::GetProvidersRes { @@ -1790,9 +1898,7 @@ where let peers = closer_peers.iter().chain(provider_peers.iter()); self.discovered(&user_data, &source, peers); if let Some(query) = self.queries.get_mut(&user_data) { - if let QueryInfo::GetProviders { - providers, .. - } = &mut query.inner.info { + if let QueryInfo::GetProviders { providers, .. } = &mut query.inner.info { for peer in provider_peers { providers.insert(peer.node_id); } @@ -1801,8 +1907,12 @@ where } KademliaHandlerEvent::QueryError { user_data, error } => { - log::debug!("Request to {:?} in query {:?} failed with {:?}", - source, user_data, error); + log::debug!( + "Request to {:?} in query {:?} failed with {:?}", + source, + user_data, + error + ); // If the query to which the error relates is still active, // signal the failure w.r.t. `source`. if let Some(query) = self.queries.get_mut(&user_data) { @@ -1813,14 +1923,17 @@ where KademliaHandlerEvent::AddProvider { key, provider } => { // Only accept a provider record from a legitimate peer. if provider.node_id != source { - return + return; } self.provider_received(key, provider); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::AddProvider {} } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::AddProvider {}, + }, + )); } KademliaHandlerEvent::GetRecord { key, request_id } => { @@ -1833,28 +1946,32 @@ where } else { Some(record.into_owned()) } - }, - None => None + } + None => None, }; let closer_peers = self.find_closest(&kbucket::Key::new(key), &source); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::GetRecord { - num_closer_peers: closer_peers.len(), - present_locally: record.is_some(), - }} - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::GetRecord { + num_closer_peers: closer_peers.len(), + present_locally: record.is_some(), + }, + }, + )); - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: source, - handler: NotifyHandler::One(connection), - event: KademliaHandlerIn::GetRecordRes { - record, - closer_peers, - request_id, - }, - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: source, + handler: NotifyHandler::One(connection), + event: KademliaHandlerIn::GetRecordRes { + record, + closer_peers, + request_id, + }, + }); } KademliaHandlerEvent::GetRecordRes { @@ -1864,17 +1981,25 @@ where } => { if let Some(query) = self.queries.get_mut(&user_data) { if let QueryInfo::GetRecord { - key, records, quorum, cache_candidates - } = &mut query.inner.info { + key, + records, + quorum, + cache_candidates, + } = &mut query.inner.info + { if let Some(record) = record { - records.push(PeerRecord{ peer: Some(source), record }); + records.push(PeerRecord { + peer: Some(source), + record, + }); let quorum = quorum.get(); if records.len() >= quorum { // Desired quorum reached. The query may finish. See // [`Query::try_finish`] for details. - let peers = records.iter() - .filter_map(|PeerRecord{ peer, .. }| peer.as_ref()) + let peers = records + .iter() + .filter_map(|PeerRecord { peer, .. }| peer.as_ref()) .cloned() .collect::>(); let finished = query.try_finish(peers.iter()); @@ -1882,7 +2007,10 @@ where debug!( "GetRecord query ({:?}) reached quorum ({}/{}) with \ response from peer {} but could not yet finish.", - user_data, peers.len(), quorum, source, + user_data, + peers.len(), + quorum, + source, ); } } @@ -1896,7 +2024,8 @@ where if cache_candidates.len() > max_peers as usize { // TODO: `pop_last()` would be nice once stabilised. // See https://github.com/rust-lang/rust/issues/62924. - let last = *cache_candidates.keys().next_back().expect("len > 0"); + let last = + *cache_candidates.keys().next_back().expect("len > 0"); cache_candidates.remove(&last); } } @@ -1907,25 +2036,26 @@ where self.discovered(&user_data, &source, closer_peers.iter()); } - KademliaHandlerEvent::PutRecord { - record, - request_id - } => { + KademliaHandlerEvent::PutRecord { record, request_id } => { self.record_received(source, connection, request_id, record); - self.queued_events.push_back(NetworkBehaviourAction::GenerateEvent( - KademliaEvent::InboundRequestServed{ request: InboundRequest::PutRecord {} } - )); + self.queued_events + .push_back(NetworkBehaviourAction::GenerateEvent( + KademliaEvent::InboundRequestServed { + request: InboundRequest::PutRecord {}, + }, + )); } - KademliaHandlerEvent::PutRecordRes { - user_data, .. - } => { + KademliaHandlerEvent::PutRecordRes { user_data, .. } => { if let Some(query) = self.queries.get_mut(&user_data) { query.on_success(&source, vec![]); if let QueryInfo::PutRecord { - phase: PutRecordPhase::PutRecord { success, .. }, quorum, .. - } = &mut query.inner.info { + phase: PutRecordPhase::PutRecord { success, .. }, + quorum, + .. + } = &mut query.inner.info + { success.push(source); let quorum = quorum.get(); @@ -1936,7 +2066,10 @@ where debug!( "PutRecord query ({:?}) reached quorum ({}/{}) with response \ from peer {} but could not yet finish.", - user_data, peers.len(), quorum, source, + user_data, + peers.len(), + quorum, + source, ); } } @@ -1960,12 +2093,11 @@ where } } - fn poll(&mut self, cx: &mut Context<'_>, parameters: &mut impl PollParameters) -> Poll< - NetworkBehaviourAction< - KademliaHandlerIn, - Self::OutEvent, - >, - > { + fn poll( + &mut self, + cx: &mut Context<'_>, + parameters: &mut impl PollParameters, + ) -> Poll, Self::OutEvent>> { let now = Instant::now(); // Calculate the available capacity for queries triggered by background jobs. @@ -1974,11 +2106,11 @@ where // Run the periodic provider announcement job. if let Some(mut job) = self.add_provider_job.take() { let num = usize::min(JOBS_MAX_NEW_QUERIES, jobs_query_capacity); - for _ in 0 .. num { + for _ in 0..num { if let Poll::Ready(r) = job.poll(cx, &mut self.store, now) { self.start_add_provider(r.key, AddProviderContext::Republish) } else { - break + break; } } jobs_query_capacity -= num; @@ -1988,16 +2120,17 @@ where // Run the periodic record replication / publication job. if let Some(mut job) = self.put_record_job.take() { let num = usize::min(JOBS_MAX_NEW_QUERIES, jobs_query_capacity); - for _ in 0 .. num { + for _ in 0..num { if let Poll::Ready(r) = job.poll(cx, &mut self.store, now) { - let context = if r.publisher.as_ref() == Some(self.kbuckets.local_key().preimage()) { - PutRecordContext::Republish - } else { - PutRecordContext::Replicate - }; + let context = + if r.publisher.as_ref() == Some(self.kbuckets.local_key().preimage()) { + PutRecordContext::Republish + } else { + PutRecordContext::Replicate + }; self.start_put_record(r, Quorum::All, context) } else { - break + break; } } self.put_record_job = Some(job); @@ -2013,7 +2146,8 @@ where if let Some(entry) = self.kbuckets.take_applied_pending() { let kbucket::Node { key, value } = entry.inserted; let event = KademliaEvent::RoutingUpdated { - bucket_range: self.kbuckets + bucket_range: self + .kbuckets .bucket(&key) .map(|b| b.range()) .expect("Self to never be applied from pending."), @@ -2022,7 +2156,7 @@ where addresses: value, old_peer: entry.evicted.map(|n| n.key.into_preimage()), }; - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } // Look for a finished query. @@ -2030,12 +2164,12 @@ where match self.queries.poll(now) { QueryPoolState::Finished(q) => { if let Some(event) = self.query_finished(q, parameters) { - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } } QueryPoolState::Timeout(q) => { if let Some(event) = self.query_timeout(q) { - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)); } } QueryPoolState::Waiting(Some((query, peer_id))) => { @@ -2048,18 +2182,24 @@ where if let QueryInfo::AddProvider { phase: AddProviderPhase::AddProvider { .. }, .. - } = &query.inner.info { + } = &query.inner.info + { query.on_success(&peer_id, vec![]) } if self.connected_peers.contains(&peer_id) { - self.queued_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, event, handler: NotifyHandler::Any - }); + self.queued_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id, + event, + handler: NotifyHandler::Any, + }); } else if &peer_id != self.kbuckets.local_key().preimage() { query.inner.pending_rpcs.push((peer_id, event)); - self.queued_events.push_back(NetworkBehaviourAction::DialPeer { - peer_id, condition: DialPeerCondition::Disconnected - }); + self.queued_events + .push_back(NetworkBehaviourAction::DialPeer { + peer_id, + condition: DialPeerCondition::Disconnected, + }); } } QueryPoolState::Waiting(None) | QueryPoolState::Idle => break, @@ -2070,7 +2210,7 @@ where // If no new events have been queued either, signal `NotReady` to // be polled again later. if self.queued_events.is_empty() { - return Poll::Pending + return Poll::Pending; } } } @@ -2084,7 +2224,7 @@ pub enum Quorum { One, Majority, All, - N(NonZeroUsize) + N(NonZeroUsize), } impl Quorum { @@ -2094,7 +2234,7 @@ impl Quorum { Quorum::One => NonZeroUsize::new(1).expect("1 != 0"), Quorum::Majority => NonZeroUsize::new(total.get() / 2 + 1).expect("n + 1 != 0"), Quorum::All => total, - Quorum::N(n) => NonZeroUsize::min(total, *n) + Quorum::N(n) => NonZeroUsize::min(total, *n), } } } @@ -2122,9 +2262,7 @@ pub enum KademliaEvent { // Note on the difference between 'request' and 'query': A request is a // single request-response style exchange with a single remote peer. A query // is made of multiple requests across multiple remote peers. - InboundRequestServed { - request: InboundRequest, - }, + InboundRequestServed { request: InboundRequest }, /// An outbound query has produced a result. OutboundQueryCompleted { @@ -2133,7 +2271,7 @@ pub enum KademliaEvent { /// The result of the query. result: QueryResult, /// Execution statistics from the query. - stats: QueryStats + stats: QueryStats, }, /// The routing table has been updated with a new peer and / or @@ -2158,9 +2296,7 @@ pub enum KademliaEvent { /// /// If the peer is to be added to the routing table, a known /// listen address for the peer must be provided via [`Kademlia::add_address`]. - UnroutablePeer { - peer: PeerId - }, + UnroutablePeer { peer: PeerId }, /// A connection to a peer has been established for whom a listen address /// is known but the peer has not been added to the routing table either @@ -2173,10 +2309,7 @@ pub enum KademliaEvent { /// /// See [`Kademlia::kbucket`] for insight into the contents of /// the k-bucket of `peer`. - RoutablePeer { - peer: PeerId, - address: Multiaddr, - }, + RoutablePeer { peer: PeerId, address: Multiaddr }, /// A connection to a peer has been established for whom a listen address /// is known but the peer is only pending insertion into the routing table @@ -2189,19 +2322,14 @@ pub enum KademliaEvent { /// /// See [`Kademlia::kbucket`] for insight into the contents of /// the k-bucket of `peer`. - PendingRoutablePeer { - peer: PeerId, - address: Multiaddr, - } + PendingRoutablePeer { peer: PeerId, address: Multiaddr }, } /// Information about a received and handled inbound request. #[derive(Debug)] pub enum InboundRequest { /// Request for the list of nodes whose IDs are the closest to `key`. - FindNode { - num_closer_peers: usize, - }, + FindNode { num_closer_peers: usize }, /// Same as `FindNode`, but should also return the entries of the local /// providers list for this key. GetProvider { @@ -2278,18 +2406,18 @@ pub struct GetRecordOk { pub enum GetRecordError { NotFound { key: record::Key, - closest_peers: Vec + closest_peers: Vec, }, QuorumFailed { key: record::Key, records: Vec, - quorum: NonZeroUsize + quorum: NonZeroUsize, }, Timeout { key: record::Key, records: Vec, - quorum: NonZeroUsize - } + quorum: NonZeroUsize, + }, } impl GetRecordError { @@ -2319,7 +2447,7 @@ pub type PutRecordResult = Result; /// The successful result of [`Kademlia::put_record`]. #[derive(Debug, Clone)] pub struct PutRecordOk { - pub key: record::Key + pub key: record::Key, } /// The error result of [`Kademlia::put_record`]. @@ -2329,13 +2457,13 @@ pub enum PutRecordError { key: record::Key, /// [`PeerId`]s of the peers the record was successfully stored on. success: Vec, - quorum: NonZeroUsize + quorum: NonZeroUsize, }, Timeout { key: record::Key, /// [`PeerId`]s of the peers the record was successfully stored on. success: Vec, - quorum: NonZeroUsize + quorum: NonZeroUsize, }, } @@ -2374,7 +2502,7 @@ pub enum BootstrapError { Timeout { peer: PeerId, num_remaining: Option, - } + }, } /// The result of [`Kademlia::get_closest_peers`]. @@ -2384,16 +2512,13 @@ pub type GetClosestPeersResult = Result #[derive(Debug, Clone)] pub struct GetClosestPeersOk { pub key: Vec, - pub peers: Vec + pub peers: Vec, } /// The error result of [`Kademlia::get_closest_peers`]. #[derive(Debug, Clone)] pub enum GetClosestPeersError { - Timeout { - key: Vec, - peers: Vec - } + Timeout { key: Vec, peers: Vec }, } impl GetClosestPeersError { @@ -2421,7 +2546,7 @@ pub type GetProvidersResult = Result; pub struct GetProvidersOk { pub key: record::Key, pub providers: HashSet, - pub closest_peers: Vec + pub closest_peers: Vec, } /// The error result of [`Kademlia::get_providers`]. @@ -2430,8 +2555,8 @@ pub enum GetProvidersError { Timeout { key: record::Key, providers: HashSet, - closest_peers: Vec - } + closest_peers: Vec, + }, } impl GetProvidersError { @@ -2464,9 +2589,7 @@ pub struct AddProviderOk { #[derive(Debug)] pub enum AddProviderError { /// The query timed out. - Timeout { - key: record::Key, - }, + Timeout { key: record::Key }, } impl AddProviderError { @@ -2492,8 +2615,8 @@ impl From, Addresses>> for KadPeer { multiaddrs: e.node.value.into_vec(), connection_ty: match e.status { NodeStatus::Connected => KadConnectionType::Connected, - NodeStatus::Disconnected => KadConnectionType::NotConnected - } + NodeStatus::Disconnected => KadConnectionType::NotConnected, + }, } } } @@ -2510,7 +2633,7 @@ struct QueryInner { /// /// A request is pending if the targeted peer is not currently connected /// and these requests are sent as soon as a connection to the peer is established. - pending_rpcs: SmallVec<[(PeerId, KademliaHandlerIn); K_VALUE.get()]> + pending_rpcs: SmallVec<[(PeerId, KademliaHandlerIn); K_VALUE.get()]>, } impl QueryInner { @@ -2518,7 +2641,7 @@ impl QueryInner { QueryInner { info, addresses: Default::default(), - pending_rpcs: SmallVec::default() + pending_rpcs: SmallVec::default(), } } } @@ -2567,7 +2690,7 @@ pub enum QueryInfo { /// This is `None` if the initial self-lookup has not /// yet completed and `Some` with an exhausted iterator /// if bootstrapping is complete. - remaining: Option>> + remaining: Option>>, }, /// A query initiated by [`Kademlia::get_closest_peers`]. @@ -2639,16 +2762,18 @@ impl QueryInfo { key: key.to_vec(), user_data: query_id, }, - AddProviderPhase::AddProvider { provider_id, external_addresses, .. } => { - KademliaHandlerIn::AddProvider { - key: key.clone(), - provider: crate::protocol::KadPeer { - node_id: *provider_id, - multiaddrs: external_addresses.clone(), - connection_ty: crate::protocol::KadConnectionType::Connected, - } - } - } + AddProviderPhase::AddProvider { + provider_id, + external_addresses, + .. + } => KademliaHandlerIn::AddProvider { + key: key.clone(), + provider: crate::protocol::KadPeer { + node_id: *provider_id, + multiaddrs: external_addresses.clone(), + connection_ty: crate::protocol::KadConnectionType::Connected, + }, + }, }, QueryInfo::GetRecord { key, .. } => KademliaHandlerIn::GetRecord { key: key.clone(), @@ -2661,9 +2786,9 @@ impl QueryInfo { }, PutRecordPhase::PutRecord { .. } => KademliaHandlerIn::PutRecord { record: record.clone(), - user_data: query_id - } - } + user_data: query_id, + }, + }, } } } diff --git a/protocols/kad/src/behaviour/test.rs b/protocols/kad/src/behaviour/test.rs index 97b020a88c8..8cbe29a0e83 100644 --- a/protocols/kad/src/behaviour/test.rs +++ b/protocols/kad/src/behaviour/test.rs @@ -22,30 +22,30 @@ use super::*; -use crate::K_VALUE; use crate::kbucket::Distance; -use crate::record::{Key, store::MemoryStore}; -use futures::{ - prelude::*, - executor::block_on, - future::poll_fn, -}; +use crate::record::{store::MemoryStore, Key}; +use crate::K_VALUE; +use futures::{executor::block_on, future::poll_fn, prelude::*}; use futures_timer::Delay; use libp2p_core::{ connection::{ConnectedPoint, ConnectionId}, - PeerId, - Transport, identity, - transport::MemoryTransport, - multiaddr::{Protocol, Multiaddr, multiaddr}, + multiaddr::{multiaddr, Multiaddr, Protocol}, multihash::{Code, Multihash, MultihashDigest}, + transport::MemoryTransport, + PeerId, Transport, }; use libp2p_noise as noise; use libp2p_swarm::{Swarm, SwarmEvent}; use libp2p_yamux as yamux; use quickcheck::*; -use rand::{Rng, random, thread_rng, rngs::StdRng, SeedableRng}; -use std::{collections::{HashSet, HashMap}, time::Duration, num::NonZeroUsize, u64}; +use rand::{random, rngs::StdRng, thread_rng, Rng, SeedableRng}; +use std::{ + collections::{HashMap, HashSet}, + num::NonZeroUsize, + time::Duration, + u64, +}; type TestSwarm = Swarm>; @@ -56,7 +56,9 @@ fn build_node() -> (Multiaddr, TestSwarm) { fn build_node_with_config(cfg: KademliaConfig) -> (Multiaddr, TestSwarm) { let local_key = identity::Keypair::generate_ed25519(); let local_public_key = local_key.public(); - let noise_keys = noise::Keypair::::new().into_authentic(&local_key).unwrap(); + let noise_keys = noise::Keypair::::new() + .into_authentic(&local_key) + .unwrap(); let transport = MemoryTransport::default() .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) @@ -82,25 +84,33 @@ fn build_nodes(num: usize) -> Vec<(Multiaddr, TestSwarm)> { /// Builds swarms, each listening on a port. Does *not* connect the nodes together. fn build_nodes_with_config(num: usize, cfg: KademliaConfig) -> Vec<(Multiaddr, TestSwarm)> { - (0..num).map(|_| build_node_with_config(cfg.clone())).collect() + (0..num) + .map(|_| build_node_with_config(cfg.clone())) + .collect() } fn build_connected_nodes(total: usize, step: usize) -> Vec<(Multiaddr, TestSwarm)> { build_connected_nodes_with_config(total, step, Default::default()) } -fn build_connected_nodes_with_config(total: usize, step: usize, cfg: KademliaConfig) - -> Vec<(Multiaddr, TestSwarm)> -{ +fn build_connected_nodes_with_config( + total: usize, + step: usize, + cfg: KademliaConfig, +) -> Vec<(Multiaddr, TestSwarm)> { let mut swarms = build_nodes_with_config(total, cfg); - let swarm_ids: Vec<_> = swarms.iter() + let swarm_ids: Vec<_> = swarms + .iter() .map(|(addr, swarm)| (addr.clone(), *swarm.local_peer_id())) .collect(); let mut i = 0; for (j, (addr, peer_id)) in swarm_ids.iter().enumerate().skip(1) { if i < swarm_ids.len() { - swarms[i].1.behaviour_mut().add_address(peer_id, addr.clone()); + swarms[i] + .1 + .behaviour_mut() + .add_address(peer_id, addr.clone()); } if j % step == 0 { i += step; @@ -110,11 +120,13 @@ fn build_connected_nodes_with_config(total: usize, step: usize, cfg: KademliaCon swarms } -fn build_fully_connected_nodes_with_config(total: usize, cfg: KademliaConfig) - -> Vec<(Multiaddr, TestSwarm)> -{ +fn build_fully_connected_nodes_with_config( + total: usize, + cfg: KademliaConfig, +) -> Vec<(Multiaddr, TestSwarm)> { let mut swarms = build_nodes_with_config(total, cfg); - let swarm_addr_and_peer_id: Vec<_> = swarms.iter() + let swarm_addr_and_peer_id: Vec<_> = swarms + .iter() .map(|(addr, swarm)| (addr.clone(), *swarm.local_peer_id())) .collect(); @@ -159,18 +171,12 @@ fn bootstrap() { cfg.disjoint_query_paths(true); } - let mut swarms = build_connected_nodes_with_config( - num_total, - num_group, - cfg, - ).into_iter() + let mut swarms = build_connected_nodes_with_config(num_total, num_group, cfg) + .into_iter() .map(|(_a, s)| s) .collect::>(); - let swarm_ids: Vec<_> = swarms.iter() - .map(Swarm::local_peer_id) - .cloned() - .collect(); + let swarm_ids: Vec<_> = swarms.iter().map(Swarm::local_peer_id).cloned().collect(); let qid = swarms[0].behaviour_mut().bootstrap().unwrap(); @@ -179,46 +185,49 @@ fn bootstrap() { let mut first = true; // Run test - block_on( - poll_fn(move |ctx| { - for (i, swarm) in swarms.iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::Bootstrap(Ok(ok)), .. - }))) => { - assert_eq!(id, qid); - assert_eq!(i, 0); - if first { - // Bootstrapping must start with a self-lookup. - assert_eq!(ok.peer, swarm_ids[0]); - } - first = false; - if ok.num_remaining == 0 { - assert_eq!( - swarm.behaviour_mut().queries.size(), 0, - "Expect no remaining queries when `num_remaining` is zero.", - ); - let mut known = HashSet::new(); - for b in swarm.behaviour_mut().kbuckets.iter() { - for e in b.iter() { - known.insert(e.node.key.preimage().clone()); - } + block_on(poll_fn(move |ctx| { + for (i, swarm) in swarms.iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::Bootstrap(Ok(ok)), + .. + }, + ))) => { + assert_eq!(id, qid); + assert_eq!(i, 0); + if first { + // Bootstrapping must start with a self-lookup. + assert_eq!(ok.peer, swarm_ids[0]); + } + first = false; + if ok.num_remaining == 0 { + assert_eq!( + swarm.behaviour_mut().queries.size(), + 0, + "Expect no remaining queries when `num_remaining` is zero.", + ); + let mut known = HashSet::new(); + for b in swarm.behaviour_mut().kbuckets.iter() { + for e in b.iter() { + known.insert(e.node.key.preimage().clone()); } - assert_eq!(expected_known, known); - return Poll::Ready(()) } + assert_eq!(expected_known, known); + return Poll::Ready(()); } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } - Poll::Pending - }) - ) + } + Poll::Pending + })) } QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _) @@ -227,7 +236,8 @@ fn bootstrap() { #[test] fn query_iter() { fn distances(key: &kbucket::Key, peers: Vec) -> Vec { - peers.into_iter() + peers + .into_iter() .map(kbucket::Key::from) .map(|k| k.distance(key)) .collect() @@ -235,7 +245,8 @@ fn query_iter() { fn run(rng: &mut impl Rng) { let num_total = rng.gen_range(2, 20); - let mut swarms = build_connected_nodes(num_total, 1).into_iter() + let mut swarms = build_connected_nodes(num_total, 1) + .into_iter() .map(|(_a, s)| s) .collect::>(); let swarm_ids: Vec<_> = swarms.iter().map(Swarm::local_peer_id).cloned().collect(); @@ -250,10 +261,10 @@ fn query_iter() { Some(q) => match q.info() { QueryInfo::GetClosestPeers { key } => { assert_eq!(&key[..], search_target.to_bytes().as_slice()) - }, - i => panic!("Unexpected query info: {:?}", i) - } - None => panic!("Query not found: {:?}", qid) + } + i => panic!("Unexpected query info: {:?}", i), + }, + None => panic!("Query not found: {:?}", qid), } // Set up expectations. @@ -263,37 +274,39 @@ fn query_iter() { expected_distances.sort(); // Run test - block_on( - poll_fn(move |ctx| { - for (i, swarm) in swarms.iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::GetClosestPeers(Ok(ok)), .. - }))) => { - assert_eq!(id, qid); - assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); - assert_eq!(swarm_ids[i], expected_swarm_id); - assert_eq!(swarm.behaviour_mut().queries.size(), 0); - assert!(expected_peer_ids.iter().all(|p| ok.peers.contains(p))); - let key = kbucket::Key::new(ok.key); - assert_eq!(expected_distances, distances(&key, ok.peers)); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + block_on(poll_fn(move |ctx| { + for (i, swarm) in swarms.iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::GetClosestPeers(Ok(ok)), + .. + }, + ))) => { + assert_eq!(id, qid); + assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); + assert_eq!(swarm_ids[i], expected_swarm_id); + assert_eq!(swarm.behaviour_mut().queries.size(), 0); + assert!(expected_peer_ids.iter().all(|p| ok.peers.contains(p))); + let key = kbucket::Key::new(ok.key); + assert_eq!(expected_distances, distances(&key, ok.peers)); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } - Poll::Pending - }) - ) + } + Poll::Pending + })) } let mut rng = thread_rng(); - for _ in 0 .. 10 { + for _ in 0..10 { run(&mut rng) } } @@ -303,42 +316,46 @@ fn unresponsive_not_returned_direct() { // Build one node. It contains fake addresses to non-existing nodes. We ask it to find a // random peer. We make sure that no fake address is returned. - let mut swarms = build_nodes(1).into_iter() + let mut swarms = build_nodes(1) + .into_iter() .map(|(_a, s)| s) .collect::>(); // Add fake addresses. - for _ in 0 .. 10 { - swarms[0].behaviour_mut().add_address(&PeerId::random(), Protocol::Udp(10u16).into()); + for _ in 0..10 { + swarms[0] + .behaviour_mut() + .add_address(&PeerId::random(), Protocol::Udp(10u16).into()); } // Ask first to search a random value. let search_target = PeerId::random(); swarms[0].behaviour_mut().get_closest_peers(search_target); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - result: QueryResult::GetClosestPeers(Ok(ok)), .. - }))) => { - assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); - assert_eq!(ok.peers.len(), 0); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + result: QueryResult::GetClosestPeers(Ok(ok)), + .. + }, + ))) => { + assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); + assert_eq!(ok.peers.len(), 0); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } #[test] @@ -350,96 +367,120 @@ fn unresponsive_not_returned_indirect() { let mut swarms = build_nodes(2); // Add fake addresses to first. - for _ in 0 .. 10 { - swarms[0].1.behaviour_mut().add_address(&PeerId::random(), multiaddr![Udp(10u16)]); + for _ in 0..10 { + swarms[0] + .1 + .behaviour_mut() + .add_address(&PeerId::random(), multiaddr![Udp(10u16)]); } // Connect second to first. let first_peer_id = *swarms[0].1.local_peer_id(); let first_address = swarms[0].0.clone(); - swarms[1].1.behaviour_mut().add_address(&first_peer_id, first_address); + swarms[1] + .1 + .behaviour_mut() + .add_address(&first_peer_id, first_address); // Drop the swarm addresses. - let mut swarms = swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>(); + let mut swarms = swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>(); // Ask second to search a random value. let search_target = PeerId::random(); swarms[1].behaviour_mut().get_closest_peers(search_target); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - result: QueryResult::GetClosestPeers(Ok(ok)), .. - }))) => { - assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); - assert_eq!(ok.peers.len(), 1); - assert_eq!(ok.peers[0], first_peer_id); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + result: QueryResult::GetClosestPeers(Ok(ok)), + .. + }, + ))) => { + assert_eq!(&ok.key[..], search_target.to_bytes().as_slice()); + assert_eq!(ok.peers.len(), 1); + assert_eq!(ok.peers[0], first_peer_id); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } #[test] fn get_record_not_found() { let mut swarms = build_nodes(3); - let swarm_ids: Vec<_> = swarms.iter() + let swarm_ids: Vec<_> = swarms + .iter() .map(|(_addr, swarm)| *swarm.local_peer_id()) .collect(); let (second, third) = (swarms[1].0.clone(), swarms[2].0.clone()); - swarms[0].1.behaviour_mut().add_address(&swarm_ids[1], second); - swarms[1].1.behaviour_mut().add_address(&swarm_ids[2], third); + swarms[0] + .1 + .behaviour_mut() + .add_address(&swarm_ids[1], second); + swarms[1] + .1 + .behaviour_mut() + .add_address(&swarm_ids[2], third); // Drop the swarm addresses. - let mut swarms = swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>(); + let mut swarms = swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>(); let target_key = record::Key::from(random_multihash()); - let qid = swarms[0].behaviour_mut().get_record(&target_key, Quorum::One); + let qid = swarms[0] + .behaviour_mut() + .get_record(&target_key, Quorum::One); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::GetRecord(Err(e)), .. - }))) => { - assert_eq!(id, qid); - if let GetRecordError::NotFound { key, closest_peers, } = e { - assert_eq!(key, target_key); - assert_eq!(closest_peers.len(), 2); - assert!(closest_peers.contains(&swarm_ids[1])); - assert!(closest_peers.contains(&swarm_ids[2])); - return Poll::Ready(()); - } else { - panic!("Unexpected error result: {:?}", e); - } + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::GetRecord(Err(e)), + .. + }, + ))) => { + assert_eq!(id, qid); + if let GetRecordError::NotFound { key, closest_peers } = e { + assert_eq!(key, target_key); + assert_eq!(closest_peers.len(), 2); + assert!(closest_peers.contains(&swarm_ids[1])); + assert!(closest_peers.contains(&swarm_ids[2])); + return Poll::Ready(()); + } else { + panic!("Unexpected error result: {:?}", e); } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } /// A node joining a fully connected network via three (ALPHA_VALUE) bootnodes @@ -449,7 +490,8 @@ fn get_record_not_found() { fn put_record() { fn prop(records: Vec, seed: Seed) { let mut rng = StdRng::from_seed(seed.0); - let replication_factor = NonZeroUsize::new(rng.gen_range(1, (K_VALUE.get() / 2) + 1)).unwrap(); + let replication_factor = + NonZeroUsize::new(rng.gen_range(1, (K_VALUE.get() / 2) + 1)).unwrap(); // At least 4 nodes, 1 under test + 3 bootnodes. let num_total = usize::max(4, replication_factor.get() * 2); @@ -460,10 +502,8 @@ fn put_record() { } let mut swarms = { - let mut fully_connected_swarms = build_fully_connected_nodes_with_config( - num_total - 1, - config.clone(), - ); + let mut fully_connected_swarms = + build_fully_connected_nodes_with_config(num_total - 1, config.clone()); let mut single_swarm = build_node_with_config(config); // Connect `single_swarm` to three bootnodes. @@ -478,10 +518,14 @@ fn put_record() { swarms.append(&mut fully_connected_swarms); // Drop the swarm addresses. - swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>() + swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>() }; - let records = records.into_iter() + let records = records + .into_iter() .take(num_total) .map(|mut r| { // We don't want records to expire prematurely, as they would @@ -490,12 +534,15 @@ fn put_record() { r.expires = r.expires.map(|t| t + Duration::from_secs(60)); (r.key.clone(), r) }) - .collect::>(); + .collect::>(); // Initiate put_record queries. let mut qids = HashSet::new(); for r in records.values() { - let qid = swarms[0].behaviour_mut().put_record(r.clone(), Quorum::All).unwrap(); + let qid = swarms[0] + .behaviour_mut() + .put_record(r.clone(), Quorum::All) + .unwrap(); match swarms[0].behaviour_mut().query(&qid) { Some(q) => match q.info() { QueryInfo::PutRecord { phase, record, .. } => { @@ -504,10 +551,10 @@ fn put_record() { assert_eq!(record.value, r.value); assert!(record.expires.is_some()); qids.insert(qid); - }, - i => panic!("Unexpected query info: {:?}", i) - } - None => panic!("Query not found: {:?}", qid) + } + i => panic!("Unexpected query info: {:?}", i), + }, + None => panic!("Query not found: {:?}", qid), } } @@ -516,118 +563,136 @@ fn put_record() { // The accumulated results for one round of publishing. let mut results = Vec::new(); - block_on( - poll_fn(move |ctx| loop { - // Poll all swarms until they are "Pending". - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::PutRecord(res), stats - }))) | - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::RepublishRecord(res), stats - }))) => { - assert!(qids.is_empty() || qids.remove(&id)); - assert!(stats.duration().is_some()); - assert!(stats.num_successes() >= replication_factor.get() as u32); - assert!(stats.num_requests() >= stats.num_successes()); - assert_eq!(stats.num_failures(), 0); - match res { - Err(e) => panic!("{:?}", e), - Ok(ok) => { - assert!(records.contains_key(&ok.key)); - let record = swarm.behaviour_mut().store.get(&ok.key).unwrap(); - results.push(record.into_owned()); - } + block_on(poll_fn(move |ctx| loop { + // Poll all swarms until they are "Pending". + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::PutRecord(res), + stats, + }, + ))) + | Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::RepublishRecord(res), + stats, + }, + ))) => { + assert!(qids.is_empty() || qids.remove(&id)); + assert!(stats.duration().is_some()); + assert!(stats.num_successes() >= replication_factor.get() as u32); + assert!(stats.num_requests() >= stats.num_successes()); + assert_eq!(stats.num_failures(), 0); + match res { + Err(e) => panic!("{:?}", e), + Ok(ok) => { + assert!(records.contains_key(&ok.key)); + let record = swarm.behaviour_mut().store.get(&ok.key).unwrap(); + results.push(record.into_owned()); } } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - // All swarms are Pending and not enough results have been collected - // so far, thus wait to be polled again for further progress. - if results.len() != records.len() { - return Poll::Pending - } + // All swarms are Pending and not enough results have been collected + // so far, thus wait to be polled again for further progress. + if results.len() != records.len() { + return Poll::Pending; + } - // Consume the results, checking that each record was replicated - // correctly to the closest peers to the key. - while let Some(r) = results.pop() { - let expected = records.get(&r.key).unwrap(); - - assert_eq!(r.key, expected.key); - assert_eq!(r.value, expected.value); - assert_eq!(r.expires, expected.expires); - assert_eq!(r.publisher, Some(*swarms[0].local_peer_id())); - - let key = kbucket::Key::new(r.key.clone()); - let mut expected = swarms.iter() - .skip(1) - .map(Swarm::local_peer_id) - .cloned() - .collect::>(); - expected.sort_by(|id1, id2| - kbucket::Key::from(*id1).distance(&key).cmp( - &kbucket::Key::from(*id2).distance(&key))); - - let expected = expected - .into_iter() - .take(replication_factor.get()) - .collect::>(); - - let actual = swarms.iter() - .skip(1) - .filter_map(|swarm| - if swarm.behaviour().store.get(key.preimage()).is_some() { - Some(*swarm.local_peer_id()) - } else { - None - }) - .collect::>(); - - assert_eq!(actual.len(), replication_factor.get()); - - let actual_not_expected = actual.difference(&expected) - .collect::>(); - assert!( - actual_not_expected.is_empty(), - "Did not expect records to be stored on nodes {:?}.", - actual_not_expected, - ); - - let expected_not_actual = expected.difference(&actual) - .collect::>(); - assert!(expected_not_actual.is_empty(), - "Expected record to be stored on nodes {:?}.", - expected_not_actual, - ); - } + // Consume the results, checking that each record was replicated + // correctly to the closest peers to the key. + while let Some(r) = results.pop() { + let expected = records.get(&r.key).unwrap(); + + assert_eq!(r.key, expected.key); + assert_eq!(r.value, expected.value); + assert_eq!(r.expires, expected.expires); + assert_eq!(r.publisher, Some(*swarms[0].local_peer_id())); + + let key = kbucket::Key::new(r.key.clone()); + let mut expected = swarms + .iter() + .skip(1) + .map(Swarm::local_peer_id) + .cloned() + .collect::>(); + expected.sort_by(|id1, id2| { + kbucket::Key::from(*id1) + .distance(&key) + .cmp(&kbucket::Key::from(*id2).distance(&key)) + }); + + let expected = expected + .into_iter() + .take(replication_factor.get()) + .collect::>(); + + let actual = swarms + .iter() + .skip(1) + .filter_map(|swarm| { + if swarm.behaviour().store.get(key.preimage()).is_some() { + Some(*swarm.local_peer_id()) + } else { + None + } + }) + .collect::>(); - if republished { - assert_eq!(swarms[0].behaviour_mut().store.records().count(), records.len()); - assert_eq!(swarms[0].behaviour_mut().queries.size(), 0); - for k in records.keys() { - swarms[0].behaviour_mut().store.remove(&k); - } - assert_eq!(swarms[0].behaviour_mut().store.records().count(), 0); - // All records have been republished, thus the test is complete. - return Poll::Ready(()); + assert_eq!(actual.len(), replication_factor.get()); + + let actual_not_expected = actual.difference(&expected).collect::>(); + assert!( + actual_not_expected.is_empty(), + "Did not expect records to be stored on nodes {:?}.", + actual_not_expected, + ); + + let expected_not_actual = expected.difference(&actual).collect::>(); + assert!( + expected_not_actual.is_empty(), + "Expected record to be stored on nodes {:?}.", + expected_not_actual, + ); + } + + if republished { + assert_eq!( + swarms[0].behaviour_mut().store.records().count(), + records.len() + ); + assert_eq!(swarms[0].behaviour_mut().queries.size(), 0); + for k in records.keys() { + swarms[0].behaviour_mut().store.remove(&k); } + assert_eq!(swarms[0].behaviour_mut().store.records().count(), 0); + // All records have been republished, thus the test is complete. + return Poll::Ready(()); + } - // Tell the replication job to republish asap. - swarms[0].behaviour_mut().put_record_job.as_mut().unwrap().asap(true); - republished = true; - }) - ) + // Tell the replication job to republish asap. + swarms[0] + .behaviour_mut() + .put_record_job + .as_mut() + .unwrap() + .asap(true); + republished = true; + })) } - QuickCheck::new().tests(3).quickcheck(prop as fn(_,_) -> _) + QuickCheck::new().tests(3).quickcheck(prop as fn(_, _) -> _) } #[test] @@ -636,95 +701,109 @@ fn get_record() { // Let first peer know of second peer and second peer know of third peer. for i in 0..2 { - let (peer_id, address) = (Swarm::local_peer_id(&swarms[i+1].1).clone(), swarms[i+1].0.clone()); + let (peer_id, address) = ( + Swarm::local_peer_id(&swarms[i + 1].1).clone(), + swarms[i + 1].0.clone(), + ); swarms[i].1.behaviour_mut().add_address(&peer_id, address); } // Drop the swarm addresses. - let mut swarms = swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>(); + let mut swarms = swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>(); - let record = Record::new(random_multihash(), vec![4,5,6]); + let record = Record::new(random_multihash(), vec![4, 5, 6]); let expected_cache_candidate = *Swarm::local_peer_id(&swarms[1]); swarms[2].behaviour_mut().store.put(record.clone()).unwrap(); - let qid = swarms[0].behaviour_mut().get_record(&record.key, Quorum::One); + let qid = swarms[0] + .behaviour_mut() + .get_record(&record.key, Quorum::One); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { id, - result: QueryResult::GetRecord(Ok(GetRecordOk { - records, cache_candidates - })), + result: + QueryResult::GetRecord(Ok(GetRecordOk { + records, + cache_candidates, + })), .. - }))) => { - assert_eq!(id, qid); - assert_eq!(records.len(), 1); - assert_eq!(records.first().unwrap().record, record); - assert_eq!(cache_candidates.len(), 1); - assert_eq!(cache_candidates.values().next(), Some(&expected_cache_candidate)); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + }, + ))) => { + assert_eq!(id, qid); + assert_eq!(records.len(), 1); + assert_eq!(records.first().unwrap().record, record); + assert_eq!(cache_candidates.len(), 1); + assert_eq!( + cache_candidates.values().next(), + Some(&expected_cache_candidate) + ); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } #[test] fn get_record_many() { // TODO: Randomise let num_nodes = 12; - let mut swarms = build_connected_nodes(num_nodes, 3).into_iter() + let mut swarms = build_connected_nodes(num_nodes, 3) + .into_iter() .map(|(_addr, swarm)| swarm) .collect::>(); let num_results = 10; - let record = Record::new(random_multihash(), vec![4,5,6]); + let record = Record::new(random_multihash(), vec![4, 5, 6]); - for i in 0 .. num_nodes { + for i in 0..num_nodes { swarms[i].behaviour_mut().store.put(record.clone()).unwrap(); } let quorum = Quorum::N(NonZeroUsize::new(num_results).unwrap()); let qid = swarms[0].behaviour_mut().get_record(&record.key, quorum); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { id, result: QueryResult::GetRecord(Ok(GetRecordOk { records, .. })), .. - }))) => { - assert_eq!(id, qid); - assert!(records.len() >= num_results); - assert!(records.into_iter().all(|r| r.record == record)); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + }, + ))) => { + assert_eq!(id, qid); + assert!(records.len() >= num_results); + assert!(records.into_iter().all(|r| r.record == record)); + return Poll::Ready(()); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } - Poll::Pending - }) - ) + } + Poll::Pending + })) } /// A node joining a fully connected network via three (ALPHA_VALUE) bootnodes @@ -734,7 +813,8 @@ fn get_record_many() { fn add_provider() { fn prop(keys: Vec, seed: Seed) { let mut rng = StdRng::from_seed(seed.0); - let replication_factor = NonZeroUsize::new(rng.gen_range(1, (K_VALUE.get() / 2) + 1)).unwrap(); + let replication_factor = + NonZeroUsize::new(rng.gen_range(1, (K_VALUE.get() / 2) + 1)).unwrap(); // At least 4 nodes, 1 under test + 3 bootnodes. let num_total = usize::max(4, replication_factor.get() * 2); @@ -745,10 +825,8 @@ fn add_provider() { } let mut swarms = { - let mut fully_connected_swarms = build_fully_connected_nodes_with_config( - num_total - 1, - config.clone(), - ); + let mut fully_connected_swarms = + build_fully_connected_nodes_with_config(num_total - 1, config.clone()); let mut single_swarm = build_node_with_config(config); // Connect `single_swarm` to three bootnodes. @@ -763,7 +841,10 @@ fn add_provider() { swarms.append(&mut fully_connected_swarms); // Drop addresses before returning. - swarms.into_iter().map(|(_addr, swarm)| swarm).collect::>() + swarms + .into_iter() + .map(|(_addr, swarm)| swarm) + .collect::>() }; let keys: HashSet<_> = keys.into_iter().take(num_total).collect(); @@ -777,113 +858,136 @@ fn add_provider() { // Initiate the first round of publishing. let mut qids = HashSet::new(); for k in &keys { - let qid = swarms[0].behaviour_mut().start_providing(k.clone()).unwrap(); + let qid = swarms[0] + .behaviour_mut() + .start_providing(k.clone()) + .unwrap(); qids.insert(qid); } - block_on( - poll_fn(move |ctx| loop { - // Poll all swarms until they are "Pending". - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::StartProviding(res), .. - }))) | - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - id, result: QueryResult::RepublishProvider(res), .. - }))) => { - assert!(qids.is_empty() || qids.remove(&id)); - match res { - Err(e) => panic!("{:?}", e), - Ok(ok) => { - assert!(keys.contains(&ok.key)); - results.push(ok.key); - } + block_on(poll_fn(move |ctx| loop { + // Poll all swarms until they are "Pending". + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::StartProviding(res), + .. + }, + ))) + | Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::RepublishProvider(res), + .. + }, + ))) => { + assert!(qids.is_empty() || qids.remove(&id)); + match res { + Err(e) => panic!("{:?}", e), + Ok(ok) => { + assert!(keys.contains(&ok.key)); + results.push(ok.key); } } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Pending => break, } } + } - if results.len() == keys.len() { - // All requests have been sent for one round of publishing. - published = true - } + if results.len() == keys.len() { + // All requests have been sent for one round of publishing. + published = true + } - if !published { - // Still waiting for all requests to be sent for one round - // of publishing. - return Poll::Pending - } + if !published { + // Still waiting for all requests to be sent for one round + // of publishing. + return Poll::Pending; + } - // A round of publishing is complete. Consume the results, checking that - // each key was published to the `replication_factor` closest peers. - while let Some(key) = results.pop() { - // Collect the nodes that have a provider record for `key`. - let actual = swarms.iter().skip(1) - .filter_map(|swarm| - if swarm.behaviour().store.providers(&key).len() == 1 { - Some(Swarm::local_peer_id(&swarm).clone()) - } else { - None - }) - .collect::>(); - - if actual.len() != replication_factor.get() { - // Still waiting for some nodes to process the request. - results.push(key); - return Poll::Pending - } + // A round of publishing is complete. Consume the results, checking that + // each key was published to the `replication_factor` closest peers. + while let Some(key) = results.pop() { + // Collect the nodes that have a provider record for `key`. + let actual = swarms + .iter() + .skip(1) + .filter_map(|swarm| { + if swarm.behaviour().store.providers(&key).len() == 1 { + Some(Swarm::local_peer_id(&swarm).clone()) + } else { + None + } + }) + .collect::>(); - let mut expected = swarms.iter() - .skip(1) - .map(Swarm::local_peer_id) - .cloned() - .collect::>(); - let kbucket_key = kbucket::Key::new(key); - expected.sort_by(|id1, id2| - kbucket::Key::from(*id1).distance(&kbucket_key).cmp( - &kbucket::Key::from(*id2).distance(&kbucket_key))); - - let expected = expected - .into_iter() - .take(replication_factor.get()) - .collect::>(); - - assert_eq!(actual, expected); + if actual.len() != replication_factor.get() { + // Still waiting for some nodes to process the request. + results.push(key); + return Poll::Pending; } - // One round of publishing is complete. - assert!(results.is_empty()); - for swarm in &swarms { - assert_eq!(swarm.behaviour().queries.size(), 0); - } + let mut expected = swarms + .iter() + .skip(1) + .map(Swarm::local_peer_id) + .cloned() + .collect::>(); + let kbucket_key = kbucket::Key::new(key); + expected.sort_by(|id1, id2| { + kbucket::Key::from(*id1) + .distance(&kbucket_key) + .cmp(&kbucket::Key::from(*id2).distance(&kbucket_key)) + }); + + let expected = expected + .into_iter() + .take(replication_factor.get()) + .collect::>(); + + assert_eq!(actual, expected); + } - if republished { - assert_eq!(swarms[0].behaviour_mut().store.provided().count(), keys.len()); - for k in &keys { - swarms[0].behaviour_mut().stop_providing(&k); - } - assert_eq!(swarms[0].behaviour_mut().store.provided().count(), 0); - // All records have been republished, thus the test is complete. - return Poll::Ready(()); + // One round of publishing is complete. + assert!(results.is_empty()); + for swarm in &swarms { + assert_eq!(swarm.behaviour().queries.size(), 0); + } + + if republished { + assert_eq!( + swarms[0].behaviour_mut().store.provided().count(), + keys.len() + ); + for k in &keys { + swarms[0].behaviour_mut().stop_providing(&k); } + assert_eq!(swarms[0].behaviour_mut().store.provided().count(), 0); + // All records have been republished, thus the test is complete. + return Poll::Ready(()); + } - // Initiate the second round of publishing by telling the - // periodic provider job to run asap. - swarms[0].behaviour_mut().add_provider_job.as_mut().unwrap().asap(); - published = false; - republished = true; - }) - ) + // Initiate the second round of publishing by telling the + // periodic provider job to run asap. + swarms[0] + .behaviour_mut() + .add_provider_job + .as_mut() + .unwrap() + .asap(); + published = false; + republished = true; + })) } - QuickCheck::new().tests(3).quickcheck(prop as fn(_,_)) + QuickCheck::new().tests(3).quickcheck(prop as fn(_, _)) } /// User code should be able to start queries beyond the internal @@ -893,33 +997,32 @@ fn add_provider() { fn exceed_jobs_max_queries() { let (_addr, mut swarm) = build_node(); let num = JOBS_MAX_QUERIES + 1; - for _ in 0 .. num { + for _ in 0..num { swarm.behaviour_mut().get_closest_peers(PeerId::random()); } assert_eq!(swarm.behaviour_mut().queries.size(), num); - block_on( - poll_fn(move |ctx| { - for _ in 0 .. num { - // There are no other nodes, so the queries finish instantly. - loop { - if let Poll::Ready(Some(e)) = swarm.poll_next_unpin(ctx) { - match e { - SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { - result: QueryResult::GetClosestPeers(Ok(r)), .. - }) => break assert!(r.peers.is_empty()), - SwarmEvent::Behaviour(e) => panic!("Unexpected event: {:?}", e), - _ => {} - } - } else { - panic!("Expected event") + block_on(poll_fn(move |ctx| { + for _ in 0..num { + // There are no other nodes, so the queries finish instantly. + loop { + if let Poll::Ready(Some(e)) = swarm.poll_next_unpin(ctx) { + match e { + SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted { + result: QueryResult::GetClosestPeers(Ok(r)), + .. + }) => break assert!(r.peers.is_empty()), + SwarmEvent::Behaviour(e) => panic!("Unexpected event: {:?}", e), + _ => {} } + } else { + panic!("Expected event") } } - Poll::Ready(()) - }) - ) + } + Poll::Ready(()) + })) } #[test] @@ -952,11 +1055,22 @@ fn disjoint_query_does_not_finish_before_all_paths_did() { // Make `bob` and `trudy` aware of their version of the record searched by // `alice`. bob.1.behaviour_mut().store.put(record_bob.clone()).unwrap(); - trudy.1.behaviour_mut().store.put(record_trudy.clone()).unwrap(); + trudy + .1 + .behaviour_mut() + .store + .put(record_trudy.clone()) + .unwrap(); // Make `trudy` and `bob` known to `alice`. - alice.1.behaviour_mut().add_address(&trudy.1.local_peer_id(), trudy.0.clone()); - alice.1.behaviour_mut().add_address(&bob.1.local_peer_id(), bob.0.clone()); + alice + .1 + .behaviour_mut() + .add_address(&trudy.1.local_peer_id(), trudy.0.clone()); + alice + .1 + .behaviour_mut() + .add_address(&bob.1.local_peer_id(), bob.0.clone()); // Drop the swarm addresses. let (mut alice, mut bob, mut trudy) = (alice.1, bob.1, trudy.1); @@ -970,45 +1084,48 @@ fn disjoint_query_does_not_finish_before_all_paths_did() { // Poll only `alice` and `trudy` expecting `alice` not yet to return a query // result as it is not able to connect to `bob` just yet. - block_on( - poll_fn(|ctx| { - for (i, swarm) in [&mut alice, &mut trudy].iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted{ + block_on(poll_fn(|ctx| { + for (i, swarm) in [&mut alice, &mut trudy].iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { result: QueryResult::GetRecord(result), - .. - }))) => { - if i != 0 { - panic!("Expected `QueryResult` from Alice.") - } + .. + }, + ))) => { + if i != 0 { + panic!("Expected `QueryResult` from Alice.") + } - match result { - Ok(_) => panic!( - "Expected query not to finish until all \ + match result { + Ok(_) => panic!( + "Expected query not to finish until all \ disjoint paths have been explored.", - ), - Err(e) => panic!("{:?}", e), - } + ), + Err(e) => panic!("{:?}", e), } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - Poll::Ready(None) => panic!("Expected Kademlia behaviour not to finish."), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) => panic!("Expected Kademlia behaviour not to finish."), + Poll::Pending => break, } } + } - // Make sure not to wait until connections to `bob` time out. - before_timeout.poll_unpin(ctx) - }) - ); + // Make sure not to wait until connections to `bob` time out. + before_timeout.poll_unpin(ctx) + })); // Make sure `alice` has exactly one query with `trudy`'s record only. assert_eq!(1, alice.behaviour().queries.iter().count()); - alice.behaviour().queries.iter().for_each(|q| { - match &q.inner.info { - QueryInfo::GetRecord{ records, .. } => { + alice + .behaviour() + .queries + .iter() + .for_each(|q| match &q.inner.info { + QueryInfo::GetRecord { records, .. } => { assert_eq!( *records, vec![PeerRecord { @@ -1016,44 +1133,41 @@ fn disjoint_query_does_not_finish_before_all_paths_did() { record: record_trudy.clone(), }], ); - }, + } i @ _ => panic!("Unexpected query info: {:?}", i), - } - }); + }); // Poll `alice` and `bob` expecting `alice` to return a successful query // result as it is now able to explore the second disjoint path. - let records = block_on( - poll_fn(|ctx| { - for (i, swarm) in [&mut alice, &mut bob].iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::OutboundQueryCompleted{ + let records = block_on(poll_fn(|ctx| { + for (i, swarm) in [&mut alice, &mut bob].iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(SwarmEvent::Behaviour( + KademliaEvent::OutboundQueryCompleted { result: QueryResult::GetRecord(result), .. - }))) => { - if i != 0 { - panic!("Expected `QueryResult` from Alice.") - } + }, + ))) => { + if i != 0 { + panic!("Expected `QueryResult` from Alice.") + } - match result { - Ok(ok) => return Poll::Ready(ok.records), - Err(e) => unreachable!("{:?}", e), - } + match result { + Ok(ok) => return Poll::Ready(ok.records), + Err(e) => unreachable!("{:?}", e), } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - Poll::Ready(None) => panic!( - "Expected Kademlia behaviour not to finish.", - ), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) => panic!("Expected Kademlia behaviour not to finish.",), + Poll::Pending => break, } } + } - Poll::Pending - }) - ); + Poll::Pending + })); assert_eq!(2, records.len()); assert!(records.contains(&PeerRecord { @@ -1075,25 +1189,31 @@ fn manual_bucket_inserts() { // 1 -> 2 -> [3 -> ...] let mut swarms = build_connected_nodes_with_config(3, 1, cfg); // The peers and their addresses for which we expect `RoutablePeer` events. - let mut expected = swarms.iter().skip(2) + let mut expected = swarms + .iter() + .skip(2) .map(|(a, s)| { let pid = *Swarm::local_peer_id(s); let addr = a.clone().with(Protocol::P2p(pid.into())); (addr, pid) }) - .collect::>(); + .collect::>(); // We collect the peers for which a `RoutablePeer` event // was received in here to check at the end of the test // that none of them was inserted into a bucket. let mut routable = Vec::new(); // Start an iterative query from the first peer. - swarms[0].1.behaviour_mut().get_closest_peers(PeerId::random()); + swarms[0] + .1 + .behaviour_mut() + .get_closest_peers(PeerId::random()); block_on(poll_fn(move |ctx| { for (_, swarm) in swarms.iter_mut() { loop { match swarm.poll_next_unpin(ctx) { Poll::Ready(Some(SwarmEvent::Behaviour(KademliaEvent::RoutablePeer { - peer, address + peer, + address, }))) => { assert_eq!(peer, expected.remove(&address).expect("Missing address")); routable.push(peer); @@ -1102,11 +1222,11 @@ fn manual_bucket_inserts() { let bucket = swarm.behaviour_mut().kbucket(*peer).unwrap(); assert!(bucket.iter().all(|e| e.node.key.preimage() != peer)); } - return Poll::Ready(()) + return Poll::Ready(()); } } - Poll::Ready(..) => {}, - Poll::Pending => break + Poll::Ready(..) => {} + Poll::Pending => break, } } } @@ -1123,19 +1243,14 @@ fn network_behaviour_inject_address_change() { let old_address: Multiaddr = Protocol::Memory(1).into(); let new_address: Multiaddr = Protocol::Memory(2).into(); - let mut kademlia = Kademlia::new( - local_peer_id.clone(), - MemoryStore::new(local_peer_id), - ); + let mut kademlia = Kademlia::new(local_peer_id.clone(), MemoryStore::new(local_peer_id)); - let endpoint = ConnectedPoint::Dialer { address: old_address.clone() }; + let endpoint = ConnectedPoint::Dialer { + address: old_address.clone(), + }; // Mimick a connection being established. - kademlia.inject_connection_established( - &remote_peer_id, - &connection_id, - &endpoint, - ); + kademlia.inject_connection_established(&remote_peer_id, &connection_id, &endpoint); kademlia.inject_connected(&remote_peer_id); // At this point the remote is not yet known to support the @@ -1148,7 +1263,7 @@ fn network_behaviour_inject_address_change() { kademlia.inject_event( remote_peer_id.clone(), connection_id.clone(), - KademliaHandlerEvent::ProtocolConfirmed { endpoint } + KademliaHandlerEvent::ProtocolConfirmed { endpoint }, ); assert_eq!( @@ -1159,8 +1274,12 @@ fn network_behaviour_inject_address_change() { kademlia.inject_address_change( &remote_peer_id, &connection_id, - &ConnectedPoint::Dialer { address: old_address.clone() }, - &ConnectedPoint::Dialer { address: new_address.clone() }, + &ConnectedPoint::Dialer { + address: old_address.clone(), + }, + &ConnectedPoint::Dialer { + address: new_address.clone(), + }, ); assert_eq!( diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index 70b8fdd955a..3c955bb428a 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -24,23 +24,19 @@ use crate::protocol::{ }; use crate::record::{self, Record}; use futures::prelude::*; -use libp2p_swarm::{ - IntoProtocolsHandler, - KeepAlive, - NegotiatedSubstream, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr -}; use libp2p_core::{ - ConnectedPoint, - PeerId, either::EitherOutput, - upgrade::{self, InboundUpgrade, OutboundUpgrade} + upgrade::{self, InboundUpgrade, OutboundUpgrade}, + ConnectedPoint, PeerId, +}; +use libp2p_swarm::{ + IntoProtocolsHandler, KeepAlive, NegotiatedSubstream, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; use log::trace; -use std::{error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration}; +use std::{ + error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration, +}; use wasm_timer::Instant; /// A prototype from which [`KademliaHandler`]s can be constructed. @@ -51,7 +47,10 @@ pub struct KademliaHandlerProto { impl KademliaHandlerProto { pub fn new(config: KademliaHandlerConfig) -> Self { - KademliaHandlerProto { config, _type: PhantomData } + KademliaHandlerProto { + config, + _type: PhantomData, + } } } @@ -151,7 +150,11 @@ enum SubstreamState { /// Waiting for the user to send a `KademliaHandlerIn` event containing the response. InWaitingUser(UniqueConnecId, KadInStreamSink), /// Waiting to send an answer back to the remote. - InPendingSend(UniqueConnecId, KadInStreamSink, KadResponseMsg), + InPendingSend( + UniqueConnecId, + KadInStreamSink, + KadResponseMsg, + ), /// Waiting to flush an answer back to the remote. InPendingFlush(UniqueConnecId, KadInStreamSink), /// The substream is being closed. @@ -164,23 +167,28 @@ impl SubstreamState { /// If the substream is not ready to be closed, returns it back. fn try_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { match self { - SubstreamState::OutPendingOpen(_, _) - | SubstreamState::OutReportError(_, _) => Poll::Ready(()), + SubstreamState::OutPendingOpen(_, _) | SubstreamState::OutReportError(_, _) => { + Poll::Ready(()) + } SubstreamState::OutPendingSend(ref mut stream, _, _) | SubstreamState::OutPendingFlush(ref mut stream, _) | SubstreamState::OutWaitingAnswer(ref mut stream, _) - | SubstreamState::OutClosing(ref mut stream) => match Sink::poll_close(Pin::new(stream), cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - }, + | SubstreamState::OutClosing(ref mut stream) => { + match Sink::poll_close(Pin::new(stream), cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } SubstreamState::InWaitingMessage(_, ref mut stream) | SubstreamState::InWaitingUser(_, ref mut stream) | SubstreamState::InPendingSend(_, ref mut stream, _) | SubstreamState::InPendingFlush(_, ref mut stream) - | SubstreamState::InClosing(ref mut stream) => match Sink::poll_close(Pin::new(stream), cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - }, + | SubstreamState::InClosing(ref mut stream) => { + match Sink::poll_close(Pin::new(stream), cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } } } } @@ -282,7 +290,7 @@ pub enum KademliaHandlerEvent { value: Vec, /// The user data passed to the `PutValue`. user_data: TUserData, - } + }, } /// Error that can happen when requesting an RPC query. @@ -301,13 +309,16 @@ impl fmt::Display for KademliaHandlerQueryErr { match self { KademliaHandlerQueryErr::Upgrade(err) => { write!(f, "Error while performing Kademlia query: {}", err) - }, + } KademliaHandlerQueryErr::UnexpectedMessage => { - write!(f, "Remote answered our Kademlia RPC query with the wrong message type") - }, + write!( + f, + "Remote answered our Kademlia RPC query with the wrong message type" + ) + } KademliaHandlerQueryErr::Io(err) => { write!(f, "I/O error during a Kademlia RPC query: {}", err) - }, + } } } } @@ -424,7 +435,7 @@ pub enum KademliaHandlerIn { value: Vec, /// Identifier of the request that was made by the remote. request_id: KademliaRequestId, - } + }, } /// Unique identifier for a request. Must be passed back in order to answer a request from @@ -470,7 +481,8 @@ where fn listen_protocol(&self) -> SubstreamProtocol { if self.config.allow_listening { - SubstreamProtocol::new(self.config.protocol_config.clone(), ()).map_upgrade(upgrade::EitherUpgrade::A) + SubstreamProtocol::new(self.config.protocol_config.clone(), ()) + .map_upgrade(upgrade::EitherUpgrade::A) } else { SubstreamProtocol::new(upgrade::EitherUpgrade::B(upgrade::DeniedUpgrade), ()) } @@ -481,7 +493,8 @@ where protocol: >::Output, (msg, user_data): Self::OutboundOpenInfo, ) { - self.substreams.push(SubstreamState::OutPendingSend(protocol, msg, user_data)); + self.substreams + .push(SubstreamState::OutPendingSend(protocol, msg, user_data)); if let ProtocolStatus::Unconfirmed = self.protocol_status { // Upon the first successfully negotiated substream, we know that the // remote is configured with the same protocol name and we want @@ -493,7 +506,7 @@ where fn inject_fully_negotiated_inbound( &mut self, protocol: >::Output, - (): Self::InboundOpenInfo + (): Self::InboundOpenInfo, ) { // If `self.allow_listening` is false, then we produced a `DeniedUpgrade` and `protocol` // is a `Void`. @@ -505,7 +518,8 @@ where debug_assert!(self.config.allow_listening); let connec_unique_id = self.next_connec_unique_id; self.next_connec_unique_id.0 += 1; - self.substreams.push(SubstreamState::InWaitingMessage(connec_unique_id, protocol)); + self.substreams + .push(SubstreamState::InWaitingMessage(connec_unique_id, protocol)); if let ProtocolStatus::Unconfirmed = self.protocol_status { // Upon the first successfully negotiated substream, we know that the // remote is configured with the same protocol name and we want @@ -518,8 +532,9 @@ where match message { KademliaHandlerIn::Reset(request_id) => { let pos = self.substreams.iter().position(|state| match state { - SubstreamState::InWaitingUser(conn_id, _) => - conn_id == &request_id.connec_unique_id, + SubstreamState::InWaitingUser(conn_id, _) => { + conn_id == &request_id.connec_unique_id + } _ => false, }); if let Some(pos) = pos { @@ -531,15 +546,17 @@ where } KademliaHandlerIn::FindNodeReq { key, user_data } => { let msg = KadRequestMsg::FindNode { key }; - self.substreams.push(SubstreamState::OutPendingOpen(msg, Some(user_data))); + self.substreams + .push(SubstreamState::OutPendingOpen(msg, Some(user_data))); } KademliaHandlerIn::FindNodeRes { closer_peers, request_id, } => { let pos = self.substreams.iter().position(|state| match state { - SubstreamState::InWaitingUser(ref conn_id, _) => - conn_id == &request_id.connec_unique_id, + SubstreamState::InWaitingUser(ref conn_id, _) => { + conn_id == &request_id.connec_unique_id + } _ => false, }); @@ -549,9 +566,7 @@ where _ => unreachable!(), }; - let msg = KadResponseMsg::FindNode { - closer_peers, - }; + let msg = KadResponseMsg::FindNode { closer_peers }; self.substreams .push(SubstreamState::InPendingSend(conn_id, substream, msg)); } @@ -591,12 +606,13 @@ where } KademliaHandlerIn::AddProvider { key, provider } => { let msg = KadRequestMsg::AddProvider { key, provider }; - self.substreams.push(SubstreamState::OutPendingOpen(msg, None)); + self.substreams + .push(SubstreamState::OutPendingOpen(msg, None)); } KademliaHandlerIn::GetRecord { key, user_data } => { let msg = KadRequestMsg::GetValue { key }; - self.substreams.push(SubstreamState::OutPendingOpen(msg, Some(user_data))); - + self.substreams + .push(SubstreamState::OutPendingOpen(msg, Some(user_data))); } KademliaHandlerIn::PutRecord { record, user_data } => { let msg = KadRequestMsg::PutValue { record }; @@ -609,8 +625,9 @@ where request_id, } => { let pos = self.substreams.iter().position(|state| match state { - SubstreamState::InWaitingUser(ref conn_id, _) - => conn_id == &request_id.connec_unique_id, + SubstreamState::InWaitingUser(ref conn_id, _) => { + conn_id == &request_id.connec_unique_id + } _ => false, }); @@ -636,9 +653,9 @@ where let pos = self.substreams.iter().position(|state| match state { SubstreamState::InWaitingUser(ref conn_id, _) if conn_id == &request_id.connec_unique_id => - { - true - } + { + true + } _ => false, }); @@ -648,10 +665,7 @@ where _ => unreachable!(), }; - let msg = KadResponseMsg::PutValue { - key, - value, - }; + let msg = KadResponseMsg::PutValue { key, value }; self.substreams .push(SubstreamState::InPendingSend(conn_id, substream, msg)); } @@ -680,7 +694,12 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { if self.substreams.is_empty() { return Poll::Pending; @@ -690,8 +709,9 @@ where self.protocol_status = ProtocolStatus::Reported; return Poll::Ready(ProtocolsHandlerEvent::Custom( KademliaHandlerEvent::ProtocolConfirmed { - endpoint: self.endpoint.clone() - })) + endpoint: self.endpoint.clone(), + }, + )); } // We remove each element from `substreams` one by one and add them back. @@ -706,7 +726,8 @@ where } (None, Some(event), _) => { if self.substreams.is_empty() { - self.keep_alive = KeepAlive::Until(Instant::now() + self.config.idle_timeout); + self.keep_alive = + KeepAlive::Until(Instant::now() + self.config.idle_timeout); } return Poll::Ready(event); } @@ -765,36 +786,35 @@ fn advance_substream( >, >, bool, -) -{ +) { match state { SubstreamState::OutPendingOpen(msg, user_data) => { let ev = ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(upgrade, (msg, user_data)) + protocol: SubstreamProtocol::new(upgrade, (msg, user_data)), }; (None, Some(ev), false) } SubstreamState::OutPendingSend(mut substream, msg, user_data) => { match Sink::poll_ready(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => { - match Sink::start_send(Pin::new(&mut substream), msg) { - Ok(()) => ( - Some(SubstreamState::OutPendingFlush(substream, user_data)), - None, - true, - ), - Err(error) => { - let event = if let Some(user_data) = user_data { - Some(ProtocolsHandlerEvent::Custom(KademliaHandlerEvent::QueryError { + Poll::Ready(Ok(())) => match Sink::start_send(Pin::new(&mut substream), msg) { + Ok(()) => ( + Some(SubstreamState::OutPendingFlush(substream, user_data)), + None, + true, + ), + Err(error) => { + let event = if let Some(user_data) = user_data { + Some(ProtocolsHandlerEvent::Custom( + KademliaHandlerEvent::QueryError { error: KademliaHandlerQueryErr::Io(error), - user_data - })) - } else { - None - }; - - (None, event, false) - } + user_data, + }, + )) + } else { + None + }; + + (None, event, false) } }, Poll::Pending => ( @@ -804,10 +824,12 @@ fn advance_substream( ), Poll::Ready(Err(error)) => { let event = if let Some(user_data) = user_data { - Some(ProtocolsHandlerEvent::Custom(KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data - })) + Some(ProtocolsHandlerEvent::Custom( + KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }, + )) } else { None }; @@ -836,10 +858,12 @@ fn advance_substream( ), Poll::Ready(Err(error)) => { let event = if let Some(user_data) = user_data { - Some(ProtocolsHandlerEvent::Custom(KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data, - })) + Some(ProtocolsHandlerEvent::Custom( + KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }, + )) } else { None }; @@ -848,110 +872,121 @@ fn advance_substream( } } } - SubstreamState::OutWaitingAnswer(mut substream, user_data) => match Stream::poll_next(Pin::new(&mut substream), cx) { - Poll::Ready(Some(Ok(msg))) => { - let new_state = SubstreamState::OutClosing(substream); - let event = process_kad_response(msg, user_data); - ( - Some(new_state), - Some(ProtocolsHandlerEvent::Custom(event)), - true, - ) - } - Poll::Pending => ( - Some(SubstreamState::OutWaitingAnswer(substream, user_data)), - None, - false, - ), - Poll::Ready(Some(Err(error))) => { - let event = KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data, - }; - (None, Some(ProtocolsHandlerEvent::Custom(event)), false) - } - Poll::Ready(None) => { - let event = KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(io::ErrorKind::UnexpectedEof.into()), - user_data, - }; - (None, Some(ProtocolsHandlerEvent::Custom(event)), false) + SubstreamState::OutWaitingAnswer(mut substream, user_data) => { + match Stream::poll_next(Pin::new(&mut substream), cx) { + Poll::Ready(Some(Ok(msg))) => { + let new_state = SubstreamState::OutClosing(substream); + let event = process_kad_response(msg, user_data); + ( + Some(new_state), + Some(ProtocolsHandlerEvent::Custom(event)), + true, + ) + } + Poll::Pending => ( + Some(SubstreamState::OutWaitingAnswer(substream, user_data)), + None, + false, + ), + Poll::Ready(Some(Err(error))) => { + let event = KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }; + (None, Some(ProtocolsHandlerEvent::Custom(event)), false) + } + Poll::Ready(None) => { + let event = KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(io::ErrorKind::UnexpectedEof.into()), + user_data, + }; + (None, Some(ProtocolsHandlerEvent::Custom(event)), false) + } } - }, + } SubstreamState::OutReportError(error, user_data) => { let event = KademliaHandlerEvent::QueryError { error, user_data }; (None, Some(ProtocolsHandlerEvent::Custom(event)), false) } - SubstreamState::OutClosing(mut stream) => match Sink::poll_close(Pin::new(&mut stream), cx) { + SubstreamState::OutClosing(mut stream) => match Sink::poll_close(Pin::new(&mut stream), cx) + { Poll::Ready(Ok(())) => (None, None, false), Poll::Pending => (Some(SubstreamState::OutClosing(stream)), None, false), Poll::Ready(Err(_)) => (None, None, false), }, - SubstreamState::InWaitingMessage(id, mut substream) => match Stream::poll_next(Pin::new(&mut substream), cx) { - Poll::Ready(Some(Ok(msg))) => { - if let Ok(ev) = process_kad_request(msg, id) { - ( - Some(SubstreamState::InWaitingUser(id, substream)), - Some(ProtocolsHandlerEvent::Custom(ev)), - false, - ) - } else { - (Some(SubstreamState::InClosing(substream)), None, true) + SubstreamState::InWaitingMessage(id, mut substream) => { + match Stream::poll_next(Pin::new(&mut substream), cx) { + Poll::Ready(Some(Ok(msg))) => { + if let Ok(ev) = process_kad_request(msg, id) { + ( + Some(SubstreamState::InWaitingUser(id, substream)), + Some(ProtocolsHandlerEvent::Custom(ev)), + false, + ) + } else { + (Some(SubstreamState::InClosing(substream)), None, true) + } + } + Poll::Pending => ( + Some(SubstreamState::InWaitingMessage(id, substream)), + None, + false, + ), + Poll::Ready(None) => { + trace!("Inbound substream: EOF"); + (None, None, false) + } + Poll::Ready(Some(Err(e))) => { + trace!("Inbound substream error: {:?}", e); + (None, None, false) } } - Poll::Pending => ( - Some(SubstreamState::InWaitingMessage(id, substream)), - None, - false, - ), - Poll::Ready(None) => { - trace!("Inbound substream: EOF"); - (None, None, false) - } - Poll::Ready(Some(Err(e))) => { - trace!("Inbound substream error: {:?}", e); - (None, None, false) - }, - }, + } SubstreamState::InWaitingUser(id, substream) => ( Some(SubstreamState::InWaitingUser(id, substream)), None, false, ), - SubstreamState::InPendingSend(id, mut substream, msg) => match Sink::poll_ready(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => match Sink::start_send(Pin::new(&mut substream), msg) { - Ok(()) => ( - Some(SubstreamState::InPendingFlush(id, substream)), + SubstreamState::InPendingSend(id, mut substream, msg) => { + match Sink::poll_ready(Pin::new(&mut substream), cx) { + Poll::Ready(Ok(())) => match Sink::start_send(Pin::new(&mut substream), msg) { + Ok(()) => ( + Some(SubstreamState::InPendingFlush(id, substream)), + None, + true, + ), + Err(_) => (None, None, false), + }, + Poll::Pending => ( + Some(SubstreamState::InPendingSend(id, substream, msg)), + None, + false, + ), + Poll::Ready(Err(_)) => (None, None, false), + } + } + SubstreamState::InPendingFlush(id, mut substream) => { + match Sink::poll_flush(Pin::new(&mut substream), cx) { + Poll::Ready(Ok(())) => ( + Some(SubstreamState::InWaitingMessage(id, substream)), None, true, ), - Err(_) => (None, None, false), - }, - Poll::Pending => ( - Some(SubstreamState::InPendingSend(id, substream, msg)), - None, - false, - ), - Poll::Ready(Err(_)) => (None, None, false), + Poll::Pending => ( + Some(SubstreamState::InPendingFlush(id, substream)), + None, + false, + ), + Poll::Ready(Err(_)) => (None, None, false), + } + } + SubstreamState::InClosing(mut stream) => { + match Sink::poll_close(Pin::new(&mut stream), cx) { + Poll::Ready(Ok(())) => (None, None, false), + Poll::Pending => (Some(SubstreamState::InClosing(stream)), None, false), + Poll::Ready(Err(_)) => (None, None, false), + } } - SubstreamState::InPendingFlush(id, mut substream) => match Sink::poll_flush(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => ( - Some(SubstreamState::InWaitingMessage(id, substream)), - None, - true, - ), - Poll::Pending => ( - Some(SubstreamState::InPendingFlush(id, substream)), - None, - false, - ), - Poll::Ready(Err(_)) => (None, None, false), - }, - SubstreamState::InClosing(mut stream) => match Sink::poll_close(Pin::new(&mut stream), cx) { - Poll::Ready(Ok(())) => (None, None, false), - Poll::Pending => (Some(SubstreamState::InClosing(stream)), None, false), - Poll::Ready(Err(_)) => (None, None, false), - }, } } @@ -987,7 +1022,7 @@ fn process_kad_request( KadRequestMsg::PutValue { record } => Ok(KademliaHandlerEvent::PutRecord { record, request_id: KademliaRequestId { connec_unique_id }, - }) + }), } } @@ -1005,11 +1040,9 @@ fn process_kad_response( user_data, } } - KadResponseMsg::FindNode { closer_peers } => { - KademliaHandlerEvent::FindNodeRes { - closer_peers, - user_data, - } + KadResponseMsg::FindNode { closer_peers } => KademliaHandlerEvent::FindNodeRes { + closer_peers, + user_data, }, KadResponseMsg::GetProviders { closer_peers, @@ -1027,12 +1060,10 @@ fn process_kad_response( closer_peers, user_data, }, - KadResponseMsg::PutValue { key, value, .. } => { - KademliaHandlerEvent::PutRecordRes { - key, - value, - user_data, - } - } + KadResponseMsg::PutValue { key, value, .. } => KademliaHandlerEvent::PutRecordRes { + key, + value, + user_data, + }, } } diff --git a/protocols/kad/src/jobs.rs b/protocols/kad/src/jobs.rs index 8737f9ad8b9..402a797a52d 100644 --- a/protocols/kad/src/jobs.rs +++ b/protocols/kad/src/jobs.rs @@ -61,15 +61,15 @@ //! > to the size of all stored records. As a job runs, the records are moved //! > out of the job to the consumer, where they can be dropped after being sent. -use crate::record::{self, Record, ProviderRecord, store::RecordStore}; -use libp2p_core::PeerId; +use crate::record::{self, store::RecordStore, ProviderRecord, Record}; use futures::prelude::*; +use libp2p_core::PeerId; use std::collections::HashSet; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; use std::vec; -use wasm_timer::{Instant, Delay}; +use wasm_timer::{Delay, Instant}; /// The maximum number of queries towards which background jobs /// are allowed to start new queries on an invocation of @@ -110,7 +110,7 @@ impl PeriodicJob { fn is_ready(&mut self, cx: &mut Context<'_>, now: Instant) -> bool { if let PeriodicJobState::Waiting(delay, deadline) = &mut self.state { if now >= *deadline || !Future::poll(Pin::new(delay), cx).is_pending() { - return true + return true; } } false @@ -121,7 +121,7 @@ impl PeriodicJob { #[derive(Debug)] enum PeriodicJobState { Running(T), - Waiting(Delay, Instant) + Waiting(Delay, Instant), } ////////////////////////////////////////////////////////////////////////////// @@ -158,8 +158,8 @@ impl PutRecordJob { skipped: HashSet::new(), inner: PeriodicJob { interval: replicate_interval, - state: PeriodicJobState::Waiting(delay, deadline) - } + state: PeriodicJobState::Waiting(delay, deadline), + }, } } @@ -192,11 +192,12 @@ impl PutRecordJob { /// to be run. pub fn poll(&mut self, cx: &mut Context<'_>, store: &mut T, now: Instant) -> Poll where - for<'a> T: RecordStore<'a> + for<'a> T: RecordStore<'a>, { if self.inner.is_ready(cx, now) { let publish = self.next_publish.map_or(false, |t_pub| now >= t_pub); - let records = store.records() + let records = store + .records() .filter_map(|r| { let is_publisher = r.publisher.as_ref() == Some(&self.local_id); if self.skipped.contains(&r.key) || (!publish && is_publisher) { @@ -204,8 +205,9 @@ impl PutRecordJob { } else { let mut record = r.into_owned(); if publish && is_publisher { - record.expires = record.expires.or_else(|| - self.record_ttl.map(|ttl| now + ttl)); + record.expires = record + .expires + .or_else(|| self.record_ttl.map(|ttl| now + ttl)); } Some(record) } @@ -228,7 +230,7 @@ impl PutRecordJob { if r.is_expired(now) { store.remove(&r.key) } else { - return Poll::Ready(r) + return Poll::Ready(r); } } @@ -248,7 +250,7 @@ impl PutRecordJob { /// Periodic job for replicating provider records. pub struct AddProviderJob { - inner: PeriodicJob> + inner: PeriodicJob>, } impl AddProviderJob { @@ -261,8 +263,8 @@ impl AddProviderJob { state: { let deadline = now + interval; PeriodicJobState::Waiting(Delay::new_at(deadline), deadline) - } - } + }, + }, } } @@ -284,12 +286,18 @@ impl AddProviderJob { /// Must be called in the context of a task. When `NotReady` is returned, /// the current task is registered to be notified when the job is ready /// to be run. - pub fn poll(&mut self, cx: &mut Context<'_>, store: &mut T, now: Instant) -> Poll + pub fn poll( + &mut self, + cx: &mut Context<'_>, + store: &mut T, + now: Instant, + ) -> Poll where - for<'a> T: RecordStore<'a> + for<'a> T: RecordStore<'a>, { if self.inner.is_ready(cx, now) { - let records = store.provided() + let records = store + .provided() .map(|r| r.into_owned()) .collect::>() .into_iter(); @@ -301,7 +309,7 @@ impl AddProviderJob { if r.is_expired(now) { store.remove_provider(&r.key, &r.provider) } else { - return Poll::Ready(r) + return Poll::Ready(r); } } @@ -317,11 +325,11 @@ impl AddProviderJob { #[cfg(test)] mod tests { + use super::*; use crate::record::store::MemoryStore; use futures::{executor::block_on, future::poll_fn}; use quickcheck::*; use rand::Rng; - use super::*; fn rand_put_record_job() -> PutRecordJob { let mut rng = rand::thread_rng(); diff --git a/protocols/kad/src/kbucket.rs b/protocols/kad/src/kbucket.rs index ff00b0d7ed0..111407789d8 100644 --- a/protocols/kad/src/kbucket.rs +++ b/protocols/kad/src/kbucket.rs @@ -91,7 +91,7 @@ pub struct KBucketsTable { buckets: Vec>, /// The list of evicted entries that have been replaced with pending /// entries since the last call to [`KBucketsTable::take_applied_pending`]. - applied_pending: VecDeque> + applied_pending: VecDeque>, } /// A (type-safe) index into a `KBucketsTable`, i.e. a non-negative integer in the @@ -132,7 +132,7 @@ impl BucketIndex { fn rand_distance(&self, rng: &mut impl rand::Rng) -> Distance { let mut bytes = [0u8; 32]; let quot = self.0 / 8; - for i in 0 .. quot { + for i in 0..quot { bytes[31 - i] = rng.gen(); } let rem = (self.0 % 8) as u32; @@ -146,7 +146,7 @@ impl BucketIndex { impl KBucketsTable where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { /// Creates a new, empty Kademlia routing table with entries partitioned /// into buckets as per the Kademlia protocol. @@ -157,8 +157,10 @@ where pub fn new(local_key: TKey, pending_timeout: Duration) -> Self { KBucketsTable { local_key, - buckets: (0 .. NUM_BUCKETS).map(|_| KBucket::new(pending_timeout)).collect(), - applied_pending: VecDeque::new() + buckets: (0..NUM_BUCKETS) + .map(|_| KBucket::new(pending_timeout)) + .collect(), + applied_pending: VecDeque::new(), } } @@ -194,7 +196,7 @@ where } KBucketRef { index: BucketIndex(i), - bucket: b + bucket: b, } }) } @@ -236,10 +238,9 @@ where /// Returns an iterator over the keys closest to `target`, ordered by /// increasing distance. - pub fn closest_keys<'a, T>(&'a mut self, target: &'a T) - -> impl Iterator + 'a + pub fn closest_keys<'a, T>(&'a mut self, target: &'a T) -> impl Iterator + 'a where - T: Clone + AsRef + T: Clone + AsRef, { let distance = self.local_key.as_ref().distance(target); ClosestIter { @@ -248,18 +249,20 @@ where table: self, buckets_iter: ClosestBucketsIter::new(distance), fmap: |b: &KBucket| -> ArrayVec<_> { - b.iter().map(|(n,_)| n.key.clone()).collect() - } + b.iter().map(|(n, _)| n.key.clone()).collect() + }, } } /// Returns an iterator over the nodes closest to the `target` key, ordered by /// increasing distance. - pub fn closest<'a, T>(&'a mut self, target: &'a T) - -> impl Iterator> + 'a + pub fn closest<'a, T>( + &'a mut self, + target: &'a T, + ) -> impl Iterator> + 'a where T: Clone + AsRef, - TVal: Clone + TVal: Clone, { let distance = self.local_key.as_ref().distance(target); ClosestIter { @@ -268,11 +271,13 @@ where table: self, buckets_iter: ClosestBucketsIter::new(distance), fmap: |b: &KBucket<_, TVal>| -> ArrayVec<_> { - b.iter().map(|(n, status)| EntryView { - node: n.clone(), - status - }).collect() - } + b.iter() + .map(|(n, status)| EntryView { + node: n.clone(), + status, + }) + .collect() + }, } } @@ -283,14 +288,15 @@ where /// calculated by backtracking from the target towards the local key. pub fn count_nodes_between(&mut self, target: &T) -> usize where - T: AsRef + T: AsRef, { let local_key = self.local_key.clone(); let distance = target.as_ref().distance(&local_key); let mut iter = ClosestBucketsIter::new(distance).take_while(|i| i.get() != 0); if let Some(i) = iter.next() { - let num_first = self.buckets[i.get()].iter() - .filter(|(n,_)| n.key.as_ref().distance(&local_key) <= distance) + let num_first = self.buckets[i.get()] + .iter() + .filter(|(n, _)| n.key.as_ref().distance(&local_key) <= distance) .count(); let num_rest: usize = iter.map(|i| self.buckets[i.get()].num_entries()).sum(); num_first + num_rest @@ -317,7 +323,7 @@ struct ClosestIter<'a, TTarget, TKey, TVal, TMap, TOut> { iter: Option>, /// The projection function / mapping applied on each bucket as /// it is encountered, producing the next `iter`ator. - fmap: TMap + fmap: TMap, } /// An iterator over the bucket indices, in the order determined by the `Distance` of @@ -327,7 +333,7 @@ struct ClosestBucketsIter { /// The distance to the `local_key`. distance: Distance, /// The current state of the iterator. - state: ClosestBucketsIterState + state: ClosestBucketsIterState, } /// Operating states of a `ClosestBucketsIter`. @@ -348,34 +354,36 @@ enum ClosestBucketsIterState { /// `255` is reached, the iterator transitions to state `Done`. ZoomOut(BucketIndex), /// The iterator is in this state once it has visited all buckets. - Done + Done, } impl ClosestBucketsIter { fn new(distance: Distance) -> Self { let state = match BucketIndex::new(&distance) { Some(i) => ClosestBucketsIterState::Start(i), - None => ClosestBucketsIterState::Start(BucketIndex(0)) + None => ClosestBucketsIterState::Start(BucketIndex(0)), }; Self { distance, state } } fn next_in(&self, i: BucketIndex) -> Option { - (0 .. i.get()).rev().find_map(|i| + (0..i.get()).rev().find_map(|i| { if self.distance.0.bit(i) { Some(BucketIndex(i)) } else { None - }) + } + }) } fn next_out(&self, i: BucketIndex) -> Option { - (i.get() + 1 .. NUM_BUCKETS).find_map(|i| + (i.get() + 1..NUM_BUCKETS).find_map(|i| { if !self.distance.0.bit(i) { Some(BucketIndex(i)) } else { None - }) + } + }) } } @@ -388,7 +396,7 @@ impl Iterator for ClosestBucketsIter { self.state = ClosestBucketsIterState::ZoomIn(i); Some(i) } - ClosestBucketsIterState::ZoomIn(i) => + ClosestBucketsIterState::ZoomIn(i) => { if let Some(i) = self.next_in(i) { self.state = ClosestBucketsIterState::ZoomIn(i); Some(i) @@ -397,7 +405,8 @@ impl Iterator for ClosestBucketsIter { self.state = ClosestBucketsIterState::ZoomOut(i); Some(i) } - ClosestBucketsIterState::ZoomOut(i) => + } + ClosestBucketsIterState::ZoomOut(i) => { if let Some(i) = self.next_out(i) { self.state = ClosestBucketsIterState::ZoomOut(i); Some(i) @@ -405,19 +414,19 @@ impl Iterator for ClosestBucketsIter { self.state = ClosestBucketsIterState::Done; None } - ClosestBucketsIterState::Done => None + } + ClosestBucketsIterState::Done => None, } } } -impl Iterator -for ClosestIter<'_, TTarget, TKey, TVal, TMap, TOut> +impl Iterator for ClosestIter<'_, TTarget, TKey, TVal, TMap, TOut> where TTarget: AsRef, TKey: Clone + AsRef, TVal: Clone, TMap: Fn(&KBucket) -> ArrayVec<[TOut; K_VALUE.get()]>, - TOut: AsRef + TOut: AsRef, { type Item = TOut; @@ -426,8 +435,8 @@ where match &mut self.iter { Some(iter) => match iter.next() { Some(k) => return Some(k), - None => self.iter = None - } + None => self.iter = None, + }, None => { if let Some(i) = self.buckets_iter.next() { let bucket = &mut self.table.buckets[i.get()]; @@ -435,12 +444,15 @@ where self.table.applied_pending.push_back(applied) } let mut v = (self.fmap)(bucket); - v.sort_by(|a, b| - self.target.as_ref().distance(a.as_ref()) - .cmp(&self.target.as_ref().distance(b.as_ref()))); + v.sort_by(|a, b| { + self.target + .as_ref() + .distance(a.as_ref()) + .cmp(&self.target.as_ref().distance(b.as_ref())) + }); self.iter = Some(v.into_iter()); } else { - return None + return None; } } } @@ -451,13 +463,13 @@ where /// A reference to a bucket in a [`KBucketsTable`]. pub struct KBucketRef<'a, TKey, TVal> { index: BucketIndex, - bucket: &'a mut KBucket + bucket: &'a mut KBucket, } impl<'a, TKey, TVal> KBucketRef<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { /// Returns the minimum inclusive and maximum inclusive [`Distance`] for /// this bucket. @@ -497,14 +509,12 @@ where /// Returns an iterator over the entries in the bucket. pub fn iter(&'a self) -> impl Iterator> { - self.bucket.iter().map(move |(n, status)| { - EntryRefView { - node: NodeRefView { - key: &n.key, - value: &n.value - }, - status - } + self.bucket.iter().map(move |(n, status)| EntryRefView { + node: NodeRefView { + key: &n.key, + value: &n.value, + }, + status, }) } } @@ -528,14 +538,17 @@ mod tests { let ix = BucketIndex(i); let num = g.gen_range(0, usize::min(K_VALUE.get(), num_total) + 1); num_total -= num; - for _ in 0 .. num { + for _ in 0..num { let distance = ix.rand_distance(g); let key = local_key.for_distance(distance); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; let status = NodeStatus::arbitrary(g); match b.insert(node, status) { InsertResult::Inserted => {} - _ => panic!() + _ => panic!(), } } } @@ -607,7 +620,7 @@ mod tests { if let Entry::Absent(entry) = table.entry(&other_id) { match entry.insert((), NodeStatus::Connected) { InsertResult::Inserted => (), - _ => panic!() + _ => panic!(), } } else { panic!() @@ -634,7 +647,9 @@ mod tests { let mut table = KBucketsTable::<_, ()>::new(local_key, Duration::from_secs(5)); let mut count = 0; loop { - if count == 100 { break; } + if count == 100 { + break; + } let key = Key::from(PeerId::random()); if let Entry::Absent(e) = table.entry(&key) { match e.insert((), NodeStatus::Connected) { @@ -646,12 +661,13 @@ mod tests { } } - let mut expected_keys: Vec<_> = table.buckets + let mut expected_keys: Vec<_> = table + .buckets .iter() - .flat_map(|t| t.iter().map(|(n,_)| n.key.clone())) + .flat_map(|t| t.iter().map(|(n, _)| n.key.clone())) .collect(); - for _ in 0 .. 10 { + for _ in 0..10 { let target_key = Key::from(PeerId::random()); let keys = table.closest_keys(&target_key).collect::>(); // The list of keys is expected to match the result of a full-table scan. @@ -675,18 +691,24 @@ mod tests { match e.insert((), NodeStatus::Connected) { InsertResult::Pending { disconnected } => { expected_applied = AppliedPending { - inserted: Node { key: key.clone(), value: () }, - evicted: Some(Node { key: disconnected, value: () }) + inserted: Node { + key: key.clone(), + value: (), + }, + evicted: Some(Node { + key: disconnected, + value: (), + }), }; full_bucket_index = BucketIndex::new(&key.distance(&local_key)); - break - }, - _ => panic!() + break; + } + _ => panic!(), } } else { panic!() } - }, + } _ => continue, } } else { @@ -701,12 +723,12 @@ mod tests { match table.entry(&expected_applied.inserted.key) { Entry::Present(_, NodeStatus::Connected) => {} - x => panic!("Unexpected entry: {:?}", x) + x => panic!("Unexpected entry: {:?}", x), } match table.entry(&expected_applied.evicted.as_ref().unwrap().key) { Entry::Absent(_) => {} - x => panic!("Unexpected entry: {:?}", x) + x => panic!("Unexpected entry: {:?}", x), } assert_eq!(Some(expected_applied), table.take_applied_pending()); @@ -734,6 +756,8 @@ mod tests { }) } - QuickCheck::new().tests(10).quickcheck(prop as fn(_,_) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _) -> _) } } diff --git a/protocols/kad/src/kbucket/bucket.rs b/protocols/kad/src/kbucket/bucket.rs index e9729917e8f..b9d34519d5d 100644 --- a/protocols/kad/src/kbucket/bucket.rs +++ b/protocols/kad/src/kbucket/bucket.rs @@ -25,8 +25,8 @@ //! > buckets in a `KBucketsTable` and hence is enforced by the public API //! > of the `KBucketsTable` and in particular the public `Entry` API. -pub use crate::K_VALUE; use super::*; +pub use crate::K_VALUE; /// A `PendingNode` is a `Node` that is pending insertion into a `KBucket`. #[derive(Debug, Clone)] @@ -51,7 +51,7 @@ pub enum NodeStatus { /// The node is considered connected. Connected, /// The node is considered disconnected. - Disconnected + Disconnected, } impl PendingNode { @@ -125,29 +125,29 @@ pub struct KBucket { /// The timeout window before a new pending node is eligible for insertion, /// if the least-recently connected node is not updated as being connected /// in the meantime. - pending_timeout: Duration + pending_timeout: Duration, } /// The result of inserting an entry into a bucket. #[must_use] #[derive(Debug, Clone, PartialEq, Eq)] pub enum InsertResult { - /// The entry has been successfully inserted. - Inserted, - /// The entry is pending insertion because the relevant bucket is currently full. - /// The entry is inserted after a timeout elapsed, if the status of the - /// least-recently connected (and currently disconnected) node in the bucket - /// is not updated before the timeout expires. - Pending { - /// The key of the least-recently connected entry that is currently considered - /// disconnected and whose corresponding peer should be checked for connectivity - /// in order to prevent it from being evicted. If connectivity to the peer is - /// re-established, the corresponding entry should be updated with - /// [`NodeStatus::Connected`]. - disconnected: TKey - }, - /// The entry was not inserted because the relevant bucket is full. - Full + /// The entry has been successfully inserted. + Inserted, + /// The entry is pending insertion because the relevant bucket is currently full. + /// The entry is inserted after a timeout elapsed, if the status of the + /// least-recently connected (and currently disconnected) node in the bucket + /// is not updated before the timeout expires. + Pending { + /// The key of the least-recently connected entry that is currently considered + /// disconnected and whose corresponding peer should be checked for connectivity + /// in order to prevent it from being evicted. If connectivity to the peer is + /// re-established, the corresponding entry should be updated with + /// [`NodeStatus::Connected`]. + disconnected: TKey, + }, + /// The entry was not inserted because the relevant bucket is full. + Full, } /// The result of applying a pending node to a bucket, possibly @@ -158,13 +158,13 @@ pub struct AppliedPending { pub inserted: Node, /// The node that has been evicted from the bucket to make room for the /// pending node, if any. - pub evicted: Option> + pub evicted: Option>, } impl KBucket where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { /// Creates a new `KBucket` with the given timeout for pending entries. pub fn new(pending_timeout: Duration) -> Self { @@ -189,7 +189,8 @@ where /// Returns a reference to the pending node of the bucket, if there is any /// with a matching key. pub fn as_pending(&self, key: &TKey) -> Option<&PendingNode> { - self.pending().filter(|p| p.node.key.as_ref() == key.as_ref()) + self.pending() + .filter(|p| p.node.key.as_ref() == key.as_ref()) } /// Returns a reference to a node in the bucket. @@ -199,7 +200,10 @@ where /// Returns an iterator over the nodes in the bucket, together with their status. pub fn iter(&self) -> impl Iterator, NodeStatus)> { - self.nodes.iter().enumerate().map(move |(p, n)| (n, self.status(Position(p)))) + self.nodes + .iter() + .enumerate() + .map(move |(p, n)| (n, self.status(Position(p)))) } /// Inserts the pending node into the bucket, if its timeout has elapsed, @@ -214,21 +218,20 @@ where if self.nodes.is_full() { if self.status(Position(0)) == NodeStatus::Connected { // The bucket is full with connected nodes. Drop the pending node. - return None + return None; } debug_assert!(self.first_connected_pos.map_or(true, |p| p > 0)); // (*) - // The pending node will be inserted. + // The pending node will be inserted. let inserted = pending.node.clone(); // A connected pending node goes at the end of the list for // the connected peers, removing the least-recently connected. if pending.status == NodeStatus::Connected { let evicted = Some(self.nodes.remove(0)); - self.first_connected_pos = self.first_connected_pos - .map_or_else( - | | Some(self.nodes.len()), - |p| p.checked_sub(1)); + self.first_connected_pos = self + .first_connected_pos + .map_or_else(|| Some(self.nodes.len()), |p| p.checked_sub(1)); self.nodes.push(pending.node); - return Some(AppliedPending { inserted, evicted }) + return Some(AppliedPending { inserted, evicted }); } // A disconnected pending node goes at the end of the list // for the disconnected peers. @@ -236,21 +239,25 @@ where let insert_pos = p.checked_sub(1).expect("by (*)"); let evicted = Some(self.nodes.remove(0)); self.nodes.insert(insert_pos, pending.node); - return Some(AppliedPending { inserted, evicted }) + return Some(AppliedPending { inserted, evicted }); } else { // All nodes are disconnected. Insert the new node as the most // recently disconnected, removing the least-recently disconnected. let evicted = Some(self.nodes.remove(0)); self.nodes.push(pending.node); - return Some(AppliedPending { inserted, evicted }) + return Some(AppliedPending { inserted, evicted }); } } else { // There is room in the bucket, so just insert the pending node. let inserted = pending.node.clone(); match self.insert(pending.node, pending.status) { - InsertResult::Inserted => - return Some(AppliedPending { inserted, evicted: None }), - _ => unreachable!("Bucket is not full.") + InsertResult::Inserted => { + return Some(AppliedPending { + inserted, + evicted: None, + }) + } + _ => unreachable!("Bucket is not full."), } } } else { @@ -289,8 +296,8 @@ where } // Reinsert the node with the desired status. match self.insert(node, status) { - InsertResult::Inserted => {}, - _ => unreachable!("The node is removed before being (re)inserted.") + InsertResult::Inserted => {} + _ => unreachable!("The node is removed before being (re)inserted."), } } } @@ -317,7 +324,7 @@ where NodeStatus::Connected => { if self.nodes.is_full() { if self.first_connected_pos == Some(0) || self.pending.is_some() { - return InsertResult::Full + return InsertResult::Full; } else { self.pending = Some(PendingNode { node, @@ -325,8 +332,8 @@ where replace: Instant::now() + self.pending_timeout, }); return InsertResult::Pending { - disconnected: self.nodes[0].key.clone() - } + disconnected: self.nodes[0].key.clone(), + }; } } let pos = self.nodes.len(); @@ -336,7 +343,7 @@ where } NodeStatus::Disconnected => { if self.nodes.is_full() { - return InsertResult::Full + return InsertResult::Full; } if let Some(ref mut p) = self.first_connected_pos { self.nodes.insert(*p, node); @@ -357,17 +364,19 @@ where let node = self.nodes.remove(pos.0); // Adjust `first_connected_pos` accordingly. match status { - NodeStatus::Connected => + NodeStatus::Connected => { if self.first_connected_pos.map_or(false, |p| p == pos.0) { if pos.0 == self.nodes.len() { // It was the last connected node. self.first_connected_pos = None } } - NodeStatus::Disconnected => + } + NodeStatus::Disconnected => { if let Some(ref mut p) = self.first_connected_pos { *p -= 1; } + } } Some((node, status, pos)) } else { @@ -406,7 +415,10 @@ where /// Gets the position of an node in the bucket. pub fn position(&self, key: &TKey) -> Option { - self.nodes.iter().position(|p| p.key.as_ref() == key.as_ref()).map(Position) + self.nodes + .iter() + .position(|p| p.key.as_ref() == key.as_ref()) + .map(Position) } /// Gets a mutable reference to the node identified by the given key. @@ -414,30 +426,35 @@ where /// Returns `None` if the given key does not refer to a node in the /// bucket. pub fn get_mut(&mut self, key: &TKey) -> Option<&mut Node> { - self.nodes.iter_mut().find(move |p| p.key.as_ref() == key.as_ref()) + self.nodes + .iter_mut() + .find(move |p| p.key.as_ref() == key.as_ref()) } } #[cfg(test)] mod tests { + use super::*; use libp2p_core::PeerId; + use quickcheck::*; use rand::Rng; use std::collections::VecDeque; - use super::*; - use quickcheck::*; impl Arbitrary for KBucket, ()> { fn arbitrary(g: &mut G) -> KBucket, ()> { let timeout = Duration::from_secs(g.gen_range(1, g.size() as u64)); let mut bucket = KBucket::, ()>::new(timeout); let num_nodes = g.gen_range(1, K_VALUE.get() + 1); - for _ in 0 .. num_nodes { + for _ in 0..num_nodes { let key = Key::from(PeerId::random()); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; let status = NodeStatus::arbitrary(g); match bucket.insert(node, status) { InsertResult::Inserted => {} - _ => panic!() + _ => panic!(), } } bucket @@ -463,7 +480,7 @@ mod tests { // Fill a bucket with random nodes with the given status. fn fill_bucket(bucket: &mut KBucket, ()>, status: NodeStatus) { let num_entries_start = bucket.num_entries(); - for i in 0 .. K_VALUE.get() - num_entries_start { + for i in 0..K_VALUE.get() - num_entries_start { let key = Key::from(PeerId::random()); let node = Node { key, value: () }; assert_eq!(InsertResult::Inserted, bucket.insert(node, status)); @@ -483,13 +500,16 @@ mod tests { // Fill the bucket, thereby populating the expected lists in insertion order. for status in status { let key = Key::from(PeerId::random()); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; let full = bucket.num_entries() == K_VALUE.get(); match bucket.insert(node, status) { InsertResult::Inserted => { let vec = match status { NodeStatus::Connected => &mut connected, - NodeStatus::Disconnected => &mut disconnected + NodeStatus::Disconnected => &mut disconnected, }; if full { vec.pop_front(); @@ -501,21 +521,20 @@ mod tests { } // Get all nodes from the bucket, together with their status. - let mut nodes = bucket.iter() + let mut nodes = bucket + .iter() .map(|(n, s)| (s, n.key.clone())) .collect::>(); // Split the list of nodes at the first connected node. - let first_connected_pos = nodes.iter().position(|(s,_)| *s == NodeStatus::Connected); + let first_connected_pos = nodes.iter().position(|(s, _)| *s == NodeStatus::Connected); assert_eq!(bucket.first_connected_pos, first_connected_pos); let tail = first_connected_pos.map_or(Vec::new(), |p| nodes.split_off(p)); // All nodes before the first connected node must be disconnected and // in insertion order. Similarly, all remaining nodes must be connected // and in insertion order. - nodes == Vec::from(disconnected) - && - tail == Vec::from(connected) + nodes == Vec::from(disconnected) && tail == Vec::from(connected) } quickcheck(prop as fn(_) -> _); @@ -532,12 +551,12 @@ mod tests { let key = Key::from(PeerId::random()); let node = Node { key, value: () }; match bucket.insert(node, NodeStatus::Disconnected) { - InsertResult::Full => {}, - x => panic!("{:?}", x) + InsertResult::Full => {} + x => panic!("{:?}", x), } // One-by-one fill the bucket with connected nodes, replacing the disconnected ones. - for i in 0 .. K_VALUE.get() { + for i in 0..K_VALUE.get() { let (first, first_status) = bucket.iter().next().unwrap(); let first_disconnected = first.clone(); assert_eq!(first_status, NodeStatus::Disconnected); @@ -545,17 +564,21 @@ mod tests { // Add a connected node, which is expected to be pending, scheduled to // replace the first (i.e. least-recently connected) node. let key = Key::from(PeerId::random()); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; match bucket.insert(node.clone(), NodeStatus::Connected) { - InsertResult::Pending { disconnected } => - assert_eq!(disconnected, first_disconnected.key), - x => panic!("{:?}", x) + InsertResult::Pending { disconnected } => { + assert_eq!(disconnected, first_disconnected.key) + } + x => panic!("{:?}", x), } // Trying to insert another connected node fails. match bucket.insert(node.clone(), NodeStatus::Connected) { - InsertResult::Full => {}, - x => panic!("{:?}", x) + InsertResult::Full => {} + x => panic!("{:?}", x), } assert!(bucket.pending().is_some()); @@ -564,10 +587,13 @@ mod tests { let pending = bucket.pending_mut().expect("No pending node."); pending.set_ready_at(Instant::now() - Duration::from_secs(1)); let result = bucket.apply_pending(); - assert_eq!(result, Some(AppliedPending { - inserted: node.clone(), - evicted: Some(first_disconnected) - })); + assert_eq!( + result, + Some(AppliedPending { + inserted: node.clone(), + evicted: Some(first_disconnected) + }) + ); assert_eq!(Some((&node, NodeStatus::Connected)), bucket.iter().last()); assert!(bucket.pending().is_none()); assert_eq!(Some(K_VALUE.get() - (i + 1)), bucket.first_connected_pos); @@ -580,8 +606,8 @@ mod tests { let key = Key::from(PeerId::random()); let node = Node { key, value: () }; match bucket.insert(node, NodeStatus::Connected) { - InsertResult::Full => {}, - x => panic!("{:?}", x) + InsertResult::Full => {} + x => panic!("{:?}", x), } } @@ -594,7 +620,10 @@ mod tests { // Add a connected pending node. let key = Key::from(PeerId::random()); - let node = Node { key: key.clone(), value: () }; + let node = Node { + key: key.clone(), + value: (), + }; if let InsertResult::Pending { disconnected } = bucket.insert(node, NodeStatus::Connected) { assert_eq!(&disconnected, &first_disconnected.key); } else { @@ -607,16 +636,21 @@ mod tests { // The pending node has been discarded. assert!(bucket.pending().is_none()); - assert!(bucket.iter().all(|(n,_)| &n.key != &key)); + assert!(bucket.iter().all(|(n, _)| &n.key != &key)); // The initially disconnected node is now the most-recently connected. - assert_eq!(Some((&first_disconnected, NodeStatus::Connected)), bucket.iter().last()); - assert_eq!(bucket.position(&first_disconnected.key).map(|p| p.0), bucket.first_connected_pos); + assert_eq!( + Some((&first_disconnected, NodeStatus::Connected)), + bucket.iter().last() + ); + assert_eq!( + bucket.position(&first_disconnected.key).map(|p| p.0), + bucket.first_connected_pos + ); assert_eq!(1, bucket.num_connected()); assert_eq!(K_VALUE.get() - 1, bucket.num_disconnected()); } - #[test] fn bucket_update() { fn prop(mut bucket: KBucket, ()>, pos: Position, status: NodeStatus) -> bool { @@ -627,7 +661,10 @@ mod tests { let key = bucket.nodes[pos].key.clone(); // Record the (ordered) list of status of all nodes in the bucket. - let mut expected = bucket.iter().map(|(n,s)| (n.key.clone(), s)).collect::>(); + let mut expected = bucket + .iter() + .map(|(n, s)| (n.key.clone(), s)) + .collect::>(); // Update the node in the bucket. bucket.update(&key, status); @@ -636,14 +673,17 @@ mod tests { // preserving the status and relative order of all other nodes. let expected_pos = match status { NodeStatus::Connected => num_nodes - 1, - NodeStatus::Disconnected => bucket.first_connected_pos.unwrap_or(num_nodes) - 1 + NodeStatus::Disconnected => bucket.first_connected_pos.unwrap_or(num_nodes) - 1, }; expected.remove(pos); expected.insert(expected_pos, (key.clone(), status)); - let actual = bucket.iter().map(|(n,s)| (n.key.clone(), s)).collect::>(); + let actual = bucket + .iter() + .map(|(n, s)| (n.key.clone(), s)) + .collect::>(); expected == actual } - quickcheck(prop as fn(_,_,_) -> _); + quickcheck(prop as fn(_, _, _) -> _); } } diff --git a/protocols/kad/src/kbucket/entry.rs b/protocols/kad/src/kbucket/entry.rs index e72140cec73..3447146007b 100644 --- a/protocols/kad/src/kbucket/entry.rs +++ b/protocols/kad/src/kbucket/entry.rs @@ -21,7 +21,7 @@ //! The `Entry` API for quering and modifying the entries of a `KBucketsTable` //! representing the nodes participating in the Kademlia DHT. -pub use super::bucket::{Node, NodeStatus, InsertResult, AppliedPending, K_VALUE}; +pub use super::bucket::{AppliedPending, InsertResult, Node, NodeStatus, K_VALUE}; pub use super::key::*; use super::*; @@ -31,27 +31,27 @@ pub struct EntryRefView<'a, TPeerId, TVal> { /// The node represented by the entry. pub node: NodeRefView<'a, TPeerId, TVal>, /// The status of the node identified by the key. - pub status: NodeStatus + pub status: NodeStatus, } /// An immutable by-reference view of a `Node`. pub struct NodeRefView<'a, TKey, TVal> { pub key: &'a TKey, - pub value: &'a TVal + pub value: &'a TVal, } impl EntryRefView<'_, TKey, TVal> { pub fn to_owned(&self) -> EntryView where TKey: Clone, - TVal: Clone + TVal: Clone, { EntryView { node: Node { key: self.node.key.clone(), - value: self.node.value.clone() + value: self.node.value.clone(), }, - status: self.status + status: self.status, } } } @@ -63,7 +63,7 @@ pub struct EntryView { /// The node represented by the entry. pub node: Node, /// The status of the node. - pub status: NodeStatus + pub status: NodeStatus, } impl, TVal> AsRef for EntryView { @@ -96,7 +96,7 @@ struct EntryRef<'a, TKey, TVal> { impl<'a, TKey, TVal> Entry<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { /// Creates a new `Entry` for a `Key`, encapsulating access to a bucket. pub(super) fn new(bucket: &'a mut KBucket, key: &'a TKey) -> Self { @@ -120,18 +120,18 @@ where Entry::Present(entry, status) => Some(EntryRefView { node: NodeRefView { key: entry.0.key, - value: entry.value() + value: entry.value(), }, - status: *status + status: *status, }), Entry::Pending(entry, status) => Some(EntryRefView { node: NodeRefView { key: entry.0.key, - value: entry.value() + value: entry.value(), }, - status: *status + status: *status, }), - _ => None + _ => None, } } @@ -170,7 +170,7 @@ pub struct PresentEntry<'a, TKey, TVal>(EntryRef<'a, TKey, TVal>); impl<'a, TKey, TVal> PresentEntry<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { fn new(bucket: &'a mut KBucket, key: &'a TKey) -> Self { PresentEntry(EntryRef { bucket, key }) @@ -183,7 +183,9 @@ where /// Returns the value associated with the key. pub fn value(&mut self) -> &mut TVal { - &mut self.0.bucket + &mut self + .0 + .bucket .get_mut(self.0.key) .expect("We can only build a PresentEntry if the entry is in the bucket; QED") .value @@ -196,7 +198,9 @@ where /// Removes the entry from the bucket. pub fn remove(self) -> EntryView { - let (node, status, _pos) = self.0.bucket + let (node, status, _pos) = self + .0 + .bucket .remove(&self.0.key) .expect("We can only build a PresentEntry if the entry is in the bucket; QED"); EntryView { node, status } @@ -210,7 +214,7 @@ pub struct PendingEntry<'a, TKey, TVal>(EntryRef<'a, TKey, TVal>); impl<'a, TKey, TVal> PendingEntry<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { fn new(bucket: &'a mut KBucket, key: &'a TKey) -> Self { PendingEntry(EntryRef { bucket, key }) @@ -223,7 +227,8 @@ where /// Returns the value associated with the key. pub fn value(&mut self) -> &mut TVal { - self.0.bucket + self.0 + .bucket .pending_mut() .expect("We can only build a ConnectedPendingEntry if the entry is pending; QED") .value_mut() @@ -237,10 +242,10 @@ where /// Removes the pending entry from the bucket. pub fn remove(self) -> EntryView { - let pending = self.0.bucket - .remove_pending() - .expect("We can only build a PendingEntry if the entry is pending insertion - into the bucket; QED"); + let pending = self.0.bucket.remove_pending().expect( + "We can only build a PendingEntry if the entry is pending insertion + into the bucket; QED", + ); let status = pending.status(); let node = pending.into_node(); EntryView { node, status } @@ -254,7 +259,7 @@ pub struct AbsentEntry<'a, TKey, TVal>(EntryRef<'a, TKey, TVal>); impl<'a, TKey, TVal> AbsentEntry<'a, TKey, TVal> where TKey: Clone + AsRef, - TVal: Clone + TVal: Clone, { fn new(bucket: &'a mut KBucket, key: &'a TKey) -> Self { AbsentEntry(EntryRef { bucket, key }) @@ -267,9 +272,12 @@ where /// Attempts to insert the entry into a bucket. pub fn insert(self, value: TVal, status: NodeStatus) -> InsertResult { - self.0.bucket.insert(Node { - key: self.0.key.clone(), - value - }, status) + self.0.bucket.insert( + Node { + key: self.0.key.clone(), + value, + }, + status, + ) } } diff --git a/protocols/kad/src/kbucket/key.rs b/protocols/kad/src/kbucket/key.rs index 38eb825ae66..ca3444da636 100644 --- a/protocols/kad/src/kbucket/key.rs +++ b/protocols/kad/src/kbucket/key.rs @@ -18,13 +18,13 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use uint::*; -use libp2p_core::{PeerId, multihash::Multihash}; +use crate::record; +use libp2p_core::{multihash::Multihash, PeerId}; +use sha2::digest::generic_array::{typenum::U32, GenericArray}; use sha2::{Digest, Sha256}; -use sha2::digest::generic_array::{GenericArray, typenum::U32}; use std::borrow::Borrow; use std::hash::{Hash, Hasher}; -use crate::record; +use uint::*; construct_uint! { /// 256-bit unsigned integer. @@ -52,7 +52,7 @@ impl Key { /// [`Key::into_preimage`]. pub fn new(preimage: T) -> Key where - T: Borrow<[u8]> + T: Borrow<[u8]>, { let bytes = KeyBytes::new(preimage.borrow()); Key { preimage, bytes } @@ -71,7 +71,7 @@ impl Key { /// Computes the distance of the keys according to the XOR metric. pub fn distance(&self, other: &U) -> Distance where - U: AsRef + U: AsRef, { self.bytes.distance(other) } @@ -93,22 +93,16 @@ impl Into for Key { } impl From for Key { - fn from(m: Multihash) -> Self { - let bytes = KeyBytes(Sha256::digest(&m.to_bytes())); - Key { - preimage: m, - bytes - } - } + fn from(m: Multihash) -> Self { + let bytes = KeyBytes(Sha256::digest(&m.to_bytes())); + Key { preimage: m, bytes } + } } impl From for Key { fn from(p: PeerId) -> Self { - let bytes = KeyBytes(Sha256::digest(&p.to_bytes())); - Key { - preimage: p, - bytes - } + let bytes = KeyBytes(Sha256::digest(&p.to_bytes())); + Key { preimage: p, bytes } } } @@ -153,7 +147,7 @@ impl KeyBytes { /// value through a random oracle. pub fn new(value: T) -> Self where - T: Borrow<[u8]> + T: Borrow<[u8]>, { KeyBytes(Sha256::digest(value.borrow())) } @@ -161,7 +155,7 @@ impl KeyBytes { /// Computes the distance of the keys according to the XOR metric. pub fn distance(&self, other: &U) -> Distance where - U: AsRef + U: AsRef, { let a = U256::from(self.0.as_slice()); let b = U256::from(other.as_ref().0.as_slice()); @@ -201,8 +195,8 @@ impl Distance { #[cfg(test)] mod tests { use super::*; - use quickcheck::*; use libp2p_core::multihash::Code; + use quickcheck::*; use rand::Rng; impl Arbitrary for Key { @@ -231,7 +225,7 @@ mod tests { fn prop(a: Key, b: Key) -> bool { a.distance(&b) == b.distance(&a) } - quickcheck(prop as fn(_,_) -> _) + quickcheck(prop as fn(_, _) -> _) } #[test] @@ -246,18 +240,18 @@ mod tests { TestResult::from_bool(a.distance(&c) <= Distance(ab_plus_bc)) } } - quickcheck(prop as fn(_,_,_) -> _) + quickcheck(prop as fn(_, _, _) -> _) } #[test] fn unidirectionality() { fn prop(a: Key, b: Key) -> bool { let d = a.distance(&b); - (0 .. 100).all(|_| { + (0..100).all(|_| { let c = Key::from(PeerId::random()); a.distance(&c) != d || b == c }) } - quickcheck(prop as fn(_,_) -> _) + quickcheck(prop as fn(_, _) -> _) } } diff --git a/protocols/kad/src/lib.rs b/protocols/kad/src/lib.rs index 30819ec0056..0fbeb61587d 100644 --- a/protocols/kad/src/lib.rs +++ b/protocols/kad/src/lib.rs @@ -40,56 +40,19 @@ mod dht_proto { pub use addresses::Addresses; pub use behaviour::{ - Kademlia, - KademliaBucketInserts, - KademliaConfig, - KademliaCaching, - KademliaEvent, - Quorum + AddProviderContext, AddProviderError, AddProviderOk, AddProviderPhase, AddProviderResult, + BootstrapError, BootstrapOk, BootstrapResult, GetClosestPeersError, GetClosestPeersOk, + GetClosestPeersResult, GetProvidersError, GetProvidersOk, GetProvidersResult, GetRecordError, + GetRecordOk, GetRecordResult, InboundRequest, PeerRecord, PutRecordContext, PutRecordError, + PutRecordOk, PutRecordPhase, PutRecordResult, QueryInfo, QueryMut, QueryRef, QueryResult, + QueryStats, }; pub use behaviour::{ - InboundRequest, - - QueryRef, - QueryMut, - - QueryResult, - QueryInfo, - QueryStats, - - PeerRecord, - - BootstrapResult, - BootstrapOk, - BootstrapError, - - GetRecordResult, - GetRecordOk, - GetRecordError, - - PutRecordPhase, - PutRecordContext, - PutRecordResult, - PutRecordOk, - PutRecordError, - - GetClosestPeersResult, - GetClosestPeersOk, - GetClosestPeersError, - - AddProviderPhase, - AddProviderContext, - AddProviderResult, - AddProviderOk, - AddProviderError, - - GetProvidersResult, - GetProvidersOk, - GetProvidersError, + Kademlia, KademliaBucketInserts, KademliaCaching, KademliaConfig, KademliaEvent, Quorum, }; -pub use query::QueryId; pub use protocol::KadConnectionType; -pub use record::{store, Record, ProviderRecord}; +pub use query::QueryId; +pub use record::{store, ProviderRecord, Record}; use std::num::NonZeroUsize; diff --git a/protocols/kad/src/protocol.rs b/protocols/kad/src/protocol.rs index 393074e2932..0f883649b05 100644 --- a/protocols/kad/src/protocol.rs +++ b/protocols/kad/src/protocol.rs @@ -26,14 +26,14 @@ //! to poll the underlying transport for incoming messages, and the `Sink` component //! is used to send messages to remote peers. -use bytes::BytesMut; -use codec::UviBytes; use crate::dht_proto as proto; use crate::record::{self, Record}; -use futures::prelude::*; use asynchronous_codec::Framed; -use libp2p_core::{Multiaddr, PeerId}; +use bytes::BytesMut; +use codec::UviBytes; +use futures::prelude::*; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; +use libp2p_core::{Multiaddr, PeerId}; use prost::Message; use std::{borrow::Cow, convert::TryFrom, time::Duration}; use std::{io, iter}; @@ -101,8 +101,7 @@ impl TryFrom for KadPeer { fn try_from(peer: proto::message::Peer) -> Result { // TODO: this is in fact a CID; not sure if this should be handled in `from_bytes` or // as a special case here - let node_id = PeerId::from_bytes(&peer.id) - .map_err(|_| invalid_data("invalid peer id"))?; + let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?; let mut addrs = Vec::with_capacity(peer.addrs.len()); for addr in peer.addrs.into_iter() { @@ -118,7 +117,7 @@ impl TryFrom for KadPeer { Ok(KadPeer { node_id, multiaddrs: addrs, - connection_ty + connection_ty, }) } } @@ -131,7 +130,7 @@ impl From for proto::message::Peer { connection: { let ct: proto::message::ConnectionType = peer.connection_ty.into(); ct as i32 - } + }, } } } @@ -202,13 +201,15 @@ where .with::<_, _, fn(_) -> _, _>(|response| { let proto_struct = resp_msg_to_proto(response); let mut buf = Vec::with_capacity(proto_struct.encoded_len()); - proto_struct.encode(&mut buf).expect("Vec provides capacity as needed"); + proto_struct + .encode(&mut buf) + .expect("Vec provides capacity as needed"); future::ready(Ok(io::Cursor::new(buf))) }) .and_then::<_, fn(_) -> _>(|bytes| { let request = match proto::Message::decode(bytes) { Ok(r) => r, - Err(err) => return future::ready(Err(err.into())) + Err(err) => return future::ready(Err(err.into())), }; future::ready(proto_to_req_msg(request)) }), @@ -234,13 +235,15 @@ where .with::<_, _, fn(_) -> _, _>(|request| { let proto_struct = req_msg_to_proto(request); let mut buf = Vec::with_capacity(proto_struct.encoded_len()); - proto_struct.encode(&mut buf).expect("Vec provides capacity as needed"); + proto_struct + .encode(&mut buf) + .expect("Vec provides capacity as needed"); future::ready(Ok(io::Cursor::new(buf))) }) .and_then::<_, fn(_) -> _>(|bytes| { let response = match proto::Message::decode(bytes) { Ok(r) => r, - Err(err) => return future::ready(Err(err.into())) + Err(err) => return future::ready(Err(err.into())), }; future::ready(proto_to_resp_msg(response)) }), @@ -301,9 +304,7 @@ pub enum KadRequestMsg { }, /// Request to put a value into the dht records. - PutValue { - record: Record, - } + PutValue { record: Record }, } /// Response that we can send to a peer or that we received from a peer. @@ -348,38 +349,38 @@ fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message { match kad_msg { KadRequestMsg::Ping => proto::Message { r#type: proto::message::MessageType::Ping as i32, - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::FindNode { key } => proto::Message { r#type: proto::message::MessageType::FindNode as i32, key, cluster_level_raw: 10, - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::GetProviders { key } => proto::Message { r#type: proto::message::MessageType::GetProviders as i32, key: key.to_vec(), cluster_level_raw: 10, - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::AddProvider { key, provider } => proto::Message { r#type: proto::message::MessageType::AddProvider as i32, cluster_level_raw: 10, key: key.to_vec(), provider_peers: vec![provider.into()], - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::GetValue { key } => proto::Message { r#type: proto::message::MessageType::GetValue as i32, cluster_level_raw: 10, key: key.to_vec(), - .. proto::Message::default() + ..proto::Message::default() }, KadRequestMsg::PutValue { record } => proto::Message { r#type: proto::message::MessageType::PutValue as i32, record: Some(record_to_proto(record)), - .. proto::Message::default() - } + ..proto::Message::default() + }, } } @@ -388,27 +389,33 @@ fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message { match kad_msg { KadResponseMsg::Pong => proto::Message { r#type: proto::message::MessageType::Ping as i32, - .. proto::Message::default() + ..proto::Message::default() }, KadResponseMsg::FindNode { closer_peers } => proto::Message { r#type: proto::message::MessageType::FindNode as i32, cluster_level_raw: 9, closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(), - .. proto::Message::default() + ..proto::Message::default() }, - KadResponseMsg::GetProviders { closer_peers, provider_peers } => proto::Message { + KadResponseMsg::GetProviders { + closer_peers, + provider_peers, + } => proto::Message { r#type: proto::message::MessageType::GetProviders as i32, cluster_level_raw: 9, closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(), provider_peers: provider_peers.into_iter().map(KadPeer::into).collect(), - .. proto::Message::default() + ..proto::Message::default() }, - KadResponseMsg::GetValue { record, closer_peers } => proto::Message { + KadResponseMsg::GetValue { + record, + closer_peers, + } => proto::Message { r#type: proto::message::MessageType::GetValue as i32, cluster_level_raw: 9, closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(), record: record.map(record_to_proto), - .. proto::Message::default() + ..proto::Message::default() }, KadResponseMsg::PutValue { key, value } => proto::Message { r#type: proto::message::MessageType::PutValue as i32, @@ -416,10 +423,10 @@ fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message { record: Some(proto::Record { key: key.to_vec(), value, - .. proto::Record::default() + ..proto::Record::default() }), - .. proto::Message::default() - } + ..proto::Message::default() + }, } } @@ -436,20 +443,19 @@ fn proto_to_req_msg(message: proto::Message) -> Result let record = record_from_proto(message.record.unwrap_or_default())?; Ok(KadRequestMsg::PutValue { record }) } - proto::message::MessageType::GetValue => { - Ok(KadRequestMsg::GetValue { key: record::Key::from(message.key) }) - } - proto::message::MessageType::FindNode => { - Ok(KadRequestMsg::FindNode { key: message.key }) - } - proto::message::MessageType::GetProviders => { - Ok(KadRequestMsg::GetProviders { key: record::Key::from(message.key)}) - } + proto::message::MessageType::GetValue => Ok(KadRequestMsg::GetValue { + key: record::Key::from(message.key), + }), + proto::message::MessageType::FindNode => Ok(KadRequestMsg::FindNode { key: message.key }), + proto::message::MessageType::GetProviders => Ok(KadRequestMsg::GetProviders { + key: record::Key::from(message.key), + }), proto::message::MessageType::AddProvider => { // TODO: for now we don't parse the peer properly, so it is possible that we get // parsing errors for peers even when they are valid; we ignore these // errors for now, but ultimately we should just error altogether - let provider = message.provider_peers + let provider = message + .provider_peers .into_iter() .find_map(|peer| KadPeer::try_from(peer).ok()); @@ -473,22 +479,28 @@ fn proto_to_resp_msg(message: proto::Message) -> Result Ok(KadResponseMsg::Pong), proto::message::MessageType::GetValue => { - let record = - if let Some(r) = message.record { - Some(record_from_proto(r)?) - } else { - None - }; + let record = if let Some(r) = message.record { + Some(record_from_proto(r)?) + } else { + None + }; - let closer_peers = message.closer_peers.into_iter() + let closer_peers = message + .closer_peers + .into_iter() .filter_map(|peer| KadPeer::try_from(peer).ok()) .collect(); - Ok(KadResponseMsg::GetValue { record, closer_peers }) + Ok(KadResponseMsg::GetValue { + record, + closer_peers, + }) } proto::message::MessageType::FindNode => { - let closer_peers = message.closer_peers.into_iter() + let closer_peers = message + .closer_peers + .into_iter() .filter_map(|peer| KadPeer::try_from(peer).ok()) .collect(); @@ -496,11 +508,15 @@ fn proto_to_resp_msg(message: proto::Message) -> Result { - let closer_peers = message.closer_peers.into_iter() + let closer_peers = message + .closer_peers + .into_iter() .filter_map(|peer| KadPeer::try_from(peer).ok()) .collect(); - let provider_peers = message.provider_peers.into_iter() + let provider_peers = message + .provider_peers + .into_iter() .filter_map(|peer| KadPeer::try_from(peer).ok()) .collect(); @@ -512,18 +528,19 @@ fn proto_to_resp_msg(message: proto::Message) -> Result { let key = record::Key::from(message.key); - let rec = message.record.ok_or_else(|| { - invalid_data("received PutValue message with no record") - })?; + let rec = message + .record + .ok_or_else(|| invalid_data("received PutValue message with no record"))?; Ok(KadResponseMsg::PutValue { key, - value: rec.value + value: rec.value, }) } - proto::message::MessageType::AddProvider => + proto::message::MessageType::AddProvider => { Err(invalid_data("received an unexpected AddProvider message")) + } } } @@ -531,23 +548,26 @@ fn record_from_proto(record: proto::Record) -> Result { let key = record::Key::from(record.key); let value = record.value; - let publisher = - if !record.publisher.is_empty() { - PeerId::from_bytes(&record.publisher) - .map(Some) - .map_err(|_| invalid_data("Invalid publisher peer ID."))? - } else { - None - }; - - let expires = - if record.ttl > 0 { - Some(Instant::now() + Duration::from_secs(record.ttl as u64)) - } else { - None - }; - - Ok(Record { key, value, publisher, expires }) + let publisher = if !record.publisher.is_empty() { + PeerId::from_bytes(&record.publisher) + .map(Some) + .map_err(|_| invalid_data("Invalid publisher peer ID."))? + } else { + None + }; + + let expires = if record.ttl > 0 { + Some(Instant::now() + Duration::from_secs(record.ttl as u64)) + } else { + None + }; + + Ok(Record { + key, + value, + publisher, + expires, + }) } fn record_to_proto(record: Record) -> proto::Record { @@ -555,7 +575,8 @@ fn record_to_proto(record: Record) -> proto::Record { key: record.key.to_vec(), value: record.value, publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(), - ttl: record.expires + ttl: record + .expires .map(|t| { let now = Instant::now(); if t > now { @@ -565,14 +586,14 @@ fn record_to_proto(record: Record) -> proto::Record { } }) .unwrap_or(0), - time_received: String::new() + time_received: String::new(), } } /// Creates an `io::Error` with `io::ErrorKind::InvalidData`. fn invalid_data(e: E) -> io::Error where - E: Into> + E: Into>, { io::Error::new(io::ErrorKind::InvalidData, e) } diff --git a/protocols/kad/src/query.rs b/protocols/kad/src/query.rs index 0b19425b7fe..6fcf90df79f 100644 --- a/protocols/kad/src/query.rs +++ b/protocols/kad/src/query.rs @@ -20,16 +20,18 @@ mod peers; -use peers::PeersIterState; -use peers::closest::{ClosestPeersIterConfig, ClosestPeersIter, disjoint::ClosestDisjointPeersIter}; +use peers::closest::{ + disjoint::ClosestDisjointPeersIter, ClosestPeersIter, ClosestPeersIterConfig, +}; use peers::fixed::FixedPeersIter; +use peers::PeersIterState; -use crate::{ALPHA_VALUE, K_VALUE}; use crate::kbucket::{Key, KeyBytes}; +use crate::{ALPHA_VALUE, K_VALUE}; use either::Either; use fnv::FnvHashMap; use libp2p_core::PeerId; -use std::{time::Duration, num::NonZeroUsize}; +use std::{num::NonZeroUsize, time::Duration}; use wasm_timer::Instant; /// A `QueryPool` provides an aggregate state machine for driving `Query`s to completion. @@ -53,7 +55,7 @@ pub enum QueryPoolState<'a, TInner> { /// A query has finished. Finished(Query), /// A query has timed out. - Timeout(Query) + Timeout(Query), } impl QueryPool { @@ -62,7 +64,7 @@ impl QueryPool { QueryPool { next_id: 0, config, - queries: Default::default() + queries: Default::default(), } } @@ -89,7 +91,7 @@ impl QueryPool { /// Adds a query to the pool that contacts a fixed set of peers. pub fn add_fixed(&mut self, peers: I, inner: TInner) -> QueryId where - I: IntoIterator + I: IntoIterator, { let id = self.next_query_id(); self.continue_fixed(id, peers, inner); @@ -101,7 +103,7 @@ impl QueryPool { /// earlier. pub fn continue_fixed(&mut self, id: QueryId, peers: I, inner: TInner) where - I: IntoIterator + I: IntoIterator, { assert!(!self.queries.contains_key(&id)); let parallelism = self.config.replication_factor; @@ -114,7 +116,7 @@ impl QueryPool { pub fn add_iter_closest(&mut self, target: T, peers: I, inner: TInner) -> QueryId where T: Into + Clone, - I: IntoIterator> + I: IntoIterator>, { let id = self.next_query_id(); self.continue_iter_closest(id, target, peers, inner); @@ -125,18 +127,18 @@ impl QueryPool { pub fn continue_iter_closest(&mut self, id: QueryId, target: T, peers: I, inner: TInner) where T: Into + Clone, - I: IntoIterator> + I: IntoIterator>, { let cfg = ClosestPeersIterConfig { num_results: self.config.replication_factor, parallelism: self.config.parallelism, - .. ClosestPeersIterConfig::default() + ..ClosestPeersIterConfig::default() }; let peer_iter = if self.config.disjoint_query_paths { - QueryPeerIter::ClosestDisjoint( - ClosestDisjointPeersIter::with_config(cfg, target, peers), - ) + QueryPeerIter::ClosestDisjoint(ClosestDisjointPeersIter::with_config( + cfg, target, peers, + )) } else { QueryPeerIter::Closest(ClosestPeersIter::with_config(cfg, target, peers)) }; @@ -172,18 +174,18 @@ impl QueryPool { match query.next(now) { PeersIterState::Finished => { finished = Some(query_id); - break + break; } PeersIterState::Waiting(Some(peer_id)) => { let peer = peer_id.into_owned(); waiting = Some((query_id, peer)); - break + break; } PeersIterState::Waiting(None) | PeersIterState::WaitingAtCapacity => { let elapsed = now - query.stats.start.unwrap_or(now); if elapsed >= self.config.timeout { timeout = Some(query_id); - break + break; } } } @@ -191,19 +193,19 @@ impl QueryPool { if let Some((query_id, peer_id)) = waiting { let query = self.queries.get_mut(&query_id).expect("s.a."); - return QueryPoolState::Waiting(Some((query, peer_id))) + return QueryPoolState::Waiting(Some((query, peer_id))); } if let Some(query_id) = finished { let mut query = self.queries.remove(&query_id).expect("s.a."); query.stats.end = Some(now); - return QueryPoolState::Finished(query) + return QueryPoolState::Finished(query); } if let Some(query_id) = timeout { let mut query = self.queries.remove(&query_id).expect("s.a."); query.stats.end = Some(now); - return QueryPoolState::Timeout(query) + return QueryPoolState::Timeout(query); } if self.queries.is_empty() { @@ -269,13 +271,18 @@ pub struct Query { enum QueryPeerIter { Closest(ClosestPeersIter), ClosestDisjoint(ClosestDisjointPeersIter), - Fixed(FixedPeersIter) + Fixed(FixedPeersIter), } impl Query { /// Creates a new query without starting it. fn new(id: QueryId, peer_iter: QueryPeerIter, inner: TInner) -> Self { - Query { id, inner, peer_iter, stats: QueryStats::empty() } + Query { + id, + inner, + peer_iter, + stats: QueryStats::empty(), + } } /// Gets the unique ID of the query. @@ -293,7 +300,7 @@ impl Query { let updated = match &mut self.peer_iter { QueryPeerIter::Closest(iter) => iter.on_failure(peer), QueryPeerIter::ClosestDisjoint(iter) => iter.on_failure(peer), - QueryPeerIter::Fixed(iter) => iter.on_failure(peer) + QueryPeerIter::Fixed(iter) => iter.on_failure(peer), }; if updated { self.stats.failure += 1; @@ -305,12 +312,12 @@ impl Query { /// the query, if applicable. pub fn on_success(&mut self, peer: &PeerId, new_peers: I) where - I: IntoIterator + I: IntoIterator, { let updated = match &mut self.peer_iter { QueryPeerIter::Closest(iter) => iter.on_success(peer, new_peers), QueryPeerIter::ClosestDisjoint(iter) => iter.on_success(peer, new_peers), - QueryPeerIter::Fixed(iter) => iter.on_success(peer) + QueryPeerIter::Fixed(iter) => iter.on_success(peer), }; if updated { self.stats.success += 1; @@ -322,7 +329,7 @@ impl Query { match &self.peer_iter { QueryPeerIter::Closest(iter) => iter.is_waiting(peer), QueryPeerIter::ClosestDisjoint(iter) => iter.is_waiting(peer), - QueryPeerIter::Fixed(iter) => iter.is_waiting(peer) + QueryPeerIter::Fixed(iter) => iter.is_waiting(peer), } } @@ -331,7 +338,7 @@ impl Query { let state = match &mut self.peer_iter { QueryPeerIter::Closest(iter) => iter.next(now), QueryPeerIter::ClosestDisjoint(iter) => iter.next(now), - QueryPeerIter::Fixed(iter) => iter.next() + QueryPeerIter::Fixed(iter) => iter.next(), }; if let PeersIterState::Waiting(Some(_)) = state { @@ -360,12 +367,18 @@ impl Query { /// [`QueryPoolState::Finished`]. pub fn try_finish<'a, I>(&mut self, peers: I) -> bool where - I: IntoIterator + I: IntoIterator, { match &mut self.peer_iter { - QueryPeerIter::Closest(iter) => { iter.finish(); true }, + QueryPeerIter::Closest(iter) => { + iter.finish(); + true + } QueryPeerIter::ClosestDisjoint(iter) => iter.finish_paths(peers), - QueryPeerIter::Fixed(iter) => { iter.finish(); true } + QueryPeerIter::Fixed(iter) => { + iter.finish(); + true + } } } @@ -377,7 +390,7 @@ impl Query { match &mut self.peer_iter { QueryPeerIter::Closest(iter) => iter.finish(), QueryPeerIter::ClosestDisjoint(iter) => iter.finish(), - QueryPeerIter::Fixed(iter) => iter.finish() + QueryPeerIter::Fixed(iter) => iter.finish(), } } @@ -389,7 +402,7 @@ impl Query { match &self.peer_iter { QueryPeerIter::Closest(iter) => iter.is_finished(), QueryPeerIter::ClosestDisjoint(iter) => iter.is_finished(), - QueryPeerIter::Fixed(iter) => iter.is_finished() + QueryPeerIter::Fixed(iter) => iter.is_finished(), } } @@ -398,9 +411,13 @@ impl Query { let peers = match self.peer_iter { QueryPeerIter::Closest(iter) => Either::Left(Either::Left(iter.into_result())), QueryPeerIter::ClosestDisjoint(iter) => Either::Left(Either::Right(iter.into_result())), - QueryPeerIter::Fixed(iter) => Either::Right(iter.into_result()) + QueryPeerIter::Fixed(iter) => Either::Right(iter.into_result()), }; - QueryResult { peers, inner: self.inner, stats: self.stats } + QueryResult { + peers, + inner: self.inner, + stats: self.stats, + } } } @@ -411,7 +428,7 @@ pub struct QueryResult { /// The successfully contacted peers. pub peers: TPeers, /// The collected query statistics. - pub stats: QueryStats + pub stats: QueryStats, } /// Execution statistics of a query. @@ -421,7 +438,7 @@ pub struct QueryStats { success: u32, failure: u32, start: Option, - end: Option + end: Option, } impl QueryStats { @@ -490,9 +507,9 @@ impl QueryStats { failure: self.failure + other.failure, start: match (self.start, other.start) { (Some(a), Some(b)) => Some(std::cmp::min(a, b)), - (a, b) => a.or(b) + (a, b) => a.or(b), }, - end: std::cmp::max(self.end, other.end) + end: std::cmp::max(self.end, other.end), } } } diff --git a/protocols/kad/src/query/peers.rs b/protocols/kad/src/query/peers.rs index 964068aa25a..7a177a494cf 100644 --- a/protocols/kad/src/query/peers.rs +++ b/protocols/kad/src/query/peers.rs @@ -63,5 +63,5 @@ pub enum PeersIterState<'a> { WaitingAtCapacity, /// The iterator finished. - Finished + Finished, } diff --git a/protocols/kad/src/query/peers/closest.rs b/protocols/kad/src/query/peers/closest.rs index 702335c50f8..684c109b934 100644 --- a/protocols/kad/src/query/peers/closest.rs +++ b/protocols/kad/src/query/peers/closest.rs @@ -20,11 +20,11 @@ use super::*; -use crate::{K_VALUE, ALPHA_VALUE}; -use crate::kbucket::{Key, KeyBytes, Distance}; +use crate::kbucket::{Distance, Key, KeyBytes}; +use crate::{ALPHA_VALUE, K_VALUE}; use libp2p_core::PeerId; -use std::{time::Duration, iter::FromIterator, num::NonZeroUsize}; use std::collections::btree_map::{BTreeMap, Entry}; +use std::{iter::FromIterator, num::NonZeroUsize, time::Duration}; use wasm_timer::Instant; pub mod disjoint; @@ -88,16 +88,24 @@ impl ClosestPeersIter { /// Creates a new iterator with a default configuration. pub fn new(target: KeyBytes, known_closest_peers: I) -> Self where - I: IntoIterator> + I: IntoIterator>, { - Self::with_config(ClosestPeersIterConfig::default(), target, known_closest_peers) + Self::with_config( + ClosestPeersIterConfig::default(), + target, + known_closest_peers, + ) } /// Creates a new iterator with the given configuration. - pub fn with_config(config: ClosestPeersIterConfig, target: T, known_closest_peers: I) -> Self + pub fn with_config( + config: ClosestPeersIterConfig, + target: T, + known_closest_peers: I, + ) -> Self where I: IntoIterator>, - T: Into + T: Into, { let target = target.into(); @@ -110,17 +118,18 @@ impl ClosestPeersIter { let state = PeerState::NotContacted; (distance, Peer { key, state }) }) - .take(K_VALUE.into())); + .take(K_VALUE.into()), + ); // The iterator initially makes progress by iterating towards the target. - let state = State::Iterating { no_progress : 0 }; + let state = State::Iterating { no_progress: 0 }; ClosestPeersIter { config, target, state, closest_peers, - num_waiting: 0 + num_waiting: 0, } } @@ -142,10 +151,10 @@ impl ClosestPeersIter { /// calling this function has no effect and `false` is returned. pub fn on_success(&mut self, peer: &PeerId, closer_peers: I) -> bool where - I: IntoIterator + I: IntoIterator, { if let State::Finished = self.state { - return false + return false; } let key = Key::from(*peer); @@ -163,10 +172,8 @@ impl ClosestPeersIter { PeerState::Unresponsive => { e.get_mut().state = PeerState::Succeeded; } - PeerState::NotContacted - | PeerState::Failed - | PeerState::Succeeded => return false - } + PeerState::NotContacted | PeerState::Failed | PeerState::Succeeded => return false, + }, } let num_closest = self.closest_peers.len(); @@ -176,7 +183,10 @@ impl ClosestPeersIter { for peer in closer_peers { let key = peer.into(); let distance = self.target.distance(&key); - let peer = Peer { key, state: PeerState::NotContacted }; + let peer = Peer { + key, + state: PeerState::NotContacted, + }; self.closest_peers.entry(distance).or_insert(peer); // The iterator makes progress if the new peer is either closer to the target // than any peer seen so far (i.e. is the first entry), or the iterator did @@ -195,13 +205,14 @@ impl ClosestPeersIter { State::Iterating { no_progress } } } - State::Stalled => + State::Stalled => { if progress { State::Iterating { no_progress: 0 } } else { State::Stalled } - State::Finished => State::Finished + } + State::Finished => State::Finished, }; true @@ -219,7 +230,7 @@ impl ClosestPeersIter { /// calling this function has no effect and `false` is returned. pub fn on_failure(&mut self, peer: &PeerId) -> bool { if let State::Finished = self.state { - return false + return false; } let key = Key::from(*peer); @@ -233,13 +244,9 @@ impl ClosestPeersIter { self.num_waiting -= 1; e.get_mut().state = PeerState::Failed } - PeerState::Unresponsive => { - e.get_mut().state = PeerState::Failed - } - PeerState::NotContacted - | PeerState::Failed - | PeerState::Succeeded => return false - } + PeerState::Unresponsive => e.get_mut().state = PeerState::Failed, + PeerState::NotContacted | PeerState::Failed | PeerState::Succeeded => return false, + }, } true @@ -248,10 +255,11 @@ impl ClosestPeersIter { /// Returns the list of peers for which the iterator is currently waiting /// for results. pub fn waiting(&self) -> impl Iterator { - self.closest_peers.values().filter_map(|peer| - match peer.state { + self.closest_peers + .values() + .filter_map(|peer| match peer.state { PeerState::Waiting(..) => Some(peer.key.preimage()), - _ => None + _ => None, }) } @@ -269,7 +277,7 @@ impl ClosestPeersIter { /// Advances the state of the iterator, potentially getting a new peer to contact. pub fn next(&mut self, now: Instant) -> PeersIterState<'_> { if let State::Finished = self.state { - return PeersIterState::Finished + return PeersIterState::Finished; } // Count the number of peers that returned a result. If there is a @@ -292,13 +300,11 @@ impl ClosestPeersIter { debug_assert!(self.num_waiting > 0); self.num_waiting -= 1; peer.state = PeerState::Unresponsive - } - else if at_capacity { + } else if at_capacity { // The iterator is still waiting for a result from a peer and is // at capacity w.r.t. the maximum number of peers being waited on. - return PeersIterState::WaitingAtCapacity - } - else { + return PeersIterState::WaitingAtCapacity; + } else { // The iterator is still waiting for a result from a peer and the // `result_counter` did not yet reach `num_results`. Therefore // the iterator is not yet done, regardless of already successful @@ -307,26 +313,28 @@ impl ClosestPeersIter { } } - PeerState::Succeeded => + PeerState::Succeeded => { if let Some(ref mut cnt) = result_counter { *cnt += 1; // If `num_results` successful results have been delivered for the // closest peers, the iterator is done. if *cnt >= self.config.num_results.get() { self.state = State::Finished; - return PeersIterState::Finished + return PeersIterState::Finished; } } + } - PeerState::NotContacted => + PeerState::NotContacted => { if !at_capacity { let timeout = now + self.config.peer_timeout; peer.state = PeerState::Waiting(timeout); self.num_waiting += 1; - return PeersIterState::Waiting(Some(Cow::Borrowed(peer.key.preimage()))) + return PeersIterState::Waiting(Some(Cow::Borrowed(peer.key.preimage()))); } else { - return PeersIterState::WaitingAtCapacity + return PeersIterState::WaitingAtCapacity; } + } PeerState::Unresponsive | PeerState::Failed => { // Skip over unresponsive or failed peers. @@ -379,11 +387,12 @@ impl ClosestPeersIter { /// k closest nodes it has not already queried". fn at_capacity(&self) -> bool { match self.state { - State::Stalled => self.num_waiting >= usize::max( - self.config.num_results.get(), self.config.parallelism.get() - ), + State::Stalled => { + self.num_waiting + >= usize::max(self.config.num_results.get(), self.config.parallelism.get()) + } State::Iterating { .. } => self.num_waiting >= self.config.parallelism.get(), - State::Finished => true + State::Finished => true, } } } @@ -425,14 +434,14 @@ enum State { /// from the closest peers (not counting those that failed or are unresponsive) /// or because the iterator ran out of peers that have not yet delivered /// results (or failed). - Finished + Finished, } /// Representation of a peer in the context of a iterator. #[derive(Debug, Clone)] struct Peer { key: Key, - state: PeerState + state: PeerState, } /// The state of a single `Peer`. @@ -466,19 +475,29 @@ enum PeerState { #[cfg(test)] mod tests { use super::*; - use libp2p_core::{PeerId, multihash::{Code, Multihash}}; + use libp2p_core::{ + multihash::{Code, Multihash}, + PeerId, + }; use quickcheck::*; - use rand::{Rng, rngs::StdRng, SeedableRng}; + use rand::{rngs::StdRng, Rng, SeedableRng}; use std::{iter, time::Duration}; fn random_peers(n: usize, g: &mut R) -> Vec { - (0 .. n).map(|_| PeerId::from_multihash( - Multihash::wrap(Code::Sha2_256.into(), &g.gen::<[u8; 32]>()).unwrap() - ).unwrap()).collect() + (0..n) + .map(|_| { + PeerId::from_multihash( + Multihash::wrap(Code::Sha2_256.into(), &g.gen::<[u8; 32]>()).unwrap(), + ) + .unwrap() + }) + .collect() } fn sorted>(target: &T, peers: &Vec>) -> bool { - peers.windows(2).all(|w| w[0].distance(&target) < w[1].distance(&target)) + peers + .windows(2) + .all(|w| w[0].distance(&target) < w[1].distance(&target)) } impl Arbitrary for ClosestPeersIter { @@ -510,26 +529,32 @@ mod tests { fn prop(iter: ClosestPeersIter) { let target = iter.target.clone(); - let (keys, states): (Vec<_>, Vec<_>) = iter.closest_peers + let (keys, states): (Vec<_>, Vec<_>) = iter + .closest_peers .values() .map(|e| (e.key.clone(), &e.state)) .unzip(); - let none_contacted = states - .iter() - .all(|s| match s { - PeerState::NotContacted => true, - _ => false - }); - - assert!(none_contacted, - "Unexpected peer state in new iterator."); - assert!(sorted(&target, &keys), - "Closest peers in new iterator not sorted by distance to target."); - assert_eq!(iter.num_waiting(), 0, - "Unexpected peers in progress in new iterator."); - assert_eq!(iter.into_result().count(), 0, - "Unexpected closest peers in new iterator"); + let none_contacted = states.iter().all(|s| match s { + PeerState::NotContacted => true, + _ => false, + }); + + assert!(none_contacted, "Unexpected peer state in new iterator."); + assert!( + sorted(&target, &keys), + "Closest peers in new iterator not sorted by distance to target." + ); + assert_eq!( + iter.num_waiting(), + 0, + "Unexpected peers in progress in new iterator." + ); + assert_eq!( + iter.into_result().count(), + 0, + "Unexpected closest peers in new iterator" + ); } QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _) @@ -541,7 +566,8 @@ mod tests { let now = Instant::now(); let mut rng = StdRng::from_seed(seed.0); - let mut expected = iter.closest_peers + let mut expected = iter + .closest_peers .values() .map(|e| e.key.clone()) .collect::>(); @@ -559,8 +585,7 @@ mod tests { // Split off the next up to `parallelism` expected peers. else if expected.len() < max_parallelism { remaining = Vec::new(); - } - else { + } else { remaining = expected.split_off(max_parallelism); } @@ -570,7 +595,9 @@ mod tests { PeersIterState::Finished => break 'finished, PeersIterState::Waiting(Some(p)) => assert_eq!(&*p, k.preimage()), PeersIterState::Waiting(None) => panic!("Expected another peer."), - PeersIterState::WaitingAtCapacity => panic!("Unexpectedly reached capacity.") + PeersIterState::WaitingAtCapacity => { + panic!("Unexpectedly reached capacity.") + } } } let num_waiting = iter.num_waiting(); @@ -611,7 +638,7 @@ mod tests { // of results. let all_contacted = iter.closest_peers.values().all(|e| match e.state { PeerState::NotContacted | PeerState::Waiting { .. } => false, - _ => true + _ => true, }); let target = iter.target.clone(); @@ -634,7 +661,9 @@ mod tests { } } - QuickCheck::new().tests(10).quickcheck(prop as fn(_, _) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _) -> _) } #[test] @@ -648,7 +677,7 @@ mod tests { // A first peer reports a "closer" peer. let peer1 = match iter.next(now) { PeersIterState::Waiting(Some(p)) => p.into_owned(), - _ => panic!("No peer.") + _ => panic!("No peer."), }; iter.on_success(&peer1, closer.clone()); // Duplicate result from te same peer. @@ -665,25 +694,38 @@ mod tests { }; // The "closer" peer must only be in the iterator once. - let n = iter.closest_peers.values().filter(|e| e.key.preimage() == &closer[0]).count(); + let n = iter + .closest_peers + .values() + .filter(|e| e.key.preimage() == &closer[0]) + .count(); assert_eq!(n, 1); true } - QuickCheck::new().tests(10).quickcheck(prop as fn(_, _) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _) -> _) } #[test] fn timeout() { fn prop(mut iter: ClosestPeersIter) -> bool { let mut now = Instant::now(); - let peer = iter.closest_peers.values().next().unwrap().key.clone().into_preimage(); + let peer = iter + .closest_peers + .values() + .next() + .unwrap() + .key + .clone() + .into_preimage(); // Poll the iterator for the first peer to be in progress. match iter.next(now) { PeersIterState::Waiting(Some(id)) => assert_eq!(&*id, &peer), - _ => panic!() + _ => panic!(), } // Artificially advance the clock. @@ -692,10 +734,13 @@ mod tests { // Advancing the iterator again should mark the first peer as unresponsive. let _ = iter.next(now); match &iter.closest_peers.values().next().unwrap() { - Peer { key, state: PeerState::Unresponsive } => { + Peer { + key, + state: PeerState::Unresponsive, + } => { assert_eq!(key.preimage(), &peer); - }, - Peer { state, .. } => panic!("Unexpected peer state: {:?}", state) + } + Peer { state, .. } => panic!("Unexpected peer state: {:?}", state), } let finished = iter.is_finished(); @@ -727,7 +772,7 @@ mod tests { PeersIterState::Waiting(Some(p)) => { let peer = p.clone().into_owned(); iter.on_failure(&peer); - }, + } _ => panic!("Expected iterator to yield another peer to query."), } } @@ -751,10 +796,8 @@ mod tests { ) } - iter.num_waiting = usize::max( - iter.config.parallelism.get(), - iter.config.num_results.get(), - ); + iter.num_waiting = + usize::max(iter.config.parallelism.get(), iter.config.num_results.get()); assert!( iter.at_capacity(), "Iterator should be at capacity if `max(parallelism, num_results)` requests are \ diff --git a/protocols/kad/src/query/peers/closest/disjoint.rs b/protocols/kad/src/query/peers/closest/disjoint.rs index b295355634b..01506ff6f7b 100644 --- a/protocols/kad/src/query/peers/closest/disjoint.rs +++ b/protocols/kad/src/query/peers/closest/disjoint.rs @@ -72,7 +72,10 @@ impl ClosestDisjointPeersIter { I: IntoIterator>, T: Into + Clone, { - let peers = known_closest_peers.into_iter().take(K_VALUE.get()).collect::>(); + let peers = known_closest_peers + .into_iter() + .take(K_VALUE.get()) + .collect::>(); let iters = (0..config.parallelism.get()) // NOTE: All [`ClosestPeersIter`] share the same set of peers at // initialization. The [`ClosestDisjointPeersIter.contacted_peers`] @@ -88,7 +91,9 @@ impl ClosestDisjointPeersIter { config, target: target.into(), iters, - iter_order: (0..iters_len).map(IteratorIndex as fn(usize) -> IteratorIndex).cycle(), + iter_order: (0..iters_len) + .map(IteratorIndex as fn(usize) -> IteratorIndex) + .cycle(), contacted_peers: HashMap::new(), } } @@ -106,7 +111,11 @@ impl ClosestDisjointPeersIter { pub fn on_failure(&mut self, peer: &PeerId) -> bool { let mut updated = false; - if let Some(PeerState{ initiated_by, response }) = self.contacted_peers.get_mut(peer) { + if let Some(PeerState { + initiated_by, + response, + }) = self.contacted_peers.get_mut(peer) + { updated = self.iters[*initiated_by].on_failure(peer); if updated { @@ -148,7 +157,11 @@ impl ClosestDisjointPeersIter { { let mut updated = false; - if let Some(PeerState{ initiated_by, response }) = self.contacted_peers.get_mut(peer) { + if let Some(PeerState { + initiated_by, + response, + }) = self.contacted_peers.get_mut(peer) + { // Pass the new `closer_peers` to the iterator that first yielded // the peer. updated = self.iters[*initiated_by].on_success(peer, closer_peers); @@ -185,7 +198,7 @@ impl ClosestDisjointPeersIter { let mut state = None; // Ensure querying each iterator at most once. - for _ in 0 .. self.iters.len() { + for _ in 0..self.iters.len() { let i = self.iter_order.next().expect("Cycle never ends."); let iter = &mut self.iters[i]; @@ -198,7 +211,7 @@ impl ClosestDisjointPeersIter { // [`ClosestPeersIter`] yielded a peer. Thus this state is // unreachable. unreachable!(); - }, + } Some(PeersIterState::Waiting(None)) => {} Some(PeersIterState::WaitingAtCapacity) => { // At least one ClosestPeersIter is no longer at capacity, thus the @@ -210,14 +223,13 @@ impl ClosestDisjointPeersIter { unreachable!(); } None => state = Some(PeersIterState::Waiting(None)), - }; break; } PeersIterState::Waiting(Some(peer)) => { match self.contacted_peers.get_mut(&*peer) { - Some(PeerState{ response, .. }) => { + Some(PeerState { response, .. }) => { // Another iterator already contacted this peer. let peer = peer.into_owned(); @@ -225,27 +237,27 @@ impl ClosestDisjointPeersIter { // The iterator will be notified later whether the given node // was successfully contacted or not. See // [`ClosestDisjointPeersIter::on_success`] for details. - ResponseState::Waiting => {}, + ResponseState::Waiting => {} ResponseState::Succeeded => { // Given that iterator was not the first to contact the peer // it will not be made aware of the closer peers discovered // to uphold the S/Kademlia disjoint paths guarantee. See // [`ClosestDisjointPeersIter::on_success`] for details. iter.on_success(&peer, std::iter::empty()); - }, + } ResponseState::Failed => { iter.on_failure(&peer); - }, + } } - }, + } None => { // The iterator is the first to contact this peer. - self.contacted_peers.insert( - peer.clone().into_owned(), - PeerState::new(i), - ); - return PeersIterState::Waiting(Some(Cow::Owned(peer.into_owned()))); - }, + self.contacted_peers + .insert(peer.clone().into_owned(), PeerState::new(i)); + return PeersIterState::Waiting(Some(Cow::Owned( + peer.into_owned(), + ))); + } } } PeersIterState::WaitingAtCapacity => { @@ -255,13 +267,13 @@ impl ClosestDisjointPeersIter { // [`ClosestPeersIter`] yielded a peer. Thus this state is // unreachable. unreachable!(); - }, + } Some(PeersIterState::Waiting(None)) => {} Some(PeersIterState::WaitingAtCapacity) => {} Some(PeersIterState::Finished) => { // `state` is never set to `Finished`. unreachable!(); - }, + } None => state = Some(PeersIterState::WaitingAtCapacity), }; @@ -280,10 +292,10 @@ impl ClosestDisjointPeersIter { /// See [`crate::query::Query::try_finish`] for details. pub fn finish_paths<'a, I>(&mut self, peers: I) -> bool where - I: IntoIterator + I: IntoIterator, { for peer in peers { - if let Some(PeerState{ initiated_by, .. }) = self.contacted_peers.get_mut(peer) { + if let Some(PeerState { initiated_by, .. }) = self.contacted_peers.get_mut(peer) { self.iters[*initiated_by].finish(); } } @@ -312,7 +324,9 @@ impl ClosestDisjointPeersIter { /// differentiate benign from faulty paths it as well returns faulty /// peers and thus overall returns more than `num_results` peers. pub fn into_result(self) -> impl Iterator { - let result_per_path= self.iters.into_iter() + let result_per_path = self + .iters + .into_iter() .map(|iter| iter.into_result().map(Key::from)); ResultIter::new(self.target, result_per_path).map(Key::into_preimage) @@ -370,7 +384,8 @@ enum ResponseState { // // Note: This operates under the assumption that `I` is ordered. #[derive(Clone, Debug)] -struct ResultIter where +struct ResultIter +where I: Iterator>, { target: KeyBytes, @@ -379,7 +394,7 @@ struct ResultIter where impl>> ResultIter { fn new(target: KeyBytes, iters: impl Iterator) -> Self { - ResultIter{ + ResultIter { target, iters: iters.map(Iterator::peekable).collect(), } @@ -392,36 +407,34 @@ impl>> Iterator for ResultIter { fn next(&mut self) -> Option { let target = &self.target; - self.iters.iter_mut() + self.iters + .iter_mut() // Find the iterator with the next closest peer. - .fold( - Option::<&mut Peekable<_>>::None, - |iter_a, iter_b| { - let iter_a = match iter_a { - Some(iter_a) => iter_a, - None => return Some(iter_b), - }; - - match (iter_a.peek(), iter_b.peek()) { - (Some(next_a), Some(next_b)) => { - if next_a == next_b { - // Remove from one for deduplication. - iter_b.next(); - return Some(iter_a) - } + .fold(Option::<&mut Peekable<_>>::None, |iter_a, iter_b| { + let iter_a = match iter_a { + Some(iter_a) => iter_a, + None => return Some(iter_b), + }; + + match (iter_a.peek(), iter_b.peek()) { + (Some(next_a), Some(next_b)) => { + if next_a == next_b { + // Remove from one for deduplication. + iter_b.next(); + return Some(iter_a); + } - if target.distance(next_a) < target.distance(next_b) { - Some(iter_a) - } else { - Some(iter_b) - } - }, - (Some(_), None) => Some(iter_a), - (None, Some(_)) => Some(iter_b), - (None, None) => None, + if target.distance(next_a) < target.distance(next_b) { + Some(iter_a) + } else { + Some(iter_b) + } } - }, - ) + (Some(_), None) => Some(iter_a), + (None, Some(_)) => Some(iter_b), + (None, None) => None, + } + }) // Pop off the next closest peer from that iterator. .and_then(Iterator::next) } @@ -434,7 +447,7 @@ mod tests { use crate::K_VALUE; use libp2p_core::multihash::{Code, Multihash}; use quickcheck::*; - use rand::{Rng, seq::SliceRandom}; + use rand::{seq::SliceRandom, Rng}; use std::collections::HashSet; use std::iter; @@ -442,22 +455,18 @@ mod tests { fn arbitrary(g: &mut G) -> Self { let target = Target::arbitrary(g).0; let num_closest_iters = g.gen_range(0, 20 + 1); - let peers = random_peers( - g.gen_range(0, 20 * num_closest_iters + 1), - g, - ); + let peers = random_peers(g.gen_range(0, 20 * num_closest_iters + 1), g); let iters: Vec<_> = (0..num_closest_iters) .map(|_| { let num_peers = g.gen_range(0, 20 + 1); - let mut peers = peers.choose_multiple(g, num_peers) + let mut peers = peers + .choose_multiple(g, num_peers) .cloned() .map(Key::from) .collect::>(); - peers.sort_unstable_by(|a, b| { - target.distance(a).cmp(&target.distance(b)) - }); + peers.sort_unstable_by(|a, b| target.distance(a).cmp(&target.distance(b))); peers.into_iter() }) @@ -467,7 +476,8 @@ mod tests { } fn shrink(&self) -> Box> { - let peers = self.iters + let peers = self + .iters .clone() .into_iter() .flatten() @@ -475,7 +485,9 @@ mod tests { .into_iter() .collect::>(); - let iters = self.iters.clone() + let iters = self + .iters + .clone() .into_iter() .map(|iter| iter.collect::>()) .collect(); @@ -503,14 +515,18 @@ mod tests { // The peer that should not be included. let peer = self.peers.pop()?; - let iters = self.iters.clone().into_iter() + let iters = self + .iters + .clone() + .into_iter() .filter_map(|mut iter| { iter.retain(|p| p != &peer); if iter.is_empty() { return None; } Some(iter.into_iter()) - }).collect::>(); + }) + .collect::>(); Some(ResultIter::new(self.target.clone(), iters.into_iter())) } @@ -526,16 +542,22 @@ mod tests { } fn random_peers(n: usize, g: &mut R) -> Vec { - (0 .. n).map(|_| PeerId::from_multihash( - Multihash::wrap(Code::Sha2_256.into(), &g.gen::<[u8; 32]>()).unwrap() - ).unwrap()).collect() + (0..n) + .map(|_| { + PeerId::from_multihash( + Multihash::wrap(Code::Sha2_256.into(), &g.gen::<[u8; 32]>()).unwrap(), + ) + .unwrap() + }) + .collect() } #[test] fn result_iter_returns_deduplicated_ordered_peer_id_stream() { fn prop(result_iter: ResultIter>>) { let expected = { - let mut deduplicated = result_iter.clone() + let mut deduplicated = result_iter + .clone() .iters .into_iter() .flatten() @@ -545,7 +567,10 @@ mod tests { .collect::>(); deduplicated.sort_unstable_by(|a, b| { - result_iter.target.distance(a).cmp(&result_iter.target.distance(b)) + result_iter + .target + .distance(a) + .cmp(&result_iter.target.distance(b)) }); deduplicated @@ -560,7 +585,7 @@ mod tests { #[derive(Debug, Clone)] struct Parallelism(NonZeroUsize); - impl Arbitrary for Parallelism{ + impl Arbitrary for Parallelism { fn arbitrary(g: &mut G) -> Self { Parallelism(NonZeroUsize::new(g.gen_range(1, 10)).unwrap()) } @@ -569,7 +594,7 @@ mod tests { #[derive(Debug, Clone)] struct NumResults(NonZeroUsize); - impl Arbitrary for NumResults{ + impl Arbitrary for NumResults { fn arbitrary(g: &mut G) -> Self { NumResults(NonZeroUsize::new(g.gen_range(1, K_VALUE.get())).unwrap()) } @@ -604,13 +629,12 @@ mod tests { let now = Instant::now(); let target: KeyBytes = Key::from(PeerId::random()).into(); - let mut pool = [0; 12].iter() + let mut pool = [0; 12] + .iter() .map(|_| Key::from(PeerId::random())) .collect::>(); - pool.sort_unstable_by(|a, b| { - target.distance(a).cmp(&target.distance(b)) - }); + pool.sort_unstable_by(|a, b| target.distance(a).cmp(&target.distance(b))); let known_closest_peers = pool.split_off(pool.len() - 3); @@ -637,10 +661,7 @@ mod tests { } } - assert_eq!( - PeersIterState::WaitingAtCapacity, - peers_iter.next(now), - ); + assert_eq!(PeersIterState::WaitingAtCapacity, peers_iter.next(now),); let response_2 = pool.split_off(pool.len() - 3); let response_3 = pool.split_off(pool.len() - 3); @@ -651,7 +672,10 @@ mod tests { // Response from malicious peer 1. peers_iter.on_success( known_closest_peers[0].preimage(), - malicious_response_1.clone().into_iter().map(|k| k.preimage().clone()), + malicious_response_1 + .clone() + .into_iter() + .map(|k| k.preimage().clone()), ); // Response from peer 2. @@ -676,7 +700,7 @@ mod tests { } else { panic!("Expected iterator to return peer to query."); } - }; + } // Expect a peer from each disjoint path. assert!(next_to_query.contains(malicious_response_1[0].preimage())); @@ -696,10 +720,7 @@ mod tests { } } - assert_eq!( - PeersIterState::Finished, - peers_iter.next(now), - ); + assert_eq!(PeersIterState::Finished, peers_iter.next(now),); let final_peers: Vec<_> = peers_iter.into_result().collect(); @@ -715,7 +736,9 @@ mod tests { impl std::fmt::Debug for Graph { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - fmt.debug_list().entries(self.0.iter().map(|(id, _)| id)).finish() + fmt.debug_list() + .entries(self.0.iter().map(|(id, _)| id)) + .finish() } } @@ -727,22 +750,24 @@ mod tests { .collect::>(); // Make each peer aware of its direct neighborhood. - let mut peers = peer_ids.clone().into_iter() + let mut peers = peer_ids + .clone() + .into_iter() .map(|(peer_id, key)| { - peer_ids.sort_unstable_by(|(_, a), (_, b)| { - key.distance(a).cmp(&key.distance(b)) - }); + peer_ids + .sort_unstable_by(|(_, a), (_, b)| key.distance(a).cmp(&key.distance(b))); assert_eq!(peer_id, peer_ids[0].0); - let known_peers = peer_ids.iter() + let known_peers = peer_ids + .iter() // Skip itself. .skip(1) .take(K_VALUE.get()) .cloned() .collect::>(); - (peer_id, Peer{ known_peers }) + (peer_id, Peer { known_peers }) }) .collect::>(); @@ -751,7 +776,8 @@ mod tests { peer_ids.shuffle(g); let num_peers = g.gen_range(K_VALUE.get(), peer_ids.len() + 1); - let mut random_peer_ids = peer_ids.choose_multiple(g, num_peers) + let mut random_peer_ids = peer_ids + .choose_multiple(g, num_peers) // Make sure not to include itself. .filter(|(id, _)| peer_id != id) .cloned() @@ -760,7 +786,10 @@ mod tests { peer.known_peers.append(&mut random_peer_ids); peer.known_peers = std::mem::replace(&mut peer.known_peers, vec![]) // Deduplicate peer ids. - .into_iter().collect::>().into_iter().collect(); + .into_iter() + .collect::>() + .into_iter() + .collect(); } Graph(peers) @@ -769,21 +798,22 @@ mod tests { impl Graph { fn get_closest_peer(&self, target: &KeyBytes) -> PeerId { - self.0.iter() + self.0 + .iter() .map(|(peer_id, _)| (target.distance(&Key::from(*peer_id)), peer_id)) - .fold(None, |acc, (distance_b, peer_id_b)| { - match acc { - None => Some((distance_b, peer_id_b)), - Some((distance_a, peer_id_a)) => if distance_a < distance_b { + .fold(None, |acc, (distance_b, peer_id_b)| match acc { + None => Some((distance_b, peer_id_b)), + Some((distance_a, peer_id_a)) => { + if distance_a < distance_b { Some((distance_a, peer_id_a)) } else { Some((distance_b, peer_id_b)) } } - }) .expect("Graph to have at least one peer.") - .1.clone() + .1 + .clone() } } @@ -794,11 +824,15 @@ mod tests { impl Peer { fn get_closest_peers(&mut self, target: &KeyBytes) -> Vec { - self.known_peers.sort_unstable_by(|(_, a), (_, b)| { - target.distance(a).cmp(&target.distance(b)) - }); + self.known_peers + .sort_unstable_by(|(_, a), (_, b)| target.distance(a).cmp(&target.distance(b))); - self.known_peers.iter().take(K_VALUE.get()).map(|(id, _)| id).cloned().collect() + self.known_peers + .iter() + .take(K_VALUE.get()) + .map(|(id, _)| id) + .cloned() + .collect() } } @@ -846,15 +880,16 @@ mod tests { let target: KeyBytes = target.0; let closest_peer = graph.get_closest_peer(&target); - let mut known_closest_peers = graph.0.iter() + let mut known_closest_peers = graph + .0 + .iter() .take(K_VALUE.get()) .map(|(key, _peers)| Key::from(*key)) .collect::>(); - known_closest_peers.sort_unstable_by(|a, b| { - target.distance(a).cmp(&target.distance(b)) - }); + known_closest_peers + .sort_unstable_by(|a, b| target.distance(a).cmp(&target.distance(b))); - let cfg = ClosestPeersIterConfig{ + let cfg = ClosestPeersIterConfig { parallelism: parallelism.0, num_results: num_results.0, ..ClosestPeersIterConfig::default() @@ -923,25 +958,32 @@ mod tests { match iter.next(now) { PeersIterState::Waiting(Some(peer_id)) => { let peer_id = peer_id.clone().into_owned(); - let closest_peers = graph.0.get_mut(&peer_id) + let closest_peers = graph + .0 + .get_mut(&peer_id) .unwrap() .get_closest_peers(&target); iter.on_success(&peer_id, closest_peers); - } , - PeersIterState::WaitingAtCapacity | PeersIterState::Waiting(None) => - panic!("There is never more than one request in flight."), + } + PeersIterState::WaitingAtCapacity | PeersIterState::Waiting(None) => { + panic!("There is never more than one request in flight.") + } PeersIterState::Finished => break, } } - let mut result = iter.into_result().into_iter().map(Key::from).collect::>(); - result.sort_unstable_by(|a, b| { - target.distance(a).cmp(&target.distance(b)) - }); + let mut result = iter + .into_result() + .into_iter() + .map(Key::from) + .collect::>(); + result.sort_unstable_by(|a, b| target.distance(a).cmp(&target.distance(b))); result.into_iter().map(|k| k.into_preimage()).collect() } - QuickCheck::new().tests(10).quickcheck(prop as fn(_, _, _, _) -> _) + QuickCheck::new() + .tests(10) + .quickcheck(prop as fn(_, _, _, _) -> _) } #[test] @@ -957,16 +999,22 @@ mod tests { // Expect peer to be marked as succeeded. assert!(iter.on_success(&peer, iter::empty())); - assert_eq!(iter.contacted_peers.get(&peer), Some(&PeerState { - initiated_by: IteratorIndex(0), - response: ResponseState::Succeeded, - })); + assert_eq!( + iter.contacted_peers.get(&peer), + Some(&PeerState { + initiated_by: IteratorIndex(0), + response: ResponseState::Succeeded, + }) + ); // Expect peer to stay marked as succeeded. assert!(!iter.on_failure(&peer)); - assert_eq!(iter.contacted_peers.get(&peer), Some(&PeerState { - initiated_by: IteratorIndex(0), - response: ResponseState::Succeeded, - })); + assert_eq!( + iter.contacted_peers.get(&peer), + Some(&PeerState { + initiated_by: IteratorIndex(0), + response: ResponseState::Succeeded, + }) + ); } } diff --git a/protocols/kad/src/query/peers/fixed.rs b/protocols/kad/src/query/peers/fixed.rs index b816ea9ce0f..e4be4094eb1 100644 --- a/protocols/kad/src/query/peers/fixed.rs +++ b/protocols/kad/src/query/peers/fixed.rs @@ -22,7 +22,7 @@ use super::*; use fnv::FnvHashMap; use libp2p_core::PeerId; -use std::{vec, collections::hash_map::Entry, num::NonZeroUsize}; +use std::{collections::hash_map::Entry, num::NonZeroUsize, vec}; /// A peer iterator for a fixed set of peers. pub struct FixedPeersIter { @@ -42,7 +42,7 @@ pub struct FixedPeersIter { #[derive(Debug, PartialEq, Eq)] enum State { Waiting { num_waiting: usize }, - Finished + Finished, } #[derive(Copy, Clone, PartialEq, Eq)] @@ -60,7 +60,7 @@ enum PeerState { impl FixedPeersIter { pub fn new(peers: I, parallelism: NonZeroUsize) -> Self where - I: IntoIterator + I: IntoIterator, { let peers = peers.into_iter().collect::>(); @@ -87,7 +87,7 @@ impl FixedPeersIter { if let Some(state @ PeerState::Waiting) = self.peers.get_mut(peer) { *state = PeerState::Succeeded; *num_waiting -= 1; - return true + return true; } } false @@ -108,7 +108,7 @@ impl FixedPeersIter { if let Some(state @ PeerState::Waiting) = self.peers.get_mut(peer) { *state = PeerState::Failed; *num_waiting -= 1; - return true + return true; } } false @@ -134,24 +134,26 @@ impl FixedPeersIter { State::Finished => PeersIterState::Finished, State::Waiting { num_waiting } => { if *num_waiting >= self.parallelism.get() { - return PeersIterState::WaitingAtCapacity + return PeersIterState::WaitingAtCapacity; } loop { match self.iter.next() { - None => if *num_waiting == 0 { - self.state = State::Finished; - return PeersIterState::Finished - } else { - return PeersIterState::Waiting(None) + None => { + if *num_waiting == 0 { + self.state = State::Finished; + return PeersIterState::Finished; + } else { + return PeersIterState::Waiting(None); + } } Some(p) => match self.peers.entry(p) { Entry::Occupied(_) => {} // skip duplicates Entry::Vacant(e) => { *num_waiting += 1; e.insert(PeerState::Waiting); - return PeersIterState::Waiting(Some(Cow::Owned(p))) + return PeersIterState::Waiting(Some(Cow::Owned(p))); } - } + }, } } } @@ -159,13 +161,13 @@ impl FixedPeersIter { } pub fn into_result(self) -> impl Iterator { - self.peers.into_iter() - .filter_map(|(p, s)| - if let PeerState::Succeeded = s { - Some(p) - } else { - None - }) + self.peers.into_iter().filter_map(|(p, s)| { + if let PeerState::Succeeded = s { + Some(p) + } else { + None + } + }) } } @@ -184,12 +186,12 @@ mod test { PeersIterState::Waiting(Some(peer)) => { let peer = peer.into_owned(); iter.on_failure(&peer); - }, + } _ => panic!("Expected iterator to yield peer."), } match iter.next() { - PeersIterState::Waiting(Some(_)) => {}, + PeersIterState::Waiting(Some(_)) => {} PeersIterState::WaitingAtCapacity => panic!( "Expected iterator to return another peer given that the \ previous `on_failure` call should have allowed another peer \ diff --git a/protocols/kad/src/record.rs b/protocols/kad/src/record.rs index 5a15fdd1034..8f1c585d1b8 100644 --- a/protocols/kad/src/record.rs +++ b/protocols/kad/src/record.rs @@ -23,7 +23,7 @@ pub mod store; use bytes::Bytes; -use libp2p_core::{PeerId, Multiaddr, multihash::Multihash}; +use libp2p_core::{multihash::Multihash, Multiaddr, PeerId}; use std::borrow::Borrow; use std::hash::{Hash, Hasher}; use wasm_timer::Instant; @@ -85,7 +85,7 @@ impl Record { /// Creates a new record for insertion into the DHT. pub fn new(key: K, value: Vec) -> Self where - K: Into + K: Into, { Record { key: key.into(), @@ -116,7 +116,7 @@ pub struct ProviderRecord { /// The expiration time as measured by a local, monotonic clock. pub expires: Option, /// The known addresses that the provider may be listening on. - pub addresses: Vec + pub addresses: Vec, } impl Hash for ProviderRecord { @@ -138,7 +138,7 @@ impl ProviderRecord { /// Creates a new provider record for insertion into a `RecordStore`. pub fn new(key: K, provider: PeerId, addresses: Vec) -> Self where - K: Into + K: Into, { ProviderRecord { key: key.into(), @@ -157,8 +157,8 @@ impl ProviderRecord { #[cfg(test)] mod tests { use super::*; - use quickcheck::*; use libp2p_core::multihash::Code; + use quickcheck::*; use rand::Rng; use std::time::Duration; @@ -174,7 +174,11 @@ mod tests { Record { key: Key::arbitrary(g), value: Vec::arbitrary(g), - publisher: if g.gen() { Some(PeerId::random()) } else { None }, + publisher: if g.gen() { + Some(PeerId::random()) + } else { + None + }, expires: if g.gen() { Some(Instant::now() + Duration::from_secs(g.gen_range(0, 60))) } else { diff --git a/protocols/kad/src/record/store.rs b/protocols/kad/src/record/store.rs index 82402ed3c18..9347afedd7c 100644 --- a/protocols/kad/src/record/store.rs +++ b/protocols/kad/src/record/store.rs @@ -22,8 +22,8 @@ mod memory; pub use memory::{MemoryStore, MemoryStoreConfig}; -use crate::K_VALUE; use super::*; +use crate::K_VALUE; use std::borrow::Cow; /// The result of an operation on a `RecordStore`. @@ -92,4 +92,3 @@ pub trait RecordStore<'a> { /// Removes a provider record from the store. fn remove_provider(&'a mut self, k: &Key, p: &PeerId); } - diff --git a/protocols/kad/src/record/store/memory.rs b/protocols/kad/src/record/store/memory.rs index d74f32bdfbf..c6a006b6cd5 100644 --- a/protocols/kad/src/record/store/memory.rs +++ b/protocols/kad/src/record/store/memory.rs @@ -90,21 +90,19 @@ impl MemoryStore { /// Retains the records satisfying a predicate. pub fn retain(&mut self, f: F) where - F: FnMut(&Key, &mut Record) -> bool + F: FnMut(&Key, &mut Record) -> bool, { self.records.retain(f); } } impl<'a> RecordStore<'a> for MemoryStore { - type RecordsIter = iter::Map< - hash_map::Values<'a, Key, Record>, - fn(&'a Record) -> Cow<'a, Record> - >; + type RecordsIter = + iter::Map, fn(&'a Record) -> Cow<'a, Record>>; type ProvidedIter = iter::Map< hash_set::Iter<'a, ProviderRecord>, - fn(&'a ProviderRecord) -> Cow<'a, ProviderRecord> + fn(&'a ProviderRecord) -> Cow<'a, ProviderRecord>, >; fn get(&'a self, k: &Key) -> Option> { @@ -113,7 +111,7 @@ impl<'a> RecordStore<'a> for MemoryStore { fn put(&'a mut self, r: Record) -> Result<()> { if r.value.len() >= self.config.max_value_bytes { - return Err(Error::ValueTooLarge) + return Err(Error::ValueTooLarge); } let num_records = self.records.len(); @@ -124,7 +122,7 @@ impl<'a> RecordStore<'a> for MemoryStore { } hash_map::Entry::Vacant(e) => { if num_records >= self.config.max_records { - return Err(Error::MaxRecords) + return Err(Error::MaxRecords); } e.insert(r); } @@ -146,14 +144,15 @@ impl<'a> RecordStore<'a> for MemoryStore { // Obtain the entry let providers = match self.providers.entry(record.key.clone()) { - e@hash_map::Entry::Occupied(_) => e, - e@hash_map::Entry::Vacant(_) => { + e @ hash_map::Entry::Occupied(_) => e, + e @ hash_map::Entry::Vacant(_) => { if self.config.max_provided_keys == num_keys { - return Err(Error::MaxProvidedKeys) + return Err(Error::MaxProvidedKeys); } e } - }.or_insert_with(Default::default); + } + .or_insert_with(Default::default); if let Some(i) = providers.iter().position(|p| p.provider == record.provider) { // In-place update of an existing provider record. @@ -178,8 +177,7 @@ impl<'a> RecordStore<'a> for MemoryStore { self.provided.remove(&p); } } - } - else if providers.len() < self.config.max_providers_per_key { + } else if providers.len() < self.config.max_providers_per_key { // The distance of the new provider to the key is larger than // the distance of any existing provider, but there is still room. if local_key.preimage() == &record.provider { @@ -192,7 +190,9 @@ impl<'a> RecordStore<'a> for MemoryStore { } fn providers(&'a self, key: &Key) -> Vec { - self.providers.get(key).map_or_else(Vec::new, |ps| ps.clone().into_vec()) + self.providers + .get(key) + .map_or_else(Vec::new, |ps| ps.clone().into_vec()) } fn provided(&'a self) -> Self::ProvidedIter { @@ -225,8 +225,7 @@ mod tests { } fn distance(r: &ProviderRecord) -> kbucket::Distance { - kbucket::Key::new(r.key.clone()) - .distance(&kbucket::Key::from(r.provider)) + kbucket::Key::new(r.key.clone()).distance(&kbucket::Key::from(r.provider)) } #[test] @@ -259,9 +258,10 @@ mod tests { let mut store = MemoryStore::new(PeerId::random()); let key = Key::from(random_multihash()); - let mut records = providers.into_iter().map(|p| { - ProviderRecord::new(key.clone(), p.into_preimage(), Vec::new()) - }).collect::>(); + let mut records = providers + .into_iter() + .map(|p| ProviderRecord::new(key.clone(), p.into_preimage(), Vec::new())) + .collect::>(); for r in &records { assert!(store.add_provider(r.clone()).is_ok()); @@ -283,7 +283,10 @@ mod tests { let key = random_multihash(); let rec = ProviderRecord::new(key, id.clone(), Vec::new()); assert!(store.add_provider(rec.clone()).is_ok()); - assert_eq!(vec![Cow::Borrowed(&rec)], store.provided().collect::>()); + assert_eq!( + vec![Cow::Borrowed(&rec)], + store.provided().collect::>() + ); store.remove_provider(&rec.key, &id); assert_eq!(store.provided().count(), 0); } @@ -304,7 +307,7 @@ mod tests { #[test] fn max_provided_keys() { let mut store = MemoryStore::new(PeerId::random()); - for _ in 0 .. store.config.max_provided_keys { + for _ in 0..store.config.max_provided_keys { let key = random_multihash(); let prv = PeerId::random(); let rec = ProviderRecord::new(key, prv, Vec::new()); diff --git a/protocols/mdns/src/behaviour.rs b/protocols/mdns/src/behaviour.rs index 7348227c4bb..2a170e2d839 100644 --- a/protocols/mdns/src/behaviour.rs +++ b/protocols/mdns/src/behaviour.rs @@ -18,16 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::IPV4_MDNS_MULTICAST_ADDRESS; use crate::dns::{build_query, build_query_response, build_service_discovery_response}; use crate::query::MdnsPacket; +use crate::IPV4_MDNS_MULTICAST_ADDRESS; use async_io::{Async, Timer}; use futures::prelude::*; use if_watch::{IfEvent, IfWatcher}; use libp2p_core::connection::ListenerId; -use libp2p_core::{ - address_translation, multiaddr::Protocol, Multiaddr, PeerId, -}; +use libp2p_core::{address_translation, multiaddr::Protocol, Multiaddr, PeerId}; use libp2p_swarm::{ protocols_handler::DummyProtocolsHandler, NetworkBehaviour, NetworkBehaviourAction, PollParameters, ProtocolsHandler, diff --git a/protocols/ping/src/handler.rs b/protocols/ping/src/handler.rs index ebfc5a0b1a5..1c4233e2b22 100644 --- a/protocols/ping/src/handler.rs +++ b/protocols/ping/src/handler.rs @@ -19,28 +19,23 @@ // DEALINGS IN THE SOFTWARE. use crate::protocol; -use futures::prelude::*; use futures::future::BoxFuture; -use libp2p_core::{UpgradeError, upgrade::NegotiationError}; +use futures::prelude::*; +use libp2p_core::{upgrade::NegotiationError, UpgradeError}; use libp2p_swarm::{ - KeepAlive, - NegotiatedSubstream, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerUpgrErr, - ProtocolsHandlerEvent + KeepAlive, NegotiatedSubstream, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use std::collections::VecDeque; use std::{ error::Error, - io, - fmt, + fmt, io, num::NonZeroU32, task::{Context, Poll}, - time::Duration + time::Duration, }; -use std::collections::VecDeque; -use wasm_timer::Delay; use void::Void; +use wasm_timer::Delay; /// The configuration for outbound pings. #[derive(Clone, Debug)] @@ -82,7 +77,7 @@ impl PingConfig { timeout: Duration::from_secs(20), interval: Duration::from_secs(15), max_failures: NonZeroU32::new(1).expect("1 != 0"), - keep_alive: false + keep_alive: false, } } @@ -144,7 +139,9 @@ pub enum PingFailure { /// The peer does not support the ping protocol. Unsupported, /// The ping failed for reasons other than a timeout. - Other { error: Box } + Other { + error: Box, + }, } impl fmt::Display for PingFailure { @@ -190,7 +187,7 @@ pub struct PingHandler { /// next inbound ping to be answered. inbound: Option, /// Tracks the state of our handler. - state: State + state: State, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -200,7 +197,7 @@ enum State { /// Whether or not we've reported the missing support yet. /// /// This is used to avoid repeated events being emitted for a specific connection. - reported: bool + reported: bool, }, /// We are actively pinging the other peer. Active, @@ -252,11 +249,9 @@ impl ProtocolsHandler for PingHandler { ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { debug_assert_eq!(self.state, State::Active); - self.state = State::Inactive { - reported: false - }; + self.state = State::Inactive { reported: false }; return; - }, + } // Note: This timeout only covers protocol negotiation. ProtocolsHandlerUpgrErr::Timeout => PingFailure::Timeout, e => PingFailure::Other { error: Box::new(e) }, @@ -273,22 +268,25 @@ impl ProtocolsHandler for PingHandler { } } - fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { match self.state { State::Inactive { reported: true } => { - return Poll::Pending // nothing to do on this connection - }, + return Poll::Pending; // nothing to do on this connection + } State::Inactive { reported: false } => { self.state = State::Inactive { reported: true }; return Poll::Ready(ProtocolsHandlerEvent::Custom(Err(PingFailure::Unsupported))); - }, + } State::Active => {} } // Respond to inbound pings. if let Some(fut) = self.inbound.as_mut() { match fut.poll_unpin(cx) { - Poll::Pending => {}, + Poll::Pending => {} Poll::Ready(Err(e)) => { log::debug!("Inbound ping error: {:?}", e); self.inbound = None; @@ -296,7 +294,7 @@ impl ProtocolsHandler for PingHandler { Poll::Ready(Ok(stream)) => { // A ping from a remote peer has been answered, wait for the next. self.inbound = Some(protocol::recv_ping(stream).boxed()); - return Poll::Ready(ProtocolsHandlerEvent::Custom(Ok(PingSuccess::Pong))) + return Poll::Ready(ProtocolsHandlerEvent::Custom(Ok(PingSuccess::Pong))); } } } @@ -318,10 +316,10 @@ impl ProtocolsHandler for PingHandler { if self.failures > 1 || self.config.max_failures.get() > 1 { if self.failures >= self.config.max_failures.get() { log::debug!("Too many failures ({}). Closing connection.", self.failures); - return Poll::Ready(ProtocolsHandlerEvent::Close(error)) + return Poll::Ready(ProtocolsHandlerEvent::Close(error)); } - return Poll::Ready(ProtocolsHandlerEvent::Custom(Err(error))) + return Poll::Ready(ProtocolsHandlerEvent::Custom(Err(error))); } } @@ -333,50 +331,48 @@ impl ProtocolsHandler for PingHandler { self.pending_errors.push_front(PingFailure::Timeout); } else { self.outbound = Some(PingState::Ping(ping)); - break + break; } - }, + } Poll::Ready(Ok((stream, rtt))) => { self.failures = 0; self.timer.reset(self.config.interval); self.outbound = Some(PingState::Idle(stream)); - return Poll::Ready( - ProtocolsHandlerEvent::Custom( - Ok(PingSuccess::Ping { rtt }))) + return Poll::Ready(ProtocolsHandlerEvent::Custom(Ok(PingSuccess::Ping { + rtt, + }))); } Poll::Ready(Err(e)) => { - self.pending_errors.push_front(PingFailure::Other { - error: Box::new(e) - }); + self.pending_errors + .push_front(PingFailure::Other { error: Box::new(e) }); } }, Some(PingState::Idle(stream)) => match self.timer.poll_unpin(cx) { Poll::Pending => { self.outbound = Some(PingState::Idle(stream)); - break - }, + break; + } Poll::Ready(Ok(())) => { self.timer.reset(self.config.timeout); self.outbound = Some(PingState::Ping(protocol::send_ping(stream).boxed())); - }, + } Poll::Ready(Err(e)) => { - return Poll::Ready(ProtocolsHandlerEvent::Close( - PingFailure::Other { - error: Box::new(e) - })) + return Poll::Ready(ProtocolsHandlerEvent::Close(PingFailure::Other { + error: Box::new(e), + })) } - } + }, Some(PingState::OpenStream) => { self.outbound = Some(PingState::OpenStream); - break + break; } None => { self.outbound = Some(PingState::OpenStream); let protocol = SubstreamProtocol::new(protocol::Ping, ()) .with_timeout(self.config.timeout); return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol - }) + protocol, + }); } } } diff --git a/protocols/ping/src/lib.rs b/protocols/ping/src/lib.rs index cd9cc227c9d..d4e3828f430 100644 --- a/protocols/ping/src/lib.rs +++ b/protocols/ping/src/lib.rs @@ -40,13 +40,13 @@ //! [`Swarm`]: libp2p_swarm::Swarm //! [`Transport`]: libp2p_core::Transport -pub mod protocol; pub mod handler; +pub mod protocol; -pub use handler::{PingConfig, PingResult, PingSuccess, PingFailure}; use handler::PingHandler; +pub use handler::{PingConfig, PingFailure, PingResult, PingSuccess}; -use libp2p_core::{PeerId, connection::ConnectionId}; +use libp2p_core::{connection::ConnectionId, PeerId}; use libp2p_swarm::{NetworkBehaviour, NetworkBehaviourAction, PollParameters}; use std::{collections::VecDeque, task::Context, task::Poll}; use void::Void; @@ -99,9 +99,11 @@ impl NetworkBehaviour for Ping { self.events.push_front(PingEvent { peer, result }) } - fn poll(&mut self, _: &mut Context<'_>, _: &mut impl PollParameters) - -> Poll> - { + fn poll( + &mut self, + _: &mut Context<'_>, + _: &mut impl PollParameters, + ) -> Poll> { if let Some(e) = self.events.pop_back() { Poll::Ready(NetworkBehaviourAction::GenerateEvent(e)) } else { diff --git a/protocols/ping/src/protocol.rs b/protocols/ping/src/protocol.rs index aa63833f651..a3138568777 100644 --- a/protocols/ping/src/protocol.rs +++ b/protocols/ping/src/protocol.rs @@ -82,7 +82,7 @@ impl OutboundUpgrade for Ping { /// Sends a ping and waits for the pong. pub async fn send_ping(mut stream: S) -> io::Result<(S, Duration)> where - S: AsyncRead + AsyncWrite + Unpin + S: AsyncRead + AsyncWrite + Unpin, { let payload: [u8; PING_SIZE] = thread_rng().sample(distributions::Standard); log::debug!("Preparing ping payload {:?}", payload); @@ -95,14 +95,17 @@ where if recv_payload == payload { Ok((stream, started.elapsed())) } else { - Err(io::Error::new(io::ErrorKind::InvalidData, "Ping payload mismatch")) + Err(io::Error::new( + io::ErrorKind::InvalidData, + "Ping payload mismatch", + )) } } /// Waits for a ping and sends a pong. pub async fn recv_ping(mut stream: S) -> io::Result where - S: AsyncRead + AsyncWrite + Unpin + S: AsyncRead + AsyncWrite + Unpin, { let mut payload = [0u8; PING_SIZE]; log::debug!("Waiting for ping ..."); @@ -118,11 +121,7 @@ mod tests { use super::*; use libp2p_core::{ multiaddr::multiaddr, - transport::{ - Transport, - ListenerEvent, - memory::MemoryTransport - } + transport::{memory::MemoryTransport, ListenerEvent, Transport}, }; use rand::{thread_rng, Rng}; use std::time::Duration; diff --git a/protocols/ping/tests/ping.rs b/protocols/ping/tests/ping.rs index 8c394757f9c..779bbbc0069 100644 --- a/protocols/ping/tests/ping.rs +++ b/protocols/ping/tests/ping.rs @@ -20,13 +20,12 @@ //! Integration tests for the `Ping` network behaviour. +use futures::{channel::mpsc, prelude::*}; use libp2p_core::{ - Multiaddr, - PeerId, identity, muxing::StreamMuxerBox, transport::{self, Transport}, - upgrade + upgrade, Multiaddr, PeerId, }; use libp2p_mplex as mplex; use libp2p_noise as noise; @@ -34,7 +33,6 @@ use libp2p_ping::*; use libp2p_swarm::{DummyBehaviour, KeepAlive, Swarm, SwarmEvent}; use libp2p_tcp::TcpConfig; use libp2p_yamux as yamux; -use futures::{prelude::*, channel::mpsc}; use quickcheck::*; use rand::prelude::*; use std::{num::NonZeroU8, time::Duration}; @@ -65,13 +63,18 @@ fn ping_pong() { loop { match swarm1.select_next_some().await { SwarmEvent::NewListenAddr { address, .. } => tx.send(address).await.unwrap(), - SwarmEvent::Behaviour(PingEvent { peer, result: Ok(PingSuccess::Ping { rtt }) }) => { + SwarmEvent::Behaviour(PingEvent { + peer, + result: Ok(PingSuccess::Ping { rtt }), + }) => { count1 -= 1; if count1 == 0 { - return (pid1.clone(), peer, rtt) + return (pid1.clone(), peer, rtt); } - }, - SwarmEvent::Behaviour(PingEvent { result: Err(e), .. }) => panic!("Ping failure: {:?}", e), + } + SwarmEvent::Behaviour(PingEvent { result: Err(e), .. }) => { + panic!("Ping failure: {:?}", e) + } _ => {} } } @@ -85,17 +88,16 @@ fn ping_pong() { match swarm2.select_next_some().await { SwarmEvent::Behaviour(PingEvent { peer, - result: Ok(PingSuccess::Ping { rtt }) + result: Ok(PingSuccess::Ping { rtt }), }) => { count2 -= 1; if count2 == 0 { - return (pid2.clone(), peer, rtt) + return (pid2.clone(), peer, rtt); } - }, - SwarmEvent::Behaviour(PingEvent { - result: Err(e), - .. - }) => panic!("Ping failure: {:?}", e), + } + SwarmEvent::Behaviour(PingEvent { result: Err(e), .. }) => { + panic!("Ping failure: {:?}", e) + } _ => {} } } @@ -107,7 +109,7 @@ fn ping_pong() { assert!(rtt < Duration::from_millis(50)); } - QuickCheck::new().tests(10).quickcheck(prop as fn(_,_)) + QuickCheck::new().tests(10).quickcheck(prop as fn(_, _)) } /// Tests that the connection is closed upon a configurable @@ -139,18 +141,15 @@ fn max_failures() { match swarm1.select_next_some().await { SwarmEvent::NewListenAddr { address, .. } => tx.send(address).await.unwrap(), SwarmEvent::Behaviour(PingEvent { - result: Ok(PingSuccess::Ping { .. }), .. + result: Ok(PingSuccess::Ping { .. }), + .. }) => { count1 = 0; // there may be an occasional success } - SwarmEvent::Behaviour(PingEvent { - result: Err(_), .. - }) => { + SwarmEvent::Behaviour(PingEvent { result: Err(_), .. }) => { count1 += 1; } - SwarmEvent::ConnectionClosed { .. } => { - return count1 - } + SwarmEvent::ConnectionClosed { .. } => return count1, _ => {} } } @@ -164,18 +163,15 @@ fn max_failures() { loop { match swarm2.select_next_some().await { SwarmEvent::Behaviour(PingEvent { - result: Ok(PingSuccess::Ping { .. }), .. + result: Ok(PingSuccess::Ping { .. }), + .. }) => { count2 = 0; // there may be an occasional success } - SwarmEvent::Behaviour(PingEvent { - result: Err(_), .. - }) => { + SwarmEvent::Behaviour(PingEvent { result: Err(_), .. }) => { count2 += 1; } - SwarmEvent::ConnectionClosed { .. } => { - return count2 - } + SwarmEvent::ConnectionClosed { .. } => return count2, _ => {} } } @@ -186,16 +182,24 @@ fn max_failures() { assert_eq!(u8::max(count1, count2), max_failures.get() - 1); } - QuickCheck::new().tests(10).quickcheck(prop as fn(_,_)) + QuickCheck::new().tests(10).quickcheck(prop as fn(_, _)) } #[test] fn unsupported_doesnt_fail() { let (peer1_id, trans) = mk_transport(MuxerChoice::Mplex); - let mut swarm1 = Swarm::new(trans, DummyBehaviour::with_keep_alive(KeepAlive::Yes), peer1_id.clone()); + let mut swarm1 = Swarm::new( + trans, + DummyBehaviour::with_keep_alive(KeepAlive::Yes), + peer1_id.clone(), + ); let (peer2_id, trans) = mk_transport(MuxerChoice::Mplex); - let mut swarm2 = Swarm::new(trans, Ping::new(PingConfig::new().with_keep_alive(true)), peer2_id.clone()); + let mut swarm2 = Swarm::new( + trans, + Ping::new(PingConfig::new().with_keep_alive(true)), + peer2_id.clone(), + ); let (mut tx, mut rx) = mpsc::channel::(1); @@ -217,7 +221,8 @@ fn unsupported_doesnt_fail() { loop { match swarm2.select_next_some().await { SwarmEvent::Behaviour(PingEvent { - result: Err(PingFailure::Unsupported), .. + result: Err(PingFailure::Unsupported), + .. }) => { swarm2.disconnect_peer_id(peer1_id).unwrap(); } @@ -235,25 +240,24 @@ fn unsupported_doesnt_fail() { result.expect("node with ping should not fail connection due to unsupported protocol"); } - -fn mk_transport(muxer: MuxerChoice) -> ( - PeerId, - transport::Boxed<(PeerId, StreamMuxerBox)> -) { +fn mk_transport(muxer: MuxerChoice) -> (PeerId, transport::Boxed<(PeerId, StreamMuxerBox)>) { let id_keys = identity::Keypair::generate_ed25519(); let peer_id = id_keys.public().to_peer_id(); - let noise_keys = noise::Keypair::::new().into_authentic(&id_keys).unwrap(); - (peer_id, TcpConfig::new() - .nodelay(true) - .upgrade() - .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) - .multiplex(match muxer { - MuxerChoice::Yamux => - upgrade::EitherUpgrade::A(yamux::YamuxConfig::default()), - MuxerChoice::Mplex => - upgrade::EitherUpgrade::B(mplex::MplexConfig::default()), - }) - .boxed()) + let noise_keys = noise::Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); + ( + peer_id, + TcpConfig::new() + .nodelay(true) + .upgrade() + .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) + .multiplex(match muxer { + MuxerChoice::Yamux => upgrade::EitherUpgrade::A(yamux::YamuxConfig::default()), + MuxerChoice::Mplex => upgrade::EitherUpgrade::B(mplex::MplexConfig::default()), + }) + .boxed(), + ) } #[derive(Debug, Copy, Clone)] diff --git a/protocols/relay/build.rs b/protocols/relay/build.rs index cd7bd3deef6..c3a7d4bd823 100644 --- a/protocols/relay/build.rs +++ b/protocols/relay/build.rs @@ -19,5 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/message.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/message.proto"], &["src"]).unwrap(); } diff --git a/protocols/relay/examples/relay.rs b/protocols/relay/examples/relay.rs index 3299aca8b3c..486b2d901a4 100644 --- a/protocols/relay/examples/relay.rs +++ b/protocols/relay/examples/relay.rs @@ -60,13 +60,13 @@ use futures::executor::block_on; use futures::stream::StreamExt; use libp2p::dns::DnsConfig; +use libp2p::identity::{self, ed25519}; use libp2p::ping::{Ping, PingConfig, PingEvent}; use libp2p::plaintext; use libp2p::relay::{Relay, RelayConfig}; use libp2p::swarm::SwarmEvent; use libp2p::tcp::TcpConfig; use libp2p::Transport; -use libp2p::identity::{self, ed25519}; use libp2p::{NetworkBehaviour, PeerId, Swarm}; use std::error::Error; use std::task::{Context, Poll}; diff --git a/protocols/relay/src/behaviour.rs b/protocols/relay/src/behaviour.rs index 78e9e5d8d66..9b17eca2c51 100644 --- a/protocols/relay/src/behaviour.rs +++ b/protocols/relay/src/behaviour.rs @@ -303,7 +303,7 @@ impl NetworkBehaviour for Relay { fn inject_dial_failure(&mut self, peer_id: &PeerId) { if let Entry::Occupied(o) = self.listeners.entry(*peer_id) { - if matches!(o.get(), RelayListener::Connecting{ .. }) { + if matches!(o.get(), RelayListener::Connecting { .. }) { // By removing the entry, the channel to the listener is dropped and thus the // listener is notified that dialing the relay failed. o.remove_entry(); diff --git a/protocols/relay/src/protocol/incoming_dst_req.rs b/protocols/relay/src/protocol/incoming_dst_req.rs index b3b0ded9de6..d68a15121f5 100644 --- a/protocols/relay/src/protocol/incoming_dst_req.rs +++ b/protocols/relay/src/protocol/incoming_dst_req.rs @@ -23,8 +23,8 @@ use crate::protocol::Peer; use asynchronous_codec::{Framed, FramedParts}; use bytes::BytesMut; -use futures::{future::BoxFuture, prelude::*}; use futures::channel::oneshot; +use futures::{future::BoxFuture, prelude::*}; use libp2p_core::{Multiaddr, PeerId}; use libp2p_swarm::NegotiatedSubstream; use prost::Message; @@ -47,8 +47,7 @@ pub struct IncomingDstReq { src: Peer, } -impl IncomingDstReq -{ +impl IncomingDstReq { /// Creates a `IncomingDstReq`. pub(crate) fn new(stream: Framed, src: Peer) -> Self { IncomingDstReq { @@ -73,7 +72,10 @@ impl IncomingDstReq /// stream then points to the source (as retreived with `src_id()` and `src_addrs()`). pub fn accept( self, - ) -> BoxFuture<'static, Result<(PeerId, super::Connection, oneshot::Receiver<()>), IncomingDstReqError>> { + ) -> BoxFuture< + 'static, + Result<(PeerId, super::Connection, oneshot::Receiver<()>), IncomingDstReqError>, + > { let IncomingDstReq { mut stream, src } = self; let msg = CircuitRelay { r#type: Some(circuit_relay::Type::Status.into()), @@ -101,7 +103,11 @@ impl IncomingDstReq let (tx, rx) = oneshot::channel(); - Ok((src.peer_id, super::Connection::new(read_buffer.freeze(), io, tx), rx)) + Ok(( + src.peer_id, + super::Connection::new(read_buffer.freeze(), io, tx), + rx, + )) } .boxed() } diff --git a/protocols/relay/src/protocol/incoming_relay_req.rs b/protocols/relay/src/protocol/incoming_relay_req.rs index 6f585db2854..948a2281f5b 100644 --- a/protocols/relay/src/protocol/incoming_relay_req.rs +++ b/protocols/relay/src/protocol/incoming_relay_req.rs @@ -23,7 +23,7 @@ use crate::message_proto::{circuit_relay, circuit_relay::Status, CircuitRelay}; use crate::protocol::Peer; use asynchronous_codec::{Framed, FramedParts}; -use bytes::{BytesMut, Bytes}; +use bytes::{Bytes, BytesMut}; use futures::channel::oneshot; use futures::future::BoxFuture; use futures::prelude::*; @@ -50,8 +50,7 @@ pub struct IncomingRelayReq { _notifier: oneshot::Sender<()>, } -impl IncomingRelayReq -{ +impl IncomingRelayReq { /// Creates a [`IncomingRelayReq`] as well as a Future that resolves once the /// [`IncomingRelayReq`] is dropped. pub(crate) fn new( diff --git a/protocols/relay/src/protocol/outgoing_dst_req.rs b/protocols/relay/src/protocol/outgoing_dst_req.rs index 7cffb1a1d96..181e31ef4a8 100644 --- a/protocols/relay/src/protocol/outgoing_dst_req.rs +++ b/protocols/relay/src/protocol/outgoing_dst_req.rs @@ -27,7 +27,7 @@ use futures::prelude::*; use libp2p_core::{upgrade, Multiaddr, PeerId}; use libp2p_swarm::NegotiatedSubstream; use prost::Message; -use std::{fmt, error, iter}; +use std::{error, fmt, iter}; use unsigned_varint::codec::UviBytes; /// Ask the remote to become a destination. The upgrade succeeds if the remote accepts, and fails @@ -96,14 +96,9 @@ impl upgrade::OutboundUpgrade for OutgoingDstReq { async move { substream.send(std::io::Cursor::new(self.message)).await?; - let msg = - substream - .next() - .await - .ok_or_else(|| OutgoingDstReqError::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - )))??; + let msg = substream.next().await.ok_or_else(|| { + OutgoingDstReqError::Io(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "")) + })??; let msg = std::io::Cursor::new(msg); let CircuitRelay { diff --git a/protocols/relay/src/protocol/outgoing_relay_req.rs b/protocols/relay/src/protocol/outgoing_relay_req.rs index a34f10eba26..a9d450b04d7 100644 --- a/protocols/relay/src/protocol/outgoing_relay_req.rs +++ b/protocols/relay/src/protocol/outgoing_relay_req.rs @@ -103,14 +103,12 @@ impl upgrade::OutboundUpgrade for OutgoingRelayReq { async move { substream.send(std::io::Cursor::new(encoded)).await?; - let msg = - substream - .next() - .await - .ok_or_else(|| OutgoingRelayReqError::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - )))??; + let msg = substream.next().await.ok_or_else(|| { + OutgoingRelayReqError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "", + )) + })??; let msg = std::io::Cursor::new(msg); let CircuitRelay { diff --git a/protocols/relay/src/transport.rs b/protocols/relay/src/transport.rs index b3729c0582f..b410faf0af4 100644 --- a/protocols/relay/src/transport.rs +++ b/protocols/relay/src/transport.rs @@ -402,7 +402,7 @@ impl Stream for RelayListener { stream, src_peer_id, relay_addr, - relay_peer_id: _ + relay_peer_id: _, })) => { return Poll::Ready(Some(Ok(ListenerEvent::Upgrade { upgrade: RelayedListenerUpgrade::Relayed(Some(stream)), diff --git a/protocols/relay/tests/lib.rs b/protocols/relay/tests/lib.rs index 70d6d098669..82b62a6a3c4 100644 --- a/protocols/relay/tests/lib.rs +++ b/protocols/relay/tests/lib.rs @@ -35,7 +35,10 @@ use libp2p_ping::{Ping, PingConfig, PingEvent}; use libp2p_plaintext::PlainText2Config; use libp2p_relay::{Relay, RelayConfig}; use libp2p_swarm::protocols_handler::KeepAlive; -use libp2p_swarm::{DummyBehaviour, NetworkBehaviour, NetworkBehaviourAction, NetworkBehaviourEventProcess, PollParameters, Swarm, SwarmEvent}; +use libp2p_swarm::{ + DummyBehaviour, NetworkBehaviour, NetworkBehaviourAction, NetworkBehaviourEventProcess, + PollParameters, Swarm, SwarmEvent, +}; use std::task::{Context, Poll}; use std::time::Duration; use void::Void; @@ -388,9 +391,9 @@ fn src_try_connect_to_offline_dst() { loop { match src_swarm.select_next_some().await { - SwarmEvent::UnreachableAddr { address, peer_id, .. } - if address == dst_addr_via_relay => - { + SwarmEvent::UnreachableAddr { + address, peer_id, .. + } if address == dst_addr_via_relay => { assert_eq!(peer_id, dst_peer_id); break; } @@ -445,9 +448,9 @@ fn src_try_connect_to_unsupported_dst() { loop { match src_swarm.select_next_some().await { - SwarmEvent::UnreachableAddr { address, peer_id, .. } - if address == dst_addr_via_relay => - { + SwarmEvent::UnreachableAddr { + address, peer_id, .. + } if address == dst_addr_via_relay => { assert_eq!(peer_id, dst_peer_id); break; } @@ -495,10 +498,11 @@ fn src_try_connect_to_offline_dst_via_offline_relay() { // Source Node fail to reach Destination Node due to failure reaching Relay. match src_swarm.select_next_some().await { - SwarmEvent::UnreachableAddr { address, peer_id, .. } - if address == dst_addr_via_relay => { - assert_eq!(peer_id, dst_peer_id); - } + SwarmEvent::UnreachableAddr { + address, peer_id, .. + } if address == dst_addr_via_relay => { + assert_eq!(peer_id, dst_peer_id); + } e => panic!("{:?}", e), } }); @@ -582,11 +586,13 @@ fn firewalled_src_discover_firewalled_dst_via_kad_and_connect_to_dst_via_routabl let query_id = dst_swarm.behaviour_mut().kad.bootstrap().unwrap(); loop { match dst_swarm.select_next_some().await { - SwarmEvent::Behaviour(CombinedEvent::Kad(KademliaEvent::OutboundQueryCompleted { - id, - result: QueryResult::Bootstrap(Ok(_)), - .. - })) if query_id == id => { + SwarmEvent::Behaviour(CombinedEvent::Kad( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::Bootstrap(Ok(_)), + .. + }, + )) if query_id == id => { if dst_swarm.behaviour_mut().kad.iter_queries().count() == 0 { break; } @@ -660,11 +666,13 @@ fn firewalled_src_discover_firewalled_dst_via_kad_and_connect_to_dst_via_routabl SwarmEvent::Dialing(peer_id) if peer_id == relay_peer_id || peer_id == dst_peer_id => {} SwarmEvent::Behaviour(CombinedEvent::Ping(_)) => {} - SwarmEvent::Behaviour(CombinedEvent::Kad(KademliaEvent::OutboundQueryCompleted { - id, - result: QueryResult::GetClosestPeers(Ok(GetClosestPeersOk { .. })), - .. - })) if id == query_id => { + SwarmEvent::Behaviour(CombinedEvent::Kad( + KademliaEvent::OutboundQueryCompleted { + id, + result: QueryResult::GetClosestPeers(Ok(GetClosestPeersOk { .. })), + .. + }, + )) if id == query_id => { tries += 1; if tries > 300 { panic!("Too many retries."); @@ -929,8 +937,12 @@ fn yield_incoming_connection_through_correct_listener() { relay_3_swarm.listen_on(relay_3_addr.clone()).unwrap(); spawn_swarm_on_pool(&pool, relay_3_swarm); - let dst_listener_via_relay_1 = dst_swarm.listen_on(relay_1_addr_incl_circuit.clone()).unwrap(); - let dst_listener_via_relay_2 = dst_swarm.listen_on(relay_2_addr_incl_circuit.clone()).unwrap(); + let dst_listener_via_relay_1 = dst_swarm + .listen_on(relay_1_addr_incl_circuit.clone()) + .unwrap(); + let dst_listener_via_relay_2 = dst_swarm + .listen_on(relay_2_addr_incl_circuit.clone()) + .unwrap(); // Listen on own address in order for relay 3 to be able to connect to destination node. let dst_listener = dst_swarm.listen_on(dst_addr.clone()).unwrap(); @@ -952,11 +964,15 @@ fn yield_incoming_connection_through_correct_listener() { SwarmEvent::NewListenAddr { address, listener_id, - } if listener_id == dst_listener_via_relay_2 => assert_eq!(address, relay_2_addr_incl_circuit), + } if listener_id == dst_listener_via_relay_2 => { + assert_eq!(address, relay_2_addr_incl_circuit) + } SwarmEvent::NewListenAddr { address, listener_id, - } if listener_id == dst_listener_via_relay_1 => assert_eq!(address, relay_1_addr_incl_circuit), + } if listener_id == dst_listener_via_relay_1 => { + assert_eq!(address, relay_1_addr_incl_circuit) + } SwarmEvent::NewListenAddr { address, listener_id, @@ -1077,7 +1093,11 @@ fn yield_incoming_connection_through_correct_listener() { pool.run_until(async { loop { match dst_swarm.select_next_some().await { - SwarmEvent::NewListenAddr { address, .. } if address == Protocol::P2pCircuit.into() => break, + SwarmEvent::NewListenAddr { address, .. } + if address == Protocol::P2pCircuit.into() => + { + break + } SwarmEvent::Behaviour(CombinedEvent::Ping(_)) => {} SwarmEvent::Behaviour(CombinedEvent::Kad(KademliaEvent::RoutingUpdated { .. @@ -1325,7 +1345,11 @@ fn build_keep_alive_only_swarm() -> Swarm { .multiplex(libp2p_yamux::YamuxConfig::default()) .boxed(); - Swarm::new(transport, DummyBehaviour::with_keep_alive(KeepAlive::Yes), local_peer_id) + Swarm::new( + transport, + DummyBehaviour::with_keep_alive(KeepAlive::Yes), + local_peer_id, + ) } fn spawn_swarm_on_pool(pool: &LocalPool, mut swarm: Swarm) { diff --git a/protocols/request-response/src/codec.rs b/protocols/request-response/src/codec.rs index bbb708081dc..5345d200843 100644 --- a/protocols/request-response/src/codec.rs +++ b/protocols/request-response/src/codec.rs @@ -38,30 +38,43 @@ pub trait RequestResponseCodec { /// Reads a request from the given I/O stream according to the /// negotiated protocol. - async fn read_request(&mut self, protocol: &Self::Protocol, io: &mut T) - -> io::Result + async fn read_request( + &mut self, + protocol: &Self::Protocol, + io: &mut T, + ) -> io::Result where T: AsyncRead + Unpin + Send; /// Reads a response from the given I/O stream according to the /// negotiated protocol. - async fn read_response(&mut self, protocol: &Self::Protocol, io: &mut T) - -> io::Result + async fn read_response( + &mut self, + protocol: &Self::Protocol, + io: &mut T, + ) -> io::Result where T: AsyncRead + Unpin + Send; /// Writes a request to the given I/O stream according to the /// negotiated protocol. - async fn write_request(&mut self, protocol: &Self::Protocol, io: &mut T, req: Self::Request) - -> io::Result<()> + async fn write_request( + &mut self, + protocol: &Self::Protocol, + io: &mut T, + req: Self::Request, + ) -> io::Result<()> where T: AsyncWrite + Unpin + Send; /// Writes a response to the given I/O stream according to the /// negotiated protocol. - async fn write_response(&mut self, protocol: &Self::Protocol, io: &mut T, res: Self::Response) - -> io::Result<()> + async fn write_response( + &mut self, + protocol: &Self::Protocol, + io: &mut T, + res: Self::Response, + ) -> io::Result<()> where T: AsyncWrite + Unpin + Send; } - diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index ddb9f042dd4..ee2550df183 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -20,37 +20,29 @@ mod protocol; -use crate::{EMPTY_QUEUE_SHRINK_THRESHOLD, RequestId}; use crate::codec::RequestResponseCodec; +use crate::{RequestId, EMPTY_QUEUE_SHRINK_THRESHOLD}; -pub use protocol::{RequestProtocol, ResponseProtocol, ProtocolSupport}; +pub use protocol::{ProtocolSupport, RequestProtocol, ResponseProtocol}; -use futures::{ - channel::oneshot, - future::BoxFuture, - prelude::*, - stream::FuturesUnordered -}; -use libp2p_core::{ - upgrade::{UpgradeError, NegotiationError}, -}; +use futures::{channel::oneshot, future::BoxFuture, prelude::*, stream::FuturesUnordered}; +use libp2p_core::upgrade::{NegotiationError, UpgradeError}; use libp2p_swarm::{ - SubstreamProtocol, protocols_handler::{ - KeepAlive, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, - } + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, + }, + SubstreamProtocol, }; use smallvec::SmallVec; use std::{ collections::VecDeque, - fmt, - io, - sync::{atomic::{AtomicU64, Ordering}, Arc}, + fmt, io, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + task::{Context, Poll}, time::Duration, - task::{Context, Poll} }; use wasm_timer::Instant; @@ -79,12 +71,19 @@ where /// Outbound upgrades waiting to be emitted as an `OutboundSubstreamRequest`. outbound: VecDeque>, /// Inbound upgrades waiting for the incoming request. - inbound: FuturesUnordered), - oneshot::Canceled - >>>, - inbound_request_id: Arc + inbound: FuturesUnordered< + BoxFuture< + 'static, + Result< + ( + (RequestId, TCodec::Request), + oneshot::Sender, + ), + oneshot::Canceled, + >, + >, + >, + inbound_request_id: Arc, } impl RequestResponseHandler @@ -96,7 +95,7 @@ where codec: TCodec, keep_alive_timeout: Duration, substream_timeout: Duration, - inbound_request_id: Arc + inbound_request_id: Arc, ) -> Self { Self { inbound_protocols, @@ -108,7 +107,7 @@ where inbound: FuturesUnordered::new(), pending_events: VecDeque::new(), pending_error: None, - inbound_request_id + inbound_request_id, } } } @@ -117,18 +116,18 @@ where #[doc(hidden)] pub enum RequestResponseHandlerEvent where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { /// A request has been received. Request { request_id: RequestId, request: TCodec::Request, - sender: oneshot::Sender + sender: oneshot::Sender, }, /// A response has been received. Response { request_id: RequestId, - response: TCodec::Response + response: TCodec::Response, }, /// A response to an inbound request has been sent. ResponseSent(RequestId), @@ -150,28 +149,43 @@ where impl fmt::Debug for RequestResponseHandlerEvent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - RequestResponseHandlerEvent::Request { request_id, request: _, sender: _ } => f.debug_struct("RequestResponseHandlerEvent::Request") + RequestResponseHandlerEvent::Request { + request_id, + request: _, + sender: _, + } => f + .debug_struct("RequestResponseHandlerEvent::Request") .field("request_id", request_id) .finish(), - RequestResponseHandlerEvent::Response { request_id, response: _ } => f.debug_struct("RequestResponseHandlerEvent::Response") + RequestResponseHandlerEvent::Response { + request_id, + response: _, + } => f + .debug_struct("RequestResponseHandlerEvent::Response") .field("request_id", request_id) .finish(), - RequestResponseHandlerEvent::ResponseSent(request_id) => f.debug_tuple("RequestResponseHandlerEvent::ResponseSent") + RequestResponseHandlerEvent::ResponseSent(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::ResponseSent") .field(request_id) .finish(), - RequestResponseHandlerEvent::ResponseOmission(request_id) => f.debug_tuple("RequestResponseHandlerEvent::ResponseOmission") + RequestResponseHandlerEvent::ResponseOmission(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::ResponseOmission") .field(request_id) .finish(), - RequestResponseHandlerEvent::OutboundTimeout(request_id) => f.debug_tuple("RequestResponseHandlerEvent::OutboundTimeout") + RequestResponseHandlerEvent::OutboundTimeout(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::OutboundTimeout") .field(request_id) .finish(), - RequestResponseHandlerEvent::OutboundUnsupportedProtocols(request_id) => f.debug_tuple("RequestResponseHandlerEvent::OutboundUnsupportedProtocols") + RequestResponseHandlerEvent::OutboundUnsupportedProtocols(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::OutboundUnsupportedProtocols") .field(request_id) .finish(), - RequestResponseHandlerEvent::InboundTimeout(request_id) => f.debug_tuple("RequestResponseHandlerEvent::InboundTimeout") + RequestResponseHandlerEvent::InboundTimeout(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::InboundTimeout") .field(request_id) .finish(), - RequestResponseHandlerEvent::InboundUnsupportedProtocols(request_id) => f.debug_tuple("RequestResponseHandlerEvent::InboundUnsupportedProtocols") + RequestResponseHandlerEvent::InboundUnsupportedProtocols(request_id) => f + .debug_tuple("RequestResponseHandlerEvent::InboundUnsupportedProtocols") .field(request_id) .finish(), } @@ -212,28 +226,25 @@ where codec: self.codec.clone(), request_sender: rq_send, response_receiver: rs_recv, - request_id + request_id, }; // The handler waits for the request to come in. It then emits // `RequestResponseHandlerEvent::Request` together with a // `ResponseChannel`. - self.inbound.push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed()); + self.inbound + .push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed()); SubstreamProtocol::new(proto, request_id).with_timeout(self.substream_timeout) } - fn inject_fully_negotiated_inbound( - &mut self, - sent: bool, - request_id: RequestId - ) { + fn inject_fully_negotiated_inbound(&mut self, sent: bool, request_id: RequestId) { if sent { - self.pending_events.push_back( - RequestResponseHandlerEvent::ResponseSent(request_id)) + self.pending_events + .push_back(RequestResponseHandlerEvent::ResponseSent(request_id)) } else { - self.pending_events.push_back( - RequestResponseHandlerEvent::ResponseOmission(request_id)) + self.pending_events + .push_back(RequestResponseHandlerEvent::ResponseOmission(request_id)) } } @@ -242,9 +253,10 @@ where response: TCodec::Response, request_id: RequestId, ) { - self.pending_events.push_back( - RequestResponseHandlerEvent::Response { - request_id, response + self.pending_events + .push_back(RequestResponseHandlerEvent::Response { + request_id, + response, }); } @@ -260,8 +272,8 @@ where ) { match error { ProtocolsHandlerUpgrErr::Timeout => { - self.pending_events.push_back( - RequestResponseHandlerEvent::OutboundTimeout(info)); + self.pending_events + .push_back(RequestResponseHandlerEvent::OutboundTimeout(info)); } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { // The remote merely doesn't support the protocol(s) we requested. @@ -270,7 +282,8 @@ where // An event is reported to permit user code to react to the fact that // the remote peer does not support the requested protocol(s). self.pending_events.push_back( - RequestResponseHandlerEvent::OutboundUnsupportedProtocols(info)); + RequestResponseHandlerEvent::OutboundUnsupportedProtocols(info), + ); } _ => { // Anything else is considered a fatal error or misbehaviour of @@ -283,12 +296,12 @@ where fn inject_listen_upgrade_error( &mut self, info: RequestId, - error: ProtocolsHandlerUpgrErr + error: ProtocolsHandlerUpgrErr, ) { match error { - ProtocolsHandlerUpgrErr::Timeout => { - self.pending_events.push_back(RequestResponseHandlerEvent::InboundTimeout(info)) - } + ProtocolsHandlerUpgrErr::Timeout => self + .pending_events + .push_back(RequestResponseHandlerEvent::InboundTimeout(info)), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { // The local peer merely doesn't support the protocol(s) requested. // This is no reason to close the connection, which may @@ -296,7 +309,8 @@ where // An event is reported to permit user code to react to the fact that // the local peer does not support the requested protocol(s). self.pending_events.push_back( - RequestResponseHandlerEvent::InboundUnsupportedProtocols(info)); + RequestResponseHandlerEvent::InboundUnsupportedProtocols(info), + ); } _ => { // Anything else is considered a fatal error or misbehaviour of @@ -313,18 +327,17 @@ where fn poll( &mut self, cx: &mut Context<'_>, - ) -> Poll< - ProtocolsHandlerEvent, RequestId, Self::OutEvent, Self::Error>, - > { + ) -> Poll, RequestId, Self::OutEvent, Self::Error>> + { // Check for a pending (fatal) error. if let Some(err) = self.pending_error.take() { // The handler will not be polled again by the `Swarm`. - return Poll::Ready(ProtocolsHandlerEvent::Close(err)) + return Poll::Ready(ProtocolsHandlerEvent::Close(err)); } // Drain pending events. if let Some(event) = self.pending_events.pop_front() { - return Poll::Ready(ProtocolsHandlerEvent::Custom(event)) + return Poll::Ready(ProtocolsHandlerEvent::Custom(event)); } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { self.pending_events.shrink_to_fit(); } @@ -337,8 +350,11 @@ where self.keep_alive = KeepAlive::Yes; return Poll::Ready(ProtocolsHandlerEvent::Custom( RequestResponseHandlerEvent::Request { - request_id: id, request: rq, sender: rs_sender - })) + request_id: id, + request: rq, + sender: rs_sender, + }, + )); } Err(oneshot::Canceled) => { // The inbound upgrade has errored or timed out reading @@ -351,12 +367,10 @@ where // Emit outbound requests. if let Some(request) = self.outbound.pop_front() { let info = request.request_id; - return Poll::Ready( - ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(request, info) - .with_timeout(self.substream_timeout) - }, - ) + return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(request, info) + .with_timeout(self.substream_timeout), + }); } debug_assert!(self.outbound.is_empty()); diff --git a/protocols/request-response/src/handler/protocol.rs b/protocols/request-response/src/handler/protocol.rs index cede827df27..dda4ee00d2d 100644 --- a/protocols/request-response/src/handler/protocol.rs +++ b/protocols/request-response/src/handler/protocol.rs @@ -23,8 +23,8 @@ //! receives a request and sends a response, whereas the //! outbound upgrade send a request and receives a response. -use crate::RequestId; use crate::codec::RequestResponseCodec; +use crate::RequestId; use futures::{channel::oneshot, future::BoxFuture, prelude::*}; use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo}; @@ -40,7 +40,7 @@ pub enum ProtocolSupport { /// The protocol is only supported for outbound requests. Outbound, /// The protocol is supported for inbound and outbound requests. - Full + Full, } impl ProtocolSupport { @@ -67,19 +67,18 @@ impl ProtocolSupport { #[derive(Debug)] pub struct ResponseProtocol where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { pub(crate) codec: TCodec, pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>, pub(crate) request_sender: oneshot::Sender<(RequestId, TCodec::Request)>, pub(crate) response_receiver: oneshot::Receiver, - pub(crate) request_id: RequestId - + pub(crate) request_id: RequestId, } impl UpgradeInfo for ResponseProtocol where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { type Info = TCodec::Protocol; type InfoIter = smallvec::IntoIter<[Self::Info; 2]>; @@ -97,7 +96,11 @@ where type Error = io::Error; type Future = BoxFuture<'static, Result>; - fn upgrade_inbound(mut self, mut io: NegotiatedSubstream, protocol: Self::Info) -> Self::Future { + fn upgrade_inbound( + mut self, + mut io: NegotiatedSubstream, + protocol: Self::Info, + ) -> Self::Future { async move { let read = self.codec.read_request(&protocol, &mut io); let request = read.await?; @@ -129,7 +132,7 @@ where /// Sends a request and receives a response. pub struct RequestProtocol where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { pub(crate) codec: TCodec, pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>, @@ -150,7 +153,7 @@ where impl UpgradeInfo for RequestProtocol where - TCodec: RequestResponseCodec + TCodec: RequestResponseCodec, { type Info = TCodec::Protocol; type InfoIter = smallvec::IntoIter<[Self::Info; 2]>; @@ -168,7 +171,11 @@ where type Error = io::Error; type Future = BoxFuture<'static, Result>; - fn upgrade_outbound(mut self, mut io: NegotiatedSubstream, protocol: Self::Info) -> Self::Future { + fn upgrade_outbound( + mut self, + mut io: NegotiatedSubstream, + protocol: Self::Info, + ) -> Self::Future { async move { let write = self.codec.write_request(&protocol, &mut io, self.request); write.await?; @@ -176,6 +183,7 @@ where let read = self.codec.read_response(&protocol, &mut io); let response = read.await?; Ok(response) - }.boxed() + } + .boxed() } } diff --git a/protocols/request-response/src/lib.rs b/protocols/request-response/src/lib.rs index 7e5fd58c5c1..a2277e4c8df 100644 --- a/protocols/request-response/src/lib.rs +++ b/protocols/request-response/src/lib.rs @@ -60,38 +60,23 @@ pub mod codec; pub mod handler; pub mod throttled; -pub use codec::{RequestResponseCodec, ProtocolName}; +pub use codec::{ProtocolName, RequestResponseCodec}; pub use handler::ProtocolSupport; pub use throttled::Throttled; -use futures::{ - channel::oneshot, -}; -use handler::{ - RequestProtocol, - RequestResponseHandler, - RequestResponseHandlerEvent, -}; -use libp2p_core::{ - ConnectedPoint, - Multiaddr, - PeerId, - connection::ConnectionId, -}; +use futures::channel::oneshot; +use handler::{RequestProtocol, RequestResponseHandler, RequestResponseHandlerEvent}; +use libp2p_core::{connection::ConnectionId, ConnectedPoint, Multiaddr, PeerId}; use libp2p_swarm::{ - DialPeerCondition, - NetworkBehaviour, - NetworkBehaviourAction, - NotifyHandler, - PollParameters, + DialPeerCondition, NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters, }; use smallvec::SmallVec; use std::{ collections::{HashMap, HashSet, VecDeque}, fmt, - time::Duration, sync::{atomic::AtomicU64, Arc}, - task::{Context, Poll} + task::{Context, Poll}, + time::Duration, }; /// An inbound request or response. @@ -117,7 +102,7 @@ pub enum RequestResponseMessage /// The peer who sent the message. peer: PeerId, /// The incoming message. - message: RequestResponseMessage + message: RequestResponseMessage, }, /// An outbound request failed. OutboundFailure { @@ -186,8 +171,12 @@ impl fmt::Display for OutboundFailure { match self { OutboundFailure::DialFailure => write!(f, "Failed to dial the requested peer"), OutboundFailure::Timeout => write!(f, "Timeout while waiting for a response"), - OutboundFailure::ConnectionClosed => write!(f, "Connection was closed before a response was received"), - OutboundFailure::UnsupportedProtocols => write!(f, "The remote supports none of the requested protocols") + OutboundFailure::ConnectionClosed => { + write!(f, "Connection was closed before a response was received") + } + OutboundFailure::UnsupportedProtocols => { + write!(f, "The remote supports none of the requested protocols") + } } } } @@ -217,10 +206,20 @@ pub enum InboundFailure { impl fmt::Display for InboundFailure { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - InboundFailure::Timeout => write!(f, "Timeout while receiving request or sending response"), - InboundFailure::ConnectionClosed => write!(f, "Connection was closed before a response could be sent"), - InboundFailure::UnsupportedProtocols => write!(f, "The local peer supports none of the protocols requested by the remote"), - InboundFailure::ResponseOmission => write!(f, "The response channel was dropped without sending a response to the remote") + InboundFailure::Timeout => { + write!(f, "Timeout while receiving request or sending response") + } + InboundFailure::ConnectionClosed => { + write!(f, "Connection was closed before a response could be sent") + } + InboundFailure::UnsupportedProtocols => write!( + f, + "The local peer supports none of the protocols requested by the remote" + ), + InboundFailure::ResponseOmission => write!( + f, + "The response channel was dropped without sending a response to the remote" + ), } } } @@ -322,7 +321,9 @@ where pending_events: VecDeque< NetworkBehaviourAction< RequestProtocol, - RequestResponseEvent>>, + RequestResponseEvent, + >, + >, /// The currently connected peers, their pending outbound and inbound responses and their known, /// reachable addresses, if any. connected: HashMap>, @@ -341,7 +342,7 @@ where /// protocols, codec and configuration. pub fn new(codec: TCodec, protocols: I, cfg: RequestResponseConfig) -> Self where - I: IntoIterator + I: IntoIterator, { let mut inbound_protocols = SmallVec::new(); let mut outbound_protocols = SmallVec::new(); @@ -375,7 +376,7 @@ where where I: IntoIterator, TCodec: Send, - TCodec::Protocol: Sync + TCodec::Protocol: Sync, { Throttled::new(c, protos, cfg) } @@ -402,11 +403,15 @@ where }; if let Some(request) = self.try_send_request(peer, request) { - self.pending_events.push_back(NetworkBehaviourAction::DialPeer { - peer_id: *peer, - condition: DialPeerCondition::Disconnected, - }); - self.pending_outbound_requests.entry(*peer).or_default().push(request); + self.pending_events + .push_back(NetworkBehaviourAction::DialPeer { + peer_id: *peer, + condition: DialPeerCondition::Disconnected, + }); + self.pending_outbound_requests + .entry(*peer) + .or_default() + .push(request); } request_id @@ -423,9 +428,11 @@ where /// /// The provided `ResponseChannel` is obtained from an inbound /// [`RequestResponseMessage::Request`]. - pub fn send_response(&mut self, ch: ResponseChannel, rs: TCodec::Response) - -> Result<(), TCodec::Response> - { + pub fn send_response( + &mut self, + ch: ResponseChannel, + rs: TCodec::Response, + ) -> Result<(), TCodec::Response> { ch.sender.send(rs) } @@ -464,12 +471,19 @@ where /// pending, i.e. waiting for a response. pub fn is_pending_outbound(&self, peer: &PeerId, request_id: &RequestId) -> bool { // Check if request is already sent on established connection. - let est_conn = self.connected.get(peer) - .map(|cs| cs.iter().any(|c| c.pending_inbound_responses.contains(request_id))) + let est_conn = self + .connected + .get(peer) + .map(|cs| { + cs.iter() + .any(|c| c.pending_inbound_responses.contains(request_id)) + }) .unwrap_or(false); // Check if request is still pending to be sent. - let pen_conn = self.pending_outbound_requests.get(peer) - .map(|rps| rps.iter().any(|rp| {rp.request_id == *request_id})) + let pen_conn = self + .pending_outbound_requests + .get(peer) + .map(|rps| rps.iter().any(|rp| rp.request_id == *request_id)) .unwrap_or(false); est_conn || pen_conn @@ -479,8 +493,12 @@ where /// [`PeerId`] is still pending, i.e. waiting for a response by the local /// node through [`RequestResponse::send_response`]. pub fn is_pending_inbound(&self, peer: &PeerId, request_id: &RequestId) -> bool { - self.connected.get(peer) - .map(|cs| cs.iter().any(|c| c.pending_outbound_responses.contains(request_id))) + self.connected + .get(peer) + .map(|cs| { + cs.iter() + .any(|c| c.pending_outbound_responses.contains(request_id)) + }) .unwrap_or(false) } @@ -494,21 +512,24 @@ where /// Tries to send a request by queueing an appropriate event to be /// emitted to the `Swarm`. If the peer is not currently connected, /// the given request is return unchanged. - fn try_send_request(&mut self, peer: &PeerId, request: RequestProtocol) - -> Option> - { + fn try_send_request( + &mut self, + peer: &PeerId, + request: RequestProtocol, + ) -> Option> { if let Some(connections) = self.connected.get_mut(peer) { if connections.is_empty() { - return Some(request) + return Some(request); } let ix = (request.request_id.0 as usize) % connections.len(); let conn = &mut connections[ix]; conn.pending_inbound_responses.insert(request.request_id); - self.pending_events.push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: *peer, - handler: NotifyHandler::One(conn.id), - event: request - }); + self.pending_events + .push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: *peer, + handler: NotifyHandler::One(conn.id), + event: request, + }); None } else { Some(request) @@ -554,9 +575,9 @@ where peer: &PeerId, connection: ConnectionId, ) -> Option<&mut Connection> { - self.connected.get_mut(peer).and_then(|connections| { - connections.iter_mut().find(|c| c.id == connection) - }) + self.connected + .get_mut(peer) + .and_then(|connections| connections.iter_mut().find(|c| c.id == connection)) } } @@ -573,7 +594,7 @@ where self.codec.clone(), self.config.connection_keep_alive, self.config.request_timeout, - self.next_inbound_id.clone() + self.next_inbound_id.clone(), ) } @@ -597,21 +618,35 @@ where } } - fn inject_connection_established(&mut self, peer: &PeerId, conn: &ConnectionId, endpoint: &ConnectedPoint) { + fn inject_connection_established( + &mut self, + peer: &PeerId, + conn: &ConnectionId, + endpoint: &ConnectedPoint, + ) { let address = match endpoint { ConnectedPoint::Dialer { address } => Some(address.clone()), - ConnectedPoint::Listener { .. } => None + ConnectedPoint::Listener { .. } => None, }; - self.connected.entry(*peer) + self.connected + .entry(*peer) .or_default() .push(Connection::new(*conn, address)); } - fn inject_connection_closed(&mut self, peer_id: &PeerId, conn: &ConnectionId, _: &ConnectedPoint) { - let connections = self.connected.get_mut(peer_id) + fn inject_connection_closed( + &mut self, + peer_id: &PeerId, + conn: &ConnectionId, + _: &ConnectedPoint, + ) { + let connections = self + .connected + .get_mut(peer_id) .expect("Expected some established connection to peer before closing."); - let connection = connections.iter() + let connection = connections + .iter() .position(|c| &c.id == conn) .map(|p: usize| connections.remove(p)) .expect("Expected connection to be established before closing."); @@ -621,24 +656,25 @@ where } for request_id in connection.pending_outbound_responses { - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::InboundFailure { - peer: *peer_id, - request_id, - error: InboundFailure::ConnectionClosed - } - )); - + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::InboundFailure { + peer: *peer_id, + request_id, + error: InboundFailure::ConnectionClosed, + }, + )); } for request_id in connection.pending_inbound_responses { - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::OutboundFailure { - peer: *peer_id, - request_id, - error: OutboundFailure::ConnectionClosed - } - )); + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::OutboundFailure { + peer: *peer_id, + request_id, + error: OutboundFailure::ConnectionClosed, + }, + )); } } @@ -655,13 +691,14 @@ where // another, concurrent dialing attempt ongoing. if let Some(pending) = self.pending_outbound_requests.remove(peer) { for request in pending { - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::OutboundFailure { - peer: *peer, - request_id: request.request_id, - error: OutboundFailure::DialFailure - } - )); + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::OutboundFailure { + peer: *peer, + request_id: request.request_id, + error: OutboundFailure::DialFailure, + }, + )); } } } @@ -673,49 +710,74 @@ where event: RequestResponseHandlerEvent, ) { match event { - RequestResponseHandlerEvent::Response { request_id, response } => { + RequestResponseHandlerEvent::Response { + request_id, + response, + } => { let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); debug_assert!( removed, "Expect request_id to be pending before receiving response.", ); - let message = RequestResponseMessage::Response { request_id, response }; - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::Message { peer, message })); + let message = RequestResponseMessage::Response { + request_id, + response, + }; + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::Message { peer, message }, + )); } - RequestResponseHandlerEvent::Request { request_id, request, sender } => { - let channel = ResponseChannel { request_id, peer, sender }; - let message = RequestResponseMessage::Request { request_id, request, channel }; - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::Message { peer, message } - )); + RequestResponseHandlerEvent::Request { + request_id, + request, + sender, + } => { + let channel = ResponseChannel { + request_id, + peer, + sender, + }; + let message = RequestResponseMessage::Request { + request_id, + request, + channel, + }; + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::Message { peer, message }, + )); match self.get_connection_mut(&peer, connection) { Some(connection) => { let inserted = connection.pending_outbound_responses.insert(request_id); debug_assert!(inserted, "Expect id of new request to be unknown."); - }, + } // Connection closed after `RequestResponseEvent::Request` has been emitted. None => { - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::InboundFailure { - peer, - request_id, - error: InboundFailure::ConnectionClosed - } - )); + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::InboundFailure { + peer, + request_id, + error: InboundFailure::ConnectionClosed, + }, + )); } } } RequestResponseHandlerEvent::ResponseSent(request_id) => { let removed = self.remove_pending_outbound_response(&peer, connection, request_id); - debug_assert!(removed, "Expect request_id to be pending before response is sent."); + debug_assert!( + removed, + "Expect request_id to be pending before response is sent." + ); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::ResponseSent { peer, request_id })); + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::ResponseSent { peer, request_id }, + )); } RequestResponseHandlerEvent::ResponseOmission(request_id) => { let removed = self.remove_pending_outbound_response(&peer, connection, request_id); @@ -724,25 +786,30 @@ where "Expect request_id to be pending before response is omitted.", ); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::InboundFailure { peer, request_id, - error: InboundFailure::ResponseOmission - })); + error: InboundFailure::ResponseOmission, + }, + )); } RequestResponseHandlerEvent::OutboundTimeout(request_id) => { let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); - debug_assert!(removed, "Expect request_id to be pending before request times out."); + debug_assert!( + removed, + "Expect request_id to be pending before request times out." + ); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::OutboundFailure { peer, request_id, error: OutboundFailure::Timeout, - })); + }, + )); } RequestResponseHandlerEvent::InboundTimeout(request_id) => { // Note: `RequestResponseHandlerEvent::InboundTimeout` is emitted both for timing @@ -751,13 +818,14 @@ where // not assert the request_id to be present before removing it. self.remove_pending_outbound_response(&peer, connection, request_id); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::InboundFailure { peer, request_id, error: InboundFailure::Timeout, - })); + }, + )); } RequestResponseHandlerEvent::OutboundUnsupportedProtocols(request_id) => { let removed = self.remove_pending_inbound_response(&peer, connection, &request_id); @@ -766,35 +834,41 @@ where "Expect request_id to be pending before failing to connect.", ); - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::OutboundFailure { peer, request_id, error: OutboundFailure::UnsupportedProtocols, - })); + }, + )); } RequestResponseHandlerEvent::InboundUnsupportedProtocols(request_id) => { // Note: No need to call `self.remove_pending_outbound_response`, // `RequestResponseHandlerEvent::Request` was never emitted for this request and // thus request was never added to `pending_outbound_responses`. - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( + self.pending_events + .push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::InboundFailure { peer, request_id, error: InboundFailure::UnsupportedProtocols, - })); + }, + )); } } } - fn poll(&mut self, _: &mut Context<'_>, _: &mut impl PollParameters) - -> Poll, + _: &mut impl PollParameters, + ) -> Poll< + NetworkBehaviourAction< RequestProtocol, - RequestResponseEvent - >> - { + RequestResponseEvent, + >, + > { if let Some(ev) = self.pending_events.pop_front() { return Poll::Ready(ev); } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD { @@ -821,7 +895,7 @@ struct Connection { pending_outbound_responses: HashSet, /// Pending inbound responses for previously sent requests on this /// connection. - pending_inbound_responses: HashSet + pending_inbound_responses: HashSet, } impl Connection { diff --git a/protocols/request-response/src/throttled.rs b/protocols/request-response/src/throttled.rs index 611331f4f52..c882f41b211 100644 --- a/protocols/request-response/src/throttled.rs +++ b/protocols/request-response/src/throttled.rs @@ -36,22 +36,20 @@ mod codec; -use codec::{Codec, Message, ProtocolWrapper, Type}; +use super::{ + ProtocolSupport, RequestId, RequestResponse, RequestResponseCodec, RequestResponseConfig, + RequestResponseEvent, RequestResponseMessage, +}; use crate::handler::{RequestProtocol, RequestResponseHandler, RequestResponseHandlerEvent}; +use codec::{Codec, Message, ProtocolWrapper, Type}; use futures::ready; -use libp2p_core::{ConnectedPoint, connection::ConnectionId, Multiaddr, PeerId}; +use libp2p_core::{connection::ConnectionId, ConnectedPoint, Multiaddr, PeerId}; use libp2p_swarm::{NetworkBehaviour, NetworkBehaviourAction, PollParameters}; use lru::LruCache; -use std::{collections::{HashMap, HashSet, VecDeque}, task::{Context, Poll}}; use std::{cmp::max, num::NonZeroU16}; -use super::{ - ProtocolSupport, - RequestId, - RequestResponse, - RequestResponseCodec, - RequestResponseConfig, - RequestResponseEvent, - RequestResponseMessage, +use std::{ + collections::{HashMap, HashSet, VecDeque}, + task::{Context, Poll}, }; pub type ResponseChannel = super::ResponseChannel>; @@ -60,7 +58,7 @@ pub type ResponseChannel = super::ResponseChannel>; pub struct Throttled where C: RequestResponseCodec + Send, - C::Protocol: Sync + C::Protocol: Sync, { /// A random id used for logging. id: u32, @@ -77,7 +75,7 @@ where /// Pending events to report in `Throttled::poll`. events: VecDeque>>, /// The current credit ID. - next_grant_id: u64 + next_grant_id: u64, } /// Information about a credit grant that is sent to remote peers. @@ -89,7 +87,7 @@ struct Grant { request: RequestId, /// The credit given in this grant, i.e. the number of additional /// requests the remote is allowed to send. - credit: u16 + credit: u16, } /// Max. number of inbound requests that can be received. @@ -99,7 +97,7 @@ struct Limit { max_recv: NonZeroU16, /// The next receive limit which becomes active after /// the current limit has been reached. - next_max: NonZeroU16 + next_max: NonZeroU16, } impl Limit { @@ -111,7 +109,7 @@ impl Limit { // sender so we must not use `max` right away. Limit { max_recv: NonZeroU16::new(1).expect("1 > 0"), - next_max: max + next_max: max, } } @@ -191,7 +189,7 @@ impl PeerInfo { limit: recv_limit, remaining: 1, sent: HashSet::new(), - } + }, } } @@ -210,16 +208,18 @@ impl PeerInfo { impl Throttled where C: RequestResponseCodec + Send + Clone, - C::Protocol: Sync + C::Protocol: Sync, { /// Create a new throttled request-response behaviour. pub fn new(c: C, protos: I, cfg: RequestResponseConfig) -> Self where I: IntoIterator, C: Send, - C::Protocol: Sync + C::Protocol: Sync, { - let protos = protos.into_iter().map(|(p, ps)| (ProtocolWrapper::new(b"/t/1", p), ps)); + let protos = protos + .into_iter() + .map(|(p, ps)| (ProtocolWrapper::new(b"/t/1", p), ps)); Throttled::from(RequestResponse::new(Codec::new(c, 8192), protos, cfg)) } @@ -233,7 +233,7 @@ where default_limit: Limit::new(NonZeroU16::new(1).expect("1 > 0")), limit_overrides: HashMap::new(), events: VecDeque::new(), - next_grant_id: 0 + next_grant_id: 0, } } @@ -262,7 +262,10 @@ where /// Has the limit of outbound requests been reached for the given peer? pub fn can_send(&mut self, p: &PeerId) -> bool { - self.peer_info.get(p).map(|i| i.send_budget.remaining > 0).unwrap_or(true) + self.peer_info + .get(p) + .map(|i| i.send_budget.remaining > 0) + .unwrap_or(true) } /// Send a request to a peer. @@ -273,22 +276,30 @@ where pub fn send_request(&mut self, p: &PeerId, req: C::Request) -> Result { let connected = &mut self.peer_info; let disconnected = &mut self.offline_peer_info; - let remaining = - if let Some(info) = connected.get_mut(p).or_else(|| disconnected.get_mut(p)) { - if info.send_budget.remaining == 0 { - log::trace!("{:08x}: no more budget to send another request to {}", self.id, p); - return Err(req) - } - info.send_budget.remaining -= 1; - info.send_budget.remaining - } else { - let limit = self.limit_overrides.get(p).copied().unwrap_or(self.default_limit); - let mut info = PeerInfo::new(limit); - info.send_budget.remaining -= 1; - let remaining = info.send_budget.remaining; - self.offline_peer_info.put(*p, info); - remaining - }; + let remaining = if let Some(info) = connected.get_mut(p).or_else(|| disconnected.get_mut(p)) + { + if info.send_budget.remaining == 0 { + log::trace!( + "{:08x}: no more budget to send another request to {}", + self.id, + p + ); + return Err(req); + } + info.send_budget.remaining -= 1; + info.send_budget.remaining + } else { + let limit = self + .limit_overrides + .get(p) + .copied() + .unwrap_or(self.default_limit); + let mut info = PeerInfo::new(limit); + info.send_budget.remaining -= 1; + let remaining = info.send_budget.remaining; + self.offline_peer_info.put(*p, info); + remaining + }; let rid = self.behaviour.send_request(p, Message::request(req)); @@ -305,12 +316,20 @@ where /// Answer an inbound request with a response. /// /// See [`RequestResponse::send_response`] for details. - pub fn send_response(&mut self, ch: ResponseChannel, res: C::Response) - -> Result<(), C::Response> - { - log::trace!("{:08x}: sending response {} to peer {}", self.id, ch.request_id(), &ch.peer); + pub fn send_response( + &mut self, + ch: ResponseChannel, + res: C::Response, + ) -> Result<(), C::Response> { + log::trace!( + "{:08x}: sending response {} to peer {}", + self.id, + ch.request_id(), + &ch.peer + ); if let Some(info) = self.peer_info.get_mut(&ch.peer) { - if info.recv_budget.remaining == 0 { // need to send more credit to the remote peer + if info.recv_budget.remaining == 0 { + // need to send more credit to the remote peer let crd = info.recv_budget.limit.switch(); info.recv_budget.remaining = info.recv_budget.limit.max_recv.get(); self.send_credit(&ch.peer, crd); @@ -350,7 +369,6 @@ where self.behaviour.is_pending_outbound(p, r) } - /// Is the remote waiting for the local node to respond to the given /// request? /// @@ -365,8 +383,18 @@ where let cid = self.next_grant_id; self.next_grant_id += 1; let rid = self.behaviour.send_request(p, Message::credit(credit, cid)); - log::trace!("{:08x}: sending {} credit as grant {} to {}", self.id, credit, cid, p); - let grant = Grant { id: cid, request: rid, credit }; + log::trace!( + "{:08x}: sending {} credit as grant {} to {}", + self.id, + credit, + cid, + p + ); + let grant = Grant { + id: cid, + request: rid, + credit, + }; info.recv_budget.grant = Some(grant); info.recv_budget.sent.insert(rid); } @@ -383,13 +411,13 @@ pub enum Event { /// When previously reaching the send limit of a peer, /// this event is eventually emitted when sending is /// allowed to resume. - ResumeSending(PeerId) + ResumeSending(PeerId), } impl NetworkBehaviour for Throttled where C: RequestResponseCodec + Send + Clone + 'static, - C::Protocol: Sync + C::Protocol: Sync, { type ProtocolsHandler = RequestResponseHandler>; type OutEvent = Event>; @@ -402,7 +430,12 @@ where self.behaviour.addresses_of_peer(p) } - fn inject_connection_established(&mut self, p: &PeerId, id: &ConnectionId, end: &ConnectedPoint) { + fn inject_connection_established( + &mut self, + p: &PeerId, + id: &ConnectionId, + end: &ConnectedPoint, + ) { self.behaviour.inject_connection_established(p, id, end) } @@ -433,7 +466,11 @@ where self.send_credit(p, recv_budget - 1); } } else { - let limit = self.limit_overrides.get(p).copied().unwrap_or(self.default_limit); + let limit = self + .limit_overrides + .get(p) + .copied() + .unwrap_or(self.default_limit); self.peer_info.insert(*p, PeerInfo::new(limit)); } } @@ -451,142 +488,183 @@ where self.behaviour.inject_dial_failure(p) } - fn inject_event(&mut self, p: PeerId, i: ConnectionId, e: RequestResponseHandlerEvent>) { + fn inject_event( + &mut self, + p: PeerId, + i: ConnectionId, + e: RequestResponseHandlerEvent>, + ) { self.behaviour.inject_event(p, i, e) } - fn poll(&mut self, cx: &mut Context<'_>, params: &mut impl PollParameters) - -> Poll>, Self::OutEvent>> - { + fn poll( + &mut self, + cx: &mut Context<'_>, + params: &mut impl PollParameters, + ) -> Poll>, Self::OutEvent>> { loop { if let Some(ev) = self.events.pop_front() { - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(ev)) + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(ev)); } else if self.events.capacity() > super::EMPTY_QUEUE_SHRINK_THRESHOLD { self.events.shrink_to_fit() } let event = match ready!(self.behaviour.poll(cx, params)) { - | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::Message { peer, message }) => { + NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::Message { + peer, + message, + }) => { let message = match message { - | RequestResponseMessage::Response { request_id, response } => - match &response.header().typ { - | Some(Type::Ack) => { - if let Some(info) = self.peer_info.get_mut(&peer) { - if let Some(id) = info.recv_budget.grant.as_ref().map(|c| c.id) { - if Some(id) == response.header().ident { - log::trace!("{:08x}: received ack {} from {}", self.id, id, peer); - info.recv_budget.grant = None; - } + RequestResponseMessage::Response { + request_id, + response, + } => match &response.header().typ { + Some(Type::Ack) => { + if let Some(info) = self.peer_info.get_mut(&peer) { + if let Some(id) = info.recv_budget.grant.as_ref().map(|c| c.id) + { + if Some(id) == response.header().ident { + log::trace!( + "{:08x}: received ack {} from {}", + self.id, + id, + peer + ); + info.recv_budget.grant = None; } - info.recv_budget.sent.remove(&request_id); } - continue + info.recv_budget.sent.remove(&request_id); + } + continue; + } + Some(Type::Response) => { + log::trace!( + "{:08x}: received response {} from {}", + self.id, + request_id, + peer + ); + if let Some(rs) = response.into_parts().1 { + RequestResponseMessage::Response { + request_id, + response: rs, + } + } else { + log::error! { "{:08x}: missing data for response {} from peer {}", + self.id, + request_id, + peer + } + continue; } - | Some(Type::Response) => { - log::trace!("{:08x}: received response {} from {}", self.id, request_id, peer); - if let Some(rs) = response.into_parts().1 { - RequestResponseMessage::Response { request_id, response: rs } + } + ty => { + log::trace! { + "{:08x}: unknown message type: {:?} from {}; expected response or credit", + self.id, + ty, + peer + }; + continue; + } + }, + RequestResponseMessage::Request { + request_id, + request, + channel, + } => match &request.header().typ { + Some(Type::Credit) => { + if let Some(info) = self.peer_info.get_mut(&peer) { + let id = if let Some(n) = request.header().ident { + n } else { - log::error! { "{:08x}: missing data for response {} from peer {}", + log::warn! { "{:08x}: missing credit id in message from {}", self.id, - request_id, peer } - continue - } - } - | ty => { - log::trace! { - "{:08x}: unknown message type: {:?} from {}; expected response or credit", + continue; + }; + let credit = request.header().credit.unwrap_or(0); + log::trace! { "{:08x}: received {} additional credit {} from {}", self.id, - ty, + credit, + id, peer }; - continue - } - } - | RequestResponseMessage::Request { request_id, request, channel } => - match &request.header().typ { - | Some(Type::Credit) => { - if let Some(info) = self.peer_info.get_mut(&peer) { - let id = if let Some(n) = request.header().ident { - n - } else { - log::warn! { "{:08x}: missing credit id in message from {}", + if info.send_budget.grant < Some(id) { + if info.send_budget.remaining == 0 && credit > 0 { + log::trace!( + "{:08x}: sending to peer {} can resume", self.id, peer - } - continue - }; - let credit = request.header().credit.unwrap_or(0); - log::trace! { "{:08x}: received {} additional credit {} from {}", - self.id, - credit, - id, - peer - }; - if info.send_budget.grant < Some(id) { - if info.send_budget.remaining == 0 && credit > 0 { - log::trace!("{:08x}: sending to peer {} can resume", self.id, peer); - self.events.push_back(Event::ResumeSending(peer)) - } - info.send_budget.remaining += credit; - info.send_budget.grant = Some(id); + ); + self.events.push_back(Event::ResumeSending(peer)) } - // Note: Failing to send a response to a credit grant is - // handled along with other inbound failures further below. - let _ = self.behaviour.send_response(channel, Message::ack(id)); - info.send_budget.received.insert(request_id); + info.send_budget.remaining += credit; + info.send_budget.grant = Some(id); } - continue + // Note: Failing to send a response to a credit grant is + // handled along with other inbound failures further below. + let _ = self.behaviour.send_response(channel, Message::ack(id)); + info.send_budget.received.insert(request_id); } - | Some(Type::Request) => { - if let Some(info) = self.peer_info.get_mut(&peer) { - log::trace! { "{:08x}: received request {} (recv. budget = {})", - self.id, - request_id, - info.recv_budget.remaining - }; - if info.recv_budget.remaining == 0 { - log::debug!("{:08x}: peer {} exceeds its budget", self.id, peer); - self.events.push_back(Event::TooManyInboundRequests(peer)); - continue - } - info.recv_budget.remaining -= 1; - // We consider a request as proof that our credit grant has - // reached the peer. Usually, an ACK has already been - // received. - info.recv_budget.grant = None; - } - if let Some(rq) = request.into_parts().1 { - RequestResponseMessage::Request { request_id, request: rq, channel } - } else { - log::error! { "{:08x}: missing data for request {} from peer {}", + continue; + } + Some(Type::Request) => { + if let Some(info) = self.peer_info.get_mut(&peer) { + log::trace! { "{:08x}: received request {} (recv. budget = {})", + self.id, + request_id, + info.recv_budget.remaining + }; + if info.recv_budget.remaining == 0 { + log::debug!( + "{:08x}: peer {} exceeds its budget", self.id, - request_id, peer - } - continue + ); + self.events.push_back(Event::TooManyInboundRequests(peer)); + continue; } + info.recv_budget.remaining -= 1; + // We consider a request as proof that our credit grant has + // reached the peer. Usually, an ACK has already been + // received. + info.recv_budget.grant = None; } - | ty => { - log::trace! { - "{:08x}: unknown message type: {:?} from {}; expected request or ack", + if let Some(rq) = request.into_parts().1 { + RequestResponseMessage::Request { + request_id, + request: rq, + channel, + } + } else { + log::error! { "{:08x}: missing data for request {} from peer {}", self.id, - ty, + request_id, peer - }; - continue + } + continue; } } + ty => { + log::trace! { + "{:08x}: unknown message type: {:?} from {}; expected request or ack", + self.id, + ty, + peer + }; + continue; + } + }, }; let event = RequestResponseEvent::Message { peer, message }; NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::OutboundFailure { + NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::OutboundFailure { peer, request_id, - error + error, }) => { if let Some(info) = self.peer_info.get_mut(&peer) { if let Some(grant) = info.recv_budget.grant.as_mut() { @@ -606,16 +684,20 @@ where // If the outbound failure was for a credit message, don't report it on // the public API and retry the sending. if info.recv_budget.sent.remove(&request_id) { - continue + continue; } } - let event = RequestResponseEvent::OutboundFailure { peer, request_id, error }; + let event = RequestResponseEvent::OutboundFailure { + peer, + request_id, + error, + }; NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::InboundFailure { + NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::InboundFailure { peer, request_id, - error + error, }) => { // If the inbound failure occurred in the context of responding to a // credit grant, don't report it on the public API. @@ -625,15 +707,19 @@ where "{:08}: failed to acknowledge credit grant from {}: {:?}", self.id, peer, error }; - continue + continue; } } - let event = RequestResponseEvent::InboundFailure { peer, request_id, error }; + let event = RequestResponseEvent::InboundFailure { + peer, + request_id, + error, + }; NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::ResponseSent { + NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::ResponseSent { peer, - request_id + request_id, }) => { // If this event is for an ACK response that was sent for // the last received credit grant, skip it. @@ -644,25 +730,41 @@ where self.id, info.send_budget.grant, } - continue + continue; } } NetworkBehaviourAction::GenerateEvent(Event::Event( - RequestResponseEvent::ResponseSent { peer, request_id })) + RequestResponseEvent::ResponseSent { peer, request_id }, + )) + } + NetworkBehaviourAction::DialAddress { address } => { + NetworkBehaviourAction::DialAddress { address } + } + NetworkBehaviourAction::DialPeer { peer_id, condition } => { + NetworkBehaviourAction::DialPeer { peer_id, condition } + } + NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + } => NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + }, + NetworkBehaviourAction::ReportObservedAddr { address, score } => { + NetworkBehaviourAction::ReportObservedAddr { address, score } } - | NetworkBehaviourAction::DialAddress { address } => - NetworkBehaviourAction::DialAddress { address }, - | NetworkBehaviourAction::DialPeer { peer_id, condition } => - NetworkBehaviourAction::DialPeer { peer_id, condition }, - | NetworkBehaviourAction::NotifyHandler { peer_id, handler, event } => - NetworkBehaviourAction::NotifyHandler { peer_id, handler, event }, - | NetworkBehaviourAction::ReportObservedAddr { address, score } => - NetworkBehaviourAction::ReportObservedAddr { address, score }, - | NetworkBehaviourAction::CloseConnection { peer_id, connection } => - NetworkBehaviourAction::CloseConnection { peer_id, connection } + NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + } => NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }, }; - return Poll::Ready(event) + return Poll::Ready(event); } } } diff --git a/protocols/request-response/src/throttled/codec.rs b/protocols/request-response/src/throttled/codec.rs index 580fdd3da85..f82c4ae3961 100644 --- a/protocols/request-response/src/throttled/codec.rs +++ b/protocols/request-response/src/throttled/codec.rs @@ -18,13 +18,13 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use super::RequestResponseCodec; use async_trait::async_trait; use bytes::{Bytes, BytesMut}; use futures::prelude::*; use libp2p_core::ProtocolName; -use minicbor::{Encode, Decode}; +use minicbor::{Decode, Encode}; use std::io; -use super::RequestResponseCodec; use unsigned_varint::{aio, io::ReadError}; /// A protocol header. @@ -32,27 +32,34 @@ use unsigned_varint::{aio, io::ReadError}; #[cbor(map)] pub struct Header { /// The type of message. - #[n(0)] pub typ: Option, + #[n(0)] + pub typ: Option, /// The number of additional requests the remote is willing to receive. - #[n(1)] pub credit: Option, + #[n(1)] + pub credit: Option, /// An identifier used for sending credit grants. - #[n(2)] pub ident: Option + #[n(2)] + pub ident: Option, } /// A protocol message type. #[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)] pub enum Type { - #[n(0)] Request, - #[n(1)] Response, - #[n(2)] Credit, - #[n(3)] Ack + #[n(0)] + Request, + #[n(1)] + Response, + #[n(2)] + Credit, + #[n(3)] + Ack, } /// A protocol message consisting of header and data. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Message { header: Header, - data: Option + data: Option, } impl Message { @@ -63,26 +70,40 @@ impl Message { /// Create a request message. pub fn request(data: T) -> Self { - let mut m = Message::new(Header { typ: Some(Type::Request), .. Header::default() }); + let mut m = Message::new(Header { + typ: Some(Type::Request), + ..Header::default() + }); m.data = Some(data); m } /// Create a response message. pub fn response(data: T) -> Self { - let mut m = Message::new(Header { typ: Some(Type::Response), .. Header::default() }); + let mut m = Message::new(Header { + typ: Some(Type::Response), + ..Header::default() + }); m.data = Some(data); m } /// Create a credit grant. pub fn credit(credit: u16, ident: u64) -> Self { - Message::new(Header { typ: Some(Type::Credit), credit: Some(credit), ident: Some(ident) }) + Message::new(Header { + typ: Some(Type::Credit), + credit: Some(credit), + ident: Some(ident), + }) } /// Create an acknowledge message. pub fn ack(ident: u64) -> Self { - Message::new(Header { typ: Some(Type::Ack), credit: None, ident: Some(ident) }) + Message::new(Header { + typ: Some(Type::Ack), + credit: None, + ident: Some(ident), + }) } /// Access the message header. @@ -130,28 +151,34 @@ pub struct Codec { /// Encoding/decoding buffer. buffer: Vec, /// Max. header length. - max_header_len: u32 + max_header_len: u32, } impl Codec { /// Create a codec by wrapping an existing one. pub fn new(c: C, max_header_len: u32) -> Self { - Codec { inner: c, buffer: Vec::new(), max_header_len } + Codec { + inner: c, + buffer: Vec::new(), + max_header_len, + } } /// Read and decode a request header. async fn read_header(&mut self, io: &mut T) -> io::Result where T: AsyncRead + Unpin + Send, - H: for<'a> minicbor::Decode<'a> + H: for<'a> minicbor::Decode<'a>, { - let header_len = aio::read_u32(&mut *io).await - .map_err(|e| match e { - ReadError::Io(e) => e, - other => io::Error::new(io::ErrorKind::Other, other) - })?; + let header_len = aio::read_u32(&mut *io).await.map_err(|e| match e { + ReadError::Io(e) => e, + other => io::Error::new(io::ErrorKind::Other, other), + })?; if header_len > self.max_header_len { - return Err(io::Error::new(io::ErrorKind::InvalidData, "header too large to read")) + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "header too large to read", + )); } self.buffer.resize(u32_to_usize(header_len), 0u8); io.read_exact(&mut self.buffer).await?; @@ -162,12 +189,16 @@ impl Codec { async fn write_header(&mut self, hdr: &H, io: &mut T) -> io::Result<()> where T: AsyncWrite + Unpin + Send, - H: minicbor::Encode + H: minicbor::Encode, { self.buffer.clear(); - minicbor::encode(hdr, &mut self.buffer).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + minicbor::encode(hdr, &mut self.buffer) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; if self.buffer.len() > u32_to_usize(self.max_header_len) { - return Err(io::Error::new(io::ErrorKind::InvalidData, "header too large to write")) + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "header too large to write", + )); } let mut b = unsigned_varint::encode::u32_buffer(); let header_len = unsigned_varint::encode::u32(self.buffer.len() as u32, &mut b); @@ -180,7 +211,7 @@ impl Codec { impl RequestResponseCodec for Codec where C: RequestResponseCodec + Send, - C::Protocol: Sync + C::Protocol: Sync, { type Protocol = ProtocolWrapper; type Request = Message; @@ -188,7 +219,7 @@ where async fn read_request(&mut self, p: &Self::Protocol, io: &mut T) -> io::Result where - T: AsyncRead + Unpin + Send + T: AsyncRead + Unpin + Send, { let mut msg = Message::new(self.read_header(io).await?); match msg.header.typ { @@ -198,15 +229,22 @@ where } Some(Type::Credit) => Ok(msg), Some(Type::Response) | Some(Type::Ack) | None => { - log::debug!("unexpected {:?} when expecting request or credit grant", msg.header.typ); + log::debug!( + "unexpected {:?} when expecting request or credit grant", + msg.header.typ + ); Err(io::ErrorKind::InvalidData.into()) } } } - async fn read_response(&mut self, p: &Self::Protocol, io: &mut T) -> io::Result + async fn read_response( + &mut self, + p: &Self::Protocol, + io: &mut T, + ) -> io::Result where - T: AsyncRead + Unpin + Send + T: AsyncRead + Unpin + Send, { let mut msg = Message::new(self.read_header(io).await?); match msg.header.typ { @@ -216,15 +254,23 @@ where } Some(Type::Ack) => Ok(msg), Some(Type::Request) | Some(Type::Credit) | None => { - log::debug!("unexpected {:?} when expecting response or ack", msg.header.typ); + log::debug!( + "unexpected {:?} when expecting response or ack", + msg.header.typ + ); Err(io::ErrorKind::InvalidData.into()) } } } - async fn write_request(&mut self, p: &Self::Protocol, io: &mut T, r: Self::Request) -> io::Result<()> + async fn write_request( + &mut self, + p: &Self::Protocol, + io: &mut T, + r: Self::Request, + ) -> io::Result<()> where - T: AsyncWrite + Unpin + Send + T: AsyncWrite + Unpin + Send, { self.write_header(&r.header, io).await?; if let Some(data) = r.data { @@ -233,9 +279,14 @@ where Ok(()) } - async fn write_response(&mut self, p: &Self::Protocol, io: &mut T, r: Self::Response) -> io::Result<()> + async fn write_response( + &mut self, + p: &Self::Protocol, + io: &mut T, + r: Self::Response, + ) -> io::Result<()> where - T: AsyncWrite + Unpin + Send + T: AsyncWrite + Unpin + Send, { self.write_header(&r.header, io).await?; if let Some(data) = r.data { diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index 8e877f8f6dd..a88611e64b4 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -21,22 +21,21 @@ //! Integration tests for the `RequestResponse` network behaviour. use async_trait::async_trait; +use futures::{channel::mpsc, executor::LocalPool, prelude::*, task::SpawnExt, AsyncWriteExt}; use libp2p_core::{ - Multiaddr, - PeerId, identity, muxing::StreamMuxerBox, transport::{self, Transport}, - upgrade::{read_length_prefixed, write_length_prefixed} + upgrade::{read_length_prefixed, write_length_prefixed}, + Multiaddr, PeerId, }; -use libp2p_noise::{NoiseConfig, X25519Spec, Keypair}; +use libp2p_noise::{Keypair, NoiseConfig, X25519Spec}; use libp2p_request_response::*; use libp2p_swarm::{Swarm, SwarmEvent}; use libp2p_tcp::TcpConfig; -use futures::{channel::mpsc, executor::LocalPool, prelude::*, task::SpawnExt, AsyncWriteExt}; use rand::{self, Rng}; -use std::{io, iter}; use std::{collections::HashSet, num::NonZeroU16}; +use std::{io, iter}; #[test] fn is_response_outbound() { @@ -50,24 +49,30 @@ fn is_response_outbound() { let ping_proto1 = RequestResponse::new(PingCodec(), protocols, cfg); let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id); - let request_id1 = swarm1.behaviour_mut().send_request(&offline_peer, ping.clone()); + let request_id1 = swarm1 + .behaviour_mut() + .send_request(&offline_peer, ping.clone()); match futures::executor::block_on(swarm1.select_next_some()) { - SwarmEvent::Behaviour(RequestResponseEvent::OutboundFailure{ + SwarmEvent::Behaviour(RequestResponseEvent::OutboundFailure { peer, request_id: req_id, - error: _error + error: _error, }) => { assert_eq!(&offline_peer, &peer); assert_eq!(req_id, request_id1); - }, + } e => panic!("Peer: Unexpected event: {:?}", e), } let request_id2 = swarm1.behaviour_mut().send_request(&offline_peer, ping); - assert!(!swarm1.behaviour().is_pending_outbound(&offline_peer, &request_id1)); - assert!(swarm1.behaviour().is_pending_outbound(&offline_peer, &request_id2)); + assert!(!swarm1 + .behaviour() + .is_pending_outbound(&offline_peer, &request_id1)); + assert!(swarm1 + .behaviour() + .is_pending_outbound(&offline_peer, &request_id2)); } /// Exercises a simple ping protocol. @@ -98,18 +103,22 @@ fn ping_protocol() { let peer1 = async move { loop { match swarm1.select_next_some().await { - SwarmEvent::NewListenAddr { address, .. }=> tx.send(address).await.unwrap(), + SwarmEvent::NewListenAddr { address, .. } => tx.send(address).await.unwrap(), SwarmEvent::Behaviour(RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Request { request, channel, .. } + message: + RequestResponseMessage::Request { + request, channel, .. + }, }) => { assert_eq!(&request, &expected_ping); assert_eq!(&peer, &peer2_id); - swarm1.behaviour_mut().send_response(channel, pong.clone()).unwrap(); - }, - SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { - peer, .. - }) => { + swarm1 + .behaviour_mut() + .send_response(channel, pong.clone()) + .unwrap(); + } + SwarmEvent::Behaviour(RequestResponseEvent::ResponseSent { peer, .. }) => { assert_eq!(&peer, &peer2_id); } SwarmEvent::Behaviour(e) => panic!("Peer1: Unexpected event: {:?}", e), @@ -131,20 +140,23 @@ fn ping_protocol() { match swarm2.select_next_some().await { SwarmEvent::Behaviour(RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Response { request_id, response } + message: + RequestResponseMessage::Response { + request_id, + response, + }, }) => { count += 1; assert_eq!(&response, &expected_pong); assert_eq!(&peer, &peer1_id); assert_eq!(req_id, request_id); if count >= num_pings { - return + return; } else { req_id = swarm2.behaviour_mut().send_request(&peer1_id, ping.clone()); } - } - SwarmEvent::Behaviour(e) =>panic!("Peer2: Unexpected event: {:?}", e), + SwarmEvent::Behaviour(e) => panic!("Peer2: Unexpected event: {:?}", e), _ => {} } } @@ -273,7 +285,7 @@ fn emits_inbound_connection_closed_if_channel_is_dropped() { let error = match event { RequestResponseEvent::OutboundFailure { error, .. } => error, - e => panic!("unexpected event from peer 2: {:?}", e) + e => panic!("unexpected event from peer 2: {:?}", e), }; assert_eq!(error, OutboundFailure::ConnectionClosed); @@ -306,24 +318,34 @@ fn ping_protocol_throttled() { let limit1: u16 = rand::thread_rng().gen_range(1, 10); let limit2: u16 = rand::thread_rng().gen_range(1, 10); - swarm1.behaviour_mut().set_receive_limit(NonZeroU16::new(limit1).unwrap()); - swarm2.behaviour_mut().set_receive_limit(NonZeroU16::new(limit2).unwrap()); + swarm1 + .behaviour_mut() + .set_receive_limit(NonZeroU16::new(limit1).unwrap()); + swarm2 + .behaviour_mut() + .set_receive_limit(NonZeroU16::new(limit2).unwrap()); let peer1 = async move { - for i in 1 .. { + for i in 1.. { match swarm1.select_next_some().await { SwarmEvent::NewListenAddr { address, .. } => tx.send(address).await.unwrap(), SwarmEvent::Behaviour(throttled::Event::Event(RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Request { request, channel, .. }, + message: + RequestResponseMessage::Request { + request, channel, .. + }, })) => { assert_eq!(&request, &expected_ping); assert_eq!(&peer, &peer2_id); - swarm1.behaviour_mut().send_response(channel, pong.clone()).unwrap(); - }, - SwarmEvent::Behaviour(throttled::Event::Event(RequestResponseEvent::ResponseSent { - peer, .. - })) => { + swarm1 + .behaviour_mut() + .send_response(channel, pong.clone()) + .unwrap(); + } + SwarmEvent::Behaviour(throttled::Event::Event( + RequestResponseEvent::ResponseSent { peer, .. }, + )) => { assert_eq!(&peer, &peer2_id); } SwarmEvent::Behaviour(e) => panic!("Peer1: Unexpected event: {:?}", e), @@ -331,7 +353,9 @@ fn ping_protocol_throttled() { } if i % 31 == 0 { let lim = rand::thread_rng().gen_range(1, 17); - swarm1.behaviour_mut().override_receive_limit(&peer2_id, NonZeroU16::new(lim).unwrap()); + swarm1 + .behaviour_mut() + .override_receive_limit(&peer2_id, NonZeroU16::new(lim).unwrap()); } } }; @@ -348,7 +372,11 @@ fn ping_protocol_throttled() { loop { if !blocked { - while let Some(id) = swarm2.behaviour_mut().send_request(&peer1_id, ping.clone()).ok() { + while let Some(id) = swarm2 + .behaviour_mut() + .send_request(&peer1_id, ping.clone()) + .ok() + { req_ids.insert(id); } blocked = true; @@ -358,19 +386,23 @@ fn ping_protocol_throttled() { assert_eq!(peer, peer1_id); blocked = false } - SwarmEvent::Behaviour(throttled::Event::Event(RequestResponseEvent::Message { + SwarmEvent::Behaviour(throttled::Event::Event(RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Response { request_id, response } + message: + RequestResponseMessage::Response { + request_id, + response, + }, })) => { count += 1; assert_eq!(&response, &expected_pong); assert_eq!(&peer, &peer1_id); assert!(req_ids.remove(&request_id)); if count >= num_pings { - break + break; } } - SwarmEvent::Behaviour(e) =>panic!("Peer2: Unexpected event: {:?}", e), + SwarmEvent::Behaviour(e) => panic!("Peer2: Unexpected event: {:?}", e), _ => {} } } @@ -384,13 +416,18 @@ fn ping_protocol_throttled() { fn mk_transport() -> (PeerId, transport::Boxed<(PeerId, StreamMuxerBox)>) { let id_keys = identity::Keypair::generate_ed25519(); let peer_id = id_keys.public().to_peer_id(); - let noise_keys = Keypair::::new().into_authentic(&id_keys).unwrap(); - (peer_id, TcpConfig::new() - .nodelay(true) - .upgrade() - .authenticate(NoiseConfig::xx(noise_keys).into_authenticated()) - .multiplex(libp2p_yamux::YamuxConfig::default()) - .boxed()) + let noise_keys = Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); + ( + peer_id, + TcpConfig::new() + .nodelay(true) + .upgrade() + .authenticate(NoiseConfig::xx(noise_keys).into_authenticated()) + .multiplex(libp2p_yamux::YamuxConfig::default()) + .boxed(), + ) } // Simple Ping-Pong Protocol @@ -416,38 +453,40 @@ impl RequestResponseCodec for PingCodec { type Request = Ping; type Response = Pong; - async fn read_request(&mut self, _: &PingProtocol, io: &mut T) - -> io::Result + async fn read_request(&mut self, _: &PingProtocol, io: &mut T) -> io::Result where - T: AsyncRead + Unpin + Send + T: AsyncRead + Unpin + Send, { let vec = read_length_prefixed(io, 1024).await?; if vec.is_empty() { - return Err(io::ErrorKind::UnexpectedEof.into()) + return Err(io::ErrorKind::UnexpectedEof.into()); } Ok(Ping(vec)) } - async fn read_response(&mut self, _: &PingProtocol, io: &mut T) - -> io::Result + async fn read_response(&mut self, _: &PingProtocol, io: &mut T) -> io::Result where - T: AsyncRead + Unpin + Send + T: AsyncRead + Unpin + Send, { let vec = read_length_prefixed(io, 1024).await?; if vec.is_empty() { - return Err(io::ErrorKind::UnexpectedEof.into()) + return Err(io::ErrorKind::UnexpectedEof.into()); } Ok(Pong(vec)) } - async fn write_request(&mut self, _: &PingProtocol, io: &mut T, Ping(data): Ping) - -> io::Result<()> + async fn write_request( + &mut self, + _: &PingProtocol, + io: &mut T, + Ping(data): Ping, + ) -> io::Result<()> where - T: AsyncWrite + Unpin + Send + T: AsyncWrite + Unpin + Send, { write_length_prefixed(io, data).await?; io.close().await?; @@ -455,10 +494,14 @@ impl RequestResponseCodec for PingCodec { Ok(()) } - async fn write_response(&mut self, _: &PingProtocol, io: &mut T, Pong(data): Pong) - -> io::Result<()> + async fn write_response( + &mut self, + _: &PingProtocol, + io: &mut T, + Pong(data): Pong, + ) -> io::Result<()> where - T: AsyncWrite + Unpin + Send + T: AsyncWrite + Unpin + Send, { write_length_prefixed(io, data).await?; io.close().await?; diff --git a/src/bandwidth.rs b/src/bandwidth.rs index 87b66653cfc..a341b4dfbab 100644 --- a/src/bandwidth.rs +++ b/src/bandwidth.rs @@ -18,12 +18,26 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{Multiaddr, core::{Transport, transport::{ListenerEvent, TransportError}}}; +use crate::{ + core::{ + transport::{ListenerEvent, TransportError}, + Transport, + }, + Multiaddr, +}; use atomic::Atomic; -use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready}; +use futures::{ + io::{IoSlice, IoSliceMut}, + prelude::*, + ready, +}; use std::{ - convert::TryFrom as _, io, pin::Pin, sync::{atomic::Ordering, Arc}, task::{Context, Poll} + convert::TryFrom as _, + io, + pin::Pin, + sync::{atomic::Ordering, Arc}, + task::{Context, Poll}, }; /// Wraps around a `Transport` and counts the number of bytes that go through all the opened @@ -91,19 +105,18 @@ pub struct BandwidthListener { impl Stream for BandwidthListener where - TInner: TryStream, Error = TErr> + TInner: TryStream, Error = TErr>, { type Item = Result, TErr>, TErr>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - let event = - if let Some(event) = ready!(this.inner.try_poll_next(cx)?) { - event - } else { - return Poll::Ready(None) - }; + let event = if let Some(event) = ready!(this.inner.try_poll_next(cx)?) { + event + } else { + return Poll::Ready(None); + }; let event = event.map({ let sinks = this.sinks.clone(); @@ -129,7 +142,10 @@ impl Future for BandwidthFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let inner = ready!(this.inner.try_poll(cx)?); - let logged = BandwidthConnecLogging { inner, sinks: this.sinks.clone() }; + let logged = BandwidthConnecLogging { + inner, + sinks: this.sinks.clone(), + }; Poll::Ready(Ok(logged)) } } @@ -169,33 +185,61 @@ pub struct BandwidthConnecLogging { } impl AsyncRead for BandwidthConnecLogging { - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { let this = self.project(); let num_bytes = ready!(this.inner.poll_read(cx, buf))?; - this.sinks.inbound.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + this.sinks.inbound.fetch_add( + u64::try_from(num_bytes).unwrap_or(u64::max_value()), + Ordering::Relaxed, + ); Poll::Ready(Ok(num_bytes)) } - fn poll_read_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>]) -> Poll> { + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { let this = self.project(); let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?; - this.sinks.inbound.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + this.sinks.inbound.fetch_add( + u64::try_from(num_bytes).unwrap_or(u64::max_value()), + Ordering::Relaxed, + ); Poll::Ready(Ok(num_bytes)) } } impl AsyncWrite for BandwidthConnecLogging { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { let this = self.project(); let num_bytes = ready!(this.inner.poll_write(cx, buf))?; - this.sinks.outbound.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + this.sinks.outbound.fetch_add( + u64::try_from(num_bytes).unwrap_or(u64::max_value()), + Ordering::Relaxed, + ); Poll::Ready(Ok(num_bytes)) } - fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll> { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { let this = self.project(); let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?; - this.sinks.outbound.fetch_add(u64::try_from(num_bytes).unwrap_or(u64::max_value()), Ordering::Relaxed); + this.sinks.outbound.fetch_add( + u64::try_from(num_bytes).unwrap_or(u64::max_value()), + Ordering::Relaxed, + ); Poll::Ready(Ok(num_bytes)) } diff --git a/src/lib.rs b/src/lib.rs index 21e1334a28d..6b6819706b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,9 +36,9 @@ pub use bytes; pub use futures; #[doc(inline)] -pub use multiaddr; -#[doc(inline)] pub use libp2p_core::multihash; +#[doc(inline)] +pub use multiaddr; #[doc(inline)] pub use libp2p_core as core; @@ -48,18 +48,13 @@ pub use libp2p_core as core; #[doc(inline)] pub use libp2p_deflate as deflate; #[cfg(any(feature = "dns-async-std", feature = "dns-tokio"))] -#[cfg_attr(docsrs, doc(cfg(any(feature = "dns-async-std", feature = "dns-tokio"))))] +#[cfg_attr( + docsrs, + doc(cfg(any(feature = "dns-async-std", feature = "dns-tokio"))) +)] #[cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))] #[doc(inline)] pub use libp2p_dns as dns; -#[cfg(feature = "identify")] -#[cfg_attr(docsrs, doc(cfg(feature = "identify")))] -#[doc(inline)] -pub use libp2p_identify as identify; -#[cfg(feature = "kad")] -#[cfg_attr(docsrs, doc(cfg(feature = "kad")))] -#[doc(inline)] -pub use libp2p_kad as kad; #[cfg(feature = "floodsub")] #[cfg_attr(docsrs, doc(cfg(feature = "floodsub")))] #[doc(inline)] @@ -68,15 +63,23 @@ pub use libp2p_floodsub as floodsub; #[cfg_attr(docsrs, doc(cfg(feature = "gossipsub")))] #[doc(inline)] pub use libp2p_gossipsub as gossipsub; -#[cfg(feature = "mplex")] -#[cfg_attr(docsrs, doc(cfg(feature = "mplex")))] +#[cfg(feature = "identify")] +#[cfg_attr(docsrs, doc(cfg(feature = "identify")))] #[doc(inline)] -pub use libp2p_mplex as mplex; +pub use libp2p_identify as identify; +#[cfg(feature = "kad")] +#[cfg_attr(docsrs, doc(cfg(feature = "kad")))] +#[doc(inline)] +pub use libp2p_kad as kad; #[cfg(feature = "mdns")] #[cfg_attr(docsrs, doc(cfg(feature = "mdns")))] #[cfg(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")))] #[doc(inline)] pub use libp2p_mdns as mdns; +#[cfg(feature = "mplex")] +#[cfg_attr(docsrs, doc(cfg(feature = "mplex")))] +#[doc(inline)] +pub use libp2p_mplex as mplex; #[cfg(feature = "noise")] #[cfg_attr(docsrs, doc(cfg(feature = "noise")))] #[doc(inline)] @@ -89,6 +92,18 @@ pub use libp2p_ping as ping; #[cfg_attr(docsrs, doc(cfg(feature = "plaintext")))] #[doc(inline)] pub use libp2p_plaintext as plaintext; +#[cfg(feature = "pnet")] +#[cfg_attr(docsrs, doc(cfg(feature = "pnet")))] +#[doc(inline)] +pub use libp2p_pnet as pnet; +#[cfg(feature = "relay")] +#[cfg_attr(docsrs, doc(cfg(feature = "relay")))] +#[doc(inline)] +pub use libp2p_relay as relay; +#[cfg(feature = "request-response")] +#[cfg_attr(docsrs, doc(cfg(feature = "request-response")))] +#[doc(inline)] +pub use libp2p_request_response as request_response; #[doc(inline)] pub use libp2p_swarm as swarm; #[cfg(any(feature = "tcp-async-io", feature = "tcp-tokio"))] @@ -113,18 +128,6 @@ pub use libp2p_websocket as websocket; #[cfg_attr(docsrs, doc(cfg(feature = "yamux")))] #[doc(inline)] pub use libp2p_yamux as yamux; -#[cfg(feature = "pnet")] -#[cfg_attr(docsrs, doc(cfg(feature = "pnet")))] -#[doc(inline)] -pub use libp2p_pnet as pnet; -#[cfg(feature = "relay")] -#[cfg_attr(docsrs, doc(cfg(feature = "relay")))] -#[doc(inline)] -pub use libp2p_relay as relay; -#[cfg(feature = "request-response")] -#[cfg_attr(docsrs, doc(cfg(feature = "request-response")))] -#[doc(inline)] -pub use libp2p_request_response as request_response; mod transport_ext; @@ -136,16 +139,15 @@ pub mod tutorial; pub use self::core::{ identity, - PeerId, - Transport, transport::TransportError, - upgrade::{InboundUpgrade, InboundUpgradeExt, OutboundUpgrade, OutboundUpgradeExt} + upgrade::{InboundUpgrade, InboundUpgradeExt, OutboundUpgrade, OutboundUpgradeExt}, + PeerId, Transport, }; -pub use libp2p_swarm_derive::NetworkBehaviour; -pub use self::multiaddr::{Multiaddr, multiaddr as build_multiaddr}; +pub use self::multiaddr::{multiaddr as build_multiaddr, Multiaddr}; pub use self::simple::SimpleProtocol; pub use self::swarm::Swarm; pub use self::transport_ext::TransportExt; +pub use libp2p_swarm_derive::NetworkBehaviour; /// Builds a `Transport` based on TCP/IP that supports the most commonly-used features of libp2p: /// @@ -158,11 +160,30 @@ pub use self::transport_ext::TransportExt; /// /// > **Note**: This `Transport` is not suitable for production usage, as its implementation /// > reserves the right to support additional protocols or remove deprecated protocols. -#[cfg(all(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), feature = "tcp-async-io", feature = "dns-async-std", feature = "websocket", feature = "noise", feature = "mplex", feature = "yamux"))] -#[cfg_attr(docsrs, doc(cfg(all(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), feature = "tcp-async-io", feature = "dns-async-std", feature = "websocket", feature = "noise", feature = "mplex", feature = "yamux"))))] -pub async fn development_transport(keypair: identity::Keypair) - -> std::io::Result> -{ +#[cfg(all( + not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), + feature = "tcp-async-io", + feature = "dns-async-std", + feature = "websocket", + feature = "noise", + feature = "mplex", + feature = "yamux" +))] +#[cfg_attr( + docsrs, + doc(cfg(all( + not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), + feature = "tcp-async-io", + feature = "dns-async-std", + feature = "websocket", + feature = "noise", + feature = "mplex", + feature = "yamux" + ))) +)] +pub async fn development_transport( + keypair: identity::Keypair, +) -> std::io::Result> { let transport = { let tcp = tcp::TcpConfig::new().nodelay(true); let dns_tcp = dns::DnsConfig::system(tcp).await?; @@ -177,7 +198,10 @@ pub async fn development_transport(keypair: identity::Keypair) Ok(transport .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) - .multiplex(core::upgrade::SelectUpgrade::new(yamux::YamuxConfig::default(), mplex::MplexConfig::default())) + .multiplex(core::upgrade::SelectUpgrade::new( + yamux::YamuxConfig::default(), + mplex::MplexConfig::default(), + )) .timeout(std::time::Duration::from_secs(20)) .boxed()) } @@ -193,11 +217,30 @@ pub async fn development_transport(keypair: identity::Keypair) /// /// > **Note**: This `Transport` is not suitable for production usage, as its implementation /// > reserves the right to support additional protocols or remove deprecated protocols. -#[cfg(all(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), feature = "tcp-tokio", feature = "dns-tokio", feature = "websocket", feature = "noise", feature = "mplex", feature = "yamux"))] -#[cfg_attr(docsrs, doc(cfg(all(not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), feature = "tcp-tokio", feature = "dns-tokio", feature = "websocket", feature = "noise", feature = "mplex", feature = "yamux"))))] -pub fn tokio_development_transport(keypair: identity::Keypair) - -> std::io::Result> -{ +#[cfg(all( + not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), + feature = "tcp-tokio", + feature = "dns-tokio", + feature = "websocket", + feature = "noise", + feature = "mplex", + feature = "yamux" +))] +#[cfg_attr( + docsrs, + doc(cfg(all( + not(any(target_os = "emscripten", target_os = "wasi", target_os = "unknown")), + feature = "tcp-tokio", + feature = "dns-tokio", + feature = "websocket", + feature = "noise", + feature = "mplex", + feature = "yamux" + ))) +)] +pub fn tokio_development_transport( + keypair: identity::Keypair, +) -> std::io::Result> { let transport = { let tcp = tcp::TokioTcpConfig::new().nodelay(true); let dns_tcp = dns::TokioDnsConfig::system(tcp)?; @@ -212,7 +255,10 @@ pub fn tokio_development_transport(keypair: identity::Keypair) Ok(transport .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) - .multiplex(core::upgrade::SelectUpgrade::new(yamux::YamuxConfig::default(), mplex::MplexConfig::default())) + .multiplex(core::upgrade::SelectUpgrade::new( + yamux::YamuxConfig::default(), + mplex::MplexConfig::default(), + )) .timeout(std::time::Duration::from_secs(20)) .boxed()) } diff --git a/src/transport_ext.rs b/src/transport_ext.rs index de77007b9c4..fa8926c8380 100644 --- a/src/transport_ext.rs +++ b/src/transport_ext.rs @@ -33,7 +33,7 @@ pub trait TransportExt: Transport { /// of bytes transferred through the sockets. fn with_bandwidth_logging(self) -> (BandwidthLogging, Arc) where - Self: Sized + Self: Sized, { BandwidthLogging::new(self) } diff --git a/src/tutorial.rs b/src/tutorial.rs index 9d88bf54c9d..ddaa71350bc 100644 --- a/src/tutorial.rs +++ b/src/tutorial.rs @@ -349,8 +349,8 @@ //! //! Note: The [`Multiaddr`] at the end being one of the [`Multiaddr`] printed //! earlier in terminal window one. -//! Both peers have to be in the same network with which the address is associated. -//! In our case any printed addresses can be used, as both peers run on the same +//! Both peers have to be in the same network with which the address is associated. +//! In our case any printed addresses can be used, as both peers run on the same //! device. //! //! The two nodes will establish a connection and send each other ping and pong diff --git a/swarm-derive/src/lib.rs b/swarm-derive/src/lib.rs index a5cdf4900ca..3f92e549fa9 100644 --- a/swarm-derive/src/lib.rs +++ b/swarm-derive/src/lib.rs @@ -20,9 +20,9 @@ #![recursion_limit = "256"] -use quote::quote; use proc_macro::TokenStream; -use syn::{parse_macro_input, DeriveInput, Data, DataStruct, Ident}; +use quote::quote; +use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Ident}; /// Generates a delegating `NetworkBehaviour` implementation for the struct this is used for. See /// the trait documentation for better description. @@ -45,27 +45,27 @@ fn build(ast: &DeriveInput) -> TokenStream { fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { let name = &ast.ident; let (_, ty_generics, where_clause) = ast.generics.split_for_impl(); - let multiaddr = quote!{::libp2p::core::Multiaddr}; - let trait_to_impl = quote!{::libp2p::swarm::NetworkBehaviour}; - let net_behv_event_proc = quote!{::libp2p::swarm::NetworkBehaviourEventProcess}; - let either_ident = quote!{::libp2p::core::either::EitherOutput}; - let network_behaviour_action = quote!{::libp2p::swarm::NetworkBehaviourAction}; - let into_protocols_handler = quote!{::libp2p::swarm::IntoProtocolsHandler}; - let protocols_handler = quote!{::libp2p::swarm::ProtocolsHandler}; - let into_proto_select_ident = quote!{::libp2p::swarm::IntoProtocolsHandlerSelect}; - let peer_id = quote!{::libp2p::core::PeerId}; - let connection_id = quote!{::libp2p::core::connection::ConnectionId}; - let connected_point = quote!{::libp2p::core::ConnectedPoint}; - let listener_id = quote!{::libp2p::core::connection::ListenerId}; - - let poll_parameters = quote!{::libp2p::swarm::PollParameters}; + let multiaddr = quote! {::libp2p::core::Multiaddr}; + let trait_to_impl = quote! {::libp2p::swarm::NetworkBehaviour}; + let net_behv_event_proc = quote! {::libp2p::swarm::NetworkBehaviourEventProcess}; + let either_ident = quote! {::libp2p::core::either::EitherOutput}; + let network_behaviour_action = quote! {::libp2p::swarm::NetworkBehaviourAction}; + let into_protocols_handler = quote! {::libp2p::swarm::IntoProtocolsHandler}; + let protocols_handler = quote! {::libp2p::swarm::ProtocolsHandler}; + let into_proto_select_ident = quote! {::libp2p::swarm::IntoProtocolsHandlerSelect}; + let peer_id = quote! {::libp2p::core::PeerId}; + let connection_id = quote! {::libp2p::core::connection::ConnectionId}; + let connected_point = quote! {::libp2p::core::ConnectedPoint}; + let listener_id = quote! {::libp2p::core::connection::ListenerId}; + + let poll_parameters = quote! {::libp2p::swarm::PollParameters}; // Build the generics. let impl_generics = { let tp = ast.generics.type_params(); let lf = ast.generics.lifetimes(); let cst = ast.generics.const_params(); - quote!{<#(#lf,)* #(#tp,)* #(#cst,)*>} + quote! {<#(#lf,)* #(#tp,)* #(#cst,)*>} }; // Whether or not we require the `NetworkBehaviourEventProcess` trait to be implemented. @@ -75,12 +75,14 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { for meta_items in ast.attrs.iter().filter_map(get_meta_items) { for meta_item in meta_items { match meta_item { - syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("event_process") => { + syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) + if m.path.is_ident("event_process") => + { if let syn::Lit::Bool(ref b) = m.lit { event_process = b.value } } - _ => () + _ => (), } } } @@ -92,17 +94,19 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { // If we find a `#[behaviour(out_event = "Foo")]` attribute on the struct, we set `Foo` as // the out event. Otherwise we use `()`. let out_event = { - let mut out = quote!{()}; + let mut out = quote! {()}; for meta_items in ast.attrs.iter().filter_map(get_meta_items) { for meta_item in meta_items { match meta_item { - syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("out_event") => { + syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) + if m.path.is_ident("out_event") => + { if let syn::Lit::Str(ref s) = m.lit { let ident: syn::Type = syn::parse_str(&s.value()).unwrap(); - out = quote!{#ident}; + out = quote! {#ident}; } } - _ => () + _ => (), } } } @@ -111,70 +115,84 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { // Build the `where ...` clause of the trait implementation. let where_clause = { - let additional = data_struct.fields.iter() + let additional = data_struct + .fields + .iter() .filter(|x| !is_ignored(x)) .flat_map(|field| { let ty = &field.ty; vec![ - quote!{#ty: #trait_to_impl}, + quote! {#ty: #trait_to_impl}, if event_process { - quote!{Self: #net_behv_event_proc<<#ty as #trait_to_impl>::OutEvent>} + quote! {Self: #net_behv_event_proc<<#ty as #trait_to_impl>::OutEvent>} } else { - quote!{#out_event: From< <#ty as #trait_to_impl>::OutEvent >} - } + quote! {#out_event: From< <#ty as #trait_to_impl>::OutEvent >} + }, ] }) .collect::>(); if let Some(where_clause) = where_clause { if where_clause.predicates.trailing_punct() { - Some(quote!{#where_clause #(#additional),*}) + Some(quote! {#where_clause #(#additional),*}) } else { - Some(quote!{#where_clause, #(#additional),*}) + Some(quote! {#where_clause, #(#additional),*}) } } else { - Some(quote!{where #(#additional),*}) + Some(quote! {where #(#additional),*}) } }; // Build the list of statements to put in the body of `addresses_of_peer()`. let addresses_of_peer_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ out.extend(self.#i.addresses_of_peer(peer_id)); }, - None => quote!{ out.extend(self.#field_n.addresses_of_peer(peer_id)); }, + Some(match field.ident { + Some(ref i) => quote! { out.extend(self.#i.addresses_of_peer(peer_id)); }, + None => quote! { out.extend(self.#field_n.addresses_of_peer(peer_id)); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_connected()`. let inject_connected_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_connected(peer_id); }, - None => quote!{ self.#field_n.inject_connected(peer_id); }, + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_connected(peer_id); }, + None => quote! { self.#field_n.inject_connected(peer_id); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_disconnected()`. let inject_disconnected_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_disconnected(peer_id); }, - None => quote!{ self.#field_n.inject_disconnected(peer_id); }, + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_disconnected(peer_id); }, + None => quote! { self.#field_n.inject_disconnected(peer_id); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_connection_established()`. @@ -217,8 +235,9 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { }; // Build the list of statements to put in the body of `inject_addr_reach_failure()`. - let inject_addr_reach_failure_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { + let inject_addr_reach_failure_stmts = + { + data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { if is_ignored(&field) { return None; } @@ -228,116 +247,148 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { None => quote!{ self.#field_n.inject_addr_reach_failure(peer_id, addr, error); }, }) }) - }; + }; // Build the list of statements to put in the body of `inject_dial_failure()`. let inject_dial_failure_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_dial_failure(peer_id); }, - None => quote!{ self.#field_n.inject_dial_failure(peer_id); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_dial_failure(peer_id); }, + None => quote! { self.#field_n.inject_dial_failure(peer_id); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_new_listener()`. let inject_new_listener_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_new_listener(id); }, - None => quote!{ self.#field_n.inject_new_listener(id); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_new_listener(id); }, + None => quote! { self.#field_n.inject_new_listener(id); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_new_listen_addr()`. let inject_new_listen_addr_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_new_listen_addr(id, addr); }, - None => quote!{ self.#field_n.inject_new_listen_addr(id, addr); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_new_listen_addr(id, addr); }, + None => quote! { self.#field_n.inject_new_listen_addr(id, addr); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_expired_listen_addr()`. let inject_expired_listen_addr_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_expired_listen_addr(id, addr); }, - None => quote!{ self.#field_n.inject_expired_listen_addr(id, addr); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_expired_listen_addr(id, addr); }, + None => quote! { self.#field_n.inject_expired_listen_addr(id, addr); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_new_external_addr()`. let inject_new_external_addr_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_new_external_addr(addr); }, - None => quote!{ self.#field_n.inject_new_external_addr(addr); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_new_external_addr(addr); }, + None => quote! { self.#field_n.inject_new_external_addr(addr); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_expired_external_addr()`. let inject_expired_external_addr_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None; - } + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } - Some(match field.ident { - Some(ref i) => quote!{ self.#i.inject_expired_external_addr(addr); }, - None => quote!{ self.#field_n.inject_expired_external_addr(addr); }, + Some(match field.ident { + Some(ref i) => quote! { self.#i.inject_expired_external_addr(addr); }, + None => quote! { self.#field_n.inject_expired_external_addr(addr); }, + }) }) - }) }; // Build the list of statements to put in the body of `inject_listener_error()`. let inject_listener_error_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None - } - Some(match field.ident { - Some(ref i) => quote!(self.#i.inject_listener_error(id, err);), - None => quote!(self.#field_n.inject_listener_error(id, err);) + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } + Some(match field.ident { + Some(ref i) => quote!(self.#i.inject_listener_error(id, err);), + None => quote!(self.#field_n.inject_listener_error(id, err);), + }) }) - }) }; // Build the list of statements to put in the body of `inject_listener_closed()`. let inject_listener_closed_stmts = { - data_struct.fields.iter().enumerate().filter_map(move |(field_n, field)| { - if is_ignored(&field) { - return None - } - Some(match field.ident { - Some(ref i) => quote!(self.#i.inject_listener_closed(id, reason);), - None => quote!(self.#field_n.inject_listener_closed(id, reason);) + data_struct + .fields + .iter() + .enumerate() + .filter_map(move |(field_n, field)| { + if is_ignored(&field) { + return None; + } + Some(match field.ident { + Some(ref i) => quote!(self.#i.inject_listener_closed(id, reason);), + None => quote!(self.#field_n.inject_listener_closed(id, reason);), + }) }) - }) }; // Build the list of variants to put in the body of `inject_event()`. @@ -369,13 +420,13 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { continue; } let ty = &field.ty; - let field_info = quote!{ <#ty as #trait_to_impl>::ProtocolsHandler }; + let field_info = quote! { <#ty as #trait_to_impl>::ProtocolsHandler }; match ph_ty { - Some(ev) => ph_ty = Some(quote!{ #into_proto_select_ident<#ev, #field_info> }), + Some(ev) => ph_ty = Some(quote! { #into_proto_select_ident<#ev, #field_info> }), ref mut ev @ None => *ev = Some(field_info), } } - ph_ty.unwrap_or(quote!{()}) // TODO: `!` instead + ph_ty.unwrap_or(quote! {()}) // TODO: `!` instead }; // The content of `new_handler()`. @@ -389,8 +440,8 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { } let field_name = match field.ident { - Some(ref i) => quote!{ self.#i }, - None => quote!{ self.#field_n }, + Some(ref i) => quote! { self.#i }, + None => quote! { self.#field_n }, }; let builder = quote! { @@ -398,29 +449,33 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { }; match out_handler { - Some(h) => out_handler = Some(quote!{ #into_protocols_handler::select(#h, #builder) }), + Some(h) => { + out_handler = Some(quote! { #into_protocols_handler::select(#h, #builder) }) + } ref mut h @ None => *h = Some(builder), } } - out_handler.unwrap_or(quote!{()}) // TODO: incorrect + out_handler.unwrap_or(quote! {()}) // TODO: incorrect }; // The method to use to poll. // If we find a `#[behaviour(poll_method = "poll")]` attribute on the struct, we call // `self.poll()` at the end of the polling. let poll_method = { - let mut poll_method = quote!{std::task::Poll::Pending}; + let mut poll_method = quote! {std::task::Poll::Pending}; for meta_items in ast.attrs.iter().filter_map(get_meta_items) { for meta_item in meta_items { match meta_item { - syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) if m.path.is_ident("poll_method") => { + syn::NestedMeta::Meta(syn::Meta::NameValue(ref m)) + if m.path.is_ident("poll_method") => + { if let syn::Lit::Str(ref s) = m.lit { let ident: Ident = syn::parse_str(&s.value()).unwrap(); - poll_method = quote!{#name::#ident(self, cx, poll_params)}; + poll_method = quote! {#name::#ident(self, cx, poll_params)}; } } - _ => () + _ => (), } } } @@ -489,7 +544,7 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { }); // Now the magic happens. - let final_quote = quote!{ + let final_quote = quote! { impl #impl_generics #trait_to_impl for #name #ty_generics #where_clause { @@ -609,7 +664,7 @@ fn is_ignored(field: &syn::Field) -> bool { syn::NestedMeta::Meta(syn::Meta::Path(ref m)) if m.is_ident("ignore") => { return true; } - _ => () + _ => (), } } } diff --git a/swarm-derive/tests/test.rs b/swarm-derive/tests/test.rs index e1913b7eab9..78a9ed985f9 100644 --- a/swarm-derive/tests/test.rs +++ b/swarm-derive/tests/test.rs @@ -43,8 +43,7 @@ fn one_field() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } #[allow(dead_code)] @@ -63,13 +62,11 @@ fn two_fields() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } #[allow(dead_code)] @@ -91,18 +88,15 @@ fn three_fields() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::kad::KademliaEvent) { - } + fn inject_event(&mut self, _: libp2p::kad::KademliaEvent) {} } #[allow(dead_code)] @@ -123,13 +117,11 @@ fn three_fields_non_last_ignored() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::kad::KademliaEvent) { - } + fn inject_event(&mut self, _: libp2p::kad::KademliaEvent) {} } #[allow(dead_code)] @@ -149,17 +141,21 @@ fn custom_polling() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } impl Foo { - fn foo(&mut self, _: &mut std::task::Context, _: &mut impl libp2p::swarm::PollParameters) -> std::task::Poll> { std::task::Poll::Pending } + fn foo( + &mut self, + _: &mut std::task::Context, + _: &mut impl libp2p::swarm::PollParameters, + ) -> std::task::Poll> { + std::task::Poll::Pending + } } #[allow(dead_code)] @@ -179,13 +175,11 @@ fn custom_event_no_polling() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } #[allow(dead_code)] @@ -205,17 +199,21 @@ fn custom_event_and_polling() { } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl libp2p::swarm::NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) { - } + fn inject_event(&mut self, _: libp2p::identify::IdentifyEvent) {} } impl Foo { - fn foo(&mut self, _: &mut std::task::Context, _: &mut impl libp2p::swarm::PollParameters) -> std::task::Poll> { std::task::Poll::Pending } + fn foo( + &mut self, + _: &mut std::task::Context, + _: &mut impl libp2p::swarm::PollParameters, + ) -> std::task::Poll> { + std::task::Poll::Pending + } } #[allow(dead_code)] @@ -251,13 +249,11 @@ fn nested_derives_with_import() { } impl NetworkBehaviourEventProcess for Foo { - fn inject_event(&mut self, _: libp2p::ping::PingEvent) { - } + fn inject_event(&mut self, _: libp2p::ping::PingEvent) {} } impl NetworkBehaviourEventProcess<()> for Bar { - fn inject_event(&mut self, _: ()) { - } + fn inject_event(&mut self, _: ()) {} } #[allow(dead_code)] @@ -270,7 +266,7 @@ fn nested_derives_with_import() { fn event_process_false() { enum BehaviourOutEvent { Ping(libp2p::ping::PingEvent), - Identify(libp2p::identify::IdentifyEvent) + Identify(libp2p::identify::IdentifyEvent), } impl From for BehaviourOutEvent { @@ -302,7 +298,7 @@ fn event_process_false() { // check that the event is bubbled up all the way to swarm let _ = async { loop { - match _swarm.select_next_some().await { + match _swarm.select_next_some().await { SwarmEvent::Behaviour(BehaviourOutEvent::Ping(_)) => break, SwarmEvent::Behaviour(BehaviourOutEvent::Identify(_)) => break, _ => {} diff --git a/swarm/src/behaviour.rs b/swarm/src/behaviour.rs index 41cf11ffc8e..a21c7a023b8 100644 --- a/swarm/src/behaviour.rs +++ b/swarm/src/behaviour.rs @@ -18,9 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{AddressScore, AddressRecord}; use crate::protocols_handler::{IntoProtocolsHandler, ProtocolsHandler}; -use libp2p_core::{ConnectedPoint, Multiaddr, PeerId, connection::{ConnectionId, ListenerId}}; +use crate::{AddressRecord, AddressScore}; +use libp2p_core::{ + connection::{ConnectionId, ListenerId}, + ConnectedPoint, Multiaddr, PeerId, +}; use std::{error, task::Context, task::Poll}; /// A behaviour for the network. Allows customizing the swarm. @@ -90,7 +93,7 @@ pub trait NetworkBehaviour: Send + 'static { /// /// This method is only called when the first connection to the peer is established, preceded by /// [`inject_connection_established`](NetworkBehaviour::inject_connection_established). - fn inject_connected(&mut self, _: &PeerId) { } + fn inject_connected(&mut self, _: &PeerId) {} /// Indicates to the behaviour that we disconnected from the node with the given peer id. /// @@ -99,19 +102,17 @@ pub trait NetworkBehaviour: Send + 'static { /// /// This method is only called when the last established connection to the peer is closed, /// preceded by [`inject_connection_closed`](NetworkBehaviour::inject_connection_closed). - fn inject_disconnected(&mut self, _: &PeerId) { } + fn inject_disconnected(&mut self, _: &PeerId) {} /// Informs the behaviour about a newly established connection to a peer. - fn inject_connection_established(&mut self, _: &PeerId, _: &ConnectionId, _: &ConnectedPoint) - {} + fn inject_connection_established(&mut self, _: &PeerId, _: &ConnectionId, _: &ConnectedPoint) {} /// Informs the behaviour about a closed connection to a peer. /// /// A call to this method is always paired with an earlier call to /// `inject_connection_established` with the same peer ID, connection ID and /// endpoint. - fn inject_connection_closed(&mut self, _: &PeerId, _: &ConnectionId, _: &ConnectedPoint) - {} + fn inject_connection_closed(&mut self, _: &PeerId, _: &ConnectionId, _: &ConnectedPoint) {} /// Informs the behaviour that the [`ConnectedPoint`] of an existing connection has changed. fn inject_address_change( @@ -119,8 +120,9 @@ pub trait NetworkBehaviour: Send + 'static { _: &PeerId, _: &ConnectionId, _old: &ConnectedPoint, - _new: &ConnectedPoint - ) {} + _new: &ConnectedPoint, + ) { + } /// Informs the behaviour about an event generated by the handler dedicated to the peer identified by `peer_id`. /// for the behaviour. @@ -131,14 +133,19 @@ pub trait NetworkBehaviour: Send + 'static { &mut self, peer_id: PeerId, connection: ConnectionId, - event: <::Handler as ProtocolsHandler>::OutEvent + event: <::Handler as ProtocolsHandler>::OutEvent, ); /// Indicates to the behaviour that we tried to reach an address, but failed. /// /// If we were trying to reach a specific node, its ID is passed as parameter. If this is the /// last address to attempt for the given node, then `inject_dial_failure` is called afterwards. - fn inject_addr_reach_failure(&mut self, _peer_id: Option<&PeerId>, _addr: &Multiaddr, _error: &dyn error::Error) { + fn inject_addr_reach_failure( + &mut self, + _peer_id: Option<&PeerId>, + _addr: &Multiaddr, + _error: &dyn error::Error, + ) { } /// Indicates to the behaviour that we tried to dial all the addresses known for a node, but @@ -146,37 +153,30 @@ pub trait NetworkBehaviour: Send + 'static { /// /// The `peer_id` is guaranteed to be in a disconnected state. In other words, /// `inject_connected` has not been called, or `inject_disconnected` has been called since then. - fn inject_dial_failure(&mut self, _peer_id: &PeerId) { - } + fn inject_dial_failure(&mut self, _peer_id: &PeerId) {} /// Indicates to the behaviour that a new listener was created. - fn inject_new_listener(&mut self, _id: ListenerId) { - } + fn inject_new_listener(&mut self, _id: ListenerId) {} /// Indicates to the behaviour that we have started listening on a new multiaddr. - fn inject_new_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) { - } + fn inject_new_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) {} /// Indicates to the behaviour that a multiaddr we were listening on has expired, /// which means that we are no longer listening in it. - fn inject_expired_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) { - } + fn inject_expired_listen_addr(&mut self, _id: ListenerId, _addr: &Multiaddr) {} /// A listener experienced an error. fn inject_listener_error(&mut self, _id: ListenerId, _err: &(dyn std::error::Error + 'static)) { } /// A listener closed. - fn inject_listener_closed(&mut self, _id: ListenerId, _reason: Result<(), &std::io::Error>) { - } + fn inject_listener_closed(&mut self, _id: ListenerId, _reason: Result<(), &std::io::Error>) {} /// Indicates to the behaviour that we have discovered a new external address for us. - fn inject_new_external_addr(&mut self, _addr: &Multiaddr) { - } + fn inject_new_external_addr(&mut self, _addr: &Multiaddr) {} /// Indicates to the behaviour that an external address was removed. - fn inject_expired_external_addr(&mut self, _addr: &Multiaddr) { - } + fn inject_expired_external_addr(&mut self, _addr: &Multiaddr) {} /// Polls for things that swarm should do. /// @@ -311,47 +311,71 @@ pub enum NetworkBehaviourAction { peer_id: PeerId, /// Whether to close a specific or all connections to the given peer. connection: CloseConnection, - } + }, } impl NetworkBehaviourAction { /// Map the handler event. pub fn map_in(self, f: impl FnOnce(TInEvent) -> E) -> NetworkBehaviourAction { match self { - NetworkBehaviourAction::GenerateEvent(e) => - NetworkBehaviourAction::GenerateEvent(e), - NetworkBehaviourAction::DialAddress { address } => - NetworkBehaviourAction::DialAddress { address }, - NetworkBehaviourAction::DialPeer { peer_id, condition } => - NetworkBehaviourAction::DialPeer { peer_id, condition }, - NetworkBehaviourAction::NotifyHandler { peer_id, handler, event } => - NetworkBehaviourAction::NotifyHandler { - peer_id, - handler, - event: f(event) - }, - NetworkBehaviourAction::ReportObservedAddr { address, score } => - NetworkBehaviourAction::ReportObservedAddr { address, score }, - NetworkBehaviourAction::CloseConnection { peer_id, connection } => - NetworkBehaviourAction::CloseConnection { peer_id, connection } + NetworkBehaviourAction::GenerateEvent(e) => NetworkBehaviourAction::GenerateEvent(e), + NetworkBehaviourAction::DialAddress { address } => { + NetworkBehaviourAction::DialAddress { address } + } + NetworkBehaviourAction::DialPeer { peer_id, condition } => { + NetworkBehaviourAction::DialPeer { peer_id, condition } + } + NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + } => NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event: f(event), + }, + NetworkBehaviourAction::ReportObservedAddr { address, score } => { + NetworkBehaviourAction::ReportObservedAddr { address, score } + } + NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + } => NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }, } } /// Map the event the swarm will return. pub fn map_out(self, f: impl FnOnce(TOutEvent) -> E) -> NetworkBehaviourAction { match self { - NetworkBehaviourAction::GenerateEvent(e) => - NetworkBehaviourAction::GenerateEvent(f(e)), - NetworkBehaviourAction::DialAddress { address } => - NetworkBehaviourAction::DialAddress { address }, - NetworkBehaviourAction::DialPeer { peer_id, condition } => - NetworkBehaviourAction::DialPeer { peer_id, condition }, - NetworkBehaviourAction::NotifyHandler { peer_id, handler, event } => - NetworkBehaviourAction::NotifyHandler { peer_id, handler, event }, - NetworkBehaviourAction::ReportObservedAddr { address, score } => - NetworkBehaviourAction::ReportObservedAddr { address, score }, - NetworkBehaviourAction::CloseConnection { peer_id, connection } => - NetworkBehaviourAction::CloseConnection { peer_id, connection } + NetworkBehaviourAction::GenerateEvent(e) => NetworkBehaviourAction::GenerateEvent(f(e)), + NetworkBehaviourAction::DialAddress { address } => { + NetworkBehaviourAction::DialAddress { address } + } + NetworkBehaviourAction::DialPeer { peer_id, condition } => { + NetworkBehaviourAction::DialPeer { peer_id, condition } + } + NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + } => NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + }, + NetworkBehaviourAction::ReportObservedAddr { address, score } => { + NetworkBehaviourAction::ReportObservedAddr { address, score } + } + NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + } => NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }, } } } diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 778e01f1898..8b0fe09703f 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -63,62 +63,42 @@ pub mod protocols_handler; pub mod toggle; pub use behaviour::{ - NetworkBehaviour, - NetworkBehaviourAction, - NetworkBehaviourEventProcess, - PollParameters, - NotifyHandler, - DialPeerCondition, - CloseConnection + CloseConnection, DialPeerCondition, NetworkBehaviour, NetworkBehaviourAction, + NetworkBehaviourEventProcess, NotifyHandler, PollParameters, }; pub use protocols_handler::{ - IntoProtocolsHandler, - IntoProtocolsHandlerSelect, - KeepAlive, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerSelect, - ProtocolsHandlerUpgrErr, - OneShotHandler, - OneShotHandlerConfig, - SubstreamProtocol + IntoProtocolsHandler, IntoProtocolsHandlerSelect, KeepAlive, OneShotHandler, + OneShotHandlerConfig, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerSelect, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; -pub use registry::{AddressScore, AddressRecord, AddAddressResult}; - -use protocols_handler::{ - NodeHandlerWrapperBuilder, - NodeHandlerWrapperError, -}; -use futures::{ - prelude::*, - executor::ThreadPoolBuilder, - stream::FusedStream, -}; -use libp2p_core::{Executor, Multiaddr, Negotiated, PeerId, Transport, connection::{ - ConnectionError, - ConnectionId, - ConnectionLimit, - ConnectedPoint, - EstablishedConnection, - ConnectionHandler, - IntoConnectionHandler, - ListenerId, - PendingConnectionError, - Substream - }, muxing::StreamMuxerBox, network::{ - self, - ConnectionLimits, - Network, +pub use registry::{AddAddressResult, AddressRecord, AddressScore}; + +use futures::{executor::ThreadPoolBuilder, prelude::*, stream::FusedStream}; +use libp2p_core::{ + connection::{ + ConnectedPoint, ConnectionError, ConnectionHandler, ConnectionId, ConnectionLimit, + EstablishedConnection, IntoConnectionHandler, ListenerId, PendingConnectionError, + Substream, + }, + muxing::StreamMuxerBox, + network::{ + self, peer::ConnectedPeer, ConnectionLimits, Network, NetworkConfig, NetworkEvent, NetworkInfo, - NetworkEvent, - NetworkConfig, - peer::ConnectedPeer, - }, transport::{self, TransportError}, upgrade::{ProtocolName}}; -use registry::{Addresses, AddressIntoIter}; + }, + transport::{self, TransportError}, + upgrade::ProtocolName, + Executor, Multiaddr, Negotiated, PeerId, Transport, +}; +use protocols_handler::{NodeHandlerWrapperBuilder, NodeHandlerWrapperError}; +use registry::{AddressIntoIter, Addresses}; use smallvec::SmallVec; -use std::{error, fmt, io, pin::Pin, task::{Context, Poll}}; use std::collections::HashSet; use std::num::{NonZeroU32, NonZeroUsize}; +use std::{ + error, fmt, io, + pin::Pin, + task::{Context, Poll}, +}; use upgrade::UpgradeInfoSend as _; /// Substream for which a protocol has been chosen. @@ -136,13 +116,16 @@ type THandler = ::ProtocolsHandler; /// Custom event that can be received by the [`ProtocolsHandler`] of the /// [`NetworkBehaviour`]. -type THandlerInEvent = < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::InEvent; +type THandlerInEvent = + < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::InEvent; /// Custom event that can be produced by the [`ProtocolsHandler`] of the [`NetworkBehaviour`]. -type THandlerOutEvent = < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::OutEvent; +type THandlerOutEvent = + < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::OutEvent; /// Custom error that can be produced by the [`ProtocolsHandler`] of the [`NetworkBehaviour`]. -type THandlerErr = < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::Error; +type THandlerErr = + < as IntoProtocolsHandler>::Handler as ProtocolsHandler>::Error; /// Event generated by the `Swarm`. #[derive(Debug)] @@ -228,18 +211,18 @@ pub enum SwarmEvent { error: PendingConnectionError, }, /// One of our listeners has reported a new local listening address. - NewListenAddr{ + NewListenAddr { /// The listener that is listening on the new address. listener_id: ListenerId, /// The new address that is being listened on. - address: Multiaddr + address: Multiaddr, }, /// One of our listeners has reported the expiration of a listening address. - ExpiredListenAddr{ + ExpiredListenAddr { /// The listener that is no longer listening on the address. listener_id: ListenerId, /// The expired address. - address: Multiaddr + address: Multiaddr, }, /// One of the listeners gracefully closed. ListenerClosed { @@ -308,11 +291,7 @@ where substream_upgrade_protocol_override: Option, } -impl Unpin for Swarm -where - TBehaviour: NetworkBehaviour, -{ -} +impl Unpin for Swarm where TBehaviour: NetworkBehaviour {} impl Swarm where @@ -322,7 +301,7 @@ where pub fn new( transport: transport::Boxed<(PeerId, StreamMuxerBox)>, behaviour: TBehaviour, - local_peer_id: PeerId + local_peer_id: PeerId, ) -> Self { SwarmBuilder::new(transport, behaviour, local_peer_id).build() } @@ -352,7 +331,9 @@ where /// Initiates a new dialing attempt to the given address. pub fn dial_addr(&mut self, addr: Multiaddr) -> Result<(), DialError> { - let handler = self.behaviour.new_handler() + let handler = self + .behaviour + .new_handler() .into_node_handler_builder() .with_substream_upgrade_protocol_override(self.substream_upgrade_protocol_override); Ok(self.network.dial(&addr, handler).map(|_id| ())?) @@ -362,31 +343,37 @@ where pub fn dial(&mut self, peer_id: &PeerId) -> Result<(), DialError> { if self.banned_peers.contains(peer_id) { self.behaviour.inject_dial_failure(peer_id); - return Err(DialError::Banned) + return Err(DialError::Banned); } let self_listening = &self.listened_addrs; - let mut addrs = self.behaviour.addresses_of_peer(peer_id) + let mut addrs = self + .behaviour + .addresses_of_peer(peer_id) .into_iter() .filter(|a| !self_listening.contains(a)); - let result = - if let Some(first) = addrs.next() { - let handler = self.behaviour.new_handler() - .into_node_handler_builder() - .with_substream_upgrade_protocol_override(self.substream_upgrade_protocol_override); - self.network.peer(*peer_id) - .dial(first, addrs, handler) - .map(|_| ()) - .map_err(DialError::from) - } else { - Err(DialError::NoAddresses) - }; + let result = if let Some(first) = addrs.next() { + let handler = self + .behaviour + .new_handler() + .into_node_handler_builder() + .with_substream_upgrade_protocol_override(self.substream_upgrade_protocol_override); + self.network + .peer(*peer_id) + .dial(first, addrs, handler) + .map(|_| ()) + .map_err(DialError::from) + } else { + Err(DialError::NoAddresses) + }; if let Err(error) = &result { log::debug!( "New dialing attempt to peer {:?} failed: {:?}.", - peer_id, error); + peer_id, + error + ); self.behaviour.inject_dial_failure(&peer_id); } @@ -508,9 +495,10 @@ where /// Internal function used by everything event-related. /// /// Polls the `Swarm` for the next event. - fn poll_next_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll>> - { + fn poll_next_event( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { // We use a `this` variable because the compiler can't mutably borrow multiple times // across a `Deref`. let this = &mut *self; @@ -525,38 +513,62 @@ where let peer = connection.peer_id(); let connection = connection.id(); this.behaviour.inject_event(peer, connection, event); - }, - Poll::Ready(NetworkEvent::AddressChange { connection, new_endpoint, old_endpoint }) => { + } + Poll::Ready(NetworkEvent::AddressChange { + connection, + new_endpoint, + old_endpoint, + }) => { let peer = connection.peer_id(); let connection = connection.id(); - this.behaviour.inject_address_change(&peer, &connection, &old_endpoint, &new_endpoint); - }, - Poll::Ready(NetworkEvent::ConnectionEstablished { connection, num_established }) => { + this.behaviour.inject_address_change( + &peer, + &connection, + &old_endpoint, + &new_endpoint, + ); + } + Poll::Ready(NetworkEvent::ConnectionEstablished { + connection, + num_established, + }) => { let peer_id = connection.peer_id(); let endpoint = connection.endpoint().clone(); if this.banned_peers.contains(&peer_id) { - this.network.peer(peer_id) + this.network + .peer(peer_id) .into_connected() .expect("the Network just notified us that we were connected; QED") .disconnect(); - return Poll::Ready(SwarmEvent::BannedPeer { - peer_id, - endpoint, - }); + return Poll::Ready(SwarmEvent::BannedPeer { peer_id, endpoint }); } else { - log::debug!("Connection established: {:?}; Total (peer): {}.", - connection.connected(), num_established); + log::debug!( + "Connection established: {:?}; Total (peer): {}.", + connection.connected(), + num_established + ); let endpoint = connection.endpoint().clone(); - this.behaviour.inject_connection_established(&peer_id, &connection.id(), &endpoint); + this.behaviour.inject_connection_established( + &peer_id, + &connection.id(), + &endpoint, + ); if num_established.get() == 1 { this.behaviour.inject_connected(&peer_id); } return Poll::Ready(SwarmEvent::ConnectionEstablished { - peer_id, num_established, endpoint + peer_id, + num_established, + endpoint, }); } - }, - Poll::Ready(NetworkEvent::ConnectionClosed { id, connected, error, num_established }) => { + } + Poll::Ready(NetworkEvent::ConnectionClosed { + id, + connected, + error, + num_established, + }) => { if let Some(error) = error.as_ref() { log::debug!("Connection {:?} closed: {:?}", connected, error); } else { @@ -564,7 +576,8 @@ where } let peer_id = connected.peer_id; let endpoint = connected.endpoint; - this.behaviour.inject_connection_closed(&peer_id, &id, &endpoint); + this.behaviour + .inject_connection_closed(&peer_id, &id, &endpoint); if num_established == 0 { this.behaviour.inject_disconnected(&peer_id); } @@ -574,11 +587,15 @@ where cause: error, num_established, }); - }, + } Poll::Ready(NetworkEvent::IncomingConnection { connection, .. }) => { - let handler = this.behaviour.new_handler() + let handler = this + .behaviour + .new_handler() .into_node_handler_builder() - .with_substream_upgrade_protocol_override(this.substream_upgrade_protocol_override); + .with_substream_upgrade_protocol_override( + this.substream_upgrade_protocol_override, + ); let local_addr = connection.local_addr.clone(); let send_back_addr = connection.send_back_addr.clone(); if let Err(e) = this.network.accept(connection, handler) { @@ -588,36 +605,55 @@ where local_addr, send_back_addr, }); - }, - Poll::Ready(NetworkEvent::NewListenerAddress { listener_id, listen_addr }) => { + } + Poll::Ready(NetworkEvent::NewListenerAddress { + listener_id, + listen_addr, + }) => { log::debug!("Listener {:?}; New address: {:?}", listener_id, listen_addr); if !this.listened_addrs.contains(&listen_addr) { this.listened_addrs.push(listen_addr.clone()) } - this.behaviour.inject_new_listen_addr(listener_id, &listen_addr); + this.behaviour + .inject_new_listen_addr(listener_id, &listen_addr); return Poll::Ready(SwarmEvent::NewListenAddr { listener_id, - address: listen_addr + address: listen_addr, }); } - Poll::Ready(NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr }) => { - log::debug!("Listener {:?}; Expired address {:?}.", listener_id, listen_addr); + Poll::Ready(NetworkEvent::ExpiredListenerAddress { + listener_id, + listen_addr, + }) => { + log::debug!( + "Listener {:?}; Expired address {:?}.", + listener_id, + listen_addr + ); this.listened_addrs.retain(|a| a != &listen_addr); - this.behaviour.inject_expired_listen_addr(listener_id, &listen_addr); - return Poll::Ready(SwarmEvent::ExpiredListenAddr{ + this.behaviour + .inject_expired_listen_addr(listener_id, &listen_addr); + return Poll::Ready(SwarmEvent::ExpiredListenAddr { listener_id, - address: listen_addr + address: listen_addr, }); } - Poll::Ready(NetworkEvent::ListenerClosed { listener_id, addresses, reason }) => { + Poll::Ready(NetworkEvent::ListenerClosed { + listener_id, + addresses, + reason, + }) => { log::debug!("Listener {:?}; Closed by {:?}.", listener_id, reason); for addr in addresses.iter() { this.behaviour.inject_expired_listen_addr(listener_id, addr); } - this.behaviour.inject_listener_closed(listener_id, match &reason { - Ok(()) => Ok(()), - Err(err) => Err(err), - }); + this.behaviour.inject_listener_closed( + listener_id, + match &reason { + Ok(()) => Ok(()), + Err(err) => Err(err), + }, + ); return Poll::Ready(SwarmEvent::ListenerClosed { listener_id, addresses, @@ -626,24 +662,31 @@ where } Poll::Ready(NetworkEvent::ListenerError { listener_id, error }) => { this.behaviour.inject_listener_error(listener_id, &error); - return Poll::Ready(SwarmEvent::ListenerError { - listener_id, - error, - }); - }, - Poll::Ready(NetworkEvent::IncomingConnectionError { local_addr, send_back_addr, error }) => { + return Poll::Ready(SwarmEvent::ListenerError { listener_id, error }); + } + Poll::Ready(NetworkEvent::IncomingConnectionError { + local_addr, + send_back_addr, + error, + }) => { log::debug!("Incoming connection failed: {:?}", error); return Poll::Ready(SwarmEvent::IncomingConnectionError { local_addr, send_back_addr, error, }); - }, - Poll::Ready(NetworkEvent::DialError { peer_id, multiaddr, error, attempts_remaining }) => { + } + Poll::Ready(NetworkEvent::DialError { + peer_id, + multiaddr, + error, + attempts_remaining, + }) => { log::debug!( "Connection attempt to {:?} via {:?} failed with {:?}. Attempts remaining: {}.", peer_id, multiaddr, error, attempts_remaining); - this.behaviour.inject_addr_reach_failure(Some(&peer_id), &multiaddr, &error); + this.behaviour + .inject_addr_reach_failure(Some(&peer_id), &multiaddr, &error); if attempts_remaining == 0 { this.behaviour.inject_dial_failure(&peer_id); } @@ -653,16 +696,22 @@ where error, attempts_remaining, }); - }, - Poll::Ready(NetworkEvent::UnknownPeerDialError { multiaddr, error, .. }) => { - log::debug!("Connection attempt to address {:?} of unknown peer failed with {:?}", - multiaddr, error); - this.behaviour.inject_addr_reach_failure(None, &multiaddr, &error); + } + Poll::Ready(NetworkEvent::UnknownPeerDialError { + multiaddr, error, .. + }) => { + log::debug!( + "Connection attempt to address {:?} of unknown peer failed with {:?}", + multiaddr, + error + ); + this.behaviour + .inject_addr_reach_failure(None, &multiaddr, &error); return Poll::Ready(SwarmEvent::UnknownPeerUnreachableAddr { address: multiaddr, error, }); - }, + } } // After the network had a chance to make progress, try to deliver @@ -673,18 +722,21 @@ where if let Some((peer_id, handler, event)) = this.pending_event.take() { if let Some(mut peer) = this.network.peer(peer_id).into_connected() { match handler { - PendingNotifyHandler::One(conn_id) => + PendingNotifyHandler::One(conn_id) => { if let Some(mut conn) = peer.connection(conn_id) { if let Some(event) = notify_one(&mut conn, event, cx) { this.pending_event = Some((peer_id, handler, event)); - return Poll::Pending + return Poll::Pending; } - }, + } + } PendingNotifyHandler::Any(ids) => { - if let Some((event, ids)) = notify_any::<_, _, TBehaviour>(ids, &mut peer, event, cx) { + if let Some((event, ids)) = + notify_any::<_, _, TBehaviour>(ids, &mut peer, event, cx) + { let handler = PendingNotifyHandler::Any(ids); this.pending_event = Some((peer_id, handler, event)); - return Poll::Pending + return Poll::Pending; } } } @@ -698,7 +750,7 @@ where local_peer_id: &mut this.network.local_peer_id(), supported_protocols: &this.supported_protocols, listened_addrs: &this.listened_addrs, - external_addrs: &this.external_addrs + external_addrs: &this.external_addrs, }; this.behaviour.poll(cx, &mut parameters) }; @@ -708,29 +760,34 @@ where Poll::Pending => (), Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) => { return Poll::Ready(SwarmEvent::Behaviour(event)) - }, + } Poll::Ready(NetworkBehaviourAction::DialAddress { address }) => { let _ = Swarm::dial_addr(&mut *this, address); - }, + } Poll::Ready(NetworkBehaviourAction::DialPeer { peer_id, condition }) => { if this.banned_peers.contains(&peer_id) { this.behaviour.inject_dial_failure(&peer_id); } else { let condition_matched = match condition { - DialPeerCondition::Disconnected => this.network.is_disconnected(&peer_id), + DialPeerCondition::Disconnected => { + this.network.is_disconnected(&peer_id) + } DialPeerCondition::NotDialing => !this.network.is_dialing(&peer_id), DialPeerCondition::Always => true, }; if condition_matched { if Swarm::dial(this, &peer_id).is_ok() { - return Poll::Ready(SwarmEvent::Dialing(peer_id)) + return Poll::Ready(SwarmEvent::Dialing(peer_id)); } } else { // Even if the condition for a _new_ dialing attempt is not met, // we always add any potentially new addresses of the peer to an // ongoing dialing attempt, if there is one. - log::trace!("Condition for new dialing attempt to {:?} not met: {:?}", - peer_id, condition); + log::trace!( + "Condition for new dialing attempt to {:?} not met: {:?}", + peer_id, + condition + ); let self_listening = &this.listened_addrs; if let Some(mut peer) = this.network.peer(peer_id).into_dialing() { let addrs = this.behaviour.addresses_of_peer(peer.id()); @@ -743,8 +800,12 @@ where } } } - }, - Poll::Ready(NetworkBehaviourAction::NotifyHandler { peer_id, handler, event }) => { + } + Poll::Ready(NetworkBehaviourAction::NotifyHandler { + peer_id, + handler, + event, + }) => { if let Some(mut peer) = this.network.peer(peer_id).into_connected() { match handler { NotifyHandler::One(connection) => { @@ -752,27 +813,32 @@ where if let Some(event) = notify_one(&mut conn, event, cx) { let handler = PendingNotifyHandler::One(connection); this.pending_event = Some((peer_id, handler, event)); - return Poll::Pending + return Poll::Pending; } } } NotifyHandler::Any => { let ids = peer.connections().into_ids().collect(); - if let Some((event, ids)) = notify_any::<_, _, TBehaviour>(ids, &mut peer, event, cx) { + if let Some((event, ids)) = + notify_any::<_, _, TBehaviour>(ids, &mut peer, event, cx) + { let handler = PendingNotifyHandler::Any(ids); this.pending_event = Some((peer_id, handler, event)); - return Poll::Pending + return Poll::Pending; } } } } - }, + } Poll::Ready(NetworkBehaviourAction::ReportObservedAddr { address, score }) => { for addr in this.network.address_translation(&address) { this.add_external_address(addr, score); } - }, - Poll::Ready(NetworkBehaviourAction::CloseConnection { peer_id, connection }) => { + } + Poll::Ready(NetworkBehaviourAction::CloseConnection { + peer_id, + connection, + }) => { if let Some(mut peer) = this.network.peer(peer_id).into_connected() { match connection { CloseConnection::One(connection_id) => { @@ -785,7 +851,7 @@ where } } } - }, + } } } } @@ -814,8 +880,7 @@ fn notify_one<'a, THandlerInEvent>( conn: &mut EstablishedConnection<'a, THandlerInEvent>, event: THandlerInEvent, cx: &mut Context<'_>, -) -> Option -{ +) -> Option { match conn.poll_ready_notify_handler(cx) { Poll::Pending => Some(event), Poll::Ready(Err(())) => None, // connection is closing @@ -847,7 +912,10 @@ where TTrans: Transport, TBehaviour: NetworkBehaviour, THandler: IntoConnectionHandler, - THandler::Handler: ConnectionHandler, OutEvent = THandlerOutEvent> + THandler::Handler: ConnectionHandler< + InEvent = THandlerInEvent, + OutEvent = THandlerOutEvent, + >, { let mut pending = SmallVec::new(); let mut event = Some(event); // (1) @@ -861,19 +929,20 @@ where if let Err(e) = conn.notify_handler(e) { event = Some(e) // (2) } else { - break + break; } } } } } - event.and_then(|e| + event.and_then(|e| { if !pending.is_empty() { Some((e, pending)) } else { None - }) + } + }) } /// Stream of events returned by [`Swarm`]. @@ -890,9 +959,7 @@ where type Item = SwarmEvent, THandlerErr>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.as_mut() - .poll_next_event(cx) - .map(Some) + self.as_mut().poll_next_event(cx).map(Some) } } @@ -957,7 +1024,7 @@ where pub fn new( transport: transport::Boxed<(PeerId, StreamMuxerBox)>, behaviour: TBehaviour, - local_peer_id: PeerId + local_peer_id: PeerId, ) -> Self { SwarmBuilder { local_peer_id, @@ -1042,7 +1109,8 @@ where /// Builds a `Swarm` with the current configuration. pub fn build(mut self) -> Swarm { - let supported_protocols = self.behaviour + let supported_protocols = self + .behaviour .new_handler() .inbound_protocol() .protocol_info() @@ -1051,20 +1119,19 @@ where .collect(); // If no executor has been explicitly configured, try to set up a thread pool. - let network_cfg = self.network_config.or_else_with_executor(|| { - match ThreadPoolBuilder::new() - .name_prefix("libp2p-swarm-task-") - .create() - { - Ok(tp) => { - Some(Box::new(move |f| tp.spawn_ok(f))) - }, - Err(err) => { - log::warn!("Failed to create executor thread pool: {:?}", err); - None + let network_cfg = + self.network_config.or_else_with_executor(|| { + match ThreadPoolBuilder::new() + .name_prefix("libp2p-swarm-task-") + .create() + { + Ok(tp) => Some(Box::new(move |f| tp.spawn_ok(f))), + Err(err) => { + log::warn!("Failed to create executor thread pool: {:?}", err); + None + } } - } - }); + }); let network = Network::new(self.transport, self.local_peer_id, network_cfg); @@ -1093,7 +1160,7 @@ pub enum DialError { InvalidAddress(Multiaddr), /// [`NetworkBehaviour::addresses_of_peer`] returned no addresses /// for the peer to dial. - NoAddresses + NoAddresses, } impl From for DialError { @@ -1111,7 +1178,7 @@ impl fmt::Display for DialError { DialError::ConnectionLimit(err) => write!(f, "Dial error: {}", err), DialError::NoAddresses => write!(f, "Dial error: no addresses for peer."), DialError::InvalidAddress(a) => write!(f, "Dial error: invalid address: {}", a), - DialError::Banned => write!(f, "Dial error: peer is banned.") + DialError::Banned => write!(f, "Dial error: peer is banned."), } } } @@ -1122,7 +1189,7 @@ impl error::Error for DialError { DialError::ConnectionLimit(err) => Some(err), DialError::InvalidAddress(_) => None, DialError::NoAddresses => None, - DialError::Banned => None + DialError::Banned => None, } } } @@ -1130,14 +1197,12 @@ impl error::Error for DialError { /// Dummy implementation of [`NetworkBehaviour`] that doesn't do anything. #[derive(Clone)] pub struct DummyBehaviour { - keep_alive: KeepAlive + keep_alive: KeepAlive, } impl DummyBehaviour { pub fn with_keep_alive(keep_alive: KeepAlive) -> Self { - Self { - keep_alive - } + Self { keep_alive } } pub fn keep_alive_mut(&mut self) -> &mut KeepAlive { @@ -1148,7 +1213,7 @@ impl DummyBehaviour { impl Default for DummyBehaviour { fn default() -> Self { Self { - keep_alive: KeepAlive::No + keep_alive: KeepAlive::No, } } } @@ -1159,7 +1224,7 @@ impl NetworkBehaviour for DummyBehaviour { fn new_handler(&mut self) -> Self::ProtocolsHandler { protocols_handler::DummyProtocolsHandler { - keep_alive: self.keep_alive + keep_alive: self.keep_alive, } } @@ -1167,31 +1232,33 @@ impl NetworkBehaviour for DummyBehaviour { &mut self, _: PeerId, _: ConnectionId, - event: ::OutEvent + event: ::OutEvent, ) { void::unreachable(event) } - fn poll(&mut self, _: &mut Context<'_>, _: &mut impl PollParameters) -> - Poll::InEvent, Self::OutEvent>> - { + fn poll( + &mut self, + _: &mut Context<'_>, + _: &mut impl PollParameters, + ) -> Poll< + NetworkBehaviourAction< + ::InEvent, + Self::OutEvent, + >, + > { Poll::Pending } } #[cfg(test)] mod tests { + use super::*; use crate::protocols_handler::DummyProtocolsHandler; - use crate::test::{MockBehaviour, CallTraceBehaviour}; - use futures::{future, executor}; - use libp2p_core::{ - identity, - multiaddr, - transport - }; + use crate::test::{CallTraceBehaviour, MockBehaviour}; + use futures::{executor, future}; + use libp2p_core::{identity, multiaddr, transport}; use libp2p_noise as noise; - use super::*; // Test execution state. // Connection => Disconnecting => Connecting. @@ -1204,11 +1271,13 @@ mod tests { where T: ProtocolsHandler + Clone, T::OutEvent: Clone, - O: Send + 'static + O: Send + 'static, { let id_keys = identity::Keypair::generate_ed25519(); let pubkey = id_keys.public(); - let noise_keys = noise::Keypair::::new().into_authentic(&id_keys).unwrap(); + let noise_keys = noise::Keypair::::new() + .into_authentic(&id_keys) + .unwrap(); let transport = transport::MemoryTransport::default() .upgrade() .authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated()) @@ -1274,7 +1343,9 @@ mod tests { fn test_connect_disconnect_ban() { // Since the test does not try to open any substreams, we can // use the dummy protocols handler. - let handler_proto = DummyProtocolsHandler { keep_alive: KeepAlive::Yes }; + let handler_proto = DummyProtocolsHandler { + keep_alive: KeepAlive::Yes, + }; let mut swarm1 = new_test_swarm::<_, ()>(handler_proto.clone()); let mut swarm2 = new_test_swarm::<_, ()>(handler_proto); @@ -1305,7 +1376,7 @@ mod tests { State::Connecting => { if swarms_connected(&swarm1, &swarm2, num_connections) { if banned { - return Poll::Ready(()) + return Poll::Ready(()); } swarm2.ban_peer_id(swarm1_id.clone()); swarm1.behaviour.reset(); @@ -1317,7 +1388,7 @@ mod tests { State::Disconnecting => { if swarms_disconnected(&swarm1, &swarm2, num_connections) { if unbanned { - return Poll::Ready(()) + return Poll::Ready(()); } // Unban the first peer and reconnect. swarm2.unban_peer_id(swarm1_id.clone()); @@ -1333,7 +1404,7 @@ mod tests { } if poll1.is_pending() && poll2.is_pending() { - return Poll::Pending + return Poll::Pending; } } })) @@ -1349,7 +1420,9 @@ mod tests { fn test_swarm_disconnect() { // Since the test does not try to open any substreams, we can // use the dummy protocols handler. - let handler_proto = DummyProtocolsHandler { keep_alive: KeepAlive::Yes }; + let handler_proto = DummyProtocolsHandler { + keep_alive: KeepAlive::Yes, + }; let mut swarm1 = new_test_swarm::<_, ()>(handler_proto.clone()); let mut swarm2 = new_test_swarm::<_, ()>(handler_proto); @@ -1370,41 +1443,41 @@ mod tests { } let mut state = State::Connecting; - executor::block_on(future::poll_fn(move |cx| { - loop { - let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); - let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); - match state { - State::Connecting => { - if swarms_connected(&swarm1, &swarm2, num_connections) { - if reconnected { - return Poll::Ready(()) - } - swarm2.disconnect_peer_id(swarm1_id.clone()).expect("Error disconnecting"); - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - state = State::Disconnecting; + executor::block_on(future::poll_fn(move |cx| loop { + let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); + let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); + match state { + State::Connecting => { + if swarms_connected(&swarm1, &swarm2, num_connections) { + if reconnected { + return Poll::Ready(()); } + swarm2 + .disconnect_peer_id(swarm1_id.clone()) + .expect("Error disconnecting"); + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + state = State::Disconnecting; } - State::Disconnecting => { - if swarms_disconnected(&swarm1, &swarm2, num_connections) { - if reconnected { - return Poll::Ready(()) - } - reconnected = true; - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - for _ in 0..num_connections { - swarm2.dial_addr(addr1.clone()).unwrap(); - } - state = State::Connecting; + } + State::Disconnecting => { + if swarms_disconnected(&swarm1, &swarm2, num_connections) { + if reconnected { + return Poll::Ready(()); + } + reconnected = true; + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + for _ in 0..num_connections { + swarm2.dial_addr(addr1.clone()).unwrap(); } + state = State::Connecting; } } + } - if poll1.is_pending() && poll2.is_pending() { - return Poll::Pending - } + if poll1.is_pending() && poll2.is_pending() { + return Poll::Pending; } })) } @@ -1420,7 +1493,9 @@ mod tests { fn test_behaviour_disconnect_all() { // Since the test does not try to open any substreams, we can // use the dummy protocols handler. - let handler_proto = DummyProtocolsHandler { keep_alive: KeepAlive::Yes }; + let handler_proto = DummyProtocolsHandler { + keep_alive: KeepAlive::Yes, + }; let mut swarm1 = new_test_swarm::<_, ()>(handler_proto.clone()); let mut swarm2 = new_test_swarm::<_, ()>(handler_proto); @@ -1441,48 +1516,44 @@ mod tests { } let mut state = State::Connecting; - executor::block_on(future::poll_fn(move |cx| { - loop { - let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); - let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); - match state { - State::Connecting => { - if swarms_connected(&swarm1, &swarm2, num_connections) { - if reconnected { - return Poll::Ready(()) - } - swarm2 - .behaviour - .inner() - .next_action - .replace(NetworkBehaviourAction::CloseConnection { - peer_id: swarm1_id.clone(), - connection: CloseConnection::All, - }); - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - state = State::Disconnecting; + executor::block_on(future::poll_fn(move |cx| loop { + let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); + let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); + match state { + State::Connecting => { + if swarms_connected(&swarm1, &swarm2, num_connections) { + if reconnected { + return Poll::Ready(()); } + swarm2.behaviour.inner().next_action.replace( + NetworkBehaviourAction::CloseConnection { + peer_id: swarm1_id.clone(), + connection: CloseConnection::All, + }, + ); + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + state = State::Disconnecting; } - State::Disconnecting => { - if swarms_disconnected(&swarm1, &swarm2, num_connections) { - if reconnected { - return Poll::Ready(()) - } - reconnected = true; - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - for _ in 0..num_connections { - swarm2.dial_addr(addr1.clone()).unwrap(); - } - state = State::Connecting; + } + State::Disconnecting => { + if swarms_disconnected(&swarm1, &swarm2, num_connections) { + if reconnected { + return Poll::Ready(()); } + reconnected = true; + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + for _ in 0..num_connections { + swarm2.dial_addr(addr1.clone()).unwrap(); + } + state = State::Connecting; } } + } - if poll1.is_pending() && poll2.is_pending() { - return Poll::Pending - } + if poll1.is_pending() && poll2.is_pending() { + return Poll::Pending; } })) } @@ -1498,7 +1569,9 @@ mod tests { fn test_behaviour_disconnect_one() { // Since the test does not try to open any substreams, we can // use the dummy protocols handler. - let handler_proto = DummyProtocolsHandler { keep_alive: KeepAlive::Yes }; + let handler_proto = DummyProtocolsHandler { + keep_alive: KeepAlive::Yes, + }; let mut swarm1 = new_test_swarm::<_, ()>(handler_proto.clone()); let mut swarm2 = new_test_swarm::<_, ()>(handler_proto); @@ -1519,49 +1592,48 @@ mod tests { let mut state = State::Connecting; let mut disconnected_conn_id = None; - executor::block_on(future::poll_fn(move |cx| { - loop { - let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); - let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); - match state { - State::Connecting => { - if swarms_connected(&swarm1, &swarm2, num_connections) { - disconnected_conn_id = { - let conn_id = swarm2.behaviour.inject_connection_established[num_connections / 2].1; - swarm2 - .behaviour - .inner() - .next_action - .replace(NetworkBehaviourAction::CloseConnection { - peer_id: swarm1_id.clone(), - connection: CloseConnection::One(conn_id), - }); - Some(conn_id) - }; - swarm1.behaviour.reset(); - swarm2.behaviour.reset(); - state = State::Disconnecting; - } + executor::block_on(future::poll_fn(move |cx| loop { + let poll1 = Swarm::poll_next_event(Pin::new(&mut swarm1), cx); + let poll2 = Swarm::poll_next_event(Pin::new(&mut swarm2), cx); + match state { + State::Connecting => { + if swarms_connected(&swarm1, &swarm2, num_connections) { + disconnected_conn_id = { + let conn_id = swarm2.behaviour.inject_connection_established + [num_connections / 2] + .1; + swarm2.behaviour.inner().next_action.replace( + NetworkBehaviourAction::CloseConnection { + peer_id: swarm1_id.clone(), + connection: CloseConnection::One(conn_id), + }, + ); + Some(conn_id) + }; + swarm1.behaviour.reset(); + swarm2.behaviour.reset(); + state = State::Disconnecting; } - State::Disconnecting => { - for s in &[&swarm1, &swarm2] { - assert_eq!(s.behaviour.inject_disconnected.len(), 0); - assert_eq!(s.behaviour.inject_connection_established.len(), 0); - assert_eq!(s.behaviour.inject_connected.len(), 0); - } - if [&swarm1, &swarm2].iter().all(|s| { - s.behaviour.inject_connection_closed.len() == 1 - }) { - let conn_id = swarm2.behaviour.inject_connection_closed[0].1; - assert_eq!(Some(conn_id), disconnected_conn_id); - return Poll::Ready(()); - } + } + State::Disconnecting => { + for s in &[&swarm1, &swarm2] { + assert_eq!(s.behaviour.inject_disconnected.len(), 0); + assert_eq!(s.behaviour.inject_connection_established.len(), 0); + assert_eq!(s.behaviour.inject_connected.len(), 0); + } + if [&swarm1, &swarm2] + .iter() + .all(|s| s.behaviour.inject_connection_closed.len() == 1) + { + let conn_id = swarm2.behaviour.inject_connection_closed[0].1; + assert_eq!(Some(conn_id), disconnected_conn_id); + return Poll::Ready(()); } } + } - if poll1.is_pending() && poll2.is_pending() { - return Poll::Pending - } + if poll1.is_pending() && poll2.is_pending() { + return Poll::Pending; } })) } diff --git a/swarm/src/protocols_handler.rs b/swarm/src/protocols_handler.rs index 58ae351673b..911693f32af 100644 --- a/swarm/src/protocols_handler.rs +++ b/swarm/src/protocols_handler.rs @@ -40,23 +40,14 @@ mod dummy; mod map_in; mod map_out; +pub mod multi; mod node_handler; mod one_shot; mod select; -pub mod multi; -pub use crate::upgrade::{ - InboundUpgradeSend, - OutboundUpgradeSend, - UpgradeInfoSend, -}; - -use libp2p_core::{ - ConnectedPoint, - Multiaddr, - PeerId, - upgrade::UpgradeError, -}; +pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend}; + +use libp2p_core::{upgrade::UpgradeError, ConnectedPoint, Multiaddr, PeerId}; use std::{cmp::Ordering, error, fmt, task::Context, task::Poll, time::Duration}; use wasm_timer::Instant; @@ -128,7 +119,7 @@ pub trait ProtocolsHandler: Send + 'static { fn inject_fully_negotiated_inbound( &mut self, protocol: ::Output, - info: Self::InboundOpenInfo + info: Self::InboundOpenInfo, ); /// Injects the output of a successful upgrade on a new outbound substream. @@ -138,7 +129,7 @@ pub trait ProtocolsHandler: Send + 'static { fn inject_fully_negotiated_outbound( &mut self, protocol: ::Output, - info: Self::OutboundOpenInfo + info: Self::OutboundOpenInfo, ); /// Injects an event coming from the outside in the handler. @@ -151,17 +142,16 @@ pub trait ProtocolsHandler: Send + 'static { fn inject_dial_upgrade_error( &mut self, info: Self::OutboundOpenInfo, - error: ProtocolsHandlerUpgrErr< - ::Error - > + error: ProtocolsHandlerUpgrErr<::Error>, ); /// Indicates to the handler that upgrading an inbound substream to the given protocol has failed. fn inject_listen_upgrade_error( &mut self, _: Self::InboundOpenInfo, - _: ProtocolsHandlerUpgrErr<::Error> - ) {} + _: ProtocolsHandlerUpgrErr<::Error>, + ) { + } /// Returns until when the connection should be kept alive. /// @@ -186,8 +176,16 @@ pub trait ProtocolsHandler: Send + 'static { fn connection_keep_alive(&self) -> KeepAlive; /// Should behave like `Stream::poll()`. - fn poll(&mut self, cx: &mut Context<'_>) -> Poll< - ProtocolsHandlerEvent + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, >; /// Adds a closure that turns the input event into something else. @@ -315,7 +313,7 @@ pub enum ProtocolsHandlerEvent + protocol: SubstreamProtocol, }, /// Close the connection for the given reason. @@ -341,7 +339,7 @@ impl match self { ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } => { ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: protocol.map_info(map) + protocol: protocol.map_info(map), } } ProtocolsHandlerEvent::Custom(val) => ProtocolsHandlerEvent::Custom(val), @@ -361,7 +359,7 @@ impl match self { ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } => { ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: protocol.map_upgrade(map) + protocol: protocol.map_upgrade(map), } } ProtocolsHandlerEvent::Custom(val) => ProtocolsHandlerEvent::Custom(val), @@ -419,12 +417,12 @@ impl ProtocolsHandlerUpgrErr { /// Map the inner [`UpgradeError`] type. pub fn map_upgrade_err(self, f: F) -> ProtocolsHandlerUpgrErr where - F: FnOnce(UpgradeError) -> UpgradeError + F: FnOnce(UpgradeError) -> UpgradeError, { match self { ProtocolsHandlerUpgrErr::Timeout => ProtocolsHandlerUpgrErr::Timeout, ProtocolsHandlerUpgrErr::Timer => ProtocolsHandlerUpgrErr::Timer, - ProtocolsHandlerUpgrErr::Upgrade(e) => ProtocolsHandlerUpgrErr::Upgrade(f(e)) + ProtocolsHandlerUpgrErr::Upgrade(e) => ProtocolsHandlerUpgrErr::Upgrade(f(e)), } } } @@ -437,10 +435,10 @@ where match self { ProtocolsHandlerUpgrErr::Timeout => { write!(f, "Timeout error while opening a substream") - }, + } ProtocolsHandlerUpgrErr::Timer => { write!(f, "Timer error while opening a substream") - }, + } ProtocolsHandlerUpgrErr::Upgrade(err) => write!(f, "{}", err), } } @@ -448,7 +446,7 @@ where impl error::Error for ProtocolsHandlerUpgrErr where - TUpgrErr: error::Error + 'static + TUpgrErr: error::Error + 'static, { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { @@ -467,7 +465,11 @@ pub trait IntoProtocolsHandler: Send + 'static { /// Builds the protocols handler. /// /// The `PeerId` is the id of the node the handler is going to handle. - fn into_handler(self, remote_peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler; + fn into_handler( + self, + remote_peer_id: &PeerId, + connected_point: &ConnectedPoint, + ) -> Self::Handler; /// Return the handler's inbound protocol. fn inbound_protocol(&self) -> ::InboundProtocol; @@ -492,7 +494,8 @@ pub trait IntoProtocolsHandler: Send + 'static { } impl IntoProtocolsHandler for T -where T: ProtocolsHandler +where + T: ProtocolsHandler, { type Handler = Self; @@ -537,9 +540,9 @@ impl Ord for KeepAlive { use self::KeepAlive::*; match (self, other) { - (No, No) | (Yes, Yes) => Ordering::Equal, - (No, _) | (_, Yes) => Ordering::Less, - (_, No) | (Yes, _) => Ordering::Greater, + (No, No) | (Yes, Yes) => Ordering::Equal, + (No, _) | (_, Yes) => Ordering::Less, + (_, No) | (Yes, _) => Ordering::Greater, (Until(t1), Until(t2)) => t1.cmp(t2), } } diff --git a/swarm/src/protocols_handler/dummy.rs b/swarm/src/protocols_handler/dummy.rs index 764f95fe2cf..97dd55ce793 100644 --- a/swarm/src/protocols_handler/dummy.rs +++ b/swarm/src/protocols_handler/dummy.rs @@ -18,15 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::NegotiatedSubstream; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, +}; +use crate::NegotiatedSubstream; +use libp2p_core::{ + upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade}, + Multiaddr, }; -use libp2p_core::{Multiaddr, upgrade::{InboundUpgrade, OutboundUpgrade, DeniedUpgrade}}; use std::task::{Context, Poll}; use void::Void; @@ -39,7 +38,7 @@ pub struct DummyProtocolsHandler { impl Default for DummyProtocolsHandler { fn default() -> Self { DummyProtocolsHandler { - keep_alive: KeepAlive::No + keep_alive: KeepAlive::No, } } } @@ -60,14 +59,14 @@ impl ProtocolsHandler for DummyProtocolsHandler { fn inject_fully_negotiated_inbound( &mut self, _: >::Output, - _: Self::InboundOpenInfo + _: Self::InboundOpenInfo, ) { } fn inject_fully_negotiated_outbound( &mut self, _: >::Output, - _: Self::OutboundOpenInfo + _: Self::OutboundOpenInfo, ) { } @@ -75,9 +74,23 @@ impl ProtocolsHandler for DummyProtocolsHandler { fn inject_address_change(&mut self, _: &Multiaddr) {} - fn inject_dial_upgrade_error(&mut self, _: Self::OutboundOpenInfo, _: ProtocolsHandlerUpgrErr<>::Error>) {} + fn inject_dial_upgrade_error( + &mut self, + _: Self::OutboundOpenInfo, + _: ProtocolsHandlerUpgrErr< + >::Error, + >, + ) { + } - fn inject_listen_upgrade_error(&mut self, _: Self::InboundOpenInfo, _: ProtocolsHandlerUpgrErr<>::Error>) {} + fn inject_listen_upgrade_error( + &mut self, + _: Self::InboundOpenInfo, + _: ProtocolsHandlerUpgrErr< + >::Error, + >, + ) { + } fn connection_keep_alive(&self) -> KeepAlive { self.keep_alive @@ -87,7 +100,12 @@ impl ProtocolsHandler for DummyProtocolsHandler { &mut self, _: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { Poll::Pending } diff --git a/swarm/src/protocols_handler/map_in.rs b/swarm/src/protocols_handler/map_in.rs index 77ac5f912d9..1c1e436e42d 100644 --- a/swarm/src/protocols_handler/map_in.rs +++ b/swarm/src/protocols_handler/map_in.rs @@ -18,14 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use libp2p_core::Multiaddr; use std::{fmt::Debug, marker::PhantomData, task::Context, task::Poll}; @@ -69,7 +65,7 @@ where fn inject_fully_negotiated_inbound( &mut self, protocol: ::Output, - info: Self::InboundOpenInfo + info: Self::InboundOpenInfo, ) { self.inner.inject_fully_negotiated_inbound(protocol, info) } @@ -77,7 +73,7 @@ where fn inject_fully_negotiated_outbound( &mut self, protocol: ::Output, - info: Self::OutboundOpenInfo + info: Self::OutboundOpenInfo, ) { self.inner.inject_fully_negotiated_outbound(protocol, info) } @@ -92,11 +88,19 @@ where self.inner.inject_address_change(addr) } - fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { self.inner.inject_dial_upgrade_error(info, error) } - fn inject_listen_upgrade_error(&mut self, info: Self::InboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_listen_upgrade_error( + &mut self, + info: Self::InboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { self.inner.inject_listen_upgrade_error(info, error) } @@ -108,7 +112,12 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { self.inner.poll(cx) } diff --git a/swarm/src/protocols_handler/map_out.rs b/swarm/src/protocols_handler/map_out.rs index 9df2ace9256..77d0e1eac93 100644 --- a/swarm/src/protocols_handler/map_out.rs +++ b/swarm/src/protocols_handler/map_out.rs @@ -18,14 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use libp2p_core::Multiaddr; use std::fmt::Debug; use std::task::{Context, Poll}; @@ -39,10 +35,7 @@ pub struct MapOutEvent { impl MapOutEvent { /// Creates a `MapOutEvent`. pub(crate) fn new(inner: TProtoHandler, map: TMap) -> Self { - MapOutEvent { - inner, - map, - } + MapOutEvent { inner, map } } } @@ -68,7 +61,7 @@ where fn inject_fully_negotiated_inbound( &mut self, protocol: ::Output, - info: Self::InboundOpenInfo + info: Self::InboundOpenInfo, ) { self.inner.inject_fully_negotiated_inbound(protocol, info) } @@ -76,7 +69,7 @@ where fn inject_fully_negotiated_outbound( &mut self, protocol: ::Output, - info: Self::OutboundOpenInfo + info: Self::OutboundOpenInfo, ) { self.inner.inject_fully_negotiated_outbound(protocol, info) } @@ -89,11 +82,19 @@ where self.inner.inject_address_change(addr) } - fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { self.inner.inject_dial_upgrade_error(info, error) } - fn inject_listen_upgrade_error(&mut self, info: Self::InboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_listen_upgrade_error( + &mut self, + info: Self::InboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { self.inner.inject_listen_upgrade_error(info, error) } @@ -105,15 +106,18 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { - self.inner.poll(cx).map(|ev| { - match ev { - ProtocolsHandlerEvent::Custom(ev) => ProtocolsHandlerEvent::Custom((self.map)(ev)), - ProtocolsHandlerEvent::Close(err) => ProtocolsHandlerEvent::Close(err), - ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } => { - ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } - } + self.inner.poll(cx).map(|ev| match ev { + ProtocolsHandlerEvent::Custom(ev) => ProtocolsHandlerEvent::Custom((self.map)(ev)), + ProtocolsHandlerEvent::Close(err) => ProtocolsHandlerEvent::Close(err), + ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } => { + ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol } } }) } diff --git a/swarm/src/protocols_handler/multi.rs b/swarm/src/protocols_handler/multi.rs index 64821ca3d35..f865443766c 100644 --- a/swarm/src/protocols_handler/multi.rs +++ b/swarm/src/protocols_handler/multi.rs @@ -21,23 +21,15 @@ //! A [`ProtocolsHandler`] implementation that combines multiple other `ProtocolsHandler`s //! indexed by some key. -use crate::NegotiatedSubstream; use crate::protocols_handler::{ - KeepAlive, - IntoProtocolsHandler, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, - SubstreamProtocol -}; -use crate::upgrade::{ - InboundUpgradeSend, - OutboundUpgradeSend, - UpgradeInfoSend + IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend}; +use crate::NegotiatedSubstream; use futures::{future::BoxFuture, prelude::*}; +use libp2p_core::upgrade::{NegotiationError, ProtocolError, ProtocolName, UpgradeError}; use libp2p_core::{ConnectedPoint, Multiaddr, PeerId}; -use libp2p_core::upgrade::{ProtocolName, UpgradeError, NegotiationError, ProtocolError}; use rand::Rng; use std::{ cmp, @@ -47,19 +39,19 @@ use std::{ hash::Hash, iter::{self, FromIterator}, task::{Context, Poll}, - time::Duration + time::Duration, }; /// A [`ProtocolsHandler`] for multiple `ProtocolsHandler`s of the same type. #[derive(Clone)] pub struct MultiHandler { - handlers: HashMap + handlers: HashMap, } impl fmt::Debug for MultiHandler where K: fmt::Debug + Eq + Hash, - H: fmt::Debug + H: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MultiHandler") @@ -71,17 +63,23 @@ where impl MultiHandler where K: Hash + Eq, - H: ProtocolsHandler + H: ProtocolsHandler, { /// Create and populate a `MultiHandler` from the given handler iterator. /// /// It is an error for any two protocols handlers to share the same protocol name. pub fn try_from_iter(iter: I) -> Result where - I: IntoIterator + I: IntoIterator, { - let m = MultiHandler { handlers: HashMap::from_iter(iter) }; - uniq_proto_names(m.handlers.values().map(|h| h.listen_protocol().into_upgrade().0))?; + let m = MultiHandler { + handlers: HashMap::from_iter(iter), + }; + uniq_proto_names( + m.handlers + .values() + .map(|h| h.listen_protocol().into_upgrade().0), + )?; Ok(m) } } @@ -91,7 +89,7 @@ where K: Clone + Debug + Hash + Eq + Send + 'static, H: ProtocolsHandler, H::InboundProtocol: InboundUpgradeSend, - H::OutboundProtocol: OutboundUpgradeSend + H::OutboundProtocol: OutboundUpgradeSend, { type InEvent = (K, ::InEvent); type OutEvent = (K, ::OutEvent); @@ -102,28 +100,31 @@ where type OutboundOpenInfo = (K, ::OutboundOpenInfo); fn listen_protocol(&self) -> SubstreamProtocol { - let (upgrade, info, timeout) = self.handlers.iter() + let (upgrade, info, timeout) = self + .handlers + .iter() .map(|(key, handler)| { let proto = handler.listen_protocol(); let timeout = *proto.timeout(); let (upgrade, info) = proto.into_upgrade(); (key.clone(), (upgrade, info, timeout)) }) - .fold((Upgrade::new(), Info::new(), Duration::from_secs(0)), + .fold( + (Upgrade::new(), Info::new(), Duration::from_secs(0)), |(mut upg, mut inf, mut timeout), (k, (u, i, t))| { upg.upgrades.push((k.clone(), u)); inf.infos.push((k, i)); timeout = cmp::max(timeout, t); (upg, inf, timeout) - } + }, ); SubstreamProtocol::new(upgrade, info).with_timeout(timeout) } - fn inject_fully_negotiated_outbound ( + fn inject_fully_negotiated_outbound( &mut self, protocol: ::Output, - (key, arg): Self::OutboundOpenInfo + (key, arg): Self::OutboundOpenInfo, ) { if let Some(h) = self.handlers.get_mut(&key) { h.inject_fully_negotiated_outbound(protocol, arg) @@ -132,10 +133,10 @@ where } } - fn inject_fully_negotiated_inbound ( + fn inject_fully_negotiated_inbound( &mut self, (key, arg): ::Output, - mut info: Self::InboundOpenInfo + mut info: Self::InboundOpenInfo, ) { if let Some(h) = self.handlers.get_mut(&key) { if let Some(i) = info.take(&key) { @@ -160,10 +161,10 @@ where } } - fn inject_dial_upgrade_error ( + fn inject_dial_upgrade_error( &mut self, (key, arg): Self::OutboundOpenInfo, - error: ProtocolsHandlerUpgrErr<::Error> + error: ProtocolsHandlerUpgrErr<::Error>, ) { if let Some(h) = self.handlers.get_mut(&key) { h.inject_dial_upgrade_error(arg, error) @@ -175,77 +176,118 @@ where fn inject_listen_upgrade_error( &mut self, mut info: Self::InboundOpenInfo, - error: ProtocolsHandlerUpgrErr<::Error> + error: ProtocolsHandlerUpgrErr<::Error>, ) { match error { - ProtocolsHandlerUpgrErr::Timer => + ProtocolsHandlerUpgrErr::Timer => { for (k, h) in &mut self.handlers { if let Some(i) = info.take(k) { h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timer) } } - ProtocolsHandlerUpgrErr::Timeout => + } + ProtocolsHandlerUpgrErr::Timeout => { for (k, h) in &mut self.handlers { if let Some(i) = info.take(k) { h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Timeout) } } - ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => + } + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { for (k, h) in &mut self.handlers { if let Some(i) = info.take(k) { - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed))) + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::Failed, + )), + ) } } - ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) => - match e { - ProtocolError::IoError(e) => - for (k, h) in &mut self.handlers { - if let Some(i) = info.take(k) { - let e = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into())); - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e))) - } + } + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::ProtocolError(e), + )) => match e { + ProtocolError::IoError(e) => { + for (k, h) in &mut self.handlers { + if let Some(i) = info.take(k) { + let e = NegotiationError::ProtocolError(ProtocolError::IoError( + e.kind().into(), + )); + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)), + ) } - ProtocolError::InvalidMessage => - for (k, h) in &mut self.handlers { - if let Some(i) = info.take(k) { - let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage); - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e))) - } + } + } + ProtocolError::InvalidMessage => { + for (k, h) in &mut self.handlers { + if let Some(i) = info.take(k) { + let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage); + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)), + ) } - ProtocolError::InvalidProtocol => - for (k, h) in &mut self.handlers { - if let Some(i) = info.take(k) { - let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol); - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e))) - } + } + } + ProtocolError::InvalidProtocol => { + for (k, h) in &mut self.handlers { + if let Some(i) = info.take(k) { + let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol); + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)), + ) } - ProtocolError::TooManyProtocols => - for (k, h) in &mut self.handlers { - if let Some(i) = info.take(k) { - let e = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols); - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e))) - } + } + } + ProtocolError::TooManyProtocols => { + for (k, h) in &mut self.handlers { + if let Some(i) = info.take(k) { + let e = + NegotiationError::ProtocolError(ProtocolError::TooManyProtocols); + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)), + ) } + } } - ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) => + }, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) => { if let Some(h) = self.handlers.get_mut(&k) { if let Some(i) = info.take(&k) { - h.inject_listen_upgrade_error(i, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e))) + h.inject_listen_upgrade_error( + i, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)), + ) } } + } } } fn connection_keep_alive(&self) -> KeepAlive { - self.handlers.values() + self.handlers + .values() .map(|h| h.connection_keep_alive()) .max() .unwrap_or(KeepAlive::No) } - fn poll(&mut self, cx: &mut Context<'_>) - -> Poll> - { + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, + > { // Calling `gen_range(0, 0)` (see below) would panic, so we have return early to avoid // that situation. if self.handlers.is_empty() { @@ -257,15 +299,19 @@ where for (k, h) in self.handlers.iter_mut().skip(pos) { if let Poll::Ready(e) = h.poll(cx) { - let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p)); - return Poll::Ready(e) + let e = e + .map_outbound_open_info(|i| (k.clone(), i)) + .map_custom(|p| (k.clone(), p)); + return Poll::Ready(e); } } for (k, h) in self.handlers.iter_mut().take(pos) { if let Poll::Ready(e) = h.poll(cx) { - let e = e.map_outbound_open_info(|i| (k.clone(), i)).map_custom(|p| (k.clone(), p)); - return Poll::Ready(e) + let e = e + .map_outbound_open_info(|i| (k.clone(), i)) + .map_custom(|p| (k.clone(), p)); + return Poll::Ready(e); } } @@ -276,13 +322,13 @@ where /// A [`IntoProtocolsHandler`] for multiple other `IntoProtocolsHandler`s. #[derive(Clone)] pub struct IntoMultiHandler { - handlers: HashMap + handlers: HashMap, } impl fmt::Debug for IntoMultiHandler where K: fmt::Debug + Eq + Hash, - H: fmt::Debug + H: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IntoMultiHandler") @@ -291,20 +337,21 @@ where } } - impl IntoMultiHandler where K: Hash + Eq, - H: IntoProtocolsHandler + H: IntoProtocolsHandler, { /// Create and populate an `IntoMultiHandler` from the given iterator. /// /// It is an error for any two protocols handlers to share the same protocol name. pub fn try_from_iter(iter: I) -> Result where - I: IntoIterator + I: IntoIterator, { - let m = IntoMultiHandler { handlers: HashMap::from_iter(iter) }; + let m = IntoMultiHandler { + handlers: HashMap::from_iter(iter), + }; uniq_proto_names(m.handlers.values().map(|h| h.inbound_protocol()))?; Ok(m) } @@ -313,23 +360,27 @@ where impl IntoProtocolsHandler for IntoMultiHandler where K: Debug + Clone + Eq + Hash + Send + 'static, - H: IntoProtocolsHandler + H: IntoProtocolsHandler, { type Handler = MultiHandler; fn into_handler(self, p: &PeerId, c: &ConnectedPoint) -> Self::Handler { MultiHandler { - handlers: self.handlers.into_iter() + handlers: self + .handlers + .into_iter() .map(|(k, h)| (k, h.into_handler(p, c))) - .collect() + .collect(), } } fn inbound_protocol(&self) -> ::InboundProtocol { Upgrade { - upgrades: self.handlers.iter() + upgrades: self + .handlers + .iter() .map(|(k, h)| (k.clone(), h.inbound_protocol())) - .collect() + .collect(), } } } @@ -347,7 +398,7 @@ impl ProtocolName for IndexedProtoName { /// The aggregated `InboundOpenInfo`s of supported inbound substream protocols. #[derive(Clone)] pub struct Info { - infos: Vec<(K, I)> + infos: Vec<(K, I)>, } impl Info { @@ -357,7 +408,7 @@ impl Info { pub fn take(&mut self, k: &K) -> Option { if let Some(p) = self.infos.iter().position(|(key, _)| key == k) { - return Some(self.infos.remove(p).1) + return Some(self.infos.remove(p).1); } None } @@ -366,19 +417,21 @@ impl Info { /// Inbound and outbound upgrade for all `ProtocolsHandler`s. #[derive(Clone)] pub struct Upgrade { - upgrades: Vec<(K, H)> + upgrades: Vec<(K, H)>, } impl Upgrade { fn new() -> Self { - Upgrade { upgrades: Vec::new() } + Upgrade { + upgrades: Vec::new(), + } } } impl fmt::Debug for Upgrade where K: fmt::Debug + Eq + Hash, - H: fmt::Debug + H: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Upgrade") @@ -390,13 +443,15 @@ where impl UpgradeInfoSend for Upgrade where H: UpgradeInfoSend, - K: Send + 'static + K: Send + 'static, { type Info = IndexedProtoName; type InfoIter = std::vec::IntoIter; fn protocol_info(&self) -> Self::InfoIter { - self.upgrades.iter().enumerate() + self.upgrades + .iter() + .enumerate() .map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info())) .flatten() .map(|(i, h)| IndexedProtoName(i, h)) @@ -408,21 +463,20 @@ where impl InboundUpgradeSend for Upgrade where H: InboundUpgradeSend, - K: Send + 'static + K: Send + 'static, { type Output = (K, ::Output); - type Error = (K, ::Error); + type Error = (K, ::Error); type Future = BoxFuture<'static, Result>; fn upgrade_inbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future { let IndexedProtoName(index, info) = info; let (key, upgrade) = self.upgrades.remove(index); - upgrade.upgrade_inbound(resource, info) - .map(move |out| { - match out { - Ok(o) => Ok((key, o)), - Err(e) => Err((key, e)) - } + upgrade + .upgrade_inbound(resource, info) + .map(move |out| match out { + Ok(o) => Ok((key, o)), + Err(e) => Err((key, e)), }) .boxed() } @@ -431,21 +485,20 @@ where impl OutboundUpgradeSend for Upgrade where H: OutboundUpgradeSend, - K: Send + 'static + K: Send + 'static, { type Output = (K, ::Output); - type Error = (K, ::Error); + type Error = (K, ::Error); type Future = BoxFuture<'static, Result>; fn upgrade_outbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future { let IndexedProtoName(index, info) = info; let (key, upgrade) = self.upgrades.remove(index); - upgrade.upgrade_outbound(resource, info) - .map(move |out| { - match out { - Ok(o) => Ok((key, o)), - Err(e) => Err((key, e)) - } + upgrade + .upgrade_outbound(resource, info) + .map(move |out| match out { + Ok(o) => Ok((key, o)), + Err(e) => Err((key, e)), }) .boxed() } @@ -455,14 +508,14 @@ where fn uniq_proto_names(iter: I) -> Result<(), DuplicateProtonameError> where I: Iterator, - T: UpgradeInfoSend + T: UpgradeInfoSend, { let mut set = HashSet::new(); for infos in iter { for i in infos.protocol_info() { let v = Vec::from(i.protocol_name()); if set.contains(&v) { - return Err(DuplicateProtonameError(v)) + return Err(DuplicateProtonameError(v)); } else { set.insert(v); } diff --git a/swarm/src/protocols_handler/node_handler.rs b/swarm/src/protocols_handler/node_handler.rs index 72730117cc3..edb383282cd 100644 --- a/swarm/src/protocols_handler/node_handler.rs +++ b/swarm/src/protocols_handler/node_handler.rs @@ -18,29 +18,22 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::SendWrapper; use crate::protocols_handler::{ - KeepAlive, - ProtocolsHandler, - IntoProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr + IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, }; +use crate::upgrade::SendWrapper; use futures::prelude::*; use futures::stream::FuturesUnordered; use libp2p_core::{ - Multiaddr, - Connected, connection::{ - ConnectionHandler, - ConnectionHandlerEvent, - IntoConnectionHandler, - Substream, + ConnectionHandler, ConnectionHandlerEvent, IntoConnectionHandler, Substream, SubstreamEndpoint, }, muxing::StreamMuxerBox, - upgrade::{self, InboundUpgradeApply, OutboundUpgradeApply, UpgradeError} + upgrade::{self, InboundUpgradeApply, OutboundUpgradeApply, UpgradeError}, + Connected, Multiaddr, }; use std::{error, fmt, pin::Pin, task::Context, task::Poll, time::Duration}; use wasm_timer::{Delay, Instant}; @@ -55,7 +48,7 @@ pub struct NodeHandlerWrapperBuilder { impl NodeHandlerWrapperBuilder where - TIntoProtoHandler: IntoProtocolsHandler + TIntoProtoHandler: IntoProtocolsHandler, { /// Builds a `NodeHandlerWrapperBuilder`. pub(crate) fn new(handler: TIntoProtoHandler) -> Self { @@ -67,7 +60,7 @@ where pub(crate) fn with_substream_upgrade_protocol_override( mut self, - version: Option + version: Option, ) -> Self { self.substream_upgrade_protocol_override = version; self @@ -84,7 +77,9 @@ where fn into_handler(self, connected: &Connected) -> Self::Handler { NodeHandlerWrapper { - handler: self.handler.into_handler(&connected.peer_id, &connected.endpoint), + handler: self + .handler + .into_handler(&connected.peer_id, &connected.endpoint), negotiating_in: Default::default(), negotiating_out: Default::default(), queued_dial_upgrades: Vec::new(), @@ -105,15 +100,25 @@ where /// The underlying handler. handler: TProtoHandler, /// Futures that upgrade incoming substreams. - negotiating_in: FuturesUnordered, SendWrapper>, - >>, + negotiating_in: FuturesUnordered< + SubstreamUpgrade< + TProtoHandler::InboundOpenInfo, + InboundUpgradeApply< + Substream, + SendWrapper, + >, + >, + >, /// Futures that upgrade outgoing substreams. - negotiating_out: FuturesUnordered, SendWrapper>, - >>, + negotiating_out: FuturesUnordered< + SubstreamUpgrade< + TProtoHandler::OutboundOpenInfo, + OutboundUpgradeApply< + Substream, + SendWrapper, + >, + >, + >, /// For each outbound substream request, how to upgrade it. The first element of the tuple /// is the unique identifier (see `unique_dial_upgrade_id`). queued_dial_upgrades: Vec<(u64, SendWrapper)>, @@ -137,28 +142,43 @@ impl Future for SubstreamUpgrad where Upgrade: Future>> + Unpin, { - type Output = (UserData, Result>); + type Output = ( + UserData, + Result>, + ); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { match self.timeout.poll_unpin(cx) { - Poll::Ready(Ok(_)) => return Poll::Ready(( - self.user_data.take().expect("Future not to be polled again once ready."), - Err(ProtocolsHandlerUpgrErr::Timeout)), - ), - Poll::Ready(Err(_)) => return Poll::Ready(( - self.user_data.take().expect("Future not to be polled again once ready."), - Err(ProtocolsHandlerUpgrErr::Timer)), - ), - Poll::Pending => {}, + Poll::Ready(Ok(_)) => { + return Poll::Ready(( + self.user_data + .take() + .expect("Future not to be polled again once ready."), + Err(ProtocolsHandlerUpgrErr::Timeout), + )) + } + Poll::Ready(Err(_)) => { + return Poll::Ready(( + self.user_data + .take() + .expect("Future not to be polled again once ready."), + Err(ProtocolsHandlerUpgrErr::Timer), + )) + } + Poll::Pending => {} } match self.upgrade.poll_unpin(cx) { Poll::Ready(Ok(upgrade)) => Poll::Ready(( - self.user_data.take().expect("Future not to be polled again once ready."), + self.user_data + .take() + .expect("Future not to be polled again once ready."), Ok(upgrade), )), Poll::Ready(Err(err)) => Poll::Ready(( - self.user_data.take().expect("Future not to be polled again once ready."), + self.user_data + .take() + .expect("Future not to be polled again once ready."), Err(ProtocolsHandlerUpgrErr::Upgrade(err)), )), Poll::Pending => Poll::Pending, @@ -166,7 +186,6 @@ where } } - /// The options for a planned connection & handler shutdown. /// /// A shutdown is planned anew based on the the return value of @@ -182,7 +201,7 @@ enum Shutdown { /// A shut down is planned as soon as possible. Asap, /// A shut down is planned for when a `Delay` has elapsed. - Later(Delay, Instant) + Later(Delay, Instant), } /// Error generated by the `NodeHandlerWrapper`. @@ -202,20 +221,21 @@ impl From for NodeHandlerWrapperError { impl fmt::Display for NodeHandlerWrapperError where - TErr: fmt::Display + TErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { NodeHandlerWrapperError::Handler(err) => write!(f, "{}", err), - NodeHandlerWrapperError::KeepAliveTimeout => - write!(f, "Connection closed due to expired keep-alive timeout."), + NodeHandlerWrapperError::KeepAliveTimeout => { + write!(f, "Connection closed due to expired keep-alive timeout.") + } } } } impl error::Error for NodeHandlerWrapperError where - TErr: error::Error + 'static + TErr: error::Error + 'static, { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { @@ -272,7 +292,11 @@ where let mut version = upgrade::Version::default(); if let Some(v) = self.substream_upgrade_protocol_override { if v != version { - log::debug!("Substream upgrade protocol override: {:?} -> {:?}", version, v); + log::debug!( + "Substream upgrade protocol override: {:?} -> {:?}", + version, + v + ); version = v; } } @@ -295,19 +319,25 @@ where self.handler.inject_address_change(new_address); } - fn poll(&mut self, cx: &mut Context<'_>) -> Poll< - Result, Self::Error> - > { + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> + { while let Poll::Ready(Some((user_data, res))) = self.negotiating_in.poll_next_unpin(cx) { match res { - Ok(upgrade) => self.handler.inject_fully_negotiated_inbound(upgrade, user_data), + Ok(upgrade) => self + .handler + .inject_fully_negotiated_inbound(upgrade, user_data), Err(err) => self.handler.inject_listen_upgrade_error(user_data, err), } } while let Poll::Ready(Some((user_data, res))) = self.negotiating_out.poll_next_unpin(cx) { match res { - Ok(upgrade) => self.handler.inject_fully_negotiated_outbound(upgrade, user_data), + Ok(upgrade) => self + .handler + .inject_fully_negotiated_outbound(upgrade, user_data), Err(err) => self.handler.inject_dial_upgrade_error(user_data, err), } } @@ -319,14 +349,15 @@ where // Ask the handler whether it wants the connection (and the handler itself) // to be kept alive, which determines the planned shutdown, if any. match (&mut self.shutdown, self.handler.connection_keep_alive()) { - (Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => + (Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => { if *deadline != t { *deadline = t; timer.reset_at(t) - }, + } + } (_, KeepAlive::Until(t)) => self.shutdown = Shutdown::Later(Delay::new_at(t), t), (_, KeepAlive::No) => self.shutdown = Shutdown::Asap, - (_, KeepAlive::Yes) => self.shutdown = Shutdown::None + (_, KeepAlive::Yes) => self.shutdown = Shutdown::None, }; match poll_result { @@ -339,9 +370,9 @@ where self.unique_dial_upgrade_id += 1; let (upgrade, info) = protocol.into_upgrade(); self.queued_dial_upgrades.push((id, SendWrapper(upgrade))); - return Poll::Ready(Ok( - ConnectionHandlerEvent::OutboundSubstreamRequest((id, info, timeout)), - )); + return Poll::Ready(Ok(ConnectionHandlerEvent::OutboundSubstreamRequest(( + id, info, timeout, + )))); } Poll::Ready(ProtocolsHandlerEvent::Close(err)) => return Poll::Ready(Err(err.into())), Poll::Pending => (), @@ -351,12 +382,16 @@ where // As long as we're still negotiating substreams, shutdown is always postponed. if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() { match self.shutdown { - Shutdown::None => {}, - Shutdown::Asap => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)), + Shutdown::None => {} + Shutdown::Asap => { + return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)) + } Shutdown::Later(ref mut delay, _) => match Future::poll(Pin::new(delay), cx) { - Poll::Ready(_) => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)), + Poll::Ready(_) => { + return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)) + } Poll::Pending => {} - } + }, } } diff --git a/swarm/src/protocols_handler/one_shot.rs b/swarm/src/protocols_handler/one_shot.rs index d19dd89d39e..01a2951efc5 100644 --- a/swarm/src/protocols_handler/one_shot.rs +++ b/swarm/src/protocols_handler/one_shot.rs @@ -18,14 +18,10 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, - SubstreamProtocol + KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use smallvec::SmallVec; use std::{error, fmt::Debug, task::Context, task::Poll, time::Duration}; @@ -53,8 +49,7 @@ where config: OneShotHandlerConfig, } -impl - OneShotHandler +impl OneShotHandler where TOutbound: OutboundUpgradeSend, { @@ -102,8 +97,7 @@ where } } -impl Default - for OneShotHandler +impl Default for OneShotHandler where TOutbound: OutboundUpgradeSend, TInbound: InboundUpgradeSend + Default, @@ -111,7 +105,7 @@ where fn default() -> Self { OneShotHandler::new( SubstreamProtocol::new(Default::default(), ()), - OneShotHandlerConfig::default() + OneShotHandlerConfig::default(), ) } } @@ -128,9 +122,7 @@ where { type InEvent = TOutbound; type OutEvent = TEvent; - type Error = ProtocolsHandlerUpgrErr< - ::Error, - >; + type Error = ProtocolsHandlerUpgrErr<::Error>; type InboundProtocol = TInbound; type OutboundProtocol = TOutbound; type OutboundOpenInfo = (); @@ -143,7 +135,7 @@ where fn inject_fully_negotiated_inbound( &mut self, out: ::Output, - (): Self::InboundOpenInfo + (): Self::InboundOpenInfo, ) { // If we're shutting down the connection for inactivity, reset the timeout. if !self.keep_alive.is_yes() { @@ -169,9 +161,7 @@ where fn inject_dial_upgrade_error( &mut self, _info: Self::OutboundOpenInfo, - error: ProtocolsHandlerUpgrErr< - ::Error, - >, + error: ProtocolsHandlerUpgrErr<::Error>, ) { if self.pending_error.is_none() { self.pending_error = Some(error); @@ -186,16 +176,19 @@ where &mut self, _: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent, + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { if let Some(err) = self.pending_error.take() { - return Poll::Ready(ProtocolsHandlerEvent::Close(err)) + return Poll::Ready(ProtocolsHandlerEvent::Close(err)); } if !self.events_out.is_empty() { - return Poll::Ready(ProtocolsHandlerEvent::Custom( - self.events_out.remove(0) - )); + return Poll::Ready(ProtocolsHandlerEvent::Custom(self.events_out.remove(0))); } else { self.events_out.shrink_to_fit(); } @@ -204,12 +197,10 @@ where if self.dial_negotiated < self.config.max_dial_negotiated { self.dial_negotiated += 1; let upgrade = self.dial_queue.remove(0); - return Poll::Ready( - ProtocolsHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(upgrade, ()) - .with_timeout(self.config.outbound_substream_timeout) - }, - ); + return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(upgrade, ()) + .with_timeout(self.config.outbound_substream_timeout), + }); } } else { self.dial_queue.shrink_to_fit(); @@ -256,18 +247,19 @@ mod tests { #[test] fn do_not_keep_idle_connection_alive() { let mut handler: OneShotHandler<_, DeniedUpgrade, Void> = OneShotHandler::new( - SubstreamProtocol::new(DeniedUpgrade{}, ()), + SubstreamProtocol::new(DeniedUpgrade {}, ()), Default::default(), ); - block_on(poll_fn(|cx| { - loop { - if let Poll::Pending = handler.poll(cx) { - return Poll::Ready(()) - } + block_on(poll_fn(|cx| loop { + if let Poll::Pending = handler.poll(cx) { + return Poll::Ready(()); } })); - assert!(matches!(handler.connection_keep_alive(), KeepAlive::Until(_))); + assert!(matches!( + handler.connection_keep_alive(), + KeepAlive::Until(_) + )); } } diff --git a/swarm/src/protocols_handler/select.rs b/swarm/src/protocols_handler/select.rs index d8005eef79d..b5891c25d1f 100644 --- a/swarm/src/protocols_handler/select.rs +++ b/swarm/src/protocols_handler/select.rs @@ -18,22 +18,16 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::upgrade::{SendWrapper, InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - IntoProtocolsHandler, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, + IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, }; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper}; use libp2p_core::{ - ConnectedPoint, - Multiaddr, - PeerId, either::{EitherError, EitherOutput}, - upgrade::{EitherUpgrade, SelectUpgrade, UpgradeError, NegotiationError, ProtocolError} + upgrade::{EitherUpgrade, NegotiationError, ProtocolError, SelectUpgrade, UpgradeError}, + ConnectedPoint, Multiaddr, PeerId, }; use std::{cmp, task::Context, task::Poll}; @@ -49,10 +43,7 @@ pub struct IntoProtocolsHandlerSelect { impl IntoProtocolsHandlerSelect { /// Builds a `IntoProtocolsHandlerSelect`. pub(crate) fn new(proto1: TProto1, proto2: TProto2) -> Self { - IntoProtocolsHandlerSelect { - proto1, - proto2, - } + IntoProtocolsHandlerSelect { proto1, proto2 } } } @@ -63,7 +54,11 @@ where { type Handler = ProtocolsHandlerSelect; - fn into_handler(self, remote_peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler { + fn into_handler( + self, + remote_peer_id: &PeerId, + connected_point: &ConnectedPoint, + ) -> Self::Handler { ProtocolsHandlerSelect { proto1: self.proto1.into_handler(remote_peer_id, connected_point), proto2: self.proto2.into_handler(remote_peer_id, connected_point), @@ -71,7 +66,10 @@ where } fn inbound_protocol(&self) -> ::InboundProtocol { - SelectUpgrade::new(SendWrapper(self.proto1.inbound_protocol()), SendWrapper(self.proto2.inbound_protocol())) + SelectUpgrade::new( + SendWrapper(self.proto1.inbound_protocol()), + SendWrapper(self.proto2.inbound_protocol()), + ) } } @@ -87,10 +85,7 @@ pub struct ProtocolsHandlerSelect { impl ProtocolsHandlerSelect { /// Builds a `ProtocolsHandlerSelect`. pub(crate) fn new(proto1: TProto1, proto2: TProto2) -> Self { - ProtocolsHandlerSelect { - proto1, - proto2, - } + ProtocolsHandlerSelect { proto1, proto2 } } } @@ -102,8 +97,14 @@ where type InEvent = EitherOutput; type OutEvent = EitherOutput; type Error = EitherError; - type InboundProtocol = SelectUpgrade::InboundProtocol>, SendWrapper<::InboundProtocol>>; - type OutboundProtocol = EitherUpgrade, SendWrapper>; + type InboundProtocol = SelectUpgrade< + SendWrapper<::InboundProtocol>, + SendWrapper<::InboundProtocol>, + >; + type OutboundProtocol = EitherUpgrade< + SendWrapper, + SendWrapper, + >; type OutboundOpenInfo = EitherOutput; type InboundOpenInfo = (TProto1::InboundOpenInfo, TProto2::InboundOpenInfo); @@ -117,25 +118,39 @@ where SubstreamProtocol::new(choice, (i1, i2)).with_timeout(timeout) } - fn inject_fully_negotiated_outbound(&mut self, protocol: ::Output, endpoint: Self::OutboundOpenInfo) { + fn inject_fully_negotiated_outbound( + &mut self, + protocol: ::Output, + endpoint: Self::OutboundOpenInfo, + ) { match (protocol, endpoint) { - (EitherOutput::First(protocol), EitherOutput::First(info)) => - self.proto1.inject_fully_negotiated_outbound(protocol, info), - (EitherOutput::Second(protocol), EitherOutput::Second(info)) => - self.proto2.inject_fully_negotiated_outbound(protocol, info), - (EitherOutput::First(_), EitherOutput::Second(_)) => - panic!("wrong API usage: the protocol doesn't match the upgrade info"), - (EitherOutput::Second(_), EitherOutput::First(_)) => + (EitherOutput::First(protocol), EitherOutput::First(info)) => { + self.proto1.inject_fully_negotiated_outbound(protocol, info) + } + (EitherOutput::Second(protocol), EitherOutput::Second(info)) => { + self.proto2.inject_fully_negotiated_outbound(protocol, info) + } + (EitherOutput::First(_), EitherOutput::Second(_)) => { panic!("wrong API usage: the protocol doesn't match the upgrade info") + } + (EitherOutput::Second(_), EitherOutput::First(_)) => { + panic!("wrong API usage: the protocol doesn't match the upgrade info") + } } } - fn inject_fully_negotiated_inbound(&mut self, protocol: ::Output, (i1, i2): Self::InboundOpenInfo) { + fn inject_fully_negotiated_inbound( + &mut self, + protocol: ::Output, + (i1, i2): Self::InboundOpenInfo, + ) { match protocol { - EitherOutput::First(protocol) => - self.proto1.inject_fully_negotiated_inbound(protocol, i1), - EitherOutput::Second(protocol) => + EitherOutput::First(protocol) => { + self.proto1.inject_fully_negotiated_inbound(protocol, i1) + } + EitherOutput::Second(protocol) => { self.proto2.inject_fully_negotiated_inbound(protocol, i2) + } } } @@ -151,60 +166,108 @@ where self.proto2.inject_address_change(new_address) } - fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { match (info, error) { - (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timer) => { - self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer) - }, - (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timeout) => { - self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout) - }, - (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) => { - self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) - }, - (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(err)))) => { - self.proto1.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err))) - }, - (EitherOutput::First(_), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(_)))) => { + (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timer) => self + .proto1 + .inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer), + (EitherOutput::First(info), ProtocolsHandlerUpgrErr::Timeout) => self + .proto1 + .inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout), + ( + EitherOutput::First(info), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)), + ) => self.proto1.inject_dial_upgrade_error( + info, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)), + ), + ( + EitherOutput::First(info), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(err))), + ) => self.proto1.inject_dial_upgrade_error( + info, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)), + ), + ( + EitherOutput::First(_), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(_))), + ) => { panic!("Wrong API usage; the upgrade error doesn't match the outbound open info"); - }, - (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timeout) => { - self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout) - }, - (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timer) => { - self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer) - }, - (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) => { - self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err))) - }, - (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(err)))) => { - self.proto2.inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err))) - }, - (EitherOutput::Second(_), ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(_)))) => { + } + (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timeout) => self + .proto2 + .inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timeout), + (EitherOutput::Second(info), ProtocolsHandlerUpgrErr::Timer) => self + .proto2 + .inject_dial_upgrade_error(info, ProtocolsHandlerUpgrErr::Timer), + ( + EitherOutput::Second(info), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)), + ) => self.proto2.inject_dial_upgrade_error( + info, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(err)), + ), + ( + EitherOutput::Second(info), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(err))), + ) => self.proto2.inject_dial_upgrade_error( + info, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(err)), + ), + ( + EitherOutput::Second(_), + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(_))), + ) => { panic!("Wrong API usage; the upgrade error doesn't match the outbound open info"); - }, + } } } - fn inject_listen_upgrade_error(&mut self, (i1, i2): Self::InboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_listen_upgrade_error( + &mut self, + (i1, i2): Self::InboundOpenInfo, + error: ProtocolsHandlerUpgrErr<::Error>, + ) { match error { ProtocolsHandlerUpgrErr::Timer => { - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timer); - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timer) + self.proto1 + .inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timer); + self.proto2 + .inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timer) } ProtocolsHandlerUpgrErr::Timeout => { - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timeout); - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timeout) + self.proto1 + .inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Timeout); + self.proto2 + .inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Timeout) } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed))); - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed))); + self.proto1.inject_listen_upgrade_error( + i1, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::Failed, + )), + ); + self.proto2.inject_listen_upgrade_error( + i2, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::Failed, + )), + ); } - ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::ProtocolError(e))) => { + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select( + NegotiationError::ProtocolError(e), + )) => { let (e1, e2); match e { ProtocolError::IoError(e) => { - e1 = NegotiationError::ProtocolError(ProtocolError::IoError(e.kind().into())); + e1 = NegotiationError::ProtocolError(ProtocolError::IoError( + e.kind().into(), + )); e2 = NegotiationError::ProtocolError(ProtocolError::IoError(e)) } ProtocolError::InvalidMessage => { @@ -220,55 +283,80 @@ where e2 = NegotiationError::ProtocolError(ProtocolError::TooManyProtocols) } } - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e1))); - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e2))) + self.proto1.inject_listen_upgrade_error( + i1, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e1)), + ); + self.proto2.inject_listen_upgrade_error( + i2, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e2)), + ) } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::A(e))) => { - self.proto1.inject_listen_upgrade_error(i1, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e))) + self.proto1.inject_listen_upgrade_error( + i1, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)), + ) } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(EitherError::B(e))) => { - self.proto2.inject_listen_upgrade_error(i2, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e))) + self.proto2.inject_listen_upgrade_error( + i2, + ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)), + ) } } } fn connection_keep_alive(&self) -> KeepAlive { - cmp::max(self.proto1.connection_keep_alive(), self.proto2.connection_keep_alive()) + cmp::max( + self.proto1.connection_keep_alive(), + self.proto2.connection_keep_alive(), + ) } - fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, + > { match self.proto1.poll(cx) { Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => { return Poll::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::First(event))); - }, + } Poll::Ready(ProtocolsHandlerEvent::Close(event)) => { return Poll::Ready(ProtocolsHandlerEvent::Close(EitherError::A(event))); - }, + } Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) => { return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: protocol .map_upgrade(|u| EitherUpgrade::A(SendWrapper(u))) - .map_info(EitherOutput::First) + .map_info(EitherOutput::First), }); - }, - Poll::Pending => () + } + Poll::Pending => (), }; match self.proto2.poll(cx) { Poll::Ready(ProtocolsHandlerEvent::Custom(event)) => { return Poll::Ready(ProtocolsHandlerEvent::Custom(EitherOutput::Second(event))); - }, + } Poll::Ready(ProtocolsHandlerEvent::Close(event)) => { return Poll::Ready(ProtocolsHandlerEvent::Close(EitherError::B(event))); - }, + } Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol }) => { return Poll::Ready(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: protocol .map_upgrade(|u| EitherUpgrade::B(SendWrapper(u))) - .map_info(EitherOutput::Second) + .map_info(EitherOutput::Second), }); - }, - Poll::Pending => () + } + Poll::Pending => (), }; Poll::Pending diff --git a/swarm/src/registry.rs b/swarm/src/registry.rs index 310639296d8..5819ecf1e4e 100644 --- a/swarm/src/registry.rs +++ b/swarm/src/registry.rs @@ -20,8 +20,8 @@ use libp2p_core::Multiaddr; use smallvec::SmallVec; -use std::{collections::VecDeque, cmp::Ordering, num::NonZeroUsize}; use std::ops::{Add, Sub}; +use std::{cmp::Ordering, collections::VecDeque, num::NonZeroUsize}; /// A ranked collection of [`Multiaddr`] values. /// @@ -77,9 +77,7 @@ struct Report { impl AddressRecord { fn new(addr: Multiaddr, score: AddressScore) -> Self { - AddressRecord { - addr, score, - } + AddressRecord { addr, score } } } @@ -117,14 +115,10 @@ impl Ord for AddressScore { fn cmp(&self, other: &AddressScore) -> Ordering { // Semantics of cardinal numbers with a single infinite cardinal. match (self, other) { - (AddressScore::Infinite, AddressScore::Infinite) => - Ordering::Equal, - (AddressScore::Infinite, AddressScore::Finite(_)) => - Ordering::Greater, - (AddressScore::Finite(_), AddressScore::Infinite) => - Ordering::Less, - (AddressScore::Finite(a), AddressScore::Finite(b)) => - a.cmp(b), + (AddressScore::Infinite, AddressScore::Infinite) => Ordering::Equal, + (AddressScore::Infinite, AddressScore::Finite(_)) => Ordering::Greater, + (AddressScore::Finite(_), AddressScore::Infinite) => Ordering::Less, + (AddressScore::Finite(a), AddressScore::Finite(b)) => a.cmp(b), } } } @@ -135,14 +129,12 @@ impl Add for AddressScore { fn add(self, rhs: AddressScore) -> Self::Output { // Semantics of cardinal numbers with a single infinite cardinal. match (self, rhs) { - (AddressScore::Infinite, AddressScore::Infinite) => - AddressScore::Infinite, - (AddressScore::Infinite, AddressScore::Finite(_)) => - AddressScore::Infinite, - (AddressScore::Finite(_), AddressScore::Infinite) => - AddressScore::Infinite, - (AddressScore::Finite(a), AddressScore::Finite(b)) => + (AddressScore::Infinite, AddressScore::Infinite) => AddressScore::Infinite, + (AddressScore::Infinite, AddressScore::Finite(_)) => AddressScore::Infinite, + (AddressScore::Finite(_), AddressScore::Infinite) => AddressScore::Infinite, + (AddressScore::Finite(a), AddressScore::Finite(b)) => { AddressScore::Finite(a.saturating_add(b)) + } } } } @@ -154,7 +146,7 @@ impl Sub for AddressScore { // Semantics of cardinal numbers with a single infinite cardinal. match self { AddressScore::Infinite => AddressScore::Infinite, - AddressScore::Finite(score) => AddressScore::Finite(score.saturating_sub(rhs)) + AddressScore::Finite(score) => AddressScore::Finite(score.saturating_sub(rhs)), } } } @@ -168,8 +160,12 @@ impl Default for Addresses { /// The result of adding an address to an ordered list of /// addresses with associated scores. pub enum AddAddressResult { - Inserted { expired: SmallVec<[AddressRecord; 8]> }, - Updated { expired: SmallVec<[AddressRecord; 8]> }, + Inserted { + expired: SmallVec<[AddressRecord; 8]>, + }, + Updated { + expired: SmallVec<[AddressRecord; 8]>, + }, } impl Addresses { @@ -207,7 +203,12 @@ impl Addresses { // Remove addresses that have a score of 0. let mut expired = SmallVec::new(); - while self.registry.last().map(|e| e.score.is_zero()).unwrap_or(false) { + while self + .registry + .last() + .map(|e| e.score.is_zero()) + .unwrap_or(false) + { if let Some(addr) = self.registry.pop() { expired.push(addr); } @@ -215,7 +216,10 @@ impl Addresses { // If the address score is finite, remember this report. if let AddressScore::Finite(score) = score { - self.reports.push_back(Report { addr: addr.clone(), score }); + self.reports.push_back(Report { + addr: addr.clone(), + score, + }); } // If the address is already in the collection, increase its score. @@ -223,7 +227,7 @@ impl Addresses { if r.addr == addr { r.score = r.score + score; isort(&mut self.registry); - return AddAddressResult::Updated { expired } + return AddAddressResult::Updated { expired }; } } @@ -249,14 +253,19 @@ impl Addresses { /// /// The iteration is ordered by descending score. pub fn iter(&self) -> AddressIter<'_> { - AddressIter { items: &self.registry, offset: 0 } + AddressIter { + items: &self.registry, + offset: 0, + } } /// Return an iterator over all [`Multiaddr`] values. /// /// The iteration is ordered by descending score. pub fn into_iter(self) -> AddressIntoIter { - AddressIntoIter { items: self.registry } + AddressIntoIter { + items: self.registry, + } } } @@ -264,7 +273,7 @@ impl Addresses { #[derive(Clone)] pub struct AddressIter<'a> { items: &'a [AddressRecord], - offset: usize + offset: usize, } impl<'a> Iterator for AddressIter<'a> { @@ -272,7 +281,7 @@ impl<'a> Iterator for AddressIter<'a> { fn next(&mut self) -> Option { if self.offset == self.items.len() { - return None + return None; } let item = &self.items[self.offset]; self.offset += 1; @@ -314,10 +323,10 @@ impl ExactSizeIterator for AddressIntoIter {} // Reverse insertion sort. fn isort(xs: &mut [AddressRecord]) { - for i in 1 .. xs.len() { - for j in (1 ..= i).rev() { + for i in 1..xs.len() { + for j in (1..=i).rev() { if xs[j].score <= xs[j - 1].score { - break + break; } xs.swap(j, j - 1) } @@ -326,15 +335,16 @@ fn isort(xs: &mut [AddressRecord]) { #[cfg(test)] mod tests { + use super::*; use libp2p_core::multiaddr::{Multiaddr, Protocol}; use quickcheck::*; use rand::Rng; - use std::num::{NonZeroUsize, NonZeroU8}; - use super::*; + use std::num::{NonZeroU8, NonZeroUsize}; impl Arbitrary for AddressScore { fn arbitrary(g: &mut G) -> AddressScore { - if g.gen_range(0, 10) == 0 { // ~10% "Infinitely" scored addresses + if g.gen_range(0, 10) == 0 { + // ~10% "Infinitely" scored addresses AddressScore::Infinite } else { AddressScore::Finite(g.gen()) @@ -353,13 +363,14 @@ mod tests { #[test] fn isort_sorts() { fn property(xs: Vec) { - let mut xs = xs.into_iter() + let mut xs = xs + .into_iter() .map(|score| AddressRecord::new(Multiaddr::empty(), score)) .collect::>(); isort(&mut xs); - for i in 1 .. xs.len() { + for i in 1..xs.len() { assert!(xs[i - 1].score >= xs[i].score) } } @@ -371,7 +382,7 @@ mod tests { fn score_retention() { fn prop(first: AddressRecord, other: AddressRecord) -> TestResult { if first.addr == other.addr { - return TestResult::discard() + return TestResult::discard(); } let mut addresses = Addresses::default(); @@ -383,7 +394,7 @@ mod tests { // Add another address so often that the initial report of // the first address may be purged and, since it was the // only report, the address removed. - for _ in 0 .. addresses.limit.get() + 1 { + for _ in 0..addresses.limit.get() + 1 { addresses.add(other.addr.clone(), other.score); } @@ -398,7 +409,7 @@ mod tests { TestResult::passed() } - quickcheck(prop as fn(_,_) -> _); + quickcheck(prop as fn(_, _) -> _); } #[test] @@ -412,16 +423,22 @@ mod tests { } // Count the finitely scored addresses. - let num_finite = addresses.iter().filter(|r| match r { - AddressRecord { score: AddressScore::Finite(_), .. } => true, - _ => false, - }).count(); + let num_finite = addresses + .iter() + .filter(|r| match r { + AddressRecord { + score: AddressScore::Finite(_), + .. + } => true, + _ => false, + }) + .count(); // Check against the limit. assert!(num_finite <= limit.get() as usize); } - quickcheck(prop as fn(_,_)); + quickcheck(prop as fn(_, _)); } #[test] @@ -438,16 +455,16 @@ mod tests { // Check that each address in the registry has the expected score. for r in &addresses.registry { - let expected_score = records.iter().fold( - None::, |sum, rec| - if &rec.addr == &r.addr { - sum.map_or(Some(rec.score), |s| Some(s + rec.score)) - } else { - sum - }); + let expected_score = records.iter().fold(None::, |sum, rec| { + if &rec.addr == &r.addr { + sum.map_or(Some(rec.score), |s| Some(s + rec.score)) + } else { + sum + } + }); if Some(r.score) != expected_score { - return false + return false; } } diff --git a/swarm/src/test.rs b/swarm/src/test.rs index 4ae647d38db..5cb05d7baf3 100644 --- a/swarm/src/test.rs +++ b/swarm/src/test.rs @@ -19,17 +19,13 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - NetworkBehaviour, - NetworkBehaviourAction, + IntoProtocolsHandler, NetworkBehaviour, NetworkBehaviourAction, PollParameters, ProtocolsHandler, - IntoProtocolsHandler, - PollParameters }; use libp2p_core::{ - ConnectedPoint, - PeerId, connection::{ConnectionId, ListenerId}, multiaddr::Multiaddr, + ConnectedPoint, PeerId, }; use std::collections::HashMap; use std::task::{Context, Poll}; @@ -54,7 +50,7 @@ where impl MockBehaviour where - THandler: ProtocolsHandler + THandler: ProtocolsHandler, { pub fn new(handler_proto: THandler) -> Self { MockBehaviour { @@ -82,12 +78,13 @@ where self.addresses.get(p).map_or(Vec::new(), |v| v.clone()) } - fn inject_event(&mut self, _: PeerId, _: ConnectionId, _: THandler::OutEvent) { - } + fn inject_event(&mut self, _: PeerId, _: ConnectionId, _: THandler::OutEvent) {} - fn poll(&mut self, _: &mut Context, _: &mut impl PollParameters) -> - Poll> - { + fn poll( + &mut self, + _: &mut Context, + _: &mut impl PollParameters, + ) -> Poll> { self.next_action.take().map_or(Poll::Pending, Poll::Ready) } } @@ -106,7 +103,11 @@ where pub inject_disconnected: Vec, pub inject_connection_established: Vec<(PeerId, ConnectionId, ConnectedPoint)>, pub inject_connection_closed: Vec<(PeerId, ConnectionId, ConnectedPoint)>, - pub inject_event: Vec<(PeerId, ConnectionId, <::Handler as ProtocolsHandler>::OutEvent)>, + pub inject_event: Vec<( + PeerId, + ConnectionId, + <::Handler as ProtocolsHandler>::OutEvent, + )>, pub inject_addr_reach_failure: Vec<(Option, Multiaddr)>, pub inject_dial_failure: Vec, pub inject_new_listener: Vec, @@ -121,7 +122,7 @@ where impl CallTraceBehaviour where - TInner: NetworkBehaviour + TInner: NetworkBehaviour, { pub fn new(inner: TInner) -> Self { Self { @@ -162,13 +163,16 @@ where self.poll = 0; } - pub fn inner(&mut self) -> &mut TInner { &mut self.inner } + pub fn inner(&mut self) -> &mut TInner { + &mut self.inner + } } impl NetworkBehaviour for CallTraceBehaviour where TInner: NetworkBehaviour, - <::Handler as ProtocolsHandler>::OutEvent: Clone, + <::Handler as ProtocolsHandler>::OutEvent: + Clone, { type ProtocolsHandler = TInner::ProtocolsHandler; type OutEvent = TInner::OutEvent; @@ -188,7 +192,8 @@ where } fn inject_connection_established(&mut self, p: &PeerId, c: &ConnectionId, e: &ConnectedPoint) { - self.inject_connection_established.push((p.clone(), c.clone(), e.clone())); + self.inject_connection_established + .push((p.clone(), c.clone(), e.clone())); self.inner.inject_connection_established(p, c, e); } @@ -198,16 +203,27 @@ where } fn inject_connection_closed(&mut self, p: &PeerId, c: &ConnectionId, e: &ConnectedPoint) { - self.inject_connection_closed.push((p.clone(), c.clone(), e.clone())); + self.inject_connection_closed + .push((p.clone(), c.clone(), e.clone())); self.inner.inject_connection_closed(p, c, e); } - fn inject_event(&mut self, p: PeerId, c: ConnectionId, e: <::Handler as ProtocolsHandler>::OutEvent) { + fn inject_event( + &mut self, + p: PeerId, + c: ConnectionId, + e: <::Handler as ProtocolsHandler>::OutEvent, + ) { self.inject_event.push((p.clone(), c.clone(), e.clone())); self.inner.inject_event(p, c, e); } - fn inject_addr_reach_failure(&mut self, p: Option<&PeerId>, a: &Multiaddr, e: &dyn std::error::Error) { + fn inject_addr_reach_failure( + &mut self, + p: Option<&PeerId>, + a: &Multiaddr, + e: &dyn std::error::Error, + ) { self.inject_addr_reach_failure.push((p.cloned(), a.clone())); self.inner.inject_addr_reach_failure(p, a, e); } diff --git a/swarm/src/toggle.rs b/swarm/src/toggle.rs index d986f00fb01..5a86a4824ed 100644 --- a/swarm/src/toggle.rs +++ b/swarm/src/toggle.rs @@ -18,24 +18,20 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::{NetworkBehaviour, NetworkBehaviourAction, NetworkBehaviourEventProcess, PollParameters}; -use crate::upgrade::{SendWrapper, InboundUpgradeSend, OutboundUpgradeSend}; use crate::protocols_handler::{ - KeepAlive, - SubstreamProtocol, - ProtocolsHandler, - ProtocolsHandlerEvent, - ProtocolsHandlerUpgrErr, - IntoProtocolsHandler + IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, + ProtocolsHandlerUpgrErr, SubstreamProtocol, +}; +use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper}; +use crate::{ + NetworkBehaviour, NetworkBehaviourAction, NetworkBehaviourEventProcess, PollParameters, }; use either::Either; use libp2p_core::{ - ConnectedPoint, - PeerId, - Multiaddr, connection::{ConnectionId, ListenerId}, either::{EitherError, EitherOutput}, - upgrade::{DeniedUpgrade, EitherUpgrade} + upgrade::{DeniedUpgrade, EitherUpgrade}, + ConnectedPoint, Multiaddr, PeerId, }; use std::{error, task::Context, task::Poll}; @@ -71,19 +67,22 @@ impl From> for Toggle { impl NetworkBehaviour for Toggle where - TBehaviour: NetworkBehaviour + TBehaviour: NetworkBehaviour, { type ProtocolsHandler = ToggleIntoProtoHandler; type OutEvent = TBehaviour::OutEvent; fn new_handler(&mut self) -> Self::ProtocolsHandler { ToggleIntoProtoHandler { - inner: self.inner.as_mut().map(|i| i.new_handler()) + inner: self.inner.as_mut().map(|i| i.new_handler()), } } fn addresses_of_peer(&mut self, peer_id: &PeerId) -> Vec { - self.inner.as_mut().map(|b| b.addresses_of_peer(peer_id)).unwrap_or_else(Vec::new) + self.inner + .as_mut() + .map(|b| b.addresses_of_peer(peer_id)) + .unwrap_or_else(Vec::new) } fn inject_connected(&mut self, peer_id: &PeerId) { @@ -98,19 +97,35 @@ where } } - fn inject_connection_established(&mut self, peer_id: &PeerId, connection: &ConnectionId, endpoint: &ConnectedPoint) { + fn inject_connection_established( + &mut self, + peer_id: &PeerId, + connection: &ConnectionId, + endpoint: &ConnectedPoint, + ) { if let Some(inner) = self.inner.as_mut() { inner.inject_connection_established(peer_id, connection, endpoint) } } - fn inject_connection_closed(&mut self, peer_id: &PeerId, connection: &ConnectionId, endpoint: &ConnectedPoint) { + fn inject_connection_closed( + &mut self, + peer_id: &PeerId, + connection: &ConnectionId, + endpoint: &ConnectedPoint, + ) { if let Some(inner) = self.inner.as_mut() { inner.inject_connection_closed(peer_id, connection, endpoint) } } - fn inject_address_change(&mut self, peer_id: &PeerId, connection: &ConnectionId, old: &ConnectedPoint, new: &ConnectedPoint) { + fn inject_address_change( + &mut self, + peer_id: &PeerId, + connection: &ConnectionId, + old: &ConnectedPoint, + new: &ConnectedPoint, + ) { if let Some(inner) = self.inner.as_mut() { inner.inject_address_change(peer_id, connection, old, new) } @@ -120,14 +135,19 @@ where &mut self, peer_id: PeerId, connection: ConnectionId, - event: <::Handler as ProtocolsHandler>::OutEvent + event: <::Handler as ProtocolsHandler>::OutEvent, ) { if let Some(inner) = self.inner.as_mut() { inner.inject_event(peer_id, connection, event); } } - fn inject_addr_reach_failure(&mut self, peer_id: Option<&PeerId>, addr: &Multiaddr, error: &dyn error::Error) { + fn inject_addr_reach_failure( + &mut self, + peer_id: Option<&PeerId>, + addr: &Multiaddr, + error: &dyn error::Error, + ) { if let Some(inner) = self.inner.as_mut() { inner.inject_addr_reach_failure(peer_id, addr, error) } @@ -194,7 +214,7 @@ where impl NetworkBehaviourEventProcess for Toggle where - TBehaviour: NetworkBehaviourEventProcess + TBehaviour: NetworkBehaviourEventProcess, { fn inject_event(&mut self, event: TEvent) { if let Some(inner) = self.inner.as_mut() { @@ -210,13 +230,19 @@ pub struct ToggleIntoProtoHandler { impl IntoProtocolsHandler for ToggleIntoProtoHandler where - TInner: IntoProtocolsHandler + TInner: IntoProtocolsHandler, { type Handler = ToggleProtoHandler; - fn into_handler(self, remote_peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler { + fn into_handler( + self, + remote_peer_id: &PeerId, + connected_point: &ConnectedPoint, + ) -> Self::Handler { ToggleProtoHandler { - inner: self.inner.map(|h| h.into_handler(remote_peer_id, connected_point)) + inner: self + .inner + .map(|h| h.into_handler(remote_peer_id, connected_point)), } } @@ -241,25 +267,30 @@ where type InEvent = TInner::InEvent; type OutEvent = TInner::OutEvent; type Error = TInner::Error; - type InboundProtocol = EitherUpgrade, SendWrapper>; + type InboundProtocol = + EitherUpgrade, SendWrapper>; type OutboundProtocol = TInner::OutboundProtocol; type OutboundOpenInfo = TInner::OutboundOpenInfo; type InboundOpenInfo = Either; fn listen_protocol(&self) -> SubstreamProtocol { if let Some(inner) = self.inner.as_ref() { - inner.listen_protocol() + inner + .listen_protocol() .map_upgrade(|u| EitherUpgrade::A(SendWrapper(u))) .map_info(Either::Left) } else { - SubstreamProtocol::new(EitherUpgrade::B(SendWrapper(DeniedUpgrade)), Either::Right(())) + SubstreamProtocol::new( + EitherUpgrade::B(SendWrapper(DeniedUpgrade)), + Either::Right(()), + ) } } fn inject_fully_negotiated_inbound( &mut self, out: ::Output, - info: Self::InboundOpenInfo + info: Self::InboundOpenInfo, ) { let out = match out { EitherOutput::First(out) => out, @@ -267,7 +298,8 @@ where }; if let Either::Left(info) = info { - self.inner.as_mut() + self.inner + .as_mut() .expect("Can't receive an inbound substream if disabled; QED") .inject_fully_negotiated_inbound(out, info) } else { @@ -278,14 +310,18 @@ where fn inject_fully_negotiated_outbound( &mut self, out: ::Output, - info: Self::OutboundOpenInfo + info: Self::OutboundOpenInfo, ) { - self.inner.as_mut().expect("Can't receive an outbound substream if disabled; QED") + self.inner + .as_mut() + .expect("Can't receive an outbound substream if disabled; QED") .inject_fully_negotiated_outbound(out, info) } fn inject_event(&mut self, event: Self::InEvent) { - self.inner.as_mut().expect("Can't receive events if disabled; QED") + self.inner + .as_mut() + .expect("Can't receive events if disabled; QED") .inject_event(event) } @@ -295,12 +331,22 @@ where } } - fn inject_dial_upgrade_error(&mut self, info: Self::OutboundOpenInfo, err: ProtocolsHandlerUpgrErr<::Error>) { - self.inner.as_mut().expect("Can't receive an outbound substream if disabled; QED") + fn inject_dial_upgrade_error( + &mut self, + info: Self::OutboundOpenInfo, + err: ProtocolsHandlerUpgrErr<::Error>, + ) { + self.inner + .as_mut() + .expect("Can't receive an outbound substream if disabled; QED") .inject_dial_upgrade_error(info, err) } - fn inject_listen_upgrade_error(&mut self, info: Self::InboundOpenInfo, err: ProtocolsHandlerUpgrErr<::Error>) { + fn inject_listen_upgrade_error( + &mut self, + info: Self::InboundOpenInfo, + err: ProtocolsHandlerUpgrErr<::Error>, + ) { let (inner, info) = match (self.inner.as_mut(), info) { (Some(inner), Either::Left(info)) => (inner, info), // Ignore listen upgrade errors in disabled state. @@ -313,24 +359,26 @@ where "Unexpected `Either::Left` inbound info through \ `inject_listen_upgrade_error` in disabled state.", ), - }; let err = match err { ProtocolsHandlerUpgrErr::Timeout => ProtocolsHandlerUpgrErr::Timeout, ProtocolsHandlerUpgrErr::Timer => ProtocolsHandlerUpgrErr::Timer, - ProtocolsHandlerUpgrErr::Upgrade(err) => + ProtocolsHandlerUpgrErr::Upgrade(err) => { ProtocolsHandlerUpgrErr::Upgrade(err.map_err(|err| match err { EitherError::A(e) => e, - EitherError::B(v) => void::unreachable(v) + EitherError::B(v) => void::unreachable(v), })) + } }; inner.inject_listen_upgrade_error(info, err) } fn connection_keep_alive(&self) -> KeepAlive { - self.inner.as_ref().map(|h| h.connection_keep_alive()) + self.inner + .as_ref() + .map(|h| h.connection_keep_alive()) .unwrap_or(KeepAlive::No) } @@ -338,7 +386,12 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll< - ProtocolsHandlerEvent + ProtocolsHandlerEvent< + Self::OutboundProtocol, + Self::OutboundOpenInfo, + Self::OutEvent, + Self::Error, + >, > { if let Some(inner) = self.inner.as_mut() { inner.poll(cx) @@ -369,9 +422,7 @@ mod tests { /// [`ToggleProtoHandler`] should ignore the error in both of these cases. #[test] fn ignore_listen_upgrade_error_when_disabled() { - let mut handler = ToggleProtoHandler:: { - inner: None, - }; + let mut handler = ToggleProtoHandler:: { inner: None }; handler.inject_listen_upgrade_error(Either::Right(()), ProtocolsHandlerUpgrErr::Timeout); } diff --git a/transports/deflate/src/lib.rs b/transports/deflate/src/lib.rs index d93e6ed2e39..698b6cab6a9 100644 --- a/transports/deflate/src/lib.rs +++ b/transports/deflate/src/lib.rs @@ -105,11 +105,12 @@ impl DeflateOutput { /// Tries to write the content of `self.write_out` to `self.inner`. /// Returns `Ready(Ok(()))` if `self.write_out` is empty. fn flush_write_out(&mut self, cx: &mut Context<'_>) -> Poll> - where S: AsyncWrite + Unpin + where + S: AsyncWrite + Unpin, { loop { if self.write_out.is_empty() { - return Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } match AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &self.write_out) { @@ -123,9 +124,14 @@ impl DeflateOutput { } impl AsyncRead for DeflateOutput - where S: AsyncRead + Unpin +where + S: AsyncRead + Unpin, { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { // We use a `this` variable because the compiler doesn't allow multiple mutable borrows // across a `Deref`. let this = &mut *self; @@ -133,31 +139,38 @@ impl AsyncRead for DeflateOutput loop { // Read from `self.inner` into `self.read_interm` if necessary. if this.read_interm.is_empty() && !this.inner_read_eof { - this.read_interm.resize(this.read_interm.capacity() + 256, 0); + this.read_interm + .resize(this.read_interm.capacity() + 256, 0); match AsyncRead::poll_read(Pin::new(&mut this.inner), cx, &mut this.read_interm) { Poll::Ready(Ok(0)) => { this.inner_read_eof = true; this.read_interm.clear(); } - Poll::Ready(Ok(n)) => { - this.read_interm.truncate(n) - }, + Poll::Ready(Ok(n)) => this.read_interm.truncate(n), Poll::Ready(Err(err)) => { this.read_interm.clear(); - return Poll::Ready(Err(err)) - }, + return Poll::Ready(Err(err)); + } Poll::Pending => { this.read_interm.clear(); - return Poll::Pending - }, + return Poll::Pending; + } } } debug_assert!(!this.read_interm.is_empty() || this.inner_read_eof); let before_out = this.decompress.total_out(); let before_in = this.decompress.total_in(); - let ret = this.decompress.decompress(&this.read_interm, buf, if this.inner_read_eof { flate2::FlushDecompress::Finish } else { flate2::FlushDecompress::None })?; + let ret = this.decompress.decompress( + &this.read_interm, + buf, + if this.inner_read_eof { + flate2::FlushDecompress::Finish + } else { + flate2::FlushDecompress::None + }, + )?; // Remove from `self.read_interm` the bytes consumed by the decompressor. let consumed = (this.decompress.total_in() - before_in) as usize; @@ -165,18 +178,21 @@ impl AsyncRead for DeflateOutput let read = (this.decompress.total_out() - before_out) as usize; if read != 0 || ret == flate2::Status::StreamEnd { - return Poll::Ready(Ok(read)) + return Poll::Ready(Ok(read)); } } } } impl AsyncWrite for DeflateOutput - where S: AsyncWrite + Unpin +where + S: AsyncWrite + Unpin, { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) - -> Poll> - { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { // We use a `this` variable because the compiler doesn't allow multiple mutable borrows // across a `Deref`. let this = &mut *self; @@ -195,8 +211,12 @@ impl AsyncWrite for DeflateOutput // Instead, we invoke the compressor in a loop until it accepts some of our data. loop { let before_in = this.compress.total_in(); - this.write_out.reserve(256); // compress_vec uses the Vec's capacity - let ret = this.compress.compress_vec(buf, &mut this.write_out, flate2::FlushCompress::None)?; + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + let ret = this.compress.compress_vec( + buf, + &mut this.write_out, + flate2::FlushCompress::None, + )?; let written = (this.compress.total_in() - before_in) as usize; if written != 0 || ret == flate2::Status::StreamEnd { @@ -211,15 +231,17 @@ impl AsyncWrite for DeflateOutput let this = &mut *self; ready!(this.flush_write_out(cx))?; - this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Sync)?; + this.compress + .compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Sync)?; loop { ready!(this.flush_write_out(cx))?; debug_assert!(this.write_out.is_empty()); // We ask the compressor to flush everything into `self.write_out`. - this.write_out.reserve(256); // compress_vec uses the Vec's capacity - this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::None)?; + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + this.compress + .compress_vec(&[], &mut this.write_out, flate2::FlushCompress::None)?; if this.write_out.is_empty() { break; } @@ -238,8 +260,9 @@ impl AsyncWrite for DeflateOutput // We ask the compressor to flush everything into `self.write_out`. debug_assert!(this.write_out.is_empty()); - this.write_out.reserve(256); // compress_vec uses the Vec's capacity - this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Finish)?; + this.write_out.reserve(256); // compress_vec uses the Vec's capacity + this.compress + .compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Finish)?; if this.write_out.is_empty() { break; } diff --git a/transports/deflate/tests/test.rs b/transports/deflate/tests/test.rs index 896fb491349..6027c4f4afb 100644 --- a/transports/deflate/tests/test.rs +++ b/transports/deflate/tests/test.rs @@ -28,7 +28,7 @@ use quickcheck::{QuickCheck, RngCore, TestResult}; fn deflate() { fn prop(message: Vec) -> TestResult { if message.is_empty() { - return TestResult::discard() + return TestResult::discard(); } async_std::task::block_on(run(message)); TestResult::passed() @@ -44,16 +44,24 @@ fn lot_of_data() { } async fn run(message1: Vec) { - let transport = TcpConfig::new() - .and_then(|conn, endpoint| { - upgrade::apply(conn, DeflateConfig::default(), endpoint, upgrade::Version::V1) - }); + let transport = TcpConfig::new().and_then(|conn, endpoint| { + upgrade::apply( + conn, + DeflateConfig::default(), + endpoint, + upgrade::Version::V1, + ) + }); - let mut listener = transport.clone() + let mut listener = transport + .clone() .listen_on("/ip4/0.0.0.0/tcp/0".parse().expect("multiaddr")) .expect("listener"); - let listen_addr = listener.by_ref().next().await + let listen_addr = listener + .by_ref() + .next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -82,7 +90,11 @@ async fn run(message1: Vec) { conn.close().await.expect("close") }); - let mut conn = transport.dial(listen_addr).expect("dialer").await.expect("connection"); + let mut conn = transport + .dial(listen_addr) + .expect("dialer") + .await + .expect("connection"); conn.write_all(&message1).await.expect("write_all"); conn.close().await.expect("close"); diff --git a/transports/dns/src/lib.rs b/transports/dns/src/lib.rs index 499f33e8e5e..6174c1e362c 100644 --- a/transports/dns/src/lib.rs +++ b/transports/dns/src/lib.rs @@ -54,27 +54,23 @@ //! //![trust-dns-resolver]: https://docs.rs/trust-dns-resolver/latest/trust_dns_resolver/#dns-over-tls-and-dns-over-https -use futures::{prelude::*, future::BoxFuture}; +#[cfg(feature = "async-std")] +use async_std_resolver::{AsyncStdConnection, AsyncStdConnectionProvider}; +use futures::{future::BoxFuture, prelude::*}; use libp2p_core::{ + multiaddr::{Multiaddr, Protocol}, + transport::{ListenerEvent, TransportError}, Transport, - multiaddr::{Protocol, Multiaddr}, - transport::{TransportError, ListenerEvent} }; use smallvec::SmallVec; -use std::{convert::TryFrom, error, fmt, iter, net::IpAddr, str}; #[cfg(any(feature = "async-std", feature = "tokio"))] use std::io; +use std::{convert::TryFrom, error, fmt, iter, net::IpAddr, str}; #[cfg(any(feature = "async-std", feature = "tokio"))] use trust_dns_resolver::system_conf; -use trust_dns_resolver::{ - AsyncResolver, - ConnectionProvider, - proto::xfer::dns_handle::DnsHandle, -}; +use trust_dns_resolver::{proto::xfer::dns_handle::DnsHandle, AsyncResolver, ConnectionProvider}; #[cfg(feature = "tokio")] use trust_dns_resolver::{TokioAsyncResolver, TokioConnection, TokioConnectionProvider}; -#[cfg(feature = "async-std")] -use async_std_resolver::{AsyncStdConnection, AsyncStdConnectionProvider}; pub use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; pub use trust_dns_resolver::error::{ResolveError, ResolveErrorKind}; @@ -112,7 +108,7 @@ pub type TokioDnsConfig = GenDnsConfig where C: DnsHandle, - P: ConnectionProvider + P: ConnectionProvider, { /// The underlying transport. inner: T, @@ -129,12 +125,14 @@ impl DnsConfig { } /// Creates a [`DnsConfig`] with a custom resolver configuration and options. - pub async fn custom(inner: T, cfg: ResolverConfig, opts: ResolverOpts) - -> Result, io::Error> - { + pub async fn custom( + inner: T, + cfg: ResolverConfig, + opts: ResolverOpts, + ) -> Result, io::Error> { Ok(DnsConfig { inner, - resolver: async_std_resolver::resolver(cfg, opts).await? + resolver: async_std_resolver::resolver(cfg, opts).await?, }) } } @@ -149,12 +147,14 @@ impl TokioDnsConfig { /// Creates a [`TokioDnsConfig`] with a custom resolver configuration /// and options. - pub fn custom(inner: T, cfg: ResolverConfig, opts: ResolverOpts) - -> Result, io::Error> - { + pub fn custom( + inner: T, + cfg: ResolverConfig, + opts: ResolverOpts, + ) -> Result, io::Error> { Ok(TokioDnsConfig { inner, - resolver: TokioAsyncResolver::tokio(cfg, opts)? + resolver: TokioAsyncResolver::tokio(cfg, opts)?, }) } } @@ -181,24 +181,29 @@ where type Output = T::Output; type Error = DnsErr; type Listener = stream::MapErr< - stream::MapOk) - -> ListenerEvent>, - fn(T::Error) -> Self::Error>; + stream::MapOk< + T::Listener, + fn( + ListenerEvent, + ) -> ListenerEvent, + >, + fn(T::Error) -> Self::Error, + >; type ListenerUpgrade = future::MapErr Self::Error>; type Dial = future::Either< future::MapErr Self::Error>, - BoxFuture<'static, Result> + BoxFuture<'static, Result>, >; fn listen_on(self, addr: Multiaddr) -> Result> { - let listener = self.inner.listen_on(addr).map_err(|err| err.map(DnsErr::Transport))?; + let listener = self + .inner + .listen_on(addr) + .map_err(|err| err.map(DnsErr::Transport))?; let listener = listener .map_ok::<_, fn(_) -> _>(|event| { event - .map(|upgr| { - upgr.map_err::<_, fn(_) -> _>(DnsErr::Transport) - }) + .map(|upgr| upgr.map_err::<_, fn(_) -> _>(DnsErr::Transport)) .map_err(DnsErr::Transport) }) .map_err::<_, fn(_) -> _>(DnsErr::Transport); @@ -225,24 +230,24 @@ where // address. while let Some(addr) = unresolved.pop() { if let Some((i, name)) = addr.iter().enumerate().find(|(_, p)| match p { - Protocol::Dns(_) | - Protocol::Dns4(_) | - Protocol::Dns6(_) | - Protocol::Dnsaddr(_) => true, - _ => false + Protocol::Dns(_) + | Protocol::Dns4(_) + | Protocol::Dns6(_) + | Protocol::Dnsaddr(_) => true, + _ => false, }) { if dns_lookups == MAX_DNS_LOOKUPS { log::debug!("Too many DNS lookups. Dropping unresolved {}.", addr); last_err = Some(DnsErr::TooManyLookups); // There may still be fully resolved addresses in `unresolved`, // so keep going until `unresolved` is empty. - continue + continue; } dns_lookups += 1; match resolve(&name, &resolver).await { Err(e) => { if unresolved.is_empty() { - return Err(e) + return Err(e); } // If there are still unresolved addresses, there is // a chance of success, but we track the last error. @@ -256,7 +261,8 @@ where Ok(Resolved::Many(ips)) => { for ip in ips { log::trace!("Resolved {} -> {}", name, ip); - let addr = addr.replace(i, |_| Some(ip)).expect("`i` is a valid index"); + let addr = + addr.replace(i, |_| Some(ip)).expect("`i` is a valid index"); unresolved.push(addr); } } @@ -269,10 +275,14 @@ where if n < MAX_TXT_RECORDS { n += 1; log::trace!("Resolved {} -> {}", name, a); - let addr = prefix.iter().chain(a.iter()).collect::(); + let addr = + prefix.iter().chain(a.iter()).collect::(); unresolved.push(addr); } else { - log::debug!("Too many TXT records. Dropping resolved {}.", a); + log::debug!( + "Too many TXT records. Dropping resolved {}.", + a + ); } } } @@ -291,9 +301,10 @@ where dial_attempts += 1; out.await.map_err(DnsErr::Transport) } - Err(TransportError::MultiaddrNotSupported(a)) => - Err(DnsErr::MultiaddrNotSupported(a)), - Err(TransportError::Other(err)) => Err(DnsErr::Transport(err)) + Err(TransportError::MultiaddrNotSupported(a)) => { + Err(DnsErr::MultiaddrNotSupported(a)) + } + Err(TransportError::Other(err)) => Err(DnsErr::Transport(err)), }; match result { @@ -301,11 +312,14 @@ where Err(err) => { log::debug!("Dial error: {:?}.", err); if unresolved.is_empty() { - return Err(err) + return Err(err); } if dial_attempts == MAX_DIAL_ATTEMPTS { - log::debug!("Aborting dialing after {} attempts.", MAX_DIAL_ATTEMPTS); - return Err(err) + log::debug!( + "Aborting dialing after {} attempts.", + MAX_DIAL_ATTEMPTS + ); + return Err(err); } last_err = Some(err); } @@ -317,10 +331,12 @@ where // attempt, return that error. Otherwise there were no valid DNS records // for the given address to begin with (i.e. DNS lookups succeeded but // produced no records relevant for the given `addr`). - Err(last_err.unwrap_or_else(|| - DnsErr::ResolveError( - ResolveErrorKind::Message("No matching records found.").into()))) - }.boxed().right_future()) + Err(last_err.unwrap_or_else(|| { + DnsErr::ResolveError(ResolveErrorKind::Message("No matching records found.").into()) + })) + } + .boxed() + .right_future()) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { @@ -348,7 +364,8 @@ pub enum DnsErr { } impl fmt::Display for DnsErr -where TErr: fmt::Display +where + TErr: fmt::Display, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -361,7 +378,8 @@ where TErr: fmt::Display } impl error::Error for DnsErr -where TErr: error::Error + 'static +where + TErr: error::Error + 'static, { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { @@ -393,96 +411,111 @@ enum Resolved<'a> { /// [`Resolved::One`]. fn resolve<'a, E: 'a + Send, C, P>( proto: &Protocol<'a>, - resolver: &'a AsyncResolver, + resolver: &'a AsyncResolver, ) -> BoxFuture<'a, Result, DnsErr>> where C: DnsHandle, P: ConnectionProvider, { match proto { - Protocol::Dns(ref name) => { - resolver.lookup_ip(name.clone().into_owned()).map(move |res| match res { + Protocol::Dns(ref name) => resolver + .lookup_ip(name.clone().into_owned()) + .map(move |res| match res { Ok(ips) => { let mut ips = ips.into_iter(); - let one = ips.next() + let one = ips + .next() .expect("If there are no results, `Err(NoRecordsFound)` is expected."); if let Some(two) = ips.next() { Ok(Resolved::Many( - iter::once(one).chain(iter::once(two)) + iter::once(one) + .chain(iter::once(two)) .chain(ips) .map(Protocol::from) - .collect())) + .collect(), + )) } else { Ok(Resolved::One(Protocol::from(one))) } } - Err(e) => Err(DnsErr::ResolveError(e)) - }).boxed() - } - Protocol::Dns4(ref name) => { - resolver.ipv4_lookup(name.clone().into_owned()).map(move |res| match res { + Err(e) => Err(DnsErr::ResolveError(e)), + }) + .boxed(), + Protocol::Dns4(ref name) => resolver + .ipv4_lookup(name.clone().into_owned()) + .map(move |res| match res { Ok(ips) => { let mut ips = ips.into_iter(); - let one = ips.next() + let one = ips + .next() .expect("If there are no results, `Err(NoRecordsFound)` is expected."); if let Some(two) = ips.next() { Ok(Resolved::Many( - iter::once(one).chain(iter::once(two)) + iter::once(one) + .chain(iter::once(two)) .chain(ips) .map(IpAddr::from) .map(Protocol::from) - .collect())) + .collect(), + )) } else { Ok(Resolved::One(Protocol::from(IpAddr::from(one)))) } } - Err(e) => Err(DnsErr::ResolveError(e)) - }).boxed() - } - Protocol::Dns6(ref name) => { - resolver.ipv6_lookup(name.clone().into_owned()).map(move |res| match res { + Err(e) => Err(DnsErr::ResolveError(e)), + }) + .boxed(), + Protocol::Dns6(ref name) => resolver + .ipv6_lookup(name.clone().into_owned()) + .map(move |res| match res { Ok(ips) => { let mut ips = ips.into_iter(); - let one = ips.next() + let one = ips + .next() .expect("If there are no results, `Err(NoRecordsFound)` is expected."); if let Some(two) = ips.next() { Ok(Resolved::Many( - iter::once(one).chain(iter::once(two)) + iter::once(one) + .chain(iter::once(two)) .chain(ips) .map(IpAddr::from) .map(Protocol::from) - .collect())) + .collect(), + )) } else { Ok(Resolved::One(Protocol::from(IpAddr::from(one)))) } } - Err(e) => Err(DnsErr::ResolveError(e)) - }).boxed() - }, + Err(e) => Err(DnsErr::ResolveError(e)), + }) + .boxed(), Protocol::Dnsaddr(ref name) => { let name = [DNSADDR_PREFIX, name].concat(); - resolver.txt_lookup(name).map(move |res| match res { - Ok(txts) => { - let mut addrs = Vec::new(); - for txt in txts { - if let Some(chars) = txt.txt_data().first() { - match parse_dnsaddr_txt(chars) { - Err(e) => { - // Skip over seemingly invalid entries. - log::debug!("Invalid TXT record: {:?}", e); - } - Ok(a) => { - addrs.push(a); + resolver + .txt_lookup(name) + .map(move |res| match res { + Ok(txts) => { + let mut addrs = Vec::new(); + for txt in txts { + if let Some(chars) = txt.txt_data().first() { + match parse_dnsaddr_txt(chars) { + Err(e) => { + // Skip over seemingly invalid entries. + log::debug!("Invalid TXT record: {:?}", e); + } + Ok(a) => { + addrs.push(a); + } } } } + Ok(Resolved::Addrs(addrs)) } - Ok(Resolved::Addrs(addrs)) - } - Err(e) => Err(DnsErr::ResolveError(e)) - }).boxed() + Err(e) => Err(DnsErr::ResolveError(e)), + }) + .boxed() } - proto => future::ready(Ok(Resolved::One(proto.clone()))).boxed() + proto => future::ready(Ok(Resolved::One(proto.clone()))).boxed(), } } @@ -491,7 +524,7 @@ fn parse_dnsaddr_txt(txt: &[u8]) -> io::Result { let s = str::from_utf8(txt).map_err(invalid_data)?; match s.strip_prefix("dnsaddr=") { None => Err(invalid_data("Missing `dnsaddr=` prefix.")), - Some(a) => Ok(Multiaddr::try_from(a).map_err(invalid_data)?) + Some(a) => Ok(Multiaddr::try_from(a).map_err(invalid_data)?), } } @@ -504,11 +537,10 @@ mod tests { use super::*; use futures::{future::BoxFuture, stream::BoxStream}; use libp2p_core::{ - Transport, - PeerId, - multiaddr::{Protocol, Multiaddr}, + multiaddr::{Multiaddr, Protocol}, transport::ListenerEvent, transport::TransportError, + PeerId, Transport, }; #[test] @@ -521,19 +553,27 @@ mod tests { impl Transport for CustomTransport { type Output = (); type Error = std::io::Error; - type Listener = BoxStream<'static, Result, Self::Error>>; + type Listener = BoxStream< + 'static, + Result, Self::Error>, + >; type ListenerUpgrade = BoxFuture<'static, Result>; type Dial = BoxFuture<'static, Result>; - fn listen_on(self, _: Multiaddr) -> Result> { + fn listen_on( + self, + _: Multiaddr, + ) -> Result> { unreachable!() } fn dial(self, addr: Multiaddr) -> Result> { // Check that all DNS components have been resolved, i.e. replaced. assert!(!addr.iter().any(|p| match p { - Protocol::Dns(_) | Protocol::Dns4(_) | Protocol::Dns6(_) | Protocol::Dnsaddr(_) - => true, + Protocol::Dns(_) + | Protocol::Dns4(_) + | Protocol::Dns6(_) + | Protocol::Dnsaddr(_) => true, _ => false, })); Ok(Box::pin(future::ready(Ok(())))) @@ -598,13 +638,17 @@ mod tests { // an entry with a random `p2p` suffix. match transport .clone() - .dial(format!("/dnsaddr/bootstrap.libp2p.io/p2p/{}", PeerId::random()).parse().unwrap()) + .dial( + format!("/dnsaddr/bootstrap.libp2p.io/p2p/{}", PeerId::random()) + .parse() + .unwrap(), + ) .unwrap() .await { - Err(DnsErr::ResolveError(_)) => {}, + Err(DnsErr::ResolveError(_)) => {} Err(e) => panic!("Unexpected error: {:?}", e), - Ok(_) => panic!("Unexpected success.") + Ok(_) => panic!("Unexpected success."), } // Failure due to no records. @@ -615,7 +659,7 @@ mod tests { .await { Err(DnsErr::ResolveError(e)) => match e.kind() { - ResolveErrorKind::NoRecordsFound { .. } => {}, + ResolveErrorKind::NoRecordsFound { .. } => {} _ => panic!("Unexpected DNS error: {:?}", e), }, Err(e) => panic!("Unexpected error: {:?}", e), @@ -630,7 +674,7 @@ mod tests { let config = ResolverConfig::quad9(); let opts = ResolverOpts::default(); async_std_crate::task::block_on( - DnsConfig::custom(CustomTransport, config, opts).then(|dns| run(dns.unwrap())) + DnsConfig::custom(CustomTransport, config, opts).then(|dns| run(dns.unwrap())), ); } @@ -645,7 +689,9 @@ mod tests { .enable_time() .build() .unwrap(); - rt.block_on(run(TokioDnsConfig::custom(CustomTransport, config, opts).unwrap())); + rt.block_on(run( + TokioDnsConfig::custom(CustomTransport, config, opts).unwrap() + )); } } } diff --git a/transports/noise/build.rs b/transports/noise/build.rs index b13c29b5197..c9cf60412cd 100644 --- a/transports/noise/build.rs +++ b/transports/noise/build.rs @@ -19,5 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/io/handshake/payload.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/io/handshake/payload.proto"], &["src"]).unwrap(); } diff --git a/transports/noise/src/error.rs b/transports/noise/src/error.rs index 8b836d5ea78..4e1d240fe74 100644 --- a/transports/noise/src/error.rs +++ b/transports/noise/src/error.rs @@ -90,4 +90,3 @@ impl From for NoiseError { NoiseError::SigningError(e) } } - diff --git a/transports/noise/src/io.rs b/transports/noise/src/io.rs index c7bd110c773..37e35ecbeee 100644 --- a/transports/noise/src/io.rs +++ b/transports/noise/src/io.rs @@ -24,11 +24,16 @@ mod framed; pub mod handshake; use bytes::Bytes; -use framed::{MAX_FRAME_LEN, NoiseFramed}; -use futures::ready; +use framed::{NoiseFramed, MAX_FRAME_LEN}; use futures::prelude::*; +use futures::ready; use log::trace; -use std::{cmp::min, fmt, io, pin::Pin, task::{Context, Poll}}; +use std::{ + cmp::min, + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; /// A noise session to a remote. /// @@ -43,9 +48,7 @@ pub struct NoiseOutput { impl fmt::Debug for NoiseOutput { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("NoiseOutput") - .field("io", &self.io) - .finish() + f.debug_struct("NoiseOutput").field("io", &self.io).finish() } } @@ -62,13 +65,17 @@ impl NoiseOutput { } impl AsyncRead for NoiseOutput { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { loop { let len = self.recv_buffer.len(); let off = self.recv_offset; if len > 0 { let n = min(len - off, buf.len()); - buf[.. n].copy_from_slice(&self.recv_buffer[off .. off + n]); + buf[..n].copy_from_slice(&self.recv_buffer[off..off + n]); trace!("read: copied {}/{} bytes", off + n, len); self.recv_offset += n; if len == self.recv_offset { @@ -77,7 +84,7 @@ impl AsyncRead for NoiseOutput { // the buffer when polling for the next frame below. self.recv_buffer = Bytes::new(); } - return Poll::Ready(Ok(n)) + return Poll::Ready(Ok(n)); } match Pin::new(&mut self.io).poll_next(cx) { @@ -94,7 +101,11 @@ impl AsyncRead for NoiseOutput { } impl AsyncWrite for NoiseOutput { - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { let this = Pin::into_inner(self); let mut io = Pin::new(&mut this.io); let frame_buf = &mut this.send_buffer; @@ -111,7 +122,7 @@ impl AsyncWrite for NoiseOutput { let n = min(MAX_FRAME_LEN, off.saturating_add(buf.len())); this.send_buffer.resize(n, 0u8); let n = min(MAX_FRAME_LEN - off, buf.len()); - this.send_buffer[off .. off + n].copy_from_slice(&buf[.. n]); + this.send_buffer[off..off + n].copy_from_slice(&buf[..n]); this.send_offset += n; trace!("write: buffered {} bytes", this.send_offset); @@ -134,7 +145,7 @@ impl AsyncWrite for NoiseOutput { io.as_mut().poll_flush(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>{ + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.as_mut().poll_flush(cx))?; Pin::new(&mut self.io).poll_close(cx) } diff --git a/transports/noise/src/io/framed.rs b/transports/noise/src/io/framed.rs index 000300bdfcc..4ca228ebe54 100644 --- a/transports/noise/src/io/framed.rs +++ b/transports/noise/src/io/framed.rs @@ -21,13 +21,17 @@ //! This module provides a `Sink` and `Stream` for length-delimited //! Noise protocol messages in form of [`NoiseFramed`]. -use bytes::{Bytes, BytesMut}; -use crate::{NoiseError, Protocol, PublicKey}; use crate::io::NoiseOutput; -use futures::ready; +use crate::{NoiseError, Protocol, PublicKey}; +use bytes::{Bytes, BytesMut}; use futures::prelude::*; +use futures::ready; use log::{debug, trace}; -use std::{fmt, io, pin::Pin, task::{Context, Poll}}; +use std::{ + fmt, io, + pin::Pin, + task::{Context, Poll}, +}; /// Max. size of a noise message. const MAX_NOISE_MSG_LEN: usize = 65535; @@ -88,14 +92,14 @@ impl NoiseFramed { /// present, cannot be parsed. pub fn into_transport(self) -> Result<(Option>, NoiseOutput), NoiseError> where - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { let dh_remote_pubkey = match self.session.get_remote_static() { None => None, Some(k) => match C::public_from_bytes(k) { Err(e) => return Err(e), - Ok(dh_pk) => Some(dh_pk) - } + Ok(dh_pk) => Some(dh_pk), + }, }; match self.session.into_transport_mode() { Err(e) => Err(e.into()), @@ -129,7 +133,7 @@ enum ReadState { /// The associated result signals if the EOF was unexpected or not. Eof(Result<(), ()>), /// A decryption error occurred (terminal state). - DecErr + DecErr, } /// The states for writing Noise protocol frames. @@ -138,19 +142,23 @@ enum WriteState { /// Ready to write another frame. Ready, /// Writing the frame length. - WriteLen { len: usize, buf: [u8; 2], off: usize }, + WriteLen { + len: usize, + buf: [u8; 2], + off: usize, + }, /// Writing the frame data. WriteData { len: usize, off: usize }, /// EOF has been reached unexpectedly (terminal state). Eof, /// An encryption error occurred (terminal state). - EncErr + EncErr, } impl WriteState { fn is_ready(&self) -> bool { if let WriteState::Ready = self { - return true + return true; } false } @@ -159,7 +167,7 @@ impl WriteState { impl futures::stream::Stream for NoiseFramed where T: AsyncRead + Unpin, - S: SessionState + Unpin + S: SessionState + Unpin, { type Item = io::Result; @@ -169,7 +177,10 @@ where trace!("read state: {:?}", this.read_state); match this.read_state { ReadState::Ready => { - this.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; + this.read_state = ReadState::ReadLen { + buf: [0, 0], + off: 0, + }; } ReadState::ReadLen { mut buf, mut off } => { let n = match read_frame_len(&mut this.io, cx, &mut buf, &mut off) { @@ -177,11 +188,9 @@ where Poll::Ready(Ok(None)) => { trace!("read: eof"); this.read_state = ReadState::Eof(Ok(())); - return Poll::Ready(None) - } - Poll::Ready(Err(e)) => { - return Poll::Ready(Some(Err(e))) + return Poll::Ready(None); } + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), Poll::Pending => { this.read_state = ReadState::ReadLen { buf, off }; return Poll::Pending; @@ -191,14 +200,18 @@ where if n == 0 { trace!("read: empty frame"); this.read_state = ReadState::Ready; - continue + continue; } this.read_buffer.resize(usize::from(n), 0u8); - this.read_state = ReadState::ReadData { len: usize::from(n), off: 0 } + this.read_state = ReadState::ReadData { + len: usize::from(n), + off: 0, + } } ReadState::ReadData { len, ref mut off } => { let n = { - let f = Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off .. len]); + let f = + Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off..len]); match ready!(f) { Ok(n) => n, Err(e) => return Poll::Ready(Some(Err(e))), @@ -208,13 +221,16 @@ where if n == 0 { trace!("read: eof"); this.read_state = ReadState::Eof(Err(())); - return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))) + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); } *off += n; if len == *off { trace!("read: decrypting {} bytes", len); this.decrypt_buffer.resize(len, 0); - if let Ok(n) = this.session.read_message(&this.read_buffer, &mut this.decrypt_buffer) { + if let Ok(n) = this + .session + .read_message(&this.read_buffer, &mut this.decrypt_buffer) + { this.decrypt_buffer.truncate(n); trace!("read: payload len = {} bytes", n); this.read_state = ReadState::Ready; @@ -223,23 +239,25 @@ where // read, the `BytesMut` will reuse the same buffer // for the next frame. let view = this.decrypt_buffer.split().freeze(); - return Poll::Ready(Some(Ok(view))) + return Poll::Ready(Some(Ok(view))); } else { debug!("read: decryption error"); this.read_state = ReadState::DecErr; - return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))) + return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))); } } } ReadState::Eof(Ok(())) => { trace!("read: eof"); - return Poll::Ready(None) + return Poll::Ready(None); } ReadState::Eof(Err(())) => { trace!("read: eof (unexpected)"); - return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))) + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); + } + ReadState::DecErr => { + return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))) } - ReadState::DecErr => return Poll::Ready(Some(Err(io::ErrorKind::InvalidData.into()))) } } } @@ -248,7 +266,7 @@ where impl futures::sink::Sink<&Vec> for NoiseFramed where T: AsyncWrite + Unpin, - S: SessionState + Unpin + S: SessionState + Unpin, { type Error = io::Error; @@ -267,21 +285,20 @@ where Poll::Ready(Ok(false)) => { trace!("write: eof"); this.write_state = WriteState::Eof; - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) - } - Poll::Ready(Err(e)) => { - return Poll::Ready(Err(e)) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => { this.write_state = WriteState::WriteLen { len, buf, off }; - return Poll::Pending + return Poll::Pending; } } this.write_state = WriteState::WriteData { len, off: 0 } } WriteState::WriteData { len, ref mut off } => { let n = { - let f = Pin::new(&mut this.io).poll_write(cx, &this.write_buffer[*off .. len]); + let f = + Pin::new(&mut this.io).poll_write(cx, &this.write_buffer[*off..len]); match ready!(f) { Ok(n) => n, Err(e) => return Poll::Ready(Err(e)), @@ -290,7 +307,7 @@ where if n == 0 { trace!("write: eof"); this.write_state = WriteState::Eof; - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } *off += n; trace!("write: {}/{} bytes written", *off, len); @@ -301,9 +318,9 @@ where } WriteState::Eof => { trace!("write: eof"); - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())) + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } - WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())) + WriteState::EncErr => return Poll::Ready(Err(io::ErrorKind::InvalidData.into())), } } } @@ -313,15 +330,19 @@ where let mut this = Pin::into_inner(self); assert!(this.write_state.is_ready()); - this.write_buffer.resize(frame.len() + EXTRA_ENCRYPT_SPACE, 0u8); - match this.session.write_message(frame, &mut this.write_buffer[..]) { + this.write_buffer + .resize(frame.len() + EXTRA_ENCRYPT_SPACE, 0u8); + match this + .session + .write_message(frame, &mut this.write_buffer[..]) + { Ok(n) => { trace!("write: cipher text len = {} bytes", n); this.write_buffer.truncate(n); this.write_state = WriteState::WriteLen { len: n, buf: u16::to_be_bytes(n as u16), - off: 0 + off: 0, }; Ok(()) } @@ -386,7 +407,7 @@ fn read_frame_len( off: &mut usize, ) -> Poll>> { loop { - match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off ..])) { + match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off..])) { Ok(n) => { if n == 0 { return Poll::Ready(Ok(None)); @@ -395,10 +416,10 @@ fn read_frame_len( if *off == 2 { return Poll::Ready(Ok(Some(u16::from_be_bytes(*buf)))); } - }, + } Err(e) => { return Poll::Ready(Err(e)); - }, + } } } } @@ -419,14 +440,14 @@ fn write_frame_len( off: &mut usize, ) -> Poll> { loop { - match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off ..])) { + match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off..])) { Ok(n) => { if n == 0 { - return Poll::Ready(Ok(false)) + return Poll::Ready(Ok(false)); } *off += n; if *off == 2 { - return Poll::Ready(Ok(true)) + return Poll::Ready(Ok(true)); } } Err(e) => { @@ -435,4 +456,3 @@ fn write_frame_len( } } } - diff --git a/transports/noise/src/io/handshake.rs b/transports/noise/src/io/handshake.rs index 21faa84d100..fa97798fb23 100644 --- a/transports/noise/src/io/handshake.rs +++ b/transports/noise/src/io/handshake.rs @@ -24,14 +24,14 @@ mod payload_proto { include!(concat!(env!("OUT_DIR"), "/payload.proto.rs")); } -use bytes::Bytes; -use crate::LegacyConfig; use crate::error::NoiseError; -use crate::protocol::{Protocol, PublicKey, KeypairIdentity}; -use crate::io::{NoiseOutput, framed::NoiseFramed}; -use libp2p_core::identity; +use crate::io::{framed::NoiseFramed, NoiseOutput}; +use crate::protocol::{KeypairIdentity, Protocol, PublicKey}; +use crate::LegacyConfig; +use bytes::Bytes; use futures::prelude::*; use futures::task; +use libp2p_core::identity; use prost::Message; use std::{io, pin::Pin, task::Context}; @@ -59,7 +59,7 @@ pub enum RemoteIdentity { /// > **Note**: To rule out active attacks like a MITM, trust in the public key must /// > still be established, e.g. by comparing the key against an expected or /// > otherwise known public key. - IdentityKey(identity::PublicKey) + IdentityKey(identity::PublicKey), } /// The options for identity exchange in an authenticated handshake. @@ -87,14 +87,12 @@ pub enum IdentityExchange { /// /// The remote identity is known, thus identities must be mutually known /// in order for the handshake to succeed. - None { remote: identity::PublicKey } + None { remote: identity::PublicKey }, } /// A future performing a Noise handshake pattern. pub struct Handshake( - Pin, NoiseOutput), NoiseError>, - > + Send>> + Pin, NoiseOutput), NoiseError>> + Send>>, ); impl Future for Handshake { @@ -131,7 +129,7 @@ pub fn rt1_initiator( ) -> Handshake where T: AsyncWrite + AsyncRead + Send + Unpin + 'static, - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { Handshake(Box::pin(async move { let mut state = State::new(io, session, identity, identity_x, legacy)?; @@ -166,7 +164,7 @@ pub fn rt1_responder( ) -> Handshake where T: AsyncWrite + AsyncRead + Send + Unpin + 'static, - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { Handshake(Box::pin(async move { let mut state = State::new(io, session, identity, identity_x, legacy)?; @@ -203,7 +201,7 @@ pub fn rt15_initiator( ) -> Handshake where T: AsyncWrite + AsyncRead + Unpin + Send + 'static, - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { Handshake(Box::pin(async move { let mut state = State::new(io, session, identity, identity_x, legacy)?; @@ -241,7 +239,7 @@ pub fn rt15_responder( ) -> Handshake where T: AsyncWrite + AsyncRead + Unpin + Send + 'static, - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { Handshake(Box::pin(async move { let mut state = State::new(io, session, identity, identity_x, legacy)?; @@ -289,28 +287,25 @@ impl State { IdentityExchange::Mutual => (None, true), IdentityExchange::Send { remote } => (Some(remote), true), IdentityExchange::Receive => (None, false), - IdentityExchange::None { remote } => (Some(remote), false) + IdentityExchange::None { remote } => (Some(remote), false), }; - session.map(|s| - State { - identity, - io: NoiseFramed::new(io, s), - dh_remote_pubkey_sig: None, - id_remote_pubkey, - send_identity, - legacy, - } - ) + session.map(|s| State { + identity, + io: NoiseFramed::new(io, s), + dh_remote_pubkey_sig: None, + id_remote_pubkey, + send_identity, + legacy, + }) } } -impl State -{ +impl State { /// Finish a handshake, yielding the established remote identity and the /// [`NoiseOutput`] for communicating on the encrypted channel. fn finish(self) -> Result<(RemoteIdentity, NoiseOutput), NoiseError> where - C: Protocol + AsRef<[u8]> + C: Protocol + AsRef<[u8]>, { let (pubkey, io) = self.io.into_transport()?; let remote = match (self.id_remote_pubkey, pubkey) { @@ -320,7 +315,7 @@ impl State if C::verify(&id_pk, &dh_pk, &self.dh_remote_pubkey_sig) { RemoteIdentity::IdentityKey(id_pk) } else { - return Err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey); } } }; @@ -334,7 +329,7 @@ impl State /// A future for receiving a Noise handshake message. async fn recv(state: &mut State) -> Result where - T: AsyncRead + Unpin + T: AsyncRead + Unpin, { match state.io.next().await { None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "eof").into()), @@ -346,13 +341,13 @@ where /// A future for receiving a Noise handshake message with an empty payload. async fn recv_empty(state: &mut State) -> Result<(), NoiseError> where - T: AsyncRead + Unpin + T: AsyncRead + Unpin, { let msg = recv(state).await?; if !msg.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Unexpected handshake payload.").into()) + return Err( + io::Error::new(io::ErrorKind::InvalidData, "Unexpected handshake payload.").into(), + ); } Ok(()) } @@ -360,7 +355,7 @@ where /// A future for sending a Noise handshake message with an empty payload. async fn send_empty(state: &mut State) -> Result<(), NoiseError> where - T: AsyncWrite + Unpin + T: AsyncWrite + Unpin, { state.io.send(&Vec::new()).await?; Ok(()) @@ -390,12 +385,12 @@ where pb_result = pb_result.or_else(|e| { if msg.len() > 2 { let mut buf = [0, 0]; - buf.copy_from_slice(&msg[.. 2]); + buf.copy_from_slice(&msg[..2]); // If there is a second length it must be 2 bytes shorter than the // frame length, because each length is encoded as a `u16`. if usize::from(u16::from_be_bytes(buf)) + 2 == msg.len() { log::debug!("Attempting fallback legacy protobuf decoding."); - payload_proto::NoiseHandshakePayload::decode(&msg[2 ..]) + payload_proto::NoiseHandshakePayload::decode(&msg[2..]) } else { Err(e) } @@ -411,7 +406,7 @@ where .map_err(|_| NoiseError::InvalidKey)?; if let Some(ref k) = state.id_remote_pubkey { if k != &pk { - return Err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey); } } state.id_remote_pubkey = Some(pk); @@ -439,16 +434,16 @@ where pb.identity_sig = sig.clone() } - let mut msg = - if state.legacy.send_legacy_handshake { - let mut msg = Vec::with_capacity(2 + pb.encoded_len()); - msg.extend_from_slice(&(pb.encoded_len() as u16).to_be_bytes()); - msg - } else { - Vec::with_capacity(pb.encoded_len()) - }; + let mut msg = if state.legacy.send_legacy_handshake { + let mut msg = Vec::with_capacity(2 + pb.encoded_len()); + msg.extend_from_slice(&(pb.encoded_len() as u16).to_be_bytes()); + msg + } else { + Vec::with_capacity(pb.encoded_len()) + }; - pb.encode(&mut msg).expect("Vec provides capacity as needed"); + pb.encode(&mut msg) + .expect("Vec provides capacity as needed"); state.io.send(&msg).await?; Ok(()) diff --git a/transports/noise/src/lib.rs b/transports/noise/src/lib.rs index 7abf06915e3..28261e85dd1 100644 --- a/transports/noise/src/lib.rs +++ b/transports/noise/src/lib.rs @@ -59,15 +59,15 @@ mod io; mod protocol; pub use error::NoiseError; -pub use io::NoiseOutput; pub use io::handshake; -pub use io::handshake::{Handshake, RemoteIdentity, IdentityExchange}; -pub use protocol::{Keypair, AuthenticKeypair, KeypairIdentity, PublicKey, SecretKey}; -pub use protocol::{Protocol, ProtocolParams, IX, IK, XX}; +pub use io::handshake::{Handshake, IdentityExchange, RemoteIdentity}; +pub use io::NoiseOutput; pub use protocol::{x25519::X25519, x25519_spec::X25519Spec}; +pub use protocol::{AuthenticKeypair, Keypair, KeypairIdentity, PublicKey, SecretKey}; +pub use protocol::{Protocol, ProtocolParams, IK, IX, XX}; use futures::prelude::*; -use libp2p_core::{identity, PeerId, UpgradeInfo, InboundUpgrade, OutboundUpgrade}; +use libp2p_core::{identity, InboundUpgrade, OutboundUpgrade, PeerId, UpgradeInfo}; use std::pin::Pin; use zeroize::Zeroize; @@ -78,7 +78,7 @@ pub struct NoiseConfig { params: ProtocolParams, legacy: LegacyConfig, remote: R, - _marker: std::marker::PhantomData

+ _marker: std::marker::PhantomData

, } impl NoiseConfig { @@ -97,7 +97,7 @@ impl NoiseConfig { impl NoiseConfig where - C: Protocol + Zeroize + C: Protocol + Zeroize, { /// Create a new `NoiseConfig` for the `IX` handshake pattern. pub fn ix(dh_keys: AuthenticKeypair) -> Self { @@ -106,14 +106,14 @@ where params: C::params_ix(), legacy: LegacyConfig::default(), remote: (), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } impl NoiseConfig where - C: Protocol + Zeroize + C: Protocol + Zeroize, { /// Create a new `NoiseConfig` for the `XX` handshake pattern. pub fn xx(dh_keys: AuthenticKeypair) -> Self { @@ -122,14 +122,14 @@ where params: C::params_xx(), legacy: LegacyConfig::default(), remote: (), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } impl NoiseConfig where - C: Protocol + Zeroize + C: Protocol + Zeroize, { /// Create a new `NoiseConfig` for the `IK` handshake pattern (recipient side). /// @@ -141,14 +141,14 @@ where params: C::params_ik(), legacy: LegacyConfig::default(), remote: (), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } impl NoiseConfig, identity::PublicKey)> where - C: Protocol + Zeroize + C: Protocol + Zeroize, { /// Create a new `NoiseConfig` for the `IK` handshake pattern (initiator side). /// @@ -157,14 +157,14 @@ where pub fn ik_dialer( dh_keys: AuthenticKeypair, remote_id: identity::PublicKey, - remote_dh: PublicKey + remote_dh: PublicKey, ) -> Self { NoiseConfig { dh_keys, params: C::params_ik(), legacy: LegacyConfig::default(), remote: (remote_dh, remote_id), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } } @@ -182,14 +182,19 @@ where type Future = Handshake; fn upgrade_inbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - handshake::rt1_responder(socket, session, + handshake::rt1_responder( + socket, + session, self.dh_keys.into_identity(), IdentityExchange::Mutual, - self.legacy) + self.legacy, + ) } } @@ -204,14 +209,19 @@ where type Future = Handshake; fn upgrade_outbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_initiator() .map_err(NoiseError::from); - handshake::rt1_initiator(socket, session, - self.dh_keys.into_identity(), - IdentityExchange::Mutual, - self.legacy) + handshake::rt1_initiator( + socket, + session, + self.dh_keys.into_identity(), + IdentityExchange::Mutual, + self.legacy, + ) } } @@ -228,14 +238,19 @@ where type Future = Handshake; fn upgrade_inbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - handshake::rt15_responder(socket, session, + handshake::rt15_responder( + socket, + session, self.dh_keys.into_identity(), IdentityExchange::Mutual, - self.legacy) + self.legacy, + ) } } @@ -250,14 +265,19 @@ where type Future = Handshake; fn upgrade_outbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_initiator() .map_err(NoiseError::from); - handshake::rt15_initiator(socket, session, + handshake::rt15_initiator( + socket, + session, self.dh_keys.into_identity(), IdentityExchange::Mutual, - self.legacy) + self.legacy, + ) } } @@ -274,14 +294,19 @@ where type Future = Handshake; fn upgrade_inbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .build_responder() .map_err(NoiseError::from); - handshake::rt1_responder(socket, session, + handshake::rt1_responder( + socket, + session, self.dh_keys.into_identity(), IdentityExchange::Receive, - self.legacy) + self.legacy, + ) } } @@ -296,15 +321,22 @@ where type Future = Handshake; fn upgrade_outbound(self, socket: T, _: Self::Info) -> Self::Future { - let session = self.params.into_builder() + let session = self + .params + .into_builder() .local_private_key(self.dh_keys.secret().as_ref()) .remote_public_key(self.remote.0.as_ref()) .build_initiator() .map_err(NoiseError::from); - handshake::rt1_initiator(socket, session, + handshake::rt1_initiator( + socket, + session, self.dh_keys.into_identity(), - IdentityExchange::Send { remote: self.remote.1 }, - self.legacy) + IdentityExchange::Send { + remote: self.remote.1, + }, + self.legacy, + ) } } @@ -323,12 +355,12 @@ where /// transport for use with a [`Network`](libp2p_core::Network). #[derive(Clone)] pub struct NoiseAuthenticated { - config: NoiseConfig + config: NoiseConfig, } impl UpgradeInfo for NoiseAuthenticated where - NoiseConfig: UpgradeInfo + NoiseConfig: UpgradeInfo, { type Info = as UpgradeInfo>::Info; type InfoIter = as UpgradeInfo>::InfoIter; @@ -340,10 +372,9 @@ where impl InboundUpgrade for NoiseAuthenticated where - NoiseConfig: UpgradeInfo + InboundUpgrade, NoiseOutput), - Error = NoiseError - > + 'static, + NoiseConfig: UpgradeInfo + + InboundUpgrade, NoiseOutput), Error = NoiseError> + + 'static, as InboundUpgrade>::Future: Send, T: AsyncRead + AsyncWrite + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, @@ -353,20 +384,22 @@ where type Future = Pin> + Send>>; fn upgrade_inbound(self, socket: T, info: Self::Info) -> Self::Future { - Box::pin(self.config.upgrade_inbound(socket, info) - .and_then(|(remote, io)| match remote { - RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), - _ => future::err(NoiseError::AuthenticationFailed) - })) + Box::pin( + self.config + .upgrade_inbound(socket, info) + .and_then(|(remote, io)| match remote { + RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), + _ => future::err(NoiseError::AuthenticationFailed), + }), + ) } } impl OutboundUpgrade for NoiseAuthenticated where - NoiseConfig: UpgradeInfo + OutboundUpgrade, NoiseOutput), - Error = NoiseError - > + 'static, + NoiseConfig: UpgradeInfo + + OutboundUpgrade, NoiseOutput), Error = NoiseError> + + 'static, as OutboundUpgrade>::Future: Send, T: AsyncRead + AsyncWrite + Send + 'static, C: Protocol + AsRef<[u8]> + Zeroize + Send + 'static, @@ -376,11 +409,14 @@ where type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: T, info: Self::Info) -> Self::Future { - Box::pin(self.config.upgrade_outbound(socket, info) - .and_then(|(remote, io)| match remote { - RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), - _ => future::err(NoiseError::AuthenticationFailed) - })) + Box::pin( + self.config + .upgrade_outbound(socket, info) + .and_then(|(remote, io)| match remote { + RemoteIdentity::IdentityKey(pk) => future::ok((pk.to_peer_id(), io)), + _ => future::err(NoiseError::AuthenticationFailed), + }), + ) } } diff --git a/transports/noise/src/protocol.rs b/transports/noise/src/protocol.rs index 7c61274acb5..aa2acb150a9 100644 --- a/transports/noise/src/protocol.rs +++ b/transports/noise/src/protocol.rs @@ -92,16 +92,17 @@ pub trait Protocol { #[allow(deprecated)] fn verify(id_pk: &identity::PublicKey, dh_pk: &PublicKey, sig: &Option>) -> bool where - C: AsRef<[u8]> + C: AsRef<[u8]>, { Self::linked(id_pk, dh_pk) - || - sig.as_ref().map_or(false, |s| id_pk.verify(dh_pk.as_ref(), s)) + || sig + .as_ref() + .map_or(false, |s| id_pk.verify(dh_pk.as_ref(), s)) } fn sign(id_keys: &identity::Keypair, dh_pk: &PublicKey) -> Result, NoiseError> where - C: AsRef<[u8]> + C: AsRef<[u8]>, { Ok(id_keys.sign(dh_pk.as_ref())?) } @@ -118,7 +119,7 @@ pub struct Keypair { #[derive(Clone)] pub struct AuthenticKeypair { keypair: Keypair, - identity: KeypairIdentity + identity: KeypairIdentity, } impl AuthenticKeypair { @@ -143,7 +144,7 @@ pub struct KeypairIdentity { /// The public identity key. pub public: identity::PublicKey, /// The signature over the public DH key. - pub signature: Option> + pub signature: Option>, } impl Keypair { @@ -159,19 +160,25 @@ impl Keypair { /// Turn this DH keypair into a [`AuthenticKeypair`], i.e. a DH keypair that /// is authentic w.r.t. the given identity keypair, by signing the DH public key. - pub fn into_authentic(self, id_keys: &identity::Keypair) -> Result, NoiseError> + pub fn into_authentic( + self, + id_keys: &identity::Keypair, + ) -> Result, NoiseError> where T: AsRef<[u8]>, - T: Protocol + T: Protocol, { let sig = T::sign(id_keys, &self.public)?; let identity = KeypairIdentity { public: id_keys.public(), - signature: Some(sig) + signature: Some(sig), }; - Ok(AuthenticKeypair { keypair: self, identity }) + Ok(AuthenticKeypair { + keypair: self, + identity, + }) } } @@ -228,7 +235,10 @@ impl snow::resolvers::CryptoResolver for Resolver { } } - fn resolve_hash(&self, choice: &snow::params::HashChoice) -> Option> { + fn resolve_hash( + &self, + choice: &snow::params::HashChoice, + ) -> Option> { #[cfg(target_arch = "wasm32")] { snow::resolvers::DefaultResolver.resolve_hash(choice) @@ -239,7 +249,10 @@ impl snow::resolvers::CryptoResolver for Resolver { } } - fn resolve_cipher(&self, choice: &snow::params::CipherChoice) -> Option> { + fn resolve_cipher( + &self, + choice: &snow::params::CipherChoice, + ) -> Option> { #[cfg(target_arch = "wasm32")] { snow::resolvers::DefaultResolver.resolve_cipher(choice) diff --git a/transports/noise/src/protocol/x25519.rs b/transports/noise/src/protocol/x25519.rs index c4e79bc33ae..c0d3936ee36 100644 --- a/transports/noise/src/protocol/x25519.rs +++ b/transports/noise/src/protocol/x25519.rs @@ -29,8 +29,8 @@ use lazy_static::lazy_static; use libp2p_core::UpgradeInfo; use libp2p_core::{identity, identity::ed25519}; use rand::Rng; -use sha2::{Sha512, Digest}; -use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; +use sha2::{Digest, Sha512}; +use x25519_dalek::{x25519, X25519_BASEPOINT_BYTES}; use zeroize::Zeroize; use super::*; @@ -40,12 +40,10 @@ lazy_static! { .parse() .map(ProtocolParams) .expect("Invalid protocol name"); - static ref PARAMS_IX: ProtocolParams = "Noise_IX_25519_ChaChaPoly_SHA256" .parse() .map(ProtocolParams) .expect("Invalid protocol name"); - static ref PARAMS_XX: ProtocolParams = "Noise_XX_25519_ChaChaPoly_SHA256" .parse() .map(ProtocolParams) @@ -115,7 +113,7 @@ impl Protocol for X25519 { fn public_from_bytes(bytes: &[u8]) -> Result, NoiseError> { if bytes.len() != 32 { - return Err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey); } let mut pk = [0u8; 32]; pk.copy_from_slice(bytes); @@ -137,7 +135,7 @@ impl Keypair { pub(super) fn default() -> Self { Keypair { secret: SecretKey(X25519([0u8; 32])), - public: PublicKey(X25519([0u8; 32])) + public: PublicKey(X25519([0u8; 32])), } } @@ -170,14 +168,14 @@ impl Keypair { let kp = Keypair::from(SecretKey::from_ed25519(&p.secret())); let id = KeypairIdentity { public: id_keys.public(), - signature: None + signature: None, }; Some(AuthenticKeypair { keypair: kp, - identity: id + identity: id, }) } - _ => None + _ => None, } } } @@ -193,10 +191,13 @@ impl From> for Keypair { impl PublicKey { /// Construct a curve25519 public key from an Ed25519 public key. pub fn from_ed25519(pk: &ed25519::PublicKey) -> Self { - PublicKey(X25519(CompressedEdwardsY(pk.encode()) - .decompress() - .expect("An Ed25519 public key is a valid point by construction.") - .to_montgomery().0)) + PublicKey(X25519( + CompressedEdwardsY(pk.encode()) + .decompress() + .expect("An Ed25519 public key is a valid point by construction.") + .to_montgomery() + .0, + )) } } @@ -227,11 +228,21 @@ impl SecretKey { #[doc(hidden)] impl snow::types::Dh for Keypair { - fn name(&self) -> &'static str { "25519" } - fn pub_len(&self) -> usize { 32 } - fn priv_len(&self) -> usize { 32 } - fn pubkey(&self) -> &[u8] { self.public.as_ref() } - fn privkey(&self) -> &[u8] { self.secret.as_ref() } + fn name(&self) -> &'static str { + "25519" + } + fn pub_len(&self) -> usize { + 32 + } + fn priv_len(&self) -> usize { + 32 + } + fn pubkey(&self) -> &[u8] { + self.public.as_ref() + } + fn privkey(&self) -> &[u8] { + self.secret.as_ref() + } fn set(&mut self, sk: &[u8]) { let mut secret = [0u8; 32]; @@ -251,20 +262,20 @@ impl snow::types::Dh for Keypair { fn dh(&self, pk: &[u8], shared_secret: &mut [u8]) -> Result<(), ()> { let mut p = [0; 32]; - p.copy_from_slice(&pk[.. 32]); + p.copy_from_slice(&pk[..32]); let ss = x25519((self.secret.0).0, p); - shared_secret[.. 32].copy_from_slice(&ss[..]); + shared_secret[..32].copy_from_slice(&ss[..]); Ok(()) } } #[cfg(test)] mod tests { + use super::*; use libp2p_core::identity::ed25519; use quickcheck::*; use sodiumoxide::crypto::sign; use std::os::raw::c_int; - use super::*; use x25519_dalek::StaticSecret; // ed25519 to x25519 keypair conversion must yield the same results as @@ -276,7 +287,8 @@ mod tests { let x25519 = Keypair::from(SecretKey::from_ed25519(&ed25519.secret())); let sodium_sec = ed25519_sk_to_curve25519(&sign::SecretKey(ed25519.encode())); - let sodium_pub = ed25519_pk_to_curve25519(&sign::PublicKey(ed25519.public().encode().clone())); + let sodium_pub = + ed25519_pk_to_curve25519(&sign::PublicKey(ed25519.public().encode().clone())); let our_pub = x25519.public.0; // libsodium does the [clamping] of the scalar upon key construction, @@ -288,8 +300,7 @@ mod tests { // [clamping]: http://www.lix.polytechnique.fr/~smith/ECC/#scalar-clamping let our_sec = StaticSecret::from((x25519.secret.0).0).to_bytes(); - sodium_sec.as_ref() == Some(&our_sec) && - sodium_pub.as_ref() == Some(&our_pub.0) + sodium_sec.as_ref() == Some(&our_sec) && sodium_pub.as_ref() == Some(&our_pub.0) } quickcheck(prop as fn() -> _); @@ -340,4 +351,3 @@ mod tests { } } } - diff --git a/transports/noise/src/protocol/x25519_spec.rs b/transports/noise/src/protocol/x25519_spec.rs index 16e3ffeafee..2f2c24237a6 100644 --- a/transports/noise/src/protocol/x25519_spec.rs +++ b/transports/noise/src/protocol/x25519_spec.rs @@ -23,13 +23,13 @@ //! [libp2p-noise-spec]: https://github.com/libp2p/specs/tree/master/noise use crate::{NoiseConfig, NoiseError, Protocol, ProtocolParams}; -use libp2p_core::UpgradeInfo; use libp2p_core::identity; +use libp2p_core::UpgradeInfo; use rand::Rng; -use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519}; +use x25519_dalek::{x25519, X25519_BASEPOINT_BYTES}; use zeroize::Zeroize; -use super::{*, x25519::X25519}; +use super::{x25519::X25519, *}; /// Prefix of static key signatures for domain separation. const STATIC_KEY_DOMAIN: &str = "noise-libp2p-static-key:"; @@ -117,32 +117,48 @@ impl Protocol for X25519Spec { fn public_from_bytes(bytes: &[u8]) -> Result, NoiseError> { if bytes.len() != 32 { - return Err(NoiseError::InvalidKey) + return Err(NoiseError::InvalidKey); } let mut pk = [0u8; 32]; pk.copy_from_slice(bytes); Ok(PublicKey(X25519Spec(pk))) } - fn verify(id_pk: &identity::PublicKey, dh_pk: &PublicKey, sig: &Option>) -> bool - { + fn verify( + id_pk: &identity::PublicKey, + dh_pk: &PublicKey, + sig: &Option>, + ) -> bool { sig.as_ref().map_or(false, |s| { id_pk.verify(&[STATIC_KEY_DOMAIN.as_bytes(), dh_pk.as_ref()].concat(), s) }) } - fn sign(id_keys: &identity::Keypair, dh_pk: &PublicKey) -> Result, NoiseError> { + fn sign( + id_keys: &identity::Keypair, + dh_pk: &PublicKey, + ) -> Result, NoiseError> { Ok(id_keys.sign(&[STATIC_KEY_DOMAIN.as_bytes(), dh_pk.as_ref()].concat())?) } } #[doc(hidden)] impl snow::types::Dh for Keypair { - fn name(&self) -> &'static str { "25519" } - fn pub_len(&self) -> usize { 32 } - fn priv_len(&self) -> usize { 32 } - fn pubkey(&self) -> &[u8] { self.public.as_ref() } - fn privkey(&self) -> &[u8] { self.secret.as_ref() } + fn name(&self) -> &'static str { + "25519" + } + fn pub_len(&self) -> usize { + 32 + } + fn priv_len(&self) -> usize { + 32 + } + fn pubkey(&self) -> &[u8] { + self.public.as_ref() + } + fn privkey(&self) -> &[u8] { + self.secret.as_ref() + } fn set(&mut self, sk: &[u8]) { let mut secret = [0u8; 32]; @@ -162,9 +178,9 @@ impl snow::types::Dh for Keypair { fn dh(&self, pk: &[u8], shared_secret: &mut [u8]) -> Result<(), ()> { let mut p = [0; 32]; - p.copy_from_slice(&pk[.. 32]); + p.copy_from_slice(&pk[..32]); let ss = x25519((self.secret.0).0, p); - shared_secret[.. 32].copy_from_slice(&ss[..]); + shared_secret[..32].copy_from_slice(&ss[..]); Ok(()) } } diff --git a/transports/noise/tests/smoke.rs b/transports/noise/tests/smoke.rs index 829ebc6bf08..dc5a386dbf8 100644 --- a/transports/noise/tests/smoke.rs +++ b/transports/noise/tests/smoke.rs @@ -19,11 +19,16 @@ // DEALINGS IN THE SOFTWARE. use async_io::Async; -use futures::{future::{self, Either}, prelude::*}; +use futures::{ + future::{self, Either}, + prelude::*, +}; use libp2p_core::identity; -use libp2p_core::upgrade::{self, Negotiated, apply_inbound, apply_outbound}; -use libp2p_core::transport::{Transport, ListenerEvent}; -use libp2p_noise::{Keypair, X25519, X25519Spec, NoiseConfig, RemoteIdentity, NoiseError, NoiseOutput}; +use libp2p_core::transport::{ListenerEvent, Transport}; +use libp2p_core::upgrade::{self, apply_inbound, apply_outbound, Negotiated}; +use libp2p_noise::{ + Keypair, NoiseConfig, NoiseError, NoiseOutput, RemoteIdentity, X25519Spec, X25519, +}; use libp2p_tcp::TcpConfig; use log::info; use quickcheck::QuickCheck; @@ -50,24 +55,40 @@ fn xx_spec() { let server_id_public = server_id.public(); let client_id_public = client_id.public(); - let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); + let server_dh = Keypair::::new() + .into_authentic(&server_id) + .unwrap(); let server_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::xx(server_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::xx(server_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &client_id_public)); - let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); + let client_dh = Keypair::::new() + .into_authentic(&client_id) + .unwrap(); let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::xx(client_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::xx(client_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &server_id_public)); run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec) -> bool) } #[test] @@ -84,21 +105,33 @@ fn xx() { let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); let server_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::xx(server_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::xx(server_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &client_id_public)); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::xx(client_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::xx(client_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &server_id_public)); run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec) -> bool) } #[test] @@ -115,21 +148,33 @@ fn ix() { let server_dh = Keypair::::new().into_authentic(&server_id).unwrap(); let server_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::ix(server_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::ix(server_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &client_id_public)); let client_dh = Keypair::::new().into_authentic(&client_id).unwrap(); let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { - upgrade::apply(output, NoiseConfig::ix(client_dh), endpoint, upgrade::Version::V1) + upgrade::apply( + output, + NoiseConfig::ix(client_dh), + endpoint, + upgrade::Version::V1, + ) }) .and_then(move |out, _| expect_identity(out, &server_id_public)); run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec) -> bool) } #[test] @@ -150,8 +195,11 @@ fn ik_xx() { if endpoint.is_listener() { Either::Left(apply_inbound(output, NoiseConfig::ik_listener(server_dh))) } else { - Either::Right(apply_outbound(output, NoiseConfig::xx(server_dh), - upgrade::Version::V1)) + Either::Right(apply_outbound( + output, + NoiseConfig::xx(server_dh), + upgrade::Version::V1, + )) } }) .and_then(move |out, _| expect_identity(out, &client_id_public)); @@ -161,9 +209,11 @@ fn ik_xx() { let client_transport = TcpConfig::new() .and_then(move |output, endpoint| { if endpoint.is_dialer() { - Either::Left(apply_outbound(output, + Either::Left(apply_outbound( + output, NoiseConfig::ik_dialer(client_dh, server_id_public, server_dh_public), - upgrade::Version::V1)) + upgrade::Version::V1, + )) } else { Either::Right(apply_inbound(output, NoiseConfig::xx(client_dh))) } @@ -173,7 +223,9 @@ fn ik_xx() { run(server_transport, client_transport, messages); true } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec) -> bool) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec) -> bool) } type Output = (RemoteIdentity, NoiseOutput>>); @@ -188,14 +240,15 @@ where U::Dial: Send + 'static, U::Listener: Send + 'static, U::ListenerUpgrade: Send + 'static, - I: IntoIterator + Clone + I: IntoIterator + Clone, { futures::executor::block_on(async { let mut server: T::Listener = server_transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) .unwrap(); - let server_address = server.try_next() + let server_address = server + .try_next() .await .expect("some event") .expect("no error") @@ -204,7 +257,8 @@ where let outbound_msgs = messages.clone(); let client_fut = async { - let mut client_session = client_transport.dial(server_address.clone()) + let mut client_session = client_transport + .dial(server_address.clone()) .unwrap() .await .map(|(_, session)| session) @@ -219,7 +273,8 @@ where }; let server_fut = async { - let mut server_session = server.try_next() + let mut server_session = server + .try_next() .await .expect("some event") .map(ListenerEvent::into_upgrade) @@ -236,12 +291,15 @@ where match server_session.read_exact(&mut n).await { Ok(()) => u64::from_be_bytes(n), Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => 0, - Err(e) => panic!("error reading len: {}", e) + Err(e) => panic!("error reading len: {}", e), } }; info!("server: reading message ({} bytes)", len); let mut server_buffer = vec![0; len.try_into().unwrap()]; - server_session.read_exact(&mut server_buffer).await.expect("no error"); + server_session + .read_exact(&mut server_buffer) + .await + .expect("no error"); assert_eq!(server_buffer, m.0) } }; @@ -250,12 +308,13 @@ where }) } -fn expect_identity(output: Output, pk: &identity::PublicKey) - -> impl Future, NoiseError>> -{ +fn expect_identity( + output: Output, + pk: &identity::PublicKey, +) -> impl Future, NoiseError>> { match output.0 { RemoteIdentity::IdentityKey(ref k) if k == pk => future::ok(output), - _ => panic!("Unexpected remote identity") + _ => panic!("Unexpected remote identity"), } } diff --git a/transports/plaintext/build.rs b/transports/plaintext/build.rs index 1b0feff6a40..56c7b20121a 100644 --- a/transports/plaintext/build.rs +++ b/transports/plaintext/build.rs @@ -19,6 +19,5 @@ // DEALINGS IN THE SOFTWARE. fn main() { - prost_build::compile_protos(&["src/structs.proto"], &["src"]).unwrap(); + prost_build::compile_protos(&["src/structs.proto"], &["src"]).unwrap(); } - diff --git a/transports/plaintext/src/error.rs b/transports/plaintext/src/error.rs index 7ede99af60c..9f512c4f58e 100644 --- a/transports/plaintext/src/error.rs +++ b/transports/plaintext/src/error.rs @@ -47,16 +47,14 @@ impl error::Error for PlainTextError { impl fmt::Display for PlainTextError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - PlainTextError::IoError(e) => - write!(f, "I/O error: {}", e), - PlainTextError::InvalidPayload(protobuf_error) => { - match protobuf_error { - Some(e) => write!(f, "Protobuf error: {}", e), - None => f.write_str("Failed to parse one of the handshake protobuf messages") - } + PlainTextError::IoError(e) => write!(f, "I/O error: {}", e), + PlainTextError::InvalidPayload(protobuf_error) => match protobuf_error { + Some(e) => write!(f, "Protobuf error: {}", e), + None => f.write_str("Failed to parse one of the handshake protobuf messages"), }, - PlainTextError::InvalidPeerId => - f.write_str("The peer id of the exchange isn't consistent with the remote public key"), + PlainTextError::InvalidPeerId => f.write_str( + "The peer id of the exchange isn't consistent with the remote public key", + ), } } } diff --git a/transports/plaintext/src/handshake.rs b/transports/plaintext/src/handshake.rs index d981df4d964..6534c6d7abd 100644 --- a/transports/plaintext/src/handshake.rs +++ b/transports/plaintext/src/handshake.rs @@ -18,14 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::PlainText2Config; use crate::error::PlainTextError; use crate::structs_proto::Exchange; +use crate::PlainText2Config; +use asynchronous_codec::{Framed, FramedParts}; use bytes::{Bytes, BytesMut}; use futures::prelude::*; -use asynchronous_codec::{Framed, FramedParts}; -use libp2p_core::{PublicKey, PeerId}; +use libp2p_core::{PeerId, PublicKey}; use log::{debug, trace}; use prost::Message; use std::io::{Error as IoError, ErrorKind as IoErrorKind}; @@ -33,7 +33,7 @@ use unsigned_varint::codec::UviBytes; struct HandshakeContext { config: PlainText2Config, - state: T + state: T, } // HandshakeContext<()> --with_local-> HandshakeContext @@ -54,28 +54,31 @@ impl HandshakeContext { fn new(config: PlainText2Config) -> Self { let exchange = Exchange { id: Some(config.local_public_key.to_peer_id().to_bytes()), - pubkey: Some(config.local_public_key.to_protobuf_encoding()) + pubkey: Some(config.local_public_key.to_protobuf_encoding()), }; let mut buf = Vec::with_capacity(exchange.encoded_len()); - exchange.encode(&mut buf).expect("Vec provides capacity as needed"); + exchange + .encode(&mut buf) + .expect("Vec provides capacity as needed"); Self { config, state: Local { - exchange_bytes: buf - } + exchange_bytes: buf, + }, } } - fn with_remote(self, exchange_bytes: BytesMut) - -> Result, PlainTextError> - { + fn with_remote( + self, + exchange_bytes: BytesMut, + ) -> Result, PlainTextError> { let prop = match Exchange::decode(exchange_bytes) { Ok(prop) => prop, Err(e) => { debug!("failed to parse remote's exchange protobuf message"); return Err(PlainTextError::InvalidPayload(Some(e))); - }, + } }; let pb_pubkey = prop.pubkey.unwrap_or_default(); @@ -84,20 +87,20 @@ impl HandshakeContext { Err(_) => { debug!("failed to parse remote's exchange's pubkey protobuf"); return Err(PlainTextError::InvalidPayload(None)); - }, + } }; let peer_id = match PeerId::from_bytes(&prop.id.unwrap_or_default()) { Ok(p) => p, Err(_) => { debug!("failed to parse remote's exchange's id protobuf"); return Err(PlainTextError::InvalidPayload(None)); - }, + } }; // Check the validity of the remote's `Exchange`. if peer_id != public_key.to_peer_id() { debug!("the remote's `PeerId` isn't consistent with the remote's public key"); - return Err(PlainTextError::InvalidPeerId) + return Err(PlainTextError::InvalidPeerId); } Ok(HandshakeContext { @@ -105,13 +108,15 @@ impl HandshakeContext { state: Remote { peer_id, public_key, - } + }, }) } } -pub async fn handshake(socket: S, config: PlainText2Config) - -> Result<(S, Remote, Bytes), PlainTextError> +pub async fn handshake( + socket: S, + config: PlainText2Config, +) -> Result<(S, Remote, Bytes), PlainTextError> where S: AsyncRead + AsyncWrite + Send + Unpin, { @@ -122,7 +127,9 @@ where let context = HandshakeContext::new(config); trace!("sending exchange to remote"); - framed_socket.send(BytesMut::from(&context.state.exchange_bytes[..])).await?; + framed_socket + .send(BytesMut::from(&context.state.exchange_bytes[..])) + .await?; trace!("receiving the remote's exchange"); let context = match framed_socket.next().await { @@ -134,9 +141,17 @@ where } }; - trace!("received exchange from remote; pubkey = {:?}", context.state.public_key); - - let FramedParts { io, read_buffer, write_buffer, .. } = framed_socket.into_parts(); + trace!( + "received exchange from remote; pubkey = {:?}", + context.state.public_key + ); + + let FramedParts { + io, + read_buffer, + write_buffer, + .. + } = framed_socket.into_parts(); assert!(write_buffer.is_empty()); Ok((io, context.state, read_buffer.freeze())) } diff --git a/transports/plaintext/src/lib.rs b/transports/plaintext/src/lib.rs index c647ba0c474..99fbf2fd3d6 100644 --- a/transports/plaintext/src/lib.rs +++ b/transports/plaintext/src/lib.rs @@ -21,19 +21,16 @@ use crate::error::PlainTextError; use bytes::Bytes; +use futures::future::BoxFuture; use futures::future::{self, Ready}; use futures::prelude::*; -use futures::future::BoxFuture; -use libp2p_core::{ - identity, - InboundUpgrade, - OutboundUpgrade, - UpgradeInfo, - PeerId, - PublicKey, -}; +use libp2p_core::{identity, InboundUpgrade, OutboundUpgrade, PeerId, PublicKey, UpgradeInfo}; use log::debug; -use std::{io, iter, pin::Pin, task::{Context, Poll}}; +use std::{ + io, iter, + pin::Pin, + task::{Context, Poll}, +}; use void::Void; mod error; @@ -42,7 +39,6 @@ mod structs_proto { include!(concat!(env!("OUT_DIR"), "/structs.rs")); } - /// `PlainText1Config` is an insecure connection handshake for testing purposes only. /// /// > **Note**: Given that `PlainText1Config` has no notion of exchanging peer identity information it is not compatible @@ -119,7 +115,7 @@ impl UpgradeInfo for PlainText2Config { impl InboundUpgrade for PlainText2Config where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = (PeerId, PlainTextOutput); type Error = PlainTextError; @@ -132,7 +128,7 @@ where impl OutboundUpgrade for PlainText2Config where - C: AsyncRead + AsyncWrite + Send + Unpin + 'static + C: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Output = (PeerId, PlainTextOutput); type Error = PlainTextError; @@ -146,7 +142,7 @@ where impl PlainText2Config { async fn handshake(self, socket: T) -> Result<(PeerId, PlainTextOutput), PlainTextError> where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { debug!("Starting plaintext handshake."); let (socket, remote, read_buffer) = handshake::handshake(socket, self).await?; @@ -158,7 +154,7 @@ impl PlainText2Config { socket, remote_key: remote.public_key, read_buffer, - } + }, )) } } @@ -179,35 +175,35 @@ where } impl AsyncRead for PlainTextOutput { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) - -> Poll> - { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { if !self.read_buffer.is_empty() { let n = std::cmp::min(buf.len(), self.read_buffer.len()); let b = self.read_buffer.split_to(n); buf[..n].copy_from_slice(&b[..]); - return Poll::Ready(Ok(n)) + return Poll::Ready(Ok(n)); } AsyncRead::poll_read(Pin::new(&mut self.socket), cx, buf) } } impl AsyncWrite for PlainTextOutput { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) - -> Poll> - { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf) } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll> - { + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { AsyncWrite::poll_flush(Pin::new(&mut self.socket), cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) - -> Poll> - { + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { AsyncWrite::poll_close(Pin::new(&mut self.socket), cx) } } diff --git a/transports/plaintext/tests/smoke.rs b/transports/plaintext/tests/smoke.rs index 20a79c32e1d..ce155bdd92e 100644 --- a/transports/plaintext/tests/smoke.rs +++ b/transports/plaintext/tests/smoke.rs @@ -18,12 +18,12 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures::io::{AsyncWriteExt, AsyncReadExt}; +use futures::io::{AsyncReadExt, AsyncWriteExt}; use futures::stream::TryStreamExt; use libp2p_core::{ identity, multiaddr::Multiaddr, - transport::{Transport, ListenerEvent}, + transport::{ListenerEvent, Transport}, upgrade, }; use libp2p_plaintext::PlainText2Config; @@ -45,38 +45,40 @@ fn variable_msg_length() { let client_id_public = client_id.public(); futures::executor::block_on(async { - let server_transport = libp2p_core::transport::MemoryTransport{}.and_then( - move |output, endpoint| { + let server_transport = + libp2p_core::transport::MemoryTransport {}.and_then(move |output, endpoint| { upgrade::apply( output, - PlainText2Config{local_public_key: server_id_public}, + PlainText2Config { + local_public_key: server_id_public, + }, endpoint, libp2p_core::upgrade::Version::V1, ) - } - ); + }); - let client_transport = libp2p_core::transport::MemoryTransport{}.and_then( - move |output, endpoint| { + let client_transport = + libp2p_core::transport::MemoryTransport {}.and_then(move |output, endpoint| { upgrade::apply( output, - PlainText2Config{local_public_key: client_id_public}, + PlainText2Config { + local_public_key: client_id_public, + }, endpoint, libp2p_core::upgrade::Version::V1, ) - } - ); + }); - - let server_address: Multiaddr = format!( - "/memory/{}", - std::cmp::Ord::max(1, rand::random::()) - ).parse().unwrap(); + let server_address: Multiaddr = + format!("/memory/{}", std::cmp::Ord::max(1, rand::random::())) + .parse() + .unwrap(); let mut server = server_transport.listen_on(server_address.clone()).unwrap(); // Ignore server listen address event. - let _ = server.try_next() + let _ = server + .try_next() .await .expect("some event") .expect("no error") @@ -85,17 +87,25 @@ fn variable_msg_length() { let client_fut = async { debug!("dialing {:?}", server_address); - let (received_server_id, mut client_channel) = client_transport.dial(server_address).unwrap().await.unwrap(); + let (received_server_id, mut client_channel) = client_transport + .dial(server_address) + .unwrap() + .await + .unwrap(); assert_eq!(received_server_id, server_id.public().to_peer_id()); debug!("Client: writing message."); - client_channel.write_all(&mut msg_to_send).await.expect("no error"); + client_channel + .write_all(&mut msg_to_send) + .await + .expect("no error"); debug!("Client: flushing channel."); client_channel.flush().await.expect("no error"); }; let server_fut = async { - let mut server_channel = server.try_next() + let mut server_channel = server + .try_next() .await .expect("some event") .map(ListenerEvent::into_upgrade) @@ -108,7 +118,10 @@ fn variable_msg_length() { let mut server_buffer = vec![0; msg_to_receive.len()]; debug!("Server: reading message."); - server_channel.read_exact(&mut server_buffer).await.expect("reading client message"); + server_channel + .read_exact(&mut server_buffer) + .await + .expect("reading client message"); assert_eq!(server_buffer, msg_to_receive); }; @@ -117,5 +130,7 @@ fn variable_msg_length() { }) } - QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec)) + QuickCheck::new() + .max_tests(30) + .quickcheck(prop as fn(Vec)) } diff --git a/transports/pnet/src/crypt_writer.rs b/transports/pnet/src/crypt_writer.rs index a61957d395d..e13bb446ce6 100644 --- a/transports/pnet/src/crypt_writer.rs +++ b/transports/pnet/src/crypt_writer.rs @@ -100,7 +100,9 @@ fn poll_flush_buf( if written > 0 { buf.drain(..written); } - if let Poll::Ready(Ok(())) = ret { debug_assert!(buf.is_empty()); } + if let Poll::Ready(Ok(())) = ret { + debug_assert!(buf.is_empty()); + } ret } diff --git a/transports/pnet/src/lib.rs b/transports/pnet/src/lib.rs index 468e82cd7c9..efd27b14667 100644 --- a/transports/pnet/src/lib.rs +++ b/transports/pnet/src/lib.rs @@ -74,7 +74,10 @@ impl PreSharedKey { cipher.apply_keystream(&mut enc); let mut hasher = Shake128::default(); hasher.write_all(&enc).expect("shake128 failed"); - hasher.finalize_xof().read_exact(&mut out).expect("shake128 failed"); + hasher + .finalize_xof() + .read_exact(&mut out) + .expect("shake128 failed"); Fingerprint(out) } } diff --git a/transports/tcp/src/lib.rs b/transports/tcp/src/lib.rs index 5cf4f0fcebd..e556bf39087 100644 --- a/transports/tcp/src/lib.rs +++ b/transports/tcp/src/lib.rs @@ -57,14 +57,14 @@ use socket2::{Domain, Socket, Type}; use std::{ collections::HashSet, io, - net::{SocketAddr, IpAddr, TcpListener}, + net::{IpAddr, SocketAddr, TcpListener}, pin::Pin, sync::{Arc, RwLock}, task::{Context, Poll}, time::Duration, }; -use provider::{Provider, IfEvent}; +use provider::{IfEvent, Provider}; /// The configuration for a TCP/IP transport capability for libp2p. /// @@ -101,7 +101,7 @@ enum PortReuse { Enabled { /// The addresses and ports of the listening sockets /// registered as eligible for port reuse when dialing. - listen_addrs: Arc>> + listen_addrs: Arc>>, }, } @@ -151,7 +151,7 @@ impl PortReuse { if ip.is_ipv4() == remote_ip.is_ipv4() && ip.is_loopback() == remote_ip.is_loopback() { - return Some(SocketAddr::new(*ip, *port)) + return Some(SocketAddr::new(*ip, *port)); } } } @@ -302,7 +302,7 @@ where pub fn port_reuse(mut self, port_reuse: bool) -> Self { self.port_reuse = if port_reuse { PortReuse::Enabled { - listen_addrs: Arc::new(RwLock::new(HashSet::new())) + listen_addrs: Arc::new(RwLock::new(HashSet::new())), } } else { PortReuse::Disabled @@ -385,8 +385,7 @@ where return Err(TransportError::MultiaddrNotSupported(addr)); }; log::debug!("listening on {}", socket_addr); - self.do_listen(socket_addr) - .map_err(TransportError::Other) + self.do_listen(socket_addr).map_err(TransportError::Other) } fn dial(self, addr: Multiaddr) -> Result> { @@ -439,19 +438,19 @@ enum InAddr { /// The stream accepts connections on a single interface. One { addr: IpAddr, - out: Option + out: Option, }, /// The stream accepts connections on all interfaces. Any { addrs: HashSet, if_watch: IfWatch, - } + }, } /// A stream of incoming connections on one or more interfaces. pub struct TcpListenStream where - T: Provider + T: Provider, { /// The socket address that the listening socket is bound to, /// which may be a "wildcard address" like `INADDR_ANY` or `IN6ADDR_ANY` @@ -481,7 +480,7 @@ where impl TcpListenStream where - T: Provider + T: Provider, { /// Constructs a `TcpListenStream` for incoming connections around /// the given `TcpListener`. @@ -527,7 +526,7 @@ where match &self.in_addr { InAddr::One { addr, .. } => { self.port_reuse.unregister(*addr, self.listen_addr.port()); - }, + } InAddr::Any { addrs, .. } => { for addr in addrs { self.port_reuse.unregister(*addr, self.listen_addr.port()); @@ -539,7 +538,7 @@ where impl Drop for TcpListenStream where - T: Provider + T: Provider, { fn drop(&mut self) { self.disable_port_reuse(); @@ -565,7 +564,7 @@ where IfWatch::Pending(f) => match ready!(Pin::new(f).poll(cx)) { Ok(w) => { *if_watch = IfWatch::Ready(w); - continue + continue; } Err(err) => { log::debug! { @@ -578,42 +577,52 @@ where } }, // Consume all events for up/down interface changes. - IfWatch::Ready(watch) => while let Poll::Ready(ev) = T::poll_interfaces(watch, cx) { - match ev { - Ok(IfEvent::Up(inet)) => { - let ip = inet.addr(); - if me.listen_addr.is_ipv4() == ip.is_ipv4() && addrs.insert(ip) { - let ma = ip_to_multiaddr(ip, me.listen_addr.port()); - log::debug!("New listen address: {}", ma); - me.port_reuse.register(ip, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(ma)))); + IfWatch::Ready(watch) => { + while let Poll::Ready(ev) = T::poll_interfaces(watch, cx) { + match ev { + Ok(IfEvent::Up(inet)) => { + let ip = inet.addr(); + if me.listen_addr.is_ipv4() == ip.is_ipv4() && addrs.insert(ip) + { + let ma = ip_to_multiaddr(ip, me.listen_addr.port()); + log::debug!("New listen address: {}", ma); + me.port_reuse.register(ip, me.listen_addr.port()); + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress( + ma, + )))); + } } - } - Ok(IfEvent::Down(inet)) => { - let ip = inet.addr(); - if me.listen_addr.is_ipv4() == ip.is_ipv4() && addrs.remove(&ip) { - let ma = ip_to_multiaddr(ip, me.listen_addr.port()); - log::debug!("Expired listen address: {}", ma); - me.port_reuse.unregister(ip, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::AddressExpired(ma)))); + Ok(IfEvent::Down(inet)) => { + let ip = inet.addr(); + if me.listen_addr.is_ipv4() == ip.is_ipv4() && addrs.remove(&ip) + { + let ma = ip_to_multiaddr(ip, me.listen_addr.port()); + log::debug!("Expired listen address: {}", ma); + me.port_reuse.unregister(ip, me.listen_addr.port()); + return Poll::Ready(Some(Ok( + ListenerEvent::AddressExpired(ma), + ))); + } + } + Err(err) => { + log::debug! { + "Failure polling interfaces: {:?}. Scheduling retry.", + err + }; + me.pause = Some(Delay::new(me.sleep_on_error)); + return Poll::Ready(Some(Ok(ListenerEvent::Error(err)))); } - } - Err(err) => { - log::debug! { - "Failure polling interfaces: {:?}. Scheduling retry.", - err - }; - me.pause = Some(Delay::new(me.sleep_on_error)); - return Poll::Ready(Some(Ok(ListenerEvent::Error(err)))); } } - }, + } }, // If the listener is bound to a single interface, make sure the // address is registered for port reuse and reported once. - InAddr::One { addr, out } => if let Some(multiaddr) = out.take() { - me.port_reuse.register(*addr, me.listen_addr.port()); - return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(multiaddr)))) + InAddr::One { addr, out } => { + if let Some(multiaddr) = out.take() { + me.port_reuse.register(*addr, me.listen_addr.port()); + return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(multiaddr)))); + } } } @@ -640,7 +649,8 @@ where }; let local_addr = ip_to_multiaddr(incoming.local_addr.ip(), incoming.local_addr.port()); - let remote_addr = ip_to_multiaddr(incoming.remote_addr.ip(), incoming.remote_addr.port()); + let remote_addr = + ip_to_multiaddr(incoming.remote_addr.ip(), incoming.remote_addr.port()); log::debug!("Incoming connection from {} at {}", remote_addr, local_addr); @@ -666,18 +676,18 @@ fn multiaddr_to_socketaddr(mut addr: Multiaddr) -> Result { match proto { Protocol::Ip4(ipv4) => match port { Some(port) => return Ok(SocketAddr::new(ipv4.into(), port)), - None => return Err(()) + None => return Err(()), }, Protocol::Ip6(ipv6) => match port { Some(port) => return Ok(SocketAddr::new(ipv6.into(), port)), - None => return Err(()) + None => return Err(()), }, Protocol::Tcp(portnum) => match port { Some(_) => return Err(()), - None => { port = Some(portnum) } - } + None => port = Some(portnum), + }, Protocol::P2p(_) => {} - _ => return Err(()) + _ => return Err(()), } } Err(()) @@ -685,15 +695,13 @@ fn multiaddr_to_socketaddr(mut addr: Multiaddr) -> Result { // Create a [`Multiaddr`] from the given IP address and port number. fn ip_to_multiaddr(ip: IpAddr, port: u16) -> Multiaddr { - Multiaddr::empty() - .with(ip.into()) - .with(Protocol::Tcp(port)) + Multiaddr::empty().with(ip.into()).with(Protocol::Tcp(port)) } #[cfg(test)] mod tests { - use futures::channel::mpsc; use super::*; + use futures::channel::mpsc; #[test] fn multiaddr_to_tcp_conversion() { @@ -748,7 +756,7 @@ mod tests { fn communicating_between_dialer_and_listener() { env_logger::try_init().ok(); - async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { + async fn listener(addr: Multiaddr, mut ready_tx: mpsc::Sender) { let tcp = GenTcpConfig::::new(); let mut listener = tcp.listen_on(addr).unwrap(); loop { @@ -762,7 +770,7 @@ mod tests { upgrade.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, [1, 2, 3]); upgrade.write_all(&[4, 5, 6]).await.unwrap(); - return + return; } e => panic!("Unexpected listener event: {:?}", e), } @@ -798,7 +806,10 @@ mod tests { let (ready_tx, ready_rx) = mpsc::channel(1); let listener = listener::(addr.clone(), ready_tx); let dialer = dialer::(ready_rx); - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); let tasks = tokio_crate::task::LocalSet::new(); let listener = tasks.spawn_local(listener); tasks.block_on(&rt, dialer); @@ -833,7 +844,7 @@ mod tests { panic!("No TCP port in address: {}", a) } ready_tx.send(a).await.ok(); - return + return; } _ => {} } @@ -862,7 +873,10 @@ mod tests { let (ready_tx, ready_rx) = mpsc::channel(1); let listener = listener::(addr.clone(), ready_tx); let dialer = dialer::(ready_rx); - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); let tasks = tokio_crate::task::LocalSet::new(); let listener = tasks.spawn_local(listener); tasks.block_on(&rt, dialer); @@ -892,7 +906,7 @@ mod tests { upgrade.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, [1, 2, 3]); upgrade.write_all(&[4, 5, 6]).await.unwrap(); - return + return; } e => panic!("Unexpected event: {:?}", e), } @@ -913,7 +927,7 @@ mod tests { socket.read_exact(&mut buf).await.unwrap(); assert_eq!(buf, [4, 5, 6]); } - e => panic!("Unexpected listener event: {:?}", e) + e => panic!("Unexpected listener event: {:?}", e), } } @@ -933,7 +947,10 @@ mod tests { let (ready_tx, ready_rx) = mpsc::channel(1); let listener = listener::(addr.clone(), ready_tx); let dialer = dialer::(addr.clone(), ready_rx); - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); let tasks = tokio_crate::task::LocalSet::new(); let listener = tasks.spawn_local(listener); tasks.block_on(&rt, dialer); @@ -959,7 +976,7 @@ mod tests { match listener2.next().await.unwrap().unwrap() { ListenerEvent::NewAddress(addr2) => { assert_eq!(addr1, addr2); - return + return; } e => panic!("Unexpected listener event: {:?}", e), } @@ -978,7 +995,10 @@ mod tests { #[cfg(feature = "tokio")] { let listener = listen_twice::(addr.clone()); - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); rt.block_on(listener); } } @@ -1011,7 +1031,10 @@ mod tests { #[cfg(feature = "tokio")] { - let rt = tokio_crate::runtime::Builder::new_current_thread().enable_io().build().unwrap(); + let rt = tokio_crate::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); let new_addr = rt.block_on(listen::(addr.clone())); assert!(!new_addr.to_string().contains("tcp/0")); } diff --git a/transports/tcp/src/provider.rs b/transports/tcp/src/provider.rs index 091a6691087..7ebeaa49ee8 100644 --- a/transports/tcp/src/provider.rs +++ b/transports/tcp/src/provider.rs @@ -26,12 +26,12 @@ pub mod async_io; #[cfg(feature = "tokio")] pub mod tokio; -use futures::io::{AsyncRead, AsyncWrite}; use futures::future::BoxFuture; +use futures::io::{AsyncRead, AsyncWrite}; use ipnet::IpNet; +use std::net::{SocketAddr, TcpListener, TcpStream}; use std::task::{Context, Poll}; use std::{fmt, io}; -use std::net::{SocketAddr, TcpListener, TcpStream}; /// An event relating to a change of availability of an address /// on a network interface. @@ -73,7 +73,10 @@ pub trait Provider: Clone + Send + 'static { /// Polls a [`Self::Listener`] for an incoming connection, ensuring a task wakeup, /// if necessary. - fn poll_accept(_: &mut Self::Listener, _: &mut Context<'_>) -> Poll>>; + fn poll_accept( + _: &mut Self::Listener, + _: &mut Context<'_>, + ) -> Poll>>; /// Polls a [`Self::IfWatcher`] for network interface changes, ensuring a task wakeup, /// if necessary. diff --git a/transports/tcp/src/provider/async_io.rs b/transports/tcp/src/provider/async_io.rs index b4ce74d6901..ab65544d872 100644 --- a/transports/tcp/src/provider/async_io.rs +++ b/transports/tcp/src/provider/async_io.rs @@ -18,15 +18,13 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use super::{Provider, IfEvent, Incoming}; +use super::{IfEvent, Incoming, Provider}; use async_io_crate::Async; -use futures::{ - future::{BoxFuture, FutureExt}, -}; +use futures::future::{BoxFuture, FutureExt}; use std::io; -use std::task::{Poll, Context}; use std::net; +use std::task::{Context, Poll}; #[derive(Copy, Clone)] pub enum Tcp {} @@ -49,10 +47,14 @@ impl Provider for Tcp { let stream = Async::new(s)?; stream.writable().await?; Ok(stream) - }.boxed() + } + .boxed() } - fn poll_accept(l: &mut Self::Listener, cx: &mut Context<'_>) -> Poll>> { + fn poll_accept( + l: &mut Self::Listener, + cx: &mut Context<'_>, + ) -> Poll>> { let (stream, remote_addr) = loop { match l.poll_readable(cx) { Poll::Pending => return Poll::Pending, @@ -64,13 +66,17 @@ impl Provider for Tcp { // Since it doesn't do any harm, account for false positives of // `poll_readable` just in case, i.e. try again. } - } + }, } }; let local_addr = stream.get_ref().local_addr()?; - Poll::Ready(Ok(Incoming { stream, local_addr, remote_addr })) + Poll::Ready(Ok(Incoming { + stream, + local_addr, + remote_addr, + })) } fn poll_interfaces(w: &mut Self::IfWatcher, cx: &mut Context<'_>) -> Poll> { diff --git a/transports/tcp/src/provider/tokio.rs b/transports/tcp/src/provider/tokio.rs index 0e8136f2c60..257bccd2926 100644 --- a/transports/tcp/src/provider/tokio.rs +++ b/transports/tcp/src/provider/tokio.rs @@ -18,22 +18,22 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use super::{Provider, IfEvent, Incoming}; +use super::{IfEvent, Incoming, Provider}; use futures::{ future::{self, BoxFuture, FutureExt}, prelude::*, }; use futures_timer::Delay; -use if_addrs::{IfAddr, get_if_addrs}; +use if_addrs::{get_if_addrs, IfAddr}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use std::collections::HashSet; use std::convert::TryFrom; use std::io; -use std::task::{Poll, Context}; -use std::time::Duration; use std::net; use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; #[derive(Copy, Clone)] pub enum Tcp {} @@ -50,13 +50,12 @@ impl Provider for Tcp { type IfWatcher = IfWatcher; fn if_watcher() -> BoxFuture<'static, io::Result> { - future::ready(Ok( - IfWatcher { - addrs: HashSet::new(), - delay: Delay::new(Duration::from_secs(0)), - pending: Vec::new(), - } - )).boxed() + future::ready(Ok(IfWatcher { + addrs: HashSet::new(), + delay: Delay::new(Duration::from_secs(0)), + pending: Vec::new(), + })) + .boxed() } fn new_listener(l: net::TcpListener) -> io::Result { @@ -68,48 +67,59 @@ impl Provider for Tcp { let stream = tokio_crate::net::TcpStream::try_from(s)?; stream.writable().await?; Ok(TcpStream(stream)) - }.boxed() + } + .boxed() } - fn poll_accept(l: &mut Self::Listener, cx: &mut Context<'_>) - -> Poll>> - { + fn poll_accept( + l: &mut Self::Listener, + cx: &mut Context<'_>, + ) -> Poll>> { let (stream, remote_addr) = match l.poll_accept(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok((stream, remote_addr))) => (stream, remote_addr) + Poll::Ready(Ok((stream, remote_addr))) => (stream, remote_addr), }; let local_addr = stream.local_addr()?; let stream = TcpStream(stream); - Poll::Ready(Ok(Incoming { stream, local_addr, remote_addr })) + Poll::Ready(Ok(Incoming { + stream, + local_addr, + remote_addr, + })) } fn poll_interfaces(w: &mut Self::IfWatcher, cx: &mut Context<'_>) -> Poll> { loop { if let Some(event) = w.pending.pop() { - return Poll::Ready(Ok(event)) + return Poll::Ready(Ok(event)); } match Pin::new(&mut w.delay).poll(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(()) => { let ifs = get_if_addrs()?; - let addrs = ifs.into_iter().map(|iface| match iface.addr { - IfAddr::V4(ip4) => { - let prefix_len = (!u32::from_be_bytes(ip4.netmask.octets())).leading_zeros(); - let ipnet = Ipv4Net::new(ip4.ip, prefix_len as u8) - .expect("prefix_len can not exceed 32"); - IpNet::V4(ipnet) - } - IfAddr::V6(ip6) => { - let prefix_len = (!u128::from_be_bytes(ip6.netmask.octets())).leading_zeros(); - let ipnet = Ipv6Net::new(ip6.ip, prefix_len as u8) - .expect("prefix_len can not exceed 128"); - IpNet::V6(ipnet) - } - }).collect::>(); + let addrs = ifs + .into_iter() + .map(|iface| match iface.addr { + IfAddr::V4(ip4) => { + let prefix_len = + (!u32::from_be_bytes(ip4.netmask.octets())).leading_zeros(); + let ipnet = Ipv4Net::new(ip4.ip, prefix_len as u8) + .expect("prefix_len can not exceed 32"); + IpNet::V4(ipnet) + } + IfAddr::V6(ip6) => { + let prefix_len = + (!u128::from_be_bytes(ip6.netmask.octets())).leading_zeros(); + let ipnet = Ipv6Net::new(ip6.ip, prefix_len as u8) + .expect("prefix_len can not exceed 128"); + IpNet::V6(ipnet) + } + }) + .collect::>(); for down in w.addrs.difference(&addrs) { w.pending.push(IfEvent::Down(*down)); @@ -138,15 +148,27 @@ impl Into for TcpStream { } impl AsyncRead for TcpStream { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { let mut read_buf = tokio_crate::io::ReadBuf::new(buf); - futures::ready!(tokio_crate::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, &mut read_buf))?; + futures::ready!(tokio_crate::io::AsyncRead::poll_read( + Pin::new(&mut self.0), + cx, + &mut read_buf + ))?; Poll::Ready(Ok(read_buf.filled().len())) } } impl AsyncWrite for TcpStream { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { tokio_crate::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf) } @@ -161,7 +183,7 @@ impl AsyncWrite for TcpStream { fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>] + bufs: &[io::IoSlice<'_>], ) -> Poll> { tokio_crate::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs) } diff --git a/transports/uds/src/lib.rs b/transports/uds/src/lib.rs index 67da6c5fd85..34ac4eb51c3 100644 --- a/transports/uds/src/lib.rs +++ b/transports/uds/src/lib.rs @@ -34,12 +34,15 @@ #![cfg(all(unix, not(target_os = "emscripten")))] #![cfg_attr(docsrs, doc(cfg(all(unix, not(target_os = "emscripten")))))] -use futures::{prelude::*, future::{BoxFuture, Ready}}; use futures::stream::BoxStream; +use futures::{ + future::{BoxFuture, Ready}, + prelude::*, +}; use libp2p_core::{ + multiaddr::{Multiaddr, Protocol}, + transport::{ListenerEvent, TransportError}, Transport, - multiaddr::{Protocol, Multiaddr}, - transport::{ListenerEvent, TransportError} }; use log::debug; use std::{io, path::PathBuf}; @@ -145,14 +148,14 @@ fn multiaddr_to_path(addr: &Multiaddr) -> Result { Some(Protocol::Unix(ref path)) => { let path = PathBuf::from(path.as_ref()); if !path.is_absolute() { - return Err(()) + return Err(()); } match protocols.next() { None | Some(Protocol::P2p(_)) => Ok(path), - Some(_) => Err(()) + Some(_) => Err(()), } } - _ => Err(()) + _ => Err(()), } } @@ -160,15 +163,17 @@ fn multiaddr_to_path(addr: &Multiaddr) -> Result { mod tests { use super::{multiaddr_to_path, UdsConfig}; use futures::{channel::oneshot, prelude::*}; + use libp2p_core::{ + multiaddr::{Multiaddr, Protocol}, + Transport, + }; use std::{self, borrow::Cow, path::Path}; - use libp2p_core::{Transport, multiaddr::{Protocol, Multiaddr}}; use tempfile; #[test] fn multiaddr_to_path_conversion() { assert!( - multiaddr_to_path(&"/ip4/127.0.0.1/udp/1234".parse::().unwrap()) - .is_err() + multiaddr_to_path(&"/ip4/127.0.0.1/udp/1234".parse::().unwrap()).is_err() ); assert_eq!( @@ -185,21 +190,27 @@ mod tests { fn communicating_between_dialer_and_listener() { let temp_dir = tempfile::tempdir().unwrap(); let socket = temp_dir.path().join("socket"); - let addr = Multiaddr::from(Protocol::Unix(Cow::Owned(socket.to_string_lossy().into_owned()))); + let addr = Multiaddr::from(Protocol::Unix(Cow::Owned( + socket.to_string_lossy().into_owned(), + ))); let (tx, rx) = oneshot::channel(); async_std::task::spawn(async move { let mut listener = UdsConfig::new().listen_on(addr).unwrap(); - let listen_addr = listener.try_next().await.unwrap() + let listen_addr = listener + .try_next() + .await + .unwrap() .expect("some event") .into_new_address() .expect("listen address"); tx.send(listen_addr).unwrap(); - let (sock, _addr) = listener.try_filter_map(|e| future::ok(e.into_upgrade())) + let (sock, _addr) = listener + .try_filter_map(|e| future::ok(e.into_upgrade())) .try_next() .await .unwrap() @@ -220,18 +231,16 @@ mod tests { } #[test] - #[ignore] // TODO: for the moment unix addresses fail to parse + #[ignore] // TODO: for the moment unix addresses fail to parse fn larger_addr_denied() { let uds = UdsConfig::new(); - let addr = "/unix//foo/bar" - .parse::() - .unwrap(); + let addr = "/unix//foo/bar".parse::().unwrap(); assert!(uds.listen_on(addr).is_err()); } #[test] - #[ignore] // TODO: for the moment unix addresses fail to parse + #[ignore] // TODO: for the moment unix addresses fail to parse fn relative_addr_denied() { assert!("/unix/./foo/bar".parse::().is_err()); } diff --git a/transports/wasm-ext/src/lib.rs b/transports/wasm-ext/src/lib.rs index cec2ad1c1b9..27aafdb70c3 100644 --- a/transports/wasm-ext/src/lib.rs +++ b/transports/wasm-ext/src/lib.rs @@ -32,11 +32,11 @@ //! module. //! -use futures::{prelude::*, future::Ready}; +use futures::{future::Ready, prelude::*}; use libp2p_core::{transport::ListenerEvent, transport::TransportError, Multiaddr, Transport}; use parity_send_wrapper::SendWrapper; use std::{collections::VecDeque, error, fmt, io, mem, pin::Pin, task::Context, task::Poll}; -use wasm_bindgen::{JsCast, prelude::*}; +use wasm_bindgen::{prelude::*, JsCast}; use wasm_bindgen_futures::JsFuture; /// Contains the definition that one must match on the JavaScript side. @@ -172,16 +172,13 @@ impl Transport for ExtTransport { type Dial = Dial; fn listen_on(self, addr: Multiaddr) -> Result> { - let iter = self - .inner - .listen_on(&addr.to_string()) - .map_err(|err| { - if is_not_supported_error(&err) { - TransportError::MultiaddrNotSupported(addr) - } else { - TransportError::Other(JsErr::from(err)) - } - })?; + let iter = self.inner.listen_on(&addr.to_string()).map_err(|err| { + if is_not_supported_error(&err) { + TransportError::MultiaddrNotSupported(addr) + } else { + TransportError::Other(JsErr::from(err)) + } + })?; Ok(Listen { iterator: SendWrapper::new(iter), @@ -191,16 +188,13 @@ impl Transport for ExtTransport { } fn dial(self, addr: Multiaddr) -> Result> { - let promise = self - .inner - .dial(&addr.to_string()) - .map_err(|err| { - if is_not_supported_error(&err) { - TransportError::MultiaddrNotSupported(addr) - } else { - TransportError::Other(JsErr::from(err)) - } - })?; + let promise = self.inner.dial(&addr.to_string()).map_err(|err| { + if is_not_supported_error(&err) { + TransportError::MultiaddrNotSupported(addr) + } else { + TransportError::Other(JsErr::from(err)) + } + })?; Ok(Dial { inner: SendWrapper::new(promise.into()), @@ -315,7 +309,9 @@ impl Stream for Listen { .flat_map(|e| e.to_vec().into_iter()) { match js_value_to_addr(&addr) { - Ok(addr) => self.pending_events.push_back(ListenerEvent::NewAddress(addr)), + Ok(addr) => self + .pending_events + .push_back(ListenerEvent::NewAddress(addr)), Err(err) => self.pending_events.push_back(ListenerEvent::Error(err)), } } @@ -375,10 +371,16 @@ impl fmt::Debug for Connection { } impl AsyncRead for Connection { - fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { loop { match mem::replace(&mut self.read_state, ConnectionReadState::Finished) { - ConnectionReadState::Finished => break Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())), + ConnectionReadState::Finished => { + break Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } ConnectionReadState::PendingData(ref data) if data.is_empty() => { let iter_next = self.read_iterator.next().map_err(JsErr::from)?; @@ -411,7 +413,9 @@ impl AsyncRead for Connection { let data = match Future::poll(Pin::new(&mut *promise), cx) { Poll::Ready(Ok(ref data)) if data.is_null() => break Poll::Ready(Ok(0)), Poll::Ready(Ok(data)) => data, - Poll::Ready(Err(err)) => break Poll::Ready(Err(io::Error::from(JsErr::from(err)))), + Poll::Ready(Err(err)) => { + break Poll::Ready(Err(io::Error::from(JsErr::from(err)))) + } Poll::Pending => { self.read_state = ConnectionReadState::Waiting(promise); break Poll::Pending; @@ -439,14 +443,20 @@ impl AsyncRead for Connection { } impl AsyncWrite for Connection { - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { // Note: as explained in the doc-comments of `Connection`, each call to this function must // map to exactly one call to `self.inner.write()`. if let Some(mut promise) = self.previous_write_promise.take() { match Future::poll(Pin::new(&mut *promise), cx) { Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(err)) => return Poll::Ready(Err(io::Error::from(JsErr::from(err)))), + Poll::Ready(Err(err)) => { + return Poll::Ready(Err(io::Error::from(JsErr::from(err)))) + } Poll::Pending => { self.previous_write_promise = Some(promise); return Poll::Pending; diff --git a/transports/websocket/src/error.rs b/transports/websocket/src/error.rs index 65a5d8350c0..47421d4c069 100644 --- a/transports/websocket/src/error.rs +++ b/transports/websocket/src/error.rs @@ -18,8 +18,8 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use libp2p_core::Multiaddr; use crate::tls; +use libp2p_core::Multiaddr; use std::{error, fmt}; /// Error in WebSockets. @@ -38,7 +38,7 @@ pub enum Error { /// The location header URL was invalid. InvalidRedirectLocation, /// Websocket base framing error. - Base(Box) + Base(Box), } impl fmt::Display for Error { @@ -50,7 +50,7 @@ impl fmt::Display for Error { Error::InvalidMultiaddr(ma) => write!(f, "invalid multi-address: {}", ma), Error::TooManyRedirects => f.write_str("too many redirects"), Error::InvalidRedirectLocation => f.write_str("invalid redirect location"), - Error::Base(err) => write!(f, "{}", err) + Error::Base(err) => write!(f, "{}", err), } } } @@ -64,7 +64,7 @@ impl error::Error for Error { Error::Base(err) => Some(&**err), Error::InvalidMultiaddr(_) | Error::TooManyRedirects - | Error::InvalidRedirectLocation => None + | Error::InvalidRedirectLocation => None, } } } diff --git a/transports/websocket/src/framed.rs b/transports/websocket/src/framed.rs index 204eddd836f..dc57cb8e220 100644 --- a/transports/websocket/src/framed.rs +++ b/transports/websocket/src/framed.rs @@ -18,15 +18,15 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures_rustls::{webpki, client, server}; use crate::{error::Error, tls}; use either::Either; use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream}; +use futures_rustls::{client, server, webpki}; use libp2p_core::{ - Transport, either::EitherOutput, - multiaddr::{Protocol, Multiaddr}, - transport::{ListenerEvent, TransportError} + multiaddr::{Multiaddr, Protocol}, + transport::{ListenerEvent, TransportError}, + Transport, }; use log::{debug, trace}; use soketto::{connection, extension::deflate::Deflate, handshake}; @@ -45,7 +45,7 @@ pub struct WsConfig { max_data_size: usize, tls_config: tls::Config, max_redirects: u8, - use_deflate: bool + use_deflate: bool, } impl WsConfig { @@ -56,7 +56,7 @@ impl WsConfig { max_data_size: MAX_DATA_SIZE, tls_config: tls::Config::client(), max_redirects: 0, - use_deflate: false + use_deflate: false, } } @@ -104,11 +104,12 @@ where T::Dial: Send + 'static, T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, - T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static + T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = Connection; type Error = Error; - type Listener = BoxStream<'static, Result, Self::Error>>; + type Listener = + BoxStream<'static, Result, Self::Error>>; type ListenerUpgrade = BoxFuture<'static, Result>; type Dial = BoxFuture<'static, Result>; @@ -116,24 +117,28 @@ where let mut inner_addr = addr.clone(); let (use_tls, proto) = match inner_addr.pop() { - Some(p@Protocol::Wss(_)) => + Some(p @ Protocol::Wss(_)) => { if self.tls_config.server.is_some() { (true, p) } else { debug!("/wss address but TLS server support is not configured"); - return Err(TransportError::MultiaddrNotSupported(addr)) + return Err(TransportError::MultiaddrNotSupported(addr)); } - Some(p@Protocol::Ws(_)) => (false, p), + } + Some(p @ Protocol::Ws(_)) => (false, p), _ => { debug!("{} is not a websocket multiaddr", addr); - return Err(TransportError::MultiaddrNotSupported(addr)) + return Err(TransportError::MultiaddrNotSupported(addr)); } }; let tls_config = self.tls_config; let max_size = self.max_data_size; let use_deflate = self.use_deflate; - let transport = self.transport.listen_on(inner_addr).map_err(|e| e.map(Error::Transport))?; + let transport = self + .transport + .listen_on(inner_addr) + .map_err(|e| e.map(Error::Transport))?; let listen = transport .map_err(Error::Transport) .map_ok(move |event| match event { @@ -146,10 +151,12 @@ where a = a.with(proto.clone()); ListenerEvent::AddressExpired(a) } - ListenerEvent::Error(err) => { - ListenerEvent::Error(Error::Transport(err)) - } - ListenerEvent::Upgrade { upgrade, mut local_addr, mut remote_addr } => { + ListenerEvent::Error(err) => ListenerEvent::Error(Error::Transport(err)), + ListenerEvent::Upgrade { + upgrade, + mut local_addr, + mut remote_addr, + } => { local_addr = local_addr.with(proto.clone()); remote_addr = remote_addr.with(proto.clone()); let remote1 = remote_addr.clone(); // used for logging @@ -160,28 +167,30 @@ where let stream = upgrade.map_err(Error::Transport).await?; trace!("incoming connection from {}", remote1); - let stream = - if use_tls { // begin TLS session - let server = tls_config - .server - .expect("for use_tls we checked server is not none"); + let stream = if use_tls { + // begin TLS session + let server = tls_config + .server + .expect("for use_tls we checked server is not none"); - trace!("awaiting TLS handshake with {}", remote1); + trace!("awaiting TLS handshake with {}", remote1); - let stream = server.accept(stream) - .map_err(move |e| { - debug!("TLS handshake with {} failed: {}", remote1, e); - Error::Tls(tls::Error::from(e)) - }) - .await?; + let stream = server + .accept(stream) + .map_err(move |e| { + debug!("TLS handshake with {} failed: {}", remote1, e); + Error::Tls(tls::Error::from(e)) + }) + .await?; - let stream: TlsOrPlain<_> = - EitherOutput::First(EitherOutput::Second(stream)); + let stream: TlsOrPlain<_> = + EitherOutput::First(EitherOutput::Second(stream)); - stream - } else { // continue with plain stream - EitherOutput::Second(stream) - }; + stream + } else { + // continue with plain stream + EitherOutput::Second(stream) + }; trace!("receiving websocket handshake request from {}", remote2); @@ -192,7 +201,8 @@ where } let ws_key = { - let request = server.receive_request() + let request = server + .receive_request() .map_err(|e| Error::Handshake(Box::new(e))) .await?; request.into_key() @@ -200,13 +210,13 @@ where trace!("accepting websocket handshake request from {}", remote2); - let response = - handshake::server::Response::Accept { - key: &ws_key, - protocol: None - }; + let response = handshake::server::Response::Accept { + key: &ws_key, + protocol: None, + }; - server.send_response(&response) + server + .send_response(&response) .map_err(|e| Error::Handshake(Box::new(e))) .await?; @@ -223,7 +233,7 @@ where ListenerEvent::Upgrade { upgrade: Box::pin(upgrade) as BoxFuture<'static, _>, local_addr, - remote_addr + remote_addr, } } }); @@ -233,7 +243,9 @@ where fn dial(self, addr: Multiaddr) -> Result> { let addr = match parse_ws_dial_addr(addr) { Ok(addr) => addr, - Err(Error::InvalidMultiaddr(a)) => return Err(TransportError::MultiaddrNotSupported(a)), + Err(Error::InvalidMultiaddr(a)) => { + return Err(TransportError::MultiaddrNotSupported(a)) + } Err(e) => return Err(TransportError::Other(e)), }; @@ -247,13 +259,13 @@ where Ok(Either::Left(redirect)) => { if remaining_redirects == 0 { debug!("Too many redirects (> {})", self.max_redirects); - return Err(Error::TooManyRedirects) + return Err(Error::TooManyRedirects); } remaining_redirects -= 1; addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)? } Ok(Either::Right(conn)) => return Ok(conn), - Err(e) => return Err(e) + Err(e) => return Err(e), } } }; @@ -269,37 +281,45 @@ where impl WsConfig where T: Transport, - T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static + T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static, { /// Attempts to dial the given address and perform a websocket handshake. - async fn dial_once(self, addr: WsAddress) -> Result>, Error> { + async fn dial_once( + self, + addr: WsAddress, + ) -> Result>, Error> { trace!("Dialing websocket address: {:?}", addr); - let dial = self.transport.dial(addr.tcp_addr) - .map_err(|e| match e { - TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a), - TransportError::Other(e) => Error::Transport(e) - })?; + let dial = self.transport.dial(addr.tcp_addr).map_err(|e| match e { + TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a), + TransportError::Other(e) => Error::Transport(e), + })?; let stream = dial.map_err(Error::Transport).await?; trace!("TCP connection to {} established.", addr.host_port); - let stream = - if addr.use_tls { // begin TLS session - let dns_name = addr.dns_name.expect("for use_tls we have checked that dns_name is some"); - trace!("Starting TLS handshake with {:?}", dns_name); - let stream = self.tls_config.client.connect(dns_name.as_ref(), stream) - .map_err(|e| { - debug!("TLS handshake with {:?} failed: {}", dns_name, e); - Error::Tls(tls::Error::from(e)) - }) - .await?; - - let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::First(stream)); - stream - } else { // continue with plain stream - EitherOutput::Second(stream) - }; + let stream = if addr.use_tls { + // begin TLS session + let dns_name = addr + .dns_name + .expect("for use_tls we have checked that dns_name is some"); + trace!("Starting TLS handshake with {:?}", dns_name); + let stream = self + .tls_config + .client + .connect(dns_name.as_ref(), stream) + .map_err(|e| { + debug!("TLS handshake with {:?} failed: {}", dns_name, e); + Error::Tls(tls::Error::from(e)) + }) + .await?; + + let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::First(stream)); + stream + } else { + // continue with plain stream + EitherOutput::Second(stream) + }; trace!("Sending websocket handshake to {}", addr.host_port); @@ -309,9 +329,19 @@ where client.add_extension(Box::new(Deflate::new(connection::Mode::Client))); } - match client.handshake().map_err(|e| Error::Handshake(Box::new(e))).await? { - handshake::ServerResponse::Redirect { status_code, location } => { - debug!("received redirect ({}); location: {}", status_code, location); + match client + .handshake() + .map_err(|e| Error::Handshake(Box::new(e))) + .await? + { + handshake::ServerResponse::Redirect { + status_code, + location, + } => { + debug!( + "received redirect ({}); location: {}", + status_code, location + ); Ok(Either::Left(location)) } handshake::ServerResponse::Rejected { status_code } => { @@ -349,20 +379,26 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { let mut tcp = protocols.next(); let (host_port, dns_name) = loop { match (ip, tcp) { - (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) - => break (format!("{}:{}", ip, port), None), - (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) - => break (format!("{}:{}", ip, port), None), - (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) | - (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) | - (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) | - (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) - => break (format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?.to_owned())), + (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => { + break (format!("{}:{}", ip, port), None) + } + (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => { + break (format!("{}:{}", ip, port), None) + } + (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port))) + | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port))) + | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) + | (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => { + break ( + format!("{}:{}", &h, port), + Some(tls::dns_name_ref(&h)?.to_owned()), + ) + } (Some(_), Some(p)) => { ip = Some(p); tcp = protocols.next(); } - _ => return Err(Error::InvalidMultiaddr(addr)) + _ => return Err(Error::InvalidMultiaddr(addr)), } }; @@ -373,16 +409,16 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { let mut p2p = None; let (use_tls, path) = loop { match protocols.pop() { - p@Some(Protocol::P2p(_)) => { p2p = p } + p @ Some(Protocol::P2p(_)) => p2p = p, Some(Protocol::Ws(path)) => break (false, path.into_owned()), Some(Protocol::Wss(path)) => { if dns_name.is_none() { debug!("Missing DNS name in WSS address: {}", addr); - return Err(Error::InvalidMultiaddr(addr)) + return Err(Error::InvalidMultiaddr(addr)); } - break (true, path.into_owned()) + break (true, path.into_owned()); } - _ => return Err(Error::InvalidMultiaddr(addr)) + _ => return Err(Error::InvalidMultiaddr(addr)), } }; @@ -390,7 +426,7 @@ fn parse_ws_dial_addr(addr: Multiaddr) -> Result> { // makes up the the address for the inner TCP-based transport. let tcp_addr = match p2p { Some(p) => protocols.with(p), - None => protocols + None => protocols, }; Ok(WsAddress { @@ -408,16 +444,10 @@ fn location_to_multiaddr(location: &str) -> Result> { Ok(url) => { let mut a = Multiaddr::empty(); match url.host() { - Some(url::Host::Domain(h)) => { - a.push(Protocol::Dns(h.into())) - } - Some(url::Host::Ipv4(ip)) => { - a.push(Protocol::Ip4(ip)) - } - Some(url::Host::Ipv6(ip)) => { - a.push(Protocol::Ip6(ip)) - } - None => return Err(Error::InvalidRedirectLocation) + Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())), + Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)), + Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)), + None => return Err(Error::InvalidRedirectLocation), } if let Some(p) = url.port() { a.push(Protocol::Tcp(p)) @@ -429,7 +459,7 @@ fn location_to_multiaddr(location: &str) -> Result> { a.push(Protocol::Ws(url.path().into())) } else { debug!("unsupported scheme: {}", s); - return Err(Error::InvalidRedirectLocation) + return Err(Error::InvalidRedirectLocation); } Ok(a) } @@ -444,7 +474,7 @@ fn location_to_multiaddr(location: &str) -> Result> { pub struct Connection { receiver: BoxStream<'static, Result>, sender: Pin + Send>>, - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } /// Data received over the websocket connection. @@ -455,7 +485,7 @@ pub enum IncomingData { /// UTF-8 encoded application data. Text(Vec), /// PONG control frame data. - Pong(Vec) + Pong(Vec), } impl IncomingData { @@ -464,22 +494,34 @@ impl IncomingData { } pub fn is_binary(&self) -> bool { - if let IncomingData::Binary(_) = self { true } else { false } + if let IncomingData::Binary(_) = self { + true + } else { + false + } } pub fn is_text(&self) -> bool { - if let IncomingData::Text(_) = self { true } else { false } + if let IncomingData::Text(_) = self { + true + } else { + false + } } pub fn is_pong(&self) -> bool { - if let IncomingData::Pong(_) = self { true } else { false } + if let IncomingData::Pong(_) = self { + true + } else { + false + } } pub fn into_bytes(self) -> Vec { match self { IncomingData::Binary(d) => d, IncomingData::Text(d) => d, - IncomingData::Pong(d) => d + IncomingData::Pong(d) => d, } } } @@ -489,7 +531,7 @@ impl AsRef<[u8]> for IncomingData { match self { IncomingData::Binary(d) => d, IncomingData::Text(d) => d, - IncomingData::Pong(d) => d + IncomingData::Pong(d) => d, } } } @@ -503,7 +545,7 @@ pub enum OutgoingData { Ping(Vec), /// Send an unsolicited PONG message. /// (Incoming PINGs are answered automatically.) - Pong(Vec) + Pong(Vec), } impl fmt::Debug for Connection { @@ -514,7 +556,7 @@ impl fmt::Debug for Connection { impl Connection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { fn new(builder: connection::Builder>) -> Self { let (sender, receiver) = builder.finish(); @@ -536,29 +578,31 @@ where sender.send_pong(data).await? } quicksink::Action::Flush => sender.flush().await?, - quicksink::Action::Close => sender.close().await? + quicksink::Action::Close => sender.close().await?, } Ok(sender) }); let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async { match receiver.receive(&mut data).await { - Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => { - Some((Ok(IncomingData::Text(mem::take(&mut data))), (data, receiver))) - } - Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => { - Some((Ok(IncomingData::Binary(mem::take(&mut data))), (data, receiver))) - } + Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some(( + Ok(IncomingData::Text(mem::take(&mut data))), + (data, receiver), + )), + Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some(( + Ok(IncomingData::Binary(mem::take(&mut data))), + (data, receiver), + )), Ok(soketto::Incoming::Pong(pong)) => { Some((Ok(IncomingData::Pong(Vec::from(pong))), (data, receiver))) } Err(connection::Error::Closed) => None, - Err(e) => Some((Err(e), (data, receiver))) + Err(e) => Some((Err(e), (data, receiver))), } }); Connection { receiver: stream.boxed(), sender: Box::pin(sink), - _marker: std::marker::PhantomData + _marker: std::marker::PhantomData, } } @@ -580,22 +624,20 @@ where impl Stream for Connection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let item = ready!(self.receiver.poll_next_unpin(cx)); - let item = item.map(|result| { - result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - }); + let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e))); Poll::Ready(item) } } impl Sink for Connection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Error = io::Error; diff --git a/transports/websocket/src/lib.rs b/transports/websocket/src/lib.rs index 4473ed65d73..387aee7c72c 100644 --- a/transports/websocket/src/lib.rs +++ b/transports/websocket/src/lib.rs @@ -26,20 +26,26 @@ pub mod tls; use error::Error; use framed::Connection; -use futures::{future::BoxFuture, prelude::*, stream::BoxStream, ready}; +use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream}; use libp2p_core::{ - ConnectedPoint, - Transport, multiaddr::Multiaddr, - transport::{map::{MapFuture, MapStream}, ListenerEvent, TransportError} + transport::{ + map::{MapFuture, MapStream}, + ListenerEvent, TransportError, + }, + ConnectedPoint, Transport, }; use rw_stream_sink::RwStreamSink; -use std::{io, pin::Pin, task::{Context, Poll}}; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; /// A Websocket transport. #[derive(Debug, Clone)] pub struct WsConfig { - transport: framed::WsConfig + transport: framed::WsConfig, } impl WsConfig { @@ -92,9 +98,7 @@ impl WsConfig { impl From> for WsConfig { fn from(framed: framed::WsConfig) -> Self { - WsConfig { - transport: framed - } + WsConfig { transport: framed } } } @@ -105,7 +109,7 @@ where T::Dial: Send + 'static, T::Listener: Send + 'static, T::ListenerUpgrade: Send + 'static, - T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static + T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Output = RwStreamSink>; type Error = Error; @@ -114,11 +118,15 @@ where type Dial = MapFuture, WrapperFn>; fn listen_on(self, addr: Multiaddr) -> Result> { - self.transport.map(wrap_connection as WrapperFn).listen_on(addr) + self.transport + .map(wrap_connection as WrapperFn) + .listen_on(addr) } fn dial(self, addr: Multiaddr) -> Result> { - self.transport.map(wrap_connection as WrapperFn).dial(addr) + self.transport + .map(wrap_connection as WrapperFn) + .dial(addr) } fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { @@ -127,7 +135,8 @@ where } /// Type alias corresponding to `framed::WsConfig::Listener`. -pub type InnerStream = BoxStream<'static, Result, Error>, Error>>; +pub type InnerStream = + BoxStream<'static, Result, Error>, Error>>; /// Type alias corresponding to `framed::WsConfig::Dial` and `framed::WsConfig::ListenerUpgrade`. pub type InnerFuture = BoxFuture<'static, Result, Error>>; @@ -139,7 +148,7 @@ pub type WrapperFn = fn(Connection, ConnectedPoint) -> RwStreamSink(c: Connection, _: ConnectedPoint) -> RwStreamSink> where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { RwStreamSink::new(BytesConnection(c)) } @@ -150,7 +159,7 @@ pub struct BytesConnection(Connection); impl Stream for BytesConnection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Item = io::Result>; @@ -158,10 +167,10 @@ where loop { if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) { if item.is_data() { - return Poll::Ready(Some(Ok(item.into_bytes()))) + return Poll::Ready(Some(Ok(item.into_bytes()))); } } else { - return Poll::Ready(None) + return Poll::Ready(None); } } } @@ -169,7 +178,7 @@ where impl Sink> for BytesConnection where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, { type Error = io::Error; @@ -194,10 +203,10 @@ where #[cfg(test)] mod tests { - use libp2p_core::{Multiaddr, PeerId, Transport, multiaddr::Protocol}; - use libp2p_tcp as tcp; - use futures::prelude::*; use super::WsConfig; + use futures::prelude::*; + use libp2p_core::{multiaddr::Protocol, Multiaddr, PeerId, Transport}; + use libp2p_tcp as tcp; #[test] fn dialer_connects_to_listener_ipv4() { @@ -214,11 +223,11 @@ mod tests { async fn connect(listen_addr: Multiaddr) { let ws_config = WsConfig::new(tcp::TcpConfig::new()); - let mut listener = ws_config.clone() - .listen_on(listen_addr) - .expect("listener"); + let mut listener = ws_config.clone().listen_on(listen_addr).expect("listener"); - let addr = listener.try_next().await + let addr = listener + .try_next() + .await .expect("some event") .expect("no error") .into_new_address() @@ -228,7 +237,8 @@ mod tests { assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1)); let inbound = async move { - let (conn, _addr) = listener.try_filter_map(|e| future::ready(Ok(e.into_upgrade()))) + let (conn, _addr) = listener + .try_filter_map(|e| future::ready(Ok(e.into_upgrade()))) .try_next() .await .unwrap() @@ -236,7 +246,9 @@ mod tests { conn.await }; - let outbound = ws_config.dial(addr.with(Protocol::P2p(PeerId::random().into()))).unwrap(); + let outbound = ws_config + .dial(addr.with(Protocol::P2p(PeerId::random().into()))) + .unwrap(); let (a, b) = futures::join!(inbound, outbound); a.and(b).unwrap(); diff --git a/transports/websocket/src/tls.rs b/transports/websocket/src/tls.rs index d72535cdcc3..5aab39fe59b 100644 --- a/transports/websocket/src/tls.rs +++ b/transports/websocket/src/tls.rs @@ -18,14 +18,14 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use futures_rustls::{rustls, webpki, TlsConnector, TlsAcceptor}; +use futures_rustls::{rustls, webpki, TlsAcceptor, TlsConnector}; use std::{fmt, io, sync::Arc}; /// TLS configuration. #[derive(Clone)] pub struct Config { pub(crate) client: TlsConnector, - pub(crate) server: Option + pub(crate) server: Option, } impl fmt::Debug for Config { @@ -60,7 +60,7 @@ impl Config { /// Create a new TLS configuration with the given server key and certificate chain. pub fn new(key: PrivateKey, certs: I) -> Result where - I: IntoIterator + I: IntoIterator, { let mut builder = Config::builder(); builder.server(key, certs)?; @@ -71,45 +71,55 @@ impl Config { pub fn client() -> Self { Config { client: Arc::new(client_config()).into(), - server: None + server: None, } } /// Create a new TLS configuration builder. pub fn builder() -> Builder { - Builder { client: client_config(), server: None } + Builder { + client: client_config(), + server: None, + } } } /// Setup the rustls client configuration. fn client_config() -> rustls::ClientConfig { let mut client = rustls::ClientConfig::new(); - client.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + client + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); client } /// TLS configuration builder. pub struct Builder { client: rustls::ClientConfig, - server: Option + server: Option, } impl Builder { /// Set server key and certificate chain. pub fn server(&mut self, key: PrivateKey, certs: I) -> Result<&mut Self, Error> where - I: IntoIterator + I: IntoIterator, { let mut server = rustls::ServerConfig::new(rustls::NoClientAuth::new()); let certs = certs.into_iter().map(|c| c.0).collect(); - server.set_single_cert(certs, key.0).map_err(|e| Error::Tls(Box::new(e)))?; + server + .set_single_cert(certs, key.0) + .map_err(|e| Error::Tls(Box::new(e)))?; self.server = Some(server); Ok(self) } /// Add an additional trust anchor. pub fn add_trust(&mut self, cert: &Certificate) -> Result<&mut Self, Error> { - self.client.root_store.add(&cert.0).map_err(|e| Error::Tls(Box::new(e)))?; + self.client + .root_store + .add(&cert.0) + .map_err(|e| Error::Tls(Box::new(e)))?; Ok(self) } @@ -117,7 +127,7 @@ impl Builder { pub fn finish(self) -> Config { Config { client: Arc::new(self.client).into(), - server: self.server.map(|s| Arc::new(s).into()) + server: self.server.map(|s| Arc::new(s).into()), } } } @@ -155,7 +165,7 @@ impl std::error::Error for Error { match self { Error::Io(e) => Some(e), Error::Tls(e) => Some(&**e), - Error::InvalidDnsName(_) => None + Error::InvalidDnsName(_) => None, } } } From 62e88f888959541842e5af0d86f9789048d5be1f Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sat, 20 Nov 2021 12:39:01 +0100 Subject: [PATCH 21/23] misc/multistream-select/src/protocol.rs: Fix typo --- misc/multistream-select/src/protocol.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/misc/multistream-select/src/protocol.rs b/misc/multistream-select/src/protocol.rs index 920d78919f2..b7f1611c27f 100644 --- a/misc/multistream-select/src/protocol.rs +++ b/misc/multistream-select/src/protocol.rs @@ -303,7 +303,7 @@ impl MessageIO { } } - /// Draops the [`MessageIO`] resource, yielding the underlying I/O stream. + /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream. /// /// # Panics /// From 624c2de6ad2fce91e0d86cade181f3e7917159ef Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sat, 20 Nov 2021 12:55:04 +0100 Subject: [PATCH 22/23] misc/multistream-select: Supress needless collect warning --- misc/multistream-select/src/dialer_select.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 5e8b14de63a..decd266348d 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -333,6 +333,7 @@ where *this.state = SeqState::SendProtocol { io, protocol }; } Role::Responder => { + #[allow(clippy::needless_collect)] let protocols: Vec<_> = this.protocols.collect(); *this.state = SeqState::Responder { responder: crate::listener_select::listener_select_proto_no_header( From 29eb0970805931534f59503ce880ae2df442b3d6 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Sun, 21 Nov 2021 13:21:06 +0100 Subject: [PATCH 23/23] misc/multistream-select: Fix version usage in tests --- core/src/upgrade/apply.rs | 9 +++++++++ misc/multistream-select/tests/transport.rs | 20 +++++++++++--------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index b63450468ae..cd8e81b8279 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -323,6 +323,15 @@ impl From for multistream_select::Version { } } +impl From for AuthenticationVersion { + fn from(v: Version) -> Self { + match v { + Version::V1 => AuthenticationVersion::V1, + Version::V1Lazy => AuthenticationVersion::V1Lazy, + } + } +} + /// Applies an authentication upgrade to the inbound or outbound direction of a connection. /// /// Note: This is like [`apply`] with additional support for transports allowing simultaneously diff --git a/misc/multistream-select/tests/transport.rs b/misc/multistream-select/tests/transport.rs index 53f079a7ac1..d5c47354d4c 100644 --- a/misc/multistream-select/tests/transport.rs +++ b/misc/multistream-select/tests/transport.rs @@ -39,18 +39,20 @@ use std::{ type TestTransport = transport::Boxed<(PeerId, StreamMuxerBox)>; type TestNetwork = Network; -// TODO: Fix _up -fn mk_transport(_up: upgrade::Version) -> (PeerId, TestTransport) { +fn mk_transport(version: upgrade::Version) -> (PeerId, TestTransport) { let keys = identity::Keypair::generate_ed25519(); let id = keys.public().to_peer_id(); ( id, MemoryTransport::default() .upgrade() - .authenticate(PlainText2Config { - local_public_key: keys.public(), - }) - .multiplex(MplexConfig::default()) + .authenticate_with_version( + PlainText2Config { + local_public_key: keys.public(), + }, + version.into(), + ) + .multiplex_with_version(MplexConfig::default(), version) .boxed(), ) } @@ -61,9 +63,9 @@ fn mk_transport(_up: upgrade::Version) -> (PeerId, TestTransport) { fn transport_upgrade() { let _ = env_logger::try_init(); - fn run(up: upgrade::Version) { - let (dialer_id, dialer_transport) = mk_transport(up); - let (listener_id, listener_transport) = mk_transport(up); + fn run(version: upgrade::Version) { + let (dialer_id, dialer_transport) = mk_transport(version); + let (listener_id, listener_transport) = mk_transport(version); let listen_addr = Multiaddr::from(Protocol::Memory(random::()));