From 9e1bb522c2d5f700352dc50460667e5107bb5f6b Mon Sep 17 00:00:00 2001 From: Nick Pillitteri Date: Mon, 19 Aug 2024 08:44:39 -0400 Subject: [PATCH] Deduplicate servers by ID when doing service discovery resolution Avoid duplicate servers when a SRV lookup returns the same host multiple times with different ports. We deduplicate based on server ID. We keep the same behavior of ignoring the port number from SRV records. Fixes #184 --- mtop-client/src/discovery.rs | 68 ++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 10 deletions(-) diff --git a/mtop-client/src/discovery.rs b/mtop-client/src/discovery.rs index 75ca1ed..ea364e8 100644 --- a/mtop-client/src/discovery.rs +++ b/mtop-client/src/discovery.rs @@ -1,7 +1,8 @@ use crate::core::MtopError; -use crate::dns::{DefaultDnsClient, DnsClient, Name, Record, RecordClass, RecordData, RecordType}; +use crate::dns::{DefaultDnsClient, DnsClient, Message, Name, RecordClass, RecordData, RecordType}; use rustls_pki_types::ServerName; use std::cmp::Ordering; +use std::collections::HashSet; use std::fmt; use std::net::{IpAddr, SocketAddr}; @@ -152,7 +153,7 @@ where let name = host.parse()?; let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?; - Ok(Self::servers_from_answers(port, &server_name, res.answers())) + Ok(Self::servers_from_answers(port, &server_name, &res)) } async fn resolve_a_aaaa(&self, name: &str) -> Result, MtopError> { @@ -161,10 +162,10 @@ where let name: Name = host.parse()?; let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?; - let mut out = Self::servers_from_answers(port, &server_name, res.answers()); + let mut out = Self::servers_from_answers(port, &server_name, &res); let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?; - out.extend(Self::servers_from_answers(port, &server_name, res.answers())); + out.extend(Self::servers_from_answers(port, &server_name, &res)); Ok(out) } @@ -181,10 +182,10 @@ where Ok(vec![Server::new(ServerID::from((host, port)), server_name)]) } - fn servers_from_answers(port: u16, server_name: &ServerName, answers: &[Record]) -> Vec { - let mut out = Vec::new(); + fn servers_from_answers(port: u16, server_name: &ServerName, message: &Message) -> Vec { + let mut ids = HashSet::new(); - for answer in answers { + for answer in message.answers() { let server_id = match answer.rdata() { RecordData::A(data) => ServerID::from(SocketAddr::new(IpAddr::V4(data.addr()), port)), RecordData::AAAA(data) => ServerID::from(SocketAddr::new(IpAddr::V6(data.addr()), port)), @@ -195,11 +196,14 @@ where } }; - let server = Server::new(server_id, server_name.to_owned()); - out.push(server); + // Insert IDs into a HashSet to deduplicate them. We can potentially end up with duplicates + // when a SRV query returns multiple answers per hostname (such as when each host has more + // than a single port). Because we ignore the port number from the SRV answer we need to + // deduplicate here. + ids.insert(server_id); } - out + ids.into_iter().map(|id| Server::new(id, server_name.to_owned())).collect() } fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> { @@ -422,6 +426,48 @@ mod test { assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2); } + #[tokio::test] + async fn test_dns_client_resolve_srv_dupes() { + let response = response_with_answers( + RecordType::SRV, + vec![ + Record::new( + Name::from_str("_cache.example.com.").unwrap(), + RecordType::SRV, + RecordClass::INET, + 300, + RecordData::SRV(RecordDataSRV::new( + 100, + 10, + 11211, + Name::from_str("cache01.example.com.").unwrap(), + )), + ), + Record::new( + Name::from_str("_cache.example.com.").unwrap(), + RecordType::SRV, + RecordClass::INET, + 300, + RecordData::SRV(RecordDataSRV::new( + 100, + 10, + 9105, + Name::from_str("cache01.example.com.").unwrap(), + )), + ), + ], + ); + + let client = MockDnsClient::new(vec![response]); + let discovery = Discovery::new(client); + let servers = discovery.resolve_by_proto("dnssrv+_cache.example.com:11211").await.unwrap(); + let ids = servers.iter().map(|s| s.id().clone()).collect::>(); + + let id = ServerID::from(("cache01.example.com.", 11211)); + + assert_eq!(ids, vec![id]); + } + #[tokio::test] async fn test_dns_client_resolve_socket_addr() { let name = "127.0.0.2:11211"; @@ -433,6 +479,7 @@ mod test { let ids = servers.iter().map(|s| s.id().clone()).collect::>(); let id = ServerID::from(addr); + assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id); } @@ -446,6 +493,7 @@ mod test { let ids = servers.iter().map(|s| s.id().clone()).collect::>(); let id = ServerID::from(("localhost", 11211)); + assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id); } }