Skip to content

Commit

Permalink
Merge pull request #159 from 56quarters/resolution
Browse files Browse the repository at this point in the history
Don't attempt to resolve bare hostnames
  • Loading branch information
56quarters authored Jul 18, 2024
2 parents 6b22d3b + bad1a81 commit cdd0d54
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 19 deletions.
54 changes: 42 additions & 12 deletions mtop-client/src/discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,25 @@ impl DiscoveryDefault {
/// the targets will happen at connection time using the system resolver.
/// * No prefix with an IPv4 or IPv6 address will use the address as a Memcached
/// server.
/// * No prefix with a non-IP address will resolve the hostname into A or AAAA
/// records and pick a single one as a Memcached server.
pub async fn resolve_by_proto(&self, name: &str) -> Result<Vec<Server>, MtopError> {
/// * No prefix with a non-IP address will use the host as a Memcached server.
/// Resolution of the host will happen at connection time using the system
/// resolver.
pub async fn resolve(&self, name: &str) -> Result<Vec<Server>, MtopError> {
Self::resolve_by_proto(&self.client, name).await
}

async fn resolve_by_proto<C>(client: C, name: &str) -> Result<Vec<Server>, MtopError>
where
C: AsyncDnsClient,
{
if name.starts_with(DNS_A_PREFIX) {
Ok(Self::resolve_a_aaaa(&self.client, name.trim_start_matches(DNS_A_PREFIX)).await?)
Ok(Self::resolve_a_aaaa(client, name.trim_start_matches(DNS_A_PREFIX)).await?)
} else if name.starts_with(DNS_SRV_PREFIX) {
Ok(Self::resolve_srv(&self.client, name.trim_start_matches(DNS_SRV_PREFIX)).await?)
Ok(Self::resolve_srv(client, name.trim_start_matches(DNS_SRV_PREFIX)).await?)
} else if let Ok(addr) = name.parse::<SocketAddr>() {
Ok(Self::resolv_socket_addr(name, addr)?)
} else {
Ok(Self::resolve_a_aaaa(&self.client, name).await?.pop().into_iter().collect())
Ok(Self::resolv_bare_host(name)?)
}
}

Expand Down Expand Up @@ -184,6 +192,12 @@ impl DiscoveryDefault {
Ok(vec![Server::new(ServerID::from(addr), server_name)])
}

fn resolv_bare_host(name: &str) -> Result<Vec<Server>, MtopError> {
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
Ok(vec![Server::new(ServerID::from((host, port)), server_name)])
}

fn servers_from_answers(port: u16, server_name: &ServerName, answers: &[Record]) -> Vec<Server> {
let mut out = Vec::new();

Expand Down Expand Up @@ -370,7 +384,9 @@ mod test {
);

let client = MockDnsClient::new(vec![response_a, response_aaaa]);
let servers = DiscoveryDefault::resolve_a_aaaa(&client, "example.com:11211").await.unwrap();
let servers = DiscoveryDefault::resolve_by_proto(&client, "dns+example.com:11211")
.await
.unwrap();
let ids = servers.iter().map(|s| s.id()).collect::<Vec<_>>();

let id_a = ServerID::from(("10.1.1.1", 11211));
Expand Down Expand Up @@ -413,7 +429,7 @@ mod test {
);

let client = MockDnsClient::new(vec![response]);
let servers = DiscoveryDefault::resolve_srv(&client, "_cache.example.com:11211")
let servers = DiscoveryDefault::resolve_by_proto(&client, "dnssrv+_cache.example.com:11211")
.await
.unwrap();
let ids = servers.iter().map(|s| s.id()).collect::<Vec<_>>();
Expand All @@ -425,14 +441,28 @@ mod test {
assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2);
}

#[test]
fn test_dns_client_resolve_socket_addr() {
#[tokio::test]
async fn test_dns_client_resolve_socket_addr() {
let name = "127.0.0.2:11211";
let addr = "127.0.0.2:11211".parse().unwrap();
let servers = DiscoveryDefault::resolv_socket_addr(name, addr).unwrap();
let addr: SocketAddr = "127.0.0.2:11211".parse().unwrap();

let client = MockDnsClient::new(vec![]);
let servers = DiscoveryDefault::resolve_by_proto(&client, name).await.unwrap();
let ids = servers.iter().map(|s| s.id()).collect::<Vec<_>>();

let id = ServerID::from(addr);
assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
}

#[tokio::test]
async fn test_dns_client_resolve_bare_host() {
let name = "localhost:11211";

let client = MockDnsClient::new(vec![]);
let servers = DiscoveryDefault::resolve_by_proto(&client, name).await.unwrap();
let ids = servers.iter().map(|s| s.id()).collect::<Vec<_>>();

let id = ServerID::from(("localhost", 11211));
assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
}
}
6 changes: 3 additions & 3 deletions mtop/src/bin/mc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ async fn main() -> ExitCode {
let resolver = DiscoveryDefault::new(dns_client);
let timeout = Duration::from_secs(opts.timeout_secs);
let servers = match resolver
.resolve_by_proto(&opts.host)
.timeout(timeout, "resolver.resolve_by_proto")
.instrument(tracing::span!(Level::INFO, "resolver.resolve_by_proto"))
.resolve(&opts.host)
.timeout(timeout, "resolver.resolve")
.instrument(tracing::span!(Level::INFO, "resolver.resolve"))
.await
{
Ok(v) => v,
Expand Down
6 changes: 3 additions & 3 deletions mtop/src/bin/mtop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ async fn expand_hosts(
for host in hosts {
out.extend(
resolver
.resolve_by_proto(host)
.timeout(timeout, "resolver.resolve_by_proto")
.instrument(tracing::span!(Level::INFO, "resolver.resolve_by_proto"))
.resolve(host)
.timeout(timeout, "resolver.resolve")
.instrument(tracing::span!(Level::INFO, "resolver.resolve"))
.await?,
);
}
Expand Down
2 changes: 1 addition & 1 deletion mtop/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl Checker {
let dns_start = Instant::now();
let server = match self
.resolver
.resolve_by_proto(host)
.resolve(host)
.timeout(self.timeout, "resolver.resolve_by_proto")
.await
.map(|mut v| v.pop())
Expand Down

0 comments on commit cdd0d54

Please sign in to comment.