Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement IKeyedServiceProvider interface on ServiceProviderEngineScope #89509

Merged
merged 2 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,114 @@ public void ResolveKeyedServiceTransientTypeWithAnyKey()
Assert.NotSame(first, second);
}

[Fact]
public void ResolveKeyedSingletonFromInjectedServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedSingleton<IService, Service>("key");
serviceCollection.AddSingleton<ServiceProviderAccessor>();

var provider = CreateServiceProvider(serviceCollection);
var accessor = provider.GetRequiredService<ServiceProviderAccessor>();

Assert.Null(accessor.ServiceProvider.GetService<IService>());

var service1 = accessor.ServiceProvider.GetKeyedService<IService>("key");
var service2 = accessor.ServiceProvider.GetKeyedService<IService>("key");

Assert.Same(service1, service2);
}

[Fact]
public void ResolveKeyedTransientFromInjectedServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedTransient<IService, Service>("key");
serviceCollection.AddSingleton<ServiceProviderAccessor>();

var provider = CreateServiceProvider(serviceCollection);
var accessor = provider.GetRequiredService<ServiceProviderAccessor>();

Assert.Null(accessor.ServiceProvider.GetService<IService>());

var service1 = accessor.ServiceProvider.GetKeyedService<IService>("key");
var service2 = accessor.ServiceProvider.GetKeyedService<IService>("key");

Assert.NotSame(service1, service2);
}

[Fact]
public void ResolveKeyedSingletonFromScopeServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedSingleton<IService, Service>("key");

var provider = CreateServiceProvider(serviceCollection);
var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();

Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");

Assert.Same(serviceA1, serviceA2);
Assert.Same(serviceB1, serviceB2);
Assert.Same(serviceA1, serviceB1);
}

[Fact]
public void ResolveKeyedScopedFromScopeServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedScoped<IService, Service>("key");

var provider = CreateServiceProvider(serviceCollection);
var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();

Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");

Assert.Same(serviceA1, serviceA2);
Assert.Same(serviceB1, serviceB2);
Assert.NotSame(serviceA1, serviceB1);
}

[Fact]
public void ResolveKeyedTransientFromScopeServiceProvider()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddKeyedTransient<IService, Service>("key");

var provider = CreateServiceProvider(serviceCollection);
var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();

Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");

Assert.NotSame(serviceA1, serviceA2);
Assert.NotSame(serviceB1, serviceB2);
Assert.NotSame(serviceA1, serviceB1);
}

internal interface IService { }

internal class Service : IService
Expand Down Expand Up @@ -358,5 +466,15 @@ internal class ServiceWithIntKey : IService

public ServiceWithIntKey([ServiceKey] int id) => _id = id;
}

internal class ServiceProviderAccessor
{
public ServiceProviderAccessor(IServiceProvider serviceProvider)
{
ServiceProvider = serviceProvider;
}

public IServiceProvider ServiceProvider { get; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
[DebuggerDisplay("{DebuggerToString(),nq}")]
[DebuggerTypeProxy(typeof(ServiceProviderEngineScopeDebugView))]
internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvider, IAsyncDisposable, IServiceScopeFactory
internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvider, IKeyedServiceProvider, IAsyncDisposable, IServiceScopeFactory
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
{
// For testing and debugging only
internal IList<object> Disposables => _disposables ?? (IList<object>)Array.Empty<object>();
Expand Down Expand Up @@ -50,6 +50,26 @@ public ServiceProviderEngineScope(ServiceProvider provider, bool isRootScope)
return RootProvider.GetService(ServiceIdentifier.FromServiceType(serviceType), this);
}

public object? GetKeyedService(Type serviceType, object? serviceKey)
{
if (_disposed)
{
ThrowHelper.ThrowObjectDisposedException();
}

return RootProvider.GetKeyedService(serviceType, serviceKey, this);
}

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
{
if (_disposed)
{
ThrowHelper.ThrowObjectDisposedException();
}

return RootProvider.GetRequiredKeyedService(serviceType, serviceKey, this);
}

public IServiceProvider ServiceProvider => this;

public IServiceScope CreateScope() => RootProvider.CreateScope();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,17 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
public object? GetService(Type serviceType) => GetService(ServiceIdentifier.FromServiceType(serviceType), Root);

public object? GetKeyedService(Type serviceType, object? serviceKey)
=> GetService(new ServiceIdentifier(serviceKey, serviceType), Root);
=> GetKeyedService(serviceType, serviceKey, Root);

internal object? GetKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
=> GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
=> GetRequiredKeyedService(serviceType, serviceKey, Root);

internal object GetRequiredKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
{
object? service = GetKeyedService(serviceType, serviceKey);
object? service = GetKeyedService(serviceType, serviceKey, serviceProviderEngineScope);
if (service == null)
{
throw new InvalidOperationException(SR.Format(SR.NoServiceRegistered, serviceType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
Expand All @@ -29,5 +31,15 @@ public void RootEngineScopeDisposeTest()

Assert.Throws<ObjectDisposedException>(() => sp.GetRequiredService<IServiceProvider>());
}

[Fact]
public void ServiceProviderEngineScope_ImplementsAllServiceProviderInterfaces()
{
var engineScopeInterfaces = typeof(ServiceProviderEngineScope).GetInterfaces();
foreach (var serviceProviderInterface in typeof(ServiceProvider).GetInterfaces())
{
Assert.Contains(serviceProviderInterface, engineScopeInterfaces);
}
}
}
}
Loading