diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 344d517618..028d5321b9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -2783,7 +2783,7 @@ private void CleanUpStateObject(bool isCancelRequested = true) { _stateObj.CancelRequest(); } - _stateObj._internalTimeout = false; + _stateObj.SetTimeoutStateStopped(); _stateObj.CloseSession(); _stateObj._bulkCopyOpperationInProgress = false; _stateObj._bulkCopyWriteTimeout = false; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 83fa69a74f..674f7b1a2a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -956,7 +956,7 @@ private bool TryCloseInternal(bool closeReader) { _sharedState._dataReady = true; // set _sharedState._dataReady to not confuse CleanPartialRead } - _stateObj._internalTimeout = false; + _stateObj.SetTimeoutStateStopped(); if (_sharedState._dataReady) { cleanDataFailed = true; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 9ce5c7a8d2..511ceb57c2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -1896,7 +1896,7 @@ internal bool TryRun(RunBehavior runBehavior, SqlCommand cmdHandler, SqlDataRead // If there is data ready, but we didn't exit the loop, then something is wrong Debug.Assert(!dataReady, "dataReady not expected - did we forget to skip the row?"); - if (stateObj._internalTimeout) + if (stateObj.IsTimeoutStateExpired) { runBehavior = RunBehavior.Attention; } @@ -2520,7 +2520,7 @@ internal bool TryRun(RunBehavior runBehavior, SqlCommand cmdHandler, SqlDataRead stateObj._attentionSent = false; stateObj.HasReceivedAttention = false; - if (RunBehavior.Clean != (RunBehavior.Clean & runBehavior) && !stateObj._internalTimeout) + if (RunBehavior.Clean != (RunBehavior.Clean & runBehavior) && !stateObj.IsTimeoutStateExpired) { // Add attention error to collection - if not RunBehavior.Clean! stateObj.AddError(new SqlError(0, 0, TdsEnums.MIN_ERROR_CLASS, _server, SQLMessage.OperationCancelled(), "", 0)); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 5d36213755..55c73b99b9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -36,6 +36,23 @@ internal enum SnapshottedStateFlags : byte AttentionReceived = 1 << 5 // NOTE: Received is not volatile as it is only ever accessed\modified by TryRun its callees (i.e. single threaded access) } + private sealed class TimeoutState + { + public const int Stopped = 0; + public const int Running = 1; + public const int ExpiredAsync = 2; + public const int ExpiredSync = 3; + + private readonly int _value; + + public TimeoutState(int value) + { + _value = value; + } + + public int IdentityValue => _value; + } + private const int AttentionTimeoutSeconds = 5; private static readonly ContextCallback s_readAdyncCallbackComplete = ReadAsyncCallbackComplete; @@ -113,9 +130,17 @@ internal enum SnapshottedStateFlags : byte // Timeout variables private long _timeoutMilliseconds; private long _timeoutTime; // variable used for timeout computations, holds the value of the hi-res performance counter at which this request should expire + private int _timeoutState; // expected to be one of the constant values TimeoutStopped, TimeoutRunning, TimeoutExpiredAsync, TimeoutExpiredSync + private int _timeoutIdentitySource; + private volatile int _timeoutIdentityValue; internal volatile bool _attentionSent; // true if we sent an Attention to the server internal volatile bool _attentionSending; - internal bool _internalTimeout; // an internal timeout occurred + + // Below 2 properties are used to enforce timeout delays in code to + // reproduce issues related to theadpool starvation and timeout delay. + // It should always be set to false by default, and only be enabled during testing. + internal bool _enforceTimeoutDelay = false; + internal int _enforcedTimeoutDelayInMilliSeconds = 5000; private readonly LastIOTimer _lastSuccessfulIOTimer; @@ -760,7 +785,7 @@ private void ResetCancelAndProcessAttention() // operations. Parser.ProcessPendingAck(this); } - _internalTimeout = false; + SetTimeoutStateStopped(); } } @@ -1042,7 +1067,7 @@ internal bool TryProcessHeader() return false; } - if (_internalTimeout) + if (IsTimeoutStateExpired) { ThrowExceptionAndWarning(); return true; @@ -1447,7 +1472,7 @@ internal bool TryReadInt16(out short value) { // The entire int16 is in the packet and in the buffer, so just return it // and take care of the counters. - buffer = _inBuff.AsSpan(_inBytesUsed,2); + buffer = _inBuff.AsSpan(_inBytesUsed, 2); _inBytesUsed += 2; _inBytesPacket -= 2; } @@ -1481,7 +1506,7 @@ internal bool TryReadInt32(out int value) } AssertValidState(); - value = (buffer[3] << 24) + (buffer[2] <<16) + (buffer[1] << 8) + buffer[0]; + value = (buffer[3] << 24) + (buffer[2] << 16) + (buffer[1] << 8) + buffer[0]; return true; } @@ -2247,11 +2272,62 @@ internal void OnConnectionClosed() } } - private void OnTimeout(object state) + public void SetTimeoutStateStopped() + { + Interlocked.Exchange(ref _timeoutState, TimeoutState.Stopped); + _timeoutIdentityValue = 0; + } + + public bool IsTimeoutStateExpired + { + get + { + int state = _timeoutState; + return state == TimeoutState.ExpiredAsync || state == TimeoutState.ExpiredSync; + } + } + + private void OnTimeoutAsync(object state) { - if (!_internalTimeout) + if (_enforceTimeoutDelay) { - _internalTimeout = true; + Thread.Sleep(_enforcedTimeoutDelayInMilliSeconds); + } + + int currentIdentityValue = _timeoutIdentityValue; + TimeoutState timeoutState = (TimeoutState)state; + if (timeoutState.IdentityValue == _timeoutIdentityValue) + { + // the return value is not useful here because no choice is going to be made using it + // we only want to make this call to set the state knowing that it will be seen later + OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredAsync); + } + else + { + Debug.WriteLine($"OnTimeoutAsync called with identity state={timeoutState.IdentityValue} but current identity is {currentIdentityValue} so it is being ignored"); + } + } + + private bool OnTimeoutSync() + { + return OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredSync); + } + + /// + /// attempts to change the timout state from the expected state to the target state and if it succeeds + /// will setup the the stateobject into the timeout expired state + /// + /// the state that is the expected current state, state will change only if this is correct + /// the state that will be changed to if the expected state is correct + /// boolean value indicating whether the call changed the timeout state + private bool OnTimeoutCore(int expectedState, int targetState) + { + Debug.Assert(targetState == TimeoutState.ExpiredAsync || targetState == TimeoutState.ExpiredSync, "OnTimeoutCore must have an expiry state as the targetState"); + + bool retval = false; + if (Interlocked.CompareExchange(ref _timeoutState, targetState, expectedState) == expectedState) + { + retval = true; // lock protects against Close and Cancel lock (this) { @@ -2349,6 +2425,7 @@ private void OnTimeout(object state) } } } + return retval; } internal void ReadSni(TaskCompletionSource completion) @@ -2383,19 +2460,32 @@ internal void ReadSni(TaskCompletionSource completion) { Debug.Assert(completion != null, "Async on but null asyncResult passed"); - if (_networkPacketTimeout == null) + // if the state is currently stopped then change it to running and allocate a new identity value from + // the identity source. The identity value is used to correlate timer callback events to the currently + // running timeout and prevents a late timer callback affecting a result it does not relate to + int previousTimeoutState = Interlocked.CompareExchange(ref _timeoutState, TimeoutState.Running, TimeoutState.Stopped); + if (previousTimeoutState == TimeoutState.Stopped) { - _networkPacketTimeout = ADP.UnsafeCreateTimer( - new TimerCallback(OnTimeout), - null, - Timeout.Infinite, - Timeout.Infinite); + Debug.Assert(_timeoutIdentityValue == 0, "timer was previously stopped without resetting the _identityValue"); + _timeoutIdentityValue = Interlocked.Increment(ref _timeoutIdentitySource); } + _networkPacketTimeout?.Dispose(); + + _networkPacketTimeout = ADP.UnsafeCreateTimer( + new TimerCallback(OnTimeoutAsync), + new TimeoutState(_timeoutIdentityValue), + Timeout.Infinite, + Timeout.Infinite + ); + + // -1 == Infinite // 0 == Already timed out (NOTE: To simulate the same behavior as sync we will only timeout on 0 if we receive an IO Pending from SNI) // >0 == Actual timeout remaining int msecsRemaining = GetTimeoutRemaining(); + + Debug.Assert(previousTimeoutState == TimeoutState.Stopped, "previous timeout state was not Stopped"); if (msecsRemaining > 0) { ChangeNetworkPacketTimeout(msecsRemaining, Timeout.Infinite); @@ -2445,12 +2535,15 @@ internal void ReadSni(TaskCompletionSource completion) _networkPacketTaskSource.TrySetResult(null); } // Disable timeout timer on error + SetTimeoutStateStopped(); ChangeNetworkPacketTimeout(Timeout.Infinite, Timeout.Infinite); } else if (msecsRemaining == 0) - { // Got IO Pending, but we have no time left to wait - // Immediately schedule the timeout timer to fire - ChangeNetworkPacketTimeout(0, Timeout.Infinite); + { + // Got IO Pending, but we have no time left to wait + // disable the timer and set the error state by calling OnTimeoutSync + ChangeNetworkPacketTimeout(Timeout.Infinite, Timeout.Infinite); + OnTimeoutSync(); } // DO NOT HANDLE PENDING READ HERE - which is TdsEnums.SNI_SUCCESS_IO_PENDING state. // That is handled by user who initiated async read, or by ReadNetworkPacket which is sync over async. @@ -2565,13 +2658,13 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) Debug.Assert(_syncOverAsync, "Should never reach here with async on!"); bool fail = false; - if (_internalTimeout) + if (IsTimeoutStateExpired) { // This is now our second timeout - time to give up. fail = true; } else { - stateObj._internalTimeout = true; + stateObj.SetTimeoutStateStopped(); Debug.Assert(_parser.Connection != null, "SqlConnectionInternalTds handler can not be null at this point."); AddError(new SqlError(TdsEnums.TIMEOUT_EXPIRED, (byte)0x00, TdsEnums.MIN_ERROR_CLASS, _parser.Server, _parser.Connection.TimeoutErrorInternal.GetErrorMessage(), "", 0, TdsEnums.SNI_WAIT_TIMEOUT)); @@ -2794,6 +2887,25 @@ public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error) ChangeNetworkPacketTimeout(Timeout.Infinite, Timeout.Infinite); + // The timer thread may be unreliable under high contention scenarios. It cannot be + // assumed that the timeout has happened on the timer thread callback. Check the timeout + // synchrnously and then call OnTimeoutSync to force an atomic change of state. + if (TimeoutHasExpired) + { + OnTimeoutSync(); + } + + // try to change to the stopped state but only do so if currently in the running state + // and use cmpexch so that all changes out of the running state are atomic + int previousState = Interlocked.CompareExchange(ref _timeoutState, TimeoutState.Running, TimeoutState.Stopped); + + // if the state is anything other than running then this query has reached an end so + // set the correlation _timeoutIdentityValue to 0 to prevent late callbacks executing + if (_timeoutState != TimeoutState.Running) + { + _timeoutIdentityValue = 0; + } + ProcessSniPacket(packet, error); } catch (Exception e) @@ -3454,7 +3566,6 @@ internal void SendAttention(bool mustTakeWriteLock = false) // Set _attentionSending to true before sending attention and reset after setting _attentionSent // This prevents a race condition between receiving the attention ACK and setting _attentionSent _attentionSending = true; - #if DEBUG if (!_skipSendAttention) { @@ -3489,7 +3600,7 @@ internal void SendAttention(bool mustTakeWriteLock = false) } } #if DEBUG - } + } #endif SetTimeoutSeconds(AttentionTimeoutSeconds); // Initialize new attention timeout of 5 seconds. @@ -3862,7 +3973,7 @@ internal void AssertStateIsClean() // Attention\Cancellation\Timeouts Debug.Assert(!HasReceivedAttention && !_attentionSent && !_attentionSending, $"StateObj is still dealing with attention: Sent: {_attentionSent}, Received: {HasReceivedAttention}, Sending: {_attentionSending}"); Debug.Assert(!_cancelled, "StateObj still has cancellation set"); - Debug.Assert(!_internalTimeout, "StateObj still has internal timeout set"); + Debug.Assert(_timeoutState == TimeoutState.Stopped, "StateObj still has internal timeout set"); // Errors and Warnings Debug.Assert(!_hasErrorOrWarning, "StateObj still has stored errors or warnings"); } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index f95a76348d..1a463f1f5d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -3039,7 +3039,7 @@ private void CleanUpStateObject(bool isCancelRequested = true) { _stateObj.CancelRequest(); } - _stateObj._internalTimeout = false; + _stateObj.SetTimeoutStateStopped(); _stateObj.CloseSession(); _stateObj._bulkCopyOpperationInProgress = false; _stateObj._bulkCopyWriteTimeout = false; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 176b3ddf3c..b453bbc52a 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -1067,7 +1067,7 @@ private bool TryCloseInternal(bool closeReader) { _sharedState._dataReady = true; // set _sharedState._dataReady to not confuse CleanPartialRead } - _stateObj._internalTimeout = false; + _stateObj.SetTimeoutStateStopped(); if (_sharedState._dataReady) { cleanDataFailed = true; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index bb4ab023ef..d38633b3c9 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2262,7 +2262,7 @@ internal bool TryRun(RunBehavior runBehavior, SqlCommand cmdHandler, SqlDataRead // If there is data ready, but we didn't exit the loop, then something is wrong Debug.Assert(!dataReady, "dataReady not expected - did we forget to skip the row?"); - if (stateObj._internalTimeout) + if (stateObj.IsTimeoutStateExpired) { runBehavior = RunBehavior.Attention; } @@ -2891,7 +2891,7 @@ internal bool TryRun(RunBehavior runBehavior, SqlCommand cmdHandler, SqlDataRead stateObj._attentionSent = false; stateObj._attentionReceived = false; - if (RunBehavior.Clean != (RunBehavior.Clean & runBehavior) && !stateObj._internalTimeout) + if (RunBehavior.Clean != (RunBehavior.Clean & runBehavior) && !stateObj.IsTimeoutStateExpired) { // Add attention error to collection - if not RunBehavior.Clean! stateObj.AddError(new SqlError(0, 0, TdsEnums.MIN_ERROR_CLASS, _server, SQLMessage.OperationCancelled(), "", 0)); diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 7e526ebd9c..0b2baeb0be 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -24,7 +24,7 @@ sealed internal class LastIOTimer sealed internal class TdsParserStateObject { - const int AttentionTimeoutSeconds = 5; + private const int AttentionTimeoutSeconds = 5; // Ticks to consider a connection "good" after a successful I/O (10,000 ticks = 1 ms) // The resolution of the timer is typically in the range 10 to 16 milliseconds according to msdn. @@ -33,6 +33,23 @@ sealed internal class TdsParserStateObject // of very small open, query, close loops. private const long CheckConnectionWindow = 50000; + private sealed class TimeoutState + { + public const int Stopped = 0; + public const int Running = 1; + public const int ExpiredAsync = 2; + public const int ExpiredSync = 3; + + private readonly int _value; + + public TimeoutState(int value) + { + _value = value; + } + + public int IdentityValue => _value; + } + private static int _objectTypeCount; // EventSource Counter internal readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount); @@ -103,10 +120,19 @@ internal int ObjectID // Timeout variables private long _timeoutMilliseconds; private long _timeoutTime; // variable used for timeout computations, holds the value of the hi-res performance counter at which this request should expire + private int _timeoutState; // expected to be one of the constant values TimeoutStopped, TimeoutRunning, TimeoutExpiredAsync, TimeoutExpiredSync + private int _timeoutIdentitySource; + private volatile int _timeoutIdentityValue; internal volatile bool _attentionSent = false; // true if we sent an Attention to the server internal bool _attentionReceived = false; // NOTE: Received is not volatile as it is only ever accessed\modified by TryRun its callees (i.e. single threaded access) internal volatile bool _attentionSending = false; - internal bool _internalTimeout = false; // an internal timeout occurred + + // Below 2 properties are used to enforce timeout delays in code to + // reproduce issues related to theadpool starvation and timeout delay. + // It should always be set to false by default, and only be enabled during testing. + internal bool _enforceTimeoutDelay = false; + internal int _enforcedTimeoutDelayInMilliSeconds = 5000; + private readonly LastIOTimer _lastSuccessfulIOTimer; // secure password information to be stored @@ -804,7 +830,7 @@ private void ResetCancelAndProcessAttention() } #endif //DEBUG } - _internalTimeout = false; + SetTimeoutStateStopped(); } } @@ -1155,7 +1181,7 @@ internal bool TryProcessHeader() return false; } - if (_internalTimeout) + if (IsTimeoutStateExpired) { ThrowExceptionAndWarning(); // TODO: see the comment above @@ -2328,11 +2354,62 @@ internal void OnConnectionClosed() } - private void OnTimeout(object state) + public void SetTimeoutStateStopped() + { + Interlocked.Exchange(ref _timeoutState, TimeoutState.Stopped); + _timeoutIdentityValue = 0; + } + + public bool IsTimeoutStateExpired + { + get + { + int state = _timeoutState; + return state == TimeoutState.ExpiredAsync || state == TimeoutState.ExpiredSync; + } + } + + private void OnTimeoutAsync(object state) { - if (!_internalTimeout) + if (_enforceTimeoutDelay) + { + Thread.Sleep(_enforcedTimeoutDelayInMilliSeconds); + } + + int currentIdentityValue = _timeoutIdentityValue; + TimeoutState timeoutState = (TimeoutState)state; + if (timeoutState.IdentityValue == _timeoutIdentityValue) { - _internalTimeout = true; + // the return value is not useful here because no choice is going to be made using it + // we only want to make this call to set the state knowing that it will be seen later + OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredAsync); + } + else + { + Debug.WriteLine($"OnTimeoutAsync called with identity state={timeoutState.IdentityValue} but current identity is {currentIdentityValue} so it is being ignored"); + } + } + + private bool OnTimeoutSync() + { + return OnTimeoutCore(TimeoutState.Running, TimeoutState.ExpiredSync); + } + + /// + /// attempts to change the timout state from the expected state to the target state and if it succeeds + /// will setup the the stateobject into the timeout expired state + /// + /// the state that is the expected current state, state will change only if this is correct + /// the state that will be changed to if the expected state is correct + /// boolean value indicating whether the call changed the timeout state + private bool OnTimeoutCore(int expectedState, int targetState) + { + Debug.Assert(targetState == TimeoutState.ExpiredAsync || targetState == TimeoutState.ExpiredSync, "OnTimeoutCore must have an expiry state as the targetState"); + + bool retval = false; + if (Interlocked.CompareExchange(ref _timeoutState, targetState, expectedState) == expectedState) + { + retval = true; // lock protects against Close and Cancel lock (this) { @@ -2432,6 +2509,7 @@ private void OnTimeout(object state) } } } + return retval; } internal void ReadSni(TaskCompletionSource completion) @@ -2464,15 +2542,31 @@ internal void ReadSni(TaskCompletionSource completion) { Debug.Assert(completion != null, "Async on but null asyncResult passed"); - if (_networkPacketTimeout == null) + // if the state is currently stopped then change it to running and allocate a new identity value from + // the identity source. The identity value is used to correlate timer callback events to the currently + // running timeout and prevents a late timer callback affecting a result it does not relate to + int previousTimeoutState = Interlocked.CompareExchange(ref _timeoutState, TimeoutState.Running, TimeoutState.Stopped); + if (previousTimeoutState == TimeoutState.Stopped) { - _networkPacketTimeout = new Timer(OnTimeout, null, Timeout.Infinite, Timeout.Infinite); + Debug.Assert(_timeoutIdentityValue == 0, "timer was previously stopped without resetting the _identityValue"); + _timeoutIdentityValue = Interlocked.Increment(ref _timeoutIdentitySource); } + _networkPacketTimeout?.Dispose(); + + _networkPacketTimeout = new Timer( + new TimerCallback(OnTimeoutAsync), + new TimeoutState(_timeoutIdentityValue), + Timeout.Infinite, + Timeout.Infinite + ); + // -1 == Infinite // 0 == Already timed out (NOTE: To simulate the same behavior as sync we will only timeout on 0 if we receive an IO Pending from SNI) // >0 == Actual timeout remaining int msecsRemaining = GetTimeoutRemaining(); + + Debug.Assert(previousTimeoutState == TimeoutState.Stopped, "previous timeout state was not Stopped"); if (msecsRemaining > 0) { ChangeNetworkPacketTimeout(msecsRemaining, Timeout.Infinite); @@ -2529,12 +2623,15 @@ internal void ReadSni(TaskCompletionSource completion) _networkPacketTaskSource.TrySetResult(null); } // Disable timeout timer on error + SetTimeoutStateStopped(); ChangeNetworkPacketTimeout(Timeout.Infinite, Timeout.Infinite); } else if (msecsRemaining == 0) - { // Got IO Pending, but we have no time left to wait - // Immediately schedule the timeout timer to fire - ChangeNetworkPacketTimeout(0, Timeout.Infinite); + { + // Got IO Pending, but we have no time left to wait + // disable the timer and set the error state by calling OnTimeoutSync + ChangeNetworkPacketTimeout(Timeout.Infinite, Timeout.Infinite); + OnTimeoutSync(); } // DO NOT HANDLE PENDING READ HERE - which is TdsEnums.SNI_SUCCESS_IO_PENDING state. // That is handled by user who initiated async read, or by ReadNetworkPacket which is sync over async. @@ -2672,13 +2769,13 @@ private void ReadSniError(TdsParserStateObject stateObj, UInt32 error) Debug.Assert(_syncOverAsync, "Should never reach here with async on!"); bool fail = false; - if (_internalTimeout) + if (IsTimeoutStateExpired) { // This is now our second timeout - time to give up. fail = true; } else { - stateObj._internalTimeout = true; + stateObj.SetTimeoutStateStopped(); Debug.Assert(_parser.Connection != null, "SqlConnectionInternalTds handler can not be null at this point."); AddError(new SqlError(TdsEnums.TIMEOUT_EXPIRED, (byte)0x00, TdsEnums.MIN_ERROR_CLASS, _parser.Server, _parser.Connection.TimeoutErrorInternal.GetErrorMessage(), "", 0, TdsEnums.SNI_WAIT_TIMEOUT)); @@ -2876,6 +2973,25 @@ public void ReadAsyncCallback(IntPtr key, IntPtr packet, UInt32 error) ChangeNetworkPacketTimeout(Timeout.Infinite, Timeout.Infinite); + // The timer thread may be unreliable under high contention scenarios. It cannot be + // assumed that the timeout has happened on the timer thread callback. Check the timeout + // synchrnously and then call OnTimeoutSync to force an atomic change of state. + if (TimeoutHasExpired) + { + OnTimeoutSync(); + } + + // try to change to the stopped state but only do so if currently in the running state + // and use cmpexch so that all changes out of the running state are atomic + int previousState = Interlocked.CompareExchange(ref _timeoutState, TimeoutState.Running, TimeoutState.Stopped); + + // if the state is anything other than running then this query has reached an end so + // set the correlation _timeoutIdentityValue to 0 to prevent late callbacks executing + if (_timeoutState != TimeoutState.Running) + { + _timeoutIdentityValue = 0; + } + ProcessSniPacket(packet, error); } catch (Exception e) @@ -4011,7 +4127,7 @@ internal void AssertStateIsClean() // Attention\Cancellation\Timeouts Debug.Assert(!_attentionReceived && !_attentionSent && !_attentionSending, $"StateObj is still dealing with attention: Sent: {_attentionSent}, Received: {_attentionReceived}, Sending: {_attentionSending}"); Debug.Assert(!_cancelled, "StateObj still has cancellation set"); - Debug.Assert(!_internalTimeout, "StateObj still has internal timeout set"); + Debug.Assert(_timeoutState == TimeoutState.Stopped, "StateObj still has internal timeout set"); // Errors and Warnings Debug.Assert(!_hasErrorOrWarning, "StateObj still has stored errors or warnings"); } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 2a5bf658c5..45fc752361 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -66,6 +66,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncTimeoutTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncTimeoutTest.cs new file mode 100644 index 0000000000..0ba98d83b6 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncTimeoutTest.cs @@ -0,0 +1,209 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data; +using System.Threading.Tasks; +using System.Xml; +using Microsoft.Data.SqlClient.ManualTesting.Tests.SystemDataInternals; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public static class AsyncTimeoutTest + { + static string delayQuery2s = "WAITFOR DELAY '00:00:02'"; + static string delayQuery10s = "WAITFOR DELAY '00:00:10'"; + + public enum AsyncAPI + { + ExecuteReaderAsync, + ExecuteScalarAsync, + ExecuteXmlReaderAsync + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [ClassData(typeof(AsyncTimeoutTestVariations))] + public static void TestDelayedAsyncTimeout(AsyncAPI api, string commonObj, int delayPeriod, bool marsEnabled) => + RunTest(api, commonObj, delayPeriod, marsEnabled); + + public class AsyncTimeoutTestVariations : IEnumerable + { + public IEnumerator GetEnumerator() + { + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Connection", 8000, true }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Connection", 5000, true }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Connection", 0, true }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Connection", 8000, false }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Connection", 5000, false }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Connection", 0, false }; + + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Connection", 8000, true }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Connection", 5000, true }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Connection", 0, true }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Connection", 8000, false }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Connection", 5000, false }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Connection", 0, false }; + + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Connection", 8000, true }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Connection", 5000, true }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Connection", 0, true }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Connection", 8000, false }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Connection", 5000, false }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Connection", 0, false }; + + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Command", 8000, true }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Command", 5000, true }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Command", 0, true }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Command", 8000, false }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Command", 5000, false }; + yield return new object[] { AsyncAPI.ExecuteReaderAsync, "Command", 0, false }; + + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Command", 8000, true }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Command", 5000, true }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Command", 0, true }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Command", 8000, false }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Command", 5000, false }; + yield return new object[] { AsyncAPI.ExecuteScalarAsync, "Command", 0, false }; + + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Command", 8000, true }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Command", 5000, true }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Command", 0, true }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Command", 8000, false }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Command", 5000, false }; + yield return new object[] { AsyncAPI.ExecuteXmlReaderAsync, "Command", 0, false }; + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + private static void RunTest(AsyncAPI api, string commonObj, int timeoutDelay, bool marsEnabled) + { + string connString = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString) + { + MultipleActiveResultSets = marsEnabled + }.ConnectionString; + + using (SqlConnection sqlConnection = new SqlConnection(connString)) + { + sqlConnection.Open(); + if (timeoutDelay != 0) + { + ConnectionHelper.SetEnforcedTimeout(sqlConnection, true, timeoutDelay); + } + switch (commonObj) + { + case "Connection": + QueryAndValidate(api, 1, delayQuery2s, 1, true, true, sqlConnection).Wait(); + QueryAndValidate(api, 2, delayQuery2s, 5, false, true, sqlConnection).Wait(); + QueryAndValidate(api, 3, delayQuery10s, 1, true, true, sqlConnection).Wait(); + QueryAndValidate(api, 4, delayQuery2s, 10, false, true, sqlConnection).Wait(); + break; + case "Command": + using (SqlCommand cmd = sqlConnection.CreateCommand()) + { + QueryAndValidate(api, 1, delayQuery2s, 1, true, false, sqlConnection, cmd).Wait(); + QueryAndValidate(api, 2, delayQuery2s, 5, false, false, sqlConnection, cmd).Wait(); + QueryAndValidate(api, 3, delayQuery10s, 1, true, false, sqlConnection, cmd).Wait(); + QueryAndValidate(api, 4, delayQuery2s, 10, false, false, sqlConnection, cmd).Wait(); + } + break; + } + } + } + + private static async Task QueryAndValidate(AsyncAPI api, int index, string delayQuery, int timeout, + bool timeoutExExpected = false, bool useTransaction = false, SqlConnection cn = null, SqlCommand cmd = null) + { + SqlTransaction tx = null; + try + { + if (cn != null) + { + if (cn.State != ConnectionState.Open) + { + await cn.OpenAsync(); + } + cmd = cn.CreateCommand(); + if (useTransaction) + { + tx = cn.BeginTransaction(IsolationLevel.ReadCommitted); + cmd.Transaction = tx; + } + } + + cmd.CommandTimeout = timeout; + if (api != AsyncAPI.ExecuteXmlReaderAsync) + { + cmd.CommandText = delayQuery + $";select {index} as Id;"; + } + else + { + cmd.CommandText = delayQuery + $";select {index} as Id FOR XML PATH;"; + } + + var result = -1; + switch (api) + { + case AsyncAPI.ExecuteReaderAsync: + using (SqlDataReader reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false)) + { + while (await reader.ReadAsync().ConfigureAwait(false)) + { + var columnIndex = reader.GetOrdinal("Id"); + result = reader.GetInt32(columnIndex); + break; + } + } + break; + case AsyncAPI.ExecuteScalarAsync: + result = (int)await cmd.ExecuteScalarAsync().ConfigureAwait(false); + break; + case AsyncAPI.ExecuteXmlReaderAsync: + using (XmlReader reader = await cmd.ExecuteXmlReaderAsync().ConfigureAwait(false)) + { + try + { + Assert.True(reader.Settings.Async); + reader.ReadToDescendant("Id"); + result = reader.ReadElementContentAsInt(); + } + catch (Exception ex) + { + Assert.False(true, "Exception occurred: " + ex.Message); + } + } + break; + } + + if (result != index) + { + throw new Exception("High Alert! Wrong data received for index: " + index); + } + else + { + Assert.True(!timeoutExExpected && result == index); + } + } + catch (SqlException e) + { + if (!timeoutExExpected) + throw new Exception("Index " + index + " failed with: " + e.Message); + else + Assert.True(timeoutExExpected && e.Class == 11 && e.Number == -2); + } + finally + { + if (cn != null) + { + if (useTransaction) + tx.Commit(); + cn.Close(); + } + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs index 2b4f533dd5..54561e1be9 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionHelper.cs @@ -10,15 +10,22 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SystemDataInternals { internal static class ConnectionHelper { - private static Assembly s_systemDotData = Assembly.Load(new AssemblyName(typeof(SqlConnection).GetTypeInfo().Assembly.FullName)); - private static Type s_sqlConnection = s_systemDotData.GetType("Microsoft.Data.SqlClient.SqlConnection"); - private static Type s_sqlInternalConnection = s_systemDotData.GetType("Microsoft.Data.SqlClient.SqlInternalConnection"); - private static Type s_sqlInternalConnectionTds = s_systemDotData.GetType("Microsoft.Data.SqlClient.SqlInternalConnectionTds"); - private static Type s_dbConnectionInternal = s_systemDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionInternal"); + private static Assembly s_MicrosoftDotData = Assembly.Load(new AssemblyName(typeof(SqlConnection).GetTypeInfo().Assembly.FullName)); + private static Type s_sqlConnection = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.SqlConnection"); + private static Type s_sqlInternalConnection = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.SqlInternalConnection"); + private static Type s_sqlInternalConnectionTds = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.SqlInternalConnectionTds"); + private static Type s_dbConnectionInternal = s_MicrosoftDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionInternal"); + private static Type s_tdsParser = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.TdsParser"); + private static Type s_tdsParserStateObject = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.TdsParserStateObject"); private static PropertyInfo s_sqlConnectionInternalConnection = s_sqlConnection.GetProperty("InnerConnection", BindingFlags.Instance | BindingFlags.NonPublic); private static PropertyInfo s_dbConnectionInternalPool = s_dbConnectionInternal.GetProperty("Pool", BindingFlags.Instance | BindingFlags.NonPublic); private static MethodInfo s_dbConnectionInternalIsConnectionAlive = s_dbConnectionInternal.GetMethod("IsConnectionAlive", BindingFlags.Instance | BindingFlags.NonPublic); private static FieldInfo s_sqlInternalConnectionTdsParser = s_sqlInternalConnectionTds.GetField("_parser", BindingFlags.Instance | BindingFlags.NonPublic); + private static PropertyInfo s_innerConnectionProperty = s_sqlConnection.GetProperty("InnerConnection", BindingFlags.Instance | BindingFlags.NonPublic); + private static PropertyInfo s_tdsParserProperty = s_sqlInternalConnectionTds.GetProperty("Parser", BindingFlags.Instance | BindingFlags.NonPublic); + private static FieldInfo s_tdsParserStateObjectProperty = s_tdsParser.GetField("_physicalStateObj", BindingFlags.Instance | BindingFlags.NonPublic); + private static FieldInfo s_enforceTimeoutDelayProperty = s_tdsParserStateObject.GetField("_enforceTimeoutDelay", BindingFlags.Instance | BindingFlags.NonPublic); + private static FieldInfo s_enforcedTimeoutDelayInMilliSeconds = s_tdsParserStateObject.GetField("_enforcedTimeoutDelayInMilliSeconds", BindingFlags.Instance | BindingFlags.NonPublic); public static object GetConnectionPool(object internalConnection) { @@ -28,12 +35,12 @@ public static object GetConnectionPool(object internalConnection) public static object GetInternalConnection(this SqlConnection connection) { + VerifyObjectIsConnection(connection); object internalConnection = s_sqlConnectionInternalConnection.GetValue(connection, null); Debug.Assert(((internalConnection != null) && (s_dbConnectionInternal.IsInstanceOfType(internalConnection))), "Connection provided has an invalid internal connection"); return internalConnection; } - public static bool IsConnectionAlive(object internalConnection) { VerifyObjectIsInternalConnection(internalConnection); @@ -45,7 +52,15 @@ private static void VerifyObjectIsInternalConnection(object internalConnection) if (internalConnection == null) throw new ArgumentNullException(nameof(internalConnection)); if (!s_dbConnectionInternal.IsInstanceOfType(internalConnection)) - throw new ArgumentException("Object provided was not a DbConnectionInternal", "internalConnection"); + throw new ArgumentException("Object provided was not a DbConnectionInternal", nameof(internalConnection)); + } + + private static void VerifyObjectIsConnection(object connection) + { + if (connection == null) + throw new ArgumentNullException(nameof(connection)); + if (!s_sqlConnection.IsInstanceOfType(connection)) + throw new ArgumentException("Object provided was not a SqlConnection", nameof(connection)); } public static object GetParser(object internalConnection) @@ -53,5 +68,16 @@ public static object GetParser(object internalConnection) VerifyObjectIsInternalConnection(internalConnection); return s_sqlInternalConnectionTdsParser.GetValue(internalConnection); } + + public static void SetEnforcedTimeout(this SqlConnection connection, bool enforce, int timeout) + { + VerifyObjectIsConnection(connection); + var stateObj = s_tdsParserStateObjectProperty.GetValue( + s_tdsParserProperty.GetValue( + s_innerConnectionProperty.GetValue( + connection, null), null)); + s_enforceTimeoutDelayProperty.SetValue(stateObj, enforce); + s_enforcedTimeoutDelayInMilliSeconds.SetValue(stateObj, timeout); + } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionPoolHelper.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionPoolHelper.cs index 6ae73f5571..d7c5471427 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionPoolHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Common/SystemDataInternals/ConnectionPoolHelper.cs @@ -13,13 +13,13 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SystemDataInternals { internal static class ConnectionPoolHelper { - private static Assembly s_systemDotData = Assembly.Load(new AssemblyName(typeof(SqlConnection).GetTypeInfo().Assembly.FullName)); - private static Type s_dbConnectionPool = s_systemDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionPool"); - private static Type s_dbConnectionPoolGroup = s_systemDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionPoolGroup"); - private static Type s_dbConnectionPoolIdentity = s_systemDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionPoolIdentity"); - private static Type s_dbConnectionFactory = s_systemDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionFactory"); - private static Type s_sqlConnectionFactory = s_systemDotData.GetType("Microsoft.Data.SqlClient.SqlConnectionFactory"); - private static Type s_dbConnectionPoolKey = s_systemDotData.GetType("Microsoft.Data.Common.DbConnectionPoolKey"); + private static Assembly s_MicrosoftDotData = Assembly.Load(new AssemblyName(typeof(SqlConnection).GetTypeInfo().Assembly.FullName)); + private static Type s_dbConnectionPool = s_MicrosoftDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionPool"); + private static Type s_dbConnectionPoolGroup = s_MicrosoftDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionPoolGroup"); + private static Type s_dbConnectionPoolIdentity = s_MicrosoftDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionPoolIdentity"); + private static Type s_dbConnectionFactory = s_MicrosoftDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionFactory"); + private static Type s_sqlConnectionFactory = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.SqlConnectionFactory"); + private static Type s_dbConnectionPoolKey = s_MicrosoftDotData.GetType("Microsoft.Data.Common.DbConnectionPoolKey"); private static Type s_dictStringPoolGroup = typeof(Dictionary<,>).MakeGenericType(s_dbConnectionPoolKey, s_dbConnectionPoolGroup); private static Type s_dictPoolIdentityPool = typeof(ConcurrentDictionary<,>).MakeGenericType(s_dbConnectionPoolIdentity, s_dbConnectionPool); private static PropertyInfo s_dbConnectionPoolCount = s_dbConnectionPool.GetProperty("Count", BindingFlags.Instance | BindingFlags.NonPublic); @@ -123,7 +123,6 @@ internal static int CountConnectionsInPool(object pool) return (int)s_dbConnectionPoolCount.GetValue(pool, null); } - private static void VerifyObjectIsPool(object pool) { if (pool == null)