diff --git a/Cargo.lock b/Cargo.lock index 866f8e719..c542522c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,6 +484,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "hex-literal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" + [[package]] name = "home" version = "0.5.5" @@ -1441,6 +1447,7 @@ dependencies = [ "derive_more", "dns-lookup", "etcetera", + "hex-literal", "humantime", "indexmap", "itertools", diff --git a/Cargo.toml b/Cargo.toml index 89c5fda38..23732fe3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ windows-sys = { version = "0.48.0", features = [ ] } [dev-dependencies] +hex-literal = "0.4.1" rand = "0.8.5" test-case = "3.2.1" diff --git a/src/backend.rs b/src/backend.rs index f2042dfcc..74dbe43dc 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -7,7 +7,8 @@ use std::sync::Arc; use std::time::Duration; use tracing::instrument; use trippy::tracing::{ - Probe, ProbeStatus, Tracer, TracerChannel, TracerChannelConfig, TracerConfig, TracerRound, + Probe, ProbeResponseExtensions, ProbeStatus, Tracer, TracerChannel, TracerChannelConfig, + TracerConfig, TracerRound, }; /// The state of all hops in a trace. @@ -114,6 +115,8 @@ impl Trace { } let host = probe.host.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); *hop.addrs.entry(host).or_default() += 1; + // TODO should we combine extensions across rounds? + hop.extensions = probe.extensions.clone(); } ProbeStatus::Awaited => { let index = usize::from(probe.ttl.0) - 1; @@ -164,6 +167,7 @@ pub struct Hop { mean: f64, m2: f64, samples: Vec, + extensions: Option, } impl Hop { @@ -186,6 +190,10 @@ impl Hop { self.addrs.len() } + pub fn extensions(&self) -> Option<&ProbeResponseExtensions> { + self.extensions.as_ref() + } + /// The total number of probes sent. pub fn total_sent(&self) -> usize { self.total_sent @@ -259,6 +267,7 @@ impl Default for Hop { mean: 0f64, m2: 0f64, samples: Vec::default(), + extensions: None, } } } diff --git a/src/config.rs b/src/config.rs index fa2f57c3c..6f704ad46 100644 --- a/src/config.rs +++ b/src/config.rs @@ -191,6 +191,7 @@ pub struct TrippyConfig { pub max_inflight: u8, pub initial_sequence: u16, pub tos: u8, + pub icmp_extensions: bool, pub read_timeout: Duration, pub packet_size: u16, pub payload_pattern: u8, @@ -340,6 +341,13 @@ impl TryFrom<(Args, u16)> for TrippyConfig { cfg_file_strategy.tos, constants::DEFAULT_STRATEGY_TOS, ); + + let icmp_extensions = cfg_layer_bool_flag( + args.icmp_extensions, + cfg_file_strategy.icmp_extensions, + false, + ); + let read_timeout = cfg_layer( args.read_timeout, cfg_file_strategy.read_timeout, @@ -510,6 +518,7 @@ impl TryFrom<(Args, u16)> for TrippyConfig { packet_size, payload_pattern, tos, + icmp_extensions, source_addr, interface, port_direction, diff --git a/src/config/cmd.rs b/src/config/cmd.rs index e5e435dcc..4fb4e9d7e 100644 --- a/src/config/cmd.rs +++ b/src/config/cmd.rs @@ -105,6 +105,10 @@ pub struct Args { #[arg(short = 'Q', long)] pub tos: Option, + /// Parse ICMP extensions + #[arg(short = 'e', long)] + pub icmp_extensions: bool, + /// The socket read timeout [default: 10ms] #[arg(long)] pub read_timeout: Option, diff --git a/src/config/file.rs b/src/config/file.rs index eeebd595d..c51e5c6be 100644 --- a/src/config/file.rs +++ b/src/config/file.rs @@ -110,6 +110,7 @@ pub struct ConfigStrategy { pub packet_size: Option, pub payload_pattern: Option, pub tos: Option, + pub icmp_extensions: Option, pub read_timeout: Option, } diff --git a/src/frontend/render/table.rs b/src/frontend/render/table.rs index 2c6683ca1..689bac9ad 100644 --- a/src/frontend/render/table.rs +++ b/src/frontend/render/table.rs @@ -13,6 +13,7 @@ use ratatui::widgets::{Block, BorderType, Borders, Cell, Row, Table}; use ratatui::Frame; use std::net::IpAddr; use std::rc::Rc; +use trippy::tracing::ProbeResponseExtension; /// Render the table of data about the hops. /// @@ -233,12 +234,33 @@ fn format_address( let addr_fmt = match config.address_mode { AddressMode::IP => addr.to_string(), AddressMode::Host => { - if config.lookup_as_info { + let hostname = if config.lookup_as_info { let entry = dns.lazy_reverse_lookup_with_asinfo(*addr); format_dns_entry(entry, true, config.as_mode) } else { let entry = dns.lazy_reverse_lookup(*addr); format_dns_entry(entry, false, config.as_mode) + }; + + // TODO just a hack for now... + if let Some(extensions) = hop.extensions() { + let mpls = extensions + .extensions + .iter() + .map(|ext| match ext { + ProbeResponseExtension::Unknown => todo!(), + ProbeResponseExtension::Mpls(mpls) => mpls + .members + .iter() + .map(|member| { + format!("[MPLS label: {}, ttl: {}]", member.label, member.ttl) + }) + .join("\n"), + }) + .join("\n"); + format!("{hostname} {mpls}") + } else { + hostname } } AddressMode::Both => { diff --git a/src/main.rs b/src/main.rs index fb20954b8..7cb7418d5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -235,6 +235,7 @@ fn make_channel_config( args.payload_pattern, args.multipath_strategy, args.tos, + args.icmp_extensions, args.read_timeout, args.min_round_duration, ) diff --git a/src/tracing.rs b/src/tracing.rs index 0949442b2..fd4370fac 100644 --- a/src/tracing.rs +++ b/src/tracing.rs @@ -16,5 +16,7 @@ pub use config::{ }; pub use net::channel::TracerChannel; pub use net::source::SourceAddr; -pub use probe::{IcmpPacketType, Probe, ProbeStatus}; +pub use probe::{ + IcmpPacketType, Probe, ProbeResponseExtension, ProbeResponseExtensions, ProbeStatus, +}; pub use tracer::{Tracer, TracerRound}; diff --git a/src/tracing/config.rs b/src/tracing/config.rs index c68c67bd7..1f7e566b6 100644 --- a/src/tracing/config.rs +++ b/src/tracing/config.rs @@ -157,6 +157,7 @@ pub struct TracerChannelConfig { pub payload_pattern: PayloadPattern, pub multipath_strategy: MultipathStrategy, pub tos: TypeOfService, + pub icmp_extensions: bool, pub read_timeout: Duration, pub tcp_connect_timeout: Duration, } @@ -173,6 +174,7 @@ impl TracerChannelConfig { payload_pattern: u8, multipath_strategy: MultipathStrategy, tos: u8, + icmp_extensions: bool, read_timeout: Duration, tcp_connect_timeout: Duration, ) -> Self { @@ -185,6 +187,7 @@ impl TracerChannelConfig { payload_pattern: PayloadPattern(payload_pattern), multipath_strategy, tos: TypeOfService(tos), + icmp_extensions, read_timeout, tcp_connect_timeout, } diff --git a/src/tracing/net.rs b/src/tracing/net.rs index 3c047d9e0..d2b627ea0 100644 --- a/src/tracing/net.rs +++ b/src/tracing/net.rs @@ -8,6 +8,9 @@ mod ipv4; /// IPv6 implementation. mod ipv6; +/// ICMP extensions. +mod extension; + /// Platform specific network code. mod platform; diff --git a/src/tracing/net/channel.rs b/src/tracing/net/channel.rs index 035fe0629..3c176d5af 100644 --- a/src/tracing/net/channel.rs +++ b/src/tracing/net/channel.rs @@ -27,6 +27,7 @@ pub struct TracerChannel { payload_pattern: PayloadPattern, multipath_strategy: MultipathStrategy, tos: TypeOfService, + icmp_extensions: bool, read_timeout: Duration, tcp_connect_timeout: Duration, send_socket: Option, @@ -64,6 +65,7 @@ impl TracerChannel { payload_pattern: config.payload_pattern, multipath_strategy: config.multipath_strategy, tos: config.tos, + icmp_extensions: config.icmp_extensions, read_timeout: config.read_timeout, tcp_connect_timeout: config.tcp_connect_timeout, send_socket, @@ -91,7 +93,7 @@ impl Network for TracerChannel { resp => Ok(resp), }, }?; - if let Some(resp) = prob_response { + if let Some(resp) = &prob_response { tracing::debug!(?resp); } Ok(prob_response) @@ -163,10 +165,10 @@ impl TracerChannel { fn dispatch_tcp_probe(&mut self, probe: Probe) -> TraceResult<()> { let socket = match (self.src_addr, self.dest_addr) { (IpAddr::V4(src_addr), IpAddr::V4(dest_addr)) => { - ipv4::dispatch_tcp_probe(probe, src_addr, dest_addr, self.tos) + ipv4::dispatch_tcp_probe(&probe, src_addr, dest_addr, self.tos) } (IpAddr::V6(src_addr), IpAddr::V6(dest_addr)) => { - ipv6::dispatch_tcp_probe(probe, src_addr, dest_addr) + ipv6::dispatch_tcp_probe(&probe, src_addr, dest_addr) } _ => unreachable!(), }?; @@ -180,8 +182,16 @@ impl TracerChannel { fn recv_icmp_probe(&mut self) -> TraceResult> { if self.recv_socket.is_readable(self.read_timeout)? { match self.dest_addr { - IpAddr::V4(_) => ipv4::recv_icmp_probe(&mut self.recv_socket, self.protocol), - IpAddr::V6(_) => ipv6::recv_icmp_probe(&mut self.recv_socket, self.protocol), + IpAddr::V4(_) => ipv4::recv_icmp_probe( + &mut self.recv_socket, + self.protocol, + self.icmp_extensions, + ), + IpAddr::V6(_) => ipv6::recv_icmp_probe( + &mut self.recv_socket, + self.protocol, + self.icmp_extensions, + ), } } else { Ok(None) diff --git a/src/tracing/net/extension.rs b/src/tracing/net/extension.rs new file mode 100644 index 000000000..a5b431cd5 --- /dev/null +++ b/src/tracing/net/extension.rs @@ -0,0 +1,60 @@ +use crate::tracing::error::TracerError; +use crate::tracing::packet::icmp_extension::extension_object::{ClassNum, ExtensionObject}; +use crate::tracing::packet::icmp_extension::extension_structure::ExtensionStructure; +use crate::tracing::packet::icmp_extension::mpls_label_stack::MplsLabelStack; +use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMember; +use crate::tracing::probe::{ + MplsExtensionData, MplsExtensionMember, ProbeResponseExtension, ProbeResponseExtensions, +}; +use crate::tracing::util::Required; + +impl TryFrom<&[u8]> for ProbeResponseExtensions { + type Error = TracerError; + + fn try_from(value: &[u8]) -> Result { + Self::try_from(ExtensionStructure::new_view(value).req()?) + } +} + +impl TryFrom> for ProbeResponseExtensions { + type Error = TracerError; + + fn try_from(value: ExtensionStructure<'_>) -> Result { + let extensions = value + .objects() + .flat_map(|obj| ExtensionObject::new_view(obj).req()) + .map(|obj| match obj.get_class_num() { + ClassNum::MultiProtocolLabelSwitchingLabelStack => { + MplsLabelStack::new_view(obj.payload()) + .req() + .map(|mpls| ProbeResponseExtension::Mpls(MplsExtensionData::from(mpls))) + } + _ => Ok(ProbeResponseExtension::Unknown), + }) + .collect::>()?; + Ok(Self { extensions }) + } +} + +impl From> for MplsExtensionData { + fn from(value: MplsLabelStack<'_>) -> Self { + Self { + members: value + .members() + .flat_map(|member| MplsLabelStackMember::new_view(member).req()) + .map(MplsExtensionMember::from) + .collect(), + } + } +} + +impl From> for MplsExtensionMember { + fn from(value: MplsLabelStackMember<'_>) -> Self { + Self { + label: value.get_label(), + exp: value.get_exp(), + bos: value.get_bos(), + ttl: value.get_ttl(), + } + } +} diff --git a/src/tracing/net/ipv4.rs b/src/tracing/net/ipv4.rs index 18f3f6e27..df59a0e58 100644 --- a/src/tracing/net/ipv4.rs +++ b/src/tracing/net/ipv4.rs @@ -15,8 +15,8 @@ use crate::tracing::packet::tcp::TcpPacket; use crate::tracing::packet::udp::UdpPacket; use crate::tracing::packet::IpProtocol; use crate::tracing::probe::{ - ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, ProbeResponseSeqTcp, - ProbeResponseSeqUdp, + ProbeResponse, ProbeResponseData, ProbeResponseExtensions, ProbeResponseSeq, + ProbeResponseSeqIcmp, ProbeResponseSeqTcp, ProbeResponseSeqUdp, }; use crate::tracing::types::{PacketSize, PayloadPattern, Sequence, TraceId, TypeOfService}; use crate::tracing::util::Required; @@ -147,7 +147,7 @@ fn swap_checksum_and_payload(udp: &mut UdpPacket<'_>) { #[instrument(skip(probe))] pub fn dispatch_tcp_probe( - probe: Probe, + probe: &Probe, src_addr: Ipv4Addr, dest_addr: Ipv4Addr, tos: TypeOfService, @@ -187,12 +187,13 @@ pub fn dispatch_tcp_probe( pub fn recv_icmp_probe( recv_socket: &mut Socket, protocol: TracerProtocol, + icmp_extensions: bool, ) -> TraceResult> { let mut buf = [0_u8; MAX_PACKET_SIZE]; match recv_socket.read(&mut buf) { Ok(bytes_read) => { let ipv4 = Ipv4Packet::new_view(&buf[..bytes_read]).req()?; - Ok(extract_probe_resp(protocol, &ipv4)?) + Ok(extract_probe_resp(protocol, icmp_extensions, &ipv4)?) } Err(err) => match err.kind() { ErrorKind::WouldBlock => Ok(None), @@ -229,11 +230,10 @@ pub fn recv_tcp_socket( } if platform::is_host_unreachable_error(code) { let error_addr = tcp_socket.icmp_error_info()?; - return Ok(Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - SystemTime::now(), - error_addr, - resp_seq, - )))); + return Ok(Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(SystemTime::now(), error_addr, resp_seq), + None, + ))); } } } @@ -324,6 +324,7 @@ fn udp_payload_size(packet_size: usize) -> usize { #[instrument] fn extract_probe_resp( protocol: TracerProtocol, + icmp_extensions: bool, ipv4: &Ipv4Packet<'_>, ) -> TraceResult> { let recv = SystemTime::now(); @@ -332,16 +333,36 @@ fn extract_probe_resp( Ok(match icmp_v4.get_icmp_type() { IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v4.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; - Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - recv, src, resp_seq, - ))) + let payload = packet.payload(); + let extension = if icmp_extensions { + packet + .extension() + .map(ProbeResponseExtensions::try_from) + .transpose()? + } else { + None + }; + let resp_seq = extract_probe_resp_seq(payload, protocol)?; + Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(recv, src, resp_seq), + extension, + )) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v4.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; + let payload = packet.payload(); + let extension = if icmp_extensions { + packet + .extension() + .map(ProbeResponseExtensions::try_from) + .transpose()? + } else { + None + }; + let resp_seq = extract_probe_resp_seq(payload, protocol)?; Some(ProbeResponse::DestinationUnreachable( ProbeResponseData::new(recv, src, resp_seq), + extension, )) } IcmpType::EchoReply => match protocol { diff --git a/src/tracing/net/ipv6.rs b/src/tracing/net/ipv6.rs index bd81e5dcf..23217ccc3 100644 --- a/src/tracing/net/ipv6.rs +++ b/src/tracing/net/ipv6.rs @@ -19,7 +19,7 @@ use crate::tracing::probe::{ }; use crate::tracing::types::{PacketSize, PayloadPattern, Sequence, TraceId}; use crate::tracing::util::Required; -use crate::tracing::{Probe, TracerProtocol}; +use crate::tracing::{Probe, ProbeResponseExtensions, TracerProtocol}; use std::io::ErrorKind; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::time::SystemTime; @@ -101,7 +101,7 @@ pub fn dispatch_udp_probe( #[instrument(skip(probe))] pub fn dispatch_tcp_probe( - probe: Probe, + probe: &Probe, src_addr: Ipv6Addr, dest_addr: Ipv6Addr, ) -> TraceResult { @@ -139,6 +139,7 @@ pub fn dispatch_tcp_probe( pub fn recv_icmp_probe( recv_socket: &mut Socket, protocol: TracerProtocol, + icmp_extensions: bool, ) -> TraceResult> { let mut buf = [0_u8; MAX_PACKET_SIZE]; match recv_socket.recv_from(&mut buf) { @@ -148,7 +149,12 @@ pub fn recv_icmp_probe( SocketAddr::V6(addr) => addr.ip(), SocketAddr::V4(_) => panic!(), }; - Ok(extract_probe_resp(protocol, &icmp_v6, *src_addr)?) + Ok(extract_probe_resp( + protocol, + icmp_extensions, + &icmp_v6, + *src_addr, + )?) } Err(err) => match err.kind() { ErrorKind::WouldBlock => Ok(None), @@ -185,11 +191,10 @@ pub fn recv_tcp_socket( } if platform::is_host_unreachable_error(code) { let error_addr = tcp_socket.icmp_error_info()?; - return Ok(Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - SystemTime::now(), - error_addr, - resp_seq, - )))); + return Ok(Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(SystemTime::now(), error_addr, resp_seq), + None, + ))); } } } @@ -254,6 +259,7 @@ fn udp_payload_size(packet_size: usize) -> usize { fn extract_probe_resp( protocol: TracerProtocol, + icmp_extensions: bool, icmp_v6: &IcmpPacket<'_>, src: Ipv6Addr, ) -> TraceResult> { @@ -262,16 +268,34 @@ fn extract_probe_resp( Ok(match icmp_v6.get_icmp_type() { IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v6.packet()).req()?; + let extension = if icmp_extensions { + packet + .extension() + .map(ProbeResponseExtensions::try_from) + .transpose()? + } else { + None + }; let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; - Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( - recv, ip, resp_seq, - ))) + Some(ProbeResponse::TimeExceeded( + ProbeResponseData::new(recv, ip, resp_seq), + extension, + )) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v6.packet()).req()?; + let extension = if icmp_extensions { + packet + .extension() + .map(ProbeResponseExtensions::try_from) + .transpose()? + } else { + None + }; let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; Some(ProbeResponse::DestinationUnreachable( ProbeResponseData::new(recv, ip, resp_seq), + extension, )) } IcmpType::EchoReply => match protocol { diff --git a/src/tracing/packet.rs b/src/tracing/packet.rs index ab9f56160..29e742a15 100644 --- a/src/tracing/packet.rs +++ b/src/tracing/packet.rs @@ -9,6 +9,9 @@ pub mod icmpv4; /// `ICMPv6` packets. pub mod icmpv6; +/// ICMP extensions +pub mod icmp_extension; + /// `IPv4` packets. pub mod ipv4; @@ -21,7 +24,8 @@ pub mod udp; /// `TCP` packets. pub mod tcp; -fn fmt_payload(bytes: &[u8]) -> String { +#[must_use] +pub fn fmt_payload(bytes: &[u8]) -> String { use itertools::Itertools as _; format!("{:02x}", bytes.iter().format(" ")) } diff --git a/src/tracing/packet/icmp_extension.rs b/src/tracing/packet/icmp_extension.rs new file mode 100644 index 000000000..5a60dd30b --- /dev/null +++ b/src/tracing/packet/icmp_extension.rs @@ -0,0 +1,964 @@ +pub mod extension_structure { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::icmp_extension::extension_object::ExtensionObject; + + /// Represents an ICMP `ExtensionStructure` pseudo object. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionStructure<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionStructure<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + #[must_use] + pub fn header(&self) -> &[u8] { + &self.buf.as_slice()[..Self::minimum_packet_size()] + } + + /// An iterator of Extension Objects contained within this `ExtensionStructure`. + #[must_use] + pub fn objects(&self) -> ExtensionObjectIter<'_> { + ExtensionObjectIter::new(&self.buf) + } + } + + pub struct ExtensionObjectIter<'a> { + buf: &'a Buffer<'a>, + offset: usize, + } + + impl<'a> ExtensionObjectIter<'a> { + #[must_use] + pub fn new(buf: &'a Buffer<'_>) -> Self { + Self { + buf, + offset: ExtensionStructure::minimum_packet_size(), + } + } + } + + impl<'a> Iterator for ExtensionObjectIter<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if self.offset >= self.buf.as_slice().len() { + None + } else { + let object_bytes = &self.buf.as_slice()[self.offset..]; + if let Some(object) = ExtensionObject::new_view(object_bytes) { + self.offset += usize::from(object.get_length()); + Some(object_bytes) + } else { + None + } + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::tracing::packet::icmp_extension::extension_header::ExtensionHeader; + use crate::tracing::packet::icmp_extension::extension_object::{ + ClassNum, ClassSubType, ExtensionObject, + }; + + #[test] + fn test_header() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionStructure::new_view(&buf).unwrap(); + let header = ExtensionHeader::new_view(extensions.header()).unwrap(); + assert_eq!(2, header.get_version()); + assert_eq!(0x993A, header.get_checksum()); + } + + #[test] + fn test_object_iterator() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionStructure::new_view(&buf).unwrap(); + let mut object_iter = extensions.objects(); + let object_bytes = object_iter.next().unwrap(); + let object = ExtensionObject::new_view(object_bytes).unwrap(); + assert_eq!(8, object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + object.get_class_num() + ); + assert_eq!(ClassSubType(1), object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], object.payload()); + assert!(object_iter.next().is_none()); + } + } +} + +pub mod extension_header { + use crate::tracing::packet::buffer::Buffer; + use std::fmt::{Debug, Formatter}; + + const VERSION_OFFSET: usize = 0; + const CHECKSUM_OFFSET: usize = 2; + + /// Represents an ICMP `ExtensionHeader`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionHeader<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionHeader<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn get_version(&self) -> u8 { + (self.buf.read(VERSION_OFFSET) & 0xf0) >> 4 + } + + #[must_use] + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) + } + + pub fn set_version(&mut self, val: u8) { + *self.buf.write(VERSION_OFFSET) = + (self.buf.read(VERSION_OFFSET) & 0xf) | ((val & 0xf) << 4); + } + + pub fn set_checksum(&mut self, val: u16) { + self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + } + + impl Debug for ExtensionHeader<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionHeader") + .field("version", &self.get_version()) + .field("checksum", &self.get_checksum()) + // .field("payload", &fmt_payload(self.payload())) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_version() { + let mut buf = [0_u8; ExtensionHeader::minimum_packet_size()]; + let mut extension = ExtensionHeader::new(&mut buf).unwrap(); + extension.set_version(0); + assert_eq!(0, extension.get_version()); + assert_eq!([0x00], extension.packet()[0..1]); + extension.set_version(2); + assert_eq!(2, extension.get_version()); + assert_eq!([0x20], extension.packet()[0..1]); + extension.set_version(15); + assert_eq!(15, extension.get_version()); + assert_eq!([0xF0], extension.packet()[0..1]); + } + + #[test] + fn test_checksum() { + let mut buf = [0_u8; ExtensionHeader::minimum_packet_size()]; + let mut extension = ExtensionHeader::new(&mut buf).unwrap(); + extension.set_checksum(0); + assert_eq!(0, extension.get_checksum()); + assert_eq!([0x00, 0x00], extension.packet()[2..=3]); + extension.set_checksum(1999); + assert_eq!(1999, extension.get_checksum()); + assert_eq!([0x07, 0xCF], extension.packet()[2..=3]); + extension.set_checksum(39226); + assert_eq!(39226, extension.get_checksum()); + assert_eq!([0x99, 0x3A], extension.packet()[2..=3]); + extension.set_checksum(u16::MAX); + assert_eq!(u16::MAX, extension.get_checksum()); + assert_eq!([0xFF, 0xFF], extension.packet()[2..=3]); + } + + #[test] + fn test_extension_header_view() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extension = ExtensionHeader::new_view(&buf).unwrap(); + assert_eq!(2, extension.get_version()); + assert_eq!(0x993A, extension.get_checksum()); + } + } +} + +pub mod extension_object { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::fmt_payload; + use std::fmt::{Debug, Formatter}; + + /// The ICMP Extension Object Class Num. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] + pub enum ClassNum { + MultiProtocolLabelSwitchingLabelStack, + InterfaceInformationObject, + InterfaceIdentificationObject, + ExtendedInformation, + Other(u8), + } + + impl ClassNum { + #[must_use] + pub fn id(&self) -> u8 { + match self { + Self::MultiProtocolLabelSwitchingLabelStack => 1, + Self::InterfaceInformationObject => 2, + Self::InterfaceIdentificationObject => 3, + Self::ExtendedInformation => 4, + Self::Other(id) => *id, + } + } + } + + impl From for ClassNum { + fn from(val: u8) -> Self { + match val { + 1 => Self::MultiProtocolLabelSwitchingLabelStack, + 2 => Self::InterfaceInformationObject, + 3 => Self::InterfaceIdentificationObject, + 4 => Self::ExtendedInformation, + id => Self::Other(id), + } + } + } + + /// The ICMP Extension Object Class Sub-type. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] + pub struct ClassSubType(pub u8); + + impl From for ClassSubType { + fn from(val: u8) -> Self { + Self(val) + } + } + + const LENGTH_OFFSET: usize = 0; + const CLASS_NUM_OFFSET: usize = 2; + const CLASS_SUBTYPE_OFFSET: usize = 3; + + /// Represents an ICMP `ExtensionObject`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionObject<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionObject<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + pub fn set_length(&mut self, val: u16) { + self.buf.set_bytes(LENGTH_OFFSET, val.to_be_bytes()); + } + + pub fn set_class_num(&mut self, val: ClassNum) { + *self.buf.write(CLASS_NUM_OFFSET) = val.id(); + } + + pub fn set_class_subtype(&mut self, val: ClassSubType) { + *self.buf.write(CLASS_SUBTYPE_OFFSET) = val.0; + } + + pub fn set_payload(&mut self, vals: &[u8]) { + let current_offset = Self::minimum_packet_size(); + self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] + .copy_from_slice(vals); + } + + #[must_use] + pub fn get_length(&self) -> u16 { + u16::from_be_bytes(self.buf.get_bytes(LENGTH_OFFSET)) + } + + #[must_use] + pub fn get_class_num(&self) -> ClassNum { + ClassNum::from(self.buf.read(CLASS_NUM_OFFSET)) + } + + #[must_use] + pub fn get_class_subtype(&self) -> ClassSubType { + ClassSubType::from(self.buf.read(CLASS_SUBTYPE_OFFSET)) + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + #[must_use] + pub fn payload(&self) -> &[u8] { + &self.buf.as_slice()[Self::minimum_packet_size()..usize::from(self.get_length())] + } + } + + impl Debug for ExtensionObject<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionObject") + .field("length", &self.get_length()) + .field("class_num", &self.get_class_num()) + .field("class_subtype", &self.get_class_subtype()) + .field("payload", &fmt_payload(self.payload())) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_length() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_length(0); + assert_eq!(0, extension.get_length()); + assert_eq!([0x00, 0x00], extension.packet()[0..=1]); + extension.set_length(8); + assert_eq!(8, extension.get_length()); + assert_eq!([0x00, 0x08], extension.packet()[0..=1]); + extension.set_length(u16::MAX); + assert_eq!(u16::MAX, extension.get_length()); + assert_eq!([0xFF, 0xFF], extension.packet()[0..=1]); + } + + #[test] + fn test_class_num() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_class_num(ClassNum::MultiProtocolLabelSwitchingLabelStack); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension.get_class_num() + ); + assert_eq!([0x01], extension.packet()[2..3]); + extension.set_class_num(ClassNum::InterfaceInformationObject); + assert_eq!( + ClassNum::InterfaceInformationObject, + extension.get_class_num() + ); + assert_eq!([0x02], extension.packet()[2..3]); + extension.set_class_num(ClassNum::InterfaceIdentificationObject); + assert_eq!( + ClassNum::InterfaceIdentificationObject, + extension.get_class_num() + ); + assert_eq!([0x03], extension.packet()[2..3]); + extension.set_class_num(ClassNum::ExtendedInformation); + assert_eq!(ClassNum::ExtendedInformation, extension.get_class_num()); + assert_eq!([0x04], extension.packet()[2..3]); + extension.set_class_num(ClassNum::Other(255)); + assert_eq!(ClassNum::Other(255), extension.get_class_num()); + assert_eq!([0xFF], extension.packet()[2..3]); + } + + #[test] + fn test_class_subtype() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_class_subtype(ClassSubType(0)); + assert_eq!(ClassSubType(0), extension.get_class_subtype()); + assert_eq!([0x00], extension.packet()[3..4]); + extension.set_class_subtype(ClassSubType(1)); + assert_eq!(ClassSubType(1), extension.get_class_subtype()); + assert_eq!([0x01], extension.packet()[3..4]); + extension.set_class_subtype(ClassSubType(255)); + assert_eq!(ClassSubType(255), extension.get_class_subtype()); + assert_eq!([0xff], extension.packet()[3..4]); + } + + #[test] + fn test_extension_header_view() { + let buf = [0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01]; + let object = ExtensionObject::new_view(&buf).unwrap(); + assert_eq!(8, object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + object.get_class_num() + ); + assert_eq!(ClassSubType(1), object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], object.payload()); + } + } +} + +pub mod mpls_label_stack { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMember; + + /// Represents an ICMP `MplsLabelStack`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct MplsLabelStack<'a> { + buf: Buffer<'a>, + } + + impl<'a> MplsLabelStack<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + #[must_use] + pub fn members(&self) -> MplsLabelStackIter<'_> { + MplsLabelStackIter::new(&self.buf) + } + } + + pub struct MplsLabelStackIter<'a> { + buf: &'a Buffer<'a>, + offset: usize, + bos: u8, + } + + impl<'a> MplsLabelStackIter<'a> { + #[must_use] + pub fn new(buf: &'a Buffer<'_>) -> Self { + Self { + buf, + offset: 0, + bos: 0, + } + } + } + + impl<'a> Iterator for MplsLabelStackIter<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if self.bos > 0 || self.offset >= self.buf.as_slice().len() { + None + } else { + let member_bytes = &self.buf.as_slice()[self.offset..]; + if let Some(member) = MplsLabelStackMember::new_view(member_bytes) { + self.bos = member.get_bos(); + self.offset += MplsLabelStackMember::minimum_packet_size(); + Some(member_bytes) + } else { + None + } + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_stack_member_iterator() { + let buf = [0x04, 0xbb, 0x41, 0x01]; + let stack = MplsLabelStack::new_view(&buf).unwrap(); + let mut member_iter = stack.members(); + let member_bytes = member_iter.next().unwrap(); + let member = MplsLabelStackMember::new_view(member_bytes).unwrap(); + assert_eq!(19380, member.get_label()); + assert_eq!(0, member.get_exp()); + assert_eq!(1, member.get_bos()); + assert_eq!(1, member.get_ttl()); + assert!(member_iter.next().is_none()); + } + } +} + +pub mod mpls_label_stack_member { + use crate::tracing::packet::buffer::Buffer; + use std::fmt::{Debug, Formatter}; + + const LABEL_OFFSET: usize = 0; + const EXP_OFFSET: usize = 2; + const BOS_OFFSET: usize = 2; + const TTL_OFFSET: usize = 3; + + /// Represents an ICMP `MplsLabelStackMember`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct MplsLabelStackMember<'a> { + buf: Buffer<'a>, + } + + impl<'a> MplsLabelStackMember<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn get_label(&self) -> u32 { + u32::from_be_bytes([ + 0x0, + self.buf.read(LABEL_OFFSET), + self.buf.read(LABEL_OFFSET + 1), + self.buf.read(LABEL_OFFSET + 2), + ]) >> 4 + } + + #[must_use] + pub fn get_exp(&self) -> u8 { + (self.buf.read(EXP_OFFSET) & 0x0e) >> 1 + } + + #[must_use] + pub fn get_bos(&self) -> u8 { + self.buf.read(BOS_OFFSET) & 0x01 + } + + #[must_use] + pub fn get_ttl(&self) -> u8 { + self.buf.read(TTL_OFFSET) + } + + pub fn set_label(&mut self, val: u32) { + let bytes = (val << 4).to_be_bytes(); + *self.buf.write(LABEL_OFFSET) = bytes[1]; + *self.buf.write(LABEL_OFFSET + 1) = bytes[2]; + *self.buf.write(LABEL_OFFSET + 2) = + (self.buf.read(LABEL_OFFSET + 2) & 0x0f) | (bytes[3] & 0xf0); + } + + pub fn set_exp(&mut self, exp: u8) { + *self.buf.write(EXP_OFFSET) = (self.buf.read(EXP_OFFSET) & 0xf1) | ((exp << 1) & 0x0e); + } + + pub fn set_bos(&mut self, bos: u8) { + *self.buf.write(BOS_OFFSET) = (self.buf.read(BOS_OFFSET) & 0xfe) | (bos & 0x01); + } + + pub fn set_ttl(&mut self, ttl: u8) { + *self.buf.write(TTL_OFFSET) = ttl; + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + } + + impl Debug for MplsLabelStackMember<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MplsLabelStackMember") + .field("label", &self.get_label()) + .field("exp", &self.get_exp()) + .field("bos", &self.get_bos()) + .field("ttl", &self.get_ttl()) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_label() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_label(0); + assert_eq!(0, mpls_extension.get_label()); + assert_eq!([0x00, 0x00, 0x00], mpls_extension.packet()[0..3]); + mpls_extension.set_label(19380); + assert_eq!(19380, mpls_extension.get_label()); + assert_eq!([0x04, 0xbb, 0x40], mpls_extension.packet()[0..3]); + mpls_extension.set_label(1_048_575); + assert_eq!(1_048_575, mpls_extension.get_label()); + assert_eq!([0xff, 0xff, 0xf0], mpls_extension.packet()[0..3]); + } + + #[test] + fn test_exp() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_exp(0); + assert_eq!(0, mpls_extension.get_exp()); + assert_eq!([0x00], mpls_extension.packet()[2..3]); + mpls_extension.set_exp(7); + assert_eq!(7, mpls_extension.get_exp()); + assert_eq!([0x0e], mpls_extension.packet()[2..3]); + } + + #[test] + fn test_bos() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_bos(0); + assert_eq!(0, mpls_extension.get_bos()); + assert_eq!([0x00], mpls_extension.packet()[2..3]); + mpls_extension.set_bos(1); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!([0x01], mpls_extension.packet()[2..3]); + } + + #[test] + fn test_ttl() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_ttl(0); + assert_eq!(0, mpls_extension.get_ttl()); + assert_eq!([0x00], mpls_extension.packet()[3..4]); + mpls_extension.set_ttl(1); + assert_eq!(1, mpls_extension.get_ttl()); + assert_eq!([0x01], mpls_extension.packet()[3..4]); + mpls_extension.set_ttl(255); + assert_eq!(255, mpls_extension.get_ttl()); + assert_eq!([0xff], mpls_extension.packet()[3..4]); + } + + #[test] + fn test_combined() { + let mut buf = [0_u8; MplsLabelStackMember::minimum_packet_size()]; + let mut mpls_extension = MplsLabelStackMember::new(&mut buf).unwrap(); + mpls_extension.set_label(19380); + mpls_extension.set_exp(0); + mpls_extension.set_bos(1); + mpls_extension.set_ttl(1); + assert_eq!(19380, mpls_extension.get_label()); + assert_eq!(0, mpls_extension.get_exp()); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!(1, mpls_extension.get_ttl()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], mpls_extension.packet()[0..4]); + mpls_extension.set_label(1_048_575); + mpls_extension.set_exp(7); + mpls_extension.set_bos(1); + mpls_extension.set_ttl(255); + assert_eq!(1_048_575, mpls_extension.get_label()); + assert_eq!(7, mpls_extension.get_exp()); + assert_eq!(1, mpls_extension.get_bos()); + assert_eq!(255, mpls_extension.get_ttl()); + assert_eq!([0xff, 0xff, 0xff, 0xff], mpls_extension.packet()[0..4]); + } + + #[test] + fn test_view() { + let buf = [0x04, 0xbb, 0x41, 0x01]; + let object = MplsLabelStackMember::new_view(&buf).unwrap(); + assert_eq!(19380, object.get_label()); + assert_eq!(0, object.get_exp()); + assert_eq!(1, object.get_bos()); + assert_eq!(1, object.get_ttl()); + } + } +} + +pub mod extension_splitter { + const ICMP_ORIG_DATAGRAM_MIN_LENGTH: usize = 128; + + /// Separate an ICMP payload from ICMP extensions as defined in rfc4884. + /// + /// Applies to `TimeExceeded` and `DestinationUnreachable` ICMP messages only. + #[must_use] + pub fn split(rfc4884_length: u8, icmp_payload: &[u8]) -> (&[u8], Option<&[u8]>) { + let orig_datagram_length = usize::from(rfc4884_length * 4); + + // TODO what to do if the claimed orig_datagram_length is bigger than the actual payload? + // we could truncate or we can err or we could return empty? + if orig_datagram_length > icmp_payload.len() { + return (&[], None); + } + + if orig_datagram_length > 0 { + // compliant message case + if icmp_payload.len() > orig_datagram_length { + // extension case (untested): the icmp_payload is longer than the orig_datagram and so whatever remains must be an extension + let extension_len = icmp_payload.len() - orig_datagram_length; + let extension = + &icmp_payload[orig_datagram_length..orig_datagram_length + extension_len]; + ( + &icmp_payload[..orig_datagram_length - extension_len], + Some(extension), + ) + } else { + (&icmp_payload[..orig_datagram_length], None) + } + // "Specifically, when a TRACEROUTE application operating in non- + // compliant mode receives a sufficiently long ICMP message that does + // not specify a length attribute, it will parse for a valid extension + // header at a fixed location, assuming a 128-octet "original datagram" + // field." + // TODO - have to include length of the extension header here? MTR does + } else if orig_datagram_length == 0 && icmp_payload.len() > ICMP_ORIG_DATAGRAM_MIN_LENGTH { + // extension present, non-compliant message + let extension_len = icmp_payload.len() - ICMP_ORIG_DATAGRAM_MIN_LENGTH; + let extension = &icmp_payload + [ICMP_ORIG_DATAGRAM_MIN_LENGTH..ICMP_ORIG_DATAGRAM_MIN_LENGTH + extension_len]; + ( + &icmp_payload[..icmp_payload.len() - extension_len], + Some(extension), + ) + } else { + // no extension present + (icmp_payload, None) + } + } + + #[cfg(test)] + mod tests { + use crate::tracing::packet::icmp_extension::extension_header::ExtensionHeader; + use crate::tracing::packet::icmp_extension::extension_object::{ + ClassNum, ClassSubType, ExtensionObject, + }; + use crate::tracing::packet::icmp_extension::extension_structure::ExtensionStructure; + use crate::tracing::packet::icmp_extension::mpls_label_stack::MplsLabelStack; + use crate::tracing::packet::icmp_extension::mpls_label_stack_member::MplsLabelStackMember; + use crate::tracing::packet::icmpv4::echo_request::EchoRequestPacket; + use crate::tracing::packet::icmpv4::time_exceeded::TimeExceededPacket; + use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; + use crate::tracing::packet::ipv4::Ipv4Packet; + use std::net::Ipv4Addr; + + // This ICMP TimeExceeded packet which contains single `MPLS` extension object with a single member. The + // packet does not have a `length` field and is therefore rfc4884 non-complaint. + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_split_extension_ipv4_time_exceeded_non_compliant_mpls() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ff 00 00 00 00 45 00 00 54 cc 1c 40 00 + 01 01 b5 f4 c0 a8 01 15 5d b8 d8 22 08 00 0f e3 + 65 da 82 42 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 20 00 99 3a 00 08 01 01 + 04 bb 41 01 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(IcmpType::TimeExceeded, time_exceeded_packet.get_icmp_type()); + assert_eq!(IcmpCode(0), time_exceeded_packet.get_icmp_code()); + assert_eq!(62719, time_exceeded_packet.get_checksum()); + assert_eq!(0, time_exceeded_packet.get_length()); + assert_eq!(&buf[8..136], time_exceeded_packet.payload()); + assert_eq!(Some(&buf[136..]), time_exceeded_packet.extension()); + + let nested_ipv4 = Ipv4Packet::new_view(&buf[8..136]).unwrap(); + assert_eq!(Ipv4Addr::from([192, 168, 1, 21]), nested_ipv4.get_source()); + assert_eq!( + Ipv4Addr::from([93, 184, 216, 34]), + nested_ipv4.get_destination() + ); + assert_eq!(&buf[28..136], nested_ipv4.payload()); + + let nested_echo = EchoRequestPacket::new_view(nested_ipv4.payload()).unwrap(); + assert_eq!(IcmpCode(0), nested_echo.get_icmp_code()); + assert_eq!(IcmpType::EchoRequest, nested_echo.get_icmp_type()); + assert_eq!(0x0FE3, nested_echo.get_checksum()); + assert_eq!(26074, nested_echo.get_identifier()); + assert_eq!(33346, nested_echo.get_sequence()); + assert_eq!(&buf[36..136], nested_echo.payload()); + + let extensions = + ExtensionStructure::new_view(time_exceeded_packet.extension().unwrap()).unwrap(); + + let extension_header = ExtensionHeader::new_view(extensions.header()).unwrap(); + assert_eq!(2, extension_header.get_version()); + assert_eq!(0x993A, extension_header.get_checksum()); + + let object_bytes = extensions.objects().next().unwrap(); + let extension_object = ExtensionObject::new_view(object_bytes).unwrap(); + + assert_eq!(8, extension_object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension_object.get_class_num() + ); + assert_eq!(ClassSubType(1), extension_object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], extension_object.payload()); + + let mpls_stack = MplsLabelStack::new_view(extension_object.payload()).unwrap(); + let mpls_stack_member_bytes = mpls_stack.members().next().unwrap(); + let mpls_stack_member = + MplsLabelStackMember::new_view(mpls_stack_member_bytes).unwrap(); + assert_eq!(19380, mpls_stack_member.get_label()); + assert_eq!(0, mpls_stack_member.get_exp()); + assert_eq!(1, mpls_stack_member.get_bos()); + assert_eq!(1, mpls_stack_member.get_ttl()); + } + + // This ICMP TimeExceeded packet does not have any ICMP extensions. It has a rfc4884 complaint `length` field. + #[test] + fn test_split_extension_ipv4_time_exceeded_compliant_no_extension() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ee 00 11 00 00 45 00 00 54 a2 ee 40 00 + 01 01 df 22 c0 a8 01 15 5d b8 d8 22 08 00 0f e1 + 65 da 82 44 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(IcmpType::TimeExceeded, time_exceeded_packet.get_icmp_type()); + assert_eq!(IcmpCode(0), time_exceeded_packet.get_icmp_code()); + assert_eq!(62702, time_exceeded_packet.get_checksum()); + assert_eq!(17, time_exceeded_packet.get_length()); + assert_eq!(&buf[8..76], time_exceeded_packet.payload()); + assert_eq!(None, time_exceeded_packet.extension()); + + let nested_ipv4 = Ipv4Packet::new_view(&buf[8..76]).unwrap(); + assert_eq!(Ipv4Addr::from([192, 168, 1, 21]), nested_ipv4.get_source()); + assert_eq!( + Ipv4Addr::from([93, 184, 216, 34]), + nested_ipv4.get_destination() + ); + assert_eq!(&buf[28..76], nested_ipv4.payload()); + + let nested_echo = EchoRequestPacket::new_view(nested_ipv4.payload()).unwrap(); + assert_eq!(IcmpCode(0), nested_echo.get_icmp_code()); + assert_eq!(IcmpType::EchoRequest, nested_echo.get_icmp_type()); + assert_eq!(0x0FE1, nested_echo.get_checksum()); + assert_eq!(26074, nested_echo.get_identifier()); + assert_eq!(33348, nested_echo.get_sequence()); + assert_eq!(&buf[36..76], nested_echo.payload()); + } + } +} diff --git a/src/tracing/packet/icmpv4.rs b/src/tracing/packet/icmpv4.rs index 577e41ec0..77032f0c0 100644 --- a/src/tracing/packet/icmpv4.rs +++ b/src/tracing/packet/icmpv4.rs @@ -631,12 +631,14 @@ pub mod echo_reply { pub mod time_exceeded { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; + const LENGTH_OFFSET: usize = 5; /// Represents an ICMP `TimeExceeded` packet. /// @@ -689,6 +691,11 @@ pub mod time_exceeded { u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) } + #[must_use] + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) + } + pub fn set_icmp_type(&mut self, val: IcmpType) { *self.buf.write(TYPE_OFFSET) = val.id(); } @@ -701,6 +708,10 @@ pub mod time_exceeded { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; + } + pub fn set_payload(&mut self, vals: &[u8]) { let current_offset = Self::minimum_packet_size(); self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] @@ -714,7 +725,20 @@ pub mod time_exceeded { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -724,6 +748,7 @@ pub mod time_exceeded { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) + .field("length", &self.get_length()) .field("payload", &fmt_payload(self.payload())) .finish() } @@ -784,6 +809,21 @@ pub mod time_exceeded { assert_eq!([0xFF, 0xFF], packet.packet()[2..=3]); } + #[test] + fn test_length() { + let mut buf = [0_u8; TimeExceededPacket::minimum_packet_size()]; + let mut packet = TimeExceededPacket::new(&mut buf).unwrap(); + packet.set_length(0); + assert_eq!(0, packet.get_length()); + assert_eq!([0x00], packet.packet()[5..6]); + packet.set_length(8); + assert_eq!(8, packet.get_length()); + assert_eq!([0x08], packet.packet()[5..6]); + packet.set_length(u8::MAX); + assert_eq!(u8::MAX, packet.get_length()); + assert_eq!([0xFF], packet.packet()[5..6]); + } + #[test] fn test_view() { let buf = [0x0b, 0x00, 0xf4, 0xee, 0x00, 0x11, 0x00, 0x00]; @@ -791,6 +831,7 @@ pub mod time_exceeded { assert_eq!(IcmpType::TimeExceeded, packet.get_icmp_type()); assert_eq!(IcmpCode(0), packet.get_icmp_code()); assert_eq!(62702, packet.get_checksum()); + assert_eq!(17, packet.get_length()); assert!(packet.payload().is_empty()); } } @@ -799,13 +840,14 @@ pub mod time_exceeded { pub mod destination_unreachable { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; - const UNUSED_OFFSET: usize = 4; + const LENGTH_OFFSET: usize = 5; const NEXT_HOP_MTU_OFFSET: usize = 6; /// Represents an ICMP `DestinationUnreachable` packet. @@ -860,8 +902,8 @@ pub mod destination_unreachable { } #[must_use] - pub fn get_unused(&self) -> u16 { - u16::from_be_bytes(self.buf.get_bytes(UNUSED_OFFSET)) + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) } #[must_use] @@ -881,8 +923,8 @@ pub mod destination_unreachable { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } - pub fn set_unused(&mut self, val: u16) { - self.buf.set_bytes(UNUSED_OFFSET, val.to_be_bytes()); + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; } pub fn set_next_hop_mtu(&mut self, val: u16) { @@ -902,7 +944,20 @@ pub mod destination_unreachable { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -912,7 +967,7 @@ pub mod destination_unreachable { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) - .field("unused", &self.get_unused()) + .field("length", &self.get_length()) .field("next_hop_mtu", &self.get_next_hop_mtu()) .field("payload", &fmt_payload(self.payload())) .finish() @@ -974,6 +1029,21 @@ pub mod destination_unreachable { assert_eq!([0xFF, 0xFF], packet.packet()[2..=3]); } + #[test] + fn test_length() { + let mut buf = [0_u8; DestinationUnreachablePacket::minimum_packet_size()]; + let mut packet = DestinationUnreachablePacket::new(&mut buf).unwrap(); + packet.set_length(0); + assert_eq!(0, packet.get_length()); + assert_eq!([0x00], packet.packet()[5..6]); + packet.set_length(8); + assert_eq!(8, packet.get_length()); + assert_eq!([0x08], packet.packet()[5..6]); + packet.set_length(u8::MAX); + assert_eq!(u8::MAX, packet.get_length()); + assert_eq!([0xFF], packet.packet()[5..6]); + } + #[test] fn test_view() { let buf = [0x03, 0x03, 0xdf, 0xdc, 0x00, 0x00, 0x00, 0x00]; @@ -981,6 +1051,7 @@ pub mod destination_unreachable { assert_eq!(IcmpType::DestinationUnreachable, packet.get_icmp_type()); assert_eq!(IcmpCode(3), packet.get_icmp_code()); assert_eq!(57308, packet.get_checksum()); + assert_eq!(0, packet.get_length()); assert!(packet.payload().is_empty()); } } diff --git a/src/tracing/packet/icmpv6.rs b/src/tracing/packet/icmpv6.rs index 704f9e41e..ef6acc19b 100644 --- a/src/tracing/packet/icmpv6.rs +++ b/src/tracing/packet/icmpv6.rs @@ -631,12 +631,14 @@ pub mod echo_reply { pub mod time_exceeded { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv6::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; + const LENGTH_OFFSET: usize = 4; /// Represents an ICMP `TimeExceeded` packet. /// @@ -689,6 +691,11 @@ pub mod time_exceeded { u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) } + #[must_use] + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) + } + pub fn set_icmp_type(&mut self, val: IcmpType) { *self.buf.write(TYPE_OFFSET) = val.id(); } @@ -701,6 +708,10 @@ pub mod time_exceeded { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; + } + pub fn set_payload(&mut self, vals: &[u8]) { let current_offset = Self::minimum_packet_size(); self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] @@ -714,7 +725,20 @@ pub mod time_exceeded { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -724,6 +748,7 @@ pub mod time_exceeded { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) + .field("length", &self.get_length()) .field("payload", &fmt_payload(self.payload())) .finish() } @@ -784,13 +809,29 @@ pub mod time_exceeded { assert_eq!([0xFF, 0xFF], packet.packet()[2..=3]); } + #[test] + fn test_length() { + let mut buf = [0_u8; TimeExceededPacket::minimum_packet_size()]; + let mut packet = TimeExceededPacket::new(&mut buf).unwrap(); + packet.set_length(0); + assert_eq!(0, packet.get_length()); + assert_eq!([0x00], packet.packet()[4..5]); + packet.set_length(8); + assert_eq!(8, packet.get_length()); + assert_eq!([0x08], packet.packet()[4..5]); + packet.set_length(u8::MAX); + assert_eq!(u8::MAX, packet.get_length()); + assert_eq!([0xFF], packet.packet()[4..5]); + } + #[test] fn test_view() { - let buf = [0x03, 0x00, 0xf4, 0xee, 0x00, 0x11, 0x00, 0x00]; + let buf = [0x03, 0x00, 0xf4, 0xee, 0x11, 0x00, 0x00, 0x00]; let packet = TimeExceededPacket::new_view(&buf).unwrap(); assert_eq!(IcmpType::TimeExceeded, packet.get_icmp_type()); assert_eq!(IcmpCode(0), packet.get_icmp_code()); assert_eq!(62702, packet.get_checksum()); + assert_eq!(17, packet.get_length()); assert!(packet.payload().is_empty()); } } @@ -799,13 +840,14 @@ pub mod time_exceeded { pub mod destination_unreachable { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmp_extension::extension_splitter::split; use crate::tracing::packet::icmpv6::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; - const UNUSED_OFFSET: usize = 4; + const LENGTH_OFFSET: usize = 4; const NEXT_HOP_MTU_OFFSET: usize = 6; /// Represents an ICMP `DestinationUnreachable` packet. @@ -860,8 +902,8 @@ pub mod destination_unreachable { } #[must_use] - pub fn get_unused(&self) -> u16 { - u16::from_be_bytes(self.buf.get_bytes(UNUSED_OFFSET)) + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) } #[must_use] @@ -881,8 +923,8 @@ pub mod destination_unreachable { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } - pub fn set_unused(&mut self, val: u16) { - self.buf.set_bytes(UNUSED_OFFSET, val.to_be_bytes()); + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; } pub fn set_next_hop_mtu(&mut self, val: u16) { @@ -902,7 +944,20 @@ pub mod destination_unreachable { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -912,7 +967,7 @@ pub mod destination_unreachable { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) - .field("unused", &self.get_unused()) + .field("length", &self.get_length()) .field("next_hop_mtu", &self.get_next_hop_mtu()) .field("payload", &fmt_payload(self.payload())) .finish() @@ -974,6 +1029,21 @@ pub mod destination_unreachable { assert_eq!([0xFF, 0xFF], packet.packet()[2..=3]); } + #[test] + fn test_length() { + let mut buf = [0_u8; DestinationUnreachablePacket::minimum_packet_size()]; + let mut packet = DestinationUnreachablePacket::new(&mut buf).unwrap(); + packet.set_length(0); + assert_eq!(0, packet.get_length()); + assert_eq!([0x00], packet.packet()[4..5]); + packet.set_length(8); + assert_eq!(8, packet.get_length()); + assert_eq!([0x08], packet.packet()[4..5]); + packet.set_length(u8::MAX); + assert_eq!(u8::MAX, packet.get_length()); + assert_eq!([0xFF], packet.packet()[4..5]); + } + #[test] fn test_view() { let buf = [0x01, 0x03, 0xdf, 0xdc, 0x00, 0x00, 0x00, 0x00]; @@ -981,6 +1051,7 @@ pub mod destination_unreachable { assert_eq!(IcmpType::DestinationUnreachable, packet.get_icmp_type()); assert_eq!(IcmpCode(3), packet.get_icmp_code()); assert_eq!(57308, packet.get_checksum()); + assert_eq!(0, packet.get_length()); assert!(packet.payload().is_empty()); } } diff --git a/src/tracing/probe.rs b/src/tracing/probe.rs index c9518d5e3..cbb4ac10c 100644 --- a/src/tracing/probe.rs +++ b/src/tracing/probe.rs @@ -3,7 +3,7 @@ use std::net::IpAddr; use std::time::{Duration, SystemTime}; /// The state of an ICMP echo request/response -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct Probe { /// The sequence of the probe. pub sequence: Sequence, @@ -27,6 +27,8 @@ pub struct Probe { pub received: Option, /// The type of ICMP response packet received for the probe. pub icmp_packet_type: Option, + /// The ICMP response extensions. + pub extensions: Option, } impl Probe { @@ -53,6 +55,7 @@ impl Probe { host: None, received: None, icmp_packet_type: None, + extensions: None, } } @@ -67,12 +70,12 @@ impl Probe { } #[must_use] - pub const fn with_status(self, status: ProbeStatus) -> Self { + pub fn with_status(self, status: ProbeStatus) -> Self { Self { status, ..self } } #[must_use] - pub const fn with_icmp_packet_type(self, icmp_packet_type: IcmpPacketType) -> Self { + pub fn with_icmp_packet_type(self, icmp_packet_type: IcmpPacketType) -> Self { Self { icmp_packet_type: Some(icmp_packet_type), ..self @@ -80,7 +83,7 @@ impl Probe { } #[must_use] - pub const fn with_host(self, host: IpAddr) -> Self { + pub fn with_host(self, host: IpAddr) -> Self { Self { host: Some(host), ..self @@ -88,12 +91,17 @@ impl Probe { } #[must_use] - pub const fn with_received(self, received: SystemTime) -> Self { + pub fn with_received(self, received: SystemTime) -> Self { Self { received: Some(received), ..self } } + + #[must_use] + pub fn with_extensions(self, extensions: Option) -> Self { + Self { extensions, ..self } + } } /// The status of a `Echo` for a single TTL. @@ -128,17 +136,46 @@ pub enum IcmpPacketType { } /// The response to a probe. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub enum ProbeResponse { - TimeExceeded(ProbeResponseData), - DestinationUnreachable(ProbeResponseData), + TimeExceeded(ProbeResponseData, Option), + DestinationUnreachable(ProbeResponseData, Option), EchoReply(ProbeResponseData), TcpReply(ProbeResponseData), TcpRefused(ProbeResponseData), } +/// The ICMP extensions for a probe response. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ProbeResponseExtensions { + pub extensions: Vec, +} + +/// A probe response extension. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum ProbeResponseExtension { + #[default] + Unknown, + Mpls(MplsExtensionData), +} + +/// The members of a MPLS probe response extension. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct MplsExtensionData { + pub members: Vec, +} + +/// A member of a MPLS probe response extension. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct MplsExtensionMember { + pub label: u32, + pub exp: u8, + pub bos: u8, + pub ttl: u8, +} + /// The data in the probe response. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseData { /// Timestamp of the probe response. pub recv: SystemTime, @@ -158,14 +195,14 @@ impl ProbeResponseData { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub enum ProbeResponseSeq { Icmp(ProbeResponseSeqIcmp), Udp(ProbeResponseSeqUdp), Tcp(ProbeResponseSeqTcp), } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqIcmp { pub identifier: u16, pub sequence: u16, @@ -180,7 +217,7 @@ impl ProbeResponseSeqIcmp { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqUdp { pub identifier: u16, pub src_port: u16, @@ -199,7 +236,7 @@ impl ProbeResponseSeqUdp { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] pub struct ProbeResponseSeqTcp { pub src_port: u16, pub dest_port: u16, diff --git a/src/tracing/tracer.rs b/src/tracing/tracer.rs index f3c8c5aca..53fa6c917 100644 --- a/src/tracing/tracer.rs +++ b/src/tracing/tracer.rs @@ -144,17 +144,19 @@ impl)> Tracer { fn recv_response(&self, network: &mut N, st: &mut TracerState) -> TraceResult<()> { let next = network.recv_probe()?; match next { - Some(ProbeResponse::TimeExceeded(data)) => { + Some(ProbeResponse::TimeExceeded(data, extensions)) => { let (trace_id, sequence, received, host) = self.extract(&data); let is_target = host == self.config.target_addr; if self.check_trace_id(trace_id) && st.in_round(sequence) { - st.complete_probe_time_exceeded(sequence, host, received, is_target); + st.complete_probe_time_exceeded( + sequence, host, received, is_target, extensions, + ); } } - Some(ProbeResponse::DestinationUnreachable(data)) => { + Some(ProbeResponse::DestinationUnreachable(data, extensions)) => { let (trace_id, sequence, received, host) = self.extract(&data); if self.check_trace_id(trace_id) && st.in_round(sequence) { - st.complete_probe_unreachable(sequence, host, received); + st.complete_probe_unreachable(sequence, host, received, extensions); } } Some(ProbeResponse::EchoReply(data)) => { @@ -279,11 +281,13 @@ impl)> Tracer { /// the `TracerState` struct. mod state { use crate::tracing::constants::MAX_SEQUENCE_PER_ROUND; + use crate::tracing::probe::ProbeResponseExtensions; use crate::tracing::types::{MaxRounds, Port, Round, Sequence, TimeToLive, TraceId}; use crate::tracing::{ IcmpPacketType, MultipathStrategy, PortDirection, Probe, ProbeStatus, TracerConfig, TracerProtocol, }; + use std::array::from_fn; use std::net::IpAddr; use std::time::SystemTime; use tracing::instrument; @@ -347,7 +351,7 @@ mod state { pub fn new(config: TracerConfig) -> Self { Self { config, - buffer: [Probe::default(); BUFFER_SIZE as usize], + buffer: from_fn(|_| Probe::default()), sequence: config.initial_sequence, round_sequence: config.initial_sequence, ttl: config.first_ttl, @@ -368,7 +372,7 @@ mod state { /// Get the `Probe` for `sequence` pub fn probe_at(&self, sequence: Sequence) -> Probe { - self.buffer[usize::from(sequence - self.round_sequence)] + self.buffer[usize::from(sequence - self.round_sequence)].clone() } pub const fn ttl(&self) -> TimeToLive { @@ -430,7 +434,7 @@ mod state { self.round, SystemTime::now(), ); - self.buffer[usize::from(self.sequence - self.round_sequence)] = probe; + self.buffer[usize::from(self.sequence - self.round_sequence)] = probe.clone(); debug_assert!(self.ttl < TimeToLive(u8::MAX)); self.ttl += TimeToLive(1); debug_assert!(self.sequence < Sequence(u16::MAX)); @@ -460,7 +464,7 @@ mod state { self.round, SystemTime::now(), ); - self.buffer[usize::from(self.sequence - self.round_sequence)] = probe; + self.buffer[usize::from(self.sequence - self.round_sequence)] = probe.clone(); debug_assert!(self.sequence < Sequence(u16::MAX)); self.sequence += Sequence(1); probe @@ -553,6 +557,7 @@ mod state { host: IpAddr, received: SystemTime, is_target: bool, + extensions: Option, ) { self.complete_probe( sequence, @@ -560,6 +565,7 @@ mod state { host, received, is_target, + extensions, ); } @@ -570,8 +576,16 @@ mod state { sequence: Sequence, host: IpAddr, received: SystemTime, + extensions: Option, ) { - self.complete_probe(sequence, IcmpPacketType::Unreachable, host, received, true); + self.complete_probe( + sequence, + IcmpPacketType::Unreachable, + host, + received, + true, + extensions, + ); } /// Mark the `Probe` at `sequence` completed as `EchoReply` and update the round state. @@ -582,7 +596,14 @@ mod state { host: IpAddr, received: SystemTime, ) { - self.complete_probe(sequence, IcmpPacketType::EchoReply, host, received, true); + self.complete_probe( + sequence, + IcmpPacketType::EchoReply, + host, + received, + true, + None, + ); } /// Mark the `Probe` at `sequence` completed as `NotApplicable` and update the round state. @@ -599,6 +620,7 @@ mod state { host, received, true, + None, ); } @@ -623,6 +645,7 @@ mod state { host: IpAddr, received: SystemTime, is_target: bool, + extensions: Option, ) { // Retrieve and update the `Probe` at `sequence`. let probe = self @@ -630,8 +653,9 @@ mod state { .with_status(ProbeStatus::Complete) .with_icmp_packet_type(icmp_packet_type) .with_host(host) - .with_received(received); - self.buffer[usize::from(sequence - self.round_sequence)] = probe; + .with_received(received) + .with_extensions(extensions); + self.buffer[usize::from(sequence - self.round_sequence)] = probe.clone(); // If this `Probe` found the target then we set the `target_tll` if not already set, // being careful to account for `Probes` being received out-of-order. @@ -737,7 +761,7 @@ mod state { // Update the state of the probe 1 after receiving a TimeExceeded let received_1 = SystemTime::now(); let host = IpAddr::V4(Ipv4Addr::LOCALHOST); - state.complete_probe_time_exceeded(Sequence(33000), host, received_1, false); + state.complete_probe_time_exceeded(Sequence(33000), host, received_1, false, None); // Validate the state of the probe 1 after the update let probe_1_fetch = state.probe_at(Sequence(33000)); @@ -766,8 +790,8 @@ mod state { // Validate the probes() iterator returns returns only a single probe { let mut probe_iter = state.probes().iter(); - let probe_next1 = *probe_iter.next().unwrap(); - assert_eq!(probe_1_fetch, probe_next1); + let probe_next1 = probe_iter.next().unwrap(); + assert_eq!(&probe_1_fetch, probe_next1); assert_eq!(None, probe_iter.next()); } @@ -809,7 +833,7 @@ mod state { // Update the state of probe 2 after receiving a TimeExceeded let received_2 = SystemTime::now(); let host = IpAddr::V4(Ipv4Addr::LOCALHOST); - state.complete_probe_time_exceeded(Sequence(33001), host, received_2, false); + state.complete_probe_time_exceeded(Sequence(33001), host, received_2, false, None); let probe_2_recv = state.probe_at(Sequence(33001)); // Validate the TracerState after the update to probe 2 @@ -825,10 +849,10 @@ mod state { // Validate the probes() iterator returns the two probes in the states we expect { let mut probe_iter = state.probes().iter(); - let probe_next1 = *probe_iter.next().unwrap(); - assert_eq!(probe_2_recv, probe_next1); - let probe_next2 = *probe_iter.next().unwrap(); - assert_eq!(probe_3, probe_next2); + let probe_next1 = probe_iter.next().unwrap(); + assert_eq!(&probe_2_recv, probe_next1); + let probe_next2 = probe_iter.next().unwrap(); + assert_eq!(&probe_3, probe_next2); } // Update the state of probe 3 after receiving a EchoReply @@ -850,10 +874,10 @@ mod state { // Validate the probes() iterator returns the two probes in the states we expect { let mut probe_iter = state.probes().iter(); - let probe_next1 = *probe_iter.next().unwrap(); - assert_eq!(probe_2_recv, probe_next1); - let probe_next2 = *probe_iter.next().unwrap(); - assert_eq!(probe_3_recv, probe_next2); + let probe_next1 = probe_iter.next().unwrap(); + assert_eq!(&probe_2_recv, probe_next1); + let probe_next2 = probe_iter.next().unwrap(); + assert_eq!(&probe_3_recv, probe_next2); } } diff --git a/trippy-config-sample.toml b/trippy-config-sample.toml index 18745d4ba..0a0674808 100644 --- a/trippy-config-sample.toml +++ b/trippy-config-sample.toml @@ -151,6 +151,18 @@ payload-pattern = 0 # This is also known as DSCP+ECN. tos = 0 +# Whether to parse ICMP extensions. +# +# If enabled, all extensions attached to incoming ICMP TimeExceeded and DestinationUnavailable messages will be parsed +# and proved as part of the trace response data. +# +# The following ICMP Extension Object Classes are supported: +# 1 - MPLS Label Stack Class (RFC4950) +# +# Extension objects with an unknown class will be parsed to capture generic information including the class, subtype, +# length and payload bytes. +icmp_extensions = false + # The socket read timeout [default: 10ms] read-timeout = "10ms"