Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support TLS connection in flagd provider #48

Merged
merged 3 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/OpenFeature.Contrib.Providers.Flagd/FlagdConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ internal class FlagdConfig
internal const string EnvVarHost = "FLAGD_HOST";
internal const string EnvVarPort = "FLAGD_PORT";
internal const string EnvVarTLS = "FLAGD_TLS";
internal const string EnvCertPart = "FLAGD_SERVER_CERT_PATH";
internal const string EnvVarSocketPath = "FLAGD_SOCKET_PATH";
internal const string EnvVarCache = "FLAGD_CACHE";
internal const string EnvVarMaxCacheSize = "FLAGD_MAX_CACHE_SIZE";
Expand All @@ -29,6 +30,17 @@ internal int MaxCacheSize
get { return _maxCacheSize; }
}

internal bool UseCertificate
{
get { return _cert.Length > 0; }
}

internal string CertificatePath
{
get { return _cert; }
set { _cert = value; }
}

internal int MaxEventStreamRetries
{
get { return _maxEventStreamRetries; }
Expand All @@ -38,6 +50,7 @@ internal int MaxEventStreamRetries
private string _host;
private string _port;
private bool _useTLS;
private string _cert;
private string _socketPath;
private bool _cache;
private int _maxCacheSize;
Expand All @@ -48,6 +61,7 @@ internal FlagdConfig()
_host = Environment.GetEnvironmentVariable(EnvVarHost) ?? "localhost";
_port = Environment.GetEnvironmentVariable(EnvVarPort) ?? "8013";
_useTLS = bool.Parse(Environment.GetEnvironmentVariable(EnvVarTLS) ?? "false");
_cert = Environment.GetEnvironmentVariable(EnvCertPart) ?? "";
_socketPath = Environment.GetEnvironmentVariable(EnvVarSocketPath) ?? "";
var cacheStr = Environment.GetEnvironmentVariable(EnvVarCache) ?? "";

Expand Down
39 changes: 32 additions & 7 deletions src/OpenFeature.Contrib.Providers.Flagd/FlagdProvider.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System;
using System.IO;
using System.Text;
using System.Linq;
using System.Threading.Tasks;
using System.Security.Cryptography.X509Certificates;

using Google.Protobuf.WellKnownTypes;
using Grpc.Core;
Expand Down Expand Up @@ -40,6 +43,7 @@ public sealed class FlagdProvider : FeatureProvider
/// FLAGD_HOST - The host name of the flagd server (default="localhost")
/// FLAGD_PORT - The port of the flagd server (default="8013")
/// FLAGD_TLS - Determines whether to use https or not (default="false")
/// FLAGD_FLAGD_SERVER_CERT_PATH - The path to the client certificate (default="")
/// FLAGD_SOCKET_PATH - Path to the unix socket (default="")
/// FLAGD_CACHE - Enable or disable the cache (default="false")
/// FLAGD_MAX_CACHE_SIZE - The maximum size of the cache (default="10")
Expand All @@ -49,7 +53,7 @@ public FlagdProvider()
{
_config = new FlagdConfig();

_client = buildClientForPlatform(_config.GetUri());
_client = BuildClientForPlatform(_config.GetUri());

_mtx = new System.Threading.Mutex();

Expand Down Expand Up @@ -77,7 +81,7 @@ public FlagdProvider(Uri url)

_mtx = new System.Threading.Mutex();

_client = buildClientForPlatform(url);
_client = BuildClientForPlatform(url);
}


Expand Down Expand Up @@ -499,20 +503,41 @@ private static Value ConvertToPrimitiveValue(ProtoValue value)
}
}

private static Service.ServiceClient buildClientForPlatform(Uri url)
private Service.ServiceClient BuildClientForPlatform(Uri url)
{
var useUnixSocket = url.ToString().StartsWith("unix://");

if (!useUnixSocket)
{
#if NET462
return new Service.ServiceClient(GrpcChannel.ForAddress(url, new GrpcChannelOptions
var handler = new WinHttpHandler();
#else
var handler = new HttpClientHandler();
#endif
if (_config.UseCertificate)
{
HttpHandler = new WinHttpHandler(),
}));
#if NET5_0_OR_GREATER
if (File.Exists(_config.CertificatePath)) {
X509Certificate2 certificate = new X509Certificate2(_config.CertificatePath);
handler.ServerCertificateCustomValidationCallback = (message, cert, chain, _) => {
// the the custom cert to the chain, Build returns a bool if valid.
chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
chain.ChainPolicy.CustomTrustStore.Add(certificate);
return chain.Build(cert);
};
} else {
throw new ArgumentException("Specified certificate cannot be found.");
}
#else
return new Service.ServiceClient(GrpcChannel.ForAddress(url));
// Pre-NET5.0 APIs for custom CA validation are cumbersome.
// Looking for additional contributions here.
throw new ArgumentException("Custom certificate authorities not supported on this platform.");
#endif
}
return new Service.ServiceClient(GrpcChannel.ForAddress(url, new GrpcChannelOptions
{
HttpHandler = handler
}));
}

