diff --git a/internal/api/api.go b/internal/api/api.go index 85292775f..49b810696 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -274,7 +274,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne tollbooth.NewLimiter(api.config.SAML.RateLimitAssertion/(60*5), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, }).SetBurst(30), - )).Post("/acs", api.SAMLACS) + )).Post("/acs", api.SamlAcs) }) }) diff --git a/internal/api/external.go b/internal/api/external.go index ef6032d9a..4df4c6502 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -138,7 +138,7 @@ func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) e if err != nil { return err } - a.redirectErrors(a.internalExternalProviderCallback, w, r, u) + redirectErrors(a.internalExternalProviderCallback, w, r, u) return nil } @@ -478,7 +478,17 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p return user, nil } -func (a *API) loadExternalState(ctx context.Context, state string) (context.Context, error) { +func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) { + var state string + switch r.Method { + case http.MethodPost: + state = r.FormValue("state") + default: + state = r.URL.Query().Get("state") + } + if state == "" { + return ctx, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing") + } config := a.config claims := ExternalProviderClaims{} p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) @@ -486,10 +496,10 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont return []byte(config.JWT.Secret), nil }) if err != nil { - return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) } if claims.Provider == "" { - return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") + return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") } if claims.InviteToken != "" { ctx = withInviteToken(ctx, claims.InviteToken) @@ -573,7 +583,7 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide } } -func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) { +func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) { ctx := r.Context() log := observability.GetLogEntry(r).Entry errorID := utilities.GetRequestID(ctx) diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index af3dd51f4..cb098e373 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -10,6 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/provider" "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/utilities" ) // OAuthProviderData contains the userData and token returned by the oauth provider @@ -23,17 +24,6 @@ type OAuthProviderData struct { // loadFlowState parses the `state` query parameter as a JWS payload, // extracting the provider requested func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) { - var state string - if r.Method == http.MethodPost { - state = r.FormValue("state") - } else { - state = r.URL.Query().Get("state") - } - - if state == "" { - return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing") - } - ctx := r.Context() oauthToken := r.URL.Query().Get("oauth_token") if oauthToken != "" { @@ -43,7 +33,21 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con if oauthVerifier != "" { ctx = withOAuthVerifier(ctx, oauthVerifier) } - return a.loadExternalState(ctx, state) + + var err error + ctx, err = a.loadExternalState(ctx, r) + if err != nil { + u, uerr := url.ParseRequestURI(a.config.SiteURL) + if uerr != nil { + return ctx, internalServerError("site url is improperly formatted").WithInternalError(uerr) + } + + q := getErrorQueryString(err, utilities.GetRequestID(ctx), observability.GetLogEntry(r).Entry, u.Query()) + u.RawQuery = q.Encode() + + http.Redirect(w, r, u.String(), http.StatusSeeOther) + } + return ctx, err } func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) { diff --git a/internal/api/external_test.go b/internal/api/external_test.go index 09fdcc433..bef89d736 100644 --- a/internal/api/external_test.go +++ b/internal/api/external_test.go @@ -235,7 +235,7 @@ func (ts *ExternalTestSuite) TestRedirectErrorsShouldPreserveParams() { parsedURL, err := url.Parse(c.RedirectURL) require.Equal(ts.T(), err, nil) - ts.API.redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL) + redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL) parsedParams, err := url.ParseQuery(parsedURL.RawQuery) require.Equal(ts.T(), err, nil) diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index 0916a7235..907efcd4c 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -43,8 +43,22 @@ func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models return hasValidityExpired || hasCacheDurationExceeded || needsForceUpdate } -// SAMLACS implements the main Assertion Consumer Service endpoint behavior. -func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { +func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { + if err := a.handleSamlAcs(w, r); err != nil { + u, uerr := url.Parse(a.config.SiteURL) + if uerr != nil { + return internalServerError("site url is improperly formattted").WithInternalError(err) + } + + q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query()) + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusSeeOther) + } + return nil +} + +// handleSamlAcs implements the main Assertion Consumer Service endpoint behavior. +func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := a.db.WithContext(ctx) @@ -61,7 +75,6 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { var requestIds []string var flowState *models.FlowState - flowState = nil if relayStateUUID != uuid.Nil { // relay state is a valid UUID, therefore this is likely a SP initiated flow