Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronize OperationInternal calls #28375

Merged
merged 18 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eng/Directory.Build.Common.targets
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@

<!-- *********** Files needed for LRO ************* -->
<ItemGroup Condition="'$(IncludeOperationsSharedSource)' == 'true'">
<Compile Include="$(AzureCoreSharedSources)AsyncLockWithValue.cs" LinkBase="Shared/Core" />
<Compile Include="$(AzureCoreSharedSources)OperationHelpers.cs" LinkBase="Shared/Core" />
<Compile Include="$(AzureCoreSharedSources)OperationInternal.cs" LinkBase="Shared/Core" />
<Compile Include="$(AzureCoreSharedSources)OperationInternalBase.cs" LinkBase="Shared/Core" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,13 @@ async ValueTask<OperationState> IOperation.UpdateStateAsync(bool async, Cancella
.ConfigureAwait(false)
: _client.GetTransactionStatus(Id, new RequestContext { CancellationToken = cancellationToken, ErrorOptions = ErrorOptions.NoThrow });

_operationInternal.RawResponse = statusResponse;

if (statusResponse.Status != (int)HttpStatusCode.OK)
{
var error = new ResponseError(null, exceptionMessage);
var ex = async
? await _client.ClientDiagnostics.CreateRequestFailedExceptionAsync(statusResponse, error).ConfigureAwait(false)
: _client.ClientDiagnostics.CreateRequestFailedException(statusResponse, error);
return OperationState.Failure(GetRawResponse(), new RequestFailedException(exceptionMessage, ex));
return OperationState.Failure(statusResponse, new RequestFailedException(exceptionMessage, ex));
}

string status = JsonDocument.Parse(statusResponse.Content)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,19 @@ internal static object InvokeWaitForCompletionResponse(Operation operation, Canc
return InjectZeroPoller().WaitForCompletionResponseAsync(operation, null, cancellationToken);
}

internal static object InvokeWaitForCompletion<T>(Operation<T> operation, CancellationToken cancellationToken)
{
return InjectZeroPoller().WaitForCompletionAsync(operation, null, cancellationToken);
}

internal static object InvokeWaitForCompletion(object target, Type targetType, CancellationToken cancellationToken)
{
// get the concrete instance of OperationPoller.ValueTask<Response<T>> WaitForCompletionAsync<T>(Operation<T>, TimeSpan?, CancellationToken)
var poller = InjectZeroPoller();
var genericMethod = poller.GetType().GetMethods().Where(m => m.Name == PollerWaitForCompletionAsyncName).FirstOrDefault(m => m.GetParameters().Length == 6);
var method = genericMethod.MakeGenericMethod(GetOperationOfT(targetType).GetGenericArguments());
var method = typeof(OperationInterceptor)
.GetMethods(BindingFlags.Static | BindingFlags.NonPublic)
.First(m => m.IsGenericMethodDefinition && m.Name == nameof(InvokeWaitForCompletion))
.MakeGenericMethod(GetOperationOfT(targetType).GetGenericArguments());

var methodParams = method.GetParameters();
var updateStatus = Delegate.CreateDelegate(methodParams[0].ParameterType, target, targetType.GetMethod("UpdateStatusAsync"));
var hasCompleted = Delegate.CreateDelegate(methodParams[1].ParameterType, target, targetType.GetMethod("get_HasCompleted"));
var value = Delegate.CreateDelegate(methodParams[2].ParameterType, target, targetType.GetMethod("get_Value"));
var getResponse = Delegate.CreateDelegate(methodParams[3].ParameterType, target, targetType.GetMethod("GetRawResponse"));
//need to call the method that takes in the delegates so we used the runtime versions which allows mocking
//public virtual async ValueTask<Response<T>> WaitForCompletionAsync<T>(UpdateStatusAsync updateStatusAsync, HasCompleted hasCompleted, Value<T> value, GetRawResponse getRawResponse, TimeSpan? suggestedInterval, CancellationToken cancellationToken)
return method.Invoke(poller, new object[] { updateStatus, hasCompleted, value, getResponse, null, cancellationToken});
return method.Invoke(null, new[] {target, cancellationToken});
}

private void CheckArguments(object[] invocationArguments)
Expand Down
4 changes: 4 additions & 0 deletions sdk/core/Azure.Core/src/Azure.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<ItemGroup>
<Compile Remove="Shared\**\*.cs" />
<Compile Include="Shared\AppContextSwitchHelper.cs" />
<Compile Include="Shared\AsyncLockWithValue.cs" />
<Compile Include="Shared\Argument.cs" />
<Compile Include="Shared\AuthorizationChallengeParser.cs" />
<Compile Include="Shared\AzureEventSource.cs" />
Expand All @@ -59,6 +60,9 @@
<Compile Include="Shared\HttpMessageSanitizer.cs" />
<Compile Include="Shared\InitializationConstructorAttribute.cs" />
<Compile Include="Shared\NullableAttributes.cs" />
<Compile Include="Shared\OperationInternalBase.cs" />
<Compile Include="Shared\OperationInternal.cs" />
<Compile Include="Shared\OperationInternalOfT.cs" />
<Compile Include="Shared\RetryAfterDelayStrategy.cs" />
<Compile Include="Shared\SerializationConstructorAttribute.cs" />
<Compile Include="Shared\TaskExtensions.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#nullable enable

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core.Pipeline;

namespace Azure.Identity
namespace Azure.Core
{
/// <summary>
/// Primitive that combines async lock and value cache
/// </summary>
/// <typeparam name="T"></typeparam>
internal sealed class AsyncLockWithValue<T>
AlexanderSher marked this conversation as resolved.
Show resolved Hide resolved
{
private readonly object _syncObj = new object();
private Queue<TaskCompletionSource<Lock>> _waiters;
private readonly object _syncObj = new();
private Queue<TaskCompletionSource<LockOrValue>>? _waiters;
AlexanderSher marked this conversation as resolved.
Show resolved Hide resolved
private bool _isLocked;
private bool _hasValue;
private T _value;
private long _index;
private T? _value;

public bool HasValue
{
get
{
lock (_syncObj)
{
return _hasValue;
}
}
}

public AsyncLockWithValue() { }

public AsyncLockWithValue(T value)
{
_hasValue = true;
_value = value;
}

public bool TryGetValue(out T? value)
{
lock (_syncObj)
heaths marked this conversation as resolved.
Show resolved Hide resolved
{
if (_hasValue)
AlexanderSher marked this conversation as resolved.
Show resolved Hide resolved
{
value = _value;
return true;
}
}

value = default;
return false;
}

/// <summary>
/// Method that either returns cached value or acquire a lock.
Expand All @@ -30,31 +68,32 @@ internal sealed class AsyncLockWithValue<T>
/// <param name="async"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async ValueTask<Lock> GetLockOrValueAsync(bool async, CancellationToken cancellationToken = default)
public async ValueTask<LockOrValue> GetLockOrValueAsync(bool async, CancellationToken cancellationToken = default)
{
TaskCompletionSource<Lock> valueTcs;
TaskCompletionSource<LockOrValue> valueTcs;
lock (_syncObj)
{
// If there is a value, just return it
if (_hasValue)
{
return new Lock(_value);
return new LockOrValue(_value!);
}

// If lock isn't acquire yet, acquire it and return to the caller
if (!_isLocked)
{
_isLocked = true;
return new Lock(this);
_index = unchecked(_index + 1);
return new LockOrValue(this, _index);
}

// Check cancellationToken before instantiating waiter
cancellationToken.ThrowIfCancellationRequested();

// If lock is already taken, create a waiter and wait either until value is set or lock can be acquired by this waiter
_waiters ??= new Queue<TaskCompletionSource<Lock>>();
_waiters ??= new Queue<TaskCompletionSource<LockOrValue>>();
// if async == false, valueTcs will be waited only in this thread and only synchronously, so RunContinuationsAsynchronously isn't needed.
valueTcs = new TaskCompletionSource<Lock>(async ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
valueTcs = new TaskCompletionSource<LockOrValue>(async ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
_waiters.Enqueue(valueTcs);
}

Expand Down Expand Up @@ -89,13 +128,20 @@ public async ValueTask<Lock> GetLockOrValueAsync(bool async, CancellationToken c
/// Set value to the cache and to all the waiters
/// </summary>
/// <param name="value"></param>
private void SetValue(T value)
/// <param name="lockIndex"></param>
private void SetValue(T value, in long lockIndex)
{
Queue<TaskCompletionSource<Lock>> waiters;
Queue<TaskCompletionSource<LockOrValue>> waiters;
lock (_syncObj)
{
if (lockIndex != _index)
{
throw new InvalidOperationException($"Disposed {nameof(LockOrValue)} tries to set value. Current index: {_index}, {nameof(LockOrValue)} index: {lockIndex}");
}

_value = value;
_hasValue = true;
_index = 0;
_isLocked = false;
if (_waiters == default)
{
Expand All @@ -108,73 +154,107 @@ private void SetValue(T value)

while (waiters.Count > 0)
{
waiters.Dequeue().TrySetResult(new Lock(value));
waiters.Dequeue().TrySetResult(new LockOrValue(value));
}
}

/// <summary>
/// Release the lock and allow next waiter acquire it
/// </summary>
private void Reset()
private void Reset(in long lockIndex)
{
TaskCompletionSource<Lock> nextWaiter = UnlockOrGetNextWaiter();
while (nextWaiter != default && !nextWaiter.TrySetResult(new Lock(this)))
UnlockOrGetNextWaiter(lockIndex, out var nextWaiter);
while (nextWaiter != default && !nextWaiter.TrySetResult(new LockOrValue(this, unchecked(lockIndex + 1))))
{
nextWaiter = UnlockOrGetNextWaiter();
UnlockOrGetNextWaiter(lockIndex, out nextWaiter);
}
}

private TaskCompletionSource<Lock> UnlockOrGetNextWaiter()
private void UnlockOrGetNextWaiter(in long lockIndex, out TaskCompletionSource<LockOrValue>? nextWaiter)
{
lock (_syncObj)
{
if (!_isLocked)
nextWaiter = default;
// If lock isn't acquired, just return
if (!_isLocked || lockIndex != _index)
{
return default;
return;
}

_index = unchecked(lockIndex + 1);

// If lock was acquired, but there are no waiters, unlock and return
if (_waiters == default)
{
_isLocked = false;
return default;
return;
}

// Find the next waiter
while (_waiters.Count > 0)
{
var nextWaiter = _waiters.Dequeue();
nextWaiter = _waiters.Dequeue();
if (!nextWaiter.Task.IsCompleted)
{
// Return the waiter only if it wasn't canceled already
return nextWaiter;
return;
}
}

// If no next waiter has been found, unlock and return
_isLocked = false;
return default;
}
}

public readonly struct Lock : IDisposable
public readonly struct LockOrValue : IDisposable
AlexanderSher marked this conversation as resolved.
Show resolved Hide resolved
{
private readonly AsyncLockWithValue<T> _owner;
private readonly AsyncLockWithValue<T>? _owner;
private readonly T? _value;
private readonly long _index;

/// <summary>
/// Returns true if lock contains the cached value. Otherwise false.
/// </summary>
public bool HasValue => _owner == default;
public T Value { get; }

public Lock(T value)
/// <summary>
/// Returns cached value if it was set when lock has been created. Throws exception otherwise.
/// </summary>
/// <exception cref="InvalidOperationException">Value isn't set.</exception>
public T Value => HasValue ? _value! : throw new InvalidOperationException("Value isn't set");

public LockOrValue(T value)
{
_owner = default;
Value = value;
_value = value;
_index = 0;
}

public Lock(AsyncLockWithValue<T> owner)
public LockOrValue(AsyncLockWithValue<T> owner, long index)
{
_owner = owner;
Value = default;
_index = index;
_value = default;
}

public void SetValue(T value) => _owner.SetValue(value);
/// <summary>
/// Set value to the cache and to all the waiters.
/// </summary>
/// <param name="value"></param>
/// <exception cref="InvalidOperationException">Value is set already.</exception>
public void SetValue(T value)
{
if (_owner != null)
{
_owner.SetValue(value, _index);
}
else
{
throw new InvalidOperationException("Value for the lock is set already");
}
}

public void Dispose() => _owner?.Reset();
public void Dispose() => _owner?.Reset(_index);
}
}
}
Loading