Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: redirect invalid state errors to site url #1722

Merged
merged 6 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
tollbooth.NewLimiter(api.config.SAML.RateLimitAssertion/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).Post("/acs", api.SAMLACS)
)).Post("/acs", api.SamlAcs)
})
})

Expand Down
20 changes: 15 additions & 5 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e
if err != nil {
return err
}
a.redirectErrors(a.internalExternalProviderCallback, w, r, u)
redirectErrors(a.internalExternalProviderCallback, w, r, u)
return nil
}

Expand Down Expand Up @@ -478,18 +478,28 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p
return user, nil
}

func (a *API) loadExternalState(ctx context.Context, state string) (context.Context, error) {
func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) {
var state string
switch r.Method {
case http.MethodPost:
state = r.FormValue("state")
default:
state = r.URL.Query().Get("state")
}
if state == "" {
return ctx, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing")
}
config := a.config
claims := ExternalProviderClaims{}
p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
_, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) {
return []byte(config.JWT.Secret), nil
})
if err != nil {
return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
}
if claims.Provider == "" {
return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)")
return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)")
}
if claims.InviteToken != "" {
ctx = withInviteToken(ctx, claims.InviteToken)
Expand Down Expand Up @@ -573,7 +583,7 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
}
}

func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) {
func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) {
ctx := r.Context()
log := observability.GetLogEntry(r).Entry
errorID := utilities.GetRequestID(ctx)
Expand Down
28 changes: 16 additions & 12 deletions internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/utilities"
)

// OAuthProviderData contains the userData and token returned by the oauth provider
Expand All @@ -23,17 +24,6 @@ type OAuthProviderData struct {
// loadFlowState parses the `state` query parameter as a JWS payload,
// extracting the provider requested
func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) {
var state string
if r.Method == http.MethodPost {
state = r.FormValue("state")
} else {
state = r.URL.Query().Get("state")
}

if state == "" {
return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing")
}

ctx := r.Context()
oauthToken := r.URL.Query().Get("oauth_token")
if oauthToken != "" {
Expand All @@ -43,7 +33,21 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con
if oauthVerifier != "" {
ctx = withOAuthVerifier(ctx, oauthVerifier)
}
return a.loadExternalState(ctx, state)

var err error
ctx, err = a.loadExternalState(ctx, r)
if err != nil {
u, uerr := url.ParseRequestURI(a.config.SiteURL)
if uerr != nil {
return ctx, internalServerError("site url is improperly formatted").WithInternalError(uerr)
}
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved

q := getErrorQueryString(err, utilities.GetRequestID(ctx), observability.GetLogEntry(r).Entry, u.Query())
u.RawQuery = q.Encode()

http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
return ctx, err
}

func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/api/external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func (ts *ExternalTestSuite) TestRedirectErrorsShouldPreserveParams() {
parsedURL, err := url.Parse(c.RedirectURL)
require.Equal(ts.T(), err, nil)

ts.API.redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL)
redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL)

parsedParams, err := url.ParseQuery(parsedURL.RawQuery)
require.Equal(ts.T(), err, nil)
Expand Down
19 changes: 16 additions & 3 deletions internal/api/samlacs.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,22 @@ func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models
return hasValidityExpired || hasCacheDurationExceeded || needsForceUpdate
}

// SAMLACS implements the main Assertion Consumer Service endpoint behavior.
func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {
func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error {
if err := a.handleSamlAcs(w, r); err != nil {
u, uerr := url.Parse(a.config.SiteURL)
if uerr != nil {
return internalServerError("site url is improperly formattted").WithInternalError(err)
}

q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query())
u.RawQuery = q.Encode()
http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
return nil
}

// handleSamlAcs implements the main Assertion Consumer Service endpoint behavior.
func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()

db := a.db.WithContext(ctx)
Expand All @@ -61,7 +75,6 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {
var requestIds []string

var flowState *models.FlowState
flowState = nil
if relayStateUUID != uuid.Nil {
// relay state is a valid UUID, therefore this is likely a SP initiated flow

Expand Down
Loading