Skip to content

Commit

Permalink
Add TCP support to DNS client
Browse files Browse the repository at this point in the history
Add fallback to TCP for DNS client when the UDP message is trunctated.

Part of #107
  • Loading branch information
56quarters committed May 19, 2024
1 parent d5c88bf commit c3ac0e0
Showing 1 changed file with 52 additions and 4 deletions.
56 changes: 52 additions & 4 deletions mtop-client/src/dns/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use crate::timeout::Timeout;
use std::io::Cursor;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::net::UdpSocket;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpStream, UdpSocket};

const DEFAULT_RECV_BUF: usize = 512;

Expand Down Expand Up @@ -62,13 +63,27 @@ impl DnsClient {
// Wrap creating the socket, sending, and receiving in an async block
// so that we can apply a single timeout to all operations and ensure
// we have access to the nameserver to make the timeout message useful.
async {
let res = async {
let sock = self.connect_udp(server).await?;
self.send_udp(&sock, msg).await?;
self.recv_udp(&sock, id).await
}
.timeout(self.config.options.timeout, format!("client.exchange {}", server))
.await
.timeout(self.config.options.timeout, format!("client.exchange udp://{}", server))
.await?;

// If the UDP response indicates the message was truncated, we discard
// it and repeat the query using TCP.
if res.flags().is_truncated() {
tracing::debug!(message = "UDP response truncated, repeating with TCP", flags = ?res.flags(), server = %server);
async {
let mut sock = self.connect_tcp(server).await?;
self.send_recv_tcp(&mut sock, msg).await
}
.timeout(self.config.options.timeout, format!("client.exchange tcp://{}", server))
.await
} else {
Ok(res)
}
}

async fn connect_udp(&self, server: SocketAddr) -> Result<UdpSocket, MtopError> {
Expand All @@ -95,6 +110,39 @@ impl DnsClient {
}
}

async fn connect_tcp(&self, server: SocketAddr) -> Result<TcpStream, MtopError> {
Ok(TcpStream::connect(server).await?)
}

async fn send_recv_tcp(&self, stream: &mut TcpStream, msg: &Message) -> Result<Message, MtopError> {
let mut buf = Vec::with_capacity(DEFAULT_RECV_BUF);
let (read, write) = stream.split();
let mut read = BufReader::new(read);
let mut write = BufWriter::new(write);

// Write the message to a local buffer and then send it, prefixed
// with the size of the message.
msg.write_network_bytes(&mut buf)?;
write.write_u16(buf.len() as u16).await?;
write.write_all(&buf).await?;
write.flush().await?;

// Read the prefixed size of the response and then read exactly that
// many bytes into our buffer.
let sz = read.read_u16().await?;
buf.clear();
buf.resize(usize::from(sz), 0);
read.read_exact(&mut buf).await?;

let mut cur = Cursor::new(buf);
let res = Message::read_network_bytes(&mut cur)?;
if res.id() != msg.id() {
Err(MtopError::runtime(format!("unexpected DNS MessageId. expected {}, got {}", msg.id(), res.id())))
} else {
Ok(res)
}
}

fn nameserver(&self, attempt: usize) -> SocketAddr {
let idx = if self.config.options.rotate {
self.server.fetch_add(1, Ordering::Relaxed)
Expand Down

0 comments on commit c3ac0e0

Please sign in to comment.