Skip to content

Commit

Permalink
✨ feat: Add Max Func to Limiter Middleware (#3070)
Browse files Browse the repository at this point in the history
* feat: add max calculator to limiter middleware

* docs: update docs including the new parameter

* refactor: add new line before go code in docs

* fix: use crypto/rand instead of math/rand on tests

* test: add new test with zero set as limit

* fix: repeated tests failing when generating random limits

* fix: wrong type of MaxCalculator in docs

* feat: include max calculator in limiter_sliding

* refactor: rename MaxCalculator to MaxFunc

* docs: update docs with MaxFunc parameter

* tests: rename tests and add test for limiter sliding
  • Loading branch information
luk3skyw4lker authored Jul 23, 2024
1 parent 486304d commit 011c8f8
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 10 deletions.
22 changes: 22 additions & 0 deletions docs/middleware/limiter.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ app.Use(limiter.New(limiter.Config{
return c.IP() == "127.0.0.1"
},
Max: 20,
MaxFunc: func(c fiber.Ctx) int {
return 20
},
Expiration: 30 * time.Second,
KeyGenerator: func(c fiber.Ctx) string {
return c.Get("x-forwarded-for")
Expand Down Expand Up @@ -75,12 +78,28 @@ weightOfPreviousWindow = previous window's amount request * (whenNewWindow / Exp
rate = weightOfPreviousWindow + current window's amount request.
```

## Dynamic limit

You can also calculate the limit dynamically using the MaxFunc parameter. It's a function that receives the request's context as a parameter and allow you to calculate a different limit for each request separately.

Example:

```go
app.Use(limiter.New(limiter.Config{
MaxFunc: func(c fiber.Ctx) int {
return getUserLimit(ctx.Param("id"))
},
Expiration: 30 * time.Second,
}))
```

## Config

| Property | Type | Description | Default |
|:-----------------------|:--------------------------|:--------------------------------------------------------------------------------------------|:-----------------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| Max | `int` | Max number of recent connections during `Expiration` seconds before sending a 429 response. | 5 |
| MaxFunc | `func(fiber.Ctx) int` | A function to calculate the max number of recent connections during `Expiration` seconds before sending a 429 response. | A function which returns the cfg.Max |
| KeyGenerator | `func(fiber.Ctx) string` | KeyGenerator allows you to generate custom keys, by default c.IP() is used. | A function using c.IP() as the default |
| Expiration | `time.Duration` | Expiration is the time on how long to keep records of requests in memory. | 1 * time.Minute |
| LimitReached | `fiber.Handler` | LimitReached is called when a request hits the limit. | A function sending 429 response |
Expand All @@ -101,6 +120,9 @@ A custom store can be used if it implements the `Storage` interface - more detai
```go
var ConfigDefault = Config{
Max: 5,
MaxFunc: func(c fiber.Ctx) int {
return 5
},
Expiration: 1 * time.Minute,
KeyGenerator: func(c fiber.Ctx) string {
return c.IP()
Expand Down
12 changes: 12 additions & 0 deletions middleware/limiter/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ type Config struct {
// Optional. Default: nil
Next func(c fiber.Ctx) bool

// A function to dynamically calculate the max requests supported by the rate limiter middleware
//
// Default: func(c fiber.Ctx) int {
// return c.Max
// }
MaxFunc func(c fiber.Ctx) int

// KeyGenerator allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c fiber.Ctx) string {
Expand Down Expand Up @@ -101,5 +108,10 @@ func configDefault(config ...Config) Config {
if cfg.LimiterMiddleware == nil {
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
}
if cfg.MaxFunc == nil {
cfg.MaxFunc = func(_ fiber.Ctx) int {
return cfg.Max
}
}
return cfg
}
14 changes: 8 additions & 6 deletions middleware/limiter/limiter_fixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)

Expand All @@ -27,8 +26,11 @@ func (FixedWindow) New(cfg Config) fiber.Handler {

// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
// Generate max from generator, if no generator was provided the default value returned is 5
max := cfg.MaxFunc(c)

// Don't execute middleware if Next returns true or if the max is 0
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
return c.Next()
}

Expand Down Expand Up @@ -60,15 +62,15 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
resetInSec := e.exp - ts

// Set how many hits we have left
remaining := cfg.Max - e.currHits
remaining := max - e.currHits

// Update storage
manager.set(key, e, cfg.Expiration)

// Unlock entry
mux.Unlock()

// Check if hits exceed the cfg.Max
// Check if hits exceed the max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
Expand Down Expand Up @@ -96,7 +98,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler {
}

// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitLimit, strconv.Itoa(max))
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))

Expand Down
10 changes: 6 additions & 4 deletions middleware/limiter/limiter_sliding.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
expiration = uint64(cfg.Expiration.Seconds())
)

Expand All @@ -28,8 +27,11 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {

// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
// Generate max from generator, if no generator was provided the default value returned is 5
max := cfg.MaxFunc(c)

// Don't execute middleware if Next returns true or if the max is 0
if (cfg.Next != nil && cfg.Next(c)) || max == 0 {
return c.Next()
}

Expand Down Expand Up @@ -127,7 +129,7 @@ func (SlidingWindow) New(cfg Config) fiber.Handler {
}

// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitLimit, strconv.Itoa(max))
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))

Expand Down
126 changes: 126 additions & 0 deletions middleware/limiter/limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,132 @@ import (
"github.com/valyala/fasthttp"
)

// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
func Test_Limiter_With_Max_Func_With_Zero_And_Limiter_Sliding(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int { return 0 },
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: SlidingWindow{},
}))

