Skip to content
This repository has been archived by the owner on May 22, 2024. It is now read-only.

Commit

Permalink
Enhance dependency injection scoping (#74)
Browse files Browse the repository at this point in the history
Contributor: @shlomiassaf
  • Loading branch information
shlomiassaf committed May 7, 2022
1 parent 8128ae8 commit 1258fb0
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 102 deletions.
216 changes: 142 additions & 74 deletions SpecFlow.DependencyInjection/DependencyInjectionPlugin.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
using System;
using System.Collections.Concurrent;
using BoDi;
using Microsoft.Extensions.DependencyInjection;
using TechTalk.SpecFlow;
using TechTalk.SpecFlow.Bindings;
using TechTalk.SpecFlow.Bindings.Discovery;
using TechTalk.SpecFlow.BindingSkeletons;
using TechTalk.SpecFlow.Configuration;
using TechTalk.SpecFlow.ErrorHandling;
using TechTalk.SpecFlow.Infrastructure;
using TechTalk.SpecFlow.Plugins;
using TechTalk.SpecFlow.Tracing;
using TechTalk.SpecFlow.UnitTestProvider;

[assembly: RuntimePlugin(typeof(SolidToken.SpecFlow.DependencyInjection.DependencyInjectionPlugin))]
Expand All @@ -12,113 +19,174 @@ namespace SolidToken.SpecFlow.DependencyInjection
{
public class DependencyInjectionPlugin : IRuntimePlugin
{
private readonly object registrationLock = new object();

private static readonly ConcurrentDictionary<IServiceProvider, IContextManager> BindMapping =
new ConcurrentDictionary<IServiceProvider, IContextManager>();

private static readonly ConcurrentDictionary<ISpecFlowContext, IServiceScope> ActiveServiceScopes =
new ConcurrentDictionary<ISpecFlowContext, IServiceScope>();

private readonly object _registrationLock = new object();

public void Initialize(RuntimePluginEvents runtimePluginEvents, RuntimePluginParameters runtimePluginParameters, UnitTestProviderConfiguration unitTestProviderConfiguration)
{
runtimePluginEvents.CustomizeGlobalDependencies += (sender, args) =>
runtimePluginEvents.CustomizeGlobalDependencies += CustomizeGlobalDependencies;
runtimePluginEvents.CustomizeFeatureDependencies += CustomizeFeatureDependenciesEventHandler;
runtimePluginEvents.CustomizeScenarioDependencies += CustomizeScenarioDependenciesEventHandler;
}

private void CustomizeGlobalDependencies(object sender, CustomizeGlobalDependenciesEventArgs args)
{
if (!args.ObjectContainer.IsRegistered<IServiceCollectionFinder>())
{
if (!args.ObjectContainer.IsRegistered<IServiceCollectionFinder>())
lock (_registrationLock)
{
lock (registrationLock)
if (!args.ObjectContainer.IsRegistered<IServiceCollectionFinder>())
{
if (!args.ObjectContainer.IsRegistered<IServiceCollectionFinder>())
{
args.ObjectContainer.RegisterTypeAs<DependencyInjectionTestObjectResolver, ITestObjectResolver>();
args.ObjectContainer.RegisterTypeAs<ServiceCollectionFinder, IServiceCollectionFinder>();
}
args.ObjectContainer.RegisterTypeAs<DependencyInjectionTestObjectResolver, ITestObjectResolver>();
args.ObjectContainer.RegisterTypeAs<ServiceCollectionFinder, IServiceCollectionFinder>();
}
args.ObjectContainer.Resolve<IServiceCollectionFinder>();
}
};

runtimePluginEvents.CustomizeScenarioDependencies += (sender, args) =>
{
args.ObjectContainer.RegisterFactoryAs<IServiceProvider>(() =>
{
var serviceCollectionFinder = args.ObjectContainer.Resolve<IServiceCollectionFinder>();
var createScenarioServiceCollection = serviceCollectionFinder.GetCreateScenarioServiceCollection();
var services = createScenarioServiceCollection();
// We store the service provider in the global container, we create it only once
// It must be lazy (hence factory) because at this point we still don't have the bindings mapped.
args.ObjectContainer.RegisterFactoryAs<RootServiceProviderContainer>(() =>
{
var serviceCollectionFinder = args.ObjectContainer.Resolve<IServiceCollectionFinder>();
var (services, scoping) = serviceCollectionFinder.GetServiceCollection();
RegisterObjectContainer(args.ObjectContainer, services);
RegisterScenarioSpecFlowDependencies(services);
RegisterFeatureSpecFlowDependencies(services);
RegisterTestThreadSpecFlowDependencies(services);
RegisterProxyBindings(args.ObjectContainer, services);
return new RootServiceProviderContainer(services.BuildServiceProvider(), scoping);
});

return services.BuildServiceProvider();
});
};
args.ObjectContainer.RegisterFactoryAs<IServiceProvider>(() =>
{
return args.ObjectContainer.Resolve<RootServiceProviderContainer>().ServiceProvider;
});

// Will make sure DI scope is disposed.
var lcEvents = args.ObjectContainer.Resolve<RuntimePluginTestExecutionLifecycleEvents>();
lcEvents.AfterScenario += AfterScenarioPluginLifecycleEventHandler;
lcEvents.AfterFeature += AfterFeaturePluginLifecycleEventHandler;
}
args.ObjectContainer.Resolve<IServiceCollectionFinder>();
}
}

private static void CustomizeFeatureDependenciesEventHandler(object sender, CustomizeFeatureDependenciesEventArgs args)
{
// At this point we have the bindings, we can resolve the service provider, which will build it if it's the first time.
var spContainer = args.ObjectContainer.Resolve<RootServiceProviderContainer>();

runtimePluginEvents.CustomizeFeatureDependencies += (sender, args) =>
if (spContainer.Scoping == ScopeLevelType.Feature)
{
var serviceProvider = spContainer.ServiceProvider;

// Now we can register a new scoped service provider
args.ObjectContainer.RegisterFactoryAs<IServiceProvider>(() =>
{
var serviceCollectionFinder = args.ObjectContainer.Resolve<IServiceCollectionFinder>();
var createScenarioServiceCollection = serviceCollectionFinder.GetCreateScenarioServiceCollection();
var services = createScenarioServiceCollection();
RegisterObjectContainer(args.ObjectContainer, services);
RegisterFeatureSpecFlowDependencies(services);
RegisterTestThreadSpecFlowDependencies(services);
return services.BuildServiceProvider();
var scope = serviceProvider.CreateScope();
BindMapping.TryAdd(scope.ServiceProvider, args.ObjectContainer.Resolve<IContextManager>());
ActiveServiceScopes.TryAdd(args.ObjectContainer.Resolve<FeatureContext>(), scope);
return scope.ServiceProvider;
});
};
}
}

