diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs index 89f04d277d..d9360bdb0b 100644 --- a/src/net/tcp/socket.rs +++ b/src/net/tcp/socket.rs @@ -1,14 +1,14 @@ -use crate::net::{TcpStream, TcpListener}; +use crate::net::{TcpListener, TcpStream}; use crate::sys; use std::io; use std::mem; use std::net::SocketAddr; -use std::time::Duration; #[cfg(unix)] -use std::os::unix::io::{AsRawFd, RawFd, FromRawFd}; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use std::time::Duration; /// A non-blocking TCP socket used to configure a stream or listener. /// @@ -27,18 +27,14 @@ impl TcpSocket { /// /// This calls `socket(2)`. pub fn new_v4() -> io::Result { - sys::tcp::new_v4_socket().map(|sys| TcpSocket { - sys - }) + sys::tcp::new_v4_socket().map(|sys| TcpSocket { sys }) } /// Create a new IPv6 TCP socket. /// /// This calls `socket(2)`. pub fn new_v6() -> io::Result { - sys::tcp::new_v6_socket().map(|sys| TcpSocket { - sys - }) + sys::tcp::new_v6_socket().map(|sys| TcpSocket { sys }) } pub(crate) fn new_for_addr(addr: SocketAddr) -> io::Result { @@ -163,32 +159,118 @@ impl TcpSocket { pub fn get_send_buffer_size(&self) -> io::Result { sys::tcp::get_send_buffer_size(self.sys) } - + /// Sets whether keepalive messages are enabled to be sent on this socket. /// - /// On Unix, this option will set the `SO_KEEPALIVE` as well as the - /// `TCP_KEEPALIVE` or `TCP_KEEPIDLE` option (depending on your platform). - /// On Windows, this will set the `SIO_KEEPALIVE_VALS` option. + /// This will set the `SO_KEEPALIVE` option on this socket. + pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> { + sys::tcp::set_keepalive(self.sys, keepalive) + } + + /// Returns whether or not TCP keepalive probes will be sent by this socket. + pub fn get_keepalive(&self) -> io::Result { + sys::tcp::get_keepalive(self.sys) + } + + /// Sets the amount of time after which TCP keepalive probes will be sent + /// on idle connections, if TCP keepalive is enabled on this socket. /// - /// If `None` is specified then keepalive messages are disabled, otherwise - /// the duration specified will be the time to remain idle before sending a - /// TCP keepalive probe. + /// This sets the value of `SO_KEEPALIVE` + `IPPROTO_TCP` on OpenBSD, + /// NetBSD, and Haiku, `TCP_KEEPALIVE` on macOS and iOS, and `TCP_KEEPIDLE` + /// on all other Unix operating systems. On Windows, this sets the value of + /// the `tcp_keepalive` struct's `keepalivetime` field. /// /// Some platforms specify this value in seconds, so sub-second /// specifications may be omitted. - pub fn set_keepalive(&self, dur: Option) -> io::Result<()> { - sys::tcp::set_keepalive(self.sys, dur) + /// + /// The OS may return an error if TCP keepalive was not already enabled by + /// calling `set_keepalive(true)` on this socket. + pub fn set_keepalive_time(&self, time: Duration) -> io::Result<()> { + sys::tcp::set_keepalive_time(self.sys, time) } - /// Returns the duration after which TCP keepalive probes will be sent, if - /// keepalive messages are enabled to be sent on this socket. + /// Returns the amount of time after which TCP keepalive probes will be sent + /// on idle connections. /// /// If `None`, then keepalive messages are disabled. /// + /// This returns the value of `SO_KEEPALIVE` + `IPPROTO_TCP` on OpenBSD, + /// NetBSD, and Haiku, `TCP_KEEPALIVE` on macOS and iOS, and `TCP_KEEPIDLE` + /// on all other Unix operating systems. On Windows, this returns the value of + /// the `tcp_keepalive` struct's `keepalivetime` field. + /// /// Some platforms specify this value in seconds, so sub-second /// specifications may be omitted. - pub fn get_keepalive(&self) -> io::Result> { - sys::tcp::get_keepalive(self.sys) + pub fn get_keepalive_time(&self) -> io::Result> { + sys::tcp::get_keepalive_time(self.sys) + } + + /// Sets the time interval between TCP keepalive probes, if TCP keepalive is + /// enabled on this socket. + /// + /// This sets the value of `TCP_KEEPINTVL` on supported Unix operating + /// systems. On Windows, this sets the value of the `tcp_keepalive` struct's + /// `keepaliveinterval` field. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + /// + /// The OS may return an error if TCP keepalive was not already enabled by + /// calling `set_keepalive(true)` on this socket. + #[cfg(any( + target_os = "linux", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" + ))] + pub fn set_keepalive_interval(&self, interval: Duration) -> io::Result<()> { + sys::tcp::set_keepalive_interval(self.sys, interval) + } + + /// Returns the time interval between TCP keepalive probes, if TCP keepalive is + /// enabled on this socket. + /// + /// If `None`, then keepalive messages are disabled. + /// + /// This returns the value of `TCP_KEEPINTVL` on supported Unix operating + /// systems. On Windows, this sets the value of the `tcp_keepalive` struct's + /// `keepaliveinterval` field. + /// + /// Some platforms specify this value in seconds, so sub-second + /// specifications may be omitted. + #[cfg(any( + target_os = "linux", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" + ))] + pub fn get_keepalive_interval(&self) -> io::Result> { + sys::tcp::get_keepalive_interval(self.sys) + } + + /// Sets the maximum number of TCP keepalive probes that will be sent before + /// dropping a connection, if TCP keepalive is enabled on this socket. + /// + /// This sets the value of `TCP_KEEPCNT` on Unix operating systems that + /// support this option. + /// + /// The OS may return an error if TCP keepalive was not already enabled by + /// calling `set_keepalive(true)` on this socket. + #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd",))] + pub fn set_keepalive_retries(&self, retries: u32) -> io::Result<()> { + sys::tcp::set_keepalive_retries(self.sys, retries) + } + + /// Returns the maximum number of TCP keepalive probes that will be sent before + /// dropping a connection, if TCP keepalive is enabled on this socket. + /// + /// If `None`, then keepalive messages are disabled. + /// + /// This returns the value of `TCP_KEEPCNT` on Unix operating systems that + /// support this option. + #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd",))] + pub fn get_keepalive_retries(&self) -> io::Result> { + sys::tcp::get_keepalive_retries(self.sys) } /// Returns the local address of this socket @@ -257,6 +339,8 @@ impl FromRawSocket for TcpSocket { /// The caller is responsible for ensuring that the socket is in /// non-blocking mode. unsafe fn from_raw_socket(socket: RawSocket) -> TcpSocket { - TcpSocket { sys: socket as sys::tcp::TcpSocket } + TcpSocket { + sys: socket as sys::tcp::TcpSocket, + } } } diff --git a/src/sys/shell/tcp.rs b/src/sys/shell/tcp.rs index 0f6d683f34..597c5f65aa 100644 --- a/src/sys/shell/tcp.rs +++ b/src/sys/shell/tcp.rs @@ -65,14 +65,64 @@ pub(crate) fn set_send_buffer_size(_: TcpSocket, _: u32) -> io::Result<()> { pub(crate) fn get_send_buffer_size(_: TcpSocket) -> io::Result { os_required!(); } -pub(crate) fn set_keepalive(_: TcpSocket, _: Option) -> io::Result<()> { + +pub(crate) fn set_keepalive(_: TcpSocket, _: bool) -> io::Result<()> { + os_required!(); +} + +pub(crate) fn get_keepalive(_: TcpSocket) -> io::Result { os_required!(); } -pub(crate) fn get_keepalive(_: TcpSocket) -> io::Result> { +#[cfg(any( + target_os = "linux", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" +))] +pub(crate) fn set_keepalive_time(_: TcpSocket, _: Duration) -> io::Result<()> { os_required!(); } +#[cfg(any( + target_os = "linux", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" +))] +pub(crate) fn get_keepalive_time(_: TcpSocket) -> io::Result> { + os_required!() +} + +#[cfg(any( + target_os = "linux", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" +))] +pub(crate) fn set_keepalive_interval(_: TcpSocket, _: Duration) -> io::Result<()> { + os_required!() +} + +#[cfg(any( + target_os = "linux", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" +))] +pub(crate) fn get_keepalive_interval(_: TcpSocket) -> io::Result> { + os_required!() +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))] +pub(crate) fn set_keepalive_retries(_: TcpSocket, _: u32) -> io::Result<()> { + os_required!() +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))] +pub(crate) fn get_keepalive_retries(socket: TcpSocket) -> io::Result> { + os_required!() +} pub fn accept(_: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { os_required!(); diff --git a/src/sys/unix/tcp.rs b/src/sys/unix/tcp.rs index 7d60e54ca4..c4687451a1 100644 --- a/src/sys/unix/tcp.rs +++ b/src/sys/unix/tcp.rs @@ -1,17 +1,17 @@ -use std::io; use std::convert::TryInto; +use std::io; use std::mem; use std::mem::{size_of, MaybeUninit}; use std::net::{self, SocketAddr}; -use std::time::Duration; use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::time::Duration; use crate::sys::unix::net::{new_socket, socket_addr, to_socket_addr}; -#[cfg(any(target_os = "macos", target_os = "ios"))] -use libc::TCP_KEEPALIVE as KEEPALIVE; #[cfg(any(target_os = "openbsd", target_os = "netbsd", target_os = "haiku"))] -use libc::SO_KEEPALIVE as KEEPALIVE; +use libc::SO_KEEPALIVE as KEEPALIVE_TIME; +#[cfg(any(target_os = "macos", target_os = "ios"))] +use libc::TCP_KEEPALIVE as KEEPALIVE_TIME; #[cfg(not(any( target_os = "macos", target_os = "ios", @@ -19,8 +19,7 @@ use libc::SO_KEEPALIVE as KEEPALIVE; target_os = "netbsd", target_os = "haiku" )))] -use libc::TCP_KEEPIDLE as KEEPALIVE; - +use libc::TCP_KEEPIDLE as KEEPALIVE_TIME; pub type TcpSocket = libc::c_int; pub(crate) fn new_v4_socket() -> io::Result { @@ -41,12 +40,8 @@ pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result { - Err(err) - } - _ => { - Ok(unsafe { net::TcpStream::from_raw_fd(socket) }) - } + Err(err) if err.raw_os_error() != Some(libc::EINPROGRESS) => Err(err), + _ => Ok(unsafe { net::TcpStream::from_raw_fd(socket) }), } } @@ -68,7 +63,8 @@ pub(crate) fn set_reuseaddr(socket: TcpSocket, reuseaddr: bool) -> io::Result<() libc::SO_REUSEADDR, &val as *const libc::c_int as *const libc::c_void, size_of::() as libc::socklen_t, - )).map(|_| ()) + )) + .map(|_| ()) } pub(crate) fn get_reuseaddr(socket: TcpSocket) -> io::Result { @@ -96,7 +92,8 @@ pub(crate) fn set_reuseport(socket: TcpSocket, reuseport: bool) -> io::Result<() libc::SO_REUSEPORT, &val as *const libc::c_int as *const libc::c_void, size_of::() as libc::socklen_t, - )).map(|_| ()) + )) + .map(|_| ()) } #[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] @@ -131,7 +128,9 @@ pub(crate) fn get_localaddr(socket: TcpSocket) -> io::Result { pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result<()> { let val: libc::linger = libc::linger { l_onoff: if dur.is_some() { 1 } else { 0 }, - l_linger: dur.map(|dur| dur.as_secs() as libc::c_int).unwrap_or_default(), + l_linger: dur + .map(|dur| dur.as_secs() as libc::c_int) + .unwrap_or_default(), }; syscall!(setsockopt( socket, @@ -139,7 +138,8 @@ pub(crate) fn set_linger(socket: TcpSocket, dur: Option) -> io::Result libc::SO_LINGER, &val as *const libc::linger as *const libc::c_void, size_of::() as libc::socklen_t, - )).map(|_| ()) + )) + .map(|_| ()) } pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<()> { @@ -154,7 +154,7 @@ pub(crate) fn set_recv_buffer_size(socket: TcpSocket, size: u32) -> io::Result<( .map(|_| ()) } -pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result { +pub(crate) fn get_recv_buffer_size(socket: TcpSocket) -> io::Result { let mut optval: libc::c_int = 0; let mut optlen = size_of::() as libc::socklen_t; syscall!(getsockopt( @@ -180,7 +180,7 @@ pub(crate) fn set_send_buffer_size(socket: TcpSocket, size: u32) -> io::Result<( .map(|_| ()) } -pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { +pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { let mut optval: libc::c_int = 0; let mut optlen = size_of::() as libc::socklen_t; @@ -195,31 +195,19 @@ pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { Ok(optval as u32) } -pub(crate) fn set_keepalive(socket: TcpSocket, dur: Option) -> io::Result<()> { +pub(crate) fn set_keepalive(socket: TcpSocket, keepalive: bool) -> io::Result<()> { + let val: libc::c_int = if keepalive { 1 } else { 0 }; syscall!(setsockopt( socket, libc::SOL_SOCKET, libc::SO_KEEPALIVE, - &(dur.is_some() as libc::c_int) as *const _ as *const libc::c_void, + &val as *const _ as *const libc::c_void, size_of::() as libc::socklen_t )) - .map(|_| ())?; - - if let Some(dur) = dur { - let dur_secs = dur.as_secs().try_into().ok().unwrap_or_else(i32::max_value); - syscall!(setsockopt( - socket, - libc::IPPROTO_TCP, - KEEPALIVE, - &(dur_secs as libc::c_int) as *const _ as *const libc::c_void, - size_of::() as libc::socklen_t - )) - .map(|_| ())?; - } - Ok(()) + .map(|_| ()) } -pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result> { +pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result { let mut optval: libc::c_int = 0; let mut optlen = mem::size_of::() as libc::socklen_t; @@ -231,14 +219,36 @@ pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result> { &mut optlen, ))?; - if optval == 0 { + Ok(optval != 0) +} + +pub(crate) fn set_keepalive_time(socket: TcpSocket, time: Duration) -> io::Result<()> { + let time_secs = time + .as_secs() + .try_into() + .ok() + .unwrap_or_else(i32::max_value); + syscall!(setsockopt( + socket, + libc::IPPROTO_TCP, + KEEPALIVE_TIME, + &(time_secs as libc::c_int) as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +pub(crate) fn get_keepalive_time(socket: TcpSocket) -> io::Result> { + if !get_keepalive(socket)? { return Ok(None); } + let mut optval: libc::c_int = 0; + let mut optlen = mem::size_of::() as libc::socklen_t; syscall!(getsockopt( socket, libc::IPPROTO_TCP, - KEEPALIVE, + KEEPALIVE_TIME, &mut optval as *mut _ as *mut _, &mut optlen, ))?; @@ -246,6 +256,92 @@ pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result> { Ok(Some(Duration::from_secs(optval as u64))) } +/// Linux, FreeBSD, and NetBSD support setting the keepalive interval via +/// `TCP_KEEPINTVL`. +/// See: +/// - https://man7.org/linux/man-pages/man7/tcp.7.html +/// - https://www.freebsd.org/cgi/man.cgi?query=tcp#end +/// - http://man.netbsd.org/tcp.4#DESCRIPTION +/// +/// OpenBSD does not: +/// https://man.openbsd.org/tcp +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))] +pub(crate) fn set_keepalive_interval(socket: TcpSocket, interval: Duration) -> io::Result<()> { + let interval_secs = interval + .as_secs() + .try_into() + .ok() + .unwrap_or_else(i32::max_value); + syscall!(setsockopt( + socket, + libc::IPPROTO_TCP, + libc::TCP_KEEPINTVL, + &(interval_secs as libc::c_int) as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))] +pub(crate) fn get_keepalive_interval(socket: TcpSocket) -> io::Result> { + if !get_keepalive(socket)? { + return Ok(None); + } + + let mut optval: libc::c_int = 0; + let mut optlen = mem::size_of::() as libc::socklen_t; + syscall!(getsockopt( + socket, + libc::IPPROTO_TCP, + libc::TCP_KEEPINTVL, + &mut optval as *mut _ as *mut _, + &mut optlen, + ))?; + + Ok(Some(Duration::from_secs(optval as u64))) +} + +/// Linux, FreeBSD, and NetBSD support setting the number of TCP keepalive +/// retries via `TCP_KEEPCNT`. +/// See: +/// - https://man7.org/linux/man-pages/man7/tcp.7.html +/// - https://www.freebsd.org/cgi/man.cgi?query=tcp#end +/// - http://man.netbsd.org/tcp.4#DESCRIPTION +/// +/// OpenBSD does not: +/// https://man.openbsd.org/tcp +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))] +pub(crate) fn set_keepalive_retries(socket: TcpSocket, retries: u32) -> io::Result<()> { + let retries = retries.try_into().ok().unwrap_or_else(i32::max_value); + syscall!(setsockopt( + socket, + libc::IPPROTO_TCP, + libc::TCP_KEEPCNT, + &(retries as libc::c_int) as *const _ as *const libc::c_void, + size_of::() as libc::socklen_t + )) + .map(|_| ()) +} + +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))] +pub(crate) fn get_keepalive_retries(socket: TcpSocket) -> io::Result> { + if !get_keepalive(socket)? { + return Ok(None); + } + + let mut optval: libc::c_int = 0; + let mut optlen = mem::size_of::() as libc::socklen_t; + syscall!(getsockopt( + socket, + libc::IPPROTO_TCP, + libc::TCP_KEEPCNT, + &mut optval as *mut _ as *mut _, + &mut optlen, + ))?; + + Ok(Some(optval as u32)) +} + pub fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> { let mut addr: MaybeUninit = MaybeUninit::uninit(); let mut length = size_of::() as libc::socklen_t; diff --git a/src/sys/windows/tcp.rs b/src/sys/windows/tcp.rs index f4edc01a4f..9ecbda65ec 100644 --- a/src/sys/windows/tcp.rs +++ b/src/sys/windows/tcp.rs @@ -15,7 +15,7 @@ use winapi::shared::mstcpip; use winapi::shared::minwindef::{BOOL, TRUE, FALSE, DWORD, LPVOID, LPDWORD}; use winapi::um::winsock2::{ self, closesocket, linger, setsockopt, getsockopt, getsockname, PF_INET, PF_INET6, SOCKET, SOCKET_ERROR, - SOCK_STREAM, SOL_SOCKET, SO_LINGER, SO_REUSEADDR, SO_RCVBUF, SO_SNDBUF, WSAIoctl, LPWSAOVERLAPPED + SOCK_STREAM, SOL_SOCKET, SO_LINGER, SO_REUSEADDR, SO_RCVBUF, SO_SNDBUF, SO_KEEPALIVE, WSAIoctl, LPWSAOVERLAPPED }; use crate::sys::windows::net::{init, new_socket, socket_addr}; @@ -211,57 +211,140 @@ pub(crate) fn get_send_buffer_size(socket: TcpSocket) -> io::Result { } } +pub(crate) fn set_keepalive(socket: TcpSocket, keepalive: bool) -> io::Result<()> { + let val: BOOL = if keepalive { TRUE } else { FALSE }; + match unsafe { setsockopt( + socket, + SOL_SOCKET, + SO_KEEPALIVE, + &val as *const _ as *const c_char, + size_of::() as c_int + ) } { + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => Ok(()), + } +} -pub(crate) fn set_keepalive(socket: TcpSocket, dur: Option) -> io::Result<()> { - // Windows takes the keepalive timeout as a u32 of milliseconds. - let dur_ms = dur.map(|dur| { - let ms = dur.as_millis(); - ms.try_into().ok().unwrap_or_else(i32::max_value) - }).unwrap_or(0); - - let keepalive = mstcpip::tcp_keepalive { - onoff: dur.is_some() as c_ulong, - keepalivetime: dur_ms as c_ulong, - keepaliveinterval: dur_ms as c_ulong, - }; +pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result { + let mut optval: c_char = 0; + let mut optlen = size_of::() as c_int; - let mut out = 0; - match unsafe { WSAIoctl( + match unsafe { getsockopt( socket, - mstcpip::SIO_KEEPALIVE_VALS, - &keepalive as *const _ as *mut mstcpip::tcp_keepalive as LPVOID, - size_of::() as DWORD, - ptr::null_mut() as LPVOID, - 0 as DWORD, - &mut out as *mut _ as LPDWORD, - ptr::null_mut() as LPWSAOVERLAPPED, - None, + SOL_SOCKET, + SO_KEEPALIVE, + &mut optval as *mut _ as *mut _, + &mut optlen, ) } { - 0 => Ok(()), - _ => Err(io::Error::last_os_error()) + SOCKET_ERROR => Err(io::Error::last_os_error()), + _ => Ok(optval != FALSE as c_char), } } -pub(crate) fn get_keepalive(socket: TcpSocket) -> io::Result> { +pub(crate) fn set_keepalive_time(socket: TcpSocket, time: Duration) -> io::Result<()> { let mut keepalive = mstcpip::tcp_keepalive { onoff: 0, keepalivetime: 0, keepaliveinterval: 0, }; + // First, populate an empty keepalive structure with the current values. + // Otherwise, if we call `WSAIoctl` with fields other than the keepalive + // time set to 0, we'll clobber the existing values. + get_keepalive_vals(socket, &mut keepalive)?; + + // Windows takes the keepalive time as a u32 of milliseconds. + let time_ms = time.as_millis().try_into().ok().unwrap_or_else(u32::max_value); + keepalive.keepalivetime = time_ms as c_ulong; + // XXX(eliza): if keepalive is disabled on the socket, do we want to turn it + // on here, or just propagate the OS error? + set_keepalive_vals(socket, &keepalive) +} + +pub(crate) fn get_keepalive_time(socket: TcpSocket) -> io::Result> { + let mut keepalive = mstcpip::tcp_keepalive { + onoff: 0, + keepalivetime: 0, + keepaliveinterval: 0, + }; + + get_keepalive_vals(socket, &mut keepalive)?; + + if keepalive.onoff == 0 { + // Keepalive is disabled on this socket. + return Ok(None); + } + Ok(Some(Duration::from_millis(keepalive.keepalivetime as u64))) +} + +pub(crate) fn set_keepalive_interval(socket: TcpSocket, interval: Duration) -> io::Result<()> { + let mut keepalive = mstcpip::tcp_keepalive { + onoff: 0, + keepalivetime: 0, + keepaliveinterval: 0, + }; + + // First, populate an empty keepalive structure with the current values. + // Otherwise, if we call `WSAIoctl` with fields other than the keepalive + // interval set to 0, we'll clobber the existing values. + get_keepalive_vals(socket, &mut keepalive)?; + + // Windows takes the keepalive interval as a u32 of milliseconds. + let interval_ms = interval.as_millis().try_into().ok().unwrap_or_else(u32::max_value); + keepalive.keepaliveinterval = interval_ms as c_ulong; + // XXX(eliza): if keepalive is disabled on the socket, do we want to turn it + // on here, or just propagate the OS error? + set_keepalive_vals(socket, &keepalive) +} + +pub(crate) fn get_keepalive_interval(socket: TcpSocket) -> io::Result> { + let mut keepalive = mstcpip::tcp_keepalive { + onoff: 0, + keepalivetime: 0, + keepaliveinterval: 0, + }; + + get_keepalive_vals(socket, &mut keepalive)?; + + if keepalive.onoff == 0 { + // Keepalive is disabled on this socket. + return Ok(None); + } + + Ok(Some(Duration::from_millis(keepalive.keepaliveinterval as u64))) +} + +fn get_keepalive_vals(socket: TcpSocket, vals: &mut mstcpip::tcp_keepalive) -> io::Result<()> { match unsafe { WSAIoctl( socket, mstcpip::SIO_KEEPALIVE_VALS, ptr::null_mut() as LPVOID, 0, - &mut keepalive as *mut _ as LPVOID, + vals as *mut _ as LPVOID, size_of::() as DWORD, ptr::null_mut() as LPDWORD, ptr::null_mut() as LPWSAOVERLAPPED, None, ) } { - 0 if keepalive.onoff == 0 || keepalive.keepaliveinterval == 0 => Ok(None), - 0 => Ok(Some(Duration::from_millis(keepalive.keepaliveinterval as u64))), + 0 => Ok(()), + _ => Err(io::Error::last_os_error()) + } +} + +fn set_keepalive_vals(socket: TcpSocket, vals: &mstcpip::tcp_keepalive) -> io::Result<()> { + let mut out = 0; + match unsafe { WSAIoctl( + socket, + mstcpip::SIO_KEEPALIVE_VALS, + vals as *const _ as *mut mstcpip::tcp_keepalive as LPVOID, + size_of::() as DWORD, + ptr::null_mut() as LPVOID, + 0 as DWORD, + &mut out as *mut _ as LPDWORD, + ptr::null_mut() as LPWSAOVERLAPPED, + None, + ) } { + 0 => Ok(()), _ => Err(io::Error::last_os_error()) } } diff --git a/tests/tcp_socket.rs b/tests/tcp_socket.rs index 4870f00bc4..fec9887f20 100644 --- a/tests/tcp_socket.rs +++ b/tests/tcp_socket.rs @@ -42,25 +42,65 @@ fn set_reuseport() { #[test] fn set_keepalive() { + let addr = "127.0.0.1:0".parse().unwrap(); + + let socket = TcpSocket::new_v4().unwrap(); + socket.set_keepalive(false).unwrap(); + assert_eq!(false, socket.get_keepalive().unwrap()); + + socket.set_keepalive(true).unwrap(); + assert_eq!(true, socket.get_keepalive().unwrap()); + + socket.bind(addr).unwrap(); + + let _ = socket.listen(128).unwrap(); +} + +#[test] +fn set_keepalive_time() { + let dur = Duration::from_secs(4); // Chosen by fair dice roll, guaranteed to be random + let addr = "127.0.0.1:0".parse().unwrap(); + + let socket = TcpSocket::new_v4().unwrap(); + socket.set_keepalive(true).unwrap(); + socket.set_keepalive_time(dur).unwrap(); + assert_eq!(Some(dur), socket.get_keepalive_time().unwrap()); + + socket.bind(addr).unwrap(); + + let _ = socket.listen(128).unwrap(); +} + +#[cfg(any( + target_os = "linux", + target_os = "freebsd", + target_os = "netbsd", + target_os = "windows" +))] +#[test] +fn set_keepalive_interval() { let dur = Duration::from_secs(4); // Chosen by fair dice roll, guaranteed to be random let addr = "127.0.0.1:0".parse().unwrap(); let socket = TcpSocket::new_v4().unwrap(); - socket.set_keepalive(Some(dur)).unwrap(); - assert_eq!(Some(dur), socket.get_keepalive().unwrap()); + socket.set_keepalive(true).unwrap(); + socket.set_keepalive_interval(dur).unwrap(); + assert_eq!(Some(dur), socket.get_keepalive_interval().unwrap()); socket.bind(addr).unwrap(); let _ = socket.listen(128).unwrap(); } +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd",))] #[test] -fn set_keepalive_none() { +fn set_keepalive_retries() { let addr = "127.0.0.1:0".parse().unwrap(); let socket = TcpSocket::new_v4().unwrap(); - socket.set_keepalive(None).unwrap(); - assert_eq!(None, socket.get_keepalive().unwrap()); + socket.set_keepalive(true).unwrap(); + socket.set_keepalive_retries(16).unwrap(); + assert_eq!(Some(16), socket.get_keepalive_retries().unwrap()); socket.bind(addr).unwrap();