From f674458ff4f767f36651c3791a17addd8c3b0598 Mon Sep 17 00:00:00 2001 From: Phoenix Kahlo Date: Wed, 14 Feb 2024 22:31:32 -0600 Subject: [PATCH] (todo refactor commits) address feedback --- quinn-proto/src/connection/mod.rs | 12 ++++- quinn-proto/src/endpoint.rs | 29 +++++------- quinn-proto/src/tests/mod.rs | 23 +++++++-- quinn-proto/src/tests/util.rs | 79 ++++++++++++++++++++++++------- quinn/src/connection.rs | 5 +- quinn/src/endpoint.rs | 29 ++++++------ quinn/src/incoming_connection.rs | 26 +++------- quinn/src/lib.rs | 2 +- quinn/src/tests.rs | 6 ++- 9 files changed, 134 insertions(+), 77 deletions(-) diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 95eb49ae96..eeba0eed59 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -2203,7 +2203,10 @@ impl Connection { } ConnectionError::VersionMismatch => State::Draining, ConnectionError::LocallyClosed => { - unreachable!("LocallyClosed isn't generated by packet processing") + unreachable!("LocallyClosed isn't generated by packet processing"); + } + ConnectionError::ConnectionLimitExceeded => { + unreachable!("ConnectionLimitExceeded isn't generated by packet processing"); } }; } @@ -3501,6 +3504,9 @@ pub enum ConnectionError { /// The local application closed the connection #[error("closed")] LocallyClosed, + /// The connection could not be created without exceeding the endpoint's connection limit + #[error("connection limit exceeded")] + ConnectionLimitExceeded, } impl From for ConnectionError { @@ -3520,7 +3526,9 @@ impl From for io::Error { TimedOut => io::ErrorKind::TimedOut, Reset => io::ErrorKind::ConnectionReset, ApplicationClosed(_) | ConnectionClosed(_) => io::ErrorKind::ConnectionAborted, - TransportError(_) | VersionMismatch | LocallyClosed => io::ErrorKind::Other, + TransportError(_) | VersionMismatch | LocallyClosed | ConnectionLimitExceeded => { + io::ErrorKind::Other + } }; Self::new(kind, x) } diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 2e531fd00c..8ef98957a2 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -506,7 +506,7 @@ impl Endpoint { incoming: IncomingConnection, now: Instant, buf: &mut BytesMut, - ) -> Result<(ConnectionHandle, Connection), Option> { + ) -> Result<(ConnectionHandle, Connection), (ConnectionError, Option)> { self.check_connection_limit( incoming.version, incoming.addresses, @@ -514,7 +514,7 @@ impl Endpoint { &incoming.src_cid, buf, ) - .map_err(Some)?; + .map_err(|response| (ConnectionError::ConnectionLimitExceeded, Some(response)))?; let server_config = self.server_config.as_ref().unwrap().clone(); @@ -567,17 +567,18 @@ impl Endpoint { Err(e) => { debug!("handshake failed: {}", e); self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained)); - match e { - ConnectionError::TransportError(e) => Err(Some(self.initial_close( + let response = match e { + ConnectionError::TransportError(ref e) => Some(self.initial_close( incoming.version, incoming.addresses, &incoming.crypto, &incoming.src_cid, - e, + e.clone(), buf, - ))), - _ => Err(None), - } + )), + _ => None, + }; + Err((e, response)) } } } @@ -1050,9 +1051,11 @@ pub enum ConnectError { UnsupportedVersion, } -/// Error for attempting to retry an [`IncomingConnection`] that can not be retried +/// Error for attempting to retry an [`IncomingConnection`] which already bears an address +/// validation token from a previous retry #[derive(Debug, Error)] -pub struct RetryError(pub IncomingConnection); +#[error("retry() with validated IncomingConnection")] +pub struct RetryError(IncomingConnection); impl RetryError { /// Get the [`IncomingConnection`] @@ -1061,12 +1064,6 @@ impl RetryError { } } -impl fmt::Display for RetryError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str("retry() with validated IncomingConnection") - } -} - /// Reset Tokens which are associated with peer socket addresses /// /// The standard `HashMap` is used since both `SocketAddr` and `ResetToken` are diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index cdb3434eea..0a8cc4c5f6 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -166,7 +166,7 @@ fn draft_version_compat() { fn stateless_retry() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.use_retry = true; + pair.server.retry_policy = RetryPolicy::yes(); pair.connect(); } @@ -455,7 +455,7 @@ fn high_latency_handshake() { fn zero_rtt_happypath() { let _guard = subscribe(); let mut pair = Pair::default(); - pair.server.use_retry = true; + pair.server.retry_policy = RetryPolicy::yes(); let config = client_config(); // Establish normal connection @@ -1980,7 +1980,7 @@ fn connect_too_low_mtu() { pair.begin_connect(client_config()); pair.drive(); - pair.server.assert_no_accept() + pair.server.assert_no_accept(); } #[test] @@ -2750,3 +2750,20 @@ fn reject_new_connections() { pair.server.assert_no_accept(); assert!(pair.client.connections.get(&client_ch).unwrap().is_closed()); } + +#[test] +fn reject_manually() { + let _guard = subscribe(); + let mut pair = Pair::default(); + pair.server.retry_policy = RetryPolicy(Box::new(|_| IncomingConnectionResponse::Reject)); + + // The server should now reject incoming connections. + let client_ch = pair.begin_connect(client_config()); + pair.drive(); + let e = pair.server.assert_accept_error(); + assert!( + matches!(e, crate::ConnectionError::ConnectionClosed(_)), + "wrong error" + ); + assert!(pair.client.connections.get(&client_ch).unwrap().is_closed()); +} diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 67929b0c94..7ead7f67dc 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -287,12 +287,39 @@ pub(super) struct TestEndpoint { pub(super) outbound: VecDeque<(Transmit, Bytes)>, delayed: VecDeque<(Transmit, Bytes)>, pub(super) inbound: VecDeque<(Instant, Option, BytesMut)>, - accepted: Option, + accepted: Option>, pub(super) connections: HashMap, conn_events: HashMap>, pub(super) captured_packets: Vec>, pub(super) capture_inbound_packets: bool, - pub(super) use_retry: bool, + pub(super) retry_policy: RetryPolicy, +} + +pub(super) struct RetryPolicy( + pub(super) Box IncomingConnectionResponse>, +); + +impl RetryPolicy { + pub(super) fn no() -> Self { + Self(Box::new(|_| IncomingConnectionResponse::Accept)) + } + + pub(super) fn yes() -> Self { + Self(Box::new(|incoming| { + if incoming.remote_address_validated() { + IncomingConnectionResponse::Accept + } else { + IncomingConnectionResponse::Retry + } + })) + } +} + +#[derive(Debug, Copy, Clone)] +pub(super) enum IncomingConnectionResponse { + Accept, + Reject, + Retry, } impl TestEndpoint { @@ -319,7 +346,7 @@ impl TestEndpoint { conn_events: HashMap::default(), captured_packets: Vec::new(), capture_inbound_packets: false, - use_retry: false, + retry_policy: RetryPolicy::no(), } } @@ -343,10 +370,16 @@ impl TestEndpoint { { match event { DatagramEvent::NewConnection(incoming) => { - if self.use_retry && !incoming.remote_address_validated() { - self.retry(incoming); - } else { - self.try_accept(incoming, now); + match (self.retry_policy.0)(&incoming) { + IncomingConnectionResponse::Accept => { + let _ = self.try_accept(incoming, now); + } + IncomingConnectionResponse::Reject => { + self.reject(incoming); + } + IncomingConnectionResponse::Retry => { + self.retry(incoming); + } } } DatagramEvent::ConnectionEvent(ch, event) => { @@ -427,23 +460,23 @@ impl TestEndpoint { &mut self, incoming: IncomingConnection, now: Instant, - ) -> Option { + ) -> Result { let mut buf = BytesMut::new(); - match self.endpoint.accept(incoming, now, &mut buf) { - Ok((ch, conn)) => { + self.endpoint + .accept(incoming, now, &mut buf) + .map(|(ch, conn)| { self.connections.insert(ch, conn); - self.accepted = Some(ch); - Some(ch) - } - Err(transmit) => { + self.accepted = Some(Ok(ch)); + ch + }) + .map_err(|(e, transmit)| { if let Some(transmit) = transmit { let size = transmit.size; self.outbound .extend(split_transmit(transmit, buf.split_to(size).freeze())); } - None - } - } + e + }) } pub(super) fn retry(&mut self, incoming: IncomingConnection) { @@ -463,7 +496,17 @@ impl TestEndpoint { } pub(super) fn assert_accept(&mut self) -> ConnectionHandle { - self.accepted.take().expect("server didn't connect") + self.accepted + .take() + .expect("server didn't try connecting") + .expect("server experienced error connecting") + } + + pub(super) fn assert_accept_error(&mut self) -> ConnectionError { + self.accepted + .take() + .expect("server didn't try connecting") + .expect_err("server did unexpectedly connect without error") } pub(super) fn assert_no_accept(&self) { diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 5bed610cb5..a3c2076f1a 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -28,7 +28,6 @@ use proto::congestion::Controller; /// In-progress connection attempt future #[derive(Debug)] -#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] pub struct Connecting { conn: Option, connected: oneshot::Receiver, @@ -152,6 +151,8 @@ impl Connecting { /// /// On all non-supported platforms the local IP address will not be available, /// and the method will return `None`. + /// + /// Will panic if called after `poll` has returned `Ready`. pub fn local_ip(&self) -> Option { let conn = self.conn.as_ref().unwrap(); let inner = conn.state.lock("local_ip"); @@ -159,7 +160,7 @@ impl Connecting { inner.inner.local_ip() } - /// The peer's UDP address. + /// The peer's UDP address /// /// Will panic if called after `poll` has returned `Ready`. pub fn remote_address(&self) -> SocketAddr { diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 564eb2524d..8906c154a4 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -16,7 +16,8 @@ use crate::runtime::{default_runtime, AsyncUdpSocket, Runtime}; use bytes::{Bytes, BytesMut}; use pin_project_lite::pin_project; use proto::{ - self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig, + self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, + ServerConfig, }; use rustc_hash::FxHashMap; use tokio::sync::{futures::Notified, mpsc, Notify}; @@ -137,9 +138,11 @@ impl Endpoint { /// Get the next incoming connection attempt from a client /// - /// Yields [`IncomingConnection`]s, which can be `await`ed to obtain the final - /// [`Connection`](crate::Connection) or used in more complex ways such as to perform retries, - /// or `None` if the endpoint is [`close`](Self::close)d. + /// Yields [`IncomingConnection`]s, or `None` if the endpoint is [`close`](Self::close)d. + /// [`IncomingConnection`] can be `await`ed to obtain the final + /// [`Connection`](crate::Connection), or used to eg. filter connection attempts or force + /// address validation, or converted into an intermediate `Connecting` future which can be + /// used to eg. send 0.5-RTT data. pub fn accept(&self) -> Accept<'_> { Accept { endpoint: self, @@ -770,24 +773,22 @@ impl EndpointInner { &self, incoming: proto::IncomingConnection, mut response_buffer: BytesMut, - ) -> Option { + ) -> Result { let mut state = self.state.lock().unwrap(); - match state + state .inner .accept(incoming, Instant::now(), &mut response_buffer) - { - Ok((handle, conn)) => { + .map(|(handle, conn)| { let socket = state.socket.clone(); let runtime = state.runtime.clone(); - Some(state.connections.insert(handle, conn, socket, runtime)) - } - Err(response) => { + state.connections.insert(handle, conn, socket, runtime) + }) + .map_err(|(e, response)| { if let Some(transmit) = response { state.transmit_state.respond(transmit, response_buffer); } - None - } - } + e + }) } pub(crate) fn reject( diff --git a/quinn/src/incoming_connection.rs b/quinn/src/incoming_connection.rs index 503594c9d7..247cf75194 100644 --- a/quinn/src/incoming_connection.rs +++ b/quinn/src/incoming_connection.rs @@ -16,7 +16,6 @@ use crate::{ }; /// An incoming connection for which the server has not yet begun its part of the handshake -#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] pub struct IncomingConnection(Option); struct State { @@ -60,16 +59,7 @@ impl IncomingConnection { /// Attempt to accept this incoming connection (an error may still occur) pub fn accept(mut self) -> Result { let state = self.0.take().unwrap(); - state - .endpoint - .accept(state.inner, state.response_buffer) - .ok_or_else(|| { - ConnectionError::TransportError(proto::TransportError { - code: proto::TransportErrorCode::PROTOCOL_VIOLATION, - frame: None, - reason: "Problem with initial packet".to_owned(), - }) - }) + state.endpoint.accept(state.inner, state.response_buffer) } /// Reject this incoming connection attempt @@ -88,7 +78,7 @@ impl IncomingConnection { .retry(state.inner, state.response_buffer) .map_err(|(e, response_buffer)| { RetryError(Self(Some(State { - inner: e.0, + inner: e.into_incoming(), endpoint: state.endpoint, response_buffer, }))) @@ -121,9 +111,11 @@ impl fmt::Debug for IncomingConnection { } } -/// Error for attempting to retry an [`IncomingConnection`] that can not be retried +/// Error for attempting to retry an [`IncomingConnection`] which already bears an address +/// validation token from a previous retry #[derive(Debug, Error)] -pub struct RetryError(pub IncomingConnection); +#[error("retry() with validated IncomingConnection")] +pub struct RetryError(IncomingConnection); impl RetryError { /// Get the [`IncomingConnection`] @@ -132,12 +124,6 @@ impl RetryError { } } -impl fmt::Display for RetryError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str("retry() with validated IncomingConnection") - } -} - /// Basic adapter to let [`IncomingConnection`] be `await`-ed like a [`Connecting`] #[derive(Debug)] pub struct IncomingConnectionFuture(Result); diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 1658051f15..dbf2dabc41 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -65,7 +65,7 @@ use bytes::Bytes; pub use proto::{ congestion, crypto, AckFrequencyConfig, ApplicationClose, Chunk, ClientConfig, ConfigError, ConnectError, ConnectionClose, ConnectionError, EndpointConfig, IdleTimeout, - MtuDiscoveryConfig, ServerConfig, StreamId, Transmit, TransportConfig, VarInt, + MtuDiscoveryConfig, ServerConfig, StreamId, Transmit, TransportConfig, TransportError, VarInt, }; #[cfg(feature = "tls-rustls")] pub use rustls; diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index b75fe8ec95..97a4e0058e 100755 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -217,11 +217,15 @@ async fn ip_blocking() { }); tokio::join!( async move { - client_1 + let e = client_1 .connect(server_addr, "localhost") .unwrap() .await .expect_err("server should have blocked this"); + assert!( + matches!(e, crate::ConnectionError::ConnectionClosed(_)), + "wrong error" + ); }, async move { client_2