runtimePluginEvents.CustomizeTestThreadDependencies += (sender, args) =>
private static void CustomizeScenarioDependenciesEventHandler(object sender, CustomizeScenarioDependenciesEventArgs args)
{
// At this point we have the bindings, we can resolve the service provider, which will build it if it's the first time.
var spContainer = args.ObjectContainer.Resolve<RootServiceProviderContainer>();

if (spContainer.Scoping == ScopeLevelType.Scenario)
{
var serviceProvider = spContainer.ServiceProvider;
// Now we can register a new scoped service provider
args.ObjectContainer.RegisterFactoryAs<IServiceProvider>(() =>
{
var serviceCollectionFinder = args.ObjectContainer.Resolve<IServiceCollectionFinder>();
var createScenarioServiceCollection = serviceCollectionFinder.GetCreateScenarioServiceCollection();
var services = createScenarioServiceCollection();
RegisterObjectContainer(args.ObjectContainer, services);
RegisterTestThreadSpecFlowDependencies(services);
return services.BuildServiceProvider();
var scope = serviceProvider.CreateScope();
ActiveServiceScopes.TryAdd(args.ObjectContainer.Resolve<ScenarioContext>(), scope);
return scope.ServiceProvider;
});
};
}
}

private static void RegisterObjectContainer(
IObjectContainer objectContainer,
IServiceCollection services)

private static void AfterScenarioPluginLifecycleEventHandler(object sender, RuntimePluginAfterScenarioEventArgs eventArgs)
{
services.AddTransient<IObjectContainer>(ctx => objectContainer);
if (ActiveServiceScopes.TryRemove(eventArgs.ObjectContainer.Resolve<ScenarioContext>(), out var serviceScope))
{
BindMapping.TryRemove(serviceScope.ServiceProvider, out _);
serviceScope.Dispose();
}
}

