From afbfa60ad3014c4be647caea637da5024cae4400 Mon Sep 17 00:00:00 2001 From: Nick Pillitteri Date: Mon, 8 Apr 2024 21:38:30 -0400 Subject: [PATCH] Read DNS settings from a resolv.conf file Allow some DNS settings to be read from a system's /etc/resolv.conf file. Noteably, only the `nameserver` and a few `option`s are supported for now. This change also switches resolution of A and AAAA names to our DNS client instead of using the system resolver via `lookup_host`. Part of #107 --- mtop-client/src/core.rs | 13 +- mtop-client/src/discovery.rs | 149 ++++++++------ mtop-client/src/dns/client.rs | 93 +++++++-- mtop-client/src/dns/message.rs | 18 +- mtop-client/src/dns/mod.rs | 2 + mtop-client/src/dns/name.rs | 157 +++++++++++++-- mtop-client/src/dns/resolv.rs | 345 +++++++++++++++++++++++++++++++++ mtop-client/src/pool.rs | 36 ++-- mtop/src/bin/dns.rs | 39 ++-- mtop/src/bin/mc.rs | 16 +- mtop/src/bin/mtop.rs | 18 +- mtop/src/check.rs | 13 +- mtop/src/dns.rs | 43 ++++ mtop/src/lib.rs | 1 + rustfmt.toml | 1 + 15 files changed, 762 insertions(+), 182 deletions(-) create mode 100644 mtop-client/src/dns/resolv.rs create mode 100644 mtop/src/dns.rs diff --git a/mtop-client/src/core.rs b/mtop-client/src/core.rs index baa29c0..fdfee68 100644 --- a/mtop-client/src/core.rs +++ b/mtop-client/src/core.rs @@ -206,10 +206,7 @@ impl TryFrom<&HashMap> for Slabs { // $active_slabs + 1. let mut ids = BTreeSet::new(); for k in value.keys() { - let key_id: Option = k - .split_once(':') - .map(|(raw, _rest)| raw) - .and_then(|raw| raw.parse().ok()); + let key_id: Option = k.split_once(':').map(|(raw, _rest)| raw).and_then(|raw| raw.parse().ok()); if let Some(id) = key_id { ids.insert(id); @@ -1287,9 +1284,7 @@ mod test { macro_rules! test_store_command_success { ($method:ident, $verb:expr) => { let (mut rx, mut client) = client!("STORED\r\n"); - let res = client - .$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()) - .await; + let res = client.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()).await; assert!(res.is_ok()); let bytes = rx.recv().await.unwrap(); @@ -1301,9 +1296,7 @@ mod test { macro_rules! test_store_command_error { ($method:ident, $verb:expr) => { let (mut rx, mut client) = client!("NOT_STORED\r\n"); - let res = client - .$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()) - .await; + let res = client.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()).await; assert!(res.is_err()); let err = res.unwrap_err(); diff --git a/mtop-client/src/discovery.rs b/mtop-client/src/discovery.rs index 662455c..bc7c0ed 100644 --- a/mtop-client/src/discovery.rs +++ b/mtop-client/src/discovery.rs @@ -1,5 +1,5 @@ use crate::core::MtopError; -use crate::dns::{DnsClient, RecordData}; +use crate::dns::{DnsClient, Name, Record, RecordClass, RecordData, RecordType}; use std::cmp::Ordering; use std::fmt; use std::net::{IpAddr, SocketAddr}; @@ -60,47 +60,25 @@ impl AsRef for ServerID { /// An individual server that is part of a Memcached cluster. #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct Server { - repr: ServerRepr, -} - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -enum ServerRepr { - Resolved(ServerID, ServerName<'static>, SocketAddr), - Unresolved(ServerID, ServerName<'static>), + id: ServerID, + name: ServerName<'static>, } impl Server { - pub fn from_id(id: ServerID, name: ServerName<'static>) -> Self { - Self { - repr: ServerRepr::Unresolved(id, name), - } - } - - pub fn from_addr(addr: SocketAddr, name: ServerName<'static>) -> Self { - Self { - repr: ServerRepr::Resolved(ServerID::from(addr), name, addr), - } + pub fn new(id: ServerID, name: ServerName<'static>) -> Self { + Self { id, name } } pub fn id(&self) -> ServerID { - match &self.repr { - ServerRepr::Resolved(id, _, _) => id.clone(), - ServerRepr::Unresolved(id, _) => id.clone(), - } + self.id.clone() } pub fn server_name(&self) -> ServerName<'static> { - match &self.repr { - ServerRepr::Resolved(_, name, _) => name.clone(), - ServerRepr::Unresolved(_, name) => name.clone(), - } + self.name.clone() } pub fn address(&self) -> String { - match &self.repr { - ServerRepr::Resolved(_, _, addr) => addr.to_string(), - ServerRepr::Unresolved(id, _) => id.to_string(), - } + self.id.to_string() } } @@ -134,44 +112,56 @@ impl DiscoveryDefault { pub async fn resolve_by_proto(&self, name: &str) -> Result, MtopError> { if name.starts_with(DNS_A_PREFIX) { - Ok(self.resolve_a(name.trim_start_matches(DNS_A_PREFIX)).await?) + Ok(self.resolve_a_aaaa(name.trim_start_matches(DNS_A_PREFIX)).await?) } else if name.starts_with(DNS_SRV_PREFIX) { Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?) } else { - Ok(self.resolve_a(name).await?.pop().into_iter().collect()) + Ok(self.resolve_a_aaaa(name).await?.pop().into_iter().collect()) } } async fn resolve_srv(&self, name: &str) -> Result, MtopError> { - let server_name = Self::server_name(name)?; - let (host_name, port) = Self::host_and_port(name)?; - let mut out = Vec::new(); + let (host, port) = Self::host_and_port(name)?; + let server_name = Self::server_name(host)?; + let name = host.parse()?; - let res = self.client.resolve_srv(host_name).await?; - for a in res.answers() { - let target = if let RecordData::SRV(srv) = a.rdata() { - srv.target().to_string() - } else { - tracing::warn!(message = "unexpected record data for answer", name = host_name, answer = ?a); - continue; - }; - let server_id = ServerID::from((target, port)); - let server = Server::from_id(server_id, server_name.clone()); - out.push(server); - } + let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?; + Ok(self.servers_from_answers(port, &server_name, res.answers())) + } + + async fn resolve_a_aaaa(&self, name: &str) -> Result, MtopError> { + let (host, port) = Self::host_and_port(name)?; + let server_name = Self::server_name(host)?; + let name: Name = host.parse()?; + + let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?; + let mut out = self.servers_from_answers(port, &server_name, res.answers()); + + let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?; + out.extend(self.servers_from_answers(port, &server_name, res.answers())); Ok(out) } - async fn resolve_a(&self, name: &str) -> Result, MtopError> { - let server_name = Self::server_name(name)?; - + fn servers_from_answers(&self, port: u16, server_name: &ServerName, answers: &[Record]) -> Vec { let mut out = Vec::new(); - for addr in tokio::net::lookup_host(name).await? { - out.push(Server::from_addr(addr, server_name.clone())); + + for answer in answers { + let server_id = match answer.rdata() { + RecordData::A(data) => ServerID::from(SocketAddr::new(IpAddr::V4(data.addr()), port)), + RecordData::AAAA(data) => ServerID::from(SocketAddr::new(IpAddr::V6(data.addr()), port)), + RecordData::SRV(data) => ServerID::from((data.target().to_string(), port)), + _ => { + tracing::warn!(message = "unexpected record data for answer", answer = ?answer); + continue; + } + }; + + let server = Server::new(server_id, server_name.to_owned()); + out.push(server); } - Ok(out) + out } fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> { @@ -185,27 +175,26 @@ impl DiscoveryDefault { // IPv6 addresses use brackets around them to disambiguate them from a port number. // Since we're parsing the host and port, strip the brackets because they aren't // needed or valid to include in a TLS ServerName. - .map(|(hostname, port)| (hostname.trim_start_matches('[').trim_end_matches(']'), port)) - .and_then(|(hostname, port)| { - port.parse().map(|p| (hostname, p)).map_err(|e| { + .map(|(host, port)| (host.trim_start_matches('[').trim_end_matches(']'), port)) + .and_then(|(host, port)| { + port.parse().map(|p| (host, p)).map_err(|e| { MtopError::configuration_cause(format!("unable to parse port number from '{}'", name), e) }) }) } - fn server_name(name: &str) -> Result, MtopError> { - Self::host_and_port(name).and_then(|(host, _)| { - ServerName::try_from(host) - .map(|s| s.to_owned()) - .map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e)) - }) + fn server_name(host: &str) -> Result, MtopError> { + ServerName::try_from(host) + .map(|s| s.to_owned()) + .map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e)) } } #[cfg(test)] mod test { - use super::ServerID; + use super::{Server, ServerID}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; + use webpki::types::ServerName; #[test] fn test_server_id_from_ipv4_addr() { @@ -241,4 +230,38 @@ mod test { let id = ServerID::from(pair); assert_eq!("cache.example.com:11211", id.to_string()); } + + #[test] + fn test_server_resolved_id() { + let addr = SocketAddr::from(([127, 0, 0, 1], 11211)); + let id = ServerID::from(addr); + let name = ServerName::try_from("cache.example.com").unwrap(); + let server = Server::new(id, name); + assert_eq!("127.0.0.1:11211", server.id().to_string()); + } + + #[test] + fn test_server_resolved_address() { + let addr = SocketAddr::from(([127, 0, 0, 1], 11211)); + let id = ServerID::from(addr); + let name = ServerName::try_from("cache.example.com").unwrap(); + let server = Server::new(id, name); + assert_eq!("127.0.0.1:11211", server.address()); + } + + #[test] + fn test_server_unresolved_id() { + let id = ServerID::from(("cache01.example.com", 11211)); + let name = ServerName::try_from("cache.example.com").unwrap(); + let server = Server::new(id, name); + assert_eq!("cache01.example.com:11211", server.id().to_string()); + } + + #[test] + fn test_server_unresolved_address() { + let id = ServerID::from(("cache01.example.com", 11211)); + let name = ServerName::try_from("cache.example.com").unwrap(); + let server = Server::new(id, name); + assert_eq!("cache01.example.com:11211", server.address()); + } } diff --git a/mtop-client/src/dns/client.rs b/mtop-client/src/dns/client.rs index 66e80de..771922b 100644 --- a/mtop-client/src/dns/client.rs +++ b/mtop-client/src/dns/client.rs @@ -1,44 +1,79 @@ use crate::core::MtopError; -use crate::dns::core::RecordType; +use crate::dns::core::{RecordClass, RecordType}; use crate::dns::message::{Flags, Message, MessageId, Question}; use crate::dns::name::Name; +use crate::dns::resolv::ResolvConf; +use crate::timeout::Timeout; use std::io::Cursor; use std::net::SocketAddr; -use std::str::FromStr; +use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::net::UdpSocket; const DEFAULT_RECV_BUF: usize = 512; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct DnsClient { local: SocketAddr, - server: SocketAddr, + config: ResolvConf, + server: AtomicUsize, } impl DnsClient { - pub fn new(local: SocketAddr, server: SocketAddr) -> Self { - Self { local, server } - } - - pub async fn exchange(&self, msg: &Message) -> Result { - let id = msg.id(); - let sock = self.connect_udp().await?; - self.send_udp(&sock, msg).await?; - self.recv_udp(&sock, id).await + /// Create a new DnsClient that will use a local address to open UDP or TCP + /// connections and behavior based on a resolv.conf configuration file. + pub fn new(local: SocketAddr, config: ResolvConf) -> Self { + Self { + local, + config, + server: AtomicUsize::new(0), + } } - pub async fn resolve_srv(&self, name: &str) -> Result { - let n = Name::from_str(name)?; + /// Perform a DNS lookup with the configured nameservers. + /// + /// Timeouts and network errors will result in up to one additional attempt + /// to perform a DNS lookup when using the default configuration. + pub async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result { + let full = name.to_fqdn(); let id = MessageId::random(); let flags = Flags::default().set_recursion_desired(); - let msg = Message::new(id, flags).add_question(Question::new(n, RecordType::SRV)); + let question = Question::new(full, rtype).set_qclass(rclass); + let message = Message::new(id, flags).add_question(question); + + let mut attempt = 0; + loop { + match self.exchange(&message, usize::from(attempt)).await { + Ok(v) => return Ok(v), + Err(e) => { + if attempt + 1 >= self.config.options.attempts { + return Err(e); + } - self.exchange(&msg).await + tracing::debug!(message = "retrying failed query", attempt = attempt + 1, max_attempts = self.config.options.attempts, err = %e); + attempt += 1; + } + } + } } + async fn exchange(&self, msg: &Message, attempt: usize) -> Result { + let id = msg.id(); + let server = self.nameserver(attempt); - async fn connect_udp(&self) -> Result { + // Wrap creating the socket, sending, and receiving in an async block + // so that we can apply a single timeout to all operations and ensure + // we have access to the nameserver to make the timeout message useful. + async { + let sock = self.connect_udp(server).await?; + self.send_udp(&sock, msg).await?; + self.recv_udp(&sock, id).await + } + .timeout(self.config.options.timeout, format!("client.exchange {}", server)) + .await + } + + async fn connect_udp(&self, server: SocketAddr) -> Result { let sock = UdpSocket::bind(&self.local).await?; - sock.connect(&self.server).await?; + sock.connect(server).await?; Ok(sock) } @@ -59,4 +94,24 @@ impl DnsClient { } } } + + fn nameserver(&self, attempt: usize) -> SocketAddr { + let idx = if self.config.options.rotate { + self.server.fetch_add(1, Ordering::Relaxed) + } else { + attempt + }; + + self.config.nameservers[idx % self.config.nameservers.len()] + } +} + +impl Clone for DnsClient { + fn clone(&self) -> Self { + Self { + local: self.local, + config: self.config.clone(), + server: AtomicUsize::new(0), + } + } } diff --git a/mtop-client/src/dns/message.rs b/mtop-client/src/dns/message.rs index d369e00..7611036 100644 --- a/mtop-client/src/dns/message.rs +++ b/mtop-client/src/dns/message.rs @@ -240,18 +240,12 @@ impl Header { pub struct Flags(u16); impl Flags { - const MASK_QR: u16 = 0b1000_0000_0000_0000; - // query / response - const MASK_OP: u16 = 0b0111_1000_0000_0000; - // 4 bits, op code - const MASK_AA: u16 = 0b0000_0100_0000_0000; - // authoritative answer - const MASK_TC: u16 = 0b0000_0010_0000_0000; - // truncated - const MASK_RD: u16 = 0b0000_0001_0000_0000; - // recursion desired - const MASK_RA: u16 = 0b0000_0000_1000_0000; - // recursion available + const MASK_QR: u16 = 0b1000_0000_0000_0000; // query / response + const MASK_OP: u16 = 0b0111_1000_0000_0000; // 4 bits, op code + const MASK_AA: u16 = 0b0000_0100_0000_0000; // authoritative answer + const MASK_TC: u16 = 0b0000_0010_0000_0000; // truncated + const MASK_RD: u16 = 0b0000_0001_0000_0000; // recursion desired + const MASK_RA: u16 = 0b0000_0000_1000_0000; // recursion available const MASK_RC: u16 = 0b0000_0000_0000_1111; // 4 bits, response code const OFFSET_QR: usize = 15; diff --git a/mtop-client/src/dns/mod.rs b/mtop-client/src/dns/mod.rs index d8502bb..e1d4d3a 100644 --- a/mtop-client/src/dns/mod.rs +++ b/mtop-client/src/dns/mod.rs @@ -3,6 +3,7 @@ mod core; mod message; mod name; mod rdata; +mod resolv; pub use crate::dns::client::DnsClient; pub use crate::dns::core::{RecordClass, RecordType}; @@ -12,3 +13,4 @@ pub use crate::dns::rdata::{ RecordData, RecordDataA, RecordDataAAAA, RecordDataCNAME, RecordDataNS, RecordDataSOA, RecordDataSRV, RecordDataTXT, RecordDataUnknown, }; +pub use resolv::{config, ResolvConf, ResolvConfOptions}; diff --git a/mtop-client/src/dns/name.rs b/mtop-client/src/dns/name.rs index 4511bcd..2472317 100644 --- a/mtop-client/src/dns/name.rs +++ b/mtop-client/src/dns/name.rs @@ -8,6 +8,7 @@ use std::str::FromStr; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Name { labels: Vec, + is_fqdn: bool, } impl Name { @@ -16,7 +17,10 @@ impl Name { const MAX_POINTERS: u32 = 64; pub fn root() -> Self { - Name { labels: Vec::new() } + Name { + labels: Vec::new(), + is_fqdn: true, + } } pub fn size(&self) -> u16 { @@ -24,13 +28,39 @@ impl Name { } pub fn is_root(&self) -> bool { - self.labels.is_empty() + self.labels.is_empty() && self.is_fqdn + } + + pub fn is_fqdn(&self) -> bool { + self.is_fqdn + } + + pub fn to_fqdn(mut self) -> Self { + self.is_fqdn = true; + self + } + + pub fn append(mut self, other: Name) -> Self { + if self.is_fqdn { + return self; + } + + self.labels.extend(other.labels); + Self { + labels: self.labels, + is_fqdn: other.is_fqdn, + } } pub fn write_network_bytes(&self, mut buf: T) -> Result<(), MtopError> where T: WriteBytesExt, { + // We convert all incoming Names to fully qualified names. If we missed doing + // that, it's a bug and we should panic here. Encoded names all end with the + // root so trying to encode something that doesn't makes no sense. + assert!(self.is_fqdn, "only fully qualified domains can be encoded"); + for label in self.labels.iter() { buf.write_u8(label.len() as u8)?; buf.write_all(label.as_bytes())?; @@ -167,7 +197,8 @@ impl Name { impl Display for Name { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}.", self.labels.join(".")) + let suffix = if self.is_fqdn { "." } else { "" }; + write!(f, "{}{}", self.labels.join("."), suffix) } } @@ -187,13 +218,7 @@ impl FromStr for Name { ))); } - if !s.ends_with('.') { - return Err(MtopError::runtime(format!( - "Names must be fully qualified and end with a '.': {}", - s - ))); - } - + let is_fqdn = s.ends_with('.'); let mut labels = Vec::new(); for label in s.trim_end_matches('.').split('.') { let len = label.len(); @@ -227,7 +252,7 @@ impl FromStr for Name { labels.push(label.to_lowercase()); } - Ok(Name { labels }) + Ok(Name { labels, is_fqdn }) } } @@ -276,27 +301,120 @@ mod test { } #[test] - fn test_name_from_str_error_not_fqdn() { - let res = Name::from_str("localhost"); - assert!(res.is_err()); + fn test_name_from_str_success_not_fqdn() { + let name = Name::from_str("example.com").unwrap(); + assert!(!name.is_root()); + assert!(!name.is_fqdn()); } #[test] fn test_name_from_str_success_fqdn() { let name = Name::from_str("example.com.").unwrap(); assert!(!name.is_root()); + assert!(name.is_fqdn()); } #[test] fn test_name_from_str_success_root_empty() { let name = Name::from_str("").unwrap(); assert!(name.is_root()); + assert!(name.is_fqdn()); } #[test] fn test_name_from_str_success_root_dot() { let name = Name::from_str(".").unwrap(); assert!(name.is_root()); + assert!(name.is_fqdn()); + } + + #[test] + fn test_name_to_string_not_fqdn() { + let name = Name::from_str("example.com").unwrap(); + assert_eq!("example.com", name.to_string()); + assert!(!name.is_fqdn()); + } + + #[test] + fn test_name_to_string_fqdn() { + let name = Name::from_str("example.com.").unwrap(); + assert_eq!("example.com.", name.to_string()); + assert!(name.is_fqdn()); + } + + #[test] + fn test_name_to_string_root() { + let name = Name::root(); + assert_eq!(".", name.to_string()); + assert!(name.is_fqdn()); + } + + #[test] + fn test_to_fqdn_not_fqdn() { + let name = Name::from_str("example.com").unwrap(); + assert!(!name.is_fqdn()); + + let fqdn = name.to_fqdn(); + assert!(fqdn.is_fqdn()); + } + + #[test] + fn test_to_fqdn_already_fqdn() { + let name = Name::from_str("example.com.").unwrap(); + assert!(name.is_fqdn()); + + let fqdn = name.to_fqdn(); + assert!(fqdn.is_fqdn()); + } + + #[test] + fn test_name_append_already_fqdn() { + let name1 = Name::from_str("example.com.").unwrap(); + let name2 = Name::from_str("example.net.").unwrap(); + let combined = name1.clone().append(name2); + + assert_eq!(name1, combined); + assert!(combined.is_fqdn()); + } + + #[test] + fn test_name_append_with_non_fqdn() { + let name1 = Name::from_str("www").unwrap(); + let name2 = Name::from_str("example").unwrap(); + let combined = name1.clone().append(name2); + + assert_eq!(Name::from_str("www.example").unwrap(), combined); + assert!(!combined.is_fqdn()); + } + + #[test] + fn test_name_append_with_fqdn() { + let name1 = Name::from_str("www").unwrap(); + let name2 = Name::from_str("example.net.").unwrap(); + let combined = name1.clone().append(name2); + + assert_eq!(Name::from_str("www.example.net.").unwrap(), combined); + assert!(combined.is_fqdn()); + } + + #[test] + fn test_name_append_with_root() { + let name = Name::from_str("example.com").unwrap(); + let combined = name.clone().append(Name::root()); + + assert_eq!(Name::from_str("example.com.").unwrap(), combined); + assert!(combined.is_fqdn()); + } + + #[test] + fn test_name_append_multiple() { + let name1 = Name::from_str("dev").unwrap(); + let name2 = Name::from_str("www").unwrap(); + let name3 = Name::from_str("example.com").unwrap(); + + let combined = name1.append(name2).append(name3).append(Name::root()); + assert_eq!(Name::from_str("dev.www.example.com.").unwrap(), combined); + assert!(combined.is_fqdn()); } #[test] @@ -341,6 +459,14 @@ mod test { ); } + #[should_panic] + #[test] + fn test_name_write_network_bytes_not_fqdn() { + let mut cur = Cursor::new(Vec::new()); + let name = Name::from_str("example.com").unwrap(); + let _ = name.write_network_bytes(&mut cur); + } + #[rustfmt::skip] #[test] fn test_name_read_network_bytes_no_pointer() { @@ -354,6 +480,7 @@ mod test { let name = Name::read_network_bytes(cur).unwrap(); assert_eq!("example.com.", name.to_string()); + assert!(name.is_fqdn()); } #[rustfmt::skip] @@ -374,6 +501,7 @@ mod test { let name = Name::read_network_bytes(cur).unwrap(); assert_eq!("www.example.com.", name.to_string()); + assert!(name.is_fqdn()); } #[rustfmt::skip] @@ -397,6 +525,7 @@ mod test { let name = Name::read_network_bytes(cur).unwrap(); assert_eq!("dev.www.example.com.", name.to_string()); + assert!(name.is_fqdn()); } #[test] diff --git a/mtop-client/src/dns/resolv.rs b/mtop-client/src/dns/resolv.rs new file mode 100644 index 0000000..5fe5220 --- /dev/null +++ b/mtop-client/src/dns/resolv.rs @@ -0,0 +1,345 @@ +use crate::core::MtopError; +use std::fmt::Debug; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use std::time::Duration; +use tokio::io::{AsyncBufReadExt, AsyncRead, BufReader}; + +const DEFAULT_PORT: u16 = 53; +const MAX_NAMESERVERS: usize = 3; + +/// Configuration for a DNS client based on a parsed resolv.conf file. +/// +/// Note that only the `nameserver` setting and a few `option`s are supported. +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub struct ResolvConf { + pub nameservers: Vec, + pub options: ResolvConfOptions, +} + +/// Options to change the behavior of a DNS client based on a resolv.conf file. +/// +/// Note that only a subset of options are supported. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct ResolvConfOptions { + pub timeout: Duration, + pub attempts: u8, + pub rotate: bool, +} + +impl Default for ResolvConfOptions { + fn default() -> Self { + // Defaults picked based on `man 5 resolv.conf` + Self { + timeout: Duration::from_secs(5), + attempts: 2, + rotate: false, + } + } +} + +/// Read settings for a DNS client from a resolv.conf configuration file. +pub async fn config(read: R) -> Result +where + R: AsyncRead + Send + Sync + Unpin + 'static, +{ + let mut lines = BufReader::new(read).lines(); + let mut conf = ResolvConf::default(); + + while let Some(line) = lines.next_line().await? { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + + let mut parts = line.split_whitespace(); + let key = match parts.next() { + Some(k) => k, + None => { + tracing::debug!(message = "skipping malformed resolv.conf line", line = line); + continue; + } + }; + + match Token::get(key) { + Some(Token::NameServer) => { + if conf.nameservers.len() < MAX_NAMESERVERS { + conf.nameservers.push(parse_nameserver(line, parts)?); + } + } + Some(Token::Options) => { + for opt in parse_options(parts) { + match opt { + OptionsToken::Timeout(t) => { + conf.options.timeout = Duration::from_secs(u64::from(t)); + } + OptionsToken::Attempts(n) => { + conf.options.attempts = n; + } + OptionsToken::Rotate => { + conf.options.rotate = true; + } + } + } + } + None => { + tracing::debug!( + message = "skipping unknown resolv.conf setting", + setting = key, + line = line + ); + continue; + } + } + } + + Ok(conf) +} + +/// Parse a single nameserver IP address, adding a default port of 53, from a `nameserver` +/// line in a resolv.conf file, returning an error if the address is malformed. +fn parse_nameserver<'a>(line: &str, mut parts: impl Iterator) -> Result { + if let Some(part) = parts.next() { + part.parse::() + .map(|ip| (ip, DEFAULT_PORT).into()) + .map_err(|e| MtopError::configuration_cause(format!("malformed nameserver address '{}'", part), e)) + } else { + Err(MtopError::configuration(format!( + "malformed nameserver configuration '{}'", + line + ))) + } +} + +/// Parse one or more options from an `option` line in a resolv.conf file, ignoring any +/// malformed or unsupported options. +fn parse_options<'a>(parts: impl Iterator) -> Vec { + let mut out = Vec::new(); + + for part in parts { + let opt = match part.parse() { + Ok(o) => o, + Err(e) => { + tracing::debug!(message = "skipping unknown resolv.conf option", option = part, err = %e); + continue; + } + }; + + out.push(opt); + } + out +} + +/// Top-level configuration setting in a resolv.conf file. +/// +/// Note that only a subset of all possible settings are supported. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +enum Token { + NameServer, + Options, +} + +impl Token { + fn get(s: &str) -> Option { + match s { + "nameserver" => Some(Self::NameServer), + "options" => Some(Self::Options), + _ => None, + } + } +} + +/// Keyword or key-value pair associated with an option token. +/// +/// Note that only a subset of all possible options are supported. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +enum OptionsToken { + Timeout(u8), + Attempts(u8), + Rotate, +} + +impl OptionsToken { + const MAX_TIMEOUT: u8 = 30; + const MAX_ATTEMPTS: u8 = 5; + + fn parse(line: &str, val: &str, max: u8) -> Result { + let n = val + .parse() + .map_err(|e| MtopError::configuration_cause(format!("unable to parse {} value '{}'", line, val), e))?; + if n > max { + Ok(max) + } else { + Ok(n) + } + } +} + +impl FromStr for OptionsToken { + type Err = MtopError; + + fn from_str(s: &str) -> Result { + if s == "rotate" { + Ok(Self::Rotate) + } else { + match s.split_once(':') { + Some(("timeout", v)) => Ok(Self::Timeout(Self::parse(s, v, Self::MAX_TIMEOUT)?)), + Some(("attempts", v)) => Ok(Self::Attempts(Self::parse(s, v, Self::MAX_ATTEMPTS)?)), + _ => Err(MtopError::configuration(format!("unknown option {}", s))), + } + } + } +} + +#[cfg(test)] +mod test { + use super::{config, OptionsToken, Token}; + use crate::core::ErrorKind; + use crate::dns::{ResolvConf, ResolvConfOptions}; + use std::io::{Cursor, Error as IOError, ErrorKind as IOErrorKind}; + use std::pin::Pin; + use std::str::FromStr; + use std::task::{Context, Poll}; + use std::time::Duration; + use tokio::io::{AsyncRead, ReadBuf}; + + #[test] + fn test_configuration() { + assert_eq!(Some(Token::NameServer), Token::get("nameserver")); + assert_eq!(Some(Token::Options), Token::get("options")); + assert_eq!(None, Token::get("invalid")); + } + + #[test] + fn test_configuration_option_success() { + assert_eq!(OptionsToken::Rotate, OptionsToken::from_str("rotate").unwrap()); + assert_eq!(OptionsToken::Timeout(3), OptionsToken::from_str("timeout:3").unwrap()); + assert_eq!(OptionsToken::Attempts(4), OptionsToken::from_str("attempts:4").unwrap()); + } + + #[test] + fn test_configuration_option_limits() { + assert_eq!(OptionsToken::Timeout(30), OptionsToken::from_str("timeout:35").unwrap()); + assert_eq!( + OptionsToken::Attempts(5), + OptionsToken::from_str("attempts:10").unwrap() + ); + } + + #[test] + fn test_configuration_option_error() { + assert!(OptionsToken::from_str("ndots:bad").is_err()); + assert!(OptionsToken::from_str("timeout:bad").is_err()); + assert!(OptionsToken::from_str("attempts:-5").is_err()); + } + + #[tokio::test] + async fn test_config_read_error() { + struct ErrAsyncRead; + impl AsyncRead for ErrAsyncRead { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(IOError::new(IOErrorKind::UnexpectedEof, "test error"))) + } + } + + let reader = ErrAsyncRead; + let res = config(reader).await.unwrap_err(); + assert_eq!(ErrorKind::IO, res.kind()); + } + + #[tokio::test] + async fn test_config_no_content() { + let reader = Cursor::new(Vec::new()); + let res = config(reader).await.unwrap(); + assert_eq!(ResolvConf::default(), res); + } + + #[tokio::test] + async fn test_config_all_comments() { + #[rustfmt::skip] + let reader = Cursor::new(concat!( + "# this is a comment\n", + "# another comment\n", + )); + let res = config(reader).await.unwrap(); + assert_eq!(ResolvConf::default(), res); + } + + #[tokio::test] + async fn test_config_all_unsupported() { + #[rustfmt::skip] + let reader = Cursor::new(concat!( + "scrambler 127.0.0.5\n", + "invalid directive\n", + )); + let res = config(reader).await.unwrap(); + assert_eq!(ResolvConf::default(), res); + } + + #[tokio::test] + async fn test_config_nameservers_search_invalid_options() { + #[rustfmt::skip] + let reader = Cursor::new(concat!( + "# this is a comment\n", + "nameserver 127.0.0.53\n", + "options casual-fridays:true\n", + )); + + let expected = ResolvConf { + nameservers: vec!["127.0.0.53:53".parse().unwrap()], + options: Default::default(), + }; + + let res = config(reader).await.unwrap(); + assert_eq!(expected, res); + } + + #[tokio::test] + async fn test_config_nameservers_search_no_options() { + #[rustfmt::skip] + let reader = Cursor::new(concat!( + "# this is a comment\n", + "nameserver 127.0.0.53\n", + )); + + let expected = ResolvConf { + nameservers: vec!["127.0.0.53:53".parse().unwrap()], + options: Default::default(), + }; + + let res = config(reader).await.unwrap(); + assert_eq!(expected, res); + } + + #[tokio::test] + async fn test_config_nameservers_search_options() { + #[rustfmt::skip] + let reader = Cursor::new(concat!( + "# this is a comment\n", + "nameserver 127.0.0.53\n", + "nameserver 127.0.0.54\n", + "nameserver 127.0.0.55\n", + "options ndots:3 attempts:5 timeout:10 rotate use-vc edns0\n", + )); + + let expected = ResolvConf { + nameservers: vec![ + "127.0.0.53:53".parse().unwrap(), + "127.0.0.54:53".parse().unwrap(), + "127.0.0.55:53".parse().unwrap(), + ], + options: ResolvConfOptions { + timeout: Duration::from_secs(10), + attempts: 5, + rotate: true, + }, + }; + + let res = config(reader).await.unwrap(); + assert_eq!(expected, res); + } +} diff --git a/mtop-client/src/pool.rs b/mtop-client/src/pool.rs index 3b2b024..553865d 100644 --- a/mtop-client/src/pool.rs +++ b/mtop-client/src/pool.rs @@ -179,12 +179,7 @@ impl MemcachedPool { async fn connect(&self, server: &Server) -> Result { if let Some(cfg) = &self.client_config { - let name = self - .config - .tls - .server_name - .clone() - .unwrap_or_else(|| server.server_name()); + let name = self.config.tls.server_name.clone().unwrap_or_else(|| server.server_name()); tracing::debug!(message = "using server name for TLS validation", server_name = ?name); tls_connect(server.address(), name, cfg.clone()).await } else { @@ -272,9 +267,11 @@ where #[cfg(test)] mod test { - use super::{MemcachedPool, PoolConfig, PooledMemcached, Server}; + use super::{MemcachedPool, PoolConfig, PooledMemcached}; use crate::core::{ErrorKind, Memcached, MtopError}; + use crate::discovery::{Server, ServerID}; use std::io::{self, Cursor}; + use std::net::SocketAddr; use tokio::runtime::Handle; use webpki::types::ServerName; @@ -292,10 +289,8 @@ mod test { async fn test_get_new_connection() { let cfg = PoolConfig::default(); let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); - let server = Server::from_addr( - "127.0.0.1:11211".parse().unwrap(), - ServerName::try_from("localhost").unwrap().to_owned(), - ); + let id = ServerID::from("127.0.0.1:11211".parse::().unwrap()); + let server = Server::new(id, ServerName::try_from("localhost").unwrap().to_owned()); let connect = async { Ok(client!( @@ -313,10 +308,9 @@ mod test { async fn test_get_existing_connection() { let cfg = PoolConfig::default(); let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); - let server = Server::from_addr( - "127.0.0.1:11211".parse().unwrap(), - ServerName::try_from("localhost").unwrap().to_owned(), - ); + + let id = ServerID::from("127.0.0.1:11211".parse::().unwrap()); + let server = Server::new(id, ServerName::try_from("localhost").unwrap().to_owned()); pool.put(PooledMemcached { host: server.clone(), @@ -339,10 +333,8 @@ mod test { }; let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); - let server = Server::from_addr( - "127.0.0.1:11211".parse().unwrap(), - ServerName::try_from("localhost").unwrap().to_owned(), - ); + let id = ServerID::from("127.0.0.1:11211".parse::().unwrap()); + let server = Server::new(id, ServerName::try_from("localhost").unwrap().to_owned()); pool.put(PooledMemcached { host: server.clone(), @@ -363,10 +355,8 @@ mod test { let cfg = PoolConfig::default(); let pool = MemcachedPool::new(Handle::current(), cfg).await.unwrap(); - let server = Server::from_addr( - "127.0.0.1:11211".parse().unwrap(), - ServerName::try_from("localhost").unwrap().to_owned(), - ); + let id = ServerID::from("127.0.0.1:11211".parse::().unwrap()); + let server = Server::new(id, ServerName::try_from("localhost").unwrap().to_owned()); let connect = async { Err(MtopError::from(io::Error::new(io::ErrorKind::TimedOut, "timeout"))) }; let res = pool.get_with_connect(&server, connect).await; diff --git a/mtop/src/bin/dns.rs b/mtop/src/bin/dns.rs index 74b3331..6ab617e 100644 --- a/mtop/src/bin/dns.rs +++ b/mtop/src/bin/dns.rs @@ -1,17 +1,16 @@ -use clap::{Args, Parser, Subcommand}; -use mtop_client::dns::{DnsClient, Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordType}; -use mtop_client::Timeout; +use clap::{Args, Parser, Subcommand, ValueHint}; +use mtop_client::dns::{Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordType}; use std::fmt::Write; use std::io::Cursor; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::PathBuf; use std::process::ExitCode; use std::str::FromStr; -use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::Level; const DEFAULT_DNS_LOCAL: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); -const DEFAULT_DNS_SERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53); +const DEFAULT_LOG_LEVEL: Level = Level::INFO; const DEFAULT_TIMEOUT_SECS: u64 = 5; const DEFAULT_RECORD_TYPE: RecordType = RecordType::A; const DEFAULT_RECORD_CLASS: RecordClass = RecordClass::INET; @@ -20,6 +19,11 @@ const DEFAULT_RECORD_CLASS: RecordClass = RecordClass::INET; #[derive(Debug, Parser)] #[command(name = "dns", version = clap::crate_version!())] struct DnsConfig { + /// Logging verbosity. Allowed values are 'trace', 'debug', 'info', 'warn', and 'error' + /// (case-insensitive). + #[arg(long, default_value_t = DEFAULT_LOG_LEVEL)] + log_level: Level, + #[command(subcommand)] mode: Action, } @@ -38,9 +42,10 @@ struct QueryCommand { #[arg(long, default_value_t = DEFAULT_DNS_LOCAL)] dns_local: SocketAddr, - /// DNS server for service discovery in the form 'address:port' - #[arg(long, default_value_t = DEFAULT_DNS_SERVER)] - dns_server: SocketAddr, + /// Path to resolv.conf file for loading DNS configuration information. If this file + /// can't be loaded, default values for DNS configuration are used instead. + #[arg(long, default_value = default_resolv_conf().into_os_string(), value_hint = ValueHint::FilePath)] + resolv_conf: PathBuf, /// Timeout for making requests to a DNS server, in seconds. #[arg(long, default_value_t = DEFAULT_TIMEOUT_SECS)] @@ -65,6 +70,10 @@ struct QueryCommand { name: String, } +fn default_resolv_conf() -> PathBuf { + PathBuf::from("/etc/resolv.conf") +} + /// Read a binary format DNS message from standard input and display it as dig-like text output. #[derive(Debug, Args)] struct ReadCommand {} @@ -89,7 +98,8 @@ struct WriteCommand { async fn main() -> ExitCode { let opts = DnsConfig::parse(); - let console_subscriber = mtop::tracing::console_subscriber(Level::DEBUG).expect("failed to setup console logging"); + let console_subscriber = + mtop::tracing::console_subscriber(opts.log_level).expect("failed to setup console logging"); tracing::subscriber::set_global_default(console_subscriber).expect("failed to initialize console logging"); match &opts.mode { @@ -100,7 +110,7 @@ async fn main() -> ExitCode { } async fn run_query(cmd: &QueryCommand) -> ExitCode { - let timeout = Duration::from_secs(cmd.timeout_secs); + let client = mtop::dns::new_client(cmd.dns_local, &cmd.resolv_conf).await; let name = match Name::from_str(&cmd.name) { Ok(n) => n, Err(e) => { @@ -109,15 +119,10 @@ async fn run_query(cmd: &QueryCommand) -> ExitCode { } }; - let id = MessageId::random(); - let msg = Message::new(id, Flags::default().set_query().set_recursion_desired()) - .add_question(Question::new(name, cmd.rtype).set_qclass(cmd.rclass)); - - let client = DnsClient::new(cmd.dns_local, cmd.dns_server); - let response = match client.exchange(&msg).timeout(timeout, "client.exchange").await { + let response = match client.resolve(name, cmd.rtype, cmd.rclass).await { Ok(r) => r, Err(e) => { - tracing::error!(message = "unable to exchange message", "server" = %cmd.dns_server, err = %e); + tracing::error!(message = "unable to perform DNS query", err = %e); return ExitCode::FAILURE; } }; diff --git a/mtop/src/bin/mc.rs b/mtop/src/bin/mc.rs index ed7ae90..e3f0f52 100644 --- a/mtop/src/bin/mc.rs +++ b/mtop/src/bin/mc.rs @@ -2,7 +2,6 @@ use clap::{Args, Parser, Subcommand, ValueHint}; use mtop::bench::{Bencher, Percent, Summary}; use mtop::check::{Checker, TimingBundle}; use mtop::profile::Profiler; -use mtop_client::dns::DnsClient; use mtop_client::{ DiscoveryDefault, MemcachedClient, MemcachedPool, Meta, MtopError, PoolConfig, SelectorRendezvous, Server, TLSConfig, Timeout, Value, @@ -20,7 +19,6 @@ use tracing::{Instrument, Level}; use webpki::types::{InvalidDnsNameError, ServerName}; const DEFAULT_DNS_LOCAL: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); -const DEFAULT_DNS_SERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53); const DEFAULT_LOG_LEVEL: Level = Level::INFO; const DEFAULT_HOST: &str = "localhost:11211"; const DEFAULT_TIMEOUT_SECS: u64 = 30; @@ -39,9 +37,10 @@ struct McConfig { #[arg(long, default_value_t = DEFAULT_DNS_LOCAL)] dns_local: SocketAddr, - /// DNS server for service discovery in the form 'address:port' - #[arg(long, default_value_t = DEFAULT_DNS_SERVER)] - dns_server: SocketAddr, + /// Path to resolv.conf file for loading DNS configuration information. If this file + /// can't be loaded, default values for DNS configuration are used instead. + #[arg(long, default_value = default_resolv_conf().into_os_string(), value_hint = ValueHint::FilePath)] + resolv_conf: PathBuf, /// Memcached host to connect to in the form 'hostname:port'. #[arg(long, default_value_t = DEFAULT_HOST.to_owned(), value_hint = ValueHint::Hostname)] @@ -88,6 +87,10 @@ struct McConfig { mode: Action, } +fn default_resolv_conf() -> PathBuf { + PathBuf::from("/etc/resolv.conf") +} + fn parse_server_name(s: &str) -> Result, InvalidDnsNameError> { ServerName::try_from(s).map(|n| n.to_owned()) } @@ -296,8 +299,9 @@ async fn main() -> ExitCode { mtop::tracing::console_subscriber(opts.log_level).expect("failed to setup console logging"); tracing::subscriber::set_global_default(console_subscriber).expect("failed to initialize console logging"); + let dns_client = mtop::dns::new_client(opts.dns_local, &opts.resolv_conf).await; + let resolver = DiscoveryDefault::new(dns_client); let timeout = Duration::from_secs(opts.timeout_secs); - let resolver = DiscoveryDefault::new(DnsClient::new(opts.dns_local, opts.dns_server)); let servers = match resolver .resolve_by_proto(&opts.host) .timeout(timeout, "resolver.resolve_by_proto") diff --git a/mtop/src/bin/mtop.rs b/mtop/src/bin/mtop.rs index 2339f9a..c2f62bb 100644 --- a/mtop/src/bin/mtop.rs +++ b/mtop/src/bin/mtop.rs @@ -1,7 +1,6 @@ use clap::{Parser, ValueHint}; use mtop::queue::{BlockingStatsQueue, Host, StatsQueue}; use mtop::ui::{Theme, TAILWIND}; -use mtop_client::dns::DnsClient; use mtop_client::{ DiscoveryDefault, MemcachedClient, MemcachedPool, MtopError, PoolConfig, SelectorRendezvous, Server, TLSConfig, Timeout, @@ -19,7 +18,6 @@ use tracing::{Instrument, Level}; use webpki::types::{InvalidDnsNameError, ServerName}; const DEFAULT_DNS_LOCAL: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); -const DEFAULT_DNS_SERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53); const DEFAULT_LOG_LEVEL: Level = Level::INFO; const DEFAULT_THEME: Theme = TAILWIND; // Update interval of more than a second to minimize the chance that stats returned by the @@ -41,9 +39,10 @@ struct MtopConfig { #[arg(long, default_value_t = DEFAULT_DNS_LOCAL)] dns_local: SocketAddr, - /// DNS server for service discovery in the form 'address:port' - #[arg(long, default_value_t = DEFAULT_DNS_SERVER)] - dns_server: SocketAddr, + /// Path to resolv.conf file for loading DNS configuration information. If this file + /// can't be loaded, default values for DNS configuration are used instead. + #[arg(long, default_value = default_resolv_conf().into_os_string(), value_hint = ValueHint::FilePath)] + resolv_conf: PathBuf, /// Timeout for connecting to Memcached and fetching statistics, in seconds. #[arg(long, default_value_t = DEFAULT_TIMEOUT_SECS)] @@ -51,7 +50,7 @@ struct MtopConfig { /// File to log errors to since they cannot be logged to the console. If the path is not /// writable, mtop will not start. - #[arg(long, default_value=default_log_file().into_os_string(), value_hint = ValueHint::FilePath)] + #[arg(long, default_value = default_log_file().into_os_string(), value_hint = ValueHint::FilePath)] log_file: PathBuf, /// Color scheme to use for the UI. Available options are "ansi", "material", and "tailwind". @@ -94,6 +93,10 @@ struct MtopConfig { hosts: Vec, } +fn default_resolv_conf() -> PathBuf { + PathBuf::from("/etc/resolv.conf") +} + fn parse_server_name(s: &str) -> Result, InvalidDnsNameError> { ServerName::try_from(s).map(|n| n.to_owned()) } @@ -122,7 +125,8 @@ async fn main() -> ExitCode { let timeout = Duration::from_secs(opts.timeout_secs); let measurements = Arc::new(StatsQueue::new(NUM_MEASUREMENTS)); - let resolver = DiscoveryDefault::new(DnsClient::new(opts.dns_local, opts.dns_server)); + let dns_client = mtop::dns::new_client(opts.dns_local, &opts.resolv_conf).await; + let resolver = DiscoveryDefault::new(dns_client); let servers = match expand_hosts(&opts.hosts, &resolver, timeout).await { Ok(v) => v, diff --git a/mtop/src/check.rs b/mtop/src/check.rs index 26d39e3..43288d5 100644 --- a/mtop/src/check.rs +++ b/mtop/src/check.rs @@ -82,12 +82,7 @@ impl<'a> Checker<'a> { let dns_time = dns_start.elapsed(); let conn_start = Instant::now(); - let mut conn = match self - .client - .raw_open(&server) - .timeout(self.timeout, "client.raw_open") - .await - { + let mut conn = match self.client.raw_open(&server).timeout(self.timeout, "client.raw_open").await { Ok(v) => v, Err(e) => { tracing::warn!(message = "failed to connect to host", host = host, addr = %server.address(), err = %e); @@ -99,11 +94,7 @@ impl<'a> Checker<'a> { let conn_time = conn_start.elapsed(); let set_start = Instant::now(); - match conn - .set(&key, 0, 60, &val) - .timeout(self.timeout, "connection.set") - .await - { + match conn.set(&key, 0, 60, &val).timeout(self.timeout, "connection.set").await { Ok(_) => {} Err(e) => { tracing::warn!(message = "failed to set key", host = host, addr = %server.address(), err = %e); diff --git a/mtop/src/dns.rs b/mtop/src/dns.rs new file mode 100644 index 0000000..f903b95 --- /dev/null +++ b/mtop/src/dns.rs @@ -0,0 +1,43 @@ +use mtop_client::dns::{DnsClient, ResolvConf}; +use mtop_client::MtopError; +use std::fmt; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::Path; +use tokio::fs::File; + +const DEFAULT_SERVER: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53); + +/// Load configuration from the provided resolv.conf file and crated a new DnsClient +/// based on it. If the resolv.conf file cannot be opened or is malformed, default +/// configuration values will be used. See `man 5 resolv.conf` for more information. +pub async fn new_client

(local: SocketAddr, resolv: P) -> DnsClient +where + P: AsRef + fmt::Debug, +{ + let mut cfg = match load_config(&resolv).await { + Ok(cfg) => cfg, + Err(e) => { + tracing::warn!(message = "unable to load resolv.conf", path = ?resolv, err = %e); + ResolvConf::default() + } + }; + + // Either the resolv.conf file doesn't list any nameservers or we had to + // use Default::default() which also doesn't include any. Use localhost in + // the hopes that it will work. + if cfg.nameservers.is_empty() { + cfg.nameservers.push(DEFAULT_SERVER); + } + + DnsClient::new(local, cfg) +} + +async fn load_config

(resolv: P) -> Result +where + P: AsRef + fmt::Debug, +{ + let handle = File::open(&resolv) + .await + .map_err(|e| MtopError::configuration_cause(format!("unable to open {:?}", resolv), e))?; + mtop_client::dns::config(handle).await +} diff --git a/mtop/src/lib.rs b/mtop/src/lib.rs index c779aea..784bba6 100644 --- a/mtop/src/lib.rs +++ b/mtop/src/lib.rs @@ -2,6 +2,7 @@ pub mod bench; pub mod check; +pub mod dns; pub mod profile; pub mod queue; pub mod tracing; diff --git a/rustfmt.toml b/rustfmt.toml index 7530651..aa6c8fa 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1,2 @@ max_width = 120 +chain_width = 80