Skip to content

Commit

Permalink
Implement IKeyedServiceProvider interface on ServiceProviderEngineSco…
Browse files Browse the repository at this point in the history
…pe (#89509)

* Implement IKeyedServiceProvider interface

* Add more tests
  • Loading branch information
CarnaViire authored Jul 28, 2023
1 parent 849ed0a commit ec5c223
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,114 @@ public void ResolveKeyedServiceTransientTypeWithAnyKey()
Assert.NotSame(first, second);
}

[Fact]
public void ResolveKeyedSingletonFromInjectedServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedSingleton<IService, Service>("key");
serviceCollection.AddSingleton<ServiceProviderAccessor>();

var provider = CreateServiceProvider(serviceCollection);
var accessor = provider.GetRequiredService<ServiceProviderAccessor>();

Assert.Null(accessor.ServiceProvider.GetService<IService>());

var service1 = accessor.ServiceProvider.GetKeyedService<IService>("key");
var service2 = accessor.ServiceProvider.GetKeyedService<IService>("key");

Assert.Same(service1, service2);
}

[Fact]
public void ResolveKeyedTransientFromInjectedServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedTransient<IService, Service>("key");
serviceCollection.AddSingleton<ServiceProviderAccessor>();

var provider = CreateServiceProvider(serviceCollection);
var accessor = provider.GetRequiredService<ServiceProviderAccessor>();

Assert.Null(accessor.ServiceProvider.GetService<IService>());

var service1 = accessor.ServiceProvider.GetKeyedService<IService>("key");
var service2 = accessor.ServiceProvider.GetKeyedService<IService>("key");

Assert.NotSame(service1, service2);
}

[Fact]
public void ResolveKeyedSingletonFromScopeServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedSingleton<IService, Service>("key");

var provider = CreateServiceProvider(serviceCollection);
var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();

Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");

Assert.Same(serviceA1, serviceA2);
Assert.Same(serviceB1, serviceB2);
Assert.Same(serviceA1, serviceB1);
}

[Fact]
public void ResolveKeyedScopedFromScopeServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedScoped<IService, Service>("key");

var provider = CreateServiceProvider(serviceCollection);
var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();

Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");

Assert.Same(serviceA1, serviceA2);
Assert.Same(serviceB1, serviceB2);
Assert.NotSame(serviceA1, serviceB1);
}

[Fact]
public void ResolveKeyedTransientFromScopeServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedTransient<IService, Service>("key");

var provider = CreateServiceProvider(serviceCollection);
var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();

Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");

Assert.NotSame(serviceA1, serviceA2);
Assert.NotSame(serviceB1, serviceB2);
Assert.NotSame(serviceA1, serviceB1);
}

internal interface IService { }

internal class Service : IService
Expand Down Expand Up @@ -358,5 +466,15 @@ internal class ServiceWithIntKey : IService

public ServiceWithIntKey([ServiceKey] int id) => _id = id;
}

internal class ServiceProviderAccessor
{
public ServiceProviderAccessor(IServiceProvider serviceProvider)
{
ServiceProvider = serviceProvider;
}

public IServiceProvider ServiceProvider { get; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
[DebuggerDisplay("{DebuggerToString(),nq}")]
[DebuggerTypeProxy(typeof(ServiceProviderEngineScopeDebugView))]
internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvider, IAsyncDisposable, IServiceScopeFactory
internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvider, IKeyedServiceProvider, IAsyncDisposable, IServiceScopeFactory
{
// For testing and debugging only
internal IList<object> Disposables => _disposables ?? (IList<object>)Array.Empty<object>();
Expand Down Expand Up @@ -50,6 +50,26 @@ public ServiceProviderEngineScope(ServiceProvider provider, bool isRootScope)
return RootProvider.GetService(ServiceIdentifier.FromServiceType(serviceType), this);
}

public object? GetKeyedService(Type serviceType, object? serviceKey)
{
if (_disposed)
{
ThrowHelper.ThrowObjectDisposedException();
}

return RootProvider.GetKeyedService(serviceType, serviceKey, this);
}

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
{
if (_disposed)
{
ThrowHelper.ThrowObjectDisposedException();
}

return RootProvider.GetRequiredKeyedService(serviceType, serviceKey, this);
}

public IServiceProvider ServiceProvider => this;

public IServiceScope CreateScope() => RootProvider.CreateScope();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,17 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
public object? GetService(Type serviceType) => GetService(ServiceIdentifier.FromServiceType(serviceType), Root);

public object? GetKeyedService(Type serviceType, object? serviceKey)
=> GetService(new ServiceIdentifier(serviceKey, serviceType), Root);
=> GetKeyedService(serviceType, serviceKey, Root);

internal object? GetKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
=> GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
=> GetRequiredKeyedService(serviceType, serviceKey, Root);

internal object GetRequiredKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
{
object? service = GetKeyedService(serviceType, serviceKey);
object? service = GetKeyedService(serviceType, serviceKey, serviceProviderEngineScope);
if (service == null)
{
throw new InvalidOperationException(SR.Format(SR.NoServiceRegistered, serviceType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
Expand All @@ -29,5 +31,15 @@ public void RootEngineScopeDisposeTest()

Assert.Throws<ObjectDisposedException>(() => sp.GetRequiredService<IServiceProvider>());
}

[Fact]
public void ServiceProviderEngineScope_ImplementsAllServiceProviderInterfaces()
{
var engineScopeInterfaces = typeof(ServiceProviderEngineScope).GetInterfaces();
foreach (var serviceProviderInterface in typeof(ServiceProvider).GetInterfaces())
{
Assert.Contains(serviceProviderInterface, engineScopeInterfaces);
}
}
}
}

0 comments on commit ec5c223

Please sign in to comment.