From d9ab5a90387b9f7b94a818b8cafec5932160549f Mon Sep 17 00:00:00 2001 From: David Engel Date: Tue, 14 Mar 2023 13:05:01 -0700 Subject: [PATCH] Fix | Throttling of token requests by calling AcquireTokenSilent (#1925) * Address throttling of token requests by calling AcquireTokenSilent in Integrated/Password flows when the account is already cached. Addresses issue #1915 Co-authored-by: Lawrence Cheung <31262254+lcheunglci@users.noreply.github.com> Co-authored-by: DavoudEshtehari <61173489+DavoudEshtehari@users.noreply.github.com> --- .../ActiveDirectoryAuthenticationProvider.cs | 193 ++++++++++++------ 1 file changed, 126 insertions(+), 67 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 573c36ee55..6e57bb6c07 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -4,7 +4,10 @@ using System; using System.Collections.Concurrent; -using System.Security; +using System.Linq; +using System.Runtime.Caching; +using System.Security.Cryptography; +using System.Text; using System.Threading; using System.Threading.Tasks; using Azure.Core; @@ -24,6 +27,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro /// private static ConcurrentDictionary s_pcaMap = new ConcurrentDictionary(); + private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider)); + private static readonly int s_accountPwCacheTtlInHours = 2; private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient"; private static readonly string s_defaultScopeSuffix = "/.default"; private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name; @@ -172,7 +177,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } - AuthenticationResult result; + AuthenticationResult result = null; if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) { AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); @@ -208,82 +213,82 @@ public override async Task AcquireTokenAsync(SqlAuthenti if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - if (!string.IsNullOrEmpty(parameters.UserId)) - { - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .WithUsername(parameters.UserId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - } - else - { - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - } - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) - { - result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || - parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) - { - // Fetch available accounts from 'app' instance - System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); - IAccount account = default; - if (accounts.MoveNext()) + if (null == result) { if (!string.IsNullOrEmpty(parameters.UserId)) { - do - { - IAccount currentVal = accounts.Current; - if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0) - { - account = currentVal; - break; - } - } - while (accounts.MoveNext()); + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .WithUsername(parameters.UserId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); } else { - account = accounts.Current; + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); } + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + } + else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) + { + string pwCacheKey = GetAccountPwCacheKey(parameters); + object previousPw = s_accountPwCache.Get(pwCacheKey); + byte[] currPwHash = GetHash(parameters.Password); + + if (null != previousPw && + previousPw is byte[] previousPwBytes && + // Only get the cached token if the current password hash matches the previously used password hash + currPwHash.SequenceEqual(previousPwBytes)) + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); } - if (null != account) + if (null == result) { - try - { - // If 'account' is available in 'app', we use the same to acquire token silently. - // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent - result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); - } - catch (MsalUiRequiredException) + result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + + // We cache the password hash to ensure future connection requests include a validated password + // when we check for a cached MSAL account. Otherwise, a connection request with the same username + // against the same tenant could succeed with an invalid password when we re-use the cached token. + if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours))) { - // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, - // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), - // or the user needs to perform two factor authentication. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + s_accountPwCache.Remove(pwCacheKey); + s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)); } + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); } - else + } + else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) + { + try + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + catch (MsalUiRequiredException) + { + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + if (null == result) { // If no existing 'account' is found, we request user to sign in interactively. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } } @@ -296,8 +301,49 @@ public override async Task AcquireTokenAsync(SqlAuthenti return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); } - private async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, - SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts) + private static async Task TryAcquireTokenSilent(IPublicClientApplication app, SqlAuthenticationParameters parameters, + string[] scopes, CancellationTokenSource cts) + { + AuthenticationResult result = null; + + // Fetch available accounts from 'app' instance + System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); + + IAccount account = default; + if (accounts.MoveNext()) + { + if (!string.IsNullOrEmpty(parameters.UserId)) + { + do + { + IAccount currentVal = accounts.Current; + if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0) + { + account = currentVal; + break; + } + } + while (accounts.MoveNext()); + } + else + { + account = accounts.Current; + } + } + + if (null != account) + { + // If 'account' is available in 'app', we use the same to acquire token silently. + // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent + result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + return result; + } + + private static async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, + SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback) { try { @@ -316,11 +362,11 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( */ ctsInteractive.CancelAfter(180000); #endif - if (_customWebUI != null) + if (customWebUI != null) { return await app.AcquireTokenInteractive(scopes) .WithCorrelationId(connectionId) - .WithCustomWebUi(_customWebUI) + .WithCustomWebUi(customWebUI) .WithLoginHint(userId) .ExecuteAsync(ctsInteractive.Token) .ConfigureAwait(false); @@ -354,7 +400,7 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( else { AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes, - deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)) + deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult)) .WithCorrelationId(connectionId) .ExecuteAsync(cancellationToken: cts.Token) .ConfigureAwait(false); @@ -407,6 +453,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p return clientApplicationInstance; } + private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters) + { + return parameters.Authority + "+" + parameters.UserId; + } + + private static byte[] GetHash(string input) + { + byte[] unhashedBytes = Encoding.Unicode.GetBytes(input); + SHA256 sha256 = SHA256.Create(); + byte[] hashedBytes = sha256.ComputeHash(unhashedBytes); + return hashedBytes; + } + private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey) { IPublicClientApplication publicClientApplication;