diff --git a/SpecFlow.DependencyInjection/DependencyInjectionPlugin.cs b/SpecFlow.DependencyInjection/DependencyInjectionPlugin.cs index 51314de..3b8fb9d 100644 --- a/SpecFlow.DependencyInjection/DependencyInjectionPlugin.cs +++ b/SpecFlow.DependencyInjection/DependencyInjectionPlugin.cs @@ -1,9 +1,16 @@ using System; +using System.Collections.Concurrent; using BoDi; using Microsoft.Extensions.DependencyInjection; using TechTalk.SpecFlow; +using TechTalk.SpecFlow.Bindings; +using TechTalk.SpecFlow.Bindings.Discovery; +using TechTalk.SpecFlow.BindingSkeletons; +using TechTalk.SpecFlow.Configuration; +using TechTalk.SpecFlow.ErrorHandling; using TechTalk.SpecFlow.Infrastructure; using TechTalk.SpecFlow.Plugins; +using TechTalk.SpecFlow.Tracing; using TechTalk.SpecFlow.UnitTestProvider; [assembly: RuntimePlugin(typeof(SolidToken.SpecFlow.DependencyInjection.DependencyInjectionPlugin))] @@ -12,113 +19,174 @@ namespace SolidToken.SpecFlow.DependencyInjection { public class DependencyInjectionPlugin : IRuntimePlugin { - private readonly object registrationLock = new object(); - + private static readonly ConcurrentDictionary BindMapping = + new ConcurrentDictionary(); + + private static readonly ConcurrentDictionary ActiveServiceScopes = + new ConcurrentDictionary(); + + private readonly object _registrationLock = new object(); + public void Initialize(RuntimePluginEvents runtimePluginEvents, RuntimePluginParameters runtimePluginParameters, UnitTestProviderConfiguration unitTestProviderConfiguration) { - runtimePluginEvents.CustomizeGlobalDependencies += (sender, args) => + runtimePluginEvents.CustomizeGlobalDependencies += CustomizeGlobalDependencies; + runtimePluginEvents.CustomizeFeatureDependencies += CustomizeFeatureDependenciesEventHandler; + runtimePluginEvents.CustomizeScenarioDependencies += CustomizeScenarioDependenciesEventHandler; + } + + private void CustomizeGlobalDependencies(object sender, CustomizeGlobalDependenciesEventArgs args) + { + if (!args.ObjectContainer.IsRegistered()) { - if (!args.ObjectContainer.IsRegistered()) + lock (_registrationLock) { - lock (registrationLock) + if (!args.ObjectContainer.IsRegistered()) { - if (!args.ObjectContainer.IsRegistered()) - { - args.ObjectContainer.RegisterTypeAs(); - args.ObjectContainer.RegisterTypeAs(); - } + args.ObjectContainer.RegisterTypeAs(); + args.ObjectContainer.RegisterTypeAs(); } - args.ObjectContainer.Resolve(); - } - }; - runtimePluginEvents.CustomizeScenarioDependencies += (sender, args) => - { - args.ObjectContainer.RegisterFactoryAs(() => - { - var serviceCollectionFinder = args.ObjectContainer.Resolve(); - var createScenarioServiceCollection = serviceCollectionFinder.GetCreateScenarioServiceCollection(); - var services = createScenarioServiceCollection(); + // We store the service provider in the global container, we create it only once + // It must be lazy (hence factory) because at this point we still don't have the bindings mapped. + args.ObjectContainer.RegisterFactoryAs(() => + { + var serviceCollectionFinder = args.ObjectContainer.Resolve(); + var (services, scoping) = serviceCollectionFinder.GetServiceCollection(); - RegisterObjectContainer(args.ObjectContainer, services); - RegisterScenarioSpecFlowDependencies(services); - RegisterFeatureSpecFlowDependencies(services); - RegisterTestThreadSpecFlowDependencies(services); + RegisterProxyBindings(args.ObjectContainer, services); + return new RootServiceProviderContainer(services.BuildServiceProvider(), scoping); + }); - return services.BuildServiceProvider(); - }); - }; + args.ObjectContainer.RegisterFactoryAs(() => + { + return args.ObjectContainer.Resolve().ServiceProvider; + }); + + // Will make sure DI scope is disposed. + var lcEvents = args.ObjectContainer.Resolve(); + lcEvents.AfterScenario += AfterScenarioPluginLifecycleEventHandler; + lcEvents.AfterFeature += AfterFeaturePluginLifecycleEventHandler; + } + args.ObjectContainer.Resolve(); + } + } + + private static void CustomizeFeatureDependenciesEventHandler(object sender, CustomizeFeatureDependenciesEventArgs args) + { + // At this point we have the bindings, we can resolve the service provider, which will build it if it's the first time. + var spContainer = args.ObjectContainer.Resolve(); - runtimePluginEvents.CustomizeFeatureDependencies += (sender, args) => + if (spContainer.Scoping == ScopeLevelType.Feature) { + var serviceProvider = spContainer.ServiceProvider; + + // Now we can register a new scoped service provider args.ObjectContainer.RegisterFactoryAs(() => { - var serviceCollectionFinder = args.ObjectContainer.Resolve(); - var createScenarioServiceCollection = serviceCollectionFinder.GetCreateScenarioServiceCollection(); - var services = createScenarioServiceCollection(); - - RegisterObjectContainer(args.ObjectContainer, services); - RegisterFeatureSpecFlowDependencies(services); - RegisterTestThreadSpecFlowDependencies(services); - - return services.BuildServiceProvider(); + var scope = serviceProvider.CreateScope(); + BindMapping.TryAdd(scope.ServiceProvider, args.ObjectContainer.Resolve()); + ActiveServiceScopes.TryAdd(args.ObjectContainer.Resolve(), scope); + return scope.ServiceProvider; }); - }; + } + } - runtimePluginEvents.CustomizeTestThreadDependencies += (sender, args) => + private static void CustomizeScenarioDependenciesEventHandler(object sender, CustomizeScenarioDependenciesEventArgs args) + { + // At this point we have the bindings, we can resolve the service provider, which will build it if it's the first time. + var spContainer = args.ObjectContainer.Resolve(); + + if (spContainer.Scoping == ScopeLevelType.Scenario) { + var serviceProvider = spContainer.ServiceProvider; + // Now we can register a new scoped service provider args.ObjectContainer.RegisterFactoryAs(() => { - var serviceCollectionFinder = args.ObjectContainer.Resolve(); - var createScenarioServiceCollection = serviceCollectionFinder.GetCreateScenarioServiceCollection(); - var services = createScenarioServiceCollection(); - - RegisterObjectContainer(args.ObjectContainer, services); - RegisterTestThreadSpecFlowDependencies(services); - - return services.BuildServiceProvider(); + var scope = serviceProvider.CreateScope(); + ActiveServiceScopes.TryAdd(args.ObjectContainer.Resolve(), scope); + return scope.ServiceProvider; }); - }; + } } - - private static void RegisterObjectContainer( - IObjectContainer objectContainer, - IServiceCollection services) + + private static void AfterScenarioPluginLifecycleEventHandler(object sender, RuntimePluginAfterScenarioEventArgs eventArgs) { - services.AddTransient(ctx => objectContainer); + if (ActiveServiceScopes.TryRemove(eventArgs.ObjectContainer.Resolve(), out var serviceScope)) + { + BindMapping.TryRemove(serviceScope.ServiceProvider, out _); + serviceScope.Dispose(); + } } - - private static void RegisterScenarioSpecFlowDependencies( - IServiceCollection services) + + private static void AfterFeaturePluginLifecycleEventHandler(object sender, RuntimePluginAfterFeatureEventArgs eventArgs) { - services.AddTransient(ctx => + if (ActiveServiceScopes.TryRemove(eventArgs.ObjectContainer.Resolve(), out var serviceScope)) { - var specflowContainer = ctx.GetService(); - var scenarioContext = specflowContainer.Resolve(); - return scenarioContext; - }); + BindMapping.TryRemove(serviceScope.ServiceProvider, out _); + serviceScope.Dispose(); + } } - private static void RegisterFeatureSpecFlowDependencies( - IServiceCollection services) + private static void RegisterProxyBindings(IObjectContainer objectContainer, IServiceCollection services) { - services.AddTransient(ctx => + // Required for DI of binding classes that want container injections + // While they can (and should) use the method params for injection, we can support it. + // Note that in Feature mode, one can't inject "ScenarioContext", this can only be done from method params. + + // Bases on this: https://docs.specflow.org/projects/specflow/en/latest/Extend/Available-Containers-%26-Registrations.html + // Might need to add more... + + services.AddSingleton(objectContainer); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + services.AddSingleton(sp => objectContainer.Resolve()); + + services.AddTransient(sp => { - var specflowContainer = ctx.GetService(); - var featureContext = specflowContainer.Resolve(); - return featureContext; + var container = BindMapping.TryGetValue(sp, out var ctx) + ? ctx.ScenarioContext?.ScenarioContainer ?? + ctx.FeatureContext?.FeatureContainer ?? + ctx.TestThreadContext?.TestThreadContainer ?? + objectContainer + : objectContainer; + + return container.Resolve(); }); + + services.AddTransient(sp => BindMapping[sp]); + services.AddTransient(sp => BindMapping[sp].TestThreadContext); + services.AddTransient(sp => BindMapping[sp].FeatureContext); + services.AddTransient(sp => BindMapping[sp].ScenarioContext); + services.AddTransient(sp => BindMapping[sp].TestThreadContext.TestThreadContainer.Resolve()); + services.AddTransient(sp => BindMapping[sp].TestThreadContext.TestThreadContainer.Resolve()); + services.AddTransient(sp => BindMapping[sp].TestThreadContext.TestThreadContainer.Resolve()); + services.AddTransient(sp => BindMapping[sp].TestThreadContext.TestThreadContainer.Resolve()); } - private static void RegisterTestThreadSpecFlowDependencies( - IServiceCollection services) + private class RootServiceProviderContainer { - services.AddTransient(ctx => + public IServiceProvider ServiceProvider { get; } + public ScopeLevelType Scoping { get; } + + public RootServiceProviderContainer(IServiceProvider sp, ScopeLevelType scoping) { - var specflowContainer = ctx.GetService(); - var testThreadContext = specflowContainer.Resolve(); - return testThreadContext; - }); + ServiceProvider = sp; + Scoping = scoping; + } } } } diff --git a/SpecFlow.DependencyInjection/DependencyInjectionTestObjectResolver.cs b/SpecFlow.DependencyInjection/DependencyInjectionTestObjectResolver.cs index d1ecbfc..1261870 100644 --- a/SpecFlow.DependencyInjection/DependencyInjectionTestObjectResolver.cs +++ b/SpecFlow.DependencyInjection/DependencyInjectionTestObjectResolver.cs @@ -1,15 +1,50 @@ using System; +using System.Collections.Concurrent; +using System.Reflection; using BoDi; +using Microsoft.Extensions.DependencyInjection; using TechTalk.SpecFlow.Infrastructure; namespace SolidToken.SpecFlow.DependencyInjection { + /* TODO + If SpecFlow will add an "IObjectContainer.IsRegistered(Type type)" method next to the existing "IsRegistered()" + We can remove most of the code here! + */ public class DependencyInjectionTestObjectResolver : ITestObjectResolver { - public object ResolveBindingInstance(Type bindingType, IObjectContainer scenarioContainer) + // Can remove if IsRegistered(Type type) exists + private static readonly ConcurrentDictionary IsRegisteredMethodInfoCache = + new ConcurrentDictionary(); + + // Can remove if IsRegistered(Type type) exists + private static readonly MethodInfo IsRegisteredMethodInfo = typeof(DependencyInjectionTestObjectResolver) + .GetMethod(nameof(IsRegistered), BindingFlags.Instance | BindingFlags.Public); + + // Can remove if IsRegistered(Type type) exists + private static MethodInfo CreateGenericMethodInfo(Type t) => IsRegisteredMethodInfo.MakeGenericMethod(t); + + public object ResolveBindingInstance(Type bindingType, IObjectContainer container) + { + // Can remove if IsRegistered(Type type) exists + var mi = IsRegisteredMethodInfoCache.GetOrAdd(bindingType, CreateGenericMethodInfo); + var registered = (bool) mi.Invoke(this, new object[] { container }); + // var registered = container.IsRegistered(bindingType); + + return registered + ? container.Resolve(bindingType) + : container.Resolve().GetRequiredService(bindingType); + } + + public bool IsRegistered(IObjectContainer container) { - var provider = scenarioContainer.Resolve(); - return provider.GetService(bindingType); + if (container.IsRegistered()) + return true; + + // IsRegistered is not recursive, it will only check the current container + if (container is ObjectContainer c && c.BaseContainer != null) + return IsRegistered(c.BaseContainer); + return false; } } } diff --git a/SpecFlow.DependencyInjection/IServiceCollectionFinder.cs b/SpecFlow.DependencyInjection/IServiceCollectionFinder.cs index efd2c10..3604258 100644 --- a/SpecFlow.DependencyInjection/IServiceCollectionFinder.cs +++ b/SpecFlow.DependencyInjection/IServiceCollectionFinder.cs @@ -5,6 +5,6 @@ namespace SolidToken.SpecFlow.DependencyInjection { public interface IServiceCollectionFinder { - Func GetCreateScenarioServiceCollection(); + (IServiceCollection, ScopeLevelType) GetServiceCollection(); } } diff --git a/SpecFlow.DependencyInjection/ScenarioDependenciesAttribute.cs b/SpecFlow.DependencyInjection/ScenarioDependenciesAttribute.cs index 31e6779..0ffa03a 100644 --- a/SpecFlow.DependencyInjection/ScenarioDependenciesAttribute.cs +++ b/SpecFlow.DependencyInjection/ScenarioDependenciesAttribute.cs @@ -2,6 +2,18 @@ namespace SolidToken.SpecFlow.DependencyInjection { + public enum ScopeLevelType + { + /// + /// Scoping is created for every scenario and it is destroyed once the scenario ends. + /// + Scenario, + /// + /// Scoping is created for Feature scenario and it is destroyed once the Feature ends. + /// + Feature + } + [AttributeUsage(AttributeTargets.Method)] public class ScenarioDependenciesAttribute : Attribute { @@ -9,5 +21,9 @@ public class ScenarioDependenciesAttribute : Attribute /// Automatically register all SpecFlow bindings. /// public bool AutoRegisterBindings { get; set; } = true; + /// + /// Define when to create and destroy scope. + /// + public ScopeLevelType ScopeLevel { get; set; } = ScopeLevelType.Scenario; } } diff --git a/SpecFlow.DependencyInjection/ServiceCollectionFinder.cs b/SpecFlow.DependencyInjection/ServiceCollectionFinder.cs index 24bd146..84e1086 100644 --- a/SpecFlow.DependencyInjection/ServiceCollectionFinder.cs +++ b/SpecFlow.DependencyInjection/ServiceCollectionFinder.cs @@ -11,26 +11,18 @@ namespace SolidToken.SpecFlow.DependencyInjection public class ServiceCollectionFinder : IServiceCollectionFinder { private readonly IBindingRegistry bindingRegistry; - private readonly Lazy> createScenarioServiceCollection; - + private (IServiceCollection, ScopeLevelType) _cache; + public ServiceCollectionFinder(IBindingRegistry bindingRegistry) { this.bindingRegistry = bindingRegistry; - createScenarioServiceCollection = new Lazy>(FindCreateScenarioServiceCollection, true); } - public Func GetCreateScenarioServiceCollection() - { - var services = createScenarioServiceCollection.Value; - if (services == null) - { - throw new MissingScenarioDependenciesException(); - } - return services; - } - - protected virtual Func FindCreateScenarioServiceCollection() + public (IServiceCollection, ScopeLevelType) GetServiceCollection() { + if (_cache != default) + return _cache; + var assemblies = bindingRegistry.GetBindingAssemblies(); foreach (var assembly in assemblies) { @@ -42,22 +34,20 @@ protected virtual Func FindCreateScenarioServiceCollection() if (scenarioDependenciesAttribute != null) { - return () => + var serviceCollection = GetServiceCollection(methodInfo); + if (scenarioDependenciesAttribute.AutoRegisterBindings) { - var serviceCollection = GetServiceCollection(methodInfo); - if (scenarioDependenciesAttribute.AutoRegisterBindings) - { - AddBindingAttributes(assemblies, serviceCollection); - } - return serviceCollection; - }; + AddBindingAttributes(assemblies, serviceCollection); + } + return _cache = (serviceCollection, scenarioDependenciesAttribute.ScopeLevel); } } } } - return null; + throw new MissingScenarioDependenciesException(); } + private static IServiceCollection GetServiceCollection(MethodBase methodInfo) { return (IServiceCollection)methodInfo.Invoke(null, null); @@ -69,7 +59,7 @@ private static void AddBindingAttributes(IEnumerable bindingAssemblies { foreach (var type in assembly.GetTypes().Where(t => Attribute.IsDefined(t, typeof(BindingAttribute)))) { - serviceCollection.AddSingleton(type); + serviceCollection.AddScoped(type); } } }