private static void RegisterScenarioSpecFlowDependencies(
IServiceCollection services)

private static void AfterFeaturePluginLifecycleEventHandler(object sender, RuntimePluginAfterFeatureEventArgs eventArgs)
{
services.AddTransient<ScenarioContext>(ctx =>
if (ActiveServiceScopes.TryRemove(eventArgs.ObjectContainer.Resolve<FeatureContext>(), out var serviceScope))
{
var specflowContainer = ctx.GetService<IObjectContainer>();
var scenarioContext = specflowContainer.Resolve<ScenarioContext>();
return scenarioContext;
});
BindMapping.TryRemove(serviceScope.ServiceProvider, out _);
serviceScope.Dispose();
}
}

private static void RegisterFeatureSpecFlowDependencies(
IServiceCollection services)
private static void RegisterProxyBindings(IObjectContainer objectContainer, IServiceCollection services)
{
services.AddTransient<FeatureContext>(ctx =>
// Required for DI of binding classes that want container injections
// While they can (and should) use the method params for injection, we can support it.
// Note that in Feature mode, one can't inject "ScenarioContext", this can only be done from method params.

// Bases on this: https://docs.specflow.org/projects/specflow/en/latest/Extend/Available-Containers-%26-Registrations.html
// Might need to add more...

services.AddSingleton<IObjectContainer>(objectContainer);
services.AddSingleton(sp => objectContainer.Resolve<IRuntimeConfigurationProvider>());
services.AddSingleton(sp => objectContainer.Resolve<ITestRunnerManager>());
services.AddSingleton(sp => objectContainer.Resolve<IStepFormatter>());
services.AddSingleton(sp => objectContainer.Resolve<ITestTracer>());
services.AddSingleton(sp => objectContainer.Resolve<ITraceListener>());
services.AddSingleton(sp => objectContainer.Resolve<ITraceListenerQueue>());
services.AddSingleton(sp => objectContainer.Resolve<IErrorProvider>());
services.AddSingleton(sp => objectContainer.Resolve<IRuntimeBindingSourceProcessor>());
services.AddSingleton(sp => objectContainer.Resolve<IBindingRegistry>());
services.AddSingleton(sp => objectContainer.Resolve<IBindingFactory>());
services.AddSingleton(sp => objectContainer.Resolve<IStepDefinitionRegexCalculator>());
services.AddSingleton(sp => objectContainer.Resolve<IBindingInvoker>());
services.AddSingleton(sp => objectContainer.Resolve<IStepDefinitionSkeletonProvider>());
services.AddSingleton(sp => objectContainer.Resolve<ISkeletonTemplateProvider>());
services.AddSingleton(sp => objectContainer.Resolve<IStepTextAnalyzer>());
services.AddSingleton(sp => objectContainer.Resolve<IRuntimePluginLoader>());
services.AddSingleton(sp => objectContainer.Resolve<IBindingAssemblyLoader>());

services.AddTransient(sp =>
{
var specflowContainer = ctx.GetService<IObjectContainer>();
var featureContext = specflowContainer.Resolve<FeatureContext>();
return featureContext;
var container = BindMapping.TryGetValue(sp, out var ctx)
? ctx.ScenarioContext?.ScenarioContainer ??
ctx.FeatureContext?.FeatureContainer ??
ctx.TestThreadContext?.TestThreadContainer ??
objectContainer
: objectContainer;
return container.Resolve<ISpecFlowOutputHelper>();
});

services.AddTransient(sp => BindMapping[sp]);
services.AddTransient(sp => BindMapping[sp].TestThreadContext);
services.AddTransient(sp => BindMapping[sp].FeatureContext);
services.AddTransient(sp => BindMapping[sp].ScenarioContext);
services.AddTransient(sp => BindMapping[sp].TestThreadContext.TestThreadContainer.Resolve<ITestRunner>());
services.AddTransient(sp => BindMapping[sp].TestThreadContext.TestThreadContainer.Resolve<ITestExecutionEngine>());
services.AddTransient(sp => BindMapping[sp].TestThreadContext.TestThreadContainer.Resolve<IStepArgumentTypeConverter>());
services.AddTransient(sp => BindMapping[sp].TestThreadContext.TestThreadContainer.Resolve<IStepDefinitionMatchService>());
}

private static void RegisterTestThreadSpecFlowDependencies(
IServiceCollection services)
private class RootServiceProviderContainer
{
services.AddTransient<TestThreadContext>(ctx =>
public IServiceProvider ServiceProvider { get; }
public ScopeLevelType Scoping { get; }

public RootServiceProviderContainer(IServiceProvider sp, ScopeLevelType scoping)
{
var specflowContainer = ctx.GetService<IObjectContainer>();
var testThreadContext = specflowContainer.Resolve<TestThreadContext>();
return testThreadContext;
});
ServiceProvider = sp;
Scoping = scoping;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,50 @@
using System;
using System.Collections.Concurrent;
using System.Reflection;
using BoDi;
using Microsoft.Extensions.DependencyInjection;
using TechTalk.SpecFlow.Infrastructure;

namespace SolidToken.SpecFlow.DependencyInjection
{
/* TODO
If SpecFlow will add an "IObjectContainer.IsRegistered(Type type)" method next to the existing "IsRegistered<T>()"
We can remove most of the code here!
*/
public class DependencyInjectionTestObjectResolver : ITestObjectResolver
{
public object ResolveBindingInstance(Type bindingType, IObjectContainer scenarioContainer)
// Can remove if IsRegistered(Type type) exists
private static readonly ConcurrentDictionary<Type, MethodInfo> IsRegisteredMethodInfoCache =
new ConcurrentDictionary<Type, MethodInfo>();

// Can remove if IsRegistered(Type type) exists
private static readonly MethodInfo IsRegisteredMethodInfo = typeof(DependencyInjectionTestObjectResolver)
.GetMethod(nameof(IsRegistered), BindingFlags.Instance | BindingFlags.Public);

// Can remove if IsRegistered(Type type) exists
private static MethodInfo CreateGenericMethodInfo(Type t) => IsRegisteredMethodInfo.MakeGenericMethod(t);

public object ResolveBindingInstance(Type bindingType, IObjectContainer container)
{
// Can remove if IsRegistered(Type type) exists
var mi = IsRegisteredMethodInfoCache.GetOrAdd(bindingType, CreateGenericMethodInfo);
var registered = (bool) mi.Invoke(this, new object[] { container });
// var registered = container.IsRegistered(bindingType);

return registered
? container.Resolve(bindingType)
: container.Resolve<IServiceProvider>().GetRequiredService(bindingType);
}

public bool IsRegistered<T>(IObjectContainer container)
{
var provider = scenarioContainer.Resolve<IServiceProvider>();
return provider.GetService(bindingType);
if (container.IsRegistered<T>())
return true;

// IsRegistered is not recursive, it will only check the current container
if (container is ObjectContainer c && c.BaseContainer != null)
return IsRegistered<T>(c.BaseContainer);
return false;
}
}
}
2 changes: 1 addition & 1 deletion SpecFlow.DependencyInjection/IServiceCollectionFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ namespace SolidToken.SpecFlow.DependencyInjection
{
public interface IServiceCollectionFinder
{
Func<IServiceCollection> GetCreateScenarioServiceCollection();
(IServiceCollection, ScopeLevelType) GetServiceCollection();
}
}
16 changes: 16 additions & 0 deletions SpecFlow.DependencyInjection/ScenarioDependenciesAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,28 @@

namespace SolidToken.SpecFlow.DependencyInjection
{
public enum ScopeLevelType
{
/// <summary>
/// Scoping is created for every scenario and it is destroyed once the scenario ends.
/// </summary>
Scenario,
/// <summary>
/// Scoping is created for Feature scenario and it is destroyed once the Feature ends.
/// </summary>
Feature
}

[AttributeUsage(AttributeTargets.Method)]
public class ScenarioDependenciesAttribute : Attribute
{
/// <summary>
/// Automatically register all SpecFlow bindings.
/// </summary>
public bool AutoRegisterBindings { get; set; } = true;
/// <summary>
/// Define when to create and destroy scope.
/// </summary>
public ScopeLevelType ScopeLevel { get; set; } = ScopeLevelType.Scenario;
}
}
Loading

0 comments on commit 1258fb0

Please sign in to comment.