Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read DNS settings from a resolv.conf file #134

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions mtop-client/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,7 @@ impl TryFrom<&HashMap<String, String>> for Slabs {
// $active_slabs + 1.
let mut ids = BTreeSet::new();
for k in value.keys() {
let key_id: Option<u64> = k
.split_once(':')
.map(|(raw, _rest)| raw)
.and_then(|raw| raw.parse().ok());
let key_id: Option<u64> = k.split_once(':').map(|(raw, _rest)| raw).and_then(|raw| raw.parse().ok());

if let Some(id) = key_id {
ids.insert(id);
Expand Down Expand Up @@ -1287,9 +1284,7 @@ mod test {
macro_rules! test_store_command_success {
($method:ident, $verb:expr) => {
let (mut rx, mut client) = client!("STORED\r\n");
let res = client
.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes())
.await;
let res = client.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()).await;

assert!(res.is_ok());
let bytes = rx.recv().await.unwrap();
Expand All @@ -1301,9 +1296,7 @@ mod test {
macro_rules! test_store_command_error {
($method:ident, $verb:expr) => {
let (mut rx, mut client) = client!("NOT_STORED\r\n");
let res = client
.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes())
.await;
let res = client.$method(&Key::one("test").unwrap(), 0, 300, "val".as_bytes()).await;

assert!(res.is_err());
let err = res.unwrap_err();
Expand Down
149 changes: 86 additions & 63 deletions mtop-client/src/discovery.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::core::MtopError;
use crate::dns::{DnsClient, RecordData};
use crate::dns::{DnsClient, Name, Record, RecordClass, RecordData, RecordType};
use std::cmp::Ordering;
use std::fmt;
use std::net::{IpAddr, SocketAddr};
Expand Down Expand Up @@ -60,47 +60,25 @@ impl AsRef<str> for ServerID {
/// An individual server that is part of a Memcached cluster.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Server {
repr: ServerRepr,
}

#[derive(Debug, Clone, Eq, PartialEq, Hash)]
enum ServerRepr {
Resolved(ServerID, ServerName<'static>, SocketAddr),
Unresolved(ServerID, ServerName<'static>),
id: ServerID,
name: ServerName<'static>,
}

impl Server {
pub fn from_id(id: ServerID, name: ServerName<'static>) -> Self {
Self {
repr: ServerRepr::Unresolved(id, name),
}
}

pub fn from_addr(addr: SocketAddr, name: ServerName<'static>) -> Self {
Self {
repr: ServerRepr::Resolved(ServerID::from(addr), name, addr),
}
pub fn new(id: ServerID, name: ServerName<'static>) -> Self {
Self { id, name }
}

pub fn id(&self) -> ServerID {
match &self.repr {
ServerRepr::Resolved(id, _, _) => id.clone(),
ServerRepr::Unresolved(id, _) => id.clone(),
}
self.id.clone()
}

pub fn server_name(&self) -> ServerName<'static> {
match &self.repr {
ServerRepr::Resolved(_, name, _) => name.clone(),
ServerRepr::Unresolved(_, name) => name.clone(),
}
self.name.clone()
}

pub fn address(&self) -> String {
match &self.repr {
ServerRepr::Resolved(_, _, addr) => addr.to_string(),
ServerRepr::Unresolved(id, _) => id.to_string(),
}
self.id.to_string()
}
}

Expand Down Expand Up @@ -134,44 +112,56 @@ impl DiscoveryDefault {

pub async fn resolve_by_proto(&self, name: &str) -> Result<Vec<Server>, MtopError> {
if name.starts_with(DNS_A_PREFIX) {
Ok(self.resolve_a(name.trim_start_matches(DNS_A_PREFIX)).await?)
Ok(self.resolve_a_aaaa(name.trim_start_matches(DNS_A_PREFIX)).await?)
} else if name.starts_with(DNS_SRV_PREFIX) {
Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?)
} else {
Ok(self.resolve_a(name).await?.pop().into_iter().collect())
Ok(self.resolve_a_aaaa(name).await?.pop().into_iter().collect())
}
}

async fn resolve_srv(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let server_name = Self::server_name(name)?;
let (host_name, port) = Self::host_and_port(name)?;
let mut out = Vec::new();
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
let name = host.parse()?;

let res = self.client.resolve_srv(host_name).await?;
for a in res.answers() {
let target = if let RecordData::SRV(srv) = a.rdata() {
srv.target().to_string()
} else {
tracing::warn!(message = "unexpected record data for answer", name = host_name, answer = ?a);
continue;
};
let server_id = ServerID::from((target, port));
let server = Server::from_id(server_id, server_name.clone());
out.push(server);
}
let res = self.client.resolve(name, RecordType::SRV, RecordClass::INET).await?;
Ok(self.servers_from_answers(port, &server_name, res.answers()))
}

async fn resolve_a_aaaa(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let (host, port) = Self::host_and_port(name)?;
let server_name = Self::server_name(host)?;
let name: Name = host.parse()?;

let res = self.client.resolve(name.clone(), RecordType::A, RecordClass::INET).await?;
let mut out = self.servers_from_answers(port, &server_name, res.answers());

let res = self.client.resolve(name, RecordType::AAAA, RecordClass::INET).await?;
out.extend(self.servers_from_answers(port, &server_name, res.answers()));

Ok(out)
}

async fn resolve_a(&self, name: &str) -> Result<Vec<Server>, MtopError> {
let server_name = Self::server_name(name)?;

fn servers_from_answers(&self, port: u16, server_name: &ServerName, answers: &[Record]) -> Vec<Server> {
let mut out = Vec::new();
for addr in tokio::net::lookup_host(name).await? {
out.push(Server::from_addr(addr, server_name.clone()));

for answer in answers {
let server_id = match answer.rdata() {
RecordData::A(data) => ServerID::from(SocketAddr::new(IpAddr::V4(data.addr()), port)),
RecordData::AAAA(data) => ServerID::from(SocketAddr::new(IpAddr::V6(data.addr()), port)),
RecordData::SRV(data) => ServerID::from((data.target().to_string(), port)),
_ => {
tracing::warn!(message = "unexpected record data for answer", answer = ?answer);
continue;
}
};

let server = Server::new(server_id, server_name.to_owned());
out.push(server);
}

Ok(out)
out
}

fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> {
Expand All @@ -185,27 +175,26 @@ impl DiscoveryDefault {
// IPv6 addresses use brackets around them to disambiguate them from a port number.
// Since we're parsing the host and port, strip the brackets because they aren't
// needed or valid to include in a TLS ServerName.
.map(|(hostname, port)| (hostname.trim_start_matches('[').trim_end_matches(']'), port))
.and_then(|(hostname, port)| {
port.parse().map(|p| (hostname, p)).map_err(|e| {
.map(|(host, port)| (host.trim_start_matches('[').trim_end_matches(']'), port))
.and_then(|(host, port)| {
port.parse().map(|p| (host, p)).map_err(|e| {
MtopError::configuration_cause(format!("unable to parse port number from '{}'", name), e)
})
})
}

fn server_name(name: &str) -> Result<ServerName<'static>, MtopError> {
Self::host_and_port(name).and_then(|(host, _)| {
ServerName::try_from(host)
.map(|s| s.to_owned())
.map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e))
})
fn server_name(host: &str) -> Result<ServerName<'static>, MtopError> {
ServerName::try_from(host)
.map(|s| s.to_owned())
.map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e))
}
}

#[cfg(test)]
mod test {
use super::ServerID;
use super::{Server, ServerID};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use webpki::types::ServerName;

#[test]
fn test_server_id_from_ipv4_addr() {
Expand Down Expand Up @@ -241,4 +230,38 @@ mod test {
let id = ServerID::from(pair);
assert_eq!("cache.example.com:11211", id.to_string());
}

#[test]
fn test_server_resolved_id() {
let addr = SocketAddr::from(([127, 0, 0, 1], 11211));
let id = ServerID::from(addr);
let name = ServerName::try_from("cache.example.com").unwrap();
let server = Server::new(id, name);
assert_eq!("127.0.0.1:11211", server.id().to_string());
}

#[test]
fn test_server_resolved_address() {
let addr = SocketAddr::from(([127, 0, 0, 1], 11211));
let id = ServerID::from(addr);
let name = ServerName::try_from("cache.example.com").unwrap();
let server = Server::new(id, name);
assert_eq!("127.0.0.1:11211", server.address());
}

#[test]
fn test_server_unresolved_id() {
let id = ServerID::from(("cache01.example.com", 11211));
let name = ServerName::try_from("cache.example.com").unwrap();
let server = Server::new(id, name);
assert_eq!("cache01.example.com:11211", server.id().to_string());
}

#[test]
fn test_server_unresolved_address() {
let id = ServerID::from(("cache01.example.com", 11211));
let name = ServerName::try_from("cache.example.com").unwrap();
let server = Server::new(id, name);
assert_eq!("cache01.example.com:11211", server.address());
}
}
93 changes: 74 additions & 19 deletions mtop-client/src/dns/client.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,79 @@
use crate::core::MtopError;
use crate::dns::core::RecordType;
use crate::dns::core::{RecordClass, RecordType};
use crate::dns::message::{Flags, Message, MessageId, Question};
use crate::dns::name::Name;
use crate::dns::resolv::ResolvConf;
use crate::timeout::Timeout;
use std::io::Cursor;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::net::UdpSocket;

const DEFAULT_RECV_BUF: usize = 512;

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct DnsClient {
local: SocketAddr,
server: SocketAddr,
config: ResolvConf,
server: AtomicUsize,
}

impl DnsClient {
pub fn new(local: SocketAddr, server: SocketAddr) -> Self {
Self { local, server }
}

pub async fn exchange(&self, msg: &Message) -> Result<Message, MtopError> {
let id = msg.id();
let sock = self.connect_udp().await?;
self.send_udp(&sock, msg).await?;
self.recv_udp(&sock, id).await
/// Create a new DnsClient that will use a local address to open UDP or TCP
/// connections and behavior based on a resolv.conf configuration file.
pub fn new(local: SocketAddr, config: ResolvConf) -> Self {
Self {
local,
config,
server: AtomicUsize::new(0),
}
}

pub async fn resolve_srv(&self, name: &str) -> Result<Message, MtopError> {
let n = Name::from_str(name)?;
/// Perform a DNS lookup with the configured nameservers.
///
/// Timeouts and network errors will result in up to one additional attempt
/// to perform a DNS lookup when using the default configuration.
pub async fn resolve(&self, name: Name, rtype: RecordType, rclass: RecordClass) -> Result<Message, MtopError> {
let full = name.to_fqdn();
let id = MessageId::random();
let flags = Flags::default().set_recursion_desired();
let msg = Message::new(id, flags).add_question(Question::new(n, RecordType::SRV));
let question = Question::new(full, rtype).set_qclass(rclass);
let message = Message::new(id, flags).add_question(question);

let mut attempt = 0;
loop {
match self.exchange(&message, usize::from(attempt)).await {
Ok(v) => return Ok(v),
Err(e) => {
if attempt + 1 >= self.config.options.attempts {
return Err(e);
}

self.exchange(&msg).await
tracing::debug!(message = "retrying failed query", attempt = attempt + 1, max_attempts = self.config.options.attempts, err = %e);
attempt += 1;
}
}
}
}
async fn exchange(&self, msg: &Message, attempt: usize) -> Result<Message, MtopError> {
let id = msg.id();
let server = self.nameserver(attempt);

async fn connect_udp(&self) -> Result<UdpSocket, MtopError> {
// Wrap creating the socket, sending, and receiving in an async block
// so that we can apply a single timeout to all operations and ensure
// we have access to the nameserver to make the timeout message useful.
async {
let sock = self.connect_udp(server).await?;
self.send_udp(&sock, msg).await?;
self.recv_udp(&sock, id).await
}
.timeout(self.config.options.timeout, format!("client.exchange {}", server))
.await
}

async fn connect_udp(&self, server: SocketAddr) -> Result<UdpSocket, MtopError> {
let sock = UdpSocket::bind(&self.local).await?;
sock.connect(&self.server).await?;
sock.connect(server).await?;
Ok(sock)
}

Expand All @@ -59,4 +94,24 @@ impl DnsClient {
}
}
}

fn nameserver(&self, attempt: usize) -> SocketAddr {
let idx = if self.config.options.rotate {
self.server.fetch_add(1, Ordering::Relaxed)
} else {
attempt
};

self.config.nameservers[idx % self.config.nameservers.len()]
}
}

impl Clone for DnsClient {
fn clone(&self) -> Self {
Self {
local: self.local,
config: self.config.clone(),
server: AtomicUsize::new(0),
}
}
}
Loading