From 8a9929bf0b6ae93ea10f4c3e71b03b9958411c1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marie=20P=C3=ADchov=C3=A1?= Date: Mon, 12 Feb 2024 20:40:41 +0000 Subject: [PATCH 1/5] [release/7.0] Update MsQuic --- eng/Versions.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/Versions.props b/eng/Versions.props index ce588cb6a1cf9..6c494c3348c9a 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -179,7 +179,7 @@ 7.0.0-rtm.24060.3 - 2.2.3 + 2.2.5-ci.444313 7.0.0-alpha.1.22459.1 11.1.0-alpha.1.23115.1 From 24a00b5b75b4e60620768e02c11e79bc65386e5b Mon Sep 17 00:00:00 2001 From: Natalia Kondratyeva Date: Wed, 14 Feb 2024 18:03:02 +0000 Subject: [PATCH 2/5] [release/7.0] Fix HTTP/2 WebSocket Abort --- .../Net/Http/Http2LoopbackConnection.cs | 4 +- .../SocketsHttpHandler/Http2Connection.cs | 8 + .../Http/SocketsHttpHandler/Http2Stream.cs | 97 +++++-- .../HttpClientHandlerTest.Http2.cs | 212 ++++++++++++--- .../tests/AbortTest.Loopback.cs | 246 ++++++++++++++++++ .../tests/LoopbackHelper.cs | 21 +- .../LoopbackServer/Http2LoopbackStream.cs | 100 +++++++ .../LoopbackServer/LoopbackWebSocketServer.cs | 148 +++++++++++ .../WebSocketHandshakeHelper.cs | 134 ++++++++++ .../LoopbackServer/WebSocketRequestData.cs | 20 ++ .../System.Net.WebSockets.Client.Tests.csproj | 5 + 11 files changed, 929 insertions(+), 66 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs index 6263483807c76..9863d0adbe9ab 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs @@ -402,7 +402,7 @@ public async Task ReadRequestHeaderFrameAsync(bool expectEndOfStre return (HeadersFrame)frame; } - public async Task ReadDataFrameAsync() + public async Task ReadDataFrameAsync() { // Receive DATA frame for request. Frame frame = await ReadFrameAsync(_timeout).ConfigureAwait(false); @@ -412,7 +412,7 @@ public async Task ReadDataFrameAsync() } Assert.Equal(FrameType.Data, frame.Type); - return frame; + return (DataFrame)frame; } private static (int bytesConsumed, int value) DecodeInteger(ReadOnlySpan headerBlock, byte prefixMask) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 45dc67d7342c5..00b8f21fa8dd3 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1444,6 +1444,14 @@ private int WriteHeaderCollection(HttpRequestMessage request, HttpHeaders header continue; } + // Extended connect requests will use the response content stream for bidirectional communication. + // We will ignore any content set for such requests in Http2Stream.SendRequestBodyAsync, as it has no defined semantics. + // Drop the Content-Length header as well in the unlikely case it was set. + if (knownHeader == KnownHeaders.ContentLength && request.IsExtendedConnectRequest) + { + continue; + } + // For all other known headers, send them via their pre-encoded name and the associated value. WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer); string? separator = null; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs index 89cacb066ece5..edde62b82a56a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs @@ -105,7 +105,9 @@ public Http2Stream(HttpRequestMessage request, Http2Connection connection) _headerBudgetRemaining = connection._pool.Settings.MaxResponseHeadersByteLength; - if (_request.Content == null) + // Extended connect requests will use the response content stream for bidirectional communication. + // We will ignore any content set for such requests in SendRequestBodyAsync, as it has no defined semantics. + if (_request.Content == null || _request.IsExtendedConnectRequest) { _requestCompletionState = StreamCompletionState.Completed; if (_request.IsExtendedConnectRequest) @@ -173,7 +175,9 @@ public HttpResponseMessage GetAndClearResponse() public async Task SendRequestBodyAsync(CancellationToken cancellationToken) { - if (_request.Content == null) + // Extended connect requests will use the response content stream for bidirectional communication. + // Ignore any content set for such requests, as it has no defined semantics. + if (_request.Content == null || _request.IsExtendedConnectRequest) { Debug.Assert(_requestCompletionState == StreamCompletionState.Completed); return; @@ -250,6 +254,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken) // and we also don't want to propagate any error to the caller, in particular for non-duplex scenarios. Debug.Assert(_responseCompletionState == StreamCompletionState.Completed); _requestCompletionState = StreamCompletionState.Completed; + Debug.Assert(!ConnectProtocolEstablished); Complete(); return; } @@ -261,6 +266,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken) _requestCompletionState = StreamCompletionState.Failed; SendReset(); + Debug.Assert(!ConnectProtocolEstablished); Complete(); } @@ -313,6 +319,7 @@ public async Task SendRequestBodyAsync(CancellationToken cancellationToken) if (complete) { + Debug.Assert(!ConnectProtocolEstablished); Complete(); } } @@ -420,7 +427,17 @@ private void Cancel() if (sendReset) { SendReset(); - Complete(); + + // Extended CONNECT notes: + // + // To prevent from calling it *twice*, Extended CONNECT stream's Complete() is only + // called from CloseResponseBody(), as CloseResponseBody() is *always* called + // from Extended CONNECT stream's Dispose(). + + if (!ConnectProtocolEstablished) + { + Complete(); + } } } @@ -810,7 +827,20 @@ public void OnHeadersComplete(bool endStream) Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}"); _responseCompletionState = StreamCompletionState.Completed; - if (_requestCompletionState == StreamCompletionState.Completed) + + // Extended CONNECT notes: + // + // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only + // called from CloseResponseBody(), as CloseResponseBody() is *only* called + // from Extended CONNECT stream's Dispose(). + // + // Due to bidirectional streaming nature of the Extended CONNECT request, + // the *write side* of the stream can only be completed by calling Dispose(). + // + // The streaming in both ways happens over the single "response" stream instance, which makes + // _requestCompletionState *not indicative* of the actual state of the write side of the stream. + + if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished) { Complete(); } @@ -871,7 +901,20 @@ public void OnResponseData(ReadOnlySpan buffer, bool endStream) Debug.Assert(_responseCompletionState == StreamCompletionState.InProgress, $"Response already completed with state={_responseCompletionState}"); _responseCompletionState = StreamCompletionState.Completed; - if (_requestCompletionState == StreamCompletionState.Completed) + + // Extended CONNECT notes: + // + // To prevent from calling it *prematurely*, Extended CONNECT stream's Complete() is only + // called from CloseResponseBody(), as CloseResponseBody() is *only* called + // from Extended CONNECT stream's Dispose(). + // + // Due to bidirectional streaming nature of the Extended CONNECT request, + // the *write side* of the stream can only be completed by calling Dispose(). + // + // The streaming in both ways happens over the single "response" stream instance, which makes + // _requestCompletionState *not indicative* of the actual state of the write side of the stream. + + if (_requestCompletionState == StreamCompletionState.Completed && !ConnectProtocolEstablished) { Complete(); } @@ -1036,17 +1079,17 @@ public async Task ReadResponseHeadersAsync(CancellationToken cancellationToken) Debug.Assert(_response != null && _response.Content != null); // Start to process the response body. var responseContent = (HttpConnectionResponseContent)_response.Content; - if (emptyResponse) + if (ConnectProtocolEstablished) + { + responseContent.SetStream(new Http2ReadWriteStream(this, closeResponseBodyOnDispose: true)); + } + else if (emptyResponse) { // If there are any trailers, copy them over to the response. Normally this would be handled by // the response stream hitting EOF, but if there is no response body, we do it here. MoveTrailersToResponseMessage(_response); responseContent.SetStream(EmptyReadStream.Instance); } - else if (ConnectProtocolEstablished) - { - responseContent.SetStream(new Http2ReadWriteStream(this)); - } else { responseContent.SetStream(new Http2ReadStream(this)); @@ -1309,8 +1352,25 @@ private async ValueTask SendDataAsync(ReadOnlyMemory buffer, CancellationT } } + // This method should only be called from Http2ReadWriteStream.Dispose() private void CloseResponseBody() { + // Extended CONNECT notes: + // + // Due to bidirectional streaming nature of the Extended CONNECT request, + // the *write side* of the stream can only be completed by calling Dispose() + // (which, for Extended CONNECT case, will in turn call CloseResponseBody()) + // + // Similarly to QuicStream, disposal *gracefully* closes the write side of the stream + // (unless we've received RST_STREAM before) and *abortively* closes the read side + // of the stream (unless we've received EOS before). + + if (ConnectProtocolEstablished && _resetException is null) + { + // Gracefully close the write side of the Extended CONNECT stream + _connection.LogExceptions(_connection.SendEndStreamAsync(StreamId)); + } + // Check if the response body has been fully consumed. bool fullyConsumed = false; Debug.Assert(!Monitor.IsEntered(SyncObject)); @@ -1323,6 +1383,7 @@ private void CloseResponseBody() } // If the response body isn't completed, cancel it now. + // This includes aborting the read side of the Extended CONNECT stream. if (!fullyConsumed) { Cancel(); @@ -1337,6 +1398,12 @@ private void CloseResponseBody() lock (SyncObject) { + if (ConnectProtocolEstablished) + { + // This should be the only place where Extended Connect stream is completed + Complete(); + } + _responseBuffer.Dispose(); } } @@ -1430,10 +1497,7 @@ private enum StreamCompletionState : byte private sealed class Http2ReadStream : Http2ReadWriteStream { - public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream) - { - base.CloseResponseBodyOnDispose = true; - } + public Http2ReadStream(Http2Stream http2Stream) : base(http2Stream, closeResponseBodyOnDispose: true) { } public override bool CanWrite => false; @@ -1482,12 +1546,13 @@ public class Http2ReadWriteStream : HttpBaseStream private Http2Stream? _http2Stream; private readonly HttpResponseMessage _responseMessage; - public Http2ReadWriteStream(Http2Stream http2Stream) + public Http2ReadWriteStream(Http2Stream http2Stream, bool closeResponseBodyOnDispose = false) { Debug.Assert(http2Stream != null); Debug.Assert(http2Stream._response != null); _http2Stream = http2Stream; _responseMessage = _http2Stream._response; + CloseResponseBodyOnDispose = closeResponseBodyOnDispose; } ~Http2ReadWriteStream() @@ -1503,7 +1568,7 @@ public Http2ReadWriteStream(Http2Stream http2Stream) } } - protected bool CloseResponseBodyOnDispose { get; set; } + protected bool CloseResponseBodyOnDispose { get; private init; } protected override void Dispose(bool disposing) { diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs index 7b14f99393c28..153cdf9c8c6d6 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http2.cs @@ -2516,66 +2516,198 @@ public async Task PostAsyncDuplex_ServerSendsEndStream_Success() } } - [Fact] - public async Task ConnectAsync_ReadWriteWebSocketStream() + [Theory] + [MemberData(nameof(UseSsl_MemberData))] + public async Task ExtendedConnect_ReadWriteResponseStream(bool useSsl) { - var clientMessage = new byte[] { 1, 2, 3 }; - var serverMessage = new byte[] { 4, 5, 6, 7 }; + const int MessageCount = 3; + byte[] clientMessage = new byte[] { 1, 2, 3 }; + byte[] serverMessage = new byte[] { 4, 5, 6, 7 }; - using Http2LoopbackServer server = Http2LoopbackServer.CreateServer(); - Http2LoopbackConnection connection = null; + TaskCompletionSource clientCompleted = new(TaskCreationOptions.RunContinuationsAsynchronously); - Task serverTask = Task.Run(async () => + await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync(async uri => { - connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); + using HttpClient client = CreateHttpClient(); - // read request headers - (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + HttpRequestMessage request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true); + request.Headers.Protocol = "foo"; + + bool readFromContentStream = false; + + // We won't send the content bytes, but we will send content headers. + // Since we're dropping the content, we'll also drop the Content-Length header. + request.Content = new StreamContent(new DelegateStream( + readAsyncFunc: (_, _, _, _) => + { + readFromContentStream = true; + throw new UnreachableException(); + })); + + request.Headers.Add("User-Agent", "foo"); + request.Content.Headers.Add("Content-Language", "bar"); + request.Content.Headers.ContentLength = 42; + + using HttpResponseMessage response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + + using Stream responseStream = await response.Content.ReadAsStreamAsync(); + + for (int i = 0; i < MessageCount; i++) + { + await responseStream.WriteAsync(clientMessage); + await responseStream.FlushAsync(); + + byte[] readBuffer = new byte[serverMessage.Length]; + await responseStream.ReadExactlyAsync(readBuffer); + Assert.Equal(serverMessage, readBuffer); + } + + // Receive server's EOS + Assert.Equal(0, await responseStream.ReadAsync(new byte[1])); + + Assert.False(readFromContentStream); + + clientCompleted.SetResult(); + }, + async server => + { + await using Http2LoopbackConnection connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); + + (int streamId, HttpRequestData request) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + + Assert.Equal("foo", request.GetSingleHeaderValue("User-Agent")); + Assert.Equal("bar", request.GetSingleHeaderValue("Content-Language")); + Assert.Equal(0, request.GetHeaderValueCount("Content-Length")); - // send response headers await connection.SendResponseHeadersAsync(streamId, endStream: false).ConfigureAwait(false); - // send reply - await connection.SendResponseDataAsync(streamId, serverMessage, endStream: false); + for (int i = 0; i < MessageCount; i++) + { + DataFrame dataFrame = await connection.ReadDataFrameAsync(); + Assert.Equal(clientMessage, dataFrame.Data.ToArray()); - // send server EOS - await connection.SendResponseDataAsync(streamId, Array.Empty(), endStream: true); - }); + await connection.SendResponseDataAsync(streamId, serverMessage, endStream: i == MessageCount - 1); + } - StreamingHttpContent requestContent = new StreamingHttpContent(); + await clientCompleted.Task.WaitAsync(TestHelper.PassingTestTimeout); + }, options: new GenericLoopbackOptions { UseSsl = useSsl }); + } - using var handler = CreateSocketsHttpHandler(allowAllCertificates: true); - using HttpClient client = new HttpClient(handler); + public static IEnumerable UseSsl_MemberData() + { + yield return new object[] { false }; - HttpRequestMessage request = new(HttpMethod.Connect, server.Address); - request.Version = HttpVersion.Version20; - request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; - request.Headers.Protocol = "websocket"; + if (PlatformDetection.SupportsAlpn) + { + yield return new object[] { true }; + } + } - // initiate request - var responseTask = client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + [Theory] + [MemberData(nameof(UseSsl_MemberData))] + public async Task ExtendedConnect_ServerSideEOS_ReceivedByClient(bool useSsl) + { + var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout); + var serverReceivedEOS = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - using HttpResponseMessage response = await responseTask.WaitAsync(TimeSpan.FromSeconds(10)); + await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync( + clientFunc: async uri => + { + var client = CreateHttpClient(); + var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true); + request.Headers.Protocol = "foo"; + + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token); + var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token); + + // receive server's EOS + Assert.Equal(0, await responseStream.ReadAsync(new byte[1], timeoutTcs.Token)); + + // send client's EOS + responseStream.Dispose(); + + // wait for "ack" from server + await serverReceivedEOS.Task.WaitAsync(timeoutTcs.Token); + + // can dispose handler now + client.Dispose(); + }, + serverFunc: async server => + { + await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync( + new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); + + (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + await connection.SendResponseHeadersAsync(streamId, endStream: false); + + // send server's EOS + await connection.SendResponseDataAsync(streamId, Array.Empty(), endStream: true); + + // receive client's EOS "in response" to server's EOS + var eosFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token)); + Assert.Equal(streamId, eosFrame.StreamId); + Assert.Equal(0, eosFrame.Data.Length); + Assert.True(eosFrame.EndStreamFlag); + + serverReceivedEOS.SetResult(); + + // on handler dispose, client should shutdown the connection without sending additional frames + await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token); + }, + options: new GenericLoopbackOptions { UseSsl = useSsl }); + } + + [Theory] + [MemberData(nameof(UseSsl_MemberData))] + public async Task ExtendedConnect_ClientSideEOS_ReceivedByServer(bool useSsl) + { + var timeoutTcs = new CancellationTokenSource(TestHelper.PassingTestTimeout); + var serverReceivedRst = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await Http2LoopbackServerFactory.Singleton.CreateClientAndServerAsync( + clientFunc: async uri => + { + var client = CreateHttpClient(); + var request = CreateRequest(HttpMethod.Connect, uri, UseVersion, exactVersion: true); + request.Headers.Protocol = "foo"; + + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, timeoutTcs.Token); + var responseStream = await response.Content.ReadAsStreamAsync(timeoutTcs.Token); - await serverTask.WaitAsync(TimeSpan.FromSeconds(60)); + // send client's EOS + // this will also send RST_STREAM as we didn't receive server's EOS before + responseStream.Dispose(); + + // wait for "ack" from server + await serverReceivedRst.Task.WaitAsync(timeoutTcs.Token); + + // can dispose handler now + client.Dispose(); + }, + serverFunc: async server => + { + await using var connection = await ((Http2LoopbackServer)server).EstablishConnectionAsync( + new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }); - var responseStream = await response.Content.ReadAsStreamAsync(); + (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync(readBody: false); + await connection.SendResponseHeadersAsync(streamId, endStream: false); - // receive data - var readBuffer = new byte[10]; - int bytesRead = await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); - Assert.Equal(bytesRead, serverMessage.Length); - Assert.Equal(serverMessage, readBuffer[..bytesRead]); + // receive client's EOS + var eosFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token)); + Assert.Equal(streamId, eosFrame.StreamId); + Assert.Equal(0, eosFrame.Data.Length); + Assert.True(eosFrame.EndStreamFlag); - await responseStream.WriteAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10)); + // receive client's RST_STREAM as we didn't send server's EOS before + var rstFrame = Assert.IsType(await connection.ReadFrameAsync(timeoutTcs.Token)); + Assert.Equal(streamId, rstFrame.StreamId); - // Send client's EOS - requestContent.CompleteStream(); - // Receive server's EOS - Assert.Equal(0, await responseStream.ReadAsync(readBuffer).AsTask().WaitAsync(TimeSpan.FromSeconds(10))); + serverReceivedRst.SetResult(); - Assert.NotNull(connection); - await connection.DisposeAsync(); + // on handler dispose, client should shutdown the connection without sending additional frames + await connection.WaitForClientDisconnectAsync().WaitAsync(timeoutTcs.Token); + }, + options: new GenericLoopbackOptions { UseSsl = useSsl }); } [Fact] diff --git a/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs new file mode 100644 index 0000000000000..0aa83697a9de7 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/AbortTest.Loopback.cs @@ -0,0 +1,246 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + [ConditionalClass(typeof(ClientWebSocketTestBase), nameof(WebSocketsSupported))] + [SkipOnPlatform(TestPlatforms.Browser, "System.Net.Sockets are not supported on browser")] + public abstract class AbortTest_Loopback : ClientWebSocketTestBase + { + public AbortTest_Loopback(ITestOutputHelper output) : base(output) { } + + protected virtual Version HttpVersion => Net.HttpVersion.Version11; + + [Theory] + [MemberData(nameof(AbortClient_MemberData))] + public Task AbortClient_ServerGetsCorrectException(AbortType abortType, bool useSsl, bool verifySendReceive) + { + var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; + var serverMsg = new byte[] { 42 }; + var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); + + return LoopbackWebSocketServer.RunAsync( + async (clientWebSocket, token) => + { + if (verifySendReceive) + { + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token); + } + + switch (abortType) + { + case AbortType.Abort: + clientWebSocket.Abort(); + break; + case AbortType.Dispose: + clientWebSocket.Dispose(); + break; + } + }, + async (serverWebSocket, token) => + { + if (verifySendReceive) + { + await VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, token); + } + + var readBuffer = new byte[1]; + var exception = await Assert.ThrowsAsync(async () => + await serverWebSocket.ReceiveAsync(readBuffer, token)); + + Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode); + Assert.Equal(WebSocketState.Aborted, serverWebSocket.State); + }, + new LoopbackWebSocketServer.Options(HttpVersion, useSsl, GetInvoker()), + timeoutCts.Token); + } + + [Theory] + [MemberData(nameof(ServerPrematureEos_MemberData))] + public Task ServerPrematureEos_ClientGetsCorrectException(ServerEosType serverEosType, bool useSsl) + { + var clientMsg = new byte[] { 1, 2, 3, 4, 5, 6 }; + var serverMsg = new byte[] { 42 }; + var clientAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var serverAckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var timeoutCts = new CancellationTokenSource(TimeOutMilliseconds); + + var globalOptions = new LoopbackWebSocketServer.Options(HttpVersion, useSsl, HttpInvoker: null) + { + DisposeServerWebSocket = false, + ManualServerHandshakeResponse = true + }; + + var serverReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedEosTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + return LoopbackWebSocketServer.RunAsync( + async uri => + { + var token = timeoutCts.Token; + var clientOptions = globalOptions with { HttpInvoker = GetInvoker() }; + var clientWebSocket = await LoopbackWebSocketServer.GetConnectedClientAsync(uri, clientOptions, token).ConfigureAwait(false); + + if (serverEosType == ServerEosType.AfterSomeData) + { + await VerifySendReceiveAsync(clientWebSocket, clientMsg, serverMsg, clientAckTcs, serverAckTcs.Task, token).ConfigureAwait(false); + } + + // only one side of the stream was closed. the other should work + await clientWebSocket.SendAsync(clientMsg, WebSocketMessageType.Binary, endOfMessage: true, token).ConfigureAwait(false); + + var exception = await Assert.ThrowsAsync(() => clientWebSocket.ReceiveAsync(new byte[1], token)); + Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode); + + clientReceivedEosTcs.SetResult(); + clientWebSocket.Dispose(); + }, + async (requestData, token) => + { + WebSocket serverWebSocket = null!; + await SendServerResponseAndEosAsync( + requestData, + serverEosType, + (wsData, ct) => + { + var wsOptions = new WebSocketCreationOptions { IsServer = true }; + serverWebSocket = WebSocket.CreateFromStream(wsData.WebSocketStream, wsOptions); + + return serverEosType == ServerEosType.AfterSomeData + ? VerifySendReceiveAsync(serverWebSocket, serverMsg, clientMsg, serverAckTcs, clientAckTcs.Task, ct) + : Task.CompletedTask; + }, + token); + + Assert.NotNull(serverWebSocket); + + // only one side of the stream was closed. the other should work + var readBuffer = new byte[clientMsg.Length]; + var result = await serverWebSocket.ReceiveAsync(readBuffer, token); + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + Assert.Equal(clientMsg.Length, result.Count); + Assert.True(result.EndOfMessage); + Assert.Equal(clientMsg, readBuffer); + + await clientReceivedEosTcs.Task.WaitAsync(token).ConfigureAwait(false); + + var exception = await Assert.ThrowsAsync(() => serverWebSocket.ReceiveAsync(readBuffer, token)); + Assert.Equal(WebSocketError.ConnectionClosedPrematurely, exception.WebSocketErrorCode); + + serverWebSocket.Dispose(); + }, + globalOptions, + timeoutCts.Token); + } + + protected virtual Task SendServerResponseAndEosAsync(WebSocketRequestData requestData, ServerEosType serverEosType, Func serverFunc, CancellationToken cancellationToken) + => WebSocketHandshakeHelper.SendHttp11ServerResponseAndEosAsync(requestData, serverFunc, cancellationToken); // override for HTTP/2 + + private static readonly bool[] Bool_Values = new[] { false, true }; + private static readonly bool[] UseSsl_Values = PlatformDetection.SupportsAlpn ? Bool_Values : new[] { false }; + + public static IEnumerable AbortClient_MemberData() + { + foreach (var abortType in Enum.GetValues()) + { + foreach (var useSsl in UseSsl_Values) + { + foreach (var verifySendReceive in Bool_Values) + { + yield return new object[] { abortType, useSsl, verifySendReceive }; + } + } + } + } + + public static IEnumerable ServerPrematureEos_MemberData() + { + foreach (var serverEosType in Enum.GetValues()) + { + foreach (var useSsl in UseSsl_Values) + { + yield return new object[] { serverEosType, useSsl }; + } + } + } + + public enum AbortType + { + Abort, + Dispose + } + + public enum ServerEosType + { + WithHeaders, + RightAfterHeaders, + AfterSomeData + } + + private static async Task VerifySendReceiveAsync(WebSocket ws, byte[] localMsg, byte[] remoteMsg, + TaskCompletionSource localAckTcs, Task remoteAck, CancellationToken cancellationToken) + { + var sendTask = ws.SendAsync(localMsg, WebSocketMessageType.Binary, endOfMessage: true, cancellationToken); + + var recvBuf = new byte[remoteMsg.Length * 2]; + var recvResult = await ws.ReceiveAsync(recvBuf, cancellationToken).ConfigureAwait(false); + + Assert.Equal(WebSocketMessageType.Binary, recvResult.MessageType); + Assert.Equal(remoteMsg.Length, recvResult.Count); + Assert.True(recvResult.EndOfMessage); + Assert.Equal(remoteMsg, recvBuf[..recvResult.Count]); + + localAckTcs.SetResult(); + + await sendTask.ConfigureAwait(false); + await remoteAck.WaitAsync(cancellationToken).ConfigureAwait(false); + } + } + + // --- HTTP/1.1 WebSocket loopback tests --- + + public class AbortTest_Invoker_Loopback : AbortTest_Loopback + { + public AbortTest_Invoker_Loopback(ITestOutputHelper output) : base(output) { } + protected override bool UseCustomInvoker => true; + } + + public class AbortTest_HttpClient_Loopback : AbortTest_Loopback + { + public AbortTest_HttpClient_Loopback(ITestOutputHelper output) : base(output) { } + protected override bool UseHttpClient => true; + } + + public class AbortTest_SharedHandler_Loopback : AbortTest_Loopback + { + public AbortTest_SharedHandler_Loopback(ITestOutputHelper output) : base(output) { } + } + + // --- HTTP/2 WebSocket loopback tests --- + + public class AbortTest_Invoker_Http2 : AbortTest_Invoker_Loopback + { + public AbortTest_Invoker_Http2(ITestOutputHelper output) : base(output) { } + protected override Version HttpVersion => Net.HttpVersion.Version20; + protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) + => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); + } + + public class AbortTest_HttpClient_Http2 : AbortTest_HttpClient_Loopback + { + public AbortTest_HttpClient_Http2(ITestOutputHelper output) : base(output) { } + protected override Version HttpVersion => Net.HttpVersion.Version20; + protected override Task SendServerResponseAndEosAsync(WebSocketRequestData rd, ServerEosType eos, Func callback, CancellationToken ct) + => WebSocketHandshakeHelper.SendHttp2ServerResponseAndEosAsync(rd, eosInHeadersFrame: eos == ServerEosType.WithHeaders, callback, ct); + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs index 48d167b072f78..cee509ee06846 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs @@ -28,14 +28,7 @@ public static async Task> WebSocketHandshakeAsync(Loo if (headerName == "Sec-WebSocket-Key") { string headerValue = tokens[1].Trim(); - string responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(headerValue); - serverResponse = - "HTTP/1.1 101 Switching Protocols\r\n" + - "Content-Length: 0\r\n" + - "Upgrade: websocket\r\n" + - "Connection: Upgrade\r\n" + - (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") + - "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; + serverResponse = GetServerResponseString(headerValue, extensions); } } } @@ -50,6 +43,18 @@ public static async Task> WebSocketHandshakeAsync(Loo return null; } + public static string GetServerResponseString(string secWebSocketKey, string? extensions = null) + { + var responseSecurityAcceptValue = ComputeWebSocketHandshakeSecurityAcceptValue(secWebSocketKey); + return + "HTTP/1.1 101 Switching Protocols\r\n" + + "Content-Length: 0\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") + + "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; + } + private static string ComputeWebSocketHandshakeSecurityAcceptValue(string secWebSocketKey) { // GUID specified by RFC 6455. diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs new file mode 100644 index 0000000000000..1b3b51840ec99 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/Http2LoopbackStream.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.Test.Common +{ + public class Http2LoopbackStream : Stream + { + private readonly Http2LoopbackConnection _connection; + private readonly int _streamId; + private bool _readEnded; + private ReadOnlyMemory _leftoverReadData; + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + + public Http2LoopbackConnection Connection => _connection; + public int StreamId => _streamId; + + public Http2LoopbackStream(Http2LoopbackConnection connection, int streamId) + { + _connection = connection; + _streamId = streamId; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (!_leftoverReadData.IsEmpty) + { + int read = Math.Min(buffer.Length, _leftoverReadData.Length); + _leftoverReadData.Span.Slice(0, read).CopyTo(buffer.Span); + _leftoverReadData = _leftoverReadData.Slice(read); + return read; + } + + if (_readEnded) + { + return 0; + } + + DataFrame dataFrame = (DataFrame)await _connection.ReadFrameAsync(cancellationToken); + Assert.Equal(_streamId, dataFrame.StreamId); + _leftoverReadData = dataFrame.Data; + _readEnded = dataFrame.EndStreamFlag; + + return await ReadAsync(buffer, cancellationToken); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + await _connection.SendResponseDataAsync(_streamId, buffer, endStream: false); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + + protected override void Dispose(bool disposing) => DisposeAsync().GetAwaiter().GetResult(); + + public override async ValueTask DisposeAsync() + { + try + { + await _connection.SendResponseDataAsync(_streamId, Memory.Empty, endStream: true).ConfigureAwait(false); + + if (!_readEnded) + { + var rstFrame = new RstStreamFrame(FrameFlags.None, (int)ProtocolErrors.NO_ERROR, _streamId); + await _connection.WriteFrameAsync(rstFrame).ConfigureAwait(false); + } + } + catch (IOException) + { + // Ignore connection errors + } + catch (SocketException) + { + // Ignore connection errors + } + } + + public override void Flush() { } + public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; + + public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException(); + public override void SetLength(long value) => throw new NotImplementedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException(); + public override long Length => throw new NotImplementedException(); + public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs new file mode 100644 index 0000000000000..b24e2e20d40df --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/LoopbackWebSocketServer.cs @@ -0,0 +1,148 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Http; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Client.Tests +{ + public static class LoopbackWebSocketServer + { + public static Task RunAsync( + Func clientWebSocketFunc, + Func serverWebSocketFunc, + Options options, + CancellationToken cancellationToken) + { + Assert.False(options.ManualServerHandshakeResponse, "Not supported in this overload"); + + return RunAsyncPrivate( + uri => RunClientAsync(uri, clientWebSocketFunc, options, cancellationToken), + (requestData, token) => RunServerAsync(requestData, serverWebSocketFunc, options, token), + options, + cancellationToken); + } + + public static Task RunAsync( + Func loopbackClientFunc, + Func loopbackServerFunc, + Options options, + CancellationToken cancellationToken) + { + Assert.False(options.DisposeClientWebSocket, "Not supported in this overload"); + Assert.False(options.DisposeServerWebSocket, "Not supported in this overload"); + Assert.False(options.DisposeHttpInvoker, "Not supported in this overload"); + Assert.Null(options.HttpInvoker); // Not supported in this overload + + return RunAsyncPrivate(loopbackClientFunc, loopbackServerFunc, options, cancellationToken); + } + + private static Task RunAsyncPrivate( + Func loopbackClientFunc, + Func loopbackServerFunc, + Options options, + CancellationToken cancellationToken) + { + bool sendDefaultServerHandshakeResponse = !options.ManualServerHandshakeResponse; + if (options.HttpVersion == HttpVersion.Version11) + { + return LoopbackServer.CreateClientAndServerAsync( + loopbackClientFunc, + async server => + { + await server.AcceptConnectionAsync(async connection => + { + var requestData = await WebSocketHandshakeHelper.ProcessHttp11RequestAsync(connection, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false); + await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false); + }); + }, + new LoopbackServer.Options { WebSocketEndpoint = true, UseSsl = options.UseSsl }); + } + else if (options.HttpVersion == HttpVersion.Version20) + { + return Http2LoopbackServer.CreateClientAndServerAsync( + loopbackClientFunc, + async server => + { + var requestData = await WebSocketHandshakeHelper.ProcessHttp2RequestAsync(server, sendDefaultServerHandshakeResponse, cancellationToken).ConfigureAwait(false); + var http2Connection = requestData.Http2Connection!; + var http2StreamId = requestData.Http2StreamId.Value; + + await loopbackServerFunc(requestData, cancellationToken).ConfigureAwait(false); + + await http2Connection.DisposeAsync().ConfigureAwait(false); + }, + new Http2Options { WebSocketEndpoint = true, UseSsl = options.UseSsl }); + } + else + { + throw new ArgumentException(nameof(options.HttpVersion)); + } + } + + private static async Task RunServerAsync( + WebSocketRequestData requestData, + Func serverWebSocketFunc, + Options options, + CancellationToken cancellationToken) + { + var wsOptions = new WebSocketCreationOptions { IsServer = true }; + var serverWebSocket = WebSocket.CreateFromStream(requestData.WebSocketStream, wsOptions); + + await serverWebSocketFunc(serverWebSocket, cancellationToken).ConfigureAwait(false); + + if (options.DisposeServerWebSocket) + { + serverWebSocket.Dispose(); + } + } + + private static async Task RunClientAsync( + Uri uri, + Func clientWebSocketFunc, + Options options, + CancellationToken cancellationToken) + { + var clientWebSocket = await GetConnectedClientAsync(uri, options, cancellationToken).ConfigureAwait(false); + + await clientWebSocketFunc(clientWebSocket, cancellationToken).ConfigureAwait(false); + + if (options.DisposeClientWebSocket) + { + clientWebSocket.Dispose(); + } + + if (options.DisposeHttpInvoker) + { + options.HttpInvoker?.Dispose(); + } + } + + public static async Task GetConnectedClientAsync(Uri uri, Options options, CancellationToken cancellationToken) + { + var clientWebSocket = new ClientWebSocket(); + clientWebSocket.Options.HttpVersion = options.HttpVersion; + clientWebSocket.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact; + + if (options.UseSsl && options.HttpInvoker is null) + { + clientWebSocket.Options.RemoteCertificateValidationCallback = delegate { return true; }; + } + + await clientWebSocket.ConnectAsync(uri, options.HttpInvoker, cancellationToken).ConfigureAwait(false); + + return clientWebSocket; + } + + public record class Options(Version HttpVersion, bool UseSsl, HttpMessageInvoker? HttpInvoker) + { + public bool DisposeServerWebSocket { get; set; } = true; + public bool DisposeClientWebSocket { get; set; } + public bool DisposeHttpInvoker { get; set; } + public bool ManualServerHandshakeResponse { get; set; } + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs new file mode 100644 index 0000000000000..f4d2f42f5edbb --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketHandshakeHelper.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Net.Sockets; +using System.Net.Test.Common; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Client.Tests +{ + public static class WebSocketHandshakeHelper + { + public static async Task ProcessHttp11RequestAsync(LoopbackServer.Connection connection, bool sendServerResponse = true, CancellationToken cancellationToken = default) + { + List headers = await connection.ReadRequestHeaderAsync().WaitAsync(cancellationToken).ConfigureAwait(false); + + var data = new WebSocketRequestData() + { + HttpVersion = HttpVersion.Version11, + Http11Connection = connection + }; + + foreach (string header in headers.Skip(1)) + { + string[] tokens = header.Split(new char[] { ':' }, StringSplitOptions.RemoveEmptyEntries); + if (tokens.Length is 1 or 2) + { + data.Headers.Add( + tokens[0].Trim(), + tokens.Length == 2 ? tokens[1].Trim() : null); + } + } + + var isValidOpeningHandshake = data.Headers.TryGetValue("Sec-WebSocket-Key", out var secWebSocketKey); + Assert.True(isValidOpeningHandshake); + + if (sendServerResponse) + { + await SendHttp11ServerResponseAsync(connection, secWebSocketKey, cancellationToken).ConfigureAwait(false); + } + + data.WebSocketStream = connection.Stream; + return data; + } + + private static async Task SendHttp11ServerResponseAsync(LoopbackServer.Connection connection, string secWebSocketKey, CancellationToken cancellationToken) + { + var serverResponse = LoopbackHelper.GetServerResponseString(secWebSocketKey); + await connection.WriteStringAsync(serverResponse).WaitAsync(cancellationToken).ConfigureAwait(false); + } + + public static async Task ProcessHttp2RequestAsync(Http2LoopbackServer server, bool sendServerResponse = true, CancellationToken cancellationToken = default) + { + var connection = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.EnableConnect, Value = 1 }) + .WaitAsync(cancellationToken).ConfigureAwait(false); + + (int streamId, var httpRequestData) = await connection.ReadAndParseRequestHeaderAsync(readBody: false) + .WaitAsync(cancellationToken).ConfigureAwait(false); + + var data = new WebSocketRequestData + { + HttpVersion = HttpVersion.Version20, + Http2Connection = connection, + Http2StreamId = streamId + }; + + foreach (var header in httpRequestData.Headers) + { + Assert.NotNull(header.Name); + data.Headers.Add(header.Name, header.Value); + } + + var isValidOpeningHandshake = httpRequestData.Method == HttpMethod.Connect.ToString() && data.Headers.ContainsKey(":protocol"); + Assert.True(isValidOpeningHandshake); + + if (sendServerResponse) + { + await SendHttp2ServerResponseAsync(connection, streamId, cancellationToken: cancellationToken).ConfigureAwait(false); + } + + data.WebSocketStream = new Http2LoopbackStream(connection, streamId); + return data; + } + + private static async Task SendHttp2ServerResponseAsync(Http2LoopbackConnection connection, int streamId, bool endStream = false, CancellationToken cancellationToken = default) + { + // send status 200 OK to establish websocket + // we don't need to send anything additional as Sec-WebSocket-Key is not used for HTTP/2 + // note: endStream=true is abnormal and used for testing premature EOS scenarios only + await connection.SendResponseHeadersAsync(streamId, endStream: endStream).WaitAsync(cancellationToken).ConfigureAwait(false); + } + + public static async Task SendHttp11ServerResponseAndEosAsync(WebSocketRequestData requestData, Func? requestDataCallback, CancellationToken cancellationToken) + { + Assert.Equal(HttpVersion.Version11, requestData.HttpVersion); + + // sending default handshake response + await SendHttp11ServerResponseAsync(requestData.Http11Connection!, requestData.Headers["Sec-WebSocket-Key"], cancellationToken).ConfigureAwait(false); + + if (requestDataCallback is not null) + { + await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false); + } + + // send server EOS (half-closing from server side) + requestData.Http11Connection!.Socket.Shutdown(SocketShutdown.Send); + } + + public static async Task SendHttp2ServerResponseAndEosAsync(WebSocketRequestData requestData, bool eosInHeadersFrame, Func? requestDataCallback, CancellationToken cancellationToken) + { + Assert.Equal(HttpVersion.Version20, requestData.HttpVersion); + + var connection = requestData.Http2Connection!; + var streamId = requestData.Http2StreamId!.Value; + + await SendHttp2ServerResponseAsync(connection, streamId, endStream: eosInHeadersFrame, cancellationToken).ConfigureAwait(false); + + if (requestDataCallback is not null) + { + await requestDataCallback(requestData, cancellationToken).ConfigureAwait(false); + } + + if (!eosInHeadersFrame) + { + // send server EOS (half-closing from server side) + await connection.SendResponseDataAsync(streamId, Array.Empty(), endStream: true).ConfigureAwait(false); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs new file mode 100644 index 0000000000000..799157a370f07 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackServer/WebSocketRequestData.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Net.Test.Common; + +namespace System.Net.WebSockets.Client.Tests +{ + public class WebSocketRequestData + { + public Dictionary Headers { get; set; } = new Dictionary(); + public Stream? WebSocketStream { get; set; } + + public Version HttpVersion { get; set; } + public LoopbackServer.Connection? Http11Connection { get; set; } + public Http2LoopbackConnection? Http2Connection { get; set; } + public int? Http2StreamId { get; set; } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index 7aff0244108ee..0f918866e8a91 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -59,6 +59,7 @@ + @@ -68,6 +69,10 @@ + + + + From fdf31ccdec3c64c090bd141ccc2fc360c7b8eb00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20S=C3=A1nchez=20L=C3=B3pez?= <1175054+carlossanlop@users.noreply.github.com> Date: Tue, 12 Mar 2024 12:23:26 -0700 Subject: [PATCH 3/5] Update MicrosoftNativeQuicMsQuicVersion to 2.3.5 --- eng/Versions.props | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/Versions.props b/eng/Versions.props index e168907d30f50..5d7da41853be3 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -179,7 +179,7 @@ 7.0.0-rtm.24115.1 - 2.2.5-ci.444313 + 2.3.5 7.0.0-alpha.1.22459.1 11.1.0-alpha.1.23115.1 From b6e8c6a6fe1b42c66f54ba3678697567fca7113d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20S=C3=A1nchez=20L=C3=B3pez?= <1175054+carlossanlop@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:21:06 -0700 Subject: [PATCH 4/5] [7.0] Append job attempt number to log artifact name to get rid of file exists failure on reattempts (#99711) * Append job attempt number to log artifact name to get rid of file exists failure on reattempts. * Revert unrelated change --- eng/pipelines/official/jobs/prepare-signed-artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eng/pipelines/official/jobs/prepare-signed-artifacts.yml b/eng/pipelines/official/jobs/prepare-signed-artifacts.yml index 213d56b3cf69e..c5a0bc137249e 100644 --- a/eng/pipelines/official/jobs/prepare-signed-artifacts.yml +++ b/eng/pipelines/official/jobs/prepare-signed-artifacts.yml @@ -2,7 +2,7 @@ parameters: dependsOn: [] PublishRidAgnosticPackagesFromPlatform: '' isOfficialBuild: false - logArtifactName: 'Logs-PrepareSignedArtifacts' + logArtifactName: 'Logs-PrepareSignedArtifacts_Attempt$(System.JobAttempt)' jobs: - job: PrepareSignedArtifacts From 0d59506d6c3912d5ee7c659cedaf92846aaefe78 Mon Sep 17 00:00:00 2001 From: vseanreesermsft <78103370+vseanreesermsft@users.noreply.github.com> Date: Tue, 19 Mar 2024 14:44:29 -0700 Subject: [PATCH 5/5] Fix exporting certificate keys on macOS 14.4. (#99981) Apple changed the error code we get back from a failed data-key export. This caused us to not attempt to export the key using the legacy APIs and assume the key export failed. This change adds the additional error code returned from macOS 14.4. Co-authored-by: Kevin Jones --- .../Interop.SecKeyRef.cs | 6 +- .../tests/CertTests.cs | 105 +++++++++++++++++- 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/src/libraries/Common/src/Interop/OSX/System.Security.Cryptography.Native.Apple/Interop.SecKeyRef.cs b/src/libraries/Common/src/Interop/OSX/System.Security.Cryptography.Native.Apple/Interop.SecKeyRef.cs index ad5a34517b1a5..b9ce7b498ef61 100644 --- a/src/libraries/Common/src/Interop/OSX/System.Security.Cryptography.Native.Apple/Interop.SecKeyRef.cs +++ b/src/libraries/Common/src/Interop/OSX/System.Security.Cryptography.Native.Apple/Interop.SecKeyRef.cs @@ -127,6 +127,10 @@ internal static bool TrySecKeyCopyExternalRepresentation( { const int errSecPassphraseRequired = -25260; + // macOS Sonoma 14.4 started returning errSecInvalidKeyAttributeMask when a key could not be exported + // because it must be exported with a password. + const int errSecInvalidKeyAttributeMask = -67738; + int result = AppleCryptoNative_SecKeyCopyExternalRepresentation( key, out SafeCFDataHandle data, @@ -141,7 +145,7 @@ internal static bool TrySecKeyCopyExternalRepresentation( externalRepresentation = CoreFoundation.CFGetData(data); return true; case kErrorSeeError: - if (Interop.CoreFoundation.GetErrorCode(errorHandle) == errSecPassphraseRequired) + if (Interop.CoreFoundation.GetErrorCode(errorHandle) is errSecPassphraseRequired or errSecInvalidKeyAttributeMask) { externalRepresentation = Array.Empty(); return false; diff --git a/src/libraries/System.Security.Cryptography.X509Certificates/tests/CertTests.cs b/src/libraries/System.Security.Cryptography.X509Certificates/tests/CertTests.cs index 439252e588c08..15e9be90ae1fd 100644 --- a/src/libraries/System.Security.Cryptography.X509Certificates/tests/CertTests.cs +++ b/src/libraries/System.Security.Cryptography.X509Certificates/tests/CertTests.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates.Tests.CertificateCreation; using System.Threading; using Microsoft.DotNet.XUnitExtensions; using Test.Cryptography; @@ -24,6 +25,108 @@ public CertTests(ITestOutputHelper output) _log = output; } + [Fact] + public static void PrivateKey_FromCertificate_CanExportPrivate_ECDsa() + { + using (ECDsa ca = ECDsa.Create(ECCurve.NamedCurves.nistP256)) + { + CertificateRequest req = new("CN=potatos", ca, HashAlgorithmName.SHA256); + + using (X509Certificate2 cert = req.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddDays(3))) + using (ECDsa certKey = cert.GetECDsaPrivateKey()) + { + ECParameters certParameters = certKey.ExportParameters(true); + ECParameters originalParameters = ca.ExportParameters(true); + AssertExtensions.SequenceEqual(originalParameters.D, certParameters.D); + } + } + } + + [Fact] + public static void PrivateKey_FromCertificate_CanExportPrivate_RSA() + { + using (RSA ca = RSA.Create(2048)) + { + CertificateRequest req = new("CN=potatos", ca, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + + using (X509Certificate2 cert = req.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddDays(3))) + using (RSA certKey = cert.GetRSAPrivateKey()) + { + RSAParameters certParameters = certKey.ExportParameters(true); + RSAParameters originalParameters = ca.ExportParameters(true); + AssertExtensions.SequenceEqual(originalParameters.P, certParameters.P); + AssertExtensions.SequenceEqual(originalParameters.Q, certParameters.Q); + } + } + } + + [Fact] + [SkipOnPlatform(PlatformSupport.MobileAppleCrypto, "DSA is not available")] + public static void PrivateKey_FromCertificate_CanExportPrivate_DSA() + { + DSAParameters originalParameters = TestData.GetDSA1024Params(); + + using (DSA ca = DSA.Create()) + { + ca.ImportParameters(originalParameters); + DSAX509SignatureGenerator gen = new DSAX509SignatureGenerator(ca); + X500DistinguishedName dn = new X500DistinguishedName("CN=potatos"); + + CertificateRequest req = new CertificateRequest( + dn, + gen.PublicKey, + HashAlgorithmName.SHA1); + + using (X509Certificate2 cert = req.Create(dn, gen, DateTimeOffset.Now, DateTimeOffset.Now.AddDays(3), new byte[] { 1, 2, 3 })) + using (X509Certificate2 certWithKey = cert.CopyWithPrivateKey(ca)) + using (DSA certKey = certWithKey.GetDSAPrivateKey()) + { + DSAParameters certParameters = certKey.ExportParameters(true); + AssertExtensions.SequenceEqual(originalParameters.X, certParameters.X); + } + } + } + + [Fact] + public static void PrivateKey_FromCertificate_CanExportPrivate_ECDiffieHellman() + { + using (ECDsa ca = ECDsa.Create(ECCurve.NamedCurves.nistP256)) + using (ECDiffieHellman ecdh = ECDiffieHellman.Create(ECCurve.NamedCurves.nistP256)) + { + CertificateRequest issuerRequest = new CertificateRequest( + new X500DistinguishedName("CN=root"), + ca, + HashAlgorithmName.SHA256); + + issuerRequest.CertificateExtensions.Add( + new X509BasicConstraintsExtension(true, false, 0, true)); + + CertificateRequest request = new CertificateRequest( + new X500DistinguishedName("CN=potato"), + new PublicKey(ecdh), + HashAlgorithmName.SHA256); + + request.CertificateExtensions.Add( + new X509BasicConstraintsExtension(false, false, 0, true)); + request.CertificateExtensions.Add( + new X509KeyUsageExtension(X509KeyUsageFlags.KeyAgreement, true)); + + DateTimeOffset notBefore = DateTimeOffset.UtcNow; + DateTimeOffset notAfter = notBefore.AddDays(30); + byte[] serial = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + + using (X509Certificate2 issuer = issuerRequest.CreateSelfSigned(notBefore, notAfter)) + using (X509Certificate2 cert = request.Create(issuer, notBefore, notAfter, serial)) + using (X509Certificate2 certWithKey = cert.CopyWithPrivateKey(ecdh)) + using (ECDiffieHellman certKey = certWithKey.GetECDiffieHellmanPrivateKey()) + { + ECParameters certParameters = certKey.ExportParameters(true); + ECParameters originalParameters = ecdh.ExportParameters(true); + AssertExtensions.SequenceEqual(originalParameters.D, certParameters.D); + } + } + } + [Fact] public static void RaceUseAndDisposeDoesNotCrash() { @@ -79,7 +182,7 @@ public static void X509CertTest() // GetSerialNumber() returns in little-endian order. Array.Reverse(expectedSerial); AssertExtensions.SequenceEqual(expectedSerial, serial1); - + Assert.Equal("1.2.840.113549.1.1.1", cert.GetKeyAlgorithm()); int pklen = cert.GetPublicKey().Length;