Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUIC] Add QuicStream.WaitForWriteCompletionAsync #58236

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ public override void Flush() { }
public override void SetLength(long value) { }
public void Shutdown() { }
public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask WaitForWriteCompletionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
JamesNK marked this conversation as resolved.
Show resolved Hide resolved
public override void Write(byte[] buffer, int offset, int count) { }
public override void Write(System.ReadOnlySpan<byte> buffer) { }
public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using System.Collections.Concurrent;
using System.Collections.Generic;

namespace System.Net.Quic.Implementations.Mock
{
Expand Down Expand Up @@ -244,6 +246,9 @@ internal MockStream OpenStream(long streamId, bool bidirectional)
}

MockStream.StreamState streamState = new MockStream.StreamState(streamId, bidirectional);
// TODO Streams are never removed from a connection. Consider cleaning up in the future.
state._streams[streamState._streamId] = streamState;

Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
streamChannel.Writer.TryWrite(streamState);

Expand Down Expand Up @@ -320,6 +325,12 @@ internal override ValueTask CloseAsync(long errorCode, CancellationToken cancell
state._serverErrorCode = errorCode;
DrainAcceptQueue(errorCode, -1);
}

foreach (KeyValuePair<long, MockStream.StreamState> kvp in state._streams)
JamesNK marked this conversation as resolved.
Show resolved Hide resolved
{
kvp.Value._outboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
kvp.Value._inboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
}
}

Dispose();
Expand Down Expand Up @@ -474,8 +485,9 @@ public PeerStreamLimit(int maxUnidirectional, int maxBidirectional)
internal sealed class ConnectionState
{
public readonly SslApplicationProtocol _applicationProtocol;
public Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
public Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
public readonly Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
public readonly Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
public readonly ConcurrentDictionary<long, MockStream.StreamState> _streams;

public PeerStreamLimit? _clientStreamLimit;
public PeerStreamLimit? _serverStreamLimit;
Expand All @@ -490,6 +502,7 @@ public ConnectionState(SslApplicationProtocol applicationProtocol)
_clientInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
_serverInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
_clientErrorCode = _serverErrorCode = -1;
_streams = new ConcurrentDictionary<long, MockStream.StreamState>();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool e
if (endStream)
{
streamBuffer.EndWrite();
WritesCompletedTcs.TrySetResult();
JamesNK marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -206,10 +207,12 @@ internal override void AbortRead(long errorCode)
if (_isInitiator)
{
_streamState._outboundWriteErrorCode = errorCode;
_streamState._inboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
}
else
{
_streamState._inboundWriteErrorCode = errorCode;
_streamState._outboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
}

ReadStreamBuffer?.AbortRead();
Expand All @@ -220,10 +223,12 @@ internal override void AbortWrite(long errorCode)
if (_isInitiator)
{
_streamState._outboundReadErrorCode = errorCode;
_streamState._outboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
}
else
{
_streamState._inboundReadErrorCode = errorCode;
_streamState._inboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
}

WriteStreamBuffer?.EndWrite();
Expand Down Expand Up @@ -251,6 +256,8 @@ internal override void Shutdown()
{
_connection.LocalStreamLimit!.Bidirectional.Decrement();
}

WritesCompletedTcs.TrySetResult();
}

private void CheckDisposed()
Expand Down Expand Up @@ -283,6 +290,17 @@ public override ValueTask DisposeAsync()
return default;
}

internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
{
CheckDisposed();

return new ValueTask(WritesCompletedTcs.Task);
}

private TaskCompletionSource WritesCompletedTcs => _isInitiator
? _streamState._outboundWritesCompletedTcs
: _streamState._inboundWritesCompletedTcs;

