From 8913779051b0991c532658f8b484de294a5ebfd1 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Fri, 1 Mar 2024 19:47:13 +0800 Subject: [PATCH 1/5] fix: add generic method to parse request params --- internal/api/helpers.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 282ae46d3..716c48249 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -75,3 +75,14 @@ func isStringInSlice(checkValue string, list []string) bool { func getBodyBytes(req *http.Request) ([]byte, error) { return utilities.GetBodyBytes(req) } + +func retrieveRequestParams[A any](r *http.Request, params *A) (*A, error) { + body, err := getBodyBytes(r) + if err != nil { + return nil, badRequestError("Could not read body into byte slice").WithInternalError(err) + } + if err := json.Unmarshal(body, params); err != nil { + return nil, badRequestError("Could not read request body: %v", err) + } + return params, nil +} From d2cdf38731bec008d3c75d3fe1ed5675745e62a8 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Fri, 1 Mar 2024 19:48:25 +0800 Subject: [PATCH 2/5] chore: use generic method for parsing request params --- internal/api/admin.go | 21 +++++---------------- internal/api/invite.go | 11 ++--------- internal/api/mail.go | 11 ++--------- internal/api/mfa.go | 24 ++++++------------------ internal/api/middleware.go | 12 ++++-------- internal/api/otp.go | 16 ++++------------ internal/api/recover.go | 11 ++--------- internal/api/resend.go | 11 ++--------- internal/api/signup.go | 9 ++------- internal/api/sso.go | 11 ++--------- internal/api/ssoadmin.go | 19 ++++--------------- internal/api/token.go | 20 ++++---------------- internal/api/token_oidc.go | 11 ++--------- internal/api/token_refresh.go | 11 ++--------- internal/api/user.go | 11 ++--------- internal/api/verify.go | 9 +++------ 16 files changed, 48 insertions(+), 170 deletions(-) diff --git a/internal/api/admin.go b/internal/api/admin.go index d6fd17dac..283843448 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -85,18 +85,12 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex } func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { - params := AdminUserParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &AdminUserParams{}) if err != nil { - return nil, badRequestError("Could not read body").WithInternalError(err) + return nil, err } - if err := json.Unmarshal(body, ¶ms); err != nil { - return nil, badRequestError("Could not decode admin user params: %v", err) - } - - return ¶ms, nil + return params, nil } // adminUsers responds with a list of all users in a given audience @@ -564,14 +558,9 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro factor := getFactor(ctx) user := getUser(ctx) adminUser := getAdminUser(ctx) - params := &adminUserUpdateFactorParams{} - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &adminUserUpdateFactorParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read factor update params: %v", err) + return err } err = a.db.Transaction(func(tx *storage.Connection) error { diff --git a/internal/api/invite.go b/internal/api/invite.go index 65a651985..b17d82a6d 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "github.com/fatih/structs" @@ -23,15 +22,9 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) config := a.config adminUser := getAdminUser(ctx) - params := &InviteParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &InviteParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read Invite params: %v", err) + return err } params.Email, err = validateEmail(params.Email) diff --git a/internal/api/mail.go b/internal/api/mail.go index 057c451b8..0bea23cd0 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "net/url" "strings" @@ -48,15 +47,9 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { config := a.config mailer := a.Mailer(ctx) adminUser := getAdminUser(ctx) - params := &GenerateLinkParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &GenerateLinkParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse JSON: %v", err) + return err } params.Email, err = validateEmail(params.Email) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index b67610d1d..515882198 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -2,7 +2,6 @@ package api import ( "bytes" - "encoding/json" "fmt" "net/http" "net/url" @@ -65,16 +64,11 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { session := getSession(ctx) config := a.config - params := &EnrollFactorParams{} - issuer := "" - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &EnrollFactorParams{}) if err != nil { - return internalServerError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return err } + issuer := "" if params.FactorType != models.TOTP { return badRequestError("factor_type needs to be totp") @@ -205,17 +199,11 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { factor := getFactor(ctx) config := a.config - params := &VerifyFactorParams{} - currentIP := utilities.GetIPAddress(r) - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &VerifyFactorParams{}) if err != nil { - return internalServerError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return err } + currentIP := utilities.GetIPAddress(r) if !factor.IsOwnedBy(user) { return internalServerError(InvalidFactorOwnerErrorMessage) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index cf83f6629..126851497 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -95,22 +95,18 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if shouldRateLimitEmail || shouldRateLimitPhone { if req.Method == "PUT" || req.Method == "POST" { - bodyBytes, err := getBodyBytes(req) - if err != nil { - return c, internalServerError("Error invalid request body").WithInternalError(err) - } - var requestBody struct { Email string `json:"email"` Phone string `json:"phone"` } - if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { + params, err := retrieveRequestParams(req, &requestBody) + if err != nil { return c, badRequestError("Error invalid request body").WithInternalError(err) } if shouldRateLimitEmail { - if requestBody.Email != "" { + if params.Email != "" { if err := tollbooth.LimitByKeys(emailLimiter, []string{"email_functions"}); err != nil { emailRateLimitCounter.Add( req.Context(), @@ -123,7 +119,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { } if shouldRateLimitPhone { - if requestBody.Phone != "" { + if params.Phone != "" { if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil { return c, httpError(http.StatusTooManyRequests, "Sms rate limit exceeded") } diff --git a/internal/api/otp.go b/internal/api/otp.go index 752437bb6..b3ee8e6ae 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -61,6 +61,7 @@ func (p *SmsParams) Validate(smsProvider string) error { // Otp returns the MagicLink or SmsOtp handler based on the request body params func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { + var err error params := &OtpParams{ CreateUser: true, } @@ -68,15 +69,11 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { params.Data = make(map[string]interface{}) } - body, err := getBodyBytes(r) + params, err = retrieveRequestParams(r, params) if err != nil { return err } - if err = json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) - } - if err := params.Validate(); err != nil { return err } @@ -114,16 +111,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } var err error - params := &SmsParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &SmsParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return err } - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read sms otp params: %v", err) - } // For backwards compatibility, we default to SMS if params Channel is not specified if params.Phone != "" && params.Channel == "" { params.Channel = sms_provider.SMSProvider diff --git a/internal/api/recover.go b/internal/api/recover.go index 9a5757565..bc6edce48 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "errors" "net/http" @@ -36,15 +35,9 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config - params := &RecoverParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &RecoverParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) + return err } flowType := getFlowFromChallenge(params.CodeChallenge) diff --git a/internal/api/resend.go b/internal/api/resend.go index a49a55e7e..7f703ac3f 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "errors" "net/http" "time" @@ -67,15 +66,9 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config - params := &ResendConfirmationParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &ResendConfirmationParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return err } if err := params.Validate(config); err != nil { diff --git a/internal/api/signup.go b/internal/api/signup.go index e2c858f1f..a9135b353 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "fmt" "net/http" "time" @@ -109,13 +108,9 @@ func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err } func retrieveSignupParams(r *http.Request) (*SignupParams, error) { - params := &SignupParams{} - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &SignupParams{}) if err != nil { - return nil, internalServerError("Could not read body").WithInternalError(err) - } - if err := json.Unmarshal(body, params); err != nil { - return nil, badRequestError("Could not read Signup params: %v", err) + return nil, err } return params, nil } diff --git a/internal/api/sso.go b/internal/api/sso.go index d93ff82dc..a07a82c02 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -1,7 +1,6 @@ package api import ( - "encoding/json" "net/http" "github.com/crewjam/saml" @@ -41,15 +40,9 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &SingleSignOnParams{}) if err != nil { - return internalServerError("Unable to read request body").WithInternalError(err) - } - - var params SingleSignOnParams - - if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse request body as JSON").WithInternalError(err) + return err } hasProviderID := false diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index 4fdecc0f8..fbe2de242 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "io" "net/http" "net/url" @@ -184,14 +183,9 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er ctx := r.Context() db := a.db.WithContext(ctx) - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &CreateSSOProviderParams{}) if err != nil { - return internalServerError("Unable to read request body").WithInternalError(err) - } - - var params CreateSSOProviderParams - if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse JSON").WithInternalError(err) + return err } if err := params.validate(false /* <- forUpdate */); err != nil { @@ -264,14 +258,9 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er ctx := r.Context() db := a.db.WithContext(ctx) - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &CreateSSOProviderParams{}) if err != nil { - return internalServerError("Unable to read request body").WithInternalError(err) - } - - var params CreateSSOProviderParams - if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse JSON").WithInternalError(err) + return err } if err := params.validate(true /* <- forUpdate */); err != nil { diff --git a/internal/api/token.go b/internal/api/token.go index 1675a6579..7fb91eeeb 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "net/http" "net/url" @@ -100,15 +99,9 @@ func (a *API) Token(w http.ResponseWriter, r *http.Request) error { func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) - params := &PasswordGrantParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &PasswordGrantParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read password grant params: %v", err) + return err } aud := a.requestAud(ctx, r) @@ -235,14 +228,9 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) // can be told to at least propagate the User-Agent header. grantParams.FillGrantParams(r) - params := &PKCEGrantParams{} - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &PKCEGrantParams{}) if err != nil { - return internalServerError("Could not read body").WithInternalError(err) - } - - if err = json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return err } if params.AuthCode == "" || params.CodeVerifier == "" { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index c380856c3..59c5e57b2 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -3,7 +3,6 @@ package api import ( "context" "crypto/sha256" - "encoding/json" "fmt" "net/http" @@ -113,15 +112,9 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R db := a.db.WithContext(ctx) config := a.config - params := &IdTokenGrantParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &IdTokenGrantParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read id token grant params: %v", err) + return err } if params.IdToken == "" { diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index ebe4b5f2d..458c2938a 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" mathRand "math/rand" "net/http" "time" @@ -25,15 +24,9 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h db := a.db.WithContext(ctx) config := a.config - params := &RefreshTokenGrantParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &RefreshTokenGrantParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read refresh token grant params: %v", err) + return err } if params.RefreshToken == "" { diff --git a/internal/api/user.go b/internal/api/user.go index e31a3eceb..852343ea4 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "net/http" "time" @@ -83,15 +82,9 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { config := a.config aud := a.requestAud(ctx, r) - params := &UserUpdateParams{} - - body, err := getBodyBytes(r) + params, err := retrieveRequestParams(r, &UserUpdateParams{}) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read User Update params: %v", err) + return err } user := getUser(ctx) diff --git a/internal/api/verify.go b/internal/api/verify.go index 35f8253eb..4ee8abf16 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -2,7 +2,6 @@ package api import ( "context" - "encoding/json" "errors" "net/http" "net/url" @@ -107,12 +106,10 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { } return a.verifyGet(w, r, params) case http.MethodPost: - body, err := getBodyBytes(r) + var err error + params, err = retrieveRequestParams(r, params) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) - } - if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse verification params: %v", err) + return err } if err := params.Validate(r); err != nil { return err From b746fc6b26111c53d5470c6a048ebde081c23064 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Sun, 3 Mar 2024 12:16:51 +0800 Subject: [PATCH 3/5] chore: return 500 error --- internal/api/helpers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 716c48249..b0fe353a5 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -79,7 +79,7 @@ func getBodyBytes(req *http.Request) ([]byte, error) { func retrieveRequestParams[A any](r *http.Request, params *A) (*A, error) { body, err := getBodyBytes(r) if err != nil { - return nil, badRequestError("Could not read body into byte slice").WithInternalError(err) + return nil, internalServerError("Could not read body into byte slice").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { return nil, badRequestError("Could not read request body: %v", err) From e8033706976516fcfa806603441a262291d871fd Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Sun, 3 Mar 2024 17:44:08 +0800 Subject: [PATCH 4/5] fix: add type constraints --- internal/api/helpers.go | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/internal/api/helpers.go b/internal/api/helpers.go index b0fe353a5..8b7cfec0d 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -76,7 +76,34 @@ func getBodyBytes(req *http.Request) ([]byte, error) { return utilities.GetBodyBytes(req) } -func retrieveRequestParams[A any](r *http.Request, params *A) (*A, error) { +type RequestParams interface { + AdminUserParams | + CreateSSOProviderParams | + EnrollFactorParams | + GenerateLinkParams | + IdTokenGrantParams | + InviteParams | + OtpParams | + PKCEGrantParams | + PasswordGrantParams | + RecoverParams | + RefreshTokenGrantParams | + ResendConfirmationParams | + SignupParams | + SingleSignOnParams | + SmsParams | + UserUpdateParams | + VerifyFactorParams | + VerifyParams | + adminUserUpdateFactorParams | + struct { + Email string `json:"email"` + Phone string `json:"phone"` + } +} + +// retrieveRequestParams is a generic method that unmarshals the request body into the params struct provided +func retrieveRequestParams[A RequestParams](r *http.Request, params *A) (*A, error) { body, err := getBodyBytes(r) if err != nil { return nil, internalServerError("Could not read body into byte slice").WithInternalError(err) From fb0b447c239bf90269032e066d550039314282f8 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Sun, 3 Mar 2024 17:48:14 +0800 Subject: [PATCH 5/5] chore: replace retrieveSignupParams --- internal/api/admin.go | 10 +++++----- internal/api/anonymous.go | 4 ++-- internal/api/api.go | 4 ++-- internal/api/helpers.go | 8 ++++---- internal/api/invite.go | 5 +++-- internal/api/mail.go | 5 +++-- internal/api/mfa.go | 8 ++++---- internal/api/middleware.go | 7 +++---- internal/api/otp.go | 8 +++----- internal/api/recover.go | 5 +++-- internal/api/resend.go | 5 +++-- internal/api/signup.go | 13 +++---------- internal/api/sso.go | 5 +++-- internal/api/ssoadmin.go | 8 ++++---- internal/api/token.go | 10 ++++++---- internal/api/token_oidc.go | 4 ++-- internal/api/token_refresh.go | 4 ++-- internal/api/user.go | 6 +++--- internal/api/verify.go | 4 +--- 19 files changed, 59 insertions(+), 64 deletions(-) diff --git a/internal/api/admin.go b/internal/api/admin.go index 283843448..89f7af975 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -85,8 +85,8 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex } func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { - params, err := retrieveRequestParams(r, &AdminUserParams{}) - if err != nil { + params := &AdminUserParams{} + if err := retrieveRequestParams(r, params); err != nil { return nil, err } @@ -558,12 +558,12 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro factor := getFactor(ctx) user := getUser(ctx) adminUser := getAdminUser(ctx) - params, err := retrieveRequestParams(r, &adminUserUpdateFactorParams{}) - if err != nil { + params := &adminUserUpdateFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } - err = a.db.Transaction(func(tx *storage.Connection) error { + err := a.db.Transaction(func(tx *storage.Connection) error { if params.FriendlyName != "" { if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil { return terr diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go index 11412639e..5316525a4 100644 --- a/internal/api/anonymous.go +++ b/internal/api/anonymous.go @@ -18,8 +18,8 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { return forbiddenError("Signups not allowed for this instance") } - params, err := retrieveSignupParams(r) - if err != nil { + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } params.Aud = aud diff --git a/internal/api/api.go b/internal/api/api.go index eb27c4dce..73d810fa2 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -149,8 +149,8 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati DefaultExpirationTTL: time.Hour, }).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"}) r.Post("/", func(w http.ResponseWriter, r *http.Request) error { - params, err := retrieveSignupParams(r) - if err != nil { + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } if params.Email == "" && params.Phone == "" { diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 8b7cfec0d..ea4102f2e 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -103,13 +103,13 @@ type RequestParams interface { } // retrieveRequestParams is a generic method that unmarshals the request body into the params struct provided -func retrieveRequestParams[A RequestParams](r *http.Request, params *A) (*A, error) { +func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error { body, err := getBodyBytes(r) if err != nil { - return nil, internalServerError("Could not read body into byte slice").WithInternalError(err) + return internalServerError("Could not read body into byte slice").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return nil, badRequestError("Could not read request body: %v", err) + return badRequestError("Could not read request body: %v", err) } - return params, nil + return nil } diff --git a/internal/api/invite.go b/internal/api/invite.go index b17d82a6d..45d94878c 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -22,11 +22,12 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) config := a.config adminUser := getAdminUser(ctx) - params, err := retrieveRequestParams(r, &InviteParams{}) - if err != nil { + params := &InviteParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } + var err error params.Email, err = validateEmail(params.Email) if err != nil { return err diff --git a/internal/api/mail.go b/internal/api/mail.go index 0bea23cd0..0ab561ab7 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -47,11 +47,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { config := a.config mailer := a.Mailer(ctx) adminUser := getAdminUser(ctx) - params, err := retrieveRequestParams(r, &GenerateLinkParams{}) - if err != nil { + params := &GenerateLinkParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } + var err error params.Email, err = validateEmail(params.Email) if err != nil { return err diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 515882198..077a6966a 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -64,8 +64,8 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { session := getSession(ctx) config := a.config - params, err := retrieveRequestParams(r, &EnrollFactorParams{}) - if err != nil { + params := &EnrollFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } issuer := "" @@ -199,8 +199,8 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { factor := getFactor(ctx) config := a.config - params, err := retrieveRequestParams(r, &VerifyFactorParams{}) - if err != nil { + params := &VerifyFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } currentIP := utilities.GetIPAddress(r) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 126851497..ab72e32c0 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -100,13 +100,12 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { Phone string `json:"phone"` } - params, err := retrieveRequestParams(req, &requestBody) - if err != nil { + if err := retrieveRequestParams(req, &requestBody); err != nil { return c, badRequestError("Error invalid request body").WithInternalError(err) } if shouldRateLimitEmail { - if params.Email != "" { + if requestBody.Email != "" { if err := tollbooth.LimitByKeys(emailLimiter, []string{"email_functions"}); err != nil { emailRateLimitCounter.Add( req.Context(), @@ -119,7 +118,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { } if shouldRateLimitPhone { - if params.Phone != "" { + if requestBody.Phone != "" { if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil { return c, httpError(http.StatusTooManyRequests, "Sms rate limit exceeded") } diff --git a/internal/api/otp.go b/internal/api/otp.go index b3ee8e6ae..0e437faa1 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -61,7 +61,6 @@ func (p *SmsParams) Validate(smsProvider string) error { // Otp returns the MagicLink or SmsOtp handler based on the request body params func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { - var err error params := &OtpParams{ CreateUser: true, } @@ -69,8 +68,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { params.Data = make(map[string]interface{}) } - params, err = retrieveRequestParams(r, params) - if err != nil { + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -111,8 +109,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } var err error - params, err := retrieveRequestParams(r, &SmsParams{}) - if err != nil { + params := &SmsParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } diff --git a/internal/api/recover.go b/internal/api/recover.go index bc6edce48..dcf574d1d 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -35,8 +35,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config - params, err := retrieveRequestParams(r, &RecoverParams{}) - if err != nil { + params := &RecoverParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -46,6 +46,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { } var user *models.User + var err error aud := a.requestAud(ctx, r) user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) diff --git a/internal/api/resend.go b/internal/api/resend.go index 7f703ac3f..cb8c4da24 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -66,8 +66,8 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) config := a.config - params, err := retrieveRequestParams(r, &ResendConfirmationParams{}) - if err != nil { + params := &ResendConfirmationParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -76,6 +76,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { } var user *models.User + var err error aud := a.requestAud(ctx, r) if params.Email != "" { user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) diff --git a/internal/api/signup.go b/internal/api/signup.go index a9135b353..7093fe3be 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -107,14 +107,6 @@ func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err return user, nil } -func retrieveSignupParams(r *http.Request) (*SignupParams, error) { - params, err := retrieveRequestParams(r, &SignupParams{}) - if err != nil { - return nil, err - } - return params, nil -} - // Signup is the endpoint for registering a new user func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() @@ -125,8 +117,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return forbiddenError("Signups not allowed for this instance") } - params, err := retrieveSignupParams(r) - if err != nil { + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -137,6 +129,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } var codeChallengeMethod models.CodeChallengeMethod + var err error flowType := getFlowFromChallenge(params.CodeChallenge) if isPKCEFlow(flowType) { diff --git a/internal/api/sso.go b/internal/api/sso.go index a07a82c02..0b4fd8907 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -40,11 +40,12 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) - params, err := retrieveRequestParams(r, &SingleSignOnParams{}) - if err != nil { + params := &SingleSignOnParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } + var err error hasProviderID := false if hasProviderID, err = params.validate(); err != nil { diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index fbe2de242..0f966780e 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -183,8 +183,8 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er ctx := r.Context() db := a.db.WithContext(ctx) - params, err := retrieveRequestParams(r, &CreateSSOProviderParams{}) - if err != nil { + params := &CreateSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -258,8 +258,8 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er ctx := r.Context() db := a.db.WithContext(ctx) - params, err := retrieveRequestParams(r, &CreateSSOProviderParams{}) - if err != nil { + params := &CreateSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } diff --git a/internal/api/token.go b/internal/api/token.go index 7fb91eeeb..c94b7b1ad 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -99,8 +99,8 @@ func (a *API) Token(w http.ResponseWriter, r *http.Request) error { func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) - params, err := retrieveRequestParams(r, &PasswordGrantParams{}) - if err != nil { + params := &PasswordGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -113,6 +113,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri var user *models.User var grantParams models.GrantParams var provider string + var err error grantParams.FillGrantParams(r) @@ -228,8 +229,9 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) // can be told to at least propagate the User-Agent header. grantParams.FillGrantParams(r) - params, err := retrieveRequestParams(r, &PKCEGrantParams{}) - if err != nil { + params := &PKCEGrantParams{} + + if err := retrieveRequestParams(r, params); err != nil { return err } diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 59c5e57b2..5695cb489 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -112,8 +112,8 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R db := a.db.WithContext(ctx) config := a.config - params, err := retrieveRequestParams(r, &IdTokenGrantParams{}) - if err != nil { + params := &IdTokenGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 458c2938a..65bbbb031 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -24,8 +24,8 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h db := a.db.WithContext(ctx) config := a.config - params, err := retrieveRequestParams(r, &RefreshTokenGrantParams{}) - if err != nil { + params := &RefreshTokenGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } diff --git a/internal/api/user.go b/internal/api/user.go index 852343ea4..ddf497f73 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -82,8 +82,8 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { config := a.config aud := a.requestAud(ctx, r) - params, err := retrieveRequestParams(r, &UserUpdateParams{}) - if err != nil { + params := &UserUpdateParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -163,7 +163,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } } - err = db.Transaction(func(tx *storage.Connection) error { + err := db.Transaction(func(tx *storage.Connection) error { var terr error if params.Password != nil { var sessionID *uuid.UUID diff --git a/internal/api/verify.go b/internal/api/verify.go index 4ee8abf16..6dd29be05 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -106,9 +106,7 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { } return a.verifyGet(w, r, params) case http.MethodPost: - var err error - params, err = retrieveRequestParams(r, params) - if err != nil { + if err := retrieveRequestParams(r, params); err != nil { return err } if err := params.Validate(r); err != nil {