diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index 0b0a45b4a..30890786a 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -1,6 +1,7 @@ package redis import ( + "io" "math/rand" "github.com/coocood/freecache" @@ -12,15 +13,18 @@ import ( "github.com/envoyproxy/ratelimit/src/utils" ) -func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) limiter.RateLimitCache { +func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) { + closer := &utils.MultiCloser{} var perSecondPool Client if s.RedisPerSecond { perSecondPool = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondSocketType, s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPerSecondPipelineWindow, s.RedisPerSecondPipelineLimit, s.RedisTlsConfig, s.RedisHealthCheckActiveConnection, srv) + closer.Closers = append(closer.Closers, perSecondPool) } otherPool := NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisSocketType, s.RedisType, s.RedisUrl, s.RedisPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit, s.RedisTlsConfig, s.RedisHealthCheckActiveConnection, srv) + closer.Closers = append(closer.Closers, otherPool) return NewFixedRateLimitCacheImpl( otherPool, @@ -33,5 +37,5 @@ func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freeca s.CacheKeyPrefix, statsManager, s.StopCacheKeyIncrementWhenOverlimit, - ) + ), closer } diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index f645e58fd..fcd79d698 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -32,10 +32,11 @@ import ( ) type Runner struct { - statsManager stats.Manager - settings settings.Settings - srv server.Server - mu sync.Mutex + statsManager stats.Manager + settings settings.Settings + srv server.Server + mu sync.Mutex + ratelimitCloser io.Closer } func NewRunner(s settings.Settings) Runner { @@ -80,7 +81,7 @@ func (runner *Runner) GetStatsStore() gostats.Store { return runner.statsManager.GetStatsStore() } -func createLimiter(srv server.Server, s settings.Settings, localCache *freecache.Cache, statsManager stats.Manager) limiter.RateLimitCache { +func createLimiter(srv server.Server, s settings.Settings, localCache *freecache.Cache, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) { switch s.BackendType { case "redis", "": return redis.NewRateLimiterCacheImplFromSettings( @@ -99,7 +100,7 @@ func createLimiter(srv server.Server, s settings.Settings, localCache *freecache rand.New(utils.NewLockedSource(time.Now().Unix())), localCache, srv.Scope(), - statsManager) + statsManager), &utils.MultiCloser{} // memcache client can't closed default: logger.Fatalf("Invalid setting for BackendType: %s", s.BackendType) panic("This line should not be reachable") @@ -147,8 +148,11 @@ func (runner *Runner) Run() { runner.srv = srv runner.mu.Unlock() + limiter, limiterCloser := createLimiter(srv, s, localCache, runner.statsManager) + runner.ratelimitCloser = limiterCloser + service := ratelimit.NewService( - createLimiter(srv, s, localCache, runner.statsManager), + limiter, srv.Provider(), runner.statsManager, srv.HealthChecker(), @@ -184,4 +188,8 @@ func (runner *Runner) Stop() { if srv != nil { srv.Stop() } + + if runner.ratelimitCloser != nil { + _ = runner.ratelimitCloser.Close() + } } diff --git a/src/utils/multi_closer.go b/src/utils/multi_closer.go new file mode 100644 index 000000000..fead3f6b1 --- /dev/null +++ b/src/utils/multi_closer.go @@ -0,0 +1,18 @@ +package utils + +import ( + "errors" + "io" +) + +type MultiCloser struct { + Closers []io.Closer +} + +func (m *MultiCloser) Close() error { + var e error + for _, closer := range m.Closers { + e = errors.Join(closer.Close()) + } + return e +}