internal sealed class StreamState
{
public readonly long _streamId;
Expand All @@ -292,6 +310,8 @@ internal sealed class StreamState
public long _inboundReadErrorCode;
public long _outboundWriteErrorCode;
public long _inboundWriteErrorCode;
public TaskCompletionSource _outboundWritesCompletedTcs;
public TaskCompletionSource _inboundWritesCompletedTcs;

private const int InitialBufferSize =
#if DEBUG
Expand All @@ -310,6 +330,8 @@ public StreamState(long streamId, bool bidirectional)
_streamId = streamId;
_outboundStreamBuffer = new StreamBuffer(initialBufferSize: InitialBufferSize, maxBufferSize: MaxBufferSize);
_inboundStreamBuffer = (bidirectional ? new StreamBuffer() : null);
_outboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
_inboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ private sealed class State
// Resettable completions to be used for multiple calls to send.
public readonly ResettableCompletionSource<uint> SendResettableCompletionSource = new ResettableCompletionSource<uint>();

public ShutdownWriteState ShutdownWriteState;

// Set once writes have been shutdown.
public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

public ShutdownState ShutdownState;
// The value makes sure that we release the handles only once.
public int ShutdownDone;
Expand Down Expand Up @@ -577,12 +582,26 @@ internal override void AbortWrite(long errorCode)
return;
}

bool shouldComplete = false;

lock (_state)
{
if (_state.SendState < SendState.Aborted)
{
_state.SendState = SendState.Aborted;
}

if (_state.ShutdownWriteState == ShutdownWriteState.None)
{
_state.ShutdownWriteState = ShutdownWriteState.Canceled;
shouldComplete = true;
}
}

if (shouldComplete)
{
_state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Write was aborted.")));
}

StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode);
Expand Down Expand Up @@ -629,6 +648,23 @@ internal override async ValueTask ShutdownCompleted(CancellationToken cancellati
await _state.ShutdownCompletionSource.Task.ConfigureAwait(false);
}

internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
{
// TODO: What should happen if this is called for a unidirectional stream and there are no writes?

ThrowIfDisposed();

lock (_state)
{
if (_state.ShutdownWriteState == ShutdownWriteState.ConnectionClosed)
{
throw GetConnectionAbortedException(_state);
}
}

return new ValueTask(_state.ShutdownWriteCompletionSource.Task.WaitAsync(cancellationToken));
}

internal override void Shutdown()
{
ThrowIfDisposed();
Expand Down Expand Up @@ -861,6 +897,11 @@ private static uint HandleEvent(State state, ref StreamEvent evt)
// Peer has stopped receiving data, don't send anymore.
case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED:
return HandleEventPeerRecvAborted(state, ref evt);
// Occurs when shutdown is completed for the send side.
// This only happens for shutdown on sending, not receiving
// Receive shutdown can only be abortive.
case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE:
return HandleEventSendShutdownComplete(state, ref evt);
// Shutdown for both sending and receiving is completed.
case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE:
return HandleEventShutdownComplete(state, ref evt);
Expand Down Expand Up @@ -993,23 +1034,37 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt)

private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt)
{
bool shouldComplete = false;
bool shouldSendComplete = false;
bool shouldShutdownWriteComplete = false;
lock (state)
{
if (state.SendState == SendState.None || state.SendState == SendState.Pending)
{
shouldComplete = true;
shouldSendComplete = true;
}

if (state.ShutdownWriteState == ShutdownWriteState.None)
{
state.ShutdownWriteState = ShutdownWriteState.Canceled;
shouldShutdownWriteComplete = true;
}

state.SendState = SendState.Aborted;
state.SendErrorCode = (long)evt.Data.PeerReceiveAborted.ErrorCode;
}

if (shouldComplete)
if (shouldSendComplete)
{
state.SendResettableCompletionSource.CompleteException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
}

if (shouldShutdownWriteComplete)
{
state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
}

return MsQuicStatusCodes.Success;
}

Expand All @@ -1021,6 +1076,38 @@ private static uint HandleEventStartComplete(State state, ref StreamEvent evt)
return MsQuicStatusCodes.Success;
}

