From f4d21426e9ec079d62ecca4e8d1936cb8ad299b7 Mon Sep 17 00:00:00 2001 From: Florian Bacher Date: Mon, 3 Apr 2023 18:37:22 +0200 Subject: [PATCH] feat: implemented LRU caching for flagd provider (#47) Signed-off-by: Florian Bacher --- .../Cache.cs | 173 ++++++++++++ .../FlagdConfig.cs | 83 ++++++ .../FlagdProvider.cs | 185 +++++++++--- .../README.md | 15 +- .../CacheTest.cs | 181 ++++++++++++ .../FlagdConfigTest.cs | 72 +++++ .../FlagdProviderTest.cs | 263 +++++++++++++++++- 7 files changed, 922 insertions(+), 50 deletions(-) create mode 100644 src/OpenFeature.Contrib.Providers.Flagd/Cache.cs create mode 100644 src/OpenFeature.Contrib.Providers.Flagd/FlagdConfig.cs create mode 100644 test/OpenFeature.Contrib.Providers.Flagd.Test/CacheTest.cs create mode 100644 test/OpenFeature.Contrib.Providers.Flagd.Test/FlagdConfigTest.cs diff --git a/src/OpenFeature.Contrib.Providers.Flagd/Cache.cs b/src/OpenFeature.Contrib.Providers.Flagd/Cache.cs new file mode 100644 index 00000000..ec146add --- /dev/null +++ b/src/OpenFeature.Contrib.Providers.Flagd/Cache.cs @@ -0,0 +1,173 @@ +using System.Collections.Generic; +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("DynamicProxyGenAssembly2")] + +namespace OpenFeature.Contrib.Providers.Flagd +{ + internal interface ICache + { + void Add(TKey key, TValue value); + TValue TryGet(TKey key); + void Delete(TKey key); + void Purge(); + } + class LRUCache : ICache where TValue : class + { + private readonly int _capacity; + private readonly Dictionary _map; + private Node _head; + private Node _tail; + + private System.Threading.Mutex _mtx; + + public LRUCache(int capacity) + { + _capacity = capacity; + _map = new Dictionary(); + _mtx = new System.Threading.Mutex(); + } + + public TValue TryGet(TKey key) + { + using (var mtx = new Mutex(ref _mtx)) + { + mtx.Lock(); + if (_map.TryGetValue(key, out Node node)) + { + MoveToFront(node); + return node.Value; + } + return default(TValue); + } + } + + public void Add(TKey key, TValue value) + { + using (var mtx = new Mutex(ref _mtx)) + { + mtx.Lock(); + if (_map.TryGetValue(key, out Node node)) + { + node.Value = value; + MoveToFront(node); + } + else + { + if (_map.Count >= _capacity) + { + _map.Remove(_tail.Key); + RemoveTail(); + } + node = new Node(key, value); + _map.Add(key, node); + AddToFront(node); + } + } + } + + public void Delete(TKey key) + { + using (var mtx = new Mutex(ref _mtx)) + { + mtx.Lock(); + if (_map.TryGetValue(key, out Node node)) + { + if (node == _head) + { + _head = node.Next; + } + else + { + node.Prev.Next = node.Next; + } + if (node.Next != null) + { + node.Next.Prev = node.Prev; + } + _map.Remove(key); + } + } + } + + public void Purge() + { + using (var mtx = new Mutex(ref _mtx)) + { + mtx.Lock(); + _map.Clear(); + } + } + + private void MoveToFront(Node node) + { + if (node == _head) + return; + node.Prev.Next = node.Next; + if (node == _tail) + _tail = node.Prev; + else + node.Next.Prev = node.Prev; + AddToFront(node); + } + + private void AddToFront(Node node) + { + if (_head == null) + { + _head = node; + _tail = node; + return; + } + node.Next = _head; + _head.Prev = node; + _head = node; + } + + private void RemoveTail() + { + _tail = _tail.Prev; + if (_tail != null) + _tail.Next = null; + else + _head = null; + } + + private class Node + { + public TKey Key; + public TValue Value; + public Node Next; + public Node Prev; + + public Node(TKey key, TValue value) + { + Key = key; + Value = value; + } + } + + private class Mutex : System.IDisposable + { + + public System.Threading.Mutex _mtx; + + public Mutex(ref System.Threading.Mutex mtx) + { + _mtx = mtx; + } + + public void Lock() + { + _mtx.WaitOne(); + } + + public void Dispose() + { + _mtx.ReleaseMutex(); + } + } + } + +} + diff --git a/src/OpenFeature.Contrib.Providers.Flagd/FlagdConfig.cs b/src/OpenFeature.Contrib.Providers.Flagd/FlagdConfig.cs new file mode 100644 index 00000000..b855ee73 --- /dev/null +++ b/src/OpenFeature.Contrib.Providers.Flagd/FlagdConfig.cs @@ -0,0 +1,83 @@ +using System; + +namespace OpenFeature.Contrib.Providers.Flagd + +{ + internal class FlagdConfig + { + internal const string EnvVarHost = "FLAGD_HOST"; + internal const string EnvVarPort = "FLAGD_PORT"; + internal const string EnvVarTLS = "FLAGD_TLS"; + internal const string EnvVarSocketPath = "FLAGD_SOCKET_PATH"; + internal const string EnvVarCache = "FLAGD_CACHE"; + internal const string EnvVarMaxCacheSize = "FLAGD_MAX_CACHE_SIZE"; + internal const string EnvVarMaxEventStreamRetries = "FLAGD_MAX_EVENT_STREAM_RETRIES"; + internal static int CacheSizeDefault = 10; + internal string Host + { + get { return _host; } + } + + internal bool CacheEnabled + { + get { return _cache; } + set { _cache = value; } + } + + internal int MaxCacheSize + { + get { return _maxCacheSize; } + } + + internal int MaxEventStreamRetries + { + get { return _maxEventStreamRetries; } + set { _maxEventStreamRetries = value; } + } + + private string _host; + private string _port; + private bool _useTLS; + private string _socketPath; + private bool _cache; + private int _maxCacheSize; + private int _maxEventStreamRetries; + + internal FlagdConfig() + { + _host = Environment.GetEnvironmentVariable(EnvVarHost) ?? "localhost"; + _port = Environment.GetEnvironmentVariable(EnvVarPort) ?? "8013"; + _useTLS = bool.Parse(Environment.GetEnvironmentVariable(EnvVarTLS) ?? "false"); + _socketPath = Environment.GetEnvironmentVariable(EnvVarSocketPath) ?? ""; + var cacheStr = Environment.GetEnvironmentVariable(EnvVarCache) ?? ""; + + if (cacheStr.ToUpper().Equals("LRU")) + { + _cache = true; + _maxCacheSize = int.Parse(Environment.GetEnvironmentVariable(EnvVarMaxCacheSize) ?? $"{CacheSizeDefault}"); + _maxEventStreamRetries = int.Parse(Environment.GetEnvironmentVariable(EnvVarMaxEventStreamRetries) ?? "3"); + } + } + + internal Uri GetUri() + { + Uri uri; + if (_socketPath != "") + { + uri = new Uri("unix://" + _socketPath); + } + else + { + var protocol = "http"; + + if (_useTLS) + { + protocol = "https"; + } + + uri = new Uri(protocol + "://" + _host + ":" + _port); + } + return uri; + } + } +} \ No newline at end of file diff --git a/src/OpenFeature.Contrib.Providers.Flagd/FlagdProvider.cs b/src/OpenFeature.Contrib.Providers.Flagd/FlagdProvider.cs index bcab5830..cb37bae4 100644 --- a/src/OpenFeature.Contrib.Providers.Flagd/FlagdProvider.cs +++ b/src/OpenFeature.Contrib.Providers.Flagd/FlagdProvider.cs @@ -12,8 +12,9 @@ using Metadata = OpenFeature.Model.Metadata; using Value = OpenFeature.Model.Value; using ProtoValue = Google.Protobuf.WellKnownTypes.Value; -using System.Net.Http; using System.Net.Sockets; +using System.Net.Http; +using System.Collections.Generic; namespace OpenFeature.Contrib.Providers.Flagd { @@ -22,44 +23,44 @@ namespace OpenFeature.Contrib.Providers.Flagd /// public sealed class FlagdProvider : FeatureProvider { + static int EventStreamRetryBaseBackoff = 1; + private readonly FlagdConfig _config; private readonly Service.ServiceClient _client; private readonly Metadata _providerMetadata = new Metadata("flagd Provider"); + private readonly ICache _cache; + private int _eventStreamRetries; + private int _eventStreamRetryBackoff = EventStreamRetryBaseBackoff; + + private System.Threading.Mutex _mtx; + /// /// Constructor of the provider. This constructor uses the value of the following /// environment variables to initialise its client: - /// FLAGD_HOST - The host name of the flagd server (default="localhost") - /// FLAGD_PORT - The port of the flagd server (default="8013") - /// FLAGD_TLS - Determines whether to use https or not (default="false") - /// FLAGD_SOCKET_PATH - Path to the unix socket (default="") + /// FLAGD_HOST - The host name of the flagd server (default="localhost") + /// FLAGD_PORT - The port of the flagd server (default="8013") + /// FLAGD_TLS - Determines whether to use https or not (default="false") + /// FLAGD_SOCKET_PATH - Path to the unix socket (default="") + /// FLAGD_CACHE - Enable or disable the cache (default="false") + /// FLAGD_MAX_CACHE_SIZE - The maximum size of the cache (default="10") + /// FLAGD_MAX_EVENT_STREAM_RETRIES - The maximum amount of retries for establishing the EventStream /// public FlagdProvider() { - var flagdHost = Environment.GetEnvironmentVariable("FLAGD_HOST") ?? "localhost"; - var flagdPort = Environment.GetEnvironmentVariable("FLAGD_PORT") ?? "8013"; - var flagdUseTLSStr = Environment.GetEnvironmentVariable("FLAGD_TLS") ?? "false"; - var flagdSocketPath = Environment.GetEnvironmentVariable("FLAGD_SOCKET_PATH") ?? ""; + _config = new FlagdConfig(); + _client = buildClientForPlatform(_config.GetUri()); - Uri uri; - if (flagdSocketPath != "") - { - uri = new Uri("unix://" + flagdSocketPath); - } - else - { - var protocol = "http"; - var useTLS = bool.Parse(flagdUseTLSStr); + _mtx = new System.Threading.Mutex(); - if (useTLS) + if (_config.CacheEnabled) + { + _cache = new LRUCache(_config.MaxCacheSize); + Task.Run(async () => { - protocol = "https"; - } - - uri = new Uri(protocol + "://" + flagdHost + ":" + flagdPort); + await HandleEvents(); + }); } - - _client = buildClientForPlatform(uri); } /// @@ -74,13 +75,27 @@ public FlagdProvider(Uri url) throw new ArgumentNullException(nameof(url)); } + _mtx = new System.Threading.Mutex(); + _client = buildClientForPlatform(url); } + // just for testing, internal but visible in tests - internal FlagdProvider(Service.ServiceClient client) + internal FlagdProvider(Service.ServiceClient client, FlagdConfig config, ICache cache = null) { + _mtx = new System.Threading.Mutex(); _client = client; + _config = config; + _cache = cache; + + if (_config.CacheEnabled) + { + Task.Run(async () => + { + await HandleEvents(); + }); + } } /// @@ -110,7 +125,7 @@ public static string GetProviderName() /// A ResolutionDetails object containing the value of your flag public override async Task> ResolveBooleanValue(string flagKey, bool defaultValue, EvaluationContext context = null) { - return await ResolveValue(async contextStruct => + return await ResolveValue(flagKey, async contextStruct => { var resolveBooleanResponse = await _client.ResolveBooleanAsync(new ResolveBooleanRequest { @@ -120,7 +135,7 @@ public override async Task> ResolveBooleanValue(string f return new ResolutionDetails( flagKey: flagKey, - value: resolveBooleanResponse.Value, + value: (bool)resolveBooleanResponse.Value, reason: resolveBooleanResponse.Reason, variant: resolveBooleanResponse.Variant ); @@ -136,7 +151,7 @@ public override async Task> ResolveBooleanValue(string f /// A ResolutionDetails object containing the value of your flag public override async Task> ResolveStringValue(string flagKey, string defaultValue, EvaluationContext context = null) { - return await ResolveValue(async contextStruct => + return await ResolveValue(flagKey, async contextStruct => { var resolveStringResponse = await _client.ResolveStringAsync(new ResolveStringRequest { @@ -162,7 +177,7 @@ public override async Task> ResolveStringValue(string /// A ResolutionDetails object containing the value of your flag public override async Task> ResolveIntegerValue(string flagKey, int defaultValue, EvaluationContext context = null) { - return await ResolveValue(async contextStruct => + return await ResolveValue(flagKey, async contextStruct => { var resolveIntResponse = await _client.ResolveIntAsync(new ResolveIntRequest { @@ -188,7 +203,7 @@ public override async Task> ResolveIntegerValue(string fl /// A ResolutionDetails object containing the value of your flag public override async Task> ResolveDoubleValue(string flagKey, double defaultValue, EvaluationContext context = null) { - return await ResolveValue(async contextStruct => + return await ResolveValue(flagKey, async contextStruct => { var resolveDoubleResponse = await _client.ResolveFloatAsync(new ResolveFloatRequest { @@ -214,7 +229,7 @@ public override async Task> ResolveDoubleValue(string /// A ResolutionDetails object containing the value of your flag public override async Task> ResolveStructureValue(string flagKey, Value defaultValue, EvaluationContext context = null) { - return await ResolveValue(async contextStruct => + return await ResolveValue(flagKey, async contextStruct => { var resolveObjectResponse = await _client.ResolveObjectAsync(new ResolveObjectRequest { @@ -231,12 +246,26 @@ public override async Task> ResolveStructureValue(strin }, context); } - private async Task> ResolveValue(Func>> resolveDelegate, EvaluationContext context = null) + private async Task> ResolveValue(string flagKey, Func>> resolveDelegate, EvaluationContext context = null) { try { + if (_config.CacheEnabled) + { + var value = _cache.TryGet(flagKey); + + if (value != null) + { + return (ResolutionDetails)value; + } + } var result = await resolveDelegate.Invoke(ConvertToContext(context)); + if (result.Reason.Equals("STATIC") && _config.CacheEnabled) + { + _cache.Add(flagKey, result); + } + return result; } catch (RpcException e) @@ -265,6 +294,92 @@ private FeatureProviderException GetOFException(Grpc.Core.RpcException e) } } + private async Task HandleEvents() + { + while (_eventStreamRetries < _config.MaxEventStreamRetries) + { + var call = _client.EventStream(new Empty()); + try + { + // Read the response stream asynchronously + while (await call.ResponseStream.MoveNext()) + { + var response = call.ResponseStream.Current; + + switch (response.Type.ToLower()) + { + case "configuration_change": + HandleConfigurationChangeEvent(response.Data); + break; + case "provider_ready": + HandleProviderReadyEvent(); + break; + default: + break; + } + } + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) + { + // Handle the dropped connection by reconnecting and retrying the stream + await HandleErrorEvent(); + } + } + } + + private void HandleConfigurationChangeEvent(Struct data) + { + // if we don't have a cache, we don't need to remove anything + if (!_config.CacheEnabled || !data.Fields.ContainsKey("flags")) + { + return; + } + + try + { + if (data.Fields.TryGetValue("flags", out ProtoValue val)) + { + if (val.KindCase == ProtoValue.KindOneofCase.StructValue) + { + val.StructValue.Fields.ToList().ForEach(flag => + { + _cache.Delete(flag.Key); + }); + } + var structVal = val.StructValue; + } + } + catch (Exception) + { + // purge the cache if we could not handle the configuration change event + _cache.Purge(); + } + + } + + private void HandleProviderReadyEvent() + { + _mtx.WaitOne(); + _eventStreamRetries = 0; + _eventStreamRetryBackoff = EventStreamRetryBaseBackoff; + _mtx.ReleaseMutex(); + _cache.Purge(); + } + + private async Task HandleErrorEvent() + { + _mtx.WaitOne(); + _eventStreamRetries++; + + if (_eventStreamRetries > _config.MaxEventStreamRetries) + { + return; + } + _eventStreamRetryBackoff = _eventStreamRetryBackoff * 2; + _mtx.ReleaseMutex(); + await Task.Delay(_eventStreamRetryBackoff * 1000); + } + /// /// ConvertToContext converts the given EvaluationContext to a Struct. /// @@ -393,7 +508,7 @@ private static Service.ServiceClient buildClientForPlatform(Uri url) #if NET462 return new Service.ServiceClient(GrpcChannel.ForAddress(url, new GrpcChannelOptions { - HttpHandler = new WinHttpHandler() + HttpHandler = new WinHttpHandler(), })); #else return new Service.ServiceClient(GrpcChannel.ForAddress(url)); @@ -412,7 +527,7 @@ private static Service.ServiceClient buildClientForPlatform(Uri url) // see https://learn.microsoft.com/en-us/aspnet/core/grpc/interprocess-uds?view=aspnetcore-7.0 for more details return new Service.ServiceClient(GrpcChannel.ForAddress("http://localhost", new GrpcChannelOptions { - HttpHandler = socketsHttpHandler + HttpHandler = socketsHttpHandler, })); #endif // unix socket support is not available in this dotnet version diff --git a/src/OpenFeature.Contrib.Providers.Flagd/README.md b/src/OpenFeature.Contrib.Providers.Flagd/README.md index 8351a0af..470af5e1 100644 --- a/src/OpenFeature.Contrib.Providers.Flagd/README.md +++ b/src/OpenFeature.Contrib.Providers.Flagd/README.md @@ -76,12 +76,15 @@ namespace OpenFeatureTestApp The URI of the flagd server to which the `flagd Provider` connects to can either be passed directly to the constructor, or be configured using the following environment variables: -| Option name | Environment variable name | Type | Default | Values | -|------------------|--------------------------------|---------|-----------| ------------- | -| host | FLAGD_HOST | string | localhost | | -| port | FLAGD_PORT | number | 8013 | | -| tls | FLAGD_TLS | boolean | false | | -| unix socket path | FLAGD_SOCKET_PATH | string | | | +| Option name | Environment variable name | Type | Default | Values | +|------------------------------|--------------------------------|---------|-----------| ------------- | +| host | FLAGD_HOST | string | localhost | | +| port | FLAGD_PORT | number | 8013 | | +| tls | FLAGD_TLS | boolean | false | | +| unix socket path | FLAGD_SOCKET_PATH | string | | | +| Caching | FLAGD_CACHE | string | | LRU | +| Maximum cache size | FLAGD_MAX_CACHE_SIZE | number | 10 | | +| Maximum event stream retries | FLAGD_MAX_EVENT_STREAM_RETRIES | number | 3 | | Note that if `FLAGD_SOCKET_PATH` is set, this value takes precedence, and the other variables (`FLAGD_HOST`, `FLAGD_PORT`, `FLAGD_TLS`) are disregarded. diff --git a/test/OpenFeature.Contrib.Providers.Flagd.Test/CacheTest.cs b/test/OpenFeature.Contrib.Providers.Flagd.Test/CacheTest.cs new file mode 100644 index 00000000..490d0640 --- /dev/null +++ b/test/OpenFeature.Contrib.Providers.Flagd.Test/CacheTest.cs @@ -0,0 +1,181 @@ +using Xunit; + +namespace OpenFeature.Contrib.Providers.Flagd.Test +{ + + public class UnitTestLRUCache + { + [Fact] + public void TestCacheSetGet() + { + int capacity = 5; + var cache = new LRUCache(capacity); + + cache.Add("my-key", "my-value"); + + var value = cache.TryGet("my-key"); + Assert.Equal("my-value", value); + } + [Fact] + public void TestCacheCapacity() + { + int capacity = 5; + var cache = new LRUCache(capacity); + + var tasks = new System.Collections.Generic.List(); + + for (int i = 0; i < capacity; i++) + { + cache.Add($"key-{i}", $"value-{i}"); + } + + var e = tasks.GetEnumerator(); + while (e.MoveNext()) + { + e.Current.Wait(); + } + + string value; + // verify that we can retrieve all items + for (int i = 0; i < capacity; i++) + { + value = cache.TryGet($"key-{i}"); + + Assert.Equal($"value-{i}", value); + } + + // add another item - now the least recently used item ("key-0") should be replaced + cache.Add("new-item", "new-value"); + + value = cache.TryGet("key-0"); + Assert.Null(value); + + value = cache.TryGet("new-item"); + Assert.Equal("new-value", value); + } + + [Fact] + public void TestCacheCapacityMultiThreaded() + { + int capacity = 5; + var cache = new LRUCache(capacity); + + var tasks = new System.Collections.Generic.List(); + + var counter = 0; + for (int i = 0; i < capacity; i++) + { + tasks.Add(System.Threading.Tasks.Task.Run(() => + { + var id = System.Threading.Interlocked.Increment(ref counter); + cache.Add($"key-{id}", $"value-{id}"); + })); + //cache.Add($"key-{i}", $"value-{i}"); + } + + var e = tasks.GetEnumerator(); + while (e.MoveNext()) + { + e.Current.Wait(); + } + + string value; + // verify that we can retrieve all items + for (int i = 1; i <= capacity; i++) + { + value = cache.TryGet($"key-{i}"); + + Assert.Equal($"value-{i}", value); + } + + // add another item - now the least recently used item ("key-0") should be replaced + cache.Add("new-item", "new-value"); + + value = cache.TryGet("key-0"); + Assert.Null(value); + + value = cache.TryGet("new-item"); + Assert.Equal("new-value", value); + } + + [Fact] + public void TestCacheDeleteOnlyItem() + { + int capacity = 5; + var cache = new LRUCache(capacity); + + var key = "my-key"; + var expectedValue = "my-value"; + + cache.Add(key, expectedValue); + + string value; + + value = cache.TryGet(key); + Assert.Equal(expectedValue, value); + + cache.Delete(key); + + value = cache.TryGet(key); + Assert.Null(value); + + // check that we can add something again after deleting the last item + cache.Add(key, expectedValue); + Assert.Equal(expectedValue, cache.TryGet(key)); + } + + [Fact] + public void TestCacheDeleteHead() + { + int capacity = 5; + var cache = new LRUCache(capacity); + + + cache.Add("tail", "tail"); + cache.Add("middle", "middle"); + cache.Add("head", "head"); + + cache.Delete("head"); + + Assert.Null(cache.TryGet("head")); + Assert.Equal("middle", cache.TryGet("middle")); + Assert.Equal("tail", cache.TryGet("tail")); + } + + [Fact] + public void TestCacheDeleteMiddle() + { + int capacity = 5; + var cache = new LRUCache(capacity); + + + cache.Add("tail", "tail"); + cache.Add("middle", "middle"); + cache.Add("head", "head"); + + cache.Delete("middle"); + + Assert.Null(cache.TryGet("middle")); + Assert.Equal("head", cache.TryGet("head")); + Assert.Equal("tail", cache.TryGet("tail")); + } + + [Fact] + public void TestPurge() + { + int capacity = 5; + var cache = new LRUCache(capacity); + + + cache.Add("tail", "tail"); + cache.Add("middle", "middle"); + cache.Add("head", "head"); + + cache.Purge(); + + Assert.Null(cache.TryGet("head")); + Assert.Null(cache.TryGet("middle")); + Assert.Null(cache.TryGet("tail")); + } + } +} \ No newline at end of file diff --git a/test/OpenFeature.Contrib.Providers.Flagd.Test/FlagdConfigTest.cs b/test/OpenFeature.Contrib.Providers.Flagd.Test/FlagdConfigTest.cs new file mode 100644 index 00000000..95d6e744 --- /dev/null +++ b/test/OpenFeature.Contrib.Providers.Flagd.Test/FlagdConfigTest.cs @@ -0,0 +1,72 @@ +using Xunit; + +namespace OpenFeature.Contrib.Providers.Flagd.Test +{ + public class UnitTestFlagdConfig + { + [Fact] + public void TestFlagdConfigDefault() + { + CleanEnvVars(); + var config = new FlagdConfig(); + + Assert.False(config.CacheEnabled); + Assert.Equal(new System.Uri("http://localhost:8013"), config.GetUri()); + } + + [Fact] + public void TestFlagdConfigUseTLS() + { + CleanEnvVars(); + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarTLS, "true"); + + var config = new FlagdConfig(); + + Assert.Equal(new System.Uri("https://localhost:8013"), config.GetUri()); + } + + [Fact] + public void TestFlagdConfigUnixSocket() + { + CleanEnvVars(); + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarSocketPath, "tmp.sock"); + + var config = new FlagdConfig(); + + Assert.Equal(new System.Uri("unix://tmp.sock/"), config.GetUri()); + } + + [Fact] + public void TestFlagdConfigEnabledCacheDefaultCacheSize() + { + CleanEnvVars(); + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarCache, "LRU"); + + var config = new FlagdConfig(); + + Assert.True(config.CacheEnabled); + Assert.Equal(FlagdConfig.CacheSizeDefault, config.MaxCacheSize); + } + + [Fact] + public void TestFlagdConfigEnabledCacheApplyCacheSize() + { + CleanEnvVars(); + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarCache, "LRU"); + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarMaxCacheSize, "20"); + + var config = new FlagdConfig(); + + Assert.True(config.CacheEnabled); + Assert.Equal(20, config.MaxCacheSize); + } + + private void CleanEnvVars() + { + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarTLS, ""); + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarSocketPath, ""); + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarCache, ""); + System.Environment.SetEnvironmentVariable(FlagdConfig.EnvVarMaxCacheSize, ""); + } + } +} \ No newline at end of file diff --git a/test/OpenFeature.Contrib.Providers.Flagd.Test/FlagdProviderTest.cs b/test/OpenFeature.Contrib.Providers.Flagd.Test/FlagdProviderTest.cs index 7a62ee46..c7681987 100644 --- a/test/OpenFeature.Contrib.Providers.Flagd.Test/FlagdProviderTest.cs +++ b/test/OpenFeature.Contrib.Providers.Flagd.Test/FlagdProviderTest.cs @@ -5,6 +5,10 @@ using Google.Protobuf.WellKnownTypes; using OpenFeature.Error; using ProtoValue = Google.Protobuf.WellKnownTypes.Value; +using System.Collections.Generic; +using System.Linq; +using OpenFeature.Model; +using System.Threading; namespace OpenFeature.Contrib.Providers.Flagd.Test { @@ -45,7 +49,7 @@ public void TestResolveBooleanValue() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); // resolve with default set to false to make sure we return what the grpc server gives us var val = flagdProvider.ResolveBooleanValue("my-key", false, null); @@ -72,7 +76,7 @@ public void TestResolveStringValue() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); var val = flagdProvider.ResolveStringValue("my-key", "", null); @@ -98,7 +102,7 @@ public void TestResolveIntegerValue() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); var val = flagdProvider.ResolveIntegerValue("my-key", 0, null); @@ -124,7 +128,7 @@ public void TestResolveDoubleValue() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); var val = flagdProvider.ResolveDoubleValue("my-key", 0.0, null); @@ -155,7 +159,7 @@ public void TestResolveStructureValue() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); var val = flagdProvider.ResolveStructureValue("my-key", null, null); @@ -180,7 +184,7 @@ public void TestResolveFlagNotFound() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); // make sure the correct exception is thrown Assert.ThrowsAsync(async () => @@ -216,7 +220,7 @@ public void TestResolveGrpcHostUnavailable() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); // make sure the correct exception is thrown Assert.ThrowsAsync(async () => @@ -252,7 +256,7 @@ public void TestResolveTypeMismatch() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); // make sure the correct exception is thrown Assert.ThrowsAsync(async () => @@ -288,7 +292,7 @@ public void TestResolveUnknownError() It.IsAny(), null, null, System.Threading.CancellationToken.None)) .Returns(grpcResp); - var flagdProvider = new FlagdProvider(mockGrpcClient.Object); + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, new FlagdConfig()); // make sure the correct exception is thrown Assert.ThrowsAsync(async () => @@ -305,5 +309,246 @@ public void TestResolveUnknownError() } }); } + + [Fact] + public void TestCache() + { + var resp = new ResolveBooleanResponse(); + resp.Value = true; + resp.Reason = "STATIC"; + + var grpcResp = new AsyncUnaryCall( + System.Threading.Tasks.Task.FromResult(resp), + System.Threading.Tasks.Task.FromResult(new Grpc.Core.Metadata()), + () => Status.DefaultSuccess, + () => new Grpc.Core.Metadata(), + () => { }); + + var mockGrpcClient = new Mock(); + mockGrpcClient + .Setup(m => m.ResolveBooleanAsync( + It.IsAny(), null, null, System.Threading.CancellationToken.None)) + .Returns(grpcResp); + + var asyncStreamReader = new Mock>(); + + var l = new List + { + new EventStreamResponse{ + Type = "provider_ready" + } + }; + + var enumerator = l.GetEnumerator(); + + // create an autoResetEvent which we will wait for in our test verification + var _autoResetEvent = new AutoResetEvent(false); + + asyncStreamReader.Setup(a => a.MoveNext(It.IsAny())).ReturnsAsync(() => enumerator.MoveNext()); + asyncStreamReader.Setup(a => a.Current).Returns(() => + { + // set the autoResetEvent since this path should be the last one that's reached in the background task + _autoResetEvent.Set(); + return enumerator.Current; + }); + + var grpcEventStreamResp = new AsyncServerStreamingCall( + asyncStreamReader.Object, + null, + null, + null, + null, + null + ); + + mockGrpcClient + .Setup(m => m.EventStream( + It.IsAny(), null, null, System.Threading.CancellationToken.None)) + .Returns(grpcEventStreamResp); + + var mockCache = new Mock>(); + mockCache.Setup(c => c.TryGet(It.Is(s => s == "my-key"))).Returns(() => null); + mockCache.Setup(c => c.Add(It.Is(s => s == "my-key"), It.IsAny())); + + + var config = new FlagdConfig(); + config.CacheEnabled = true; + config.MaxEventStreamRetries = 1; + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, config, mockCache.Object); + + // resolve with default set to false to make sure we return what the grpc server gives us + var val = flagdProvider.ResolveBooleanValue("my-key", false, null); + Assert.True(val.Result.Value); + + Assert.True(_autoResetEvent.WaitOne(10000)); + mockCache.VerifyAll(); + mockGrpcClient.VerifyAll(); + } + + [Fact] + public void TestCacheHit() + { + + var mockGrpcClient = new Mock(); + + var asyncStreamReader = new Mock>(); + + var l = new List + { + new EventStreamResponse{ + Type = "provider_ready" + } + }; + + var enumerator = l.GetEnumerator(); + + // create an autoResetEvent which we will wait for in our test verification + AutoResetEvent _autoResetEvent = new AutoResetEvent(false); + + asyncStreamReader.Setup(a => a.MoveNext(It.IsAny())).ReturnsAsync(() => enumerator.MoveNext()); + asyncStreamReader.Setup(a => a.Current).Returns(() => + { + // set the autoResetEvent since this path should be the last one that's reached in the background task + _autoResetEvent.Set(); + return enumerator.Current; + }); + + var grpcEventStreamResp = new AsyncServerStreamingCall( + asyncStreamReader.Object, + null, + null, + null, + null, + null + ); + + mockGrpcClient + .Setup(m => m.EventStream( + It.IsAny(), null, null, System.Threading.CancellationToken.None)) + .Returns(grpcEventStreamResp); + + var mockCache = new Mock>(); + mockCache.Setup(c => c.TryGet(It.Is(s => s == "my-key"))).Returns( + () => new ResolutionDetails("my-key", true) + ); + + var config = new FlagdConfig(); + config.CacheEnabled = true; + config.MaxEventStreamRetries = 1; + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, config, mockCache.Object); + + // resolve with default set to false to make sure we return what the grpc server gives us + var val = flagdProvider.ResolveBooleanValue("my-key", false, null); + Assert.True(val.Result.Value); + + // wait for the autoReset event to be fired before verifying the invocation of the mocked functions + Assert.True(_autoResetEvent.WaitOne(10000)); + mockCache.VerifyAll(); + mockGrpcClient.VerifyAll(); + } + + [Fact] + public void TestCacheInvalidation() + { + var resp = new ResolveBooleanResponse(); + resp.Value = true; + resp.Reason = "STATIC"; + + var grpcResp = new AsyncUnaryCall( + System.Threading.Tasks.Task.FromResult(resp), + System.Threading.Tasks.Task.FromResult(new Grpc.Core.Metadata()), + () => Status.DefaultSuccess, + () => new Grpc.Core.Metadata(), + () => { }); + + var mockGrpcClient = new Mock(); + mockGrpcClient + .Setup(m => m.ResolveBooleanAsync( + It.IsAny(), null, null, System.Threading.CancellationToken.None)) + .Returns(grpcResp); + + var asyncStreamReader = new Mock>(); + + var configurationChangeData = new Struct(); + var changedFlag = new Struct(); + changedFlag.Fields.Add("my-key", new Google.Protobuf.WellKnownTypes.Value()); + configurationChangeData.Fields.Add("flags", ProtoValue.ForStruct(changedFlag)); + + + var firstCall = true; + + asyncStreamReader.Setup(a => a.MoveNext(It.IsAny())).ReturnsAsync(() => true); + // as long as we did not send our first request to the provider, we will not send a configuration_change event + // after the value of the flag has been retrieved the first time, we will send a configuration_change to test if the + // item is deleted from the cache + + // create an autoResetEvent which we will wait for in our test verification + AutoResetEvent _autoResetEvent = new AutoResetEvent(false); + + asyncStreamReader.Setup(a => a.Current).Returns( + () => + { + if (firstCall) + { + return new EventStreamResponse + { + Type = "provider_ready" + }; + } + return new EventStreamResponse + { + Type = "configuration_change", + Data = configurationChangeData + }; + } + ); + + var grpcEventStreamResp = new AsyncServerStreamingCall( + asyncStreamReader.Object, + null, + null, + null, + null, + null + ); + + mockGrpcClient + .Setup(m => m.EventStream( + It.IsAny(), null, null, System.Threading.CancellationToken.None)) + .Returns(() => + { + return grpcEventStreamResp; + }); + + var mockCache = new Mock>(); + mockCache.Setup(c => c.TryGet(It.Is(s => s == "my-key"))).Returns(() => null); + mockCache.Setup(c => c.Add(It.Is(s => s == "my-key"), It.IsAny())); + mockCache.Setup(c => c.Delete(It.Is(s => s == "my-key"))).Callback(() => + { + // set the autoResetEvent since this path should be the last one that's reached in the background task + _autoResetEvent.Set(); + }); + + + var config = new FlagdConfig(); + config.CacheEnabled = true; + config.MaxEventStreamRetries = 1; + var flagdProvider = new FlagdProvider(mockGrpcClient.Object, config, mockCache.Object); + + // resolve with default set to false to make sure we return what the grpc server gives us + var val = flagdProvider.ResolveBooleanValue("my-key", false, null); + Assert.True(val.Result.Value); + + // set firstCall to true to make the mock EventStream return a configuration_change event + firstCall = false; + + val = flagdProvider.ResolveBooleanValue("my-key", false, null); + Assert.True(val.Result.Value); + + Assert.True(_autoResetEvent.WaitOne(10000)); + + mockCache.VerifyAll(); + mockGrpcClient.VerifyAll(); + } } }