diff --git a/src/socket/icmp.rs b/src/socket/icmp.rs index 2a301275a..eb42d130a 100644 --- a/src/socket/icmp.rs +++ b/src/socket/icmp.rs @@ -61,12 +61,14 @@ impl std::error::Error for SendError {} #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum RecvError { Exhausted, + Truncated(IpAddress), } impl core::fmt::Display for RecvError { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated(_) => write!(f, "truncated"), } } } @@ -130,8 +132,8 @@ impl<'a> Socket<'a> { /// Create an ICMP socket with the given buffers. pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> { Socket { - rx_buffer: rx_buffer, - tx_buffer: tx_buffer, + rx_buffer, + tx_buffer, endpoint: Default::default(), hop_limit: None, #[cfg(feature = "async")] @@ -394,12 +396,20 @@ impl<'a> Socket<'a> { /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, /// and return the amount of octets copied as well as the `IpAddress` /// + /// The payload is copied partially when the size of the given slice is smaller than the size + /// of the payload. In this case, a `RecvError::Truncated` error is returned. + /// /// See also [recv](#method.recv). pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpAddress), RecvError> { let (buffer, endpoint) = self.recv()?; let length = cmp::min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); - Ok((length, endpoint)) + + if data.len() < buffer.len() { + Err(RecvError::Truncated(endpoint)) + } else { + Ok((length, endpoint)) + } } /// Filter determining which packets received by the interface are appended to @@ -555,7 +565,7 @@ impl<'a> Socket<'a> { dst_addr, next_header: IpProtocol::Icmp, payload_len: repr.buffer_len(), - hop_limit: hop_limit, + hop_limit, }); emit(cx, (ip_repr, IcmpRepr::Ipv4(repr))) } @@ -592,7 +602,7 @@ impl<'a> Socket<'a> { dst_addr, next_header: IpProtocol::Icmpv6, payload_len: repr.buffer_len(), - hop_limit: hop_limit, + hop_limit, }); emit(cx, (ip_repr, IcmpRepr::Ipv6(repr))) } @@ -1096,6 +1106,41 @@ mod test_ipv6 { assert!(!socket.can_recv()); } + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_truncated_recv_slice(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); + + let checksum = ChecksumCapabilities::default(); + + let mut bytes = [0xff; 24]; + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + let data = &*packet.into_inner(); + + assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); + assert!(socket.can_recv()); + + assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); + + let mut buffer = [0u8; 1]; + assert_eq!(socket.recv_slice(&mut buffer[..]), Err(RecvError::Truncated(REMOTE_IPV6.into()))); + assert_eq!(buffer[0], data[0]); + assert!(!socket.can_recv()); + } + #[rstest] #[case::ethernet(Medium::Ethernet)] #[cfg(feature = "medium-ethernet")] diff --git a/src/socket/raw.rs b/src/socket/raw.rs index 4f85d3227..b625202ff 100644 --- a/src/socket/raw.rs +++ b/src/socket/raw.rs @@ -57,12 +57,14 @@ impl std::error::Error for SendError {} #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum RecvError { Exhausted, + Truncated, } impl core::fmt::Display for RecvError { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated => write!(f, "truncated"), } } } @@ -273,12 +275,20 @@ impl<'a> Socket<'a> { /// Dequeue a packet, and copy the payload into the given slice. /// + /// The payload is copied partially when the size of the given slice is smaller than the size + /// of the payload. In this case, a `RecvError::Truncated` error is returned. + /// /// See also [recv](#method.recv). pub fn recv_slice(&mut self, data: &mut [u8]) -> Result { let buffer = self.recv()?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); - Ok(length) + + if data.len() < buffer.len() { + Err(RecvError::Truncated) + } else { + Ok(length) + } } /// Peek at a packet in the receive buffer and return a pointer to the @@ -308,7 +318,12 @@ impl<'a> Socket<'a> { let buffer = self.peek()?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); - Ok(length) + + if data.len() < buffer.len() { + Err(RecvError::Truncated) + } else { + Ok(length) + } } pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool { @@ -602,7 +617,7 @@ mod test { socket.process(&mut cx, &$hdr, &$payload); let mut slice = [0; 4]; - assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4)); + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); assert_eq!(&slice, &$packet[..slice.len()]); } @@ -641,9 +656,9 @@ mod test { socket.process(&mut cx, &$hdr, &$payload); let mut slice = [0; 4]; - assert_eq!(socket.peek_slice(&mut slice[..]), Ok(4)); + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated)); assert_eq!(&slice, &$packet[..slice.len()]); - assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4)); + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); assert_eq!(&slice, &$packet[..slice.len()]); assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted)); } diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index 8a40f8261..8f41196f6 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -1331,7 +1331,7 @@ impl<'a> Socket<'a> { // Rate-limit to 1 per second max. self.challenge_ack_timer = cx.now() + Duration::from_secs(1); - return Some(self.ack_reply(ip_repr, repr)); + Some(self.ack_reply(ip_repr, repr)) } pub(crate) fn accepts(&self, _cx: &mut Context, ip_repr: &IpRepr, repr: &TcpRepr) -> bool { diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 5cc1da928..1962ecccf 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -88,12 +88,14 @@ impl std::error::Error for SendError {} #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum RecvError { Exhausted, + Truncated(UdpMetadata), } impl core::fmt::Display for RecvError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated(_) => write!(f, "truncated"), } } } @@ -393,12 +395,20 @@ impl<'a> Socket<'a> { /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, /// and return the amount of octets copied as well as the endpoint. /// + /// The payload is copied partially when the size of the given slice is smaller than the size + /// of the payload. In this case, a `RecvError::Truncated` error is returned. + /// /// See also [recv](#method.recv). pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> { let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); - Ok((length, endpoint)) + + if data.len() < buffer.len() { + Err(RecvError::Truncated(endpoint)) + } else { + Ok((length, endpoint)) + } } /// Peek at a packet received from a remote endpoint, and return the endpoint as well @@ -426,12 +436,20 @@ impl<'a> Socket<'a> { /// packet from the receive buffer. /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). /// + /// The payload is copied partially when the size of the given slice is smaller than the size + /// of the payload. In this case, a `RecvError::Truncated` error is returned. + /// /// See also [peek](#method.peek). pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> { let (buffer, endpoint) = self.peek()?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); - Ok((length, endpoint)) + + if data.len() < buffer.len() { + Err(RecvError::Truncated(*endpoint)) + } else { + Ok((length, endpoint)) + } } pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, repr: &UdpRepr) -> bool { @@ -853,7 +871,7 @@ mod test { let mut slice = [0; 4]; assert_eq!( socket.recv_slice(&mut slice[..]), - Ok((4, REMOTE_END.into())) + Err(RecvError::Truncated(REMOTE_END.into())) ); assert_eq!(&slice, b"abcd"); } @@ -884,12 +902,12 @@ mod test { let mut slice = [0; 4]; assert_eq!( socket.peek_slice(&mut slice[..]), - Ok((4, &REMOTE_END.into())) + Err(RecvError::Truncated(REMOTE_END.into())) ); assert_eq!(&slice, b"abcd"); assert_eq!( socket.recv_slice(&mut slice[..]), - Ok((4, REMOTE_END.into())) + Err(RecvError::Truncated(REMOTE_END.into())) ); assert_eq!(&slice, b"abcd"); assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted));