diff --git a/src/socket/icmp.rs b/src/socket/icmp.rs index 2a301275a..0c4516070 100644 --- a/src/socket/icmp.rs +++ b/src/socket/icmp.rs @@ -61,12 +61,14 @@ impl std::error::Error for SendError {} #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum RecvError { Exhausted, + Truncated, } impl core::fmt::Display for RecvError { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated => write!(f, "truncated"), } } } @@ -130,8 +132,8 @@ impl<'a> Socket<'a> { /// Create an ICMP socket with the given buffers. pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> { Socket { - rx_buffer: rx_buffer, - tx_buffer: tx_buffer, + rx_buffer, + tx_buffer, endpoint: Default::default(), hop_limit: None, #[cfg(feature = "async")] @@ -394,9 +396,17 @@ impl<'a> Socket<'a> { /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, /// and return the amount of octets copied as well as the `IpAddress` /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// the packet is dropped and a `RecvError::Truncated` error is returned. + /// /// See also [recv](#method.recv). pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpAddress), RecvError> { let (buffer, endpoint) = self.recv()?; + + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + let length = cmp::min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok((length, endpoint)) @@ -555,7 +565,7 @@ impl<'a> Socket<'a> { dst_addr, next_header: IpProtocol::Icmp, payload_len: repr.buffer_len(), - hop_limit: hop_limit, + hop_limit, }); emit(cx, (ip_repr, IcmpRepr::Ipv4(repr))) } @@ -592,7 +602,7 @@ impl<'a> Socket<'a> { dst_addr, next_header: IpProtocol::Icmpv6, payload_len: repr.buffer_len(), - hop_limit: hop_limit, + hop_limit, }); emit(cx, (ip_repr, IcmpRepr::Ipv6(repr))) } @@ -1096,6 +1106,42 @@ mod test_ipv6 { assert!(!socket.can_recv()); } + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_truncated_recv_slice(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); + + let checksum = ChecksumCapabilities::default(); + + let mut bytes = [0xff; 24]; + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + + assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); + assert!(socket.can_recv()); + + assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); + + let mut buffer = [0u8; 1]; + assert_eq!( + socket.recv_slice(&mut buffer[..]), + Err(RecvError::Truncated) + ); + assert!(!socket.can_recv()); + } + #[rstest] #[case::ethernet(Medium::Ethernet)] #[cfg(feature = "medium-ethernet")] diff --git a/src/socket/raw.rs b/src/socket/raw.rs index 4f85d3227..bb3a204ad 100644 --- a/src/socket/raw.rs +++ b/src/socket/raw.rs @@ -57,12 +57,14 @@ impl std::error::Error for SendError {} #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum RecvError { Exhausted, + Truncated, } impl core::fmt::Display for RecvError { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated => write!(f, "truncated"), } } } @@ -273,9 +275,16 @@ impl<'a> Socket<'a> { /// Dequeue a packet, and copy the payload into the given slice. /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// the packet is dropped and a `RecvError::Truncated` error is returned. + /// /// See also [recv](#method.recv). pub fn recv_slice(&mut self, data: &mut [u8]) -> Result { let buffer = self.recv()?; + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok(length) @@ -303,9 +312,16 @@ impl<'a> Socket<'a> { /// and return the amount of octets copied without removing the packet from the receive buffer. /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned. + /// /// See also [peek](#method.peek). pub fn peek_slice(&mut self, data: &mut [u8]) -> Result { let buffer = self.peek()?; + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok(length) @@ -602,8 +618,7 @@ mod test { socket.process(&mut cx, &$hdr, &$payload); let mut slice = [0; 4]; - assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4)); - assert_eq!(&slice, &$packet[..slice.len()]); + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); } #[rstest] @@ -641,10 +656,8 @@ mod test { socket.process(&mut cx, &$hdr, &$payload); let mut slice = [0; 4]; - assert_eq!(socket.peek_slice(&mut slice[..]), Ok(4)); - assert_eq!(&slice, &$packet[..slice.len()]); - assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4)); - assert_eq!(&slice, &$packet[..slice.len()]); + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated)); + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted)); } } diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 8a40f8261..8f41196f6 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -1331,7 +1331,7 @@ impl<'a> Socket<'a> { // Rate-limit to 1 per second max. self.challenge_ack_timer = cx.now() + Duration::from_secs(1); - return Some(self.ack_reply(ip_repr, repr)); + Some(self.ack_reply(ip_repr, repr)) } pub(crate) fn accepts(&self, _cx: &mut Context, ip_repr: &IpRepr, repr: &TcpRepr) -> bool { diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 5cc1da928..9eb2bcf59 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -88,12 +88,14 @@ impl std::error::Error for SendError {} #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum RecvError { Exhausted, + Truncated, } impl core::fmt::Display for RecvError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated => write!(f, "truncated"), } } } @@ -393,9 +395,17 @@ impl<'a> Socket<'a> { /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, /// and return the amount of octets copied as well as the endpoint. /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// the packet is dropped and a `RecvError::Truncated` error is returned. + /// /// See also [recv](#method.recv). pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> { let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?; + + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok((length, endpoint)) @@ -426,9 +436,17 @@ impl<'a> Socket<'a> { /// packet from the receive buffer. /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned. + /// /// See also [peek](#method.peek). pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> { let (buffer, endpoint) = self.peek()?; + + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok((length, endpoint)) @@ -851,11 +869,7 @@ mod test { ); let mut slice = [0; 4]; - assert_eq!( - socket.recv_slice(&mut slice[..]), - Ok((4, REMOTE_END.into())) - ); - assert_eq!(&slice, b"abcd"); + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); } #[rstest] @@ -882,16 +896,8 @@ mod test { ); let mut slice = [0; 4]; - assert_eq!( - socket.peek_slice(&mut slice[..]), - Ok((4, &REMOTE_END.into())) - ); - assert_eq!(&slice, b"abcd"); - assert_eq!( - socket.recv_slice(&mut slice[..]), - Ok((4, REMOTE_END.into())) - ); - assert_eq!(&slice, b"abcd"); + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated)); + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted)); }