Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Preview] AAD: Fixes token refresh interval, exception handling, and retry logic #2481

Merged
merged 13 commits into from
May 20, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ internal sealed class AuthorizationTokenProviderTokenCredential : AuthorizationT
public AuthorizationTokenProviderTokenCredential(
TokenCredential tokenCredential,
Uri accountEndpoint,
TimeSpan requestTimeout,
TimeSpan? backgroundTokenCredentialRefreshInterval)
{
this.tokenCredentialCache = new TokenCredentialCache(
tokenCredential: tokenCredential,
accountEndpoint: accountEndpoint,
requestTimeout: requestTimeout,
backgroundTokenCredentialRefreshInterval: backgroundTokenCredentialRefreshInterval);
}

Expand All @@ -38,9 +36,12 @@ public AuthorizationTokenProviderTokenCredential(
INameValueCollection headers,
AuthorizationTokenType tokenType)
{
string token = AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature(
await this.tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton));
return (token, default);
using (Trace trace = Trace.GetRootTrace(nameof(GetUserAuthorizationTokenAsync), TraceComponent.Authorization, TraceLevel.Info))
{
string token = AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature(
await this.tokenCredentialCache.GetTokenAsync(trace));
return (token, default);
}
}

public override async ValueTask<string> GetUserAuthorizationTokenAsync(
Expand All @@ -61,10 +62,13 @@ public override async ValueTask AddAuthorizationHeaderAsync(
string verb,
AuthorizationTokenType tokenType)
{
string token = AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature(
await this.tokenCredentialCache.GetTokenAsync(NoOpTrace.Singleton));
using (Trace trace = Trace.GetRootTrace(nameof(GetUserAuthorizationTokenAsync), TraceComponent.Authorization, TraceLevel.Info))
{
string token = AuthorizationTokenProviderTokenCredential.GenerateAadAuthorizationSignature(
await this.tokenCredentialCache.GetTokenAsync(trace));

headersCollection.Add(HttpConstants.HttpHeaders.Authorization, token);
headersCollection.Add(HttpConstants.HttpHeaders.Authorization, token);
}
}

public override void TraceUnauthorized(
Expand Down
286 changes: 152 additions & 134 deletions Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion Microsoft.Azure.Cosmos/src/CosmosClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ public CosmosClient(
this.AuthorizationTokenProvider = new AuthorizationTokenProviderTokenCredential(
tokenCredential,
this.Endpoint,
clientOptions.RequestTimeout,
clientOptions.TokenCredentialBackgroundRefreshInterval);

this.ClientContext = ClientContextCore.Create(
Expand Down
2 changes: 1 addition & 1 deletion Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public int GatewayModeMaxConnectionLimit
/// This avoids latency issues because the old token is used until the new token is retrieved.
/// </summary>
/// <remarks>
/// The recommended minimum value is 5 minutes. The default value is 25% of the token expire time.
/// The recommended minimum value is 5 minutes. The default value is 50% of the token expire time.
/// </remarks>
#if PREVIEW
public
Expand Down
16 changes: 6 additions & 10 deletions Microsoft.Azure.Cosmos/src/Routing/ClientCollectionCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@ namespace Microsoft.Azure.Cosmos.Routing
internal class ClientCollectionCache : CollectionCache
{
private readonly IStoreModel storeModel;
private readonly IAuthorizationTokenProvider tokenProvider;
private readonly ICosmosAuthorizationTokenProvider tokenProvider;
private readonly IRetryPolicyFactory retryPolicy;
private readonly ISessionContainer sessionContainer;

public ClientCollectionCache(ISessionContainer sessionContainer, IStoreModel storeModel, IAuthorizationTokenProvider tokenProvider, IRetryPolicyFactory retryPolicy)
public ClientCollectionCache(ISessionContainer sessionContainer, IStoreModel storeModel, ICosmosAuthorizationTokenProvider tokenProvider, IRetryPolicyFactory retryPolicy)
{
if (storeModel == null)
{
throw new ArgumentNullException("storeModel");
}

this.storeModel = storeModel;
this.storeModel = storeModel ?? throw new ArgumentNullException("storeModel");
this.tokenProvider = tokenProvider;
this.retryPolicy = retryPolicy;
this.sessionContainer = sessionContainer;
Expand Down Expand Up @@ -89,12 +84,13 @@ private async Task<ContainerProperties> ReadCollectionAsync(string collectionLin
childTrace.AddDatum("Client Side Request Stats", request.RequestContext.ClientRequestStatistics);
}

(string authorizationToken, string payload) = await this.tokenProvider.GetUserAuthorizationAsync(
string authorizationToken = await this.tokenProvider.GetUserAuthorizationTokenAsync(
request.ResourceAddress,
PathsHelper.GetResourcePath(request.ResourceType),
HttpConstants.HttpMethods.Get,
request.Headers,
AuthorizationTokenType.PrimaryMasterKey);
AuthorizationTokenType.PrimaryMasterKey,
childTrace);

request.Headers[HttpConstants.HttpHeaders.Authorization] = authorizationToken;

Expand Down
126 changes: 68 additions & 58 deletions Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ internal class GatewayAddressCache : IAddressCache

private readonly Protocol protocol;
private readonly string protocolFilter;
private readonly IAuthorizationTokenProvider tokenProvider;
private readonly ICosmosAuthorizationTokenProvider tokenProvider;
private readonly bool enableTcpConnectionEndpointRediscovery;

private CosmosHttpClient httpClient;
private readonly CosmosHttpClient httpClient;

private Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> masterPartitionAddressCache;
private DateTime suboptimalMasterPartitionTimestamp;

public GatewayAddressCache(
Uri serviceEndpoint,
Protocol protocol,
IAuthorizationTokenProvider tokenProvider,
ICosmosAuthorizationTokenProvider tokenProvider,
IServiceConfigurationReader serviceConfigReader,
CosmosHttpClient httpClient,
long suboptimalPartitionForceRefreshIntervalInSeconds = 600,
Expand Down Expand Up @@ -432,29 +432,33 @@ private async Task<DocumentServiceResponse> GetMasterAddressesViaGatewayAsync(
string resourceTypeToSign = PathsHelper.GetResourcePath(resourceType);

headers.Set(HttpConstants.HttpHeaders.XDate, DateTime.UtcNow.ToString("r", CultureInfo.InvariantCulture));
(string token, string _) = await this.tokenProvider.GetUserAuthorizationAsync(
resourceAddress,
resourceTypeToSign,
HttpConstants.HttpMethods.Get,
headers,
AuthorizationTokenType.PrimaryMasterKey);

headers.Set(HttpConstants.HttpHeaders.Authorization, token);

Uri targetEndpoint = UrlUtility.SetQuery(this.addressEndpoint, UrlUtility.CreateQuery(addressQuery));

string identifier = GatewayAddressCache.LogAddressResolutionStart(request, targetEndpoint);
using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
uri: targetEndpoint,
additionalHeaders: headers,
resourceType: resourceType,
timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.Instance,
clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
cancellationToken: default))
using (ITrace trace = Trace.GetRootTrace(nameof(GetMasterAddressesViaGatewayAsync), TraceComponent.Authorization, TraceLevel.Info))
{
DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
return documentServiceResponse;
string token = await this.tokenProvider.GetUserAuthorizationTokenAsync(
resourceAddress,
resourceTypeToSign,
HttpConstants.HttpMethods.Get,
headers,
AuthorizationTokenType.PrimaryMasterKey,
trace);

headers.Set(HttpConstants.HttpHeaders.Authorization, token);

Uri targetEndpoint = UrlUtility.SetQuery(this.addressEndpoint, UrlUtility.CreateQuery(addressQuery));

string identifier = GatewayAddressCache.LogAddressResolutionStart(request, targetEndpoint);
using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
uri: targetEndpoint,
additionalHeaders: headers,
resourceType: resourceType,
timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.Instance,
clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
cancellationToken: default))
{
DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
return documentServiceResponse;
}
}
}

Expand Down Expand Up @@ -489,47 +493,53 @@ private async Task<DocumentServiceResponse> GetServerAddressesViaGatewayAsync(

headers.Set(HttpConstants.HttpHeaders.XDate, DateTime.UtcNow.ToString("r", CultureInfo.InvariantCulture));
string token = null;
try
{
token = (await this.tokenProvider.GetUserAuthorizationAsync(
collectionRid,
resourceTypeToSign,
HttpConstants.HttpMethods.Get,
headers,
AuthorizationTokenType.PrimaryMasterKey)).token;
}
catch (UnauthorizedException)
{
}

if (token == null && request != null && request.IsNameBased)
using (ITrace trace = Trace.GetRootTrace(nameof(GetMasterAddressesViaGatewayAsync), TraceComponent.Authorization, TraceLevel.Info))
{
// User doesn't have rid based resource token. Maybe he has name based.
string collectionAltLink = PathsHelper.GetCollectionPath(request.ResourceAddress);
token = (await this.tokenProvider.GetUserAuthorizationAsync(
collectionAltLink,
try
{
token = await this.tokenProvider.GetUserAuthorizationTokenAsync(
collectionRid,
resourceTypeToSign,
HttpConstants.HttpMethods.Get,
headers,
AuthorizationTokenType.PrimaryMasterKey)).token;
}
AuthorizationTokenType.PrimaryMasterKey,
trace);
}
catch (UnauthorizedException)
{
}

headers.Set(HttpConstants.HttpHeaders.Authorization, token);
if (token == null && request != null && request.IsNameBased)
{
// User doesn't have rid based resource token. Maybe he has name based.
string collectionAltLink = PathsHelper.GetCollectionPath(request.ResourceAddress);
token = await this.tokenProvider.GetUserAuthorizationTokenAsync(
collectionAltLink,
resourceTypeToSign,
HttpConstants.HttpMethods.Get,
headers,
AuthorizationTokenType.PrimaryMasterKey,
trace);
}

Uri targetEndpoint = UrlUtility.SetQuery(this.addressEndpoint, UrlUtility.CreateQuery(addressQuery));
headers.Set(HttpConstants.HttpHeaders.Authorization, token);

string identifier = GatewayAddressCache.LogAddressResolutionStart(request, targetEndpoint);
using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
uri: targetEndpoint,
additionalHeaders: headers,
resourceType: ResourceType.Document,
timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.Instance,
clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
cancellationToken: default))
{
DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
return documentServiceResponse;
Uri targetEndpoint = UrlUtility.SetQuery(this.addressEndpoint, UrlUtility.CreateQuery(addressQuery));

string identifier = GatewayAddressCache.LogAddressResolutionStart(request, targetEndpoint);
using (HttpResponseMessage httpResponseMessage = await this.httpClient.GetAsync(
uri: targetEndpoint,
additionalHeaders: headers,
resourceType: ResourceType.Document,
timeoutPolicy: HttpTimeoutPolicyControlPlaneRetriableHotPath.Instance,
clientSideRequestStatistics: request.RequestContext?.ClientRequestStatistics,
cancellationToken: default))
{
DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
return documentServiceResponse;
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions Microsoft.Azure.Cosmos/src/Routing/GlobalAddressResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal sealed class GlobalAddressResolver : IAddressResolver
private readonly GlobalEndpointManager endpointManager;
private readonly GlobalPartitionEndpointManager partitionKeyRangeLocationCache;
private readonly Protocol protocol;
private readonly IAuthorizationTokenProvider tokenProvider;
private readonly ICosmosAuthorizationTokenProvider tokenProvider;
private readonly CollectionCache collectionCache;
private readonly PartitionKeyRangeCache routingMapProvider;
private readonly int maxEndpoints;
Expand All @@ -40,7 +40,7 @@ public GlobalAddressResolver(
GlobalEndpointManager endpointManager,
GlobalPartitionEndpointManager partitionKeyRangeLocationCache,
Protocol protocol,
IAuthorizationTokenProvider tokenProvider,
ICosmosAuthorizationTokenProvider tokenProvider,
CollectionCache collectionCache,
PartitionKeyRangeCache routingMapProvider,
IServiceConfigurationReader serviceConfigReader,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ internal class PartitionKeyRangeCache : IRoutingMapProvider, ICollectionRoutingM

private readonly AsyncCache<string, CollectionRoutingMap> routingMapCache;

private readonly IAuthorizationTokenProvider authorizationTokenProvider;
private readonly ICosmosAuthorizationTokenProvider authorizationTokenProvider;
private readonly IStoreModel storeModel;
private readonly CollectionCache collectionCache;

public PartitionKeyRangeCache(
IAuthorizationTokenProvider authorizationTokenProvider,
ICosmosAuthorizationTokenProvider authorizationTokenProvider,
IStoreModel storeModel,
CollectionCache collectionCache)
{
Expand Down Expand Up @@ -248,12 +248,13 @@ private async Task<DocumentServiceResponse> ExecutePartitionKeyRangeReadChangeFe
string authorizationToken = null;
try
{
authorizationToken = (await this.authorizationTokenProvider.GetUserAuthorizationAsync(
authorizationToken = await this.authorizationTokenProvider.GetUserAuthorizationTokenAsync(
request.ResourceAddress,
PathsHelper.GetResourcePath(request.ResourceType),
HttpConstants.HttpMethods.Get,
request.Headers,
AuthorizationTokenType.PrimaryMasterKey)).token;
AuthorizationTokenType.PrimaryMasterKey,
childTrace);
}
catch (UnauthorizedException)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ void GetAadTokenCallBack(
simpleEmulatorTokenCredential,
clientOptions))
{
Assert.AreEqual(3, getAadTokenCount);
Assert.AreEqual(2, getAadTokenCount);
j82w marked this conversation as resolved.
Show resolved Hide resolved
await Task.Delay(TimeSpan.FromSeconds(1));
ResponseMessage responseMessage = await aadClient.GetDatabase(Guid.NewGuid().ToString()).ReadStreamAsync();
Assert.IsNotNull(responseMessage);
Expand Down Expand Up @@ -197,15 +197,15 @@ void GetAadTokenCallBack(
simpleEmulatorTokenCredential,
clientOptions))
{
Assert.AreEqual(3, getAadTokenCount);
Assert.AreEqual(2, getAadTokenCount);
await Task.Delay(TimeSpan.FromSeconds(1));
try
{
ResponseMessage responseMessage =
await aadClient.GetDatabase(Guid.NewGuid().ToString()).ReadStreamAsync();
Assert.Fail("Should throw auth error.");
}
catch (CosmosException ce) when (ce.StatusCode == HttpStatusCode.Unauthorized)
catch (RequestFailedException ce) when (ce.Status == (int)HttpStatusCode.RequestTimeout)
{
Assert.IsNotNull(ce.Message);
Assert.IsTrue(ce.ToString().Contains(errorMessage));
Expand Down
Loading