diff --git a/src/System.Private.ServiceModel/src/Internals/System/Runtime/TimeoutHelper.cs b/src/System.Private.ServiceModel/src/Internals/System/Runtime/TimeoutHelper.cs index cf8ea269312..7f6c1814fda 100644 --- a/src/System.Private.ServiceModel/src/Internals/System/Runtime/TimeoutHelper.cs +++ b/src/System.Private.ServiceModel/src/Internals/System/Runtime/TimeoutHelper.cs @@ -8,54 +8,69 @@ namespace System.Runtime { - public struct TimeoutHelper + public struct TimeoutHelper : IDisposable { - public static readonly TimeSpan MaxWait = TimeSpan.FromMilliseconds(Int32.MaxValue); - private static readonly CancellationToken s_precancelledToken = new CancellationToken(true); - private bool _cancellationTokenInitialized; private bool _deadlineSet; private CancellationToken _cancellationToken; + private CancellationTokenSource _cts; private DateTime _deadline; private TimeSpan _originalTimeout; + public static readonly TimeSpan MaxWait = TimeSpan.FromMilliseconds(Int32.MaxValue); + private static Action s_cancelOnTimeout = state => ((TimeoutHelper)state)._cts.Cancel(); public TimeoutHelper(TimeSpan timeout) { Contract.Assert(timeout >= TimeSpan.Zero, "timeout must be non-negative"); _cancellationTokenInitialized = false; + _cts = null; _originalTimeout = timeout; _deadline = DateTime.MaxValue; _deadlineSet = (timeout == TimeSpan.MaxValue); } - public CancellationToken GetCancellationToken() + // No locks as we expect this class to be used linearly. + // If another CancellationTokenSource is created, we might have a CancellationToken outstanding + // that isn't cancelled if _cts.Cancel() is called. This happens only on the Abort paths, so it's not an issue. + private void InitializeCancellationToken(TimeSpan timeout) { - return GetCancellationTokenAsync().Result; + if (timeout == TimeSpan.MaxValue || timeout == Timeout.InfiniteTimeSpan) + { + _cancellationToken = CancellationToken.None; + } + else if (timeout > TimeSpan.Zero) + { + _cts = new CancellationTokenSource(); + _cancellationToken = _cts.Token; + TimeoutTokenSource.FromTimeout((int)timeout.TotalMilliseconds).Register(s_cancelOnTimeout, this); + } + else + { + _cancellationToken = new CancellationToken(true); + } + _cancellationTokenInitialized = true; } - public async Task GetCancellationTokenAsync() + public CancellationToken CancellationToken { - if (!_cancellationTokenInitialized) + get { - var timeout = RemainingTime(); - if (timeout >= MaxWait || timeout == Timeout.InfiniteTimeSpan) - { - _cancellationToken = CancellationToken.None; - } - else if (timeout > TimeSpan.Zero) + if (!_cancellationTokenInitialized) { - _cancellationToken = await TimeoutTokenSource.FromTimeoutAsync((int)timeout.TotalMilliseconds); + InitializeCancellationToken(this.RemainingTime()); } - else - { - _cancellationToken = s_precancelledToken; - } - _cancellationTokenInitialized = true; + return _cancellationToken; } + } - return _cancellationToken; + public void CancelCancellationToken(bool throwOnFirstException = false) + { + if (_cts != null) + { + _cts.Cancel(throwOnFirstException); + } } public TimeSpan OriginalTimeout @@ -179,6 +194,16 @@ private void SetDeadline() _deadlineSet = true; } + public void Dispose() + { + if (_cancellationTokenInitialized && _cts !=null) + { + _cts.Dispose(); + _cancellationTokenInitialized = false; + _cancellationToken = default(CancellationToken); + } + } + public static void ThrowIfNegativeArgument(TimeSpan timeout) { ThrowIfNegativeArgument(timeout, "timeout"); @@ -235,29 +260,9 @@ internal static TimeoutException CreateEnterTimedOutException(TimeSpan timeout) /// internal static class TimeoutTokenSource { - /// - /// These are constants use to calculate timeout coalescing, for more description see method FromTimeoutAsync - /// - private const int CoalescingFactor = 15; - private const int GranularityFactor = 2000; - private const int SegmentationFactor = CoalescingFactor * GranularityFactor; - + private const int COALESCING_SPAN_MS = 15; private static readonly ConcurrentDictionary> s_tokenCache = new ConcurrentDictionary>(); - private static readonly Action s_deregisterToken = (object state) => - { - var args = (Tuple)state; - - Task ignored; - try - { - s_tokenCache.TryRemove(args.Item1, out ignored); - } - finally - { - args.Item2.Dispose(); - } - }; public static CancellationToken FromTimeout(int millisecondsTimeout) { @@ -273,25 +278,10 @@ public static Task FromTimeoutAsync(int millisecondsTimeout) throw new ArgumentOutOfRangeException("Invalid millisecondsTimeout value " + millisecondsTimeout); } - - // To prevent s_tokenCache growing too large, we have to adjust the granularity of the our coalesce depending - // on the value of millisecondsTimeout. The coalescing span scales proportionally with millisecondsTimeout which - // would garentee constant s_tokenCache size in the case where similar millisecondsTimeout values are accepted. - // If the method is given a wildly different millisecondsTimeout values all the time, the dictionary would still - // only grow logarithmically with respect to the range of the input values uint currentTime = (uint)Environment.TickCount; long targetTime = millisecondsTimeout + currentTime; - - // Formula for our coalescing span: - // Divide millisecondsTimeout by SegmentationFactor and take the highest bit and then multiply CoalescingFactor back - var segmentValue = millisecondsTimeout / SegmentationFactor; - var coalescingSpanMs = CoalescingFactor; - while (segmentValue > 0) - { - segmentValue >>= 1; - coalescingSpanMs <<= 1; - } - targetTime = ((targetTime + (coalescingSpanMs - 1)) / coalescingSpanMs) * coalescingSpanMs; + // round the targetTime up to the next closest 15ms + targetTime = ((targetTime + (COALESCING_SPAN_MS - 1)) / COALESCING_SPAN_MS) * COALESCING_SPAN_MS; Task tokenTask; @@ -304,11 +294,13 @@ public static Task FromTimeoutAsync(int millisecondsTimeout) { // Since this thread was successful reserving a spot in the cache, it would be the only thread // that construct the CancellationTokenSource - var tokenSource = new CancellationTokenSource((int)(targetTime - currentTime)); - var token = tokenSource.Token; + var token = new CancellationTokenSource((int)(targetTime - currentTime)).Token; // Clean up cache when Token is canceled - token.Register(s_deregisterToken, Tuple.Create(targetTime, tokenSource)); + token.Register(t => { + Task ignored; + s_tokenCache.TryRemove((long)t, out ignored); + }, targetTime); // set the result so other thread may observe the token, and return tcs.TrySetResult(token); diff --git a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/CoreClrClientWebSocketFactory.cs b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/CoreClrClientWebSocketFactory.cs index 01cc0b73a5b..1d9f90707b4 100644 --- a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/CoreClrClientWebSocketFactory.cs +++ b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/CoreClrClientWebSocketFactory.cs @@ -24,8 +24,7 @@ public override async Task CreateWebSocketAsync(Uri address, WebHeade webSocket.Options.SetRequestHeader(header, headers[header]); } - var cancelToken = await timeoutHelper.GetCancellationTokenAsync(); - await webSocket.ConnectAsync(address, cancelToken); + await webSocket.ConnectAsync(address, timeoutHelper.CancellationToken); return webSocket; } } diff --git a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/HttpChannelFactory.cs b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/HttpChannelFactory.cs index 0795aeeb8bb..83b43fe4ff7 100644 --- a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/HttpChannelFactory.cs +++ b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/HttpChannelFactory.cs @@ -617,18 +617,6 @@ internal virtual void OnHttpRequestCompleted(HttpRequestMessage request) internal class HttpClientChannelAsyncRequest : IAsyncRequest { - private static readonly Action s_cancelCts = state => - { - - try - { - ((CancellationTokenSource)state).Cancel(); - } - catch (ObjectDisposedException) - { - // ignore - } - }; private HttpClientRequestChannel _channel; private HttpChannelFactory _factory; private EndpointAddress _to; @@ -639,7 +627,6 @@ internal class HttpClientChannelAsyncRequest : IAsyncRequest private TimeoutHelper _timeoutHelper; private int _httpRequestCompleted; private HttpClient _httpClient; - private readonly CancellationTokenSource _httpSendCts; public HttpClientChannelAsyncRequest(HttpClientRequestChannel channel) { @@ -648,7 +635,6 @@ public HttpClientChannelAsyncRequest(HttpClientRequestChannel channel) _via = channel.Via; _factory = channel.Factory; _httpClient = _factory.GetHttpClient(); - _httpSendCts = new CancellationTokenSource(); } public async Task SendRequestAsync(Message message, TimeoutHelper timeoutHelper) @@ -687,13 +673,9 @@ public async Task SendRequestAsync(Message message, TimeoutHelper timeoutHelper) bool success = false; - var cancelTokenTask = _timeoutHelper.GetCancellationTokenAsync(); - try { - var timeoutToken = await cancelTokenTask; - timeoutToken.Register(s_cancelCts, _httpSendCts); - _httpResponseMessage = await _httpClient.SendAsync(_httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, _httpSendCts.Token); + _httpResponseMessage = await _httpClient.SendAsync(_httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, _timeoutHelper.CancellationToken); // As we have the response message and no exceptions have been thrown, the request message has completed it's job. // Calling Dispose() on the request message to free up resources in HttpContent, but keeping the object around // as we can still query properties once dispose'd. @@ -707,7 +689,7 @@ public async Task SendRequestAsync(Message message, TimeoutHelper timeoutHelper) } catch (OperationCanceledException) { - if (cancelTokenTask.Result.IsCancellationRequested) + if (_timeoutHelper.CancellationToken.IsCancellationRequested) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(SR.Format( SR.HttpRequestTimedOut, _httpRequestMessage.RequestUri, _timeoutHelper.OriginalTimeout))); @@ -737,12 +719,11 @@ public async Task SendRequestAsync(Message message, TimeoutHelper timeoutHelper) private void Cleanup() { - s_cancelCts(_httpSendCts); - if (_httpRequestMessage != null) { var httpRequestMessageSnapshot = _httpRequestMessage; _httpRequestMessage = null; + _timeoutHelper.CancelCancellationToken(false); TryCompleteHttpRequest(httpRequestMessageSnapshot); httpRequestMessageSnapshot.Dispose(); } @@ -771,8 +752,7 @@ public async Task ReceiveReplyAsync(TimeoutHelper timeoutHelper) } catch (OperationCanceledException) { - var cancelToken = _timeoutHelper.GetCancellationToken(); - if (cancelToken.IsCancellationRequested) + if (_timeoutHelper.CancellationToken.IsCancellationRequested) { throw DiagnosticUtility.ExceptionUtility.ThrowHelperError(new TimeoutException(SR.Format( SR.HttpResponseTimedOut, _httpRequestMessage.RequestUri, timeoutHelper.OriginalTimeout))); @@ -996,8 +976,7 @@ private async Task SendPreauthenticationHeadRequestIfNeeded() RequestUri = requestUri }; - var cancelToken = await _timeoutHelper.GetCancellationTokenAsync(); - await _httpClient.SendAsync(headHttpRequestMessage, cancelToken); + await _httpClient.SendAsync(headHttpRequestMessage, _timeoutHelper.CancellationToken); } private bool AuthenticationSchemeMayRequireResend() diff --git a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/TimeoutStream.cs b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/TimeoutStream.cs index 3cfad90ecbe..e68f002c7ac 100644 --- a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/TimeoutStream.cs +++ b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/TimeoutStream.cs @@ -33,14 +33,13 @@ public override int Read(byte[] buffer, int offset, int count) return ReadAsyncInternal(buffer, offset, count, CancellationToken.None).WaitForCompletion(); } - public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { // Supporting a passed in cancellationToken as well as honoring the timeout token in this class would require // creating a linked token source on every call which is extra allocation and needs disposal. As this is an // internal classs, it's okay to add this extra constraint to usage of this method. Contract.Assert(!cancellationToken.CanBeCanceled, "cancellationToken shouldn't be cancellable"); - var cancelToken = await _timeoutHelper.GetCancellationTokenAsync(); - return await base.ReadAsync(buffer, offset, count, cancelToken); + return base.ReadAsync(buffer, offset, count, _timeoutHelper.CancellationToken); } private async Task ReadAsyncInternal(byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -54,14 +53,13 @@ public override void Write(byte[] buffer, int offset, int count) WriteAsyncInternal(buffer, offset, count, CancellationToken.None).WaitForCompletion(); } - public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { // Supporting a passed in cancellationToken as well as honoring the timeout token in this class would require // creating a linked token source on every call which is extra allocation and needs disposal. As this is an // internal classs, it's okay to add this extra constraint to usage of this method. Contract.Assert(!cancellationToken.CanBeCanceled, "cancellationToken shouldn't be cancellable"); - var cancelToken = await _timeoutHelper.GetCancellationTokenAsync(); - await base.WriteAsync(buffer, offset, count, cancelToken); + return base.WriteAsync(buffer, offset, count, _timeoutHelper.CancellationToken); } private async Task WriteAsyncInternal(byte[] buffer, int offset, int count, CancellationToken cancellationToken) @@ -76,6 +74,7 @@ protected override void Dispose(bool disposing) { if (disposing) { + _timeoutHelper.Dispose(); _timeoutHelper = default(TimeoutHelper); } diff --git a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/WebSocketTransportDuplexSessionChannel.cs b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/WebSocketTransportDuplexSessionChannel.cs index f8bf09cf494..3580c5f78a6 100644 --- a/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/WebSocketTransportDuplexSessionChannel.cs +++ b/src/System.Private.ServiceModel/src/System/ServiceModel/Channels/WebSocketTransportDuplexSessionChannel.cs @@ -234,7 +234,7 @@ protected override AsyncCompletionResult StartWritingBufferedMessage(Message mes RemoteAddress != null ? RemoteAddress.ToString() : string.Empty); } - Task task = WebSocket.SendAsync(messageData, outgoingMessageType, true, helper.GetCancellationToken()); + Task task = WebSocket.SendAsync(messageData, outgoingMessageType, true, helper.CancellationToken); Contract.Assert(_pendingWritingMessageException == null, "'pendingWritingMessageException' MUST be NULL at this point."); if (task.IsCompleted) @@ -285,7 +285,7 @@ protected override AsyncCompletionResult BeginCloseOutput(TimeSpan timeout, Acti Fx.Assert(callback != null, "callback should not be null."); var helper = new TimeoutHelper(timeout); - Task task = CloseOutputAsync(helper.GetCancellationToken()); + Task task = CloseOutputAsync(helper.CancellationToken); Fx.Assert(_pendingWritingMessageException == null, "'pendingWritingMessageException' MUST be NULL at this point."); if (task.IsCompleted) @@ -326,7 +326,7 @@ protected override void OnSendCore(Message message, TimeSpan timeout) RemoteAddress != null ? RemoteAddress.ToString() : string.Empty); } - Task task = WebSocket.SendAsync(messageData, outgoingMessageType, true, helper.GetCancellationToken()); + Task task = WebSocket.SendAsync(messageData, outgoingMessageType, true, helper.CancellationToken); task.Wait(helper.RemainingTime(), WebSocketHelper.ThrowCorrectException, WebSocketHelper.SendOperation); if (TD.WebSocketAsyncWriteStopIsEnabled()) @@ -1067,8 +1067,7 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel { Contract.Assert(_messageSource != null, "messageSource should not be null in read case."); - var cancelToken = _readTimeoutHelper.GetCancellationToken(); - if (cancelToken.IsCancellationRequested) + if (_readTimeoutHelper.CancellationToken.IsCancellationRequested) { throw FxTrace.Exception.AsError(WebSocketHelper.GetTimeoutException(null, _readTimeoutHelper.OriginalTimeout, WebSocketHelper.ReceiveOperation)); @@ -1234,7 +1233,7 @@ public void WriteEndOfMessage() if (Interlocked.CompareExchange(ref _endOfMessageWritten, WebSocketHelper.OperationFinished, WebSocketHelper.OperationNotStarted) == WebSocketHelper.OperationNotStarted) { - Task task = _webSocket.SendAsync(new ArraySegment(Array.Empty(), 0, 0), _outgoingMessageType, true, timeoutHelper.GetCancellationToken()); + Task task = _webSocket.SendAsync(new ArraySegment(Array.Empty(), 0, 0), _outgoingMessageType, true, timeoutHelper.CancellationToken); task.Wait(timeoutHelper.RemainingTime(), WebSocketHelper.ThrowCorrectException, WebSocketHelper.SendOperation); } @@ -1255,38 +1254,40 @@ public async void WriteEndOfMessageAsync(Action callback, object state) string.Empty); } - var timeoutHelper = new TimeoutHelper(_closeTimeout); - var cancelTokenTask = timeoutHelper.GetCancellationTokenAsync(); - try + using (var timeoutHelper = new TimeoutHelper(_closeTimeout)) { - var cancelToken = await cancelTokenTask; - await _webSocket.SendAsync(new ArraySegment(Array.Empty(), 0, 0), _outgoingMessageType, true, cancelToken); - - if (TD.WebSocketAsyncWriteStopIsEnabled()) + try { - TD.WebSocketAsyncWriteStop(_webSocket.GetHashCode()); + await + _webSocket.SendAsync(new ArraySegment(Array.Empty(), 0, 0), _outgoingMessageType, + true, timeoutHelper.CancellationToken); + + if (TD.WebSocketAsyncWriteStopIsEnabled()) + { + TD.WebSocketAsyncWriteStop(_webSocket.GetHashCode()); + } } - } - catch (Exception ex) - { - if (Fx.IsFatal(ex)) + catch (Exception ex) { - throw; - } + if (Fx.IsFatal(ex)) + { + throw; + } - if (cancelTokenTask.Result.IsCancellationRequested) - { - throw Fx.Exception.AsError( - new TimeoutException(InternalSR.TaskTimedOutError(timeoutHelper.OriginalTimeout))); - } + if (timeoutHelper.CancellationToken.IsCancellationRequested) + { + throw Fx.Exception.AsError( + new TimeoutException(InternalSR.TaskTimedOutError(timeoutHelper.OriginalTimeout))); + } - throw WebSocketHelper.ConvertAndTraceException(ex, timeoutHelper.OriginalTimeout, - WebSocketHelper.SendOperation); + throw WebSocketHelper.ConvertAndTraceException(ex, timeoutHelper.OriginalTimeout, + WebSocketHelper.SendOperation); - } - finally - { - callback.Invoke(state); + } + finally + { + callback.Invoke(state); + } } } @@ -1319,18 +1320,20 @@ void Cleanup() if (!_endofMessageReceived && (_webSocket.State == WebSocketState.Open || _webSocket.State == WebSocketState.CloseSent)) { // Drain the reading stream - var closeTimeoutHelper = new TimeoutHelper(_closeTimeout); - do + using (var closeTimeoutHelper = new TimeoutHelper(_closeTimeout)) { - Task receiveTask = - _webSocket.ReceiveAsync(new ArraySegment(_initialReadBuffer.Array), - closeTimeoutHelper.GetCancellationToken()); - receiveTask.Wait(closeTimeoutHelper.RemainingTime(), - WebSocketHelper.ThrowCorrectException, WebSocketHelper.ReceiveOperation); - _endofMessageReceived = receiveTask.GetAwaiter().GetResult().EndOfMessage; - } while (!_endofMessageReceived && - (_webSocket.State == WebSocketState.Open || - _webSocket.State == WebSocketState.CloseSent)); + do + { + Task receiveTask = + _webSocket.ReceiveAsync(new ArraySegment(_initialReadBuffer.Array), + closeTimeoutHelper.CancellationToken); + receiveTask.Wait(closeTimeoutHelper.RemainingTime(), + WebSocketHelper.ThrowCorrectException, WebSocketHelper.ReceiveOperation); + _endofMessageReceived = receiveTask.GetAwaiter().GetResult().EndOfMessage; + } while (!_endofMessageReceived && + (_webSocket.State == WebSocketState.Open || + _webSocket.State == WebSocketState.CloseSent)); + } } } catch (Exception ex)