diff --git a/src/limiter/base_limiter.go b/src/limiter/base_limiter.go new file mode 100644 index 00000000..a5a3c1be --- /dev/null +++ b/src/limiter/base_limiter.go @@ -0,0 +1,177 @@ +package limiter + +import ( + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/assert" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/utils" + logger "github.com/sirupsen/logrus" + "math" + "math/rand" +) + +type BaseRateLimiter struct { + timeSource utils.TimeSource + JitterRand *rand.Rand + ExpirationJitterMaxSeconds int64 + cacheKeyGenerator CacheKeyGenerator + localCache *freecache.Cache + nearLimitRatio float32 +} + +type LimitInfo struct { + limit *config.RateLimit + limitBeforeIncrease uint32 + limitAfterIncrease uint32 + nearLimitThreshold uint32 + overLimitThreshold uint32 +} + +func NewRateLimitInfo(limit *config.RateLimit, limitBeforeIncrease uint32, limitAfterIncrease uint32, + nearLimitThreshold uint32, overLimitThreshold uint32) *LimitInfo { + return &LimitInfo{limit: limit, limitBeforeIncrease: limitBeforeIncrease, limitAfterIncrease: limitAfterIncrease, + nearLimitThreshold: nearLimitThreshold, overLimitThreshold: overLimitThreshold} +} + +// Generates cache keys for given rate limit request. Each cache key is represented by a concatenation of +// domain, descriptor and current timestamp. +func (this *BaseRateLimiter) GenerateCacheKeys(request *pb.RateLimitRequest, + limits []*config.RateLimit, hitsAddend uint32) []CacheKey { + assert.Assert(len(request.Descriptors) == len(limits)) + cacheKeys := make([]CacheKey, len(request.Descriptors)) + now := this.timeSource.UnixNow() + for i := 0; i < len(request.Descriptors); i++ { + // generateCacheKey() returns an empty string in the key if there is no limit + // so that we can keep the arrays all the same size. + cacheKeys[i] = this.cacheKeyGenerator.GenerateCacheKey(request.Domain, request.Descriptors[i], limits[i], now) + // Increase statistics for limits hit by their respective requests. + if limits[i] != nil { + limits[i].Stats.TotalHits.Add(uint64(hitsAddend)) + } + } + return cacheKeys +} + +// Returns `true` in case local cache is enabled and contains value for provided cache key, `false` otherwise. +func (this *BaseRateLimiter) IsOverLimitWithLocalCache(key string) bool { + if this.localCache != nil { + // Get returns the value or not found error. + _, err := this.localCache.Get([]byte(key)) + if err == nil { + return true + } + } + return false +} + +// Generates response descriptor status based on cache key, over the limit with local cache, over the limit and +// near the limit thresholds. Thresholds are checked in order and are mutually exclusive. +func (this *BaseRateLimiter) GetResponseDescriptorStatus(key string, limitInfo *LimitInfo, + isOverLimitWithLocalCache bool, hitsAddend uint32) *pb.RateLimitResponse_DescriptorStatus { + if key == "" { + return this.generateResponseDescriptorStatus(pb.RateLimitResponse_OK, + nil, 0) + } + if isOverLimitWithLocalCache { + limitInfo.limit.Stats.OverLimit.Add(uint64(hitsAddend)) + limitInfo.limit.Stats.OverLimitWithLocalCache.Add(uint64(hitsAddend)) + return this.generateResponseDescriptorStatus(pb.RateLimitResponse_OVER_LIMIT, + limitInfo.limit.Limit, 0) + } + var responseDescriptorStatus *pb.RateLimitResponse_DescriptorStatus + limitInfo.overLimitThreshold = limitInfo.limit.Limit.RequestsPerUnit + // The nearLimitThreshold is the number of requests that can be made before hitting the nearLimitRatio. + // We need to know it in both the OK and OVER_LIMIT scenarios. + limitInfo.nearLimitThreshold = uint32(math.Floor(float64(float32(limitInfo.overLimitThreshold) * this.nearLimitRatio))) + logger.Debugf("cache key: %s current: %d", key, limitInfo.limitAfterIncrease) + if limitInfo.limitAfterIncrease > limitInfo.overLimitThreshold { + responseDescriptorStatus = this.generateResponseDescriptorStatus(pb.RateLimitResponse_OVER_LIMIT, + limitInfo.limit.Limit, 0) + + checkOverLimitThreshold(limitInfo, hitsAddend) + + if this.localCache != nil { + // Set the TTL of the local_cache to be the entire duration. + // Since the cache_key gets changed once the time crosses over current time slot, the over-the-limit + // cache keys in local_cache lose effectiveness. + // For example, if we have an hour limit on all mongo connections, the cache key would be + // similar to mongo_1h, mongo_2h, etc. In the hour 1 (0h0m - 0h59m), the cache key is mongo_1h, we start + // to get ratelimited in the 50th minute, the ttl of local_cache will be set as 1 hour(0h50m-1h49m). + // In the time of 1h1m, since the cache key becomes different (mongo_2h), it won't get ratelimited. + err := this.localCache.Set([]byte(key), []byte{}, int(utils.UnitToDivider(limitInfo.limit.Limit.Unit))) + if err != nil { + logger.Errorf("Failing to set local cache key: %s", key) + } + } + } else { + responseDescriptorStatus = this.generateResponseDescriptorStatus(pb.RateLimitResponse_OK, + limitInfo.limit.Limit, limitInfo.overLimitThreshold-limitInfo.limitAfterIncrease) + + // The limit is OK but we additionally want to know if we are near the limit. + checkNearLimitThreshold(limitInfo, hitsAddend) + } + return responseDescriptorStatus +} + +func NewBaseRateLimit(timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, + localCache *freecache.Cache, nearLimitRatio float32) *BaseRateLimiter { + return &BaseRateLimiter{ + timeSource: timeSource, + JitterRand: jitterRand, + ExpirationJitterMaxSeconds: expirationJitterMaxSeconds, + cacheKeyGenerator: NewCacheKeyGenerator(), + localCache: localCache, + nearLimitRatio: nearLimitRatio, + } +} + +func checkOverLimitThreshold(limitInfo *LimitInfo, hitsAddend uint32) { + // Increase over limit statistics. Because we support += behavior for increasing the limit, we need to + // assess if the entire hitsAddend were over the limit. That is, if the limit's value before adding the + // N hits was over the limit, then all the N hits were over limit. + // Otherwise, only the difference between the current limit value and the over limit threshold + // were over limit hits. + if limitInfo.limitBeforeIncrease >= limitInfo.overLimitThreshold { + limitInfo.limit.Stats.OverLimit.Add(uint64(hitsAddend)) + } else { + limitInfo.limit.Stats.OverLimit.Add(uint64(limitInfo.limitAfterIncrease - limitInfo.overLimitThreshold)) + + // If the limit before increase was below the over limit value, then some of the hits were + // in the near limit range. + limitInfo.limit.Stats.NearLimit.Add(uint64(limitInfo.overLimitThreshold - + utils.Max(limitInfo.nearLimitThreshold, limitInfo.limitBeforeIncrease))) + } +} + +func checkNearLimitThreshold(limitInfo *LimitInfo, hitsAddend uint32) { + if limitInfo.limitAfterIncrease > limitInfo.nearLimitThreshold { + // Here we also need to assess which portion of the hitsAddend were in the near limit range. + // If all the hits were over the nearLimitThreshold, then all hits are near limit. Otherwise, + // only the difference between the current limit value and the near limit threshold were near + // limit hits. + if limitInfo.limitBeforeIncrease >= limitInfo.nearLimitThreshold { + limitInfo.limit.Stats.NearLimit.Add(uint64(hitsAddend)) + } else { + limitInfo.limit.Stats.NearLimit.Add(uint64(limitInfo.limitAfterIncrease - limitInfo.nearLimitThreshold)) + } + } +} + +func (this *BaseRateLimiter) generateResponseDescriptorStatus(responseCode pb.RateLimitResponse_Code, + limit *pb.RateLimitResponse_RateLimit, limitRemaining uint32) *pb.RateLimitResponse_DescriptorStatus { + if limit != nil { + return &pb.RateLimitResponse_DescriptorStatus{ + Code: responseCode, + CurrentLimit: limit, + LimitRemaining: limitRemaining, + DurationUntilReset: utils.CalculateReset(limit, this.timeSource), + } + } else { + return &pb.RateLimitResponse_DescriptorStatus{ + Code: responseCode, + CurrentLimit: limit, + LimitRemaining: limitRemaining, + } + } +} diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 52f8fae4..10a87044 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -17,7 +17,6 @@ package memcached import ( "context" - "math" "math/rand" "strconv" "sync" @@ -31,7 +30,6 @@ import ( pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/assert" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/settings" @@ -47,17 +45,11 @@ type rateLimitMemcacheImpl struct { localCache *freecache.Cache waitGroup sync.WaitGroup nearLimitRatio float32 + baseRateLimiter *limiter.BaseRateLimiter } var _ limiter.RateLimitCache = (*rateLimitMemcacheImpl)(nil) -func max(a uint32, b uint32) uint32 { - if a > b { - return a - } - return b -} - func (this *rateLimitMemcacheImpl) DoLimit( ctx context.Context, request *pb.RateLimitRequest, @@ -66,22 +58,10 @@ func (this *rateLimitMemcacheImpl) DoLimit( logger.Debugf("starting cache lookup") // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. - hitsAddend := max(1, request.HitsAddend) - - // First build a list of all cache keys that we are actually going to hit. generateCacheKey() - // returns an empty string in the key if there is no limit so that we can keep the arrays - // all the same size. - assert.Assert(len(request.Descriptors) == len(limits)) - cacheKeys := make([]limiter.CacheKey, len(request.Descriptors)) - now := this.timeSource.UnixNow() - for i := 0; i < len(request.Descriptors); i++ { - cacheKeys[i] = this.cacheKeyGenerator.GenerateCacheKey(request.Domain, request.Descriptors[i], limits[i], now) - - // Increase statistics for limits hit by their respective requests. - if limits[i] != nil { - limits[i].Stats.TotalHits.Add(uint64(hitsAddend)) - } - } + hitsAddend := utils.Max(1, request.HitsAddend) + + // First build a list of all cache keys that we are actually going to hit. + cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) @@ -92,14 +72,11 @@ func (this *rateLimitMemcacheImpl) DoLimit( continue } - if this.localCache != nil { - // Get returns the value or not found error. - _, err := this.localCache.Get([]byte(cacheKey.Key)) - if err == nil { - isOverLimitWithLocalCache[i] = true - logger.Debugf("cache key is over the limit: %s", cacheKey.Key) - continue - } + // Check if key is over the limit in local cache. + if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) { + isOverLimitWithLocalCache[i] = true + logger.Debugf("cache key is over the limit: %s", cacheKey.Key) + continue } logger.Debugf("looking up cache key: %s", cacheKey.Key) @@ -121,28 +98,6 @@ func (this *rateLimitMemcacheImpl) DoLimit( } for i, cacheKey := range cacheKeys { - if cacheKey.Key == "" { - responseDescriptorStatuses[i] = - &pb.RateLimitResponse_DescriptorStatus{ - Code: pb.RateLimitResponse_OK, - CurrentLimit: nil, - LimitRemaining: 0, - } - continue - } - - if isOverLimitWithLocalCache[i] { - responseDescriptorStatuses[i] = - &pb.RateLimitResponse_DescriptorStatus{ - Code: pb.RateLimitResponse_OVER_LIMIT, - CurrentLimit: limits[i].Limit, - LimitRemaining: 0, - DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), - } - limits[i].Stats.OverLimit.Add(uint64(hitsAddend)) - limits[i].Stats.OverLimitWithLocalCache.Add(uint64(hitsAddend)) - continue - } rawMemcacheValue, ok := memcacheValues[cacheKey.Key] var limitBeforeIncrease uint32 @@ -157,70 +112,11 @@ func (this *rateLimitMemcacheImpl) DoLimit( } limitAfterIncrease := limitBeforeIncrease + hitsAddend - overLimitThreshold := limits[i].Limit.RequestsPerUnit - // The nearLimitThreshold is the number of requests that can be made before hitting the NearLimitRatio. - // We need to know it in both the OK and OVER_LIMIT scenarios. - nearLimitThreshold := uint32(math.Floor(float64(float32(overLimitThreshold) * this.nearLimitRatio))) - - logger.Debugf("cache key: %s current: %d", cacheKey.Key, limitAfterIncrease) - if limitAfterIncrease > overLimitThreshold { - responseDescriptorStatuses[i] = - &pb.RateLimitResponse_DescriptorStatus{ - Code: pb.RateLimitResponse_OVER_LIMIT, - CurrentLimit: limits[i].Limit, - LimitRemaining: 0, - DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), - } - // Increase over limit statistics. Because we support += behavior for increasing the limit, we need to - // assess if the entire hitsAddend were over the limit. That is, if the limit's value before adding the - // N hits was over the limit, then all the N hits were over limit. - // Otherwise, only the difference between the current limit value and the over limit threshold - // were over limit hits. - if limitBeforeIncrease >= overLimitThreshold { - limits[i].Stats.OverLimit.Add(uint64(hitsAddend)) - } else { - limits[i].Stats.OverLimit.Add(uint64(limitAfterIncrease - overLimitThreshold)) - - // If the limit before increase was below the over limit value, then some of the hits were - // in the near limit range. - limits[i].Stats.NearLimit.Add(uint64(overLimitThreshold - max(nearLimitThreshold, limitBeforeIncrease))) - } - if this.localCache != nil { - // Set the TTL of the local_cache to be the entire duration. - // Since the cache_key gets changed once the time crosses over current time slot, the over-the-limit - // cache keys in local_cache lose effectiveness. - // For example, if we have an hour limit on all mongo connections, the cache key would be - // similar to mongo_1h, mongo_2h, etc. In the hour 1 (0h0m - 0h59m), the cache key is mongo_1h, we start - // to get ratelimited in the 50th minute, the ttl of local_cache will be set as 1 hour(0h50m-1h49m). - // In the time of 1h1m, since the cache key becomes different (mongo_2h), it won't get ratelimited. - err := this.localCache.Set([]byte(cacheKey.Key), []byte{}, int(utils.UnitToDivider(limits[i].Limit.Unit))) - if err != nil { - logger.Errorf("Failing to set local cache key: %s", cacheKey.Key) - } - } - } else { - responseDescriptorStatuses[i] = - &pb.RateLimitResponse_DescriptorStatus{ - Code: pb.RateLimitResponse_OK, - CurrentLimit: limits[i].Limit, - LimitRemaining: overLimitThreshold - limitAfterIncrease, - DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), - } + limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) - // The limit is OK but we additionally want to know if we are near the limit. - if limitAfterIncrease > nearLimitThreshold { - // Here we also need to assess which portion of the hitsAddend were in the near limit range. - // If all the hits were over the nearLimitThreshold, then all hits are near limit. Otherwise, - // only the difference between the current limit value and the near limit threshold were near - // limit hits. - if limitBeforeIncrease >= nearLimitThreshold { - limits[i].Stats.NearLimit.Add(uint64(hitsAddend)) - } else { - limits[i].Stats.NearLimit.Add(uint64(limitAfterIncrease - nearLimitThreshold)) - } - } - } + responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, + limitInfo, isOverLimitWithLocalCache[i], hitsAddend) } this.waitGroup.Add(1) @@ -229,7 +125,8 @@ func (this *rateLimitMemcacheImpl) DoLimit( return responseDescriptorStatuses } -func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, limits []*config.RateLimit, hitsAddend uint64) { +func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool, + limits []*config.RateLimit, hitsAddend uint64) { defer this.waitGroup.Done() for i, cacheKey := range cacheKeys { if cacheKey.Key == "" || isOverLimitWithLocalCache[i] { @@ -243,7 +140,7 @@ func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, i expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) } - // Need to add instead of increment + // Need to add instead of increment. err = this.client.Add(&memcache.Item{ Key: cacheKey.Key, Value: []byte(strconv.FormatUint(hitsAddend, 10)), @@ -272,7 +169,8 @@ func (this *rateLimitMemcacheImpl) Flush() { this.waitGroup.Wait() } -func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32) limiter.RateLimitCache { +func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, + expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32) limiter.RateLimitCache { return &rateLimitMemcacheImpl{ client: client, timeSource: timeSource, @@ -281,10 +179,12 @@ func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRan expirationJitterMaxSeconds: expirationJitterMaxSeconds, localCache: localCache, nearLimitRatio: nearLimitRatio, + baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio), } } -func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { +func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, + localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { return NewRateLimitCacheImpl( memcache.New(s.MemcacheHostPort), timeSource, diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index 6ecb5309..bd2502ac 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -1,12 +1,10 @@ package redis import ( - "math" "math/rand" "github.com/coocood/freecache" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" - "github.com/envoyproxy/ratelimit/src/assert" "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/utils" @@ -20,20 +18,8 @@ type fixedRateLimitCacheImpl struct { // If this client is nil, then the Cache will use the client for all // limits regardless of unit. If this client is not nil, then it // is used for limits that have a SECOND unit. - perSecondClient Client - timeSource utils.TimeSource - jitterRand *rand.Rand - expirationJitterMaxSeconds int64 - cacheKeyGenerator limiter.CacheKeyGenerator - localCache *freecache.Cache - nearLimitRatio float32 -} - -func max(a uint32, b uint32) uint32 { - if a > b { - return a - } - return b + perSecondClient Client + baseRateLimiter *limiter.BaseRateLimiter } func pipelineAppend(client Client, pipeline *Pipeline, key string, hitsAddend uint32, result *uint32, expirationSeconds int64) { @@ -48,24 +34,11 @@ func (this *fixedRateLimitCacheImpl) DoLimit( logger.Debugf("starting cache lookup") - // request.HitsAddend could be 0 (default value) if not specified by the caller in the Ratelimit request. - hitsAddend := max(1, request.HitsAddend) - - // First build a list of all cache keys that we are actually going to hit. GenerateCacheKey() - // returns an empty string in the key if there is no limit so that we can keep the arrays - // all the same size. - assert.Assert(len(request.Descriptors) == len(limits)) - cacheKeys := make([]limiter.CacheKey, len(request.Descriptors)) - now := this.timeSource.UnixNow() - for i := 0; i < len(request.Descriptors); i++ { - cacheKeys[i] = this.cacheKeyGenerator.GenerateCacheKey( - request.Domain, request.Descriptors[i], limits[i], now) - - // Increase statistics for limits hit by their respective requests. - if limits[i] != nil { - limits[i].Stats.TotalHits.Add(uint64(hitsAddend)) - } - } + // request.HitsAddend could be 0 (default value) if not specified by the caller in the RateLimit request. + hitsAddend := utils.Max(1, request.HitsAddend) + + // First build a list of all cache keys that we are actually going to hit. + cacheKeys := this.baseRateLimiter.GenerateCacheKeys(request, limits, hitsAddend) isOverLimitWithLocalCache := make([]bool, len(request.Descriptors)) results := make([]uint32, len(request.Descriptors)) @@ -77,21 +50,18 @@ func (this *fixedRateLimitCacheImpl) DoLimit( continue } - if this.localCache != nil { - // Get returns the value or not found error. - _, err := this.localCache.Get([]byte(cacheKey.Key)) - if err == nil { - isOverLimitWithLocalCache[i] = true - logger.Debugf("cache key is over the limit: %s", cacheKey.Key) - continue - } + // Check if key is over the limit in local cache. + if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) { + isOverLimitWithLocalCache[i] = true + logger.Debugf("cache key is over the limit: %s", cacheKey.Key) + continue } logger.Debugf("looking up cache key: %s", cacheKey.Key) expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit) - if this.expirationJitterMaxSeconds > 0 { - expirationSeconds += this.jitterRand.Int63n(this.expirationJitterMaxSeconds) + if this.baseRateLimiter.ExpirationJitterMaxSeconds > 0 { + expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds) } // Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit. @@ -119,95 +89,15 @@ func (this *fixedRateLimitCacheImpl) DoLimit( responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus, len(request.Descriptors)) for i, cacheKey := range cacheKeys { - if cacheKey.Key == "" { - responseDescriptorStatuses[i] = - &pb.RateLimitResponse_DescriptorStatus{ - Code: pb.RateLimitResponse_OK, - CurrentLimit: nil, - LimitRemaining: 0, - } - continue - } - - if isOverLimitWithLocalCache[i] { - responseDescriptorStatuses[i] = - &pb.RateLimitResponse_DescriptorStatus{ - Code: pb.RateLimitResponse_OVER_LIMIT, - CurrentLimit: limits[i].Limit, - LimitRemaining: 0, - DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), - } - limits[i].Stats.OverLimit.Add(uint64(hitsAddend)) - limits[i].Stats.OverLimitWithLocalCache.Add(uint64(hitsAddend)) - continue - } limitAfterIncrease := results[i] limitBeforeIncrease := limitAfterIncrease - hitsAddend - overLimitThreshold := limits[i].Limit.RequestsPerUnit - // The nearLimitThreshold is the number of requests that can be made before hitting the NearLimitRatio. - // We need to know it in both the OK and OVER_LIMIT scenarios. - nearLimitThreshold := uint32(math.Floor(float64(float32(overLimitThreshold) * this.nearLimitRatio))) - - logger.Debugf("cache key: %s current: %d", cacheKey.Key, limitAfterIncrease) - if limitAfterIncrease > overLimitThreshold { - responseDescriptorStatuses[i] = - &pb.RateLimitResponse_DescriptorStatus{ - Code: pb.RateLimitResponse_OVER_LIMIT, - CurrentLimit: limits[i].Limit, - LimitRemaining: 0, - DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), - } - - // Increase over limit statistics. Because we support += behavior for increasing the limit, we need to - // assess if the entire hitsAddend were over the limit. That is, if the limit's value before adding the - // N hits was over the limit, then all the N hits were over limit. - // Otherwise, only the difference between the current limit value and the over limit threshold - // were over limit hits. - if limitBeforeIncrease >= overLimitThreshold { - limits[i].Stats.OverLimit.Add(uint64(hitsAddend)) - } else { - limits[i].Stats.OverLimit.Add(uint64(limitAfterIncrease - overLimitThreshold)) - - // If the limit before increase was below the over limit value, then some of the hits were - // in the near limit range. - limits[i].Stats.NearLimit.Add(uint64(overLimitThreshold - max(nearLimitThreshold, limitBeforeIncrease))) - } - if this.localCache != nil { - // Set the TTL of the local_cache to be the entire duration. - // Since the cache_key gets changed once the time crosses over current time slot, the over-the-limit - // cache keys in local_cache lose effectiveness. - // For example, if we have an hour limit on all mongo connections, the cache key would be - // similar to mongo_1h, mongo_2h, etc. In the hour 1 (0h0m - 0h59m), the cache key is mongo_1h, we start - // to get ratelimited in the 50th minute, the ttl of local_cache will be set as 1 hour(0h50m-1h49m). - // In the time of 1h1m, since the cache key becomes different (mongo_2h), it won't get ratelimited. - err := this.localCache.Set([]byte(cacheKey.Key), []byte{}, int(utils.UnitToDivider(limits[i].Limit.Unit))) - if err != nil { - logger.Errorf("Failing to set local cache key: %s", cacheKey.Key) - } - } - } else { - responseDescriptorStatuses[i] = - &pb.RateLimitResponse_DescriptorStatus{ - Code: pb.RateLimitResponse_OK, - CurrentLimit: limits[i].Limit, - LimitRemaining: overLimitThreshold - limitAfterIncrease, - DurationUntilReset: utils.CalculateReset(limits[i].Limit, this.timeSource), - } - - // The limit is OK but we additionally want to know if we are near the limit. - if limitAfterIncrease > nearLimitThreshold { - // Here we also need to assess which portion of the hitsAddend were in the near limit range. - // If all the hits were over the nearLimitThreshold, then all hits are near limit. Otherwise, - // only the difference between the current limit value and the near limit threshold were near - // limit hits. - if limitBeforeIncrease >= nearLimitThreshold { - limits[i].Stats.NearLimit.Add(uint64(hitsAddend)) - } else { - limits[i].Stats.NearLimit.Add(uint64(limitAfterIncrease - nearLimitThreshold)) - } - } - } + + limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0) + + responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key, + limitInfo, isOverLimitWithLocalCache[i], hitsAddend) + } return responseDescriptorStatuses @@ -216,15 +106,11 @@ func (this *fixedRateLimitCacheImpl) DoLimit( // Flush() is a no-op with redis since quota reads and updates happen synchronously. func (this *fixedRateLimitCacheImpl) Flush() {} -func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32) limiter.RateLimitCache { +func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource utils.TimeSource, + jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32) limiter.RateLimitCache { return &fixedRateLimitCacheImpl{ - client: client, - perSecondClient: perSecondClient, - timeSource: timeSource, - jitterRand: jitterRand, - expirationJitterMaxSeconds: expirationJitterMaxSeconds, - cacheKeyGenerator: limiter.NewCacheKeyGenerator(), - localCache: localCache, - nearLimitRatio: nearLimitRatio, + client: client, + perSecondClient: perSecondClient, + baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio), } } diff --git a/src/utils/utilities.go b/src/utils/utilities.go index 8bfbb641..e6029f5b 100644 --- a/src/utils/utilities.go +++ b/src/utils/utilities.go @@ -34,3 +34,10 @@ func CalculateReset(currentLimit *pb.RateLimitResponse_RateLimit, timeSource Tim now := timeSource.UnixNow() return &duration.Duration{Seconds: sec - now%sec} } + +func Max(a uint32, b uint32) uint32 { + if a > b { + return a + } + return b +} diff --git a/test/limiter/base_limiter_test.go b/test/limiter/base_limiter_test.go new file mode 100644 index 00000000..0694ca00 --- /dev/null +++ b/test/limiter/base_limiter_test.go @@ -0,0 +1,125 @@ +package limiter + +import ( + "github.com/coocood/freecache" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + "github.com/envoyproxy/ratelimit/test/common" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" + "github.com/golang/mock/gomock" + stats "github.com/lyft/gostats" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func TestGenerateCacheKeys(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + timeSource := mock_utils.NewMockTimeSource(controller) + jitterSource := mock_utils.NewMockJitterRandSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + timeSource.EXPECT().UnixNow().Return(int64(1234)) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8) + request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + assert.Equal(uint64(0), limits[0].Stats.TotalHits.Value()) + cacheKeys := baseRateLimit.GenerateCacheKeys(request, limits, 1) + assert.Equal(1, len(cacheKeys)) + assert.Equal("domain_key_value_1234", cacheKeys[0].Key) + assert.Equal(uint64(1), limits[0].Stats.TotalHits.Value()) +} + +func TestOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + localCache := freecache.NewCache(100) + localCache.Set([]byte("key"), []byte("value"), 100) + baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, localCache, 0.8) + // Returns true, as local cache contains over limit value for the key. + assert.Equal(true, baseRateLimit.IsOverLimitWithLocalCache("key")) +} + +func TestNoOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, nil, 0.8) + // Returns false, as local cache is nil. + assert.Equal(false, baseRateLimit.IsOverLimitWithLocalCache("domain_key_value_1234")) + localCache := freecache.NewCache(100) + baseRateLimitWithLocalCache := limiter.NewBaseRateLimit(nil, nil, 3600, localCache, 0.8) + // Returns false, as local cache does not contain value for cache key. + assert.Equal(false, baseRateLimitWithLocalCache.IsOverLimitWithLocalCache("domain_key_value_1234")) +} + +func TestGetResponseStatusEmptyKey(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, nil, 0.8) + responseStatus := baseRateLimit.GetResponseDescriptorStatus("", nil, false, 1) + assert.Equal(pb.RateLimitResponse_OK, responseStatus.GetCode()) + assert.Equal(uint32(0), responseStatus.GetLimitRemaining()) +} + +func TestGetResponseStatusOverLimitWithLocalCache(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + timeSource := mock_utils.NewMockTimeSource(controller) + timeSource.EXPECT().UnixNow().Return(int64(1234)) + statsStore := stats.NewStore(stats.NewNullSink(), false) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8) + limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 4, 5) + // As `isOverLimitWithLocalCache` is passed as `true`, immediate response is returned with no checks of the limits. + responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, true, 2) + assert.Equal(pb.RateLimitResponse_OVER_LIMIT, responseStatus.GetCode()) + assert.Equal(uint32(0), responseStatus.GetLimitRemaining()) + assert.Equal(limits[0].Limit, responseStatus.GetCurrentLimit()) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(2), limits[0].Stats.OverLimitWithLocalCache.Value()) +} + +func TestGetResponseStatusOverLimit(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + timeSource := mock_utils.NewMockTimeSource(controller) + timeSource.EXPECT().UnixNow().Return(int64(1234)) + statsStore := stats.NewStore(stats.NewNullSink(), false) + localCache := freecache.NewCache(100) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8) + limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 7, 4, 5) + responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) + assert.Equal(pb.RateLimitResponse_OVER_LIMIT, responseStatus.GetCode()) + assert.Equal(uint32(0), responseStatus.GetLimitRemaining()) + assert.Equal(limits[0].Limit, responseStatus.GetCurrentLimit()) + result, _ := localCache.Get([]byte("key")) + // Local cache should have been populated with over the limit key. + assert.Equal("", string(result)) + assert.Equal(uint64(2), limits[0].Stats.OverLimit.Value()) + assert.Equal(uint64(1), limits[0].Stats.NearLimit.Value()) +} + +func TestGetResponseStatusBelowLimit(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + timeSource := mock_utils.NewMockTimeSource(controller) + timeSource.EXPECT().UnixNow().Return(int64(1234)) + statsStore := stats.NewStore(stats.NewNullSink(), false) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8) + limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, "key_value", statsStore)} + limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 9, 10) + responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) + assert.Equal(pb.RateLimitResponse_OK, responseStatus.GetCode()) + assert.Equal(uint32(4), responseStatus.GetLimitRemaining()) + assert.Equal(uint64(0), limits[0].Stats.NearLimit.Value()) + assert.Equal(limits[0].Limit, responseStatus.GetCurrentLimit()) +}