Skip to content

Commit

Permalink
Synchronize OperationInternal calls (#28375)
Browse files Browse the repository at this point in the history
* Initial commit

* Minimize code duplication

* Fix tests

* Fix OperationInternal

* Use AsyncLockWithValue from Core

* Fix operation naming

* Address PR comments

* Address PR comment

* More tests

* Fix one test by hand

* More fixed test records

* More fixed test records

* More fixed test records

* More fixed test records

* Fix test records

* Add tests

* Added comments

* Address CR comments
  • Loading branch information
AlexanderSher authored and paterasMSFT committed Jun 14, 2022
1 parent 4a0dcef commit 9153155
Show file tree
Hide file tree
Showing 69 changed files with 4,394 additions and 4,772 deletions.
1 change: 1 addition & 0 deletions eng/Directory.Build.Common.targets
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@

<!-- *********** Files needed for LRO ************* -->
<ItemGroup Condition="'$(IncludeOperationsSharedSource)' == 'true'">
<Compile Include="$(AzureCoreSharedSources)AsyncLockWithValue.cs" LinkBase="Shared/Core" />
<Compile Include="$(AzureCoreSharedSources)OperationHelpers.cs" LinkBase="Shared/Core" />
<Compile Include="$(AzureCoreSharedSources)OperationInternal.cs" LinkBase="Shared/Core" />
<Compile Include="$(AzureCoreSharedSources)OperationInternalBase.cs" LinkBase="Shared/Core" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,13 @@ async ValueTask<OperationState> IOperation.UpdateStateAsync(bool async, Cancella
.ConfigureAwait(false)
: _client.GetTransactionStatus(Id, new RequestContext { CancellationToken = cancellationToken, ErrorOptions = ErrorOptions.NoThrow });

_operationInternal.RawResponse = statusResponse;

if (statusResponse.Status != (int)HttpStatusCode.OK)
{
var error = new ResponseError(null, exceptionMessage);
var ex = async
? await _client.ClientDiagnostics.CreateRequestFailedExceptionAsync(statusResponse, error).ConfigureAwait(false)
: _client.ClientDiagnostics.CreateRequestFailedException(statusResponse, error);
return OperationState.Failure(GetRawResponse(), new RequestFailedException(exceptionMessage, ex));
return OperationState.Failure(statusResponse, new RequestFailedException(exceptionMessage, ex));
}
string status = JsonDocument.Parse(statusResponse.Content)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,19 @@ internal static object InvokeWaitForCompletionResponse(Operation operation, Canc
return InjectZeroPoller().WaitForCompletionResponseAsync(operation, null, cancellationToken);
}

internal static object InvokeWaitForCompletion<T>(Operation<T> operation, CancellationToken cancellationToken)
{
return InjectZeroPoller().WaitForCompletionAsync(operation, null, cancellationToken);
}

internal static object InvokeWaitForCompletion(object target, Type targetType, CancellationToken cancellationToken)
{
// get the concrete instance of OperationPoller.ValueTask<Response<T>> WaitForCompletionAsync<T>(Operation<T>, TimeSpan?, CancellationToken)
var poller = InjectZeroPoller();
var genericMethod = poller.GetType().GetMethods().Where(m => m.Name == PollerWaitForCompletionAsyncName).FirstOrDefault(m => m.GetParameters().Length == 6);
var method = genericMethod.MakeGenericMethod(GetOperationOfT(targetType).GetGenericArguments());
var method = typeof(OperationInterceptor)
.GetMethods(BindingFlags.Static | BindingFlags.NonPublic)
.First(m => m.IsGenericMethodDefinition && m.Name == nameof(InvokeWaitForCompletion))
.MakeGenericMethod(GetOperationOfT(targetType).GetGenericArguments());

var methodParams = method.GetParameters();
var updateStatus = Delegate.CreateDelegate(methodParams[0].ParameterType, target, targetType.GetMethod("UpdateStatusAsync"));
var hasCompleted = Delegate.CreateDelegate(methodParams[1].ParameterType, target, targetType.GetMethod("get_HasCompleted"));
var value = Delegate.CreateDelegate(methodParams[2].ParameterType, target, targetType.GetMethod("get_Value"));
var getResponse = Delegate.CreateDelegate(methodParams[3].ParameterType, target, targetType.GetMethod("GetRawResponse"));
//need to call the method that takes in the delegates so we used the runtime versions which allows mocking
//public virtual async ValueTask<Response<T>> WaitForCompletionAsync<T>(UpdateStatusAsync updateStatusAsync, HasCompleted hasCompleted, Value<T> value, GetRawResponse getRawResponse, TimeSpan? suggestedInterval, CancellationToken cancellationToken)
return method.Invoke(poller, new object[] { updateStatus, hasCompleted, value, getResponse, null, cancellationToken});
return method.Invoke(null, new[] {target, cancellationToken});
}

private void CheckArguments(object[] invocationArguments)
Expand Down
4 changes: 4 additions & 0 deletions sdk/core/Azure.Core/src/Azure.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<ItemGroup>
<Compile Remove="Shared\**\*.cs" />
<Compile Include="Shared\AppContextSwitchHelper.cs" />
<Compile Include="Shared\AsyncLockWithValue.cs" />
<Compile Include="Shared\Argument.cs" />
<Compile Include="Shared\AuthorizationChallengeParser.cs" />
<Compile Include="Shared\AzureEventSource.cs" />
Expand All @@ -59,6 +60,9 @@
<Compile Include="Shared\HttpMessageSanitizer.cs" />
<Compile Include="Shared\InitializationConstructorAttribute.cs" />
<Compile Include="Shared\NullableAttributes.cs" />
<Compile Include="Shared\OperationInternalBase.cs" />
<Compile Include="Shared\OperationInternal.cs" />
<Compile Include="Shared\OperationInternalOfT.cs" />
<Compile Include="Shared\RetryAfterDelayStrategy.cs" />
<Compile Include="Shared\SerializationConstructorAttribute.cs" />
<Compile Include="Shared\TaskExtensions.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#nullable enable

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core.Pipeline;

namespace Azure.Identity
namespace Azure.Core
{
/// <summary>
/// Primitive that combines async lock and value cache
/// </summary>
/// <typeparam name="T"></typeparam>
internal sealed class AsyncLockWithValue<T>
{
private readonly object _syncObj = new object();
private Queue<TaskCompletionSource<Lock>> _waiters;
private readonly object _syncObj = new();
private Queue<TaskCompletionSource<LockOrValue>>? _waiters;
private bool _isLocked;
private bool _hasValue;
private T _value;
private long _index;
private T? _value;

public bool HasValue
{
get
{
lock (_syncObj)
{
return _hasValue;
}
}
}

public AsyncLockWithValue() { }

public AsyncLockWithValue(T value)
{
_hasValue = true;
_value = value;
}

public bool TryGetValue(out T? value)
{
lock (_syncObj)
{
if (_hasValue)
{
value = _value;
return true;
}
}

value = default;
return false;
}

/// <summary>
/// Method that either returns cached value or acquire a lock.
Expand All @@ -30,31 +68,32 @@ internal sealed class AsyncLockWithValue<T>
/// <param name="async"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async ValueTask<Lock> GetLockOrValueAsync(bool async, CancellationToken cancellationToken = default)
public async ValueTask<LockOrValue> GetLockOrValueAsync(bool async, CancellationToken cancellationToken = default)
{
TaskCompletionSource<Lock> valueTcs;
TaskCompletionSource<LockOrValue> valueTcs;
lock (_syncObj)
{
// If there is a value, just return it
if (_hasValue)
{
return new Lock(_value);
return new LockOrValue(_value!);
}

// If lock isn't acquire yet, acquire it and return to the caller
if (!_isLocked)
{
_isLocked = true;
return new Lock(this);
_index = unchecked(_index + 1);
return new LockOrValue(this, _index);
}

// Check cancellationToken before instantiating waiter
cancellationToken.ThrowIfCancellationRequested();

// If lock is already taken, create a waiter and wait either until value is set or lock can be acquired by this waiter
_waiters ??= new Queue<TaskCompletionSource<Lock>>();
_waiters ??= new Queue<TaskCompletionSource<LockOrValue>>();
// if async == false, valueTcs will be waited only in this thread and only synchronously, so RunContinuationsAsynchronously isn't needed.
valueTcs = new TaskCompletionSource<Lock>(async ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
valueTcs = new TaskCompletionSource<LockOrValue>(async ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
_waiters.Enqueue(valueTcs);
}

Expand Down Expand Up @@ -89,13 +128,20 @@ public async ValueTask<Lock> GetLockOrValueAsync(bool async, CancellationToken c
/// Set value to the cache and to all the waiters
/// </summary>
/// <param name="value"></param>
private void SetValue(T value)
/// <param name="lockIndex"></param>
private void SetValue(T value, in long lockIndex)
{
Queue<TaskCompletionSource<Lock>> waiters;
Queue<TaskCompletionSource<LockOrValue>> waiters;
lock (_syncObj)
{
if (lockIndex != _index)
{
throw new InvalidOperationException($"Disposed {nameof(LockOrValue)} tries to set value. Current index: {_index}, {nameof(LockOrValue)} index: {lockIndex}");
}

_value = value;
_hasValue = true;
_index = 0;
_isLocked = false;
if (_waiters == default)
{
Expand All @@ -108,73 +154,107 @@ private void SetValue(T value)

while (waiters.Count > 0)
{
waiters.Dequeue().TrySetResult(new Lock(value));
waiters.Dequeue().TrySetResult(new LockOrValue(value));
}
}

/// <summary>
/// Release the lock and allow next waiter acquire it
/// </summary>
private void Reset()
private void Reset(in long lockIndex)
{
TaskCompletionSource<Lock> nextWaiter = UnlockOrGetNextWaiter();
while (nextWaiter != default && !nextWaiter.TrySetResult(new Lock(this)))
UnlockOrGetNextWaiter(lockIndex, out var nextWaiter);
while (nextWaiter != default && !nextWaiter.TrySetResult(new LockOrValue(this, unchecked(lockIndex + 1))))
{
nextWaiter = UnlockOrGetNextWaiter();
UnlockOrGetNextWaiter(lockIndex, out nextWaiter);
}
}

private TaskCompletionSource<Lock> UnlockOrGetNextWaiter()
private void UnlockOrGetNextWaiter(in long lockIndex, out TaskCompletionSource<LockOrValue>? nextWaiter)
{
lock (_syncObj)
{
if (!_isLocked)
nextWaiter = default;
// If lock isn't acquired, just return
if (!_isLocked || lockIndex != _index)
{
return default;
return;
}

_index = unchecked(lockIndex + 1);

// If lock was acquired, but there are no waiters, unlock and return
if (_waiters == default)
{
_isLocked = false;
return default;
return;
}

// Find the next waiter
while (_waiters.Count > 0)
{
var nextWaiter = _waiters.Dequeue();
nextWaiter = _waiters.Dequeue();
if (!nextWaiter.Task.IsCompleted)
{
// Return the waiter only if it wasn't canceled already
return nextWaiter;
return;
}
}

// If no next waiter has been found, unlock and return
_isLocked = false;
return default;
}
}

public readonly struct Lock : IDisposable
public readonly struct LockOrValue : IDisposable
{
private readonly AsyncLockWithValue<T> _owner;
private readonly AsyncLockWithValue<T>? _owner;
private readonly T? _value;
private readonly long _index;

/// <summary>
/// Returns true if lock contains the cached value. Otherwise false.
/// </summary>
public bool HasValue => _owner == default;
public T Value { get; }

public Lock(T value)
/// <summary>
/// Returns cached value if it was set when lock has been created. Throws exception otherwise.
/// </summary>
/// <exception cref="InvalidOperationException">Value isn't set.</exception>
public T Value => HasValue ? _value! : throw new InvalidOperationException("Value isn't set");

public LockOrValue(T value)
{
_owner = default;
Value = value;
_value = value;
_index = 0;
}

public Lock(AsyncLockWithValue<T> owner)
public LockOrValue(AsyncLockWithValue<T> owner, long index)
{
_owner = owner;
Value = default;
_index = index;
_value = default;
}

public void SetValue(T value) => _owner.SetValue(value);
/// <summary>
/// Set value to the cache and to all the waiters.
/// </summary>
/// <param name="value"></param>
/// <exception cref="InvalidOperationException">Value is set already.</exception>
public void SetValue(T value)
{
if (_owner != null)
{
_owner.SetValue(value, _index);
}
else
{
throw new InvalidOperationException("Value for the lock is set already");
}
}

public void Dispose() => _owner?.Reset();
public void Dispose() => _owner?.Reset(_index);
}
}
}
Loading

0 comments on commit 9153155

Please sign in to comment.