private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent evt)
{
// Graceful will be false in three situations:
JamesNK marked this conversation as resolved.
Show resolved Hide resolved
// 1. The peer aborted reads and the PEER_RECEIVE_ABORTED event was raised.
// ShutdownWriteCompletionSource is already complete with an error.
// 2. We aborted writes.
// ShutdownWriteCompletionSource is already complete with an error.
// 3. The connection was closed.
// SHUTDOWN_COMPLETE event will be raised immediately after this event. It will handle completing with an error.
//
// Only use this event with sends gracefully completed.
if (evt.Data.SendShutdownComplete.Graceful != 0)
{
bool shouldComplete = false;
lock (state)
{
if (state.ShutdownWriteState == ShutdownWriteState.None)
{
state.ShutdownWriteState = ShutdownWriteState.Finished;
shouldComplete = true;
}
}

if (shouldComplete)
{
state.ShutdownWriteCompletionSource.SetResult();
}
}

return MsQuicStatusCodes.Success;
}

private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt)
{
StreamEventDataShutdownComplete shutdownCompleteEvent = evt.Data.ShutdownComplete;
Expand All @@ -1031,6 +1118,7 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt
}

bool shouldReadComplete = false;
bool shouldShutdownWriteComplete = false;
bool shouldShutdownComplete = false;

lock (state)
Expand All @@ -1040,6 +1128,15 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt

shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted);

if (state.ShutdownWriteState == ShutdownWriteState.None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this state is expected here? As you've mentioned in comment for HandleEventSendShutdownComplete, by the time we receive SHUTDOWN_COMPLETE event, we should either completed ShutdownWriteState already, or it is a connection close, which is handled separately in HandleEventConnectionClose. I think we may leave the logic here, but guard it with Debug.Assert... what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even as a fallback, I think it is especially strange to complete it successfully here... IMO successful completion should only happen in HandleEventSendShutdownComplete.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one situation where you can get to this point without HandleEventSendShutdownComplete: if the stream is unidirectional and there are no writes.

Perhaps WaitForWriteCompleteAsync should error if it is called for this type of stream. I'm going to add that as a TODO comment. I'll also explain it in the logic inside HandleEventShutdownComplete. The exact behavior can be figured out in .NET 7.

{
// TODO: We can get to this point if the stream is unidirectional and there are no writes.
// Consider what is the best behavior here with write shutdown and the read side of
// unidirecitonal streams in the future.
state.ShutdownWriteState = ShutdownWriteState.Finished;
shouldShutdownWriteComplete = true;
}

if (state.ShutdownState == ShutdownState.None)
{
state.ShutdownState = ShutdownState.Finished;
Expand All @@ -1052,6 +1149,11 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt
state.ReceiveResettableCompletionSource.Complete(0);
}

if (shouldShutdownWriteComplete)
{
state.ShutdownWriteCompletionSource.SetResult();
}

if (shouldShutdownComplete)
{
state.ShutdownCompletionSource.SetResult();
Expand Down Expand Up @@ -1361,6 +1463,7 @@ private static uint HandleEventConnectionClose(State state)

bool shouldCompleteRead = false;
bool shouldCompleteSend = false;
bool shouldCompleteShutdownWrite = false;
bool shouldCompleteShutdown = false;

lock (state)
Expand All @@ -1373,6 +1476,12 @@ private static uint HandleEventConnectionClose(State state)
}
state.SendState = SendState.ConnectionClosed;

if (state.ShutdownWriteState == ShutdownWriteState.None)
{
shouldCompleteShutdownWrite = true;
}
state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed;

if (state.ShutdownState == ShutdownState.None)
{
shouldCompleteShutdown = true;
Expand All @@ -1392,6 +1501,12 @@ private static uint HandleEventConnectionClose(State state)
ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
}

if (shouldCompleteShutdownWrite)
{
state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
}

if (shouldCompleteShutdown)
{
state.ShutdownCompletionSource.SetException(
Expand Down Expand Up @@ -1493,6 +1608,14 @@ private enum ReadState
Closed
}

private enum ShutdownWriteState
{
None = 0,
Canceled,
Finished,
ConnectionClosed
}

private enum ShutdownState
{
None = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable

internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default);

internal abstract ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default);

internal abstract void Shutdown();

internal abstract void Flush();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ public override int WriteTimeout

public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken);

public ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) => _provider.WaitForWriteCompletionAsync(cancellationToken);

public void Shutdown() => _provider.Shutdown();

protected override void Dispose(bool disposing)
Expand Down
Loading