app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", nil))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)

time.Sleep(4*time.Second + 500*time.Millisecond)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}

// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
func Test_Limiter_With_Max_Func_With_Zero(t *testing.T) {
t.Parallel()
app := fiber.New()

app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int {
return 0
},
Expiration: 2 * time.Second,
Storage: memory.New(),
}))

app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})

var wg sync.WaitGroup

for i := 0; i <= 4; i++ {
wg.Add(1)
go func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
}(&wg)
}

wg.Wait()

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}

// go test -run Test_Limiter_With_Max_Func -race -v
func Test_Limiter_With_Max_Func(t *testing.T) {
t.Parallel()
app := fiber.New()
max := 10

app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int {
return max
},
Expiration: 2 * time.Second,
Storage: memory.New(),
}))

app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})

var wg sync.WaitGroup

for i := 0; i <= max-1; i++ {
wg.Add(1)
go func(wg *sync.WaitGroup) {
defer wg.Done()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)

body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
}(&wg)
}

wg.Wait()

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)

time.Sleep(3 * time.Second)

resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}

// go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_Concurrency_Store(t *testing.T) {
t.Parallel()
Expand Down

1 comment on commit 011c8f8

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.50.

Benchmark suite Current: 011c8f8 Previous: 87bb93e Ratio
Benchmark_Ctx_Send 7.126 ns/op 0 B/op 0 allocs/op 4.661 ns/op 0 B/op 0 allocs/op 1.53
Benchmark_Ctx_Send - ns/op 7.126 ns/op 4.661 ns/op 1.53
Benchmark_Utils_GetOffer/1_parameter 205.8 ns/op 0 B/op 0 allocs/op 136.1 ns/op 0 B/op 0 allocs/op 1.51
Benchmark_Utils_GetOffer/1_parameter - ns/op 205.8 ns/op 136.1 ns/op 1.51
Benchmark_Utils_getGroupPath - allocs/op 4 allocs/op 2 allocs/op 2
Benchmark_Middleware_BasicAuth - B/op 80 B/op 48 B/op 1.67
Benchmark_Middleware_BasicAuth - allocs/op 5 allocs/op 3 allocs/op 1.67
Benchmark_Middleware_BasicAuth_Upper - B/op 80 B/op 48 B/op 1.67
Benchmark_Middleware_BasicAuth_Upper - allocs/op 5 allocs/op 3 allocs/op 1.67
Benchmark_CORS_NewHandler - B/op 16 B/op 0 B/op +∞
Benchmark_CORS_NewHandler - allocs/op 1 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerSingleOrigin - B/op 16 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerSingleOrigin - allocs/op 1 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflight 1170 ns/op 104 B/op 5 allocs/op 759.2 ns/op 0 B/op 0 allocs/op 1.54
Benchmark_CORS_NewHandlerPreflight - ns/op 1170 ns/op 759.2 ns/op 1.54
Benchmark_CORS_NewHandlerPreflight - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflight - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflightSingleOrigin 1162 ns/op 104 B/op 5 allocs/op 757.5 ns/op 0 B/op 0 allocs/op 1.53
Benchmark_CORS_NewHandlerPreflightSingleOrigin - ns/op 1162 ns/op 757.5 ns/op 1.53
Benchmark_CORS_NewHandlerPreflightSingleOrigin - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflightSingleOrigin - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_CORS_NewHandlerPreflightWildcard 1089 ns/op 104 B/op 5 allocs/op 691 ns/op 0 B/op 0 allocs/op 1.58
Benchmark_CORS_NewHandlerPreflightWildcard - ns/op 1089 ns/op 691 ns/op 1.58
Benchmark_CORS_NewHandlerPreflightWildcard - B/op 104 B/op 0 B/op +∞
Benchmark_CORS_NewHandlerPreflightWildcard - allocs/op 5 allocs/op 0 allocs/op +∞
Benchmark_Middleware_CSRF_Check - allocs/op 11 allocs/op 7 allocs/op 1.57
Benchmark_Middleware_CSRF_GenerateToken - B/op 510 B/op 326 B/op 1.56
Benchmark_Middleware_CSRF_GenerateToken - allocs/op 10 allocs/op 6 allocs/op 1.67

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.