diff --git a/src/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.cs b/src/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.cs index 37078e0196..e8c84b69ec 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery/Configuration/ConfigurationServiceEndpointProvider.cs @@ -17,6 +17,7 @@ internal sealed partial class ConfigurationServiceEndpointProvider : IServiceEnd private const string DefaultEndpointName = "default"; private readonly string _serviceName; private readonly string? _endpointName; + private readonly bool _includeAllSchemes; private readonly string[] _schemes; private readonly IConfiguration _configuration; private readonly ILogger _logger; @@ -39,6 +40,7 @@ public ConfigurationServiceEndpointProvider( { _serviceName = query.ServiceName; _endpointName = query.EndpointName; + _includeAllSchemes = serviceDiscoveryOptions.Value.AllowAllSchemes && query.IncludedSchemes.Count == 0; _schemes = ServiceDiscoveryOptions.ApplyAllowedSchemes(query.IncludedSchemes, serviceDiscoveryOptions.Value.AllowedSchemes, serviceDiscoveryOptions.Value.AllowAllSchemes); _configuration = configuration; _logger = logger; @@ -74,27 +76,17 @@ public ValueTask PopulateAsync(IServiceEndpointBuilder endpoints, CancellationTo string endpointName; if (string.IsNullOrWhiteSpace(_endpointName)) { - if (_schemes.Length == 0) + // Treat the scheme as the endpoint name and use the first section with a matching endpoint name which exists + endpointName = DefaultEndpointName; + ReadOnlySpan candidateNames = [DefaultEndpointName, .. _schemes]; + foreach (var scheme in candidateNames) { - // Use the section named "default". - endpointName = DefaultEndpointName; - namedSection = section.GetSection(endpointName); - } - else - { - // Set the ideal endpoint name for error messages. - endpointName = _schemes[0]; - - // Treat the scheme as the endpoint name and use the first section with a matching endpoint name which exists - foreach (var scheme in _schemes) + var candidate = section.GetSection(scheme); + if (candidate.Exists()) { - var candidate = section.GetSection(scheme); - if (candidate.Exists()) - { - endpointName = scheme; - namedSection = candidate; - break; - } + endpointName = scheme; + namedSection = candidate; + break; } } } @@ -135,46 +127,60 @@ public ValueTask PopulateAsync(IServiceEndpointBuilder endpoints, CancellationTo } } - // Filter the resolved endpoints to only include those which match the specified scheme. - var minIndex = _schemes.Length; - foreach (var ep in resolved) + int resolvedEndpointCount; + if (_includeAllSchemes) { - if (ep.EndPoint is UriEndPoint uri && uri.Uri.Scheme is { } scheme) + // Include all endpoints. + foreach (var ep in resolved) { - var index = Array.IndexOf(_schemes, scheme); - if (index >= 0 && index < minIndex) - { - minIndex = index; - } + endpoints.Endpoints.Add(ep); } + + resolvedEndpointCount = resolved.Count; } - - var added = 0; - foreach (var ep in resolved) + else { - if (ep.EndPoint is UriEndPoint uri && uri.Uri.Scheme is { } scheme) + // Filter the resolved endpoints to only include those which match the specified, allowed schemes. + resolvedEndpointCount = 0; + var minIndex = _schemes.Length; + foreach (var ep in resolved) { - var index = Array.IndexOf(_schemes, scheme); - if (index >= 0 && index <= minIndex) + if (ep.EndPoint is UriEndPoint uri && uri.Uri.Scheme is { } scheme) { - ++added; - endpoints.Endpoints.Add(ep); + var index = Array.IndexOf(_schemes, scheme); + if (index >= 0 && index < minIndex) + { + minIndex = index; + } } } - else + + foreach (var ep in resolved) { - ++added; - endpoints.Endpoints.Add(ep); + if (ep.EndPoint is UriEndPoint uri && uri.Uri.Scheme is { } scheme) + { + var index = Array.IndexOf(_schemes, scheme); + if (index >= 0 && index <= minIndex) + { + ++resolvedEndpointCount; + endpoints.Endpoints.Add(ep); + } + } + else + { + ++resolvedEndpointCount; + endpoints.Endpoints.Add(ep); + } } } - if (added == 0) + if (resolvedEndpointCount == 0) { Log.ServiceConfigurationNotFound(_logger, _serviceName, configPath); } else { - Log.ConfiguredEndpoints(_logger, _serviceName, configPath, endpoints.Endpoints, added); + Log.ConfiguredEndpoints(_logger, _serviceName, configPath, endpoints.Endpoints, resolvedEndpointCount); } return default; diff --git a/src/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryOptions.cs b/src/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryOptions.cs index 02ce1af162..edc652507d 100644 --- a/src/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryOptions.cs +++ b/src/Microsoft.Extensions.ServiceDiscovery/ServiceDiscoveryOptions.cs @@ -32,29 +32,35 @@ public sealed class ServiceDiscoveryOptions internal static string[] ApplyAllowedSchemes(IReadOnlyList schemes, IList allowedSchemes, bool allowAllSchemes) { - if (allowAllSchemes) + if (schemes.Count > 0) { - if (schemes is string[] array) + if (allowAllSchemes) { - return array; - } + if (schemes is string[] array && array.Length > 0) + { + return array; + } - return schemes.ToArray(); - } + return schemes.ToArray(); + } - List result = []; - foreach (var s in schemes) - { - foreach (var allowed in allowedSchemes) + List result = []; + foreach (var s in schemes) { - if (string.Equals(s, allowed, StringComparison.OrdinalIgnoreCase)) + foreach (var allowed in allowedSchemes) { - result.Add(s); - break; + if (string.Equals(s, allowed, StringComparison.OrdinalIgnoreCase)) + { + result.Add(s); + break; + } } } + + return result.ToArray(); } - return result.ToArray(); + // If no schemes were specified, but a set of allowed schemes were specified, allow those. + return allowedSchemes.ToArray(); } } diff --git a/tests/Microsoft.Extensions.ServiceDiscovery.Tests/ConfigurationServiceEndpointResolverTests.cs b/tests/Microsoft.Extensions.ServiceDiscovery.Tests/ConfigurationServiceEndpointResolverTests.cs index db72078210..6955cc1e8e 100644 --- a/tests/Microsoft.Extensions.ServiceDiscovery.Tests/ConfigurationServiceEndpointResolverTests.cs +++ b/tests/Microsoft.Extensions.ServiceDiscovery.Tests/ConfigurationServiceEndpointResolverTests.cs @@ -18,7 +18,7 @@ namespace Microsoft.Extensions.ServiceDiscovery.Tests; public class ConfigurationServiceEndpointResolverTests { [Fact] - public async Task ResolveServiceEndpoint_Configuration_SingleResult() + public async Task ResolveServiceEndpoint_Configuration_SingleResult_NoScheme() { var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary { @@ -87,8 +87,23 @@ public async Task ResolveServiceEndpoint_Configuration_DisallowedScheme() Assert.Empty(initialResult.EndpointSource.Endpoints); } + // Specifying no scheme. + // We should get the HTTPS endpoint back, since it is explicitly allowed + await using ((watcher = watcherFactory.CreateWatcher("_foo.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task.ConfigureAwait(false); + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + var ep = Assert.Single(initialResult.EndpointSource.Endpoints); + Assert.Equal(new UriEndPoint(new Uri("https://localhost")), ep.EndPoint); + } + // Specifying either https or http. - // The result should be that we only get the http endpoint back. + // We should only get the https endpoint back. await using ((watcher = watcherFactory.CreateWatcher("https+http://_foo.basket")).ConfigureAwait(false)) { Assert.NotNull(watcher); @@ -103,7 +118,7 @@ public async Task ResolveServiceEndpoint_Configuration_DisallowedScheme() } // Specifying either https or http, but in reverse. - // The result should be that we only get the http endpoint back. + // We should only get the https endpoint back. await using ((watcher = watcherFactory.CreateWatcher("http+https://_foo.basket")).ConfigureAwait(false)) { Assert.NotNull(watcher); @@ -118,6 +133,144 @@ public async Task ResolveServiceEndpoint_Configuration_DisallowedScheme() } } + [Fact] + public async Task ResolveServiceEndpoint_Configuration_DefaultEndpointName() + { + var config = new ConfigurationBuilder().AddInMemoryCollection(new Dictionary + { + ["services:basket:default:0"] = "https://localhost:8080", + ["services:basket:otlp:0"] = "https://localhost:8888", + }); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider(o => + { + o.ShouldApplyHostNameMetadata = _ => true; + }) + .Configure(o => + { + o.AllowAllSchemes = false; + o.AllowedSchemes = ["https"]; + }) + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + + // Explicitly specifying https as the scheme, but the endpoint section in configuration is the default value ("default"). + // We should get the endpoint back because it is an https endpoint (allowed) with the default endpoint name. + await using ((watcher = watcherFactory.CreateWatcher("https://basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task.ConfigureAwait(false); + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Equal(1, initialResult.EndpointSource.Endpoints.Count); + Assert.Equal(new UriEndPoint(new Uri("https://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + + Assert.All(initialResult.EndpointSource.Endpoints, ep => + { + var hostNameFeature = ep.Features.Get(); + Assert.NotNull(hostNameFeature); + Assert.Equal("basket", hostNameFeature.HostName); + }); + } + + // Not specifying the scheme or endpoint name. + // We should get the endpoint back because it is an https endpoint (allowed) with the default endpoint name. + await using ((watcher = watcherFactory.CreateWatcher("basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task.ConfigureAwait(false); + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Equal(1, initialResult.EndpointSource.Endpoints.Count); + Assert.Equal(new UriEndPoint(new Uri("https://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + } + + // Not specifying the scheme, but specifying the default endpoint name. + // We should get the endpoint back because it is an https endpoint (allowed) with the default endpoint name. + await using ((watcher = watcherFactory.CreateWatcher("_default.basket")).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task.ConfigureAwait(false); + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + Assert.Equal(1, initialResult.EndpointSource.Endpoints.Count); + Assert.Equal(new UriEndPoint(new Uri("https://localhost:8080")), initialResult.EndpointSource.Endpoints[0].EndPoint); + } + } + + /// + /// Checks that when there is no named endpoint, configuration resolves first from the "default" section, then sections named by the scheme names. + /// + [Theory] + [InlineData(true, true, "https://basket", "https://default-host:8080")] + [InlineData(false, true, "https://basket","https://https-host:8080")] + [InlineData(true, false, "https://basket", "https://default-host:8080")] + [InlineData(true, true, "basket", "https://default-host:8080")] + [InlineData(false, true, "basket", null)] + [InlineData(true, false, "basket", "https://default-host:8080")] + [InlineData(true, true, "http+https://basket", "https://default-host:8080")] + [InlineData(false, true, "http+https://basket","https://https-host:8080")] + [InlineData(true, false, "http+https://basket", "https://default-host:8080")] + public async Task ResolveServiceEndpoint_Configuration_DefaultEndpointName_ResolutionOrder( + bool includeDefault, + bool includeSchemeNamed, + string serviceName, + string? expectedResult) + { + var data = new Dictionary(); + if (includeDefault) + { + data["services:basket:default:0"] = "https://default-host:8080"; + } + + if (includeSchemeNamed) + { + data["services:basket:https:0"] = "https://https-host:8080"; + } + + var config = new ConfigurationBuilder().AddInMemoryCollection(data); + var services = new ServiceCollection() + .AddSingleton(config.Build()) + .AddServiceDiscoveryCore() + .AddConfigurationServiceEndpointProvider() + .BuildServiceProvider(); + var watcherFactory = services.GetRequiredService(); + ServiceEndpointWatcher watcher; + + // Scheme in query + await using ((watcher = watcherFactory.CreateWatcher(serviceName)).ConfigureAwait(false)) + { + Assert.NotNull(watcher); + var tcs = new TaskCompletionSource(); + watcher.OnEndpointsUpdated = tcs.SetResult; + watcher.Start(); + var initialResult = await tcs.Task.ConfigureAwait(false); + Assert.NotNull(initialResult); + Assert.True(initialResult.ResolvedSuccessfully); + if (expectedResult is not null) + { + Assert.Equal(1, initialResult.EndpointSource.Endpoints.Count); + Assert.Equal(new UriEndPoint(new Uri(expectedResult)), initialResult.EndpointSource.Endpoints[0].EndPoint); + } + else + { + Assert.Empty(initialResult.EndpointSource.Endpoints); + } + } + } + [Fact] public async Task ResolveServiceEndpoint_Configuration_MultipleResults() { @@ -125,8 +278,8 @@ public async Task ResolveServiceEndpoint_Configuration_MultipleResults() { InitialData = new Dictionary { - ["services:basket:http:0"] = "http://localhost:8080", - ["services:basket:http:1"] = "http://remotehost:9090", + ["services:basket:default:0"] = "http://localhost:8080", + ["services:basket:default:1"] = "http://remotehost:9090", } }; var config = new ConfigurationBuilder().Add(configSource); @@ -274,19 +427,4 @@ public async Task ResolveServiceEndpoint_Configuration_MultipleProtocols_WithSpe }); } } - - public class MyConfigurationProvider : ConfigurationProvider, IConfigurationSource - { - public IConfigurationProvider Build(IConfigurationBuilder builder) => this; - public void SetValues(IEnumerable> values) - { - Data.Clear(); - foreach (var (key, value) in values) - { - Data[key] = value; - } - - OnReload(); - } - } }