#if NET5_0_OR_GREATER
Expand Down
3 changes: 2 additions & 1 deletion src/OpenFeature.Contrib.Providers.Flagd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ The URI of the flagd server to which the `flagd Provider` connects to can either
| host | FLAGD_HOST | string | localhost | |
| port | FLAGD_PORT | number | 8013 | |
| tls | FLAGD_TLS | boolean | false | |
| tls certPath | FLAGD_SERVER_CERT_PATH | string | | |
| unix socket path | FLAGD_SOCKET_PATH | string | | |
| Caching | FLAGD_CACHE | string | | LRU |
| Maximum cache size | FLAGD_MAX_CACHE_SIZE | number | 10 | |
| Maximum event stream retries | FLAGD_MAX_EVENT_STREAM_RETRIES | number | 3 | |

Note that if `FLAGD_SOCKET_PATH` is set, this value takes precedence, and the other variables (`FLAGD_HOST`, `FLAGD_PORT`, `FLAGD_TLS`) are disregarded.
Note that if `FLAGD_SOCKET_PATH` is set, this value takes precedence, and the other variables (`FLAGD_HOST`, `FLAGD_PORT`, `FLAGD_TLS`, `FLAGD_SERVER_CERT_PATH`) are disregarded.


If you rely on the environment variables listed above, you can use the empty constructor which then configures the provider accordingly:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,32 @@ public void TestFlagdConfigEnabledCacheApplyCacheSize()
Assert.Equal(20, config.MaxCacheSize);
}

[Fact]
public void TestFlagdConfigSetCertificatePath()
{
CleanEnvVars();
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvCertPart, "/cert/path");

var config = new FlagdConfig();

Assert.Equal("/cert/path", config.CertificatePath);
Assert.True(config.UseCertificate);

CleanEnvVars();

config = new FlagdConfig();

Assert.Equal("", config.CertificatePath);
Assert.False(config.UseCertificate);
}

private void CleanEnvVars()
{
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarTLS, "");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarSocketPath, "");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarCache, "");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarMaxCacheSize, "");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvCertPart, "");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,62 @@
using OpenFeature.Error;
using ProtoValue = Google.Protobuf.WellKnownTypes.Value;
using System.Collections.Generic;
using System.Linq;
using OpenFeature.Model;
using System.Threading;
using System;

namespace OpenFeature.Contrib.Providers.Flagd.Test
{
public class UnitTestFlagdProvider
{
[Fact]
public void BuildClientForPlatform_Should_Throw_Exception_When_FlagdCertPath_Not_Exists()
{
// Arrange
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvCertPart, "non-existing-path");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarHost, "localhost");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarPort, "5001");

// Act & Assert
Assert.Throws<ArgumentException>(() => new FlagdProvider());

// Cleanup
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvCertPart, "");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarHost, "");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarPort, "");
}

[Fact]
public void BuildClientForPlatform_Should_Return_Client_For_Non_Unix_Socket_Without_Certificate()
{
// Arrange
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarHost, "localhost");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarPort, "5001");

// Act
var flagdProvider = new FlagdProvider();
var client = flagdProvider.GetClient();

// Assert
Assert.NotNull(client);
Assert.IsType<Service.ServiceClient>(client);

// Cleanup
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarHost, "");
System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarPort, "");
}

#if NET462
[Fact]
public void BuildClientForPlatform_Should_Throw_Exception_For_Unsupported_DotNet_Version()
{
// Arrange
var url = new Uri("unix:///var/run/flagd.sock");

// Act & Assert
Assert.Throws<Exception>(() => new FlagdProvider(url));
}
#endif
[Fact]
public void TestGetProviderName()
{
Expand Down