Skip to content

Commit

Permalink
Implement public client application global cache
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra committed Oct 22, 2020
1 parent 4107f24 commit 03e8907
Showing 1 changed file with 179 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Security;
using System.Threading;
Expand All @@ -15,6 +17,9 @@ namespace Microsoft.Data.SqlClient
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/ActiveDirectoryAuthenticationProvider/*'/>
public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider
{
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
private readonly SqlClientLogger _logger = new SqlClientLogger();
Expand Down Expand Up @@ -67,10 +72,10 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
}

#if NETSTANDARD
private Func<object> parentActivityOrWindowFunc = null;
private Func<object> _parentActivityOrWindowFunc = null;

/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/SetParentActivityOrWindowFunc/*'/>
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this.parentActivityOrWindowFunc = parentActivityOrWindowFunc;
public void SetParentActivityOrWindowFunc(Func<object> parentActivityOrWindowFunc) => this._parentActivityOrWindowFunc = parentActivityOrWindowFunc;
#endif

#if NETFRAMEWORK
Expand Down Expand Up @@ -108,51 +113,24 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
*
* https://docs.microsoft.com/en-us/azure/active-directory/develop/scenario-desktop-app-registration#redirect-uris
*/
string redirectURI = "https://login.microsoftonline.com/common/oauth2/nativeclient";
string redirectUri = s_nativeClientRedirectUri;
#if NETCOREAPP
if (parameters.AuthenticationMethod != SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
redirectURI = "http://localhost";
}
#endif
IPublicClientApplication app;
#if NETSTANDARD
if (parentActivityOrWindowFunc != null)
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.WithParentActivityOrWindow(parentActivityOrWindowFunc)
.Build();
redirectUri = "http://localhost";
}
#endif
PublicClientAppKey pcaKey = new PublicClientAppKey(parameters.Authority, redirectUri, _applicationClientId
#if NETFRAMEWORK
if (_iWin32WindowFunc != null)
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.WithParentActivityOrWindow(_iWin32WindowFunc)
.Build();
}
, _iWin32WindowFunc
#endif
#if !NETCOREAPP
else
#if NETSTANDARD
, _parentActivityOrWindowFunc
#endif
{
app = PublicClientApplicationBuilder.Create(_applicationClientId)
.WithAuthority(parameters.Authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(redirectURI)
.Build();
}
);
IPublicClientApplication app = GetPublicClientAppInstance(pcaKey);
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
Expand Down Expand Up @@ -320,5 +298,168 @@ private class CustomWebUi : ICustomWebUi
public Task<Uri> AcquireAuthorizationCodeAsync(Uri authorizationUri, Uri redirectUri, CancellationToken cancellationToken)
=> _acquireAuthorizationCodeAsyncCallback.Invoke(authorizationUri, redirectUri, cancellationToken);
}

private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication clientApplicationInstance;

if (s_pcaMap.ContainsKey(publicClientAppKey))
{
s_pcaMap.TryGetValue(publicClientAppKey, out clientApplicationInstance);
}
else
{
clientApplicationInstance = CreateClientAppInstance(publicClientAppKey);
s_pcaMap.TryAdd(publicClientAppKey, clientApplicationInstance);
}
return clientApplicationInstance;
}

private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication publicClientApplication;

#if NETSTANDARD
if (_parentActivityOrWindowFunc != null)
{
clientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.WithParentActivityOrWindow(_parentActivityOrWindowFunc)
.Build();
}
#endif
#if NETFRAMEWORK
if (_iWin32WindowFunc != null)
{
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.WithParentActivityOrWindow(_iWin32WindowFunc)
.Build();
}
#endif
#if !NETCOREAPP
else
#endif
{
publicClientApplication = PublicClientApplicationBuilder.Create(publicClientAppKey._applicationClientId)
.WithAuthority(publicClientAppKey._authority)
.WithClientName(Common.DbConnectionStringDefaults.ApplicationName)
.WithClientVersion(Common.ADP.GetAssemblyVersion().ToString())
.WithRedirectUri(publicClientAppKey._redirectUri)
.Build();
}

return publicClientApplication;
}

internal class PublicClientAppKey
{
public readonly string _authority;
public readonly string _redirectUri;
public readonly string _applicationClientId;
#if NETFRAMEWORK
public readonly Func<System.Windows.Forms.IWin32Window> _iWin32WindowFunc;
#endif
#if NETSTANDARD
public readonly Func<object> _parentActivityOrWindowFunc;
#endif
private int _hashValue;

public PublicClientAppKey(string authority, string redirectUri, string applicationClientId
#if NETFRAMEWORK
, Func<System.Windows.Forms.IWin32Window> iWin32WindowFunc
#endif
#if NETSTANDARD
, Func<object> parentActivityOrWindowFunc
#endif
)
{
_authority = authority;
_redirectUri = redirectUri;
_applicationClientId = applicationClientId;
#if NETFRAMEWORK
_iWin32WindowFunc = iWin32WindowFunc;
#endif
#if NETSTANDARD
_parentActivityOrWindowFunc = parentActivityOrWindowFunc;
#endif
}

public override bool Equals(object obj)
{
if (obj == null)
{
return false;
}

PublicClientAppKey pcaKey = obj as PublicClientAppKey;
return (string.CompareOrdinal(_authority, pcaKey._authority) == 0
&& string.CompareOrdinal(_redirectUri, pcaKey._redirectUri) == 0
&& string.CompareOrdinal(_applicationClientId, pcaKey._applicationClientId) == 0
#if NETFRAMEWORK
&& pcaKey._iWin32WindowFunc == _iWin32WindowFunc
#endif
#if NETSTANDARD
&& pcaKey._parentActivityOrWindowFunc == _parentActivityOrWindowFunc
#endif
);
}

public override int GetHashCode()
{
return _hashValue;
}

private void CalculateHashCode()
{
_hashValue = base.GetHashCode();

if (_authority != null)
{
unchecked
{
_hashValue = _hashValue * 17 + _authority.GetHashCode();
}
}
if (_redirectUri != null)
{
unchecked
{
_hashValue = _hashValue * 17 + _redirectUri.GetHashCode();
}
}
if (_applicationClientId != null)
{
unchecked
{
_hashValue = _hashValue * 17 + _applicationClientId.GetHashCode();
}
}
#if NETFRAMEWORK
if (_iWin32WindowFunc != null)
{
unchecked
{
_hashValue = _hashValue * 17 + _iWin32WindowFunc.GetHashCode();
}
}
#endif
#if NETSTANDARD
if (_parentActivityOrWindowFunc != null)
{
unchecked
{
_hashValue = _hashValue * 17 + _parentActivityOrWindowFunc.GetHashCode();
}
}
#endif
}
}
}
}

0 comments on commit 03e8907

Please sign in to comment.