Skip to content

Commit

Permalink
fix: refactor request params to use generics (supabase#1464)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?
* Introduce a new method `retrieveRequestParams` which makes use of
generics to parse a request
* This will help to simplify parsing a request from:
```go

params := RequestParams{}
body, err := getBodyBytes(r)
if err != nil {
  return nil, badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, &params); err != nil {
  return nil, badRequestError("Could not decode request params: %v", err)
}
```
to 
```go
params := &Request{}
err := retrieveRequestParams(req, params)
```

## TODO
- [x] Add type constraint instead of using `any`
  • Loading branch information
kangmingtay authored Mar 4, 2024
1 parent b536d36 commit e1cdf5c
Show file tree
Hide file tree
Showing 19 changed files with 98 additions and 187 deletions.
25 changes: 7 additions & 18 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,12 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex
}

func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) {
params := AdminUserParams{}

body, err := getBodyBytes(r)
if err != nil {
return nil, badRequestError("Could not read body").WithInternalError(err)
params := &AdminUserParams{}
if err := retrieveRequestParams(r, params); err != nil {
return nil, err
}

if err := json.Unmarshal(body, &params); err != nil {
return nil, badRequestError("Could not decode admin user params: %v", err)
}

return &params, nil
return params, nil
}

// adminUsers responds with a list of all users in a given audience
Expand Down Expand Up @@ -565,16 +559,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
user := getUser(ctx)
adminUser := getAdminUser(ctx)
params := &adminUserUpdateFactorParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read factor update params: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

err = a.db.Transaction(func(tx *storage.Connection) error {
err := a.db.Transaction(func(tx *storage.Connection) error {
if params.FriendlyName != "" {
if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil {
return terr
Expand Down
4 changes: 2 additions & 2 deletions internal/api/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
return forbiddenError("Signups not allowed for this instance")
}

params, err := retrieveSignupParams(r)
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
params.Aud = aud
Expand Down
4 changes: 2 additions & 2 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})
r.Post("/", func(w http.ResponseWriter, r *http.Request) error {
params, err := retrieveSignupParams(r)
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}
if params.Email == "" && params.Phone == "" {
Expand Down
38 changes: 38 additions & 0 deletions internal/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,41 @@ func isStringInSlice(checkValue string, list []string) bool {
func getBodyBytes(req *http.Request) ([]byte, error) {
return utilities.GetBodyBytes(req)
}

type RequestParams interface {
AdminUserParams |
CreateSSOProviderParams |
EnrollFactorParams |
GenerateLinkParams |
IdTokenGrantParams |
InviteParams |
OtpParams |
PKCEGrantParams |
PasswordGrantParams |
RecoverParams |
RefreshTokenGrantParams |
ResendConfirmationParams |
SignupParams |
SingleSignOnParams |
SmsParams |
UserUpdateParams |
VerifyFactorParams |
VerifyParams |
adminUserUpdateFactorParams |
struct {
Email string `json:"email"`
Phone string `json:"phone"`
}
}

// retrieveRequestParams is a generic method that unmarshals the request body into the params struct provided
func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error {
body, err := getBodyBytes(r)
if err != nil {
return internalServerError("Could not read body into byte slice").WithInternalError(err)
}
if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read request body: %v", err)
}
return nil
}
12 changes: 3 additions & 9 deletions internal/api/invite.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"net/http"

"github.com/fatih/structs"
Expand All @@ -24,16 +23,11 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {
config := a.config
adminUser := getAdminUser(ctx)
params := &InviteParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read Invite params: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

var err error
params.Email, err = validateEmail(params.Email)
if err != nil {
return err
Expand Down
12 changes: 3 additions & 9 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -49,16 +48,11 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error {
mailer := a.Mailer(ctx)
adminUser := getAdminUser(ctx)
params := &GenerateLinkParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not parse JSON: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

var err error
params.Email, err = validateEmail(params.Email)
if err != nil {
return err
Expand Down
24 changes: 6 additions & 18 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand Down Expand Up @@ -66,15 +65,10 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error {
config := a.config

params := &EnrollFactorParams{}
issuer := ""
body, err := getBodyBytes(r)
if err != nil {
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("invalid body: unable to parse JSON").WithInternalError(err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}
issuer := ""

if params.FactorType != models.TOTP {
return badRequestError("factor_type needs to be totp")
Expand Down Expand Up @@ -206,16 +200,10 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
config := a.config

params := &VerifyFactorParams{}
currentIP := utilities.GetIPAddress(r)

body, err := getBodyBytes(r)
if err != nil {
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("invalid body: unable to parse JSON").WithInternalError(err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}
currentIP := utilities.GetIPAddress(r)

if !factor.IsOwnedBy(user) {
return internalServerError(InvalidFactorOwnerErrorMessage)
Expand Down
7 changes: 1 addition & 6 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,12 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler {

if shouldRateLimitEmail || shouldRateLimitPhone {
if req.Method == "PUT" || req.Method == "POST" {
bodyBytes, err := getBodyBytes(req)
if err != nil {
return c, internalServerError("Error invalid request body").WithInternalError(err)
}

var requestBody struct {
Email string `json:"email"`
Phone string `json:"phone"`
}

if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
if err := retrieveRequestParams(req, &requestBody); err != nil {
return c, badRequestError("Error invalid request body").WithInternalError(err)
}

Expand Down
16 changes: 3 additions & 13 deletions internal/api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,10 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error {
params.Data = make(map[string]interface{})
}

body, err := getBodyBytes(r)
if err != nil {
if err := retrieveRequestParams(r, params); err != nil {
return err
}

if err = json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read verification params: %v", err)
}

if err := params.Validate(); err != nil {
return err
}
Expand Down Expand Up @@ -115,15 +110,10 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
var err error

params := &SmsParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read sms otp params: %v", err)
}
// For backwards compatibility, we default to SMS if params Channel is not specified
if params.Phone != "" && params.Channel == "" {
params.Channel = sms_provider.SMSProvider
Expand Down
12 changes: 3 additions & 9 deletions internal/api/recover.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"errors"
"net/http"

Expand Down Expand Up @@ -37,14 +36,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)
config := a.config
params := &RecoverParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read verification params: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

flowType := getFlowFromChallenge(params.CodeChallenge)
Expand All @@ -53,6 +46,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}

var user *models.User
var err error
aud := a.requestAud(ctx, r)

user, err = models.FindUserByEmailAndAudience(db, params.Email, aud)
Expand Down
12 changes: 3 additions & 9 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package api

import (
"encoding/json"
"errors"
"net/http"
"time"
Expand Down Expand Up @@ -68,21 +67,16 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)
config := a.config
params := &ResendConfirmationParams{}

body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read params: %v", err)
if err := retrieveRequestParams(r, params); err != nil {
return err
}

if err := params.Validate(config); err != nil {
return err
}

var user *models.User
var err error
aud := a.requestAud(ctx, r)
if params.Email != "" {
user, err = models.FindUserByEmailAndAudience(db, params.Email, aud)
Expand Down
18 changes: 3 additions & 15 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
Expand Down Expand Up @@ -108,18 +107,6 @@ func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err
return user, nil
}

func retrieveSignupParams(r *http.Request) (*SignupParams, error) {
params := &SignupParams{}
body, err := getBodyBytes(r)
if err != nil {
return nil, internalServerError("Could not read body").WithInternalError(err)
}
if err := json.Unmarshal(body, params); err != nil {
return nil, badRequestError("Could not read Signup params: %v", err)
}
return params, nil
}

// Signup is the endpoint for registering a new user
func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
Expand All @@ -130,8 +117,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return forbiddenError("Signups not allowed for this instance")
}

params, err := retrieveSignupParams(r)
if err != nil {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
return err
}

Expand All @@ -142,6 +129,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}

var codeChallengeMethod models.CodeChallengeMethod
var err error
flowType := getFlowFromChallenge(params.CodeChallenge)

if isPKCEFlow(flowType) {
Expand Down
Loading

0 comments on commit e1cdf5c

Please sign in to comment.