From 011c8f8007d32ba83a7d6335e1b1fdd36ae807d6 Mon Sep 17 00:00:00 2001 From: Lucas Lemos Date: Tue, 23 Jul 2024 18:00:37 -0300 Subject: [PATCH] :sparkles: feat: Add Max Func to Limiter Middleware (#3070) * feat: add max calculator to limiter middleware * docs: update docs including the new parameter * refactor: add new line before go code in docs * fix: use crypto/rand instead of math/rand on tests * test: add new test with zero set as limit * fix: repeated tests failing when generating random limits * fix: wrong type of MaxCalculator in docs * feat: include max calculator in limiter_sliding * refactor: rename MaxCalculator to MaxFunc * docs: update docs with MaxFunc parameter * tests: rename tests and add test for limiter sliding --- docs/middleware/limiter.md | 22 +++++ middleware/limiter/config.go | 12 +++ middleware/limiter/limiter_fixed.go | 14 +-- middleware/limiter/limiter_sliding.go | 10 +- middleware/limiter/limiter_test.go | 126 ++++++++++++++++++++++++++ 5 files changed, 174 insertions(+), 10 deletions(-) diff --git a/docs/middleware/limiter.md b/docs/middleware/limiter.md index edccec20d2..ab97620595 100644 --- a/docs/middleware/limiter.md +++ b/docs/middleware/limiter.md @@ -43,6 +43,9 @@ app.Use(limiter.New(limiter.Config{ return c.IP() == "127.0.0.1" }, Max: 20, + MaxFunc: func(c fiber.Ctx) int { + return 20 + }, Expiration: 30 * time.Second, KeyGenerator: func(c fiber.Ctx) string { return c.Get("x-forwarded-for") @@ -75,12 +78,28 @@ weightOfPreviousWindow = previous window's amount request * (whenNewWindow / Exp rate = weightOfPreviousWindow + current window's amount request. ``` +## Dynamic limit + +You can also calculate the limit dynamically using the MaxFunc parameter. It's a function that receives the request's context as a parameter and allow you to calculate a different limit for each request separately. + +Example: + +```go +app.Use(limiter.New(limiter.Config{ + MaxFunc: func(c fiber.Ctx) int { + return getUserLimit(ctx.Param("id")) + }, + Expiration: 30 * time.Second, +})) +``` + ## Config | Property | Type | Description | Default | |:-----------------------|:--------------------------|:--------------------------------------------------------------------------------------------|:-----------------------------------------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | | Max | `int` | Max number of recent connections during `Expiration` seconds before sending a 429 response. | 5 | +| MaxFunc | `func(fiber.Ctx) int` | A function to calculate the max number of recent connections during `Expiration` seconds before sending a 429 response. | A function which returns the cfg.Max | | KeyGenerator | `func(fiber.Ctx) string` | KeyGenerator allows you to generate custom keys, by default c.IP() is used. | A function using c.IP() as the default | | Expiration | `time.Duration` | Expiration is the time on how long to keep records of requests in memory. | 1 * time.Minute | | LimitReached | `fiber.Handler` | LimitReached is called when a request hits the limit. | A function sending 429 response | @@ -101,6 +120,9 @@ A custom store can be used if it implements the `Storage` interface - more detai ```go var ConfigDefault = Config{ Max: 5, + MaxFunc: func(c fiber.Ctx) int { + return 5 + }, Expiration: 1 * time.Minute, KeyGenerator: func(c fiber.Ctx) string { return c.IP() diff --git a/middleware/limiter/config.go b/middleware/limiter/config.go index 5570f96a12..461d21a7dd 100644 --- a/middleware/limiter/config.go +++ b/middleware/limiter/config.go @@ -22,6 +22,13 @@ type Config struct { // Optional. Default: nil Next func(c fiber.Ctx) bool + // A function to dynamically calculate the max requests supported by the rate limiter middleware + // + // Default: func(c fiber.Ctx) int { + // return c.Max + // } + MaxFunc func(c fiber.Ctx) int + // KeyGenerator allows you to generate custom keys, by default c.IP() is used // // Default: func(c fiber.Ctx) string { @@ -101,5 +108,10 @@ func configDefault(config ...Config) Config { if cfg.LimiterMiddleware == nil { cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware } + if cfg.MaxFunc == nil { + cfg.MaxFunc = func(_ fiber.Ctx) int { + return cfg.Max + } + } return cfg } diff --git a/middleware/limiter/limiter_fixed.go b/middleware/limiter/limiter_fixed.go index 1e2a1aa0e5..42b08afaf6 100644 --- a/middleware/limiter/limiter_fixed.go +++ b/middleware/limiter/limiter_fixed.go @@ -15,7 +15,6 @@ func (FixedWindow) New(cfg Config) fiber.Handler { var ( // Limiter variables mux = &sync.RWMutex{} - max = strconv.Itoa(cfg.Max) expiration = uint64(cfg.Expiration.Seconds()) ) @@ -27,8 +26,11 @@ func (FixedWindow) New(cfg Config) fiber.Handler { // Return new handler return func(c fiber.Ctx) error { - // Don't execute middleware if Next returns true - if cfg.Next != nil && cfg.Next(c) { + // Generate max from generator, if no generator was provided the default value returned is 5 + max := cfg.MaxFunc(c) + + // Don't execute middleware if Next returns true or if the max is 0 + if (cfg.Next != nil && cfg.Next(c)) || max == 0 { return c.Next() } @@ -60,7 +62,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler { resetInSec := e.exp - ts // Set how many hits we have left - remaining := cfg.Max - e.currHits + remaining := max - e.currHits // Update storage manager.set(key, e, cfg.Expiration) @@ -68,7 +70,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler { // Unlock entry mux.Unlock() - // Check if hits exceed the cfg.Max + // Check if hits exceed the max if remaining < 0 { // Return response with Retry-After header // https://tools.ietf.org/html/rfc6584 @@ -96,7 +98,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler { } // We can continue, update RateLimit headers - c.Set(xRateLimitLimit, max) + c.Set(xRateLimitLimit, strconv.Itoa(max)) c.Set(xRateLimitRemaining, strconv.Itoa(remaining)) c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10)) diff --git a/middleware/limiter/limiter_sliding.go b/middleware/limiter/limiter_sliding.go index a98593476e..1fc3138b9b 100644 --- a/middleware/limiter/limiter_sliding.go +++ b/middleware/limiter/limiter_sliding.go @@ -16,7 +16,6 @@ func (SlidingWindow) New(cfg Config) fiber.Handler { var ( // Limiter variables mux = &sync.RWMutex{} - max = strconv.Itoa(cfg.Max) expiration = uint64(cfg.Expiration.Seconds()) ) @@ -28,8 +27,11 @@ func (SlidingWindow) New(cfg Config) fiber.Handler { // Return new handler return func(c fiber.Ctx) error { - // Don't execute middleware if Next returns true - if cfg.Next != nil && cfg.Next(c) { + // Generate max from generator, if no generator was provided the default value returned is 5 + max := cfg.MaxFunc(c) + + // Don't execute middleware if Next returns true or if the max is 0 + if (cfg.Next != nil && cfg.Next(c)) || max == 0 { return c.Next() } @@ -127,7 +129,7 @@ func (SlidingWindow) New(cfg Config) fiber.Handler { } // We can continue, update RateLimit headers - c.Set(xRateLimitLimit, max) + c.Set(xRateLimitLimit, strconv.Itoa(max)) c.Set(xRateLimitRemaining, strconv.Itoa(remaining)) c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10)) diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index ed4470e9a8..41bd2e8ddf 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -14,6 +14,132 @@ import ( "github.com/valyala/fasthttp" ) +// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v +func Test_Limiter_With_Max_Func_With_Zero_And_Limiter_Sliding(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + MaxFunc: func(_ fiber.Ctx) int { return 0 }, + Expiration: 2 * time.Second, + SkipFailedRequests: false, + SkipSuccessfulRequests: false, + LimiterMiddleware: SlidingWindow{}, + })) + + app.Get("/:status", func(c fiber.Ctx) error { + if c.Params("status") == "fail" { + return c.SendStatus(400) + } + return c.SendStatus(200) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil)) + require.NoError(t, err) + require.Equal(t, 400, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + + time.Sleep(4*time.Second + 500*time.Millisecond) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) +} + +// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v +func Test_Limiter_With_Max_Func_With_Zero(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + MaxFunc: func(_ fiber.Ctx) int { + return 0 + }, + Expiration: 2 * time.Second, + Storage: memory.New(), + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello tester!") + }) + + var wg sync.WaitGroup + + for i := 0; i <= 4; i++ { + wg.Add(1) + go func(wg *sync.WaitGroup) { + defer wg.Done() + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "Hello tester!", string(body)) + }(&wg) + } + + wg.Wait() + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) +} + +// go test -run Test_Limiter_With_Max_Func -race -v +func Test_Limiter_With_Max_Func(t *testing.T) { + t.Parallel() + app := fiber.New() + max := 10 + + app.Use(New(Config{ + MaxFunc: func(_ fiber.Ctx) int { + return max + }, + Expiration: 2 * time.Second, + Storage: memory.New(), + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("Hello tester!") + }) + + var wg sync.WaitGroup + + for i := 0; i <= max-1; i++ { + wg.Add(1) + go func(wg *sync.WaitGroup) { + defer wg.Done() + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "Hello tester!", string(body)) + }(&wg) + } + + wg.Wait() + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + require.Equal(t, 429, resp.StatusCode) + + time.Sleep(3 * time.Second) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) +} + // go test -run Test_Limiter_Concurrency_Store -race -v func Test_Limiter_Concurrency_Store(t *testing.T) { t.Parallel()