diff --git a/src/Common/src/Interop/Windows/Winsock/AddressInfoEx.cs b/src/Common/src/Interop/Windows/Winsock/AddressInfoEx.cs new file mode 100644 index 000000000000..0972101831a6 --- /dev/null +++ b/src/Common/src/Interop/Windows/Winsock/AddressInfoEx.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Net.Internals; +using System.Runtime.InteropServices; + +namespace System.Net.Sockets +{ + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct AddressInfoEx + { + internal AddressInfoHints ai_flags; + internal AddressFamily ai_family; + internal SocketType ai_socktype; + internal ProtocolFamily ai_protocol; + internal int ai_addrlen; + internal IntPtr ai_canonname; // Ptr to the canonical name - check for NULL + internal byte* ai_addr; // Ptr to the sockaddr structure + internal IntPtr ai_blob; // Unused ptr to blob data about provider + internal int ai_bloblen; + internal IntPtr ai_provider; // Unused ptr to the namespace provider guid + internal AddressInfoEx* ai_next; // Next structure in linked list + } +} diff --git a/src/Common/src/Interop/Windows/Winsock/Interop.GetAddrInfoExW.cs b/src/Common/src/Interop/Windows/Winsock/Interop.GetAddrInfoExW.cs new file mode 100644 index 000000000000..cb2070de81e5 --- /dev/null +++ b/src/Common/src/Interop/Windows/Winsock/Interop.GetAddrInfoExW.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Threading; + +internal static partial class Interop +{ + internal static partial class Winsock + { + internal const string GetAddrInfoExCancelFunctionName = "GetAddrInfoExCancel"; + + internal unsafe delegate void LPLOOKUPSERVICE_COMPLETION_ROUTINE([In] int dwError, [In] int dwBytes, [In] NativeOverlapped* lpOverlapped); + + [DllImport(Interop.Libraries.Ws2_32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern unsafe int GetAddrInfoExW( + [In] string pName, + [In] string pServiceName, + [In] int dwNamespace, + [In] IntPtr lpNspId, + [In] ref AddressInfoEx pHints, + [Out] out AddressInfoEx* ppResult, + [In] IntPtr timeout, + [In] ref NativeOverlapped lpOverlapped, + [In] LPLOOKUPSERVICE_COMPLETION_ROUTINE lpCompletionRoutine, + [Out] out IntPtr lpNameHandle + ); + + [DllImport("ws2_32.dll", ExactSpelling = true, SetLastError = true)] + internal static extern unsafe void FreeAddrInfoEx([In] AddressInfoEx* pAddrInfo); + } +} + diff --git a/src/Common/src/Interop/Windows/kernel32/Interop.LoadLibraryEx.cs b/src/Common/src/Interop/Windows/kernel32/Interop.LoadLibraryEx.cs index 4ba2fd65a453..d24d4d3c6a88 100644 --- a/src/Common/src/Interop/Windows/kernel32/Interop.LoadLibraryEx.cs +++ b/src/Common/src/Interop/Windows/kernel32/Interop.LoadLibraryEx.cs @@ -12,6 +12,7 @@ internal partial class Interop internal partial class Kernel32 { public const int LOAD_LIBRARY_AS_DATAFILE = 0x00000002; + public const int LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800; [DllImport(Libraries.Kernel32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] public static extern SafeLibraryHandle LoadLibraryExW([In] string lpwLibFileName, [In] IntPtr hFile, [In] uint dwFlags); diff --git a/src/Common/src/System/Net/SocketAddressPal.Unix.cs b/src/Common/src/System/Net/SocketAddressPal.Unix.cs index f7eaf04fbde2..94360635e210 100644 --- a/src/Common/src/System/Net/SocketAddressPal.Unix.cs +++ b/src/Common/src/System/Net/SocketAddressPal.Unix.cs @@ -102,11 +102,11 @@ public static unsafe void SetPort(byte[] buffer, ushort port) ThrowOnFailure(err); } - public static unsafe uint GetIPv4Address(byte[] buffer) + public static unsafe uint GetIPv4Address(ReadOnlySpan buffer) { uint ipAddress; Interop.Error err; - fixed (byte* rawAddress = buffer) + fixed (byte* rawAddress = &MemoryMarshal.GetReference(buffer)) { err = Interop.Sys.GetIPv4Address(rawAddress, buffer.Length, &ipAddress); } @@ -115,11 +115,11 @@ public static unsafe uint GetIPv4Address(byte[] buffer) return ipAddress; } - public static unsafe void GetIPv6Address(byte[] buffer, Span address, out uint scope) + public static unsafe void GetIPv6Address(ReadOnlySpan buffer, Span address, out uint scope) { uint localScope; Interop.Error err; - fixed (byte* rawAddress = buffer) + fixed (byte* rawAddress = &MemoryMarshal.GetReference(buffer)) fixed (byte* ipAddress = &MemoryMarshal.GetReference(address)) { err = Interop.Sys.GetIPv6Address(rawAddress, buffer.Length, ipAddress, address.Length, &localScope); diff --git a/src/Common/src/System/Net/SocketAddressPal.Windows.cs b/src/Common/src/System/Net/SocketAddressPal.Windows.cs index 32685d48826b..404a381ae8d9 100644 --- a/src/Common/src/System/Net/SocketAddressPal.Windows.cs +++ b/src/Common/src/System/Net/SocketAddressPal.Windows.cs @@ -38,7 +38,7 @@ public static unsafe void SetPort(byte[] buffer, ushort port) port.HostToNetworkBytes(buffer, 2); } - public static unsafe uint GetIPv4Address(byte[] buffer) + public static unsafe uint GetIPv4Address(ReadOnlySpan buffer) { unchecked { @@ -49,7 +49,7 @@ public static unsafe uint GetIPv4Address(byte[] buffer) } } - public static unsafe void GetIPv6Address(byte[] buffer, Span address, out uint scope) + public static unsafe void GetIPv6Address(ReadOnlySpan buffer, Span address, out uint scope) { for (int i = 0; i < address.Length; i++) { diff --git a/src/System.Net.NameResolution/src/System.Net.NameResolution.csproj b/src/System.Net.NameResolution/src/System.Net.NameResolution.csproj index ef34f9675be7..f6887340b60b 100644 --- a/src/System.Net.NameResolution/src/System.Net.NameResolution.csproj +++ b/src/System.Net.NameResolution/src/System.Net.NameResolution.csproj @@ -5,7 +5,7 @@ System.Net.NameResolution {1714448C-211E-48C1-8B7E-4EE667D336A1} true - + @@ -46,6 +46,10 @@ Common\System\Net\IPEndPointStatics.cs + + Common\System\Net\ByteOrder.cs + + @@ -72,6 +76,9 @@ Common\System\Net\SocketProtocolSupportPal.Windows + + Common\System\Net\SocketAddressPal.Windows + Interop\Windows\Interop.Libraries.cs @@ -118,12 +125,27 @@ Interop\Windows\Winsock\SafeFreeAddrInfo.cs + + Interop\Windows\Winsock\AddressInfoEx.cs + + + Interop\Windows\Winsock\Interop.GetAddrInfoExW.cs + + + Common\Microsoft\Win32\SafeHandles\SafeLibraryHandle.cs + + + Interop\Windows\Kernel32\Interop.GetProcAddress.cs + + + Interop\Windows\Kernel32\Interop.LoadLibraryEx.cs + + + Interop\Windows\Kernel32\Interop.FreeLibrary.cs + - - Common\System\Net\Internals\ByteOrder.cs - Common\System\Net\ContextAwareResult.Unix.cs @@ -189,7 +211,8 @@ + - + \ No newline at end of file diff --git a/src/System.Net.NameResolution/src/System/Net/DNS.cs b/src/System.Net.NameResolution/src/System/Net/DNS.cs index a51ca67828c1..f15f811f6143 100644 --- a/src/System.Net.NameResolution/src/System/Net/DNS.cs +++ b/src/System.Net.NameResolution/src/System/Net/DNS.cs @@ -41,17 +41,22 @@ public static IPHostEntry GetHostByName(string hostName) return InternalGetHostByName(hostName, false); } - private static IPHostEntry InternalGetHostByName(string hostName, bool includeIPv6) + private static void ValidateHostName(string hostName) { - if (NetEventSource.IsEnabled) NetEventSource.Enter(null, hostName); - IPHostEntry ipHostEntry = null; - if (hostName.Length > MaxHostName // If 255 chars, the last one must be a dot. || hostName.Length == MaxHostName && hostName[MaxHostName - 1] != '.') { throw new ArgumentOutOfRangeException(nameof(hostName), SR.Format(SR.net_toolong, nameof(hostName), MaxHostName.ToString(NumberFormatInfo.CurrentInfo))); } + } + + private static IPHostEntry InternalGetHostByName(string hostName, bool includeIPv6) + { + if (NetEventSource.IsEnabled) NetEventSource.Enter(null, hostName); + IPHostEntry ipHostEntry = null; + + ValidateHostName(hostName); // // IPv6 Changes: IPv6 requires the use of getaddrinfo() rather @@ -252,42 +257,19 @@ public static IPHostEntry Resolve(string hostName) return ipHostEntry; } - private class ResolveAsyncResult : ContextAwareResult - { - // Forward lookup - internal ResolveAsyncResult(string hostName, object myObject, bool includeIPv6, object myState, AsyncCallback myCallBack) : - base(myObject, myState, myCallBack) - { - this.hostName = hostName; - this.includeIPv6 = includeIPv6; - } - - // Reverse lookup - internal ResolveAsyncResult(IPAddress address, object myObject, bool includeIPv6, object myState, AsyncCallback myCallBack) : - base(myObject, myState, myCallBack) - { - this.includeIPv6 = includeIPv6; - this.address = address; - } - - internal readonly string hostName; - internal bool includeIPv6; - internal IPAddress address; - } - private static void ResolveCallback(object context) { - ResolveAsyncResult result = (ResolveAsyncResult)context; + DnsResolveAsyncResult result = (DnsResolveAsyncResult)context; IPHostEntry hostEntry; try { - if (result.address != null) + if (result.IpAddress != null) { - hostEntry = InternalGetHostByAddress(result.address, result.includeIPv6); + hostEntry = InternalGetHostByAddress(result.IpAddress, result.IncludeIPv6); } else { - hostEntry = InternalGetHostByName(result.hostName, result.includeIPv6); + hostEntry = InternalGetHostByName(result.HostName, result.IncludeIPv6); } } catch (OutOfMemoryException) @@ -315,20 +297,20 @@ private static IAsyncResult HostResolutionBeginHelper(string hostName, bool just if (NetEventSource.IsEnabled) NetEventSource.Info(null, hostName); // See if it's an IP Address. - IPAddress address; - ResolveAsyncResult asyncResult; - if (IPAddress.TryParse(hostName, out address)) + IPAddress ipAddress; + DnsResolveAsyncResult asyncResult; + if (IPAddress.TryParse(hostName, out ipAddress)) { - if (throwOnIIPAny && (address.Equals(IPAddress.Any) || address.Equals(IPAddress.IPv6Any))) + if (throwOnIIPAny && (ipAddress.Equals(IPAddress.Any) || ipAddress.Equals(IPAddress.IPv6Any))) { throw new ArgumentException(SR.net_invalid_ip_addr, nameof(hostName)); } - asyncResult = new ResolveAsyncResult(address, null, includeIPv6, state, requestCallback); + asyncResult = new DnsResolveAsyncResult(ipAddress, null, includeIPv6, state, requestCallback); if (justReturnParsedIp) { - IPHostEntry hostEntry = NameResolutionUtilities.GetUnresolvedAnswer(address); + IPHostEntry hostEntry = NameResolutionUtilities.GetUnresolvedAnswer(ipAddress); asyncResult.StartPostingAsyncOp(false); asyncResult.InvokeCallback(hostEntry); asyncResult.FinishPostingAsyncOp(); @@ -337,19 +319,29 @@ private static IAsyncResult HostResolutionBeginHelper(string hostName, bool just } else { - asyncResult = new ResolveAsyncResult(hostName, null, includeIPv6, state, requestCallback); + asyncResult = new DnsResolveAsyncResult(hostName, null, includeIPv6, state, requestCallback); } // Set up the context, possibly flow. asyncResult.StartPostingAsyncOp(false); - // Start the resolve. - Task.Factory.StartNew( - s => ResolveCallback(s), - asyncResult, - CancellationToken.None, - TaskCreationOptions.DenyChildAttach, - TaskScheduler.Default); + // If the OS supports it and 'hostName' is not an IP Address, resolve the name asynchronously + // instead of calling the synchronous version in the ThreadPool. + if (NameResolutionPal.SupportsGetAddrInfoAsync && ipAddress == null) + { + ValidateHostName(hostName); + NameResolutionPal.GetAddrInfoAsync(asyncResult); + } + else + { + // Start the resolve. + Task.Factory.StartNew( + s => ResolveCallback(s), + asyncResult, + CancellationToken.None, + TaskCreationOptions.DenyChildAttach, + TaskScheduler.Default); + } // Finish the flowing, maybe it completed? This does nothing if we didn't initiate the flowing above. asyncResult.FinishPostingAsyncOp(); @@ -371,7 +363,7 @@ private static IAsyncResult HostResolutionBeginHelper(IPAddress address, bool fl if (NetEventSource.IsEnabled) NetEventSource.Info(null, address); // Set up the context, possibly flow. - ResolveAsyncResult asyncResult = new ResolveAsyncResult(address, null, includeIPv6, state, requestCallback); + DnsResolveAsyncResult asyncResult = new DnsResolveAsyncResult(address, null, includeIPv6, state, requestCallback); if (flowContext) { asyncResult.StartPostingAsyncOp(false); @@ -399,7 +391,7 @@ private static IPHostEntry HostResolutionEndHelper(IAsyncResult asyncResult) { throw new ArgumentNullException(nameof(asyncResult)); } - ResolveAsyncResult castedResult = asyncResult as ResolveAsyncResult; + DnsResolveAsyncResult castedResult = asyncResult as DnsResolveAsyncResult; if (castedResult == null) { throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult)); @@ -611,7 +603,7 @@ public static IPHostEntry EndResolve(IAsyncResult asyncResult) } catch (SocketException ex) { - IPAddress address = ((ResolveAsyncResult)asyncResult).address; + IPAddress address = ((DnsResolveAsyncResult)asyncResult).IpAddress; if (address == null) throw; // BeginResolve was called with a HostName, not an IPAddress diff --git a/src/System.Net.NameResolution/src/System/Net/DnsResolveAsyncResult.cs b/src/System.Net.NameResolution/src/System/Net/DnsResolveAsyncResult.cs new file mode 100644 index 000000000000..ca4f700af124 --- /dev/null +++ b/src/System.Net.NameResolution/src/System/Net/DnsResolveAsyncResult.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Net +{ + internal sealed class DnsResolveAsyncResult : ContextAwareResult + { + internal string HostName { get; } + internal bool IncludeIPv6 { get; } + internal IPAddress IpAddress { get; } + + // Forward lookup + internal DnsResolveAsyncResult(string hostName, object myObject, bool includeIPv6, object myState, AsyncCallback myCallBack) + : base(myObject, myState, myCallBack) + { + HostName = hostName; + IncludeIPv6 = includeIPv6; + } + + // Reverse lookup + internal DnsResolveAsyncResult(IPAddress ipAddress, object myObject, bool includeIPv6, object myState, AsyncCallback myCallBack) + : base(myObject, myState, myCallBack) + { + IncludeIPv6 = includeIPv6; + IpAddress = ipAddress; + } + } +} diff --git a/src/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs b/src/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs index 139e2740f9fd..e823898e4d39 100644 --- a/src/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs +++ b/src/System.Net.NameResolution/src/System/Net/NameResolutionPal.Unix.cs @@ -13,6 +13,8 @@ namespace System.Net { internal static partial class NameResolutionPal { + public const bool SupportsGetAddrInfoAsync = false; + private static SocketError GetSocketErrorForErrno(int errno) { switch (errno) @@ -194,6 +196,11 @@ public static unsafe SocketError TryGetAddrInfo(string name, out IPHostEntry hos return SocketError.Success; } + internal static void GetAddrInfoAsync(DnsResolveAsyncResult asyncResult) + { + throw new NotSupportedException(); + } + public static unsafe string TryGetNameInfo(IPAddress addr, out SocketError socketError, out int nativeErrorCode) { byte* buffer = stackalloc byte[Interop.Sys.NI_MAXHOST + 1 /*for null*/]; diff --git a/src/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs b/src/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs index 36a69f8637aa..1b01aa44ecca 100644 --- a/src/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs +++ b/src/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs @@ -7,6 +7,8 @@ using System.Runtime.InteropServices; using System.Text; using System.Threading; +using System.Threading.Tasks; +using Microsoft.Win32.SafeHandles; using ProtocolFamily = System.Net.Internals.ProtocolFamily; namespace System.Net @@ -17,9 +19,22 @@ internal static class NameResolutionPal // used by GetHostName() to preallocate a buffer for the call to gethostname. // private const int HostNameBufferLength = 256; + private static bool s_initialized; private static readonly object s_initializedLock = new object(); + private static readonly unsafe Interop.Winsock.LPLOOKUPSERVICE_COMPLETION_ROUTINE s_getAddrInfoExCallback = GetAddressInfoExCallback; + private static bool s_getAddrInfoExSupported; + + public static bool SupportsGetAddrInfoAsync + { + get + { + EnsureSocketsAreInitialized(); + return s_getAddrInfoExSupported; + } + } + /*++ Routine Description: @@ -232,7 +247,6 @@ public static unsafe SocketError TryGetAddrInfo(string name, out IPHostEntry hos // while (pAddressInfo != null) { - SocketAddress sockaddr; // // Retrieve the canonical name for the host - only appears in the first AddressInfo // entry in the returned array. @@ -247,29 +261,17 @@ public static unsafe SocketError TryGetAddrInfo(string name, out IPHostEntry hos // We also filter based on whether IPv6 is supported on the current // platform / machine. // - if ((pAddressInfo->ai_family == AddressFamily.InterNetwork) || // Never filter v4 - (pAddressInfo->ai_family == AddressFamily.InterNetworkV6 && SocketProtocolSupportPal.OSSupportsIPv6)) + var socketAddress = new ReadOnlySpan(pAddressInfo->ai_addr, pAddressInfo->ai_addrlen); + + if (pAddressInfo->ai_family == AddressFamily.InterNetwork) { - sockaddr = new SocketAddress(pAddressInfo->ai_family, pAddressInfo->ai_addrlen); - // - // Push address data into the socket address buffer - // - for (int d = 0; d < pAddressInfo->ai_addrlen; d++) - { - sockaddr[d] = *(pAddressInfo->ai_addr + d); - } - // - // NOTE: We need an IPAddress now, the only way to create it from a - // SocketAddress is via IPEndPoint. This ought to be simpler. - // - if (pAddressInfo->ai_family == AddressFamily.InterNetwork) - { - addresses.Add(((IPEndPoint)IPEndPointStatics.Any.Create(sockaddr)).Address); - } - else - { - addresses.Add(((IPEndPoint)IPEndPointStatics.IPv6Any.Create(sockaddr)).Address); - } + if (socketAddress.Length == SocketAddressPal.IPv4AddressSize) + addresses.Add(CreateIPv4Address(socketAddress)); + } + else if (pAddressInfo->ai_family == AddressFamily.InterNetworkV6 && SocketProtocolSupportPal.OSSupportsIPv6) + { + if (socketAddress.Length == SocketAddressPal.IPv6AddressSize) + addresses.Add(CreateIPv6Address(socketAddress)); } // // Next addressinfo entry @@ -385,10 +387,206 @@ public static void EnsureSocketsAreInitialized() throw new SocketException((int)errorCode); } + s_getAddrInfoExSupported = GetAddrInfoExSupportsOverlapped(); + Volatile.Write(ref s_initialized, true); } } } } + + private static bool GetAddrInfoExSupportsOverlapped() + { + using (SafeLibraryHandle libHandle = Interop.Kernel32.LoadLibraryExW(Interop.Libraries.Ws2_32, IntPtr.Zero, Interop.Kernel32.LOAD_LIBRARY_SEARCH_SYSTEM32)) + { + if (libHandle.IsInvalid) + return false; + + // We can't just check that 'GetAddrInfoEx' exists, because it existed before supporting overlapped. + // The existance of 'GetAddrInfoExCancel' indicates that overlapped is supported. + return Interop.Kernel32.GetProcAddress(libHandle, Interop.Winsock.GetAddrInfoExCancelFunctionName) != IntPtr.Zero; + } + } + + public static unsafe void GetAddrInfoAsync(DnsResolveAsyncResult asyncResult) + { + GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext(); + + try + { + var state = new GetAddrInfoExState(asyncResult); + context->QueryStateHandle = state.CreateHandle(); + } + catch + { + GetAddrInfoExContext.FreeContext(context); + throw; + } + + AddressInfoEx hints = new AddressInfoEx(); + hints.ai_flags = AddressInfoHints.AI_CANONNAME; + hints.ai_family = AddressFamily.Unspecified; // Gets all address families + + SocketError errorCode = + (SocketError)Interop.Winsock.GetAddrInfoExW(asyncResult.HostName, null, 0 /* NS_ALL*/, IntPtr.Zero, ref hints, out context->Result, IntPtr.Zero, ref context->Overlapped, s_getAddrInfoExCallback, out context->CancelHandle); + + if (errorCode != SocketError.IOPending) + ProcessResult(errorCode, context); + } + + private static unsafe void GetAddressInfoExCallback([In] int error, [In] int bytes, [In] NativeOverlapped* overlapped) + { + // Can be casted directly to GetAddrInfoExContext* because the overlapped is its first field + GetAddrInfoExContext* context = (GetAddrInfoExContext*)overlapped; + + ProcessResult((SocketError)error, context); + } + + private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExContext* context) + { + try + { + GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle); + + if (errorCode != SocketError.Success) + { + state.CompleteAsyncResult(new SocketException((int)errorCode)); + return; + } + + AddressInfoEx* result = context->Result; + string canonicalName = null; + + List addresses = new List(); + + while (result != null) + { + if (canonicalName == null && result->ai_canonname != IntPtr.Zero) + canonicalName = Marshal.PtrToStringUni(result->ai_canonname); + + var socketAddress = new ReadOnlySpan(result->ai_addr, result->ai_addrlen); + + if (result->ai_family == AddressFamily.InterNetwork) + { + if (socketAddress.Length == SocketAddressPal.IPv4AddressSize) + addresses.Add(CreateIPv4Address(socketAddress)); + } + else if (SocketProtocolSupportPal.OSSupportsIPv6 && result->ai_family == AddressFamily.InterNetworkV6) + { + if (socketAddress.Length == SocketAddressPal.IPv6AddressSize) + addresses.Add(CreateIPv6Address(socketAddress)); + } + + result = result->ai_next; + } + + if (canonicalName == null) + canonicalName = state.HostName; + + state.CompleteAsyncResult(new IPHostEntry + { + HostName = canonicalName, + Aliases = Array.Empty(), + AddressList = addresses.ToArray() + }); + } + finally + { + GetAddrInfoExContext.FreeContext(context); + } + } + + private static unsafe IPAddress CreateIPv4Address(ReadOnlySpan socketAddress) + { + long address = (long)SocketAddressPal.GetIPv4Address(socketAddress) & 0x0FFFFFFFF; + return new IPAddress(address); + } + + private static unsafe IPAddress CreateIPv6Address(ReadOnlySpan socketAddress) + { + Span address = stackalloc byte[IPAddressParserStatics.IPv6AddressBytes]; + uint scope; + SocketAddressPal.GetIPv6Address(socketAddress, address, out scope); + + return new IPAddress(address, (long)scope); + } + + #region GetAddrInfoAsync Helper Classes + + // + // Warning: If this ever ported to NETFX, AppDomain unloads needs to be handled + // to protect against AppDomainUnloadException if there are pending operations. + // + + private sealed class GetAddrInfoExState + { + private DnsResolveAsyncResult _asyncResult; + private object _result; + + public string HostName => _asyncResult.HostName; + + public GetAddrInfoExState(DnsResolveAsyncResult asyncResult) + { + _asyncResult = asyncResult; + } + + public void CompleteAsyncResult(object o) + { + // We don't want to expose the GetAddrInfoEx callback thread to user code. + // The callback occurs in a native windows thread pool. + + _result = o; + + Task.Factory.StartNew(s => + { + var self = (GetAddrInfoExState)s; + self._asyncResult.InvokeCallback(self._result); + }, this, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); + } + + public IntPtr CreateHandle() + { + GCHandle handle = GCHandle.Alloc(this, GCHandleType.Normal); + return GCHandle.ToIntPtr(handle); + } + + public static GetAddrInfoExState FromHandleAndFree(IntPtr handle) + { + GCHandle gcHandle = GCHandle.FromIntPtr(handle); + var state = (GetAddrInfoExState)gcHandle.Target; + gcHandle.Free(); + + return state; + } + } + + [StructLayout(LayoutKind.Sequential)] + private unsafe struct GetAddrInfoExContext + { + private static readonly int Size = sizeof(GetAddrInfoExContext); + + public NativeOverlapped Overlapped; + public AddressInfoEx* Result; + public IntPtr CancelHandle; + public IntPtr QueryStateHandle; + + public static GetAddrInfoExContext* AllocateContext() + { + var context = (GetAddrInfoExContext*)Marshal.AllocHGlobal(Size); + *context = default; + + return context; + } + + public static void FreeContext(GetAddrInfoExContext* context) + { + if (context->Result != null) + Interop.Winsock.FreeAddrInfoEx(context->Result); + + Marshal.FreeHGlobal((IntPtr)context); + } + } + + #endregion } } diff --git a/src/System.Net.NameResolution/tests/PalTests/Configurations.props b/src/System.Net.NameResolution/tests/PalTests/Configurations.props index eddfd3a9acc5..1040c9ba37f0 100644 --- a/src/System.Net.NameResolution/tests/PalTests/Configurations.props +++ b/src/System.Net.NameResolution/tests/PalTests/Configurations.props @@ -2,8 +2,6 @@ - netstandard-Windows_NT; - netstandard-Unix; netcoreapp-Windows_NT; netcoreapp-Unix; diff --git a/src/System.Net.NameResolution/tests/PalTests/Fakes/FakeContextAwareResult.cs b/src/System.Net.NameResolution/tests/PalTests/Fakes/FakeContextAwareResult.cs new file mode 100644 index 000000000000..8a9615c9d2da --- /dev/null +++ b/src/System.Net.NameResolution/tests/PalTests/Fakes/FakeContextAwareResult.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Threading; + +namespace System.Net +{ + internal partial class ContextAwareResult : IAsyncResult + { + private AsyncCallback _callback; + + private static Func _resultFactory; + + public static void FakeSetResultFactory(Func resultFactory) + { + _resultFactory = resultFactory; + } + + public object AsyncState + { + get + { + throw new NotImplementedException(); + } + } + + internal bool EndCalled + { + get; + set; + } + + internal object Result + { + get + { + return _resultFactory?.Invoke(); + } + } + + public WaitHandle AsyncWaitHandle + { + get + { + throw new NotImplementedException(); + } + } + + public bool CompletedSynchronously + { + get + { + // Simulate sync completion: + return true; + } + } + + public bool IsCompleted + { + get + { + throw new NotImplementedException(); + } + } + + internal ContextAwareResult(object myObject, object myState, AsyncCallback myCallBack) + { + _callback = myCallBack; + } + + internal object StartPostingAsyncOp(bool lockCapture) + { + return null; + } + + internal bool FinishPostingAsyncOp() + { + return true; + } + + internal void InvokeCallback(object result) + { + _callback.Invoke(this); + } + + internal void InternalWaitForCompletion() { } + + + } +} diff --git a/src/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj b/src/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj index 333aa11cc124..e01b311648b6 100644 --- a/src/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj +++ b/src/System.Net.NameResolution/tests/PalTests/System.Net.NameResolution.Pal.Tests.csproj @@ -24,6 +24,7 @@ + @@ -34,6 +35,9 @@ ProductionCode\System\Net\NameResolutionUtilities.cs + + ProductionCode\System\Net\DnsResolveAsyncResult.cs + Common\System\Net\Sockets\ProtocolType.cs @@ -51,6 +55,9 @@ Common\System\Net\Configuration.Http.cs + + + Common\System\Net\ByteOrder.cs @@ -65,6 +72,9 @@ System\Net\SocketProtocolSupportPal.Windows + + + Common\System\Net\SocketAddressPal.Windows @@ -116,6 +126,24 @@ Interop\Windows\Winsock\SafeFreeAddrInfo.cs + + Interop\Windows\Winsock\AddressInfoEx.cs + + + Interop\Windows\Winsock\Interop.GetAddrInfoExW.cs + + + Common\Microsoft\Win32\SafeHandles\SafeLibraryHandle.cs + + + Interop\Windows\Kernel32\Interop.GetProcAddress.cs + + + Interop\Windows\Kernel32\Interop.LoadLibraryEx.cs + + + Interop\Windows\Kernel32\Interop.FreeLibrary.cs + @@ -124,9 +152,6 @@ Common\System\Net\Internals\Interop.CheckedAccess.cs - - Common\System\Net\Internals\ByteOrder.cs - Common\System\Net\InteropIPAddressExtensions.Unix.cs @@ -171,4 +196,4 @@ - + \ No newline at end of file diff --git a/src/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs b/src/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs index 74da893f328e..69d1bc992961 100644 --- a/src/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs +++ b/src/System.Net.NameResolution/tests/UnitTests/Fakes/FakeNameResolutionPal.cs @@ -10,6 +10,8 @@ namespace System.Net { internal static class NameResolutionPal { + public static bool SupportsGetAddrInfoAsync => false; + internal static int FakesEnsureSocketsAreInitializedCallCount { get; @@ -49,6 +51,11 @@ internal static string TryGetNameInfo(IPAddress address, out SocketError errorCo throw new NotImplementedException(); } + internal static void GetAddrInfoAsync(DnsResolveAsyncResult asyncResult) + { + throw new NotImplementedException(); + } + internal static IPHostEntry GetHostByAddr(IPAddress address) { throw new NotImplementedException(); diff --git a/src/System.Net.NameResolution/tests/UnitTests/System.Net.NameResolution.Unit.Tests.csproj b/src/System.Net.NameResolution/tests/UnitTests/System.Net.NameResolution.Unit.Tests.csproj index e99d16ebed38..b52ff2aafce0 100644 --- a/src/System.Net.NameResolution/tests/UnitTests/System.Net.NameResolution.Unit.Tests.csproj +++ b/src/System.Net.NameResolution/tests/UnitTests/System.Net.NameResolution.Unit.Tests.csproj @@ -26,6 +26,9 @@ ProductionCode\System\Net\DNS.cs + + + ProductionCode\System\Net\DnsResolveAsyncResult.cs