diff --git a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs index 790ad3feb2939..be0b43d4200f3 100644 --- a/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs +++ b/src/libraries/System.Net.Quic/ref/System.Net.Quic.cs @@ -109,6 +109,7 @@ public override void Flush() { } public override void SetLength(long value) { } public void Shutdown() { } public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public System.Threading.Tasks.ValueTask WaitForWriteCompletionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public override void Write(byte[] buffer, int offset, int count) { } public override void Write(System.ReadOnlySpan buffer) { } public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs index 7487a958db91f..e409b962d0296 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs @@ -8,6 +8,8 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; +using System.Collections.Concurrent; +using System.Collections.Generic; namespace System.Net.Quic.Implementations.Mock { @@ -244,6 +246,9 @@ internal MockStream OpenStream(long streamId, bool bidirectional) } MockStream.StreamState streamState = new MockStream.StreamState(streamId, bidirectional); + // TODO Streams are never removed from a connection. Consider cleaning up in the future. + state._streams[streamState._streamId] = streamState; + Channel streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel; streamChannel.Writer.TryWrite(streamState); @@ -320,6 +325,12 @@ internal override ValueTask CloseAsync(long errorCode, CancellationToken cancell state._serverErrorCode = errorCode; DrainAcceptQueue(errorCode, -1); } + + foreach (KeyValuePair kvp in state._streams) + { + kvp.Value._outboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode)); + kvp.Value._inboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode)); + } } Dispose(); @@ -474,8 +485,9 @@ public PeerStreamLimit(int maxUnidirectional, int maxBidirectional) internal sealed class ConnectionState { public readonly SslApplicationProtocol _applicationProtocol; - public Channel _clientInitiatedStreamChannel; - public Channel _serverInitiatedStreamChannel; + public readonly Channel _clientInitiatedStreamChannel; + public readonly Channel _serverInitiatedStreamChannel; + public readonly ConcurrentDictionary _streams; public PeerStreamLimit? _clientStreamLimit; public PeerStreamLimit? _serverStreamLimit; @@ -490,6 +502,7 @@ public ConnectionState(SslApplicationProtocol applicationProtocol) _clientInitiatedStreamChannel = Channel.CreateUnbounded(); _serverInitiatedStreamChannel = Channel.CreateUnbounded(); _clientErrorCode = _serverErrorCode = -1; + _streams = new ConcurrentDictionary(); } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs index 588da85d32f42..fbace756f2fbd 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs @@ -164,6 +164,7 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory buffer, bool e if (endStream) { streamBuffer.EndWrite(); + WritesCompletedTcs.TrySetResult(); } } @@ -206,10 +207,12 @@ internal override void AbortRead(long errorCode) if (_isInitiator) { _streamState._outboundWriteErrorCode = errorCode; + _streamState._inboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode)); } else { _streamState._inboundWriteErrorCode = errorCode; + _streamState._outboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException()); } ReadStreamBuffer?.AbortRead(); @@ -220,10 +223,12 @@ internal override void AbortWrite(long errorCode) if (_isInitiator) { _streamState._outboundReadErrorCode = errorCode; + _streamState._outboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode)); } else { _streamState._inboundReadErrorCode = errorCode; + _streamState._inboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException()); } WriteStreamBuffer?.EndWrite(); @@ -251,6 +256,8 @@ internal override void Shutdown() { _connection.LocalStreamLimit!.Bidirectional.Decrement(); } + + WritesCompletedTcs.TrySetResult(); } private void CheckDisposed() @@ -283,6 +290,17 @@ public override ValueTask DisposeAsync() return default; } + internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) + { + CheckDisposed(); + + return new ValueTask(WritesCompletedTcs.Task); + } + + private TaskCompletionSource WritesCompletedTcs => _isInitiator + ? _streamState._outboundWritesCompletedTcs + : _streamState._inboundWritesCompletedTcs; + internal sealed class StreamState { public readonly long _streamId; @@ -292,6 +310,8 @@ internal sealed class StreamState public long _inboundReadErrorCode; public long _outboundWriteErrorCode; public long _inboundWriteErrorCode; + public TaskCompletionSource _outboundWritesCompletedTcs; + public TaskCompletionSource _inboundWritesCompletedTcs; private const int InitialBufferSize = #if DEBUG @@ -310,6 +330,8 @@ public StreamState(long streamId, bool bidirectional) _streamId = streamId; _outboundStreamBuffer = new StreamBuffer(initialBufferSize: InitialBufferSize, maxBufferSize: MaxBufferSize); _inboundStreamBuffer = (bidirectional ? new StreamBuffer() : null); + _outboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _inboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs index 83ccf08906c16..54f0ed3f4ec85 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs @@ -69,6 +69,11 @@ private sealed class State // Resettable completions to be used for multiple calls to send. public readonly ResettableCompletionSource SendResettableCompletionSource = new ResettableCompletionSource(); + public ShutdownWriteState ShutdownWriteState; + + // Set once writes have been shutdown. + public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public ShutdownState ShutdownState; // The value makes sure that we release the handles only once. public int ShutdownDone; @@ -577,12 +582,26 @@ internal override void AbortWrite(long errorCode) return; } + bool shouldComplete = false; + lock (_state) { if (_state.SendState < SendState.Aborted) { _state.SendState = SendState.Aborted; } + + if (_state.ShutdownWriteState == ShutdownWriteState.None) + { + _state.ShutdownWriteState = ShutdownWriteState.Canceled; + shouldComplete = true; + } + } + + if (shouldComplete) + { + _state.ShutdownWriteCompletionSource.SetException( + ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Write was aborted."))); } StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode); @@ -629,6 +648,23 @@ internal override async ValueTask ShutdownCompleted(CancellationToken cancellati await _state.ShutdownCompletionSource.Task.ConfigureAwait(false); } + internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) + { + // TODO: What should happen if this is called for a unidirectional stream and there are no writes? + + ThrowIfDisposed(); + + lock (_state) + { + if (_state.ShutdownWriteState == ShutdownWriteState.ConnectionClosed) + { + throw GetConnectionAbortedException(_state); + } + } + + return new ValueTask(_state.ShutdownWriteCompletionSource.Task.WaitAsync(cancellationToken)); + } + internal override void Shutdown() { ThrowIfDisposed(); @@ -861,6 +897,11 @@ private static uint HandleEvent(State state, ref StreamEvent evt) // Peer has stopped receiving data, don't send anymore. case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED: return HandleEventPeerRecvAborted(state, ref evt); + // Occurs when shutdown is completed for the send side. + // This only happens for shutdown on sending, not receiving + // Receive shutdown can only be abortive. + case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE: + return HandleEventSendShutdownComplete(state, ref evt); // Shutdown for both sending and receiving is completed. case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE: return HandleEventShutdownComplete(state, ref evt); @@ -993,23 +1034,37 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt) private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt) { - bool shouldComplete = false; + bool shouldSendComplete = false; + bool shouldShutdownWriteComplete = false; lock (state) { if (state.SendState == SendState.None || state.SendState == SendState.Pending) { - shouldComplete = true; + shouldSendComplete = true; + } + + if (state.ShutdownWriteState == ShutdownWriteState.None) + { + state.ShutdownWriteState = ShutdownWriteState.Canceled; + shouldShutdownWriteComplete = true; } + state.SendState = SendState.Aborted; state.SendErrorCode = (long)evt.Data.PeerReceiveAborted.ErrorCode; } - if (shouldComplete) + if (shouldSendComplete) { state.SendResettableCompletionSource.CompleteException( ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode))); } + if (shouldShutdownWriteComplete) + { + state.ShutdownWriteCompletionSource.SetException( + ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode))); + } + return MsQuicStatusCodes.Success; } @@ -1021,6 +1076,38 @@ private static uint HandleEventStartComplete(State state, ref StreamEvent evt) return MsQuicStatusCodes.Success; } + private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent evt) + { + // Graceful will be false in three situations: + // 1. The peer aborted reads and the PEER_RECEIVE_ABORTED event was raised. + // ShutdownWriteCompletionSource is already complete with an error. + // 2. We aborted writes. + // ShutdownWriteCompletionSource is already complete with an error. + // 3. The connection was closed. + // SHUTDOWN_COMPLETE event will be raised immediately after this event. It will handle completing with an error. + // + // Only use this event with sends gracefully completed. + if (evt.Data.SendShutdownComplete.Graceful != 0) + { + bool shouldComplete = false; + lock (state) + { + if (state.ShutdownWriteState == ShutdownWriteState.None) + { + state.ShutdownWriteState = ShutdownWriteState.Finished; + shouldComplete = true; + } + } + + if (shouldComplete) + { + state.ShutdownWriteCompletionSource.SetResult(); + } + } + + return MsQuicStatusCodes.Success; + } + private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt) { StreamEventDataShutdownComplete shutdownCompleteEvent = evt.Data.ShutdownComplete; @@ -1031,6 +1118,7 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt } bool shouldReadComplete = false; + bool shouldShutdownWriteComplete = false; bool shouldShutdownComplete = false; lock (state) @@ -1040,6 +1128,15 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted); + if (state.ShutdownWriteState == ShutdownWriteState.None) + { + // TODO: We can get to this point if the stream is unidirectional and there are no writes. + // Consider what is the best behavior here with write shutdown and the read side of + // unidirecitonal streams in the future. + state.ShutdownWriteState = ShutdownWriteState.Finished; + shouldShutdownWriteComplete = true; + } + if (state.ShutdownState == ShutdownState.None) { state.ShutdownState = ShutdownState.Finished; @@ -1052,6 +1149,11 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt state.ReceiveResettableCompletionSource.Complete(0); } + if (shouldShutdownWriteComplete) + { + state.ShutdownWriteCompletionSource.SetResult(); + } + if (shouldShutdownComplete) { state.ShutdownCompletionSource.SetResult(); @@ -1361,6 +1463,7 @@ private static uint HandleEventConnectionClose(State state) bool shouldCompleteRead = false; bool shouldCompleteSend = false; + bool shouldCompleteShutdownWrite = false; bool shouldCompleteShutdown = false; lock (state) @@ -1373,6 +1476,12 @@ private static uint HandleEventConnectionClose(State state) } state.SendState = SendState.ConnectionClosed; + if (state.ShutdownWriteState == ShutdownWriteState.None) + { + shouldCompleteShutdownWrite = true; + } + state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed; + if (state.ShutdownState == ShutdownState.None) { shouldCompleteShutdown = true; @@ -1392,6 +1501,12 @@ private static uint HandleEventConnectionClose(State state) ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); } + if (shouldCompleteShutdownWrite) + { + state.ShutdownWriteCompletionSource.SetException( + ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state))); + } + if (shouldCompleteShutdown) { state.ShutdownCompletionSource.SetException( @@ -1493,6 +1608,14 @@ private enum ReadState Closed } + private enum ShutdownWriteState + { + None = 0, + Canceled, + Finished, + ConnectionClosed + } + private enum ShutdownState { None = 0, diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs index 66c9a8b6e51c2..215ce1304c1a4 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs @@ -47,6 +47,8 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default); + internal abstract ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default); + internal abstract void Shutdown(); internal abstract void Flush(); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index 8a6dbe496ed4a..912d32c9ad889 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -117,6 +117,8 @@ public override int WriteTimeout public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken); + public ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) => _provider.WaitForWriteCompletionAsync(cancellationToken); + public void Shutdown() => _provider.Shutdown(); protected override void Dispose(bool disposing) diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index 098bd4d2af1d2..ed371262c924c 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -769,6 +769,230 @@ await RunClientServer( } ); } + + [Fact] + public async Task WaitForWriteCompletionAsync_ClientReadAborted_Throws() + { + const int ExpectedErrorCode = 0xfffffff; + + TaskCompletionSource waitForAbortTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + SemaphoreSlim sem = new SemaphoreSlim(0); + + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1], endStream: true); + + // Wait for server to read data + await sem.WaitAsync(); + + clientStream.AbortRead(ExpectedErrorCode); + }, + async serverStream => + { + var writeCompletionTask = ReleaseOnWriteCompletionAsync(); + + int received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(0, received); + + Assert.False(writeCompletionTask.IsCompleted, "Server is still writing."); + + // Tell client that data has been read and it can abort its reads. + sem.Release(); + + long sendAbortErrorCode = await waitForAbortTcs.Task; + Assert.Equal(ExpectedErrorCode, sendAbortErrorCode); + + await writeCompletionTask; + + async ValueTask ReleaseOnWriteCompletionAsync() + { + try + { + await serverStream.WaitForWriteCompletionAsync(); + waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw stream aborted.")); + } + catch (QuicStreamAbortedException ex) + { + waitForAbortTcs.SetResult(ex.ErrorCode); + } + catch (Exception ex) + { + waitForAbortTcs.SetException(ex); + } + }; + }); + } + + [Fact] + public async Task WaitForWriteCompletionAsync_ServerWriteAborted_Throws() + { + const int ExpectedErrorCode = 0xfffffff; + + TaskCompletionSource waitForAbortTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1], endStream: true); + }, + async serverStream => + { + var writeCompletionTask = ReleaseOnWriteCompletionAsync(); + + int received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(0, received); + + Assert.False(writeCompletionTask.IsCompleted, "Server is still writing."); + + serverStream.AbortWrite(ExpectedErrorCode); + + await waitForAbortTcs.Task; + await writeCompletionTask; + + async ValueTask ReleaseOnWriteCompletionAsync() + { + try + { + await serverStream.WaitForWriteCompletionAsync(); + waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw stream aborted.")); + } + catch (QuicOperationAbortedException) + { + waitForAbortTcs.SetResult(); + } + catch (Exception ex) + { + waitForAbortTcs.SetException(ex); + } + }; + }); + } + + [Fact] + public async Task WaitForWriteCompletionAsync_ServerShutdown_Success() + { + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1], endStream: true); + + int readCount = await clientStream.ReadAsync(new byte[1]); + Assert.Equal(1, readCount); + + readCount = await clientStream.ReadAsync(new byte[1]); + Assert.Equal(0, readCount); + }, + async serverStream => + { + var writeCompletionTask = serverStream.WaitForWriteCompletionAsync(); + + int received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(0, received); + + await serverStream.WriteAsync(new byte[1]); + + Assert.False(writeCompletionTask.IsCompleted, "Server is still writing."); + + serverStream.Shutdown(); + + await writeCompletionTask; + }); + } + + [Fact] + public async Task WaitForWriteCompletionAsync_GracefulShutdown_Success() + { + await RunBidirectionalClientServer( + async clientStream => + { + await clientStream.WriteAsync(new byte[1], endStream: true); + + int readCount = await clientStream.ReadAsync(new byte[1]); + Assert.Equal(1, readCount); + + readCount = await clientStream.ReadAsync(new byte[1]); + Assert.Equal(0, readCount); + }, + async serverStream => + { + var writeCompletionTask = serverStream.WaitForWriteCompletionAsync(); + + int received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + received = await serverStream.ReadAsync(new byte[1]); + Assert.Equal(0, received); + + Assert.False(writeCompletionTask.IsCompleted, "Server is still writing."); + + await serverStream.WriteAsync(new byte[1], endStream: true); + + await writeCompletionTask; + }); + } + + [Fact] + public async Task WaitForWriteCompletionAsync_ConnectionClosed_Throws() + { + const int ExpectedErrorCode = 0xfffffff; + + using SemaphoreSlim sem = new SemaphoreSlim(0); + TaskCompletionSource waitForAbortTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await RunClientServer( + serverFunction: async connection => + { + await using QuicStream stream = await connection.AcceptStreamAsync(); + + var writeCompletionTask = ReleaseOnWriteCompletionAsync(); + + int received = await stream.ReadAsync(new byte[1]); + Assert.Equal(1, received); + received = await stream.ReadAsync(new byte[1]); + Assert.Equal(0, received); + + // Signal that the server has read data + sem.Release(); + + long closeErrorCode = await waitForAbortTcs.Task; + Assert.Equal(ExpectedErrorCode, closeErrorCode); + + await writeCompletionTask; + + async ValueTask ReleaseOnWriteCompletionAsync() + { + try + { + await stream.WaitForWriteCompletionAsync(); + waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw connection aborted.")); + } + catch (QuicConnectionAbortedException ex) + { + waitForAbortTcs.SetResult(ex.ErrorCode); + } + }; + }, + clientFunction: async connection => + { + await using QuicStream stream = connection.OpenBidirectionalStream(); + + await stream.WriteAsync(new byte[1], endStream: true); + + await stream.WaitForWriteCompletionAsync(); + + // Wait for the server to read data before closing the connection + await sem.WaitAsync(); + + await connection.CloseAsync(ExpectedErrorCode); + } + ); + } } public sealed class QuicStreamTests_MockProvider : QuicStreamTests