Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: preserve rate limiters in memory across configuration reloads #1792

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cmd/serve_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ func serve(ctx context.Context) {
addr := net.JoinHostPort(config.API.Host, config.API.Port)
logrus.Infof("GoTrue API started on: %s", addr)

a := api.NewAPIWithVersion(config, db, utilities.Version)
opts := []api.Option{
api.NewLimiterOptions(config),
}
a := api.NewAPIWithVersion(config, db, utilities.Version, opts...)
ah := reloader.NewAtomicHandler(a)

baseCtx, baseCancel := context.WithCancel(context.Background())
Expand All @@ -74,7 +77,8 @@ func serve(ctx context.Context) {

fn := func(latestCfg *conf.GlobalConfiguration) {
log.Info("reloading api with new configuration")
latestAPI := api.NewAPIWithVersion(latestCfg, db, utilities.Version)
latestAPI := api.NewAPIWithVersion(
latestCfg, db, utilities.Version, opts...)
ah.Store(latestAPI)
}

Expand Down
121 changes: 40 additions & 81 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"regexp"
"time"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
"github.com/rs/cors"
"github.com/sebest/xff"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -37,6 +35,8 @@ type API struct {

// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
overrideTime func() time.Time

limiterOpts *LimiterOptions
}

func (a *API) Now() time.Time {
Expand All @@ -48,8 +48,8 @@ func (a *API) Now() time.Time {
}

// NewAPI instantiates a new REST API
func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection) *API {
return NewAPIWithVersion(globalConfig, db, defaultVersion)
func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection, opt ...Option) *API {
return NewAPIWithVersion(globalConfig, db, defaultVersion, opt...)
}

func (a *API) deprecationNotices() {
Expand All @@ -67,9 +67,15 @@ func (a *API) deprecationNotices() {
}

// NewAPIWithVersion creates a new REST API using the specified version
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API {
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API {
api := &API{config: globalConfig, db: db, version: version}

for _, o := range opt {
o.apply(api)
}
if api.limiterOpts == nil {
api.limiterOpts = NewLimiterOptions(globalConfig)
}
if api.config.Password.HIBP.Enabled {
httpClient := &http.Client{
// all HIBP API requests should finish quickly to avoid
Expand Down Expand Up @@ -134,18 +140,12 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.Get("/authorize", api.ExternalProviderRedirect)

sharedLimiter := api.limitEmailOrPhoneSentHandler()
sharedLimiter := api.limitEmailOrPhoneSentHandler(api.limiterOpts)
r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) {
// rate limit per hour
limitAnonymousSignIns := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})

limitSignups := tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns
limitSignups := api.limiterOpts.Signups
r.Post("/", func(w http.ResponseWriter, r *http.Request) error {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
Expand All @@ -172,47 +172,22 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
return api.Signup(w, r)
})
})
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitTokenRefresh/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(api.verifyCaptcha).Post("/token", api.Token)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitVerify/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).Route("/verify", func(r *router) {
r.With(api.limitHandler(api.limiterOpts.Recover)).
With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(api.limiterOpts.Resend)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(api.limiterOpts.MagicLink)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(api.limiterOpts.Otp)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(api.limiterOpts.Token)).
With(api.verifyCaptcha).Post("/token", api.Token)

r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) {
r.Get("/", api.Verify)
r.Post("/", api.Verify)
})
Expand All @@ -225,12 +200,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.With(api.requireAuthentication).Route("/user", func(r *router) {
r.Get("/", api.UserGet)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).Put("/", api.UserUpdate)
r.With(api.limitHandler(api.limiterOpts.User)).
With(sharedLimiter).Put("/", api.UserUpdate)

r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
Expand All @@ -245,37 +216,25 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.Route("/{factor_id}", func(r *router) {
r.Use(api.loadFactor)

r.With(api.limitHandler(
tollbooth.NewLimiter(api.config.MFA.RateLimitChallengeAndVerify/60, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30))).Post("/verify", api.VerifyFactor)
r.With(api.limitHandler(
tollbooth.NewLimiter(api.config.MFA.RateLimitChallengeAndVerify/60, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30))).Post("/challenge", api.ChallengeFactor)
r.With(api.limitHandler(api.limiterOpts.FactorVerify)).
Post("/verify", api.VerifyFactor)
r.With(api.limitHandler(api.limiterOpts.FactorChallenge)).
Post("/challenge", api.ChallengeFactor)
r.Delete("/", api.UnenrollFactor)

})
})

r.Route("/sso", func(r *router) {
r.Use(api.requireSAMLEnabled)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.RateLimitSso/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(api.verifyCaptcha).Post("/", api.SingleSignOn)
r.With(api.limitHandler(api.limiterOpts.SSO)).
With(api.verifyCaptcha).Post("/", api.SingleSignOn)

r.Route("/saml", func(r *router) {
r.Get("/metadata", api.SAMLMetadata)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
tollbooth.NewLimiter(api.config.SAML.RateLimitAssertion/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).Post("/acs", api.SamlAcs)
r.With(api.limitHandler(api.limiterOpts.SAMLAssertion)).
Post("/acs", api.SamlAcs)
})
})

Expand Down
3 changes: 2 additions & 1 deletion internal/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *storage.Con
cb(nil, conn)
}

return NewAPIWithVersion(config, conn, apiTestVersion), config, nil
limiterOpts := NewLimiterOptions(config)
return NewAPIWithVersion(config, conn, apiTestVersion, limiterOpts), config, nil
}

func TestEmailEnabledByDefault(t *testing.T) {
Expand Down
18 changes: 3 additions & 15 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
}
}

func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
// limit per hour
emailFreq := a.config.RateLimitEmailSent / (60 * 60)
smsFreq := a.config.RateLimitSmsSent / (60 * 60)

emailLimiter := tollbooth.NewLimiter(emailFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(a.config.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})

phoneLimiter := tollbooth.NewLimiter(smsFreq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(a.config.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

func (a *API) limitEmailOrPhoneSentHandler(limiterOptions *LimiterOptions) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
c := req.Context()
config := a.config
Expand All @@ -100,8 +88,8 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {
if req.Method == "PUT" || req.Method == "POST" {
// store rate limiter in request context
c = withLimiter(c, &SharedLimiter{
EmailLimiter: emailLimiter,
PhoneLimiter: phoneLimiter,
EmailLimiter: limiterOptions.Email,
PhoneLimiter: limiterOptions.Phone,
})
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
},
}

limiter := ts.API.limitEmailOrPhoneSentHandler()
limiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
Expand Down Expand Up @@ -484,7 +484,7 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent
ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent
lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig))
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler()
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))

// get the minimum amount to reach the threshold just before the rate limit is exceeded
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
Expand Down
107 changes: 107 additions & 0 deletions internal/api/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package api

import (
"time"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
"github.com/supabase/auth/internal/conf"
)

type Option interface {
apply(*API)
}

type LimiterOptions struct {
Email *limiter.Limiter
Phone *limiter.Limiter
Signups *limiter.Limiter
AnonymousSignIns *limiter.Limiter
Recover *limiter.Limiter
Resend *limiter.Limiter
MagicLink *limiter.Limiter
Otp *limiter.Limiter
Token *limiter.Limiter
Verify *limiter.Limiter
User *limiter.Limiter
FactorVerify *limiter.Limiter
FactorChallenge *limiter.Limiter
SSO *limiter.Limiter
SAMLAssertion *limiter.Limiter
}

func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }

func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
o := &LimiterOptions{}

o.Email = tollbooth.NewLimiter(gc.RateLimitEmailSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})

o.Phone = tollbooth.NewLimiter(gc.RateLimitSmsSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})

o.Token = tollbooth.NewLimiter(gc.RateLimitTokenRefresh/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.Verify = tollbooth.NewLimiter(gc.RateLimitVerify/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.User = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.FactorVerify = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60,
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30)

o.FactorChallenge = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60,
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Minute,
}).SetBurst(30)

o.SSO = tollbooth.NewLimiter(gc.RateLimitSso/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.SAMLAssertion = tollbooth.NewLimiter(gc.SAML.RateLimitAssertion/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

o.Signups = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

// These all use the OTP limit per 5 min with 1hour ttl and burst of 30.
o.Recover = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.MagicLink = newLimiterPer5mOver1h(gc.RateLimitOtp)
o.Otp = newLimiterPer5mOver1h(gc.RateLimitOtp)
return o
}

func newLimiterPer5mOver1h(rate float64) *limiter.Limiter {
freq := rate / (60 * 5)
lim := tollbooth.NewLimiter(freq, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)
return lim
}
Loading