diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs index e505f3ef8475c..75aa0d4aa8f60 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.OpenSsl.cs @@ -253,15 +253,15 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50 return context; } - internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int recvOffset, int recvCount, out byte[] sendBuf, out int sendCount) + internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan input, out byte[] sendBuf, out int sendCount) { sendBuf = null; sendCount = 0; Exception handshakeException = null; - if ((recvBuf != null) && (recvCount > 0)) + if (input.Length > 0) { - if (BioWrite(context.InputBio, recvBuf, recvOffset, recvCount) <= 0) + if (Ssl.BioWrite(context.InputBio, ref MemoryMarshal.GetReference(input), input.Length) != input.Length) { // Make sure we clear out the error that is stored in the queue throw Crypto.CreateOpenSslCryptographicException(); @@ -321,7 +321,7 @@ internal static bool DoSslHandshake(SafeSslHandle context, byte[] recvBuf, int r return stateOk; } - internal static int Encrypt(SafeSslHandle context, ReadOnlyMemory input, ref byte[] output, out Ssl.SslErrorCode errorCode) + internal static int Encrypt(SafeSslHandle context, ReadOnlySpan input, ref byte[] output, out Ssl.SslErrorCode errorCode) { #if DEBUG ulong assertNoError = Crypto.ErrPeekError(); @@ -334,13 +334,7 @@ internal static int Encrypt(SafeSslHandle context, ReadOnlyMemory input, r lock (context) { - unsafe - { - using (MemoryHandle handle = input.Pin()) - { - retVal = Ssl.SslWrite(context, (byte*)handle.Pointer, input.Length); - } - } + retVal = Ssl.SslWrite(context, ref MemoryMarshal.GetReference(input), input.Length); if (retVal != input.Length) { diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs index 02148f3ea3a80..2246b5f4bf9a7 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs @@ -72,7 +72,7 @@ internal static byte[] SslGetAlpnSelected(SafeSslHandle ssl) } [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslWrite")] - internal static extern unsafe int SslWrite(SafeSslHandle ssl, byte* buf, int num); + internal static extern unsafe int SslWrite(SafeSslHandle ssl, ref byte buf, int num); [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslRead")] internal static extern unsafe int SslRead(SafeSslHandle ssl, byte* buf, int num); @@ -101,6 +101,9 @@ internal static byte[] SslGetAlpnSelected(SafeSslHandle ssl) [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioWrite")] internal static extern unsafe int BioWrite(SafeBioHandle b, byte* data, int len); + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_BioWrite")] + internal static extern unsafe int BioWrite(SafeBioHandle b, ref byte data, int len); + [DllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertificate")] internal static extern SafeX509Handle SslGetPeerCertificate(SafeSslHandle ssl); diff --git a/src/libraries/Common/tests/System/Net/Capability.Security.cs b/src/libraries/Common/tests/System/Net/Capability.Security.cs index 3de857a3f436e..fe78e01da7d23 100644 --- a/src/libraries/Common/tests/System/Net/Capability.Security.cs +++ b/src/libraries/Common/tests/System/Net/Capability.Security.cs @@ -50,6 +50,18 @@ public static bool Http2ForceUnencryptedLoopback() { return true; } + + return false; + } + + public static bool SecurityForceSocketStreams() + { + string value = Configuration.Security.SecurityForceSocketStreams; + if (value != null && (value.Equals("true", StringComparison.OrdinalIgnoreCase) || value.Equals("1"))) + { + return true; + } + return false; } diff --git a/src/libraries/Common/tests/System/Net/Configuration.Security.cs b/src/libraries/Common/tests/System/Net/Configuration.Security.cs index c60e158a93dd2..b09ad6785fe53 100644 --- a/src/libraries/Common/tests/System/Net/Configuration.Security.cs +++ b/src/libraries/Common/tests/System/Net/Configuration.Security.cs @@ -35,6 +35,8 @@ public static partial class Security // 127.0.0.1 testclienteku.contoso.com public static string HostsFileNamesInstalled => GetValue("COREFX_NET_SECURITY_HOSTS_FILE_INSTALLED"); + // Allows packet captures. + public static string SecurityForceSocketStreams => GetValue("COREFX_NET_SECURITY_FORCE_SOCKET_STREAMS"); } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/HelperAsyncResults.cs b/src/libraries/System.Net.Security/src/System/Net/HelperAsyncResults.cs index 5d87ae5a0e132..a6a2ffdfdc071 100644 --- a/src/libraries/System.Net.Security/src/System/Net/HelperAsyncResults.cs +++ b/src/libraries/System.Net.Security/src/System/Net/HelperAsyncResults.cs @@ -20,10 +20,6 @@ namespace System.Net // internal class AsyncProtocolRequest { -#if DEBUG - internal object _DebugAsyncChain; // Optionally used to track chains of async calls. -#endif - private AsyncProtocolCallback _callback; private int _completionStatus; @@ -33,7 +29,6 @@ internal class AsyncProtocolRequest public LazyAsyncResult UserAsyncResult; public int Result; - public object AsyncState; public readonly CancellationToken CancellationToken; public byte[] Buffer; // Temporary buffer reused by a protocol. diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs b/src/libraries/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs index 82a273ac3ecc0..1db55ec02444a 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/Pal.OSX/SafeDeleteSslContext.cs @@ -217,15 +217,18 @@ internal void Write(byte[] buf, int offset, int count) Debug.Assert(count >= 0); Debug.Assert(count <= buf.Length - offset); + Write(buf.AsSpan(offset, count)); + } + internal void Write(ReadOnlySpan buf) + { lock (_fromConnection) { - for (int i = 0; i < count; i++) + foreach (byte b in buf) { - _fromConnection.Enqueue(buf[offset + i]); + _fromConnection.Enqueue(b); } } - } internal int BytesReadyForConnection => _toConnection.Count; diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs index e3cd0062d55e0..7486f633fd3a0 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs @@ -626,7 +626,7 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint) // // Acquire Server Side Certificate information and set it on the class. // - private bool AcquireServerCredentials(ref byte[] thumbPrint, byte[] clientHello) + private bool AcquireServerCredentials(ref byte[] thumbPrint, ReadOnlySpan clientHello) { if (NetEventSource.IsEnabled) NetEventSource.Enter(this); @@ -797,7 +797,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref if (_refreshCredentialNeeded) { cachedCreds = _sslAuthenticationOptions.IsServer - ? AcquireServerCredentials(ref thumbPrint, input) + ? AcquireServerCredentials(ref thumbPrint, new ReadOnlySpan(input, offset, count)) : AcquireClientCredentials(ref thumbPrint); } @@ -806,7 +806,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref status = SslStreamPal.AcceptSecurityContext( ref _credentialsHandle, ref _securityContext, - input != null ? new ArraySegment(input, offset, count) : default, + input, offset, count, ref result, _sslAuthenticationOptions); } @@ -816,7 +816,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref ref _credentialsHandle, ref _securityContext, _sslAuthenticationOptions.TargetHost, - input != null ? new ArraySegment(input, offset, count) : default, + input, offset, count, ref result, _sslAuthenticationOptions); } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SniHelper.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SniHelper.cs index 63eb8b4ca6549..c52e0ba84e75e 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SniHelper.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SniHelper.cs @@ -16,12 +16,7 @@ internal class SniHelper private static readonly IdnMapping s_idnMapping = CreateIdnMapping(); private static readonly Encoding s_encoding = CreateEncoding(); - public static string GetServerName(byte[] clientHello) - { - return GetSniFromSslPlainText(clientHello); - } - - private static string GetSniFromSslPlainText(ReadOnlySpan sslPlainText) + public static string GetServerName(ReadOnlySpan sslPlainText) { // https://tools.ietf.org/html/rfc6101#section-5.2.1 // struct { diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs index be9300a44c006..e400afb535071 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs @@ -19,9 +19,6 @@ namespace System.Net.Security public partial class SslStream { private static int s_uniqueNameInteger = 123; - private static readonly AsyncProtocolCallback s_partialFrameCallback = new AsyncProtocolCallback(PartialFrameCallback); - private static readonly AsyncProtocolCallback s_readFrameCallback = new AsyncProtocolCallback(ReadFrameCallback); - private static readonly AsyncCallback s_writeCallback = new AsyncCallback(WriteCallback); private SslAuthenticationOptions _sslAuthenticationOptions; @@ -38,11 +35,30 @@ private enum CachedSessionStatus : byte } private CachedSessionStatus _CachedSession; + private enum Framing + { + Unknown = 0, + BeforeSSL3, + SinceSSL3, + Unified, + Invalid + } + + // This is set on the first packet to figure out the framing style. + private Framing _framing = Framing.Unknown; + + // SSL3/TLS protocol frames definitions. + private enum FrameType : byte + { + ChangeCipherSpec = 20, + Alert = 21, + Handshake = 22, + AppData = 23 + } + // This block is used by re-handshake code to buffer data decrypted with the old key. private byte[] _queuedReadData; private int _queuedReadCount; - private bool _pendingReHandshake; - private const int MaxQueuedReadBytes = 1024 * 128; // // This block is used to rule the >>re-handshakes<< that are concurrent with read/write I/O requests. @@ -191,31 +207,6 @@ private SecurityStatusPal PrivateDecryptData(byte[] buffer, ref int offset, ref return _context.Decrypt(buffer, ref offset, ref count); } - // - // Called by re-handshake if found data decrypted with the old key - // - private Exception EnqueueOldKeyDecryptedData(byte[] buffer, int offset, int count) - { - lock (SyncLock) - { - if (_queuedReadCount + count > MaxQueuedReadBytes) - { - return ExceptionDispatchInfo.SetCurrentStackTrace( - new IOException(SR.Format(SR.net_auth_ignored_reauth, MaxQueuedReadBytes.ToString(NumberFormatInfo.CurrentInfo)))); - } - - if (count != 0) - { - // This is inefficient yet simple and that should be a rare case of receiving data encrypted with "old" key. - _queuedReadData = EnsureBufferSize(_queuedReadData, _queuedReadCount, _queuedReadCount + count); - Buffer.BlockCopy(buffer, offset, _queuedReadData, _queuedReadCount, count); - _queuedReadCount += count; - FinishHandshakeRead(LockHandshake); - } - } - return null; - } - // // When re-handshaking the "old" key decrypted data are queued until the handshake is done. // When stream calls for decryption we will feed it queued data left from "old" encryption key. @@ -249,34 +240,29 @@ private int CheckOldKeyDecryptedData(Memory buffer) // This method assumes that a SSPI context is already in a good shape. // For example it is either a fresh context or already authenticated context that needs renegotiation. // - private void ProcessAuthentication(LazyAsyncResult lazyResult, CancellationToken cancellationToken) + private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) { + Task result = null; if (Interlocked.Exchange(ref _nestedAuth, 1) == 1) { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, lazyResult == null ? "BeginAuthenticate" : "Authenticate", "authenticate")); + throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate")); } try { ThrowIfExceptional(); - AsyncProtocolRequest asyncRequest = null; - if (lazyResult != null) - { - asyncRequest = new AsyncProtocolRequest(lazyResult, cancellationToken); - asyncRequest.Buffer = null; -#if DEBUG - lazyResult._debugAsyncChain = asyncRequest; -#endif - } // A trick to discover and avoid cached sessions. _CachedSession = CachedSessionStatus.Unknown; - ForceAuthentication(_context.IsServer, null, asyncRequest); - - // Not aync so the connection is completed at this point. - if (lazyResult == null && NetEventSource.IsEnabled) + if (isAsync) { + result = ForceAuthenticationAsync(_context.IsServer, null, cancellationToken); + } + else + { + ForceAuthentication(_context.IsServer, null); + if (NetEventSource.IsEnabled) NetEventSource.Log.SspiSelectedCipherSuite(nameof(ProcessAuthentication), SslProtocol, @@ -288,50 +274,28 @@ private void ProcessAuthentication(LazyAsyncResult lazyResult, CancellationToken KeyExchangeStrength); } } - catch (Exception) - { - // If an exception emerges synchronously, the asynchronous operation was not - // initiated, so no operation is in progress. - _nestedAuth = 0; - throw; - } finally { - // For synchronous operations, the operation has completed. - if (lazyResult == null) - { - _nestedAuth = 0; - } + // Operation has completed. + _nestedAuth = 0; } + + return result; } // // This is used to reply on re-handshake when received SEC_I_RENEGOTIATE on Read(). // - private void ReplyOnReAuthentication(byte[] buffer, CancellationToken cancellationToken) + private async Task ReplyOnReAuthenticationAsync(byte[] buffer, CancellationToken cancellationToken) { lock (SyncLock) { // Note we are already inside the read, so checking for already going concurrent handshake. _lockReadState = LockHandshake; - - if (_pendingReHandshake) - { - // A concurrent handshake is pending, resume. - FinishRead(buffer); - return; - } } - // Start rehandshake from here. - - // Forcing async mode. The caller will queue another Read as soon as we return using its preferred - // calling convention, which will be woken up when the handshake completes. The callback is just - // to capture any SocketErrors that happen during the handshake so they can be surfaced from the Read. - AsyncProtocolRequest asyncRequest = new AsyncProtocolRequest(new LazyAsyncResult(this, null, new AsyncCallback(RehandshakeCompleteCallback)), cancellationToken); - // Buffer contains a result from DecryptMessage that will be passed to ISC/ASC - asyncRequest.Buffer = buffer; - ForceAuthentication(false, buffer, asyncRequest); + await ForceAuthenticationAsync(false, buffer, cancellationToken).ConfigureAwait(false); + FinishHandshakeRead(LockNone); } // @@ -339,35 +303,28 @@ private void ReplyOnReAuthentication(byte[] buffer, CancellationToken cancellati // Incoming buffer is either null or is the result of "renegotiate" decrypted message // If write is in progress the method will either wait or be put on hold // - private void ForceAuthentication(bool receiveFirst, byte[] buffer, AsyncProtocolRequest asyncRequest) + private void ForceAuthentication(bool receiveFirst, byte[] buffer) { - if (CheckEnqueueHandshake(buffer, asyncRequest)) - { - // Async handshake is enqueued and will resume later. - return; - } - // Either Sync handshake is ready to go or async handshake won the race over write. - // This will tell that we don't know the framing yet (what SSL version is) - _Framing = Framing.Unknown; + _framing = Framing.Unknown; try { if (receiveFirst) { // Listen for a client blob. - StartReceiveBlob(buffer, asyncRequest); + ReceiveBlob(buffer); } else { // We start with the first blob. - StartSendBlob(buffer, (buffer == null ? 0 : buffer.Length), asyncRequest); + SendBlob(buffer, (buffer == null ? 0 : buffer.Length)); } } catch (Exception e) { // Failed auth, reset the framing if any. - _Framing = Framing.Unknown; + _framing = Framing.Unknown; _handshakeCompleted = false; SetException(e); @@ -382,67 +339,66 @@ private void ForceAuthentication(bool receiveFirst, byte[] buffer, AsyncProtocol if (_exception != null) { // This a failed handshake. Release waiting IO if any. - FinishHandshake(null, null); + FinishHandshake(null); } } } - private void EndProcessAuthentication(IAsyncResult result) + internal async Task ForceAuthenticationAsync(bool receiveFirst, byte[] buffer, CancellationToken cancellationToken) { - if (result == null) - { - throw new ArgumentNullException("asyncResult"); - } + _framing = Framing.Unknown; + ProtocolToken message; + SslReadAsync adapter = new SslReadAsync(this, cancellationToken); - LazyAsyncResult lazyResult = result as LazyAsyncResult; - if (lazyResult == null) + if (!receiveFirst) { - throw new ArgumentException(SR.Format(SR.net_io_async_result, result.GetType().FullName), "asyncResult"); + message = _context.NextMessage(buffer, 0, (buffer == null ? 0 : buffer.Length)); + if (message.Failed) + { + // tracing done in NextMessage() + throw new AuthenticationException(SR.net_auth_SSPI, message.GetException()); + } + + await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false); } - if (Interlocked.Exchange(ref _nestedAuth, 0) == 0) + do { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, "EndAuthenticate")); - } + message = await ReceiveBlobAsync(adapter, buffer, cancellationToken).ConfigureAwait(false); + if (message.Size > 0) + { + // If there is message send it out even if call failed. It may contain TLS Alert. + await InnerStream.WriteAsync(message.Payload, cancellationToken).ConfigureAwait(false); + } - InternalEndProcessAuthentication(lazyResult); + if (message.Failed) + { + throw new AuthenticationException(SR.net_auth_SSPI, message.GetException()); + } + } while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK); - // Connection is completed at this point. - if (NetEventSource.IsEnabled) + ProtocolToken alertToken = null; + if (!CompleteHandshake(ref alertToken)) { - if (NetEventSource.IsEnabled) - NetEventSource.Log.SspiSelectedCipherSuite(nameof(EndProcessAuthentication), - SslProtocol, - CipherAlgorithm, - CipherStrength, - HashAlgorithm, - HashStrength, - KeyExchangeAlgorithm, - KeyExchangeStrength); + SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null))); } - } - - private void InternalEndProcessAuthentication(LazyAsyncResult lazyResult) - { - // No "artificial" timeouts implemented so far, InnerStream controls that. - lazyResult.InternalWaitForCompletion(); - Exception e = lazyResult.Result as Exception; - if (e != null) - { - // Failed auth, reset the framing if any. - _Framing = Framing.Unknown; - _handshakeCompleted = false; + if (NetEventSource.IsEnabled) + NetEventSource.Log.SspiSelectedCipherSuite(nameof(ForceAuthenticationAsync), + SslProtocol, + CipherAlgorithm, + CipherStrength, + HashAlgorithm, + HashStrength, + KeyExchangeAlgorithm, + KeyExchangeStrength); - SetException(e); - ThrowIfExceptional(); - } } // // Client side starts here, but server also loops through this method. // - private void StartSendBlob(byte[] incoming, int count, AsyncProtocolRequest asyncRequest) + private void SendBlob(byte[] incoming, int count) { ProtocolToken message = _context.NextMessage(incoming, 0, count); _securityStatus = message.Status; @@ -458,125 +414,65 @@ private void StartSendBlob(byte[] incoming, int count, AsyncProtocolRequest asyn _CachedSession = message.Size < 200 ? CachedSessionStatus.IsCached : CachedSessionStatus.IsNotCached; } - if (_Framing == Framing.Unified) + if (_framing == Framing.Unified) { - _Framing = DetectFraming(message.Payload, message.Payload.Length); + _framing = DetectFraming(message.Payload, message.Payload.Length); } - if (asyncRequest == null) - { - InnerStream.Write(message.Payload, 0, message.Size); - } - else - { - asyncRequest.AsyncState = message; - Task t = InnerStream.WriteAsync(message.Payload, 0, message.Size, asyncRequest.CancellationToken); - if (t.IsCompleted) - { - t.GetAwaiter().GetResult(); - } - else - { - IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest); - if (!ar.CompletedSynchronously) - { -#if DEBUG - asyncRequest._DebugAsyncChain = ar; -#endif - return; - } - TaskToApm.End(ar); - } - } + InnerStream.Write(message.Payload, 0, message.Size); } - CheckCompletionBeforeNextReceive(message, asyncRequest); + CheckCompletionBeforeNextReceive(message); } // // This will check and logically complete / fail the auth handshake. // - private void CheckCompletionBeforeNextReceive(ProtocolToken message, AsyncProtocolRequest asyncRequest) + private void CheckCompletionBeforeNextReceive(ProtocolToken message) { if (message.Failed) { - StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_auth_SSPI, message.GetException()))); + SendAuthResetSignal(null, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_auth_SSPI, message.GetException()))); return; } - else if (message.Done && !_pendingReHandshake) + else if (message.Done) { ProtocolToken alertToken = null; if (!CompleteHandshake(ref alertToken)) { - StartSendAuthResetSignal(alertToken, asyncRequest, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null))); + SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_validation, null))); return; } // Release waiting IO if any. Presumably it should not throw. // Otherwise application may get not expected type of the exception. - FinishHandshake(null, asyncRequest); + FinishHandshake(null); return; } - StartReceiveBlob(message.Payload, asyncRequest); + ReceiveBlob(message.Payload); } // // Server side starts here, but client also loops through this method. // - private void StartReceiveBlob(byte[] buffer, AsyncProtocolRequest asyncRequest) + private void ReceiveBlob(byte[] buffer) { - if (_pendingReHandshake) - { - if (CheckEnqueueHandshakeRead(ref buffer, asyncRequest)) - { - return; - } - - if (!_pendingReHandshake) - { - // Renegotiate: proceed to the next step. - ProcessReceivedBlob(buffer, buffer.Length, asyncRequest); - return; - } - } - //This is first server read. buffer = EnsureBufferSize(buffer, 0, SecureChannel.ReadHeaderSize); - int readBytes = 0; - if (asyncRequest == null) - { - readBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, 0, SecureChannel.ReadHeaderSize); - } - else - { - asyncRequest.SetNextRequest(buffer, 0, SecureChannel.ReadHeaderSize, s_partialFrameCallback); - _ = FixedSizeReader.ReadPacketAsync(_innerStream, asyncRequest); - if (!asyncRequest.MustCompleteSynchronously) - { - return; - } - - readBytes = asyncRequest.Result; - } + int readBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, 0, SecureChannel.ReadHeaderSize); - StartReadFrame(buffer, readBytes, asyncRequest); - } - - // - private void StartReadFrame(byte[] buffer, int readBytes, AsyncProtocolRequest asyncRequest) - { if (readBytes == 0) { // EOF received throw new IOException(SR.net_auth_eof); } - if (_Framing == Framing.Unknown) + if (_framing == Framing.Unknown) { - _Framing = DetectFraming(buffer, readBytes); + _framing = DetectFraming(buffer, readBytes); } int restBytes = GetRemainingFrameSize(buffer, 0, readBytes); @@ -594,81 +490,57 @@ private void StartReadFrame(byte[] buffer, int readBytes, AsyncProtocolRequest a buffer = EnsureBufferSize(buffer, readBytes, readBytes + restBytes); - if (asyncRequest == null) - { - restBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, readBytes, restBytes); - } - else - { - asyncRequest.SetNextRequest(buffer, readBytes, restBytes, s_readFrameCallback); - _ = FixedSizeReader.ReadPacketAsync(_innerStream, asyncRequest); - if (!asyncRequest.MustCompleteSynchronously) - { - return; - } + restBytes = FixedSizeReader.ReadPacket(_innerStream, buffer, readBytes, restBytes); - restBytes = asyncRequest.Result; - if (restBytes == 0) - { - //EOF received: fail. - readBytes = 0; - } - } - ProcessReceivedBlob(buffer, readBytes + restBytes, asyncRequest); + SendBlob(buffer, readBytes + restBytes); } - private void ProcessReceivedBlob(byte[] buffer, int count, AsyncProtocolRequest asyncRequest) + private async ValueTask ReceiveBlobAsync(SslReadAsync adapter, byte[] buffer, CancellationToken cancellationToken) { - if (count == 0) + ResetReadBuffer(); + int readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false); + if (readBytes == 0) { - // EOF received. - throw new AuthenticationException(SR.net_auth_eof, null); + throw new IOException(SR.net_io_eof); } - if (_pendingReHandshake) + if (_framing == Framing.Unified || _framing == Framing.Unknown) { - int offset = 0; - SecurityStatusPal status = PrivateDecryptData(buffer, ref offset, ref count); + _framing = DetectFraming(_internalBuffer, readBytes); + } - if (status.ErrorCode == SecurityStatusPalErrorCode.OK) - { - Exception e = EnqueueOldKeyDecryptedData(buffer, offset, count); - if (e != null) - { - StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture(e)); - return; - } + int payloadBytes = GetRemainingFrameSize(_internalBuffer, _internalOffset, readBytes); + if (payloadBytes < 0) + { + throw new IOException(SR.net_frame_read_size); + } - _Framing = Framing.Unknown; - StartReceiveBlob(buffer, asyncRequest); - return; - } - else if (status.ErrorCode != SecurityStatusPalErrorCode.Renegotiate) - { - // Fail re-handshake. - ProtocolToken message = new ProtocolToken(null, status); - StartSendAuthResetSignal(null, asyncRequest, ExceptionDispatchInfo.Capture( - ExceptionDispatchInfo.SetCurrentStackTrace(new AuthenticationException(SR.net_auth_SSPI, message.GetException())))); - return; - } + int frameSize = SecureChannel.ReadHeaderSize + payloadBytes; - // We expect only handshake messages from now. - _pendingReHandshake = false; - if (offset != 0) + if (readBytes < frameSize) + { + readBytes = await FillBufferAsync(adapter, frameSize).ConfigureAwait(false); + Debug.Assert(readBytes >= 0); + if (readBytes == 0) { - Buffer.BlockCopy(buffer, offset, buffer, 0, count); + throw new IOException(SR.net_io_eof); } } - StartSendBlob(buffer, count, asyncRequest); + ProtocolToken token = _context.NextMessage(_internalBuffer, _internalOffset, frameSize); + ConsumeBufferedBytes(frameSize); + + return token; } // // This is to reset auth state on remote side. // If this write succeeds we will allow auth retrying. // - private void StartSendAuthResetSignal(ProtocolToken message, AsyncProtocolRequest asyncRequest, ExceptionDispatchInfo exception) + private void SendAuthResetSignal(ProtocolToken message, ExceptionDispatchInfo exception) { + SetException(exception.SourceException); + if (message == null || message.Size == 0) { // @@ -677,28 +549,7 @@ private void StartSendAuthResetSignal(ProtocolToken message, AsyncProtocolReques exception.Throw(); } - if (asyncRequest == null) - { - InnerStream.Write(message.Payload, 0, message.Size); - } - else - { - asyncRequest.AsyncState = exception; - Task t = InnerStream.WriteAsync(message.Payload, 0, message.Size, asyncRequest.CancellationToken); - if (t.IsCompleted) - { - t.GetAwaiter().GetResult(); - } - else - { - IAsyncResult ar = TaskToApm.Begin(t, s_writeCallback, asyncRequest); - if (!ar.CompletedSynchronously) - { - return; - } - TaskToApm.End(ar); - } - } + InnerStream.Write(message.Payload, 0, message.Size); exception.Throw(); } @@ -734,144 +585,6 @@ private bool CompleteHandshake(ref ProtocolToken alertToken) return true; } - private static void WriteCallback(IAsyncResult transportResult) - { - if (transportResult.CompletedSynchronously) - { - return; - } - - AsyncProtocolRequest asyncRequest; - SslStream sslState; - -#if DEBUG - try - { -#endif - asyncRequest = (AsyncProtocolRequest)transportResult.AsyncState; - sslState = (SslStream)asyncRequest.AsyncObject; -#if DEBUG - } - catch (Exception exception) when (!ExceptionCheck.IsFatal(exception)) - { - NetEventSource.Fail(null, $"Exception while decoding context: {exception}"); - throw; - } -#endif - - // Async completion. - try - { - TaskToApm.End(transportResult); - - // Special case for an error notification. - object asyncState = asyncRequest.AsyncState; - ExceptionDispatchInfo exception = asyncState as ExceptionDispatchInfo; - if (exception != null) - { - exception.Throw(); - } - - sslState.CheckCompletionBeforeNextReceive((ProtocolToken)asyncState, asyncRequest); - } - catch (Exception e) - { - if (asyncRequest.IsUserCompleted) - { - // This will throw on a worker thread. - throw; - } - - sslState.FinishHandshake(e, asyncRequest); - } - } - - private static void PartialFrameCallback(AsyncProtocolRequest asyncRequest) - { - if (NetEventSource.IsEnabled) - NetEventSource.Enter(null); - - // Async ONLY completion. - SslStream sslState = (SslStream)asyncRequest.AsyncObject; - try - { - sslState.StartReadFrame(asyncRequest.Buffer, asyncRequest.Result, asyncRequest); - } - catch (Exception e) - { - if (asyncRequest.IsUserCompleted) - { - // This will throw on a worker thread. - throw; - } - - sslState.FinishHandshake(e, asyncRequest); - } - } - - // - // - private static void ReadFrameCallback(AsyncProtocolRequest asyncRequest) - { - if (NetEventSource.IsEnabled) - NetEventSource.Enter(null); - - // Async ONLY completion. - SslStream sslState = (SslStream)asyncRequest.AsyncObject; - try - { - if (asyncRequest.Result == 0) - { - //EOF received: will fail. - asyncRequest.Offset = 0; - } - - sslState.ProcessReceivedBlob(asyncRequest.Buffer, asyncRequest.Offset + asyncRequest.Result, asyncRequest); - } - catch (Exception e) - { - if (asyncRequest.IsUserCompleted) - { - // This will throw on a worker thread. - throw; - } - - sslState.FinishHandshake(e, asyncRequest); - } - } - - private bool CheckEnqueueHandshakeRead(ref byte[] buffer, AsyncProtocolRequest request) - { - LazyAsyncResult lazyResult = null; - lock (SyncLock) - { - if (_lockReadState == LockPendingRead) - { - return false; - } - - int lockState = Interlocked.Exchange(ref _lockReadState, LockHandshake); - if (lockState != LockRead) - { - return false; - } - - if (request != null) - { - _queuedReadStateRequest = request; - return true; - } - - lazyResult = new LazyAsyncResult(null, null, /*must be */ null); - _queuedReadStateRequest = lazyResult; - } - - // Need to exit from lock before waiting. - lazyResult.InternalWaitForCompletion(); - buffer = (byte[])lazyResult.Result; - return false; - } - private void FinishHandshakeRead(int newState) { lock (SyncLock) @@ -885,7 +598,6 @@ private void FinishHandshakeRead(int newState) } _lockReadState = LockRead; - HandleQueuedCallback(ref _queuedReadStateRequest); } } @@ -966,33 +678,6 @@ private ValueTask CheckEnqueueReadAsync(Memory buffer) } } - private void FinishRead(byte[] renegotiateBuffer) - { - int lockState = Interlocked.CompareExchange(ref _lockReadState, LockNone, LockRead); - - if (lockState != LockHandshake) - { - return; - } - - lock (SyncLock) - { - LazyAsyncResult ar = _queuedReadStateRequest as LazyAsyncResult; - if (ar != null) - { - _queuedReadStateRequest = null; - ar.InvokeCallback(renegotiateBuffer); - } - else - { - AsyncProtocolRequest request = (AsyncProtocolRequest)_queuedReadStateRequest; - request.Buffer = renegotiateBuffer; - _queuedReadStateRequest = null; - ThreadPool.QueueUserWorkItem(s => s.sslState.AsyncResumeHandshakeRead(s.request), (sslState: this, request), preferLocal: false); - } - } - } - private Task CheckEnqueueWriteAsync() { // Clear previous request. @@ -1057,123 +742,28 @@ private void FinishWrite() { return; } - - lock (SyncLock) - { - HandleQueuedCallback(ref _queuedWriteStateRequest); - } - } - - private void HandleQueuedCallback(ref object queuedStateRequest) - { - object obj = queuedStateRequest; - if (obj == null) - { - return; - } - queuedStateRequest = null; - - switch (obj) - { - case LazyAsyncResult lazy: - lazy.InvokeCallback(); - break; - case TaskCompletionSource taskCompletionSource when taskCompletionSource.Task.AsyncState != null: - Memory array = (Memory)taskCompletionSource.Task.AsyncState; - int oldKeyResult = -1; - try - { - oldKeyResult = CheckOldKeyDecryptedData(array); - } - catch (Exception exc) - { - taskCompletionSource.SetException(exc); - break; - } - taskCompletionSource.SetResult(oldKeyResult); - break; - case TaskCompletionSource taskCompletionSource: - taskCompletionSource.SetResult(0); - break; - default: - ThreadPool.QueueUserWorkItem(s => s.sslState.AsyncResumeHandshake(s.obj), (sslState: this, obj), preferLocal: false); - break; - } } - // Returns: - // true - operation queued - // false - operation can proceed - private bool CheckEnqueueHandshake(byte[] buffer, AsyncProtocolRequest asyncRequest) + private void FinishHandshake(Exception e) { - LazyAsyncResult lazyResult = null; - lock (SyncLock) { - if (_lockWriteState == LockPendingWrite) + if (e != null) { - return false; + SetException(e); } - int lockState = Interlocked.Exchange(ref _lockWriteState, LockHandshake); - if (lockState != LockWrite) - { - // Proceed with handshake. - return false; - } + // Release read if any. + FinishHandshakeRead(LockNone); - if (asyncRequest != null) + // If there is a pending write we want to keep it's lock state. + int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockNone, LockHandshake); + if (lockState != LockPendingWrite) { - asyncRequest.Buffer = buffer; - _queuedWriteStateRequest = asyncRequest; - return true; + return; } - lazyResult = new LazyAsyncResult(null, null, /*must be*/null); - _queuedWriteStateRequest = lazyResult; - } - lazyResult.InternalWaitForCompletion(); - return false; - } - - private void FinishHandshake(Exception e, AsyncProtocolRequest asyncRequest) - { - try - { - lock (SyncLock) - { - if (e != null) - { - SetException(e); - } - - // Release read if any. - FinishHandshakeRead(LockNone); - - // If there is a pending write we want to keep it's lock state. - int lockState = Interlocked.CompareExchange(ref _lockWriteState, LockNone, LockHandshake); - if (lockState != LockPendingWrite) - { - return; - } - - _lockWriteState = LockWrite; - HandleQueuedCallback(ref _queuedWriteStateRequest); - } - } - finally - { - if (asyncRequest != null) - { - if (e != null) - { - asyncRequest.CompleteUserWithError(e); - } - else - { - asyncRequest.CompleteUser(); - } - } + _lockWriteState = LockWrite; } } @@ -1307,8 +897,6 @@ private async ValueTask ReadAsyncInternal(TReadAdapter adapte { copyBytes = CopyDecryptedData(buffer); - FinishRead(null); - return copyBytes; } @@ -1368,18 +956,17 @@ private async ValueTask ReadAsyncInternal(TReadAdapter adapte { if (!_sslAuthenticationOptions.AllowRenegotiation) { + if (NetEventSource.IsEnabled) NetEventSource.Fail(this, "Renegotiation was requested but it is disallowed"); throw new IOException(SR.net_ssl_io_renego); } - ReplyOnReAuthentication(extraBuffer, adapter.CancellationToken); - + await ReplyOnReAuthenticationAsync(extraBuffer, adapter.CancellationToken).ConfigureAwait(false); // Loop on read. continue; } if (message.CloseConnection) { - FinishRead(null); return 0; } @@ -1389,8 +976,6 @@ private async ValueTask ReadAsyncInternal(TReadAdapter adapte } catch (Exception e) { - FinishRead(null); - if (e is IOException || (e is OperationCanceledException && adapter.CancellationToken.IsCancellationRequested)) { throw; @@ -1526,6 +1111,7 @@ private int CopyDecryptedData(Memory buffer) _decryptedBytesOffset += copyBytes; _decryptedBytesCount -= copyBytes; } + ReturnReadBufferIfEmpty(); return copyBytes; } @@ -1562,27 +1148,6 @@ private static byte[] EnsureBufferSize(byte[] buffer, int copyCount, int size) return buffer; } - private enum Framing - { - Unknown = 0, - BeforeSSL3, - SinceSSL3, - Unified, - Invalid - } - - // This is set on the first packet to figure out the framing style. - private Framing _Framing = Framing.Unknown; - - // SSL3/TLS protocol frames definitions. - private enum FrameType : byte - { - ChangeCipherSpec = 20, - Alert = 21, - Handshake = 22, - AppData = 23 - } - // We need at least 5 bytes to determine what we have. private Framing DetectFraming(byte[] bytes, int length) { @@ -1735,7 +1300,7 @@ private Framing DetectFraming(byte[] bytes, int length) // If this is the first packet, the client may start with an SSL2 packet // but stating that the version is 3.x, so check the full range. // For the subsequent packets we assume that an SSL2 packet should have a 2.x version. - if (_Framing == Framing.Unknown) + if (_framing == Framing.Unknown) { if (version != 0x0002 && (version < 0x200 || version >= 0x500)) { @@ -1752,7 +1317,7 @@ private Framing DetectFraming(byte[] bytes, int length) } // When server has replied the framing is already fixed depending on the prior client packet - if (!_context.IsServer || _Framing == Framing.Unified) + if (!_context.IsServer || _framing == Framing.Unified) { return Framing.BeforeSSL3; } @@ -1768,7 +1333,7 @@ private int GetRemainingFrameSize(byte[] buffer, int offset, int dataSize) NetEventSource.Enter(this, buffer, offset, dataSize); int payloadSize = -1; - switch (_Framing) + switch (_framing) { case Framing.Unified: case Framing.BeforeSSL3: @@ -1809,104 +1374,5 @@ private int GetRemainingFrameSize(byte[] buffer, int offset, int dataSize) NetEventSource.Exit(this, payloadSize); return payloadSize; } - - // - // Called with no user stack. - // - private void AsyncResumeHandshake(object state) - { - AsyncProtocolRequest request = state as AsyncProtocolRequest; - Debug.Assert(request != null, "Expected an AsyncProtocolRequest reference."); - - try - { - ForceAuthentication(_context.IsServer, request.Buffer, request); - } - catch (Exception e) - { - request.CompleteUserWithError(e); - } - } - - // - // Called with no user stack. - // - private void AsyncResumeHandshakeRead(AsyncProtocolRequest asyncRequest) - { - try - { - if (_pendingReHandshake) - { - // Resume as read a blob. - StartReceiveBlob(asyncRequest.Buffer, asyncRequest); - } - else - { - // Resume as process the blob. - ProcessReceivedBlob(asyncRequest.Buffer, asyncRequest.Buffer == null ? 0 : asyncRequest.Buffer.Length, asyncRequest); - } - } - catch (Exception e) - { - if (asyncRequest.IsUserCompleted) - { - // This will throw on a worker thread. - throw; - } - - FinishHandshake(e, asyncRequest); - } - } - - private void RehandshakeCompleteCallback(IAsyncResult result) - { - LazyAsyncResult lazyAsyncResult = (LazyAsyncResult)result; - if (lazyAsyncResult == null) - { - NetEventSource.Fail(this, "result is null!"); - } - - if (!lazyAsyncResult.InternalPeekCompleted) - { - NetEventSource.Fail(this, "result is not completed!"); - } - - // If the rehandshake succeeded, FinishHandshake has already been called; if there was a SocketException - // during the handshake, this gets called directly from FixedSizeReader, and we need to call - // FinishHandshake to wake up the Read that triggered this rehandshake so the error gets back to the caller - Exception exception = lazyAsyncResult.InternalWaitForCompletion() as Exception; - if (exception != null) - { - // We may be calling FinishHandshake reentrantly, as FinishHandshake can call - // asyncRequest.CompleteWithError, which will result in this method being called. - // This is not a problem because: - // - // 1. We pass null as the asyncRequest parameter, so this second call to FinishHandshake won't loop - // back here. - // - // 2. _QueuedWriteStateRequest and _QueuedReadStateRequest are set to null after the first call, - // so, we won't invoke their callbacks again. - // - // 3. SetException won't overwrite an already-set _Exception. - // - // 4. There are three possibilities for _LockReadState and _LockWriteState: - // - // a. They were set back to None by the first call to FinishHandshake, and this will set them to - // None again: a no-op. - // - // b. They were set to None by the first call to FinishHandshake, but as soon as the lock was given - // up, another thread took a read/write lock. Calling FinishHandshake again will set them back - // to None, but that's fine because that thread will be throwing _Exception before it actually - // does any reading or writing and setting them back to None in a catch block anyways. - // - // c. If there is a Read/Write going on another thread, and the second FinishHandshake clears its - // read/write lock, it's fine because no other Read/Write can look at the lock until the current - // one gives up _SslStream._NestedRead/Write, and no handshake will look at the lock because - // handshakes are only triggered in response to successful reads (which won't happen once - // _Exception is set). - - FinishHandshake(exception, null); - } - } } } diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs index 568e44f30e755..8dc2e8b8ee819 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs @@ -226,19 +226,10 @@ public virtual IAsyncResult BeginAuthenticateAsClient(string targetHost, X509Cer return BeginAuthenticateAsClient(options, CancellationToken.None, asyncCallback, asyncState); } - internal IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) - { - SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback); - SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); - - ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate); + internal IAsyncResult BeginAuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) => + TaskToApm.Begin(AuthenticateAsClientApm(sslClientAuthenticationOptions, cancellationToken), asyncCallback, asyncState); - LazyAsyncResult result = new LazyAsyncResult(this, asyncState, asyncCallback); - ProcessAuthentication(result, cancellationToken); - return result; - } - - public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult) => EndProcessAuthentication(asyncResult); + public virtual void EndAuthenticateAsClient(IAsyncResult asyncResult) => TaskToApm.End(asyncResult); // // Server side auth. @@ -248,7 +239,7 @@ public virtual IAsyncResult BeginAuthenticateAsServer(X509Certificate serverCert { return BeginAuthenticateAsServer(serverCertificate, false, SecurityProtocol.SystemDefaultSecurityProtocols, false, asyncCallback, - asyncState); + asyncState); } public virtual IAsyncResult BeginAuthenticateAsServer(X509Certificate serverCertificate, bool clientCertificateRequired, @@ -274,34 +265,14 @@ public virtual IAsyncResult BeginAuthenticateAsServer(X509Certificate serverCert return BeginAuthenticateAsServer(options, CancellationToken.None, asyncCallback, asyncState); } - private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) - { - SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); - - ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - - LazyAsyncResult result = new LazyAsyncResult(this, asyncState, asyncCallback); - ProcessAuthentication(result, cancellationToken); - return result; - } - - public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult) => EndProcessAuthentication(asyncResult); - - internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState) - { - ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); + private IAsyncResult BeginAuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken, AsyncCallback asyncCallback, object asyncState) => + TaskToApm.Begin(AuthenticateAsServerApm(sslServerAuthenticationOptions, cancellationToken), asyncCallback, asyncState); - ProtocolToken message = _context.CreateShutdownToken(); - return TaskToApm.Begin(InnerStream.WriteAsync(message.Payload, 0, message.Payload.Length), asyncCallback, asyncState); - } + public virtual void EndAuthenticateAsServer(IAsyncResult asyncResult) => TaskToApm.End(asyncResult); - internal void EndShutdown(IAsyncResult asyncResult) - { - ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); + internal IAsyncResult BeginShutdown(AsyncCallback asyncCallback, object asyncState) => TaskToApm.Begin(ShutdownAsync(), asyncCallback, asyncState); - TaskToApm.End(asyncResult); - _shutdown = true; - } + internal void EndShutdown(IAsyncResult asyncResult) => TaskToApm.End(asyncResult); public TransportContext TransportContext => new SslStreamContext(this); @@ -338,7 +309,7 @@ private void AuthenticateAsClient(SslClientAuthenticationOptions sslClientAuthen SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate); - ProcessAuthentication(null, default); + ProcessAuthentication(); } public virtual void AuthenticateAsServer(X509Certificate serverCertificate) @@ -370,86 +341,103 @@ private void AuthenticateAsServer(SslServerAuthenticationOptions sslServerAuthen SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); - ProcessAuthentication(null, default); + ProcessAuthentication(); } #endregion #region Task-based async public methods - public virtual Task AuthenticateAsClientAsync(string targetHost) => - Task.Factory.FromAsync( - (arg1, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, callback, state), - iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar), - targetHost, - this); - - public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, bool checkCertificateRevocation) => - Task.Factory.FromAsync( - (arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, SecurityProtocol.SystemDefaultSecurityProtocols, arg3, callback, state), - iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar), - targetHost, clientCertificates, checkCertificateRevocation, - this); + public virtual Task AuthenticateAsClientAsync(string targetHost) => AuthenticateAsClientAsync(targetHost, null, false); + + public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, bool checkCertificateRevocation) => AuthenticateAsClientAsync(targetHost, clientCertificates, SecurityProtocol.SystemDefaultSecurityProtocols, checkCertificateRevocation); public virtual Task AuthenticateAsClientAsync(string targetHost, X509CertificateCollection clientCertificates, SslProtocols enabledSslProtocols, bool checkCertificateRevocation) { - var beginMethod = checkCertificateRevocation ? (Func) - ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, arg3, true, callback, state)) : - ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, arg3, false, callback, state)); - return Task.Factory.FromAsync( - beginMethod, - iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar), - targetHost, clientCertificates, enabledSslProtocols, - this); + SslClientAuthenticationOptions options = new SslClientAuthenticationOptions() + { + TargetHost = targetHost, + ClientCertificates = clientCertificates, + EnabledSslProtocols = enabledSslProtocols, + CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + EncryptionPolicy = _encryptionPolicy, + }; + + return AuthenticateAsClientAsync(options); } public Task AuthenticateAsClientAsync(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default) { - return Task.Factory.FromAsync( - (arg1, arg2, callback, state) => ((SslStream)state).BeginAuthenticateAsClient(arg1, arg2, callback, state), - iar => ((SslStream)iar.AsyncState).EndAuthenticateAsClient(iar), - sslClientAuthenticationOptions, cancellationToken, - this); + SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback); + SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); + + ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate); + + return ProcessAuthentication(true, false, cancellationToken); + } + + private Task AuthenticateAsClientApm(SslClientAuthenticationOptions sslClientAuthenticationOptions, CancellationToken cancellationToken = default) + { + SetAndVerifyValidationCallback(sslClientAuthenticationOptions.RemoteCertificateValidationCallback); + SetAndVerifySelectionCallback(sslClientAuthenticationOptions.LocalCertificateSelectionCallback); + + ValidateCreateContext(sslClientAuthenticationOptions, _certValidationDelegate, _certSelectionDelegate); + + return ProcessAuthentication(true, true, cancellationToken); } public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate) => - Task.Factory.FromAsync( - (arg1, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, callback, state), - iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar), - serverCertificate, - this); - - public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation) => - Task.Factory.FromAsync( - (arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, SecurityProtocol.SystemDefaultSecurityProtocols, arg3, callback, state), - iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar), - serverCertificate, clientCertificateRequired, checkCertificateRevocation, - this); + AuthenticateAsServerAsync(serverCertificate, false, SecurityProtocol.SystemDefaultSecurityProtocols, false); + + public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, bool checkCertificateRevocation) + { + SslServerAuthenticationOptions options = new SslServerAuthenticationOptions + { + ServerCertificate = serverCertificate, + ClientCertificateRequired = clientCertificateRequired, + CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + EncryptionPolicy = _encryptionPolicy, + }; + + return AuthenticateAsServerAsync(options); + } public virtual Task AuthenticateAsServerAsync(X509Certificate serverCertificate, bool clientCertificateRequired, SslProtocols enabledSslProtocols, bool checkCertificateRevocation) { - var beginMethod = checkCertificateRevocation ? (Func) - ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, arg3, true, callback, state)) : - ((arg1, arg2, arg3, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, arg3, false, callback, state)); - return Task.Factory.FromAsync( - beginMethod, - iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar), - serverCertificate, clientCertificateRequired, enabledSslProtocols, - this); + SslServerAuthenticationOptions options = new SslServerAuthenticationOptions + { + ServerCertificate = serverCertificate, + ClientCertificateRequired = clientCertificateRequired, + EnabledSslProtocols = enabledSslProtocols, + CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + EncryptionPolicy = _encryptionPolicy, + }; + + return AuthenticateAsServerAsync(options); } public Task AuthenticateAsServerAsync(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default) { - return Task.Factory.FromAsync( - (arg1, arg2, callback, state) => ((SslStream)state).BeginAuthenticateAsServer(arg1, arg2, callback, state), - iar => ((SslStream)iar.AsyncState).EndAuthenticateAsServer(iar), - sslServerAuthenticationOptions, cancellationToken, - this); + SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); + ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); + + return ProcessAuthentication(true, false, cancellationToken); } - public virtual Task ShutdownAsync() => - Task.Factory.FromAsync( - (callback, state) => ((SslStream)state).BeginShutdown(callback, state), - iar => ((SslStream)iar.AsyncState).EndShutdown(iar), - this); + private Task AuthenticateAsServerApm(SslServerAuthenticationOptions sslServerAuthenticationOptions, CancellationToken cancellationToken = default) + { + SetAndVerifyValidationCallback(sslServerAuthenticationOptions.RemoteCertificateValidationCallback); + ValidateCreateContext(CreateAuthenticationOptions(sslServerAuthenticationOptions)); + + return ProcessAuthentication(true, true, cancellationToken); + } + + public virtual Task ShutdownAsync() + { + ThrowIfExceptionalOrNotAuthenticatedOrShutdown(); + + ProtocolToken message = _context.CreateShutdownToken(); + _shutdown = true; + return InnerStream.WriteAsync(message.Payload, default).AsTask(); + } #endregion public override bool IsAuthenticated => _context != null && _context.IsValidContext && _exception == null && _handshakeCompleted; diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs index f7d62460876c0..eda667a9a9264 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.OSX.cs @@ -35,22 +35,22 @@ public static void VerifyPackageInfo() public static SecurityStatusPal AcceptSecurityContext( ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, - ArraySegment inputBuffer, + byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, new ReadOnlySpan(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions); } public static SecurityStatusPal InitializeSecurityContext( ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, - ArraySegment inputBuffer, + byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, new ReadOnlySpan(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions); } public static SafeFreeCredentials AcquireCredentialsHandle( @@ -233,7 +233,7 @@ public static void QueryContextConnectionInfo( private static SecurityStatusPal HandshakeInternal( SafeFreeCredentials credential, ref SafeDeleteSslContext context, - ArraySegment inputBuffer, + ReadOnlySpan inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { @@ -260,9 +260,9 @@ private static SecurityStatusPal HandshakeInternal( } } - if (inputBuffer.Array != null && inputBuffer.Count > 0) + if (inputBuffer.Length > 0) { - sslContext.Write(inputBuffer.Array, inputBuffer.Offset, inputBuffer.Count); + sslContext.Write(inputBuffer); } SafeSslHandle sslHandle = sslContext.SslContext; diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs index 4abb400dfbd14..40867646b0231 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Unix.cs @@ -28,15 +28,15 @@ public static void VerifyPackageInfo() } public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, - ArraySegment inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, new ReadOnlySpan(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions); } public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, - ArraySegment inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { - return HandshakeInternal(credential, ref context, inputBuffer, ref outputBuffer, sslAuthenticationOptions); + return HandshakeInternal(credential, ref context, new ReadOnlySpan(inputBuffer, offset, count), ref outputBuffer, sslAuthenticationOptions); } public static SafeFreeCredentials AcquireCredentialsHandle(X509Certificate certificate, @@ -100,7 +100,7 @@ public static byte[] ConvertAlpnProtocolListToByteArray(List inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + ReadOnlySpan inputBuffer, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Debug.Assert(!credential.IsInvalid); @@ -114,16 +114,7 @@ private static SecurityStatusPal HandshakeInternal(SafeFreeCredentials credentia context = new SafeDeleteSslContext(credential as SafeFreeSslCredentials, sslAuthenticationOptions); } - bool done; - - if (inputBuffer.Array == null) - { - done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, null, 0, 0, out output, out outputSize); - } - else - { - done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, inputBuffer.Array, inputBuffer.Offset, inputBuffer.Count, out output, out outputSize); - } + bool done = Interop.OpenSsl.DoSslHandshake(((SafeDeleteSslContext)context).SslContext, inputBuffer, out output, out outputSize); // When the handshake is done, and the context is server, check if the alpnHandle target was set to null during ALPN. // If it was, then that indicates ALPN failed, send failure. @@ -172,7 +163,7 @@ private static SecurityStatusPal EncryptDecryptHelper(SafeDeleteContext security if (encrypt) { - resultSize = Interop.OpenSsl.Encrypt(scHandle, input, ref output, out errorCode); + resultSize = Interop.OpenSsl.Encrypt(scHandle, input.Span, ref output, out errorCode); } else { diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs index 5b9fa7358dcfe..2e4b77c0786c0 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStreamPal.Windows.cs @@ -46,9 +46,10 @@ public static byte[] ConvertAlpnProtocolListToByteArray(List input, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Interop.SspiCli.ContextFlags unusedAttributes = default; + ArraySegment input = inputBuffer != null ? new ArraySegment(inputBuffer, offset, count) : default; ThreeSecurityBuffers threeSecurityBuffers = default; SecurityBuffer? incomingSecurity = input.Array != null ? @@ -73,9 +74,10 @@ public static SecurityStatusPal AcceptSecurityContext(ref SafeFreeCredentials cr return SecurityStatusAdapterPal.GetSecurityStatusPalFromNativeInt(errorCode); } - public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, string targetName, ArraySegment input, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) + public static SecurityStatusPal InitializeSecurityContext(ref SafeFreeCredentials credentialsHandle, ref SafeDeleteSslContext context, string targetName, byte[] inputBuffer, int offset, int count, ref byte[] outputBuffer, SslAuthenticationOptions sslAuthenticationOptions) { Interop.SspiCli.ContextFlags unusedAttributes = default; + ArraySegment input = inputBuffer != null ? new ArraySegment(inputBuffer, offset, count) : default; ThreeSecurityBuffers threeSecurityBuffers = default; SecurityBuffer? incomingSecurity = input.Array != null ? diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNegotiatedCipherSuiteTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNegotiatedCipherSuiteTest.cs index b12ad9beb6bcd..14b4918a24ca2 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNegotiatedCipherSuiteTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNegotiatedCipherSuiteTest.cs @@ -593,7 +593,7 @@ private static CipherSuitesPolicy BuildPolicy(params TlsCipherSuite[] cipherSuit return new CipherSuitesPolicy(cipherSuites); } - private static async Task WaitForSecureConnection(VirtualNetwork connection, Func server, Func client) + private static async Task WaitForSecureConnection(SslStream client, SslClientAuthenticationOptions clientOptions, SslStream server, SslServerAuthenticationOptions serverOptions) { Task serverTask = null; Task clientTask = null; @@ -601,12 +601,13 @@ private static async Task WaitForSecureConnection(VirtualNetwork conn // check if failed synchronously try { - serverTask = server(); - clientTask = client(); + serverTask = server.AuthenticateAsServerAsync(serverOptions, CancellationToken.None); + clientTask = client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); } catch (Exception e) { - connection.BreakConnection(); + client.Close(); + server.Close(); if (!(e is AuthenticationException || e is Win32Exception)) { @@ -625,6 +626,7 @@ private static async Task WaitForSecureConnection(VirtualNetwork conn catch (AuthenticationException) { } catch (Win32Exception) { } catch (VirtualNetwork.VirtualNetworkConnectionBroken) { } + catch (IOException) { } } return e; @@ -635,32 +637,42 @@ private static async Task WaitForSecureConnection(VirtualNetwork conn // Now we expect both sides to fail or both to succeed Exception failure = null; + Task task = null; try { - await serverTask.ConfigureAwait(false); + task = await Task.WhenAny(serverTask, clientTask).TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds).ConfigureAwait(false); + await task ; } catch (Exception e) when (e is AuthenticationException || e is Win32Exception) { failure = e; - // avoid client waiting for server's response - connection.BreakConnection(); + if (task == serverTask) + { + server.Close(); + } + else + { + client.Close(); + } } try { - await clientTask.ConfigureAwait(false); + // Now wait for the other task to finish. + task = (task == serverTask ? clientTask : serverTask); + await task.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds).ConfigureAwait(false); // Fail if server has failed but client has succeeded Assert.Null(failure); } - catch (Exception e) when (e is VirtualNetwork.VirtualNetworkConnectionBroken || e is AuthenticationException || e is Win32Exception) + catch (Exception e) when (e is VirtualNetwork.VirtualNetworkConnectionBroken || e is AuthenticationException || e is Win32Exception || e is IOException) { // Fail if server has succeeded but client has failed Assert.NotNull(failure); - if (e.GetType() != typeof(VirtualNetwork.VirtualNetworkConnectionBroken)) + if (e.GetType() != typeof(VirtualNetwork.VirtualNetworkConnectionBroken) && e.GetType() != typeof(IOException)) { failure = new AggregateException(new Exception[] { failure, e }); } @@ -671,9 +683,10 @@ private static async Task WaitForSecureConnection(VirtualNetwork conn private static NegotiatedParams ConnectAndGetNegotiatedParams(ConnectionParams serverParams, ConnectionParams clientParams) { - VirtualNetwork vn = new VirtualNetwork(); - using (VirtualNetworkStream serverStream = new VirtualNetworkStream(vn, isServer: true), - clientStream = new VirtualNetworkStream(vn, isServer: false)) + (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams(); + + using (clientStream) + using (serverStream) using (SslStream server = new SslStream(serverStream, leaveInnerStreamOpen: false), client = new SslStream(clientStream, leaveInnerStreamOpen: false)) { @@ -696,10 +709,7 @@ private static NegotiatedParams ConnectAndGetNegotiatedParams(ConnectionParams s return true; }); - Func serverTask = () => server.AuthenticateAsServerAsync(serverOptions, CancellationToken.None); - Func clientTask = () => client.AuthenticateAsClientAsync(clientOptions, CancellationToken.None); - - Exception failure = WaitForSecureConnection(vn, serverTask, clientTask).Result; + Exception failure = WaitForSecureConnection(client, clientOptions, server, serverOptions).GetAwaiter().GetResult(); if (failure == null) { diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs index dfee6bca85778..838c673c5d06a 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs @@ -920,23 +920,18 @@ public async Task AuthenticateAsClientAsync_VirtualNetwork_CanceledAfterStart_Th [Fact] public async Task AuthenticateAsClientAsync_Sockets_CanceledAfterStart_ThrowsOperationCanceledException() { - using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - listener.Listen(1); + (Stream client, Stream server) = TestHelper.GetConnectedTcpStreams(); - await client.ConnectAsync(listener.LocalEndPoint); - using (Socket server = await listener.AcceptAsync()) - using (var clientSslStream = new SslStream(new NetworkStream(client), false, AllowAnyServerCertificate)) - using (var serverSslStream = new SslStream(new NetworkStream(server))) - using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate()) - { - var cts = new CancellationTokenSource(); - Task t = clientSslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() { TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false) }, cts.Token); - cts.Cancel(); - await Assert.ThrowsAnyAsync(() => t); - } + using (client) + using (server) + using (var clientSslStream = new SslStream(client, false, AllowAnyServerCertificate)) + using (var serverSslStream = new SslStream(server)) + using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate()) + { + var cts = new CancellationTokenSource(); + Task t = clientSslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions() { TargetHost = certificate.GetNameInfo(X509NameType.SimpleName, false) }, cts.Token); + cts.Cancel(); + await Assert.ThrowsAnyAsync(() => t); } } @@ -958,23 +953,18 @@ public async Task AuthenticateAsServerAsync_VirtualNetwork_CanceledAfterStart_Th [Fact] public async Task AuthenticateAsServerAsync_Sockets_CanceledAfterStart_ThrowsOperationCanceledException() { - using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - listener.Listen(1); + (Stream client, Stream server) = TestHelper.GetConnectedTcpStreams(); - await client.ConnectAsync(listener.LocalEndPoint); - using (Socket server = await listener.AcceptAsync()) - using (var clientSslStream = new SslStream(new NetworkStream(client), false, AllowAnyServerCertificate)) - using (var serverSslStream = new SslStream(new NetworkStream(server))) - using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate()) - { - var cts = new CancellationTokenSource(); - Task t = serverSslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions() { ServerCertificate = certificate }, cts.Token); - cts.Cancel(); - await Assert.ThrowsAnyAsync(() => t); - } + using (client) + using (server) + using (var clientSslStream = new SslStream(client, false, AllowAnyServerCertificate)) + using (var serverSslStream = new SslStream(server)) + using (X509Certificate2 certificate = Configuration.Certificates.GetServerCertificate()) + { + var cts = new CancellationTokenSource(); + Task t = serverSslStream.AuthenticateAsServerAsync(new SslServerAuthenticationOptions() { ServerCertificate = certificate }, cts.Token); + cts.Cancel(); + await Assert.ThrowsAnyAsync(() => t); } } } diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj b/src/libraries/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj index afddd56a90bee..b876bdb9b8480 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/System.Net.Security.Tests.csproj @@ -11,6 +11,7 @@ + @@ -128,4 +129,4 @@ - \ No newline at end of file + diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs new file mode 100644 index 0000000000000..e695cde0f5984 --- /dev/null +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Net.Test.Common; + +namespace System.Net.Security.Tests +{ + public static class TestHelper + { + public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams() + { + if (Capability.SecurityForceSocketStreams()) + { + return GetConnectedTcpStreams(); + } + + return GetConnectedVirtualStreams(); + } + + internal static (NetworkStream ClientStream, NetworkStream ServerStream) GetConnectedTcpStreams() + { + using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + listener.Listen(1); + + var clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + clientSocket.Connect(listener.LocalEndPoint); + Socket serverSocket = listener.Accept(); + + return (new NetworkStream(clientSocket, ownsSocket: true), new NetworkStream(serverSocket, ownsSocket: true)); + } + + } + + internal static (VirtualNetworkStream ClientStream, VirtualNetworkStream ServerStream) GetConnectedVirtualStreams() + { + VirtualNetwork vn = new VirtualNetwork(); + + return (new VirtualNetworkStream(vn, isServer: false), new VirtualNetworkStream(vn, isServer: true)); + } + } +} diff --git a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs index dbeaeada915a9..4cd25bda7362e 100644 --- a/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs +++ b/src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs @@ -65,12 +65,9 @@ private void CloseInternal() // This method assumes that a SSPI context is already in a good shape. // For example it is either a fresh context or already authenticated context that needs renegotiation. // - private void ProcessAuthentication(LazyAsyncResult lazyResult, CancellationToken cancellationToken) - { - } - - private void EndProcessAuthentication(IAsyncResult result) + private Task ProcessAuthentication(bool isAsync = false, bool isApm = false, CancellationToken cancellationToken = default) { + return Task.Run(() => {}); } private void ReturnReadBufferIfEmpty()