Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes race condition on captive scoped services #53325

Merged
merged 8 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,12 @@ public override Func<ServiceProviderEngineScope, object> RealizeService(ServiceC
int callCount = 0;
return scope =>
{
// We want to directly use the callsite value if it's set and the scope is the root scope.
// We've already called into the RuntimeResolver and pre-computed any singletons or root scope
// Avoid the compilation for singletons (or promoted singletons)
if (scope.IsRootScope && callSite.Value != null)
{
return callSite.Value;
}

// Resolve the result before we increment the call count, this ensures that singletons
// won't cause any side effects during the compilation of the resolve function.
var result = CallSiteRuntimeResolver.Instance.Resolve(callSite, scope);

if (Interlocked.Increment(ref callCount) == 2)
{
// This second check is to avoid the race where we end up kicking off a background thread
// if multiple calls to GetService race and resolve the values for singletons before the initial check above.
if (scope.IsRootScope && callSite.Value != null)
{
return callSite.Value;
}

// Don't capture the ExecutionContext when forking to build the compiled version of the
// resolve function
_ = ThreadPool.UnsafeQueueUserWorkItem(_ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ internal sealed class ExpressionResolverBuilder : CallSiteVisitor<object, Expres
internal static readonly MethodInfo InvokeFactoryMethodInfo = GetMethodInfo<Action<Func<IServiceProvider, object>, IServiceProvider>>((a, b) => a.Invoke(b));
internal static readonly MethodInfo CaptureDisposableMethodInfo = GetMethodInfo<Func<ServiceProviderEngineScope, object, object>>((a, b) => a.CaptureDisposable(b));
internal static readonly MethodInfo TryGetValueMethodInfo = GetMethodInfo<Func<IDictionary<ServiceCacheKey, object>, ServiceCacheKey, object, bool>>((a, b, c) => a.TryGetValue(b, out c));
internal static readonly MethodInfo ResolveCallSiteAndScopeMethodInfo = GetMethodInfo<Func<CallSiteRuntimeResolver, ServiceCallSite, ServiceProviderEngineScope, object>>((a, b, c) => a.Resolve(b, c));
internal static readonly MethodInfo AddMethodInfo = GetMethodInfo<Action<IDictionary<ServiceCacheKey, object>, ServiceCacheKey, object>>((a, b, c) => a.Add(b, c));
internal static readonly MethodInfo MonitorEnterMethodInfo = GetMethodInfo<Action<object, bool>>((lockObj, lockTaken) => Monitor.Enter(lockObj, ref lockTaken));
internal static readonly MethodInfo MonitorExitMethodInfo = GetMethodInfo<Action<object>>(lockObj => Monitor.Exit(lockObj));
Expand Down Expand Up @@ -44,6 +45,10 @@ internal sealed class ExpressionResolverBuilder : CallSiteVisitor<object, Expres
Expression.Call(ScopeParameter, CaptureDisposableMethodInfo, CaptureDisposableParameter),
CaptureDisposableParameter);

private static readonly ConstantExpression CallSiteRuntimeResolverInstanceExpression = Expression.Constant(
CallSiteRuntimeResolver.Instance,
typeof(CallSiteRuntimeResolver));

private readonly ServiceProviderEngineScope _rootScope;

private readonly ConcurrentDictionary<ServiceCacheKey, Func<ServiceProviderEngineScope, object>> _scopeResolverCache;
Expand Down Expand Up @@ -203,6 +208,19 @@ protected override Expression VisitScopeCache(ServiceCallSite callSite, object c
// Move off the main stack
private Expression BuildScopedExpression(ServiceCallSite callSite)
{
ConstantExpression callSiteExpression = Expression.Constant(
callSite,
typeof(ServiceCallSite));

// We want to directly use the callsite value if it's set and the scope is the root scope.
// We've already called into the RuntimeResolver and pre-computed any singletons or root scope
// Avoid the compilation for singletons (or promoted singletons)
MethodCallExpression resolveRootScopeExpression = Expression.Call(
CallSiteRuntimeResolverInstanceExpression,
ResolveCallSiteAndScopeMethodInfo,
callSiteExpression,
ScopeParameter);

ConstantExpression keyExpression = Expression.Constant(
callSite.Cache.Key,
typeof(ServiceCacheKey));
Expand Down Expand Up @@ -254,10 +272,17 @@ private Expression BuildScopedExpression(ServiceCallSite callSite)
BlockExpression tryBody = Expression.Block(monitorEnter, blockExpression);
ConditionalExpression finallyBody = Expression.IfThen(lockWasTaken, monitorExit);

return Expression.Block(
typeof(object),
new[] { lockWasTaken },
Expression.TryFinally(tryBody, finallyBody));
return Expression.Condition(
Expression.Property(
ScopeParameter,
typeof(ServiceProviderEngineScope)
.GetProperty(nameof(ServiceProviderEngineScope.IsRootScope), BindingFlags.Instance | BindingFlags.Public)),
resolveRootScopeExpression,
Expression.Block(
typeof(object),
new[] { lockWasTaken },
Expression.TryFinally(tryBody, finallyBody))
);
}

private static MethodInfo GetMethodInfo<T>(Expression<T> expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Reflection;
using System.Reflection.Emit;
using Microsoft.Extensions.DependencyInjection.ServiceLookup;

namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
Expand All @@ -17,6 +18,15 @@ internal sealed class ILEmitResolverBuilder : CallSiteVisitor<ILEmitResolverBuil
private static readonly MethodInfo ScopeLockGetter = typeof(ServiceProviderEngineScope).GetProperty(
nameof(ServiceProviderEngineScope.Sync), BindingFlags.Instance | BindingFlags.NonPublic).GetMethod;

private static readonly MethodInfo ScopeIsRootScope = typeof(ServiceProviderEngineScope).GetProperty(
nameof(ServiceProviderEngineScope.IsRootScope), BindingFlags.Instance | BindingFlags.Public).GetMethod;

private static readonly MethodInfo CallSiteRuntimeResolverResolveMethod = typeof(CallSiteRuntimeResolver).GetMethod(
nameof(CallSiteRuntimeResolver.Resolve), BindingFlags.Public | BindingFlags.Instance);

private static readonly MethodInfo CallSiteRuntimeResolverInstanceField = typeof(CallSiteRuntimeResolver).GetProperty(
nameof(CallSiteRuntimeResolver.Instance), BindingFlags.Static | BindingFlags.Public | BindingFlags.Instance).GetMethod;

private static readonly FieldInfo FactoriesField = typeof(ILEmitResolverBuilderRuntimeContext).GetField(nameof(ILEmitResolverBuilderRuntimeContext.Factories));
private static readonly FieldInfo ConstantsField = typeof(ILEmitResolverBuilderRuntimeContext).GetField(nameof(ILEmitResolverBuilderRuntimeContext.Constants));
private static readonly MethodInfo GetTypeFromHandleMethod = typeof(Type).GetMethod(nameof(Type.GetTypeFromHandle));
Expand Down Expand Up @@ -99,7 +109,7 @@ private GeneratedMethod BuildTypeNoCache(ServiceCallSite callSite)
"ResolveService", MethodAttributes.Public | MethodAttributes.Static, CallingConventions.Standard, typeof(object),
new[] { typeof(ILEmitResolverBuilderRuntimeContext), typeof(ServiceProviderEngineScope) });

GenerateMethodBody(callSite, method.GetILGenerator(), info);
GenerateMethodBody(callSite, method.GetILGenerator());
type.CreateTypeInfo();
assembly.Save(assemblyName + ".dll");
#endif
Expand Down Expand Up @@ -281,6 +291,10 @@ private ILEmitResolverBuilderRuntimeContext GenerateMethodBody(ServiceCallSite c
Factories = null
};

// if (scope.IsRootScope)
// {
// return CallSiteRuntimeResolver.Instance.Resolve(callSite, scope);
// }
// var cacheKey = scopedCallSite.CacheKey;
// try
// {
Expand Down Expand Up @@ -309,8 +323,21 @@ private ILEmitResolverBuilderRuntimeContext GenerateMethodBody(ServiceCallSite c

Label skipCreationLabel = context.Generator.DefineLabel();
Label returnLabel = context.Generator.DefineLabel();
Label defaultLabel = context.Generator.DefineLabel();

// Check if scope IsRootScope
context.Generator.Emit(OpCodes.Ldarg_1);
maryamariyan marked this conversation as resolved.
Show resolved Hide resolved
context.Generator.Emit(OpCodes.Callvirt, ScopeIsRootScope);
context.Generator.Emit(OpCodes.Brfalse_S, defaultLabel);

context.Generator.Emit(OpCodes.Call, CallSiteRuntimeResolverInstanceField);
AddConstant(context, callSite);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice change. @pakrym but we're paying for bounds checks!! 😄

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change AddConstant to produce unsafe code with fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think it matters though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is locking and has a try finally so I'd be surprised if it mattered as well.

We can change AddConstant to produce unsafe code with fixed.

Unsafe.Add

context.Generator.Emit(OpCodes.Ldarg_1);
context.Generator.Emit(OpCodes.Callvirt, CallSiteRuntimeResolverResolveMethod);
context.Generator.Emit(OpCodes.Ret);

// Generate cache key
context.Generator.MarkLabel(defaultLabel);
AddCacheKey(context, callSite.Cache.Key);
// and store to local
Stloc(context.Generator, cacheKeyLocal.LocalIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvider, IAsyncDisposable, IServiceScopeFactory
{
// For testing only
internal Action<object> _captureDisposableCallback;
internal IList<object> Disposables => _disposables ?? (IList<object>)Array.Empty<object>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removes a pointer 😄. Saving the planet.


private bool _disposed;
private List<object> _disposables;
Expand Down Expand Up @@ -49,8 +49,6 @@ public object GetService(Type serviceType)

internal object CaptureDisposable(object service)
{
_captureDisposableCallback?.Invoke(service);

if (ReferenceEquals(this, service) || !(service is IDisposable || service is IAsyncDisposable))
{
return service;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,13 @@ public void BuildExpressionAddsDisposableCaptureForDisposableServices(ServiceLif

var disposables = new List<object>();
var provider = new ServiceProvider(descriptors, ServiceProviderOptions.Default);
provider.Root._captureDisposableCallback = obj =>
{
disposables.Add(obj);
};

var callSite = provider.CallSiteFactory.GetCallSite(typeof(ServiceC), new CallSiteChain());
var compiledCallSite = CompileCallSite(callSite, provider);

var serviceC = (DisposableServiceC)compiledCallSite(provider.Root);

Assert.Equal(3, disposables.Count);
Assert.Equal(3, provider.Root.Disposables.Count);
}

[Theory]
Expand All @@ -161,16 +158,13 @@ public void BuildExpressionAddsDisposableCaptureForDisposableFactoryServices(Ser

var disposables = new List<object>();
var provider = new ServiceProvider(descriptors, ServiceProviderOptions.Default);
provider.Root._captureDisposableCallback = obj =>
{
disposables.Add(obj);
};

var callSite = provider.CallSiteFactory.GetCallSite(typeof(ServiceC), new CallSiteChain());
var compiledCallSite = CompileCallSite(callSite, provider);

var serviceC = (DisposableServiceC)compiledCallSite(provider.Root);

Assert.Equal(3, disposables.Count);
Assert.Equal(3, provider.Root.Disposables.Count);
}

[Theory]
Expand All @@ -190,16 +184,13 @@ public void BuildExpressionElidesDisposableCaptureForNonDisposableServices(Servi

var disposables = new List<object>();
var provider = new ServiceProvider(descriptors, ServiceProviderOptions.Default);
provider.Root._captureDisposableCallback = obj =>
{
disposables.Add(obj);
};

var callSite = provider.CallSiteFactory.GetCallSite(typeof(ServiceC), new CallSiteChain());
var compiledCallSite = CompileCallSite(callSite, provider);

var serviceC = (ServiceC)compiledCallSite(provider.Root);

Assert.Empty(disposables);
Assert.Empty(provider.Root.Disposables);
}

[Theory]
Expand All @@ -215,16 +206,13 @@ public void BuildExpressionElidesDisposableCaptureForEnumerableServices(ServiceL

var disposables = new List<object>();
var provider = new ServiceProvider(descriptors, ServiceProviderOptions.Default);
provider.Root._captureDisposableCallback = obj =>
{
disposables.Add(obj);
};

var callSite = provider.CallSiteFactory.GetCallSite(typeof(ServiceD), new CallSiteChain());
var compiledCallSite = CompileCallSite(callSite, provider);

var serviceD = (ServiceD)compiledCallSite(provider.Root);

Assert.Empty(disposables);
Assert.Empty(provider.Root.Disposables);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,54 @@ public void ScopedServiceResolvedFromSingletonAfterCompilation()
}
}

[Theory]
[InlineData(ServiceProviderMode.Default)]
[InlineData(ServiceProviderMode.Dynamic)]
[InlineData(ServiceProviderMode.Runtime)]
[InlineData(ServiceProviderMode.Expressions)]
[InlineData(ServiceProviderMode.ILEmit)]
private void ScopedServiceResolvedFromSingletonAfterCompilation2(ServiceProviderMode mode)
{
ServiceProvider sp = new ServiceCollection()
.AddScoped<A>()
.AddSingleton<IFakeOpenGenericService<A>, FakeOpenGenericService<A>>()
.BuildServiceProvider(mode);

var scope = sp.CreateScope();
maryamariyan marked this conversation as resolved.
Show resolved Hide resolved
for (int i = 0; i < 50; i++)
{
scope.ServiceProvider.GetRequiredService<A>();
maryamariyan marked this conversation as resolved.
Show resolved Hide resolved
Thread.Sleep(10); // Give the background thread time to compile
}

Assert.Same(sp.GetRequiredService<IFakeOpenGenericService<A>>().Value, sp.GetRequiredService<A>());
}

[Theory]
[InlineData(ServiceProviderMode.Default)]
[InlineData(ServiceProviderMode.Dynamic)]
[InlineData(ServiceProviderMode.Runtime)]
[InlineData(ServiceProviderMode.Expressions)]
[InlineData(ServiceProviderMode.ILEmit)]
private void ScopedServiceResolvedFromSingletonAfterCompilation3(ServiceProviderMode mode)
{
// Singleton IFakeX<A> -> Scoped A -> Scoped Aa
ServiceProvider sp = new ServiceCollection()
.AddScoped<Aa>()
.AddScoped<A>()
.AddSingleton<IFakeOpenGenericService<Aa>, FakeOpenGenericService<Aa>>()
.BuildServiceProvider(mode);

var scope = sp.CreateScope();
for (int i = 0; i < 50; i++)
{
scope.ServiceProvider.GetRequiredService<A>();
Thread.Sleep(10); // Give the background thread time to compile
}

Assert.Same(sp.GetRequiredService<IFakeOpenGenericService<Aa>>().Value.PropertyA, sp.GetRequiredService<A>());
maryamariyan marked this conversation as resolved.
Show resolved Hide resolved
}

private async Task<bool> ResolveUniqueServicesConcurrently()
{
var types = new Type[]
Expand Down Expand Up @@ -1150,5 +1198,13 @@ private class G { }
private class H { }
private class I { }
private class J { }
private class Aa
{
public Aa(A a)
{
PropertyA = a;
}
public A PropertyA { get; }
}
}
}