From bad1a81daa2c6f09298429a8af994153447e467c Mon Sep 17 00:00:00 2001 From: Nick Pillitteri Date: Wed, 17 Jul 2024 21:34:02 -0400 Subject: [PATCH] Don't attempt to resolve bare hostnames When connecting to Memcached servers, only attempt to resolve the hostnames to IP addresses when prefixed with `dns+`. If just a hostname is supplied, attempt to use it verbatim. This results in a better user experience since we get the benefits of using the system resolver and the expected hostname appears in the UI (instead of its IP). Fixes #151 --- mtop-client/src/discovery.rs | 54 ++++++++++++++++++++++++++++-------- mtop/src/bin/mc.rs | 6 ++-- mtop/src/bin/mtop.rs | 6 ++-- mtop/src/check.rs | 2 +- 4 files changed, 49 insertions(+), 19 deletions(-) diff --git a/mtop-client/src/discovery.rs b/mtop-client/src/discovery.rs index 615128e..0f9c7b6 100644 --- a/mtop-client/src/discovery.rs +++ b/mtop-client/src/discovery.rs @@ -135,17 +135,25 @@ impl DiscoveryDefault { /// the targets will happen at connection time using the system resolver. /// * No prefix with an IPv4 or IPv6 address will use the address as a Memcached /// server. - /// * No prefix with a non-IP address will resolve the hostname into A or AAAA - /// records and pick a single one as a Memcached server. - pub async fn resolve_by_proto(&self, name: &str) -> Result, MtopError> { + /// * No prefix with a non-IP address will use the host as a Memcached server. + /// Resolution of the host will happen at connection time using the system + /// resolver. + pub async fn resolve(&self, name: &str) -> Result, MtopError> { + Self::resolve_by_proto(&self.client, name).await + } + + async fn resolve_by_proto(client: C, name: &str) -> Result, MtopError> + where + C: AsyncDnsClient, + { if name.starts_with(DNS_A_PREFIX) { - Ok(Self::resolve_a_aaaa(&self.client, name.trim_start_matches(DNS_A_PREFIX)).await?) + Ok(Self::resolve_a_aaaa(client, name.trim_start_matches(DNS_A_PREFIX)).await?) } else if name.starts_with(DNS_SRV_PREFIX) { - Ok(Self::resolve_srv(&self.client, name.trim_start_matches(DNS_SRV_PREFIX)).await?) + Ok(Self::resolve_srv(client, name.trim_start_matches(DNS_SRV_PREFIX)).await?) } else if let Ok(addr) = name.parse::() { Ok(Self::resolv_socket_addr(name, addr)?) } else { - Ok(Self::resolve_a_aaaa(&self.client, name).await?.pop().into_iter().collect()) + Ok(Self::resolv_bare_host(name)?) } } @@ -184,6 +192,12 @@ impl DiscoveryDefault { Ok(vec![Server::new(ServerID::from(addr), server_name)]) } + fn resolv_bare_host(name: &str) -> Result, MtopError> { + let (host, port) = Self::host_and_port(name)?; + let server_name = Self::server_name(host)?; + Ok(vec![Server::new(ServerID::from((host, port)), server_name)]) + } + fn servers_from_answers(port: u16, server_name: &ServerName, answers: &[Record]) -> Vec { let mut out = Vec::new(); @@ -370,7 +384,9 @@ mod test { ); let client = MockDnsClient::new(vec![response_a, response_aaaa]); - let servers = DiscoveryDefault::resolve_a_aaaa(&client, "example.com:11211").await.unwrap(); + let servers = DiscoveryDefault::resolve_by_proto(&client, "dns+example.com:11211") + .await + .unwrap(); let ids = servers.iter().map(|s| s.id()).collect::>(); let id_a = ServerID::from(("10.1.1.1", 11211)); @@ -413,7 +429,7 @@ mod test { ); let client = MockDnsClient::new(vec![response]); - let servers = DiscoveryDefault::resolve_srv(&client, "_cache.example.com:11211") + let servers = DiscoveryDefault::resolve_by_proto(&client, "dnssrv+_cache.example.com:11211") .await .unwrap(); let ids = servers.iter().map(|s| s.id()).collect::>(); @@ -425,14 +441,28 @@ mod test { assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2); } - #[test] - fn test_dns_client_resolve_socket_addr() { + #[tokio::test] + async fn test_dns_client_resolve_socket_addr() { let name = "127.0.0.2:11211"; - let addr = "127.0.0.2:11211".parse().unwrap(); - let servers = DiscoveryDefault::resolv_socket_addr(name, addr).unwrap(); + let addr: SocketAddr = "127.0.0.2:11211".parse().unwrap(); + + let client = MockDnsClient::new(vec![]); + let servers = DiscoveryDefault::resolve_by_proto(&client, name).await.unwrap(); let ids = servers.iter().map(|s| s.id()).collect::>(); let id = ServerID::from(addr); assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id); } + + #[tokio::test] + async fn test_dns_client_resolve_bare_host() { + let name = "localhost:11211"; + + let client = MockDnsClient::new(vec![]); + let servers = DiscoveryDefault::resolve_by_proto(&client, name).await.unwrap(); + let ids = servers.iter().map(|s| s.id()).collect::>(); + + let id = ServerID::from(("localhost", 11211)); + assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id); + } } diff --git a/mtop/src/bin/mc.rs b/mtop/src/bin/mc.rs index 5a2cd0f..2f7d831 100644 --- a/mtop/src/bin/mc.rs +++ b/mtop/src/bin/mc.rs @@ -297,9 +297,9 @@ async fn main() -> ExitCode { let resolver = DiscoveryDefault::new(dns_client); let timeout = Duration::from_secs(opts.timeout_secs); let servers = match resolver - .resolve_by_proto(&opts.host) - .timeout(timeout, "resolver.resolve_by_proto") - .instrument(tracing::span!(Level::INFO, "resolver.resolve_by_proto")) + .resolve(&opts.host) + .timeout(timeout, "resolver.resolve") + .instrument(tracing::span!(Level::INFO, "resolver.resolve")) .await { Ok(v) => v, diff --git a/mtop/src/bin/mtop.rs b/mtop/src/bin/mtop.rs index 42cb5e2..dedd4b4 100644 --- a/mtop/src/bin/mtop.rs +++ b/mtop/src/bin/mtop.rs @@ -211,9 +211,9 @@ async fn expand_hosts( for host in hosts { out.extend( resolver - .resolve_by_proto(host) - .timeout(timeout, "resolver.resolve_by_proto") - .instrument(tracing::span!(Level::INFO, "resolver.resolve_by_proto")) + .resolve(host) + .timeout(timeout, "resolver.resolve") + .instrument(tracing::span!(Level::INFO, "resolver.resolve")) .await?, ); } diff --git a/mtop/src/check.rs b/mtop/src/check.rs index 85af323..dc566b3 100644 --- a/mtop/src/check.rs +++ b/mtop/src/check.rs @@ -64,7 +64,7 @@ impl Checker { let dns_start = Instant::now(); let server = match self .resolver - .resolve_by_proto(host) + .resolve(host) .timeout(self.timeout, "resolver.resolve_by_proto") .await .map(|mut v| v.pop())