From 78d3d3790a56227ba915d559907f5e6db54640db Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 28 May 2021 20:32:42 +0100 Subject: [PATCH] Refactor windows sockets impl methods --- library/std/src/sys/windows/c.rs | 1 + library/std/src/sys/windows/net.rs | 305 ++++++++++++++++------------- 2 files changed, 173 insertions(+), 133 deletions(-) diff --git a/library/std/src/sys/windows/c.rs b/library/std/src/sys/windows/c.rs index b7efc884473b4..b64870401f1fd 100644 --- a/library/std/src/sys/windows/c.rs +++ b/library/std/src/sys/windows/c.rs @@ -234,6 +234,7 @@ pub const SD_RECEIVE: c_int = 0; pub const SD_SEND: c_int = 1; pub const SOCK_DGRAM: c_int = 2; pub const SOCK_STREAM: c_int = 1; +pub const SOCKET_ERROR: c_int = -1; pub const SOL_SOCKET: c_int = 0xffff; pub const SO_RCVTIMEO: c_int = 0x1006; pub const SO_SNDTIMEO: c_int = 0x1005; diff --git a/library/std/src/sys/windows/net.rs b/library/std/src/sys/windows/net.rs index 1ad13254c0846..9cea5c5e63a2d 100644 --- a/library/std/src/sys/windows/net.rs +++ b/library/std/src/sys/windows/net.rs @@ -12,7 +12,7 @@ use crate::sys_common::net; use crate::sys_common::{AsInner, FromInner, IntoInner}; use crate::time::Duration; -use libc::{c_int, c_long, c_ulong, c_void}; +use libc::{c_int, c_long, c_ulong}; pub type wrlen_t = i32; @@ -93,153 +93,177 @@ where impl Socket { pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result { - let fam = match *addr { + let family = match *addr { SocketAddr::V4(..) => c::AF_INET, SocketAddr::V6(..) => c::AF_INET6, }; let socket = unsafe { - match c::WSASocketW( - fam, + c::WSASocketW( + family, ty, 0, ptr::null_mut(), 0, c::WSA_FLAG_OVERLAPPED | c::WSA_FLAG_NO_HANDLE_INHERIT, - ) { - c::INVALID_SOCKET => match c::WSAGetLastError() { - c::WSAEPROTOTYPE | c::WSAEINVAL => { - match c::WSASocketW(fam, ty, 0, ptr::null_mut(), 0, c::WSA_FLAG_OVERLAPPED) - { - c::INVALID_SOCKET => Err(last_error()), - n => { - let s = Socket(n); - s.set_no_inherit()?; - Ok(s) - } - } - } - n => Err(io::Error::from_raw_os_error(n)), - }, - n => Ok(Socket(n)), + ) + }; + + if socket != c::INVALID_SOCKET { + Ok(Self(socket)) + } else { + let error = unsafe { c::WSAGetLastError() }; + + if error != c::WSAEPROTOTYPE && error != c::WSAEINVAL { + return Err(io::Error::from_raw_os_error(error)); + } + + let socket = + unsafe { c::WSASocketW(family, ty, 0, ptr::null_mut(), 0, c::WSA_FLAG_OVERLAPPED) }; + + if socket == c::INVALID_SOCKET { + return Err(last_error()); } - }?; - Ok(socket) + + let socket = Self(socket); + socket.set_no_inherit()?; + Ok(socket) + } } pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> { self.set_nonblocking(true)?; - let r = unsafe { + let result = { let (addrp, len) = addr.into_inner(); - cvt(c::connect(self.0, addrp, len)) + let result = unsafe { c::connect(self.0, addrp, len) }; + cvt(result).map(drop) }; self.set_nonblocking(false)?; - match r { - Ok(_) => return Ok(()), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => return Err(e), - } - - if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 { - return Err(io::Error::new_const( - io::ErrorKind::InvalidInput, - &"cannot set a 0 duration timeout", - )); - } - - let mut timeout = c::timeval { - tv_sec: timeout.as_secs() as c_long, - tv_usec: (timeout.subsec_nanos() / 1000) as c_long, - }; - if timeout.tv_sec == 0 && timeout.tv_usec == 0 { - timeout.tv_usec = 1; - } + match result { + Err(ref error) if error.kind() == io::ErrorKind::WouldBlock => { + if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 { + return Err(io::Error::new_const( + io::ErrorKind::InvalidInput, + &"cannot set a 0 duration timeout", + )); + } - let fds = unsafe { - let mut fds = mem::zeroed::(); - fds.fd_count = 1; - fds.fd_array[0] = self.0; - fds - }; + let mut timeout = c::timeval { + tv_sec: timeout.as_secs() as c_long, + tv_usec: (timeout.subsec_nanos() / 1000) as c_long, + }; - let mut writefds = fds; - let mut errorfds = fds; + if timeout.tv_sec == 0 && timeout.tv_usec == 0 { + timeout.tv_usec = 1; + } - let n = - unsafe { cvt(c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout))? }; + let fds = { + let mut fds = unsafe { mem::zeroed::() }; + fds.fd_count = 1; + fds.fd_array[0] = self.0; + fds + }; + + let mut writefds = fds; + let mut errorfds = fds; + + let count = { + let result = unsafe { + c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout) + }; + cvt(result)? + }; + + match count { + 0 => { + Err(io::Error::new_const(io::ErrorKind::TimedOut, &"connection timed out")) + } + _ => { + if writefds.fd_count != 1 { + if let Some(e) = self.take_error()? { + return Err(e); + } + } - match n { - 0 => Err(io::Error::new_const(io::ErrorKind::TimedOut, &"connection timed out")), - _ => { - if writefds.fd_count != 1 { - if let Some(e) = self.take_error()? { - return Err(e); + Ok(()) } } - Ok(()) } + _ => result, } } pub fn accept(&self, storage: *mut c::SOCKADDR, len: *mut c_int) -> io::Result { - let socket = unsafe { - match c::accept(self.0, storage, len) { - c::INVALID_SOCKET => Err(last_error()), - n => Ok(Socket(n)), - } - }?; - Ok(socket) + let socket = unsafe { c::accept(self.0, storage, len) }; + + match socket { + c::INVALID_SOCKET => Err(last_error()), + _ => Ok(Self(socket)), + } } pub fn duplicate(&self) -> io::Result { + let mut info = unsafe { mem::zeroed::() }; + let result = unsafe { c::WSADuplicateSocketW(self.0, c::GetCurrentProcessId(), &mut info) }; + cvt(result)?; let socket = unsafe { - let mut info: c::WSAPROTOCOL_INFO = mem::zeroed(); - cvt(c::WSADuplicateSocketW(self.0, c::GetCurrentProcessId(), &mut info))?; - - match c::WSASocketW( + c::WSASocketW( info.iAddressFamily, info.iSocketType, info.iProtocol, &mut info, 0, c::WSA_FLAG_OVERLAPPED | c::WSA_FLAG_NO_HANDLE_INHERIT, - ) { - c::INVALID_SOCKET => match c::WSAGetLastError() { - c::WSAEPROTOTYPE | c::WSAEINVAL => { - match c::WSASocketW( - info.iAddressFamily, - info.iSocketType, - info.iProtocol, - &mut info, - 0, - c::WSA_FLAG_OVERLAPPED, - ) { - c::INVALID_SOCKET => Err(last_error()), - n => { - let s = Socket(n); - s.set_no_inherit()?; - Ok(s) - } - } - } - n => Err(io::Error::from_raw_os_error(n)), - }, - n => Ok(Socket(n)), + ) + }; + + if socket != c::INVALID_SOCKET { + Ok(Self(socket)) + } else { + let error = unsafe { c::WSAGetLastError() }; + + if error != c::WSAEPROTOTYPE && error != c::WSAEINVAL { + return Err(io::Error::from_raw_os_error(error)); + } + + let socket = unsafe { + c::WSASocketW( + info.iAddressFamily, + info.iSocketType, + info.iProtocol, + &mut info, + 0, + c::WSA_FLAG_OVERLAPPED, + ) + }; + + if socket == c::INVALID_SOCKET { + return Err(last_error()); } - }?; - Ok(socket) + + let socket = Self(socket); + socket.set_no_inherit()?; + Ok(socket) + } } fn recv_with_flags(&self, buf: &mut [u8], flags: c_int) -> io::Result { // On unix when a socket is shut down all further reads return 0, so we // do the same on windows to map a shut down socket to returning EOF. - let len = cmp::min(buf.len(), i32::MAX as usize) as i32; - unsafe { - match c::recv(self.0, buf.as_mut_ptr() as *mut c_void, len, flags) { - -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0), - -1 => Err(last_error()), - n => Ok(n as usize), + let length = cmp::min(buf.len(), i32::MAX as usize) as i32; + let result = unsafe { c::recv(self.0, buf.as_mut_ptr() as *mut _, length, flags) }; + + match result { + c::SOCKET_ERROR => { + let error = unsafe { c::WSAGetLastError() }; + + if error == c::WSAESHUTDOWN { + Ok(0) + } else { + Err(io::Error::from_raw_os_error(error)) + } } + _ => Ok(result as usize), } } @@ -250,23 +274,31 @@ impl Socket { pub fn read_vectored(&self, bufs: &mut [IoSliceMut<'_>]) -> io::Result { // On unix when a socket is shut down all further reads return 0, so we // do the same on windows to map a shut down socket to returning EOF. - let len = cmp::min(bufs.len(), c::DWORD::MAX as usize) as c::DWORD; + let length = cmp::min(bufs.len(), c::DWORD::MAX as usize) as c::DWORD; let mut nread = 0; let mut flags = 0; - unsafe { - let ret = c::WSARecv( + let result = unsafe { + c::WSARecv( self.0, bufs.as_mut_ptr() as *mut c::WSABUF, - len, + length, &mut nread, &mut flags, ptr::null_mut(), ptr::null_mut(), - ); - match ret { - 0 => Ok(nread as usize), - _ if c::WSAGetLastError() == c::WSAESHUTDOWN => Ok(0), - _ => Err(last_error()), + ) + }; + + match result { + 0 => Ok(nread as usize), + _ => { + let error = unsafe { c::WSAGetLastError() }; + + if error == c::WSAESHUTDOWN { + Ok(0) + } else { + Err(io::Error::from_raw_os_error(error)) + } } } } @@ -285,27 +317,34 @@ impl Socket { buf: &mut [u8], flags: c_int, ) -> io::Result<(usize, SocketAddr)> { - let mut storage: c::SOCKADDR_STORAGE_LH = unsafe { mem::zeroed() }; + let mut storage = unsafe { mem::zeroed::() }; let mut addrlen = mem::size_of_val(&storage) as c::socklen_t; - let len = cmp::min(buf.len(), ::MAX as usize) as wrlen_t; + let length = cmp::min(buf.len(), ::MAX as usize) as wrlen_t; // On unix when a socket is shut down all further reads return 0, so we // do the same on windows to map a shut down socket to returning EOF. - unsafe { - match c::recvfrom( + let result = unsafe { + c::recvfrom( self.0, - buf.as_mut_ptr() as *mut c_void, - len, + buf.as_mut_ptr() as *mut _, + length, flags, &mut storage as *mut _ as *mut _, &mut addrlen, - ) { - -1 if c::WSAGetLastError() == c::WSAESHUTDOWN => { + ) + }; + + match result { + c::SOCKET_ERROR => { + let error = unsafe { c::WSAGetLastError() }; + + if error == c::WSAESHUTDOWN { Ok((0, net::sockaddr_to_addr(&storage, addrlen as usize)?)) + } else { + Err(io::Error::from_raw_os_error(error)) } - -1 => Err(last_error()), - n => Ok((n as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)), } + _ => Ok((result as usize, net::sockaddr_to_addr(&storage, addrlen as usize)?)), } } @@ -318,20 +357,20 @@ impl Socket { } pub fn write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result { - let len = cmp::min(bufs.len(), c::DWORD::MAX as usize) as c::DWORD; + let length = cmp::min(bufs.len(), c::DWORD::MAX as usize) as c::DWORD; let mut nwritten = 0; - unsafe { - cvt(c::WSASend( + let result = unsafe { + c::WSASend( self.0, - bufs.as_ptr() as *const c::WSABUF as *mut c::WSABUF, - len, + bufs.as_ptr() as *const c::WSABUF as *mut _, + length, &mut nwritten, 0, ptr::null_mut(), ptr::null_mut(), - ))?; - } - Ok(nwritten as usize) + ) + }; + cvt(result).map(|_| nwritten as usize) } #[inline] @@ -384,14 +423,14 @@ impl Socket { Shutdown::Read => c::SD_RECEIVE, Shutdown::Both => c::SD_BOTH, }; - cvt(unsafe { c::shutdown(self.0, how) })?; - Ok(()) + let result = unsafe { c::shutdown(self.0, how) }; + cvt(result).map(drop) } pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { let mut nonblocking = nonblocking as c_ulong; - let r = unsafe { c::ioctlsocket(self.0, c::FIONBIO as c_int, &mut nonblocking) }; - if r == 0 { Ok(()) } else { Err(io::Error::last_os_error()) } + let result = unsafe { c::ioctlsocket(self.0, c::FIONBIO as c_int, &mut nonblocking) }; + cvt(result).map(drop) } pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {