Skip to content

Commit

Permalink
fix: carry oauth2_login_challenge over to registration flow
Browse files Browse the repository at this point in the history
Fixes #3321
  • Loading branch information
jonas-jonas committed Aug 7, 2023
1 parent 3a07af4 commit 7d6f143
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
7 changes: 7 additions & 0 deletions selfservice/flow/registration/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/julienschmidt/httprouter"
"github.com/pkg/errors"

"github.com/ory/x/sqlxx"
"github.com/ory/x/urlx"

"github.com/ory/kratos/driver/config"
Expand Down Expand Up @@ -108,6 +109,12 @@ func WithFlowReturnTo(returnTo string) FlowOption {
}
}

func WithFlowOAuth2LoginChallenge(loginChallenge string) FlowOption {
return func(f *Flow) {
f.OAuth2LoginChallenge = sqlxx.NullString(loginChallenge)
}
}

func (h *Handler) NewRegistrationFlow(w http.ResponseWriter, r *http.Request, ft flow.Type, opts ...FlowOption) (*Flow, error) {
if !h.d.Config().SelfServiceFlowRegistrationEnabled(r.Context()) {
return nil, errors.WithStack(ErrRegistrationDisabled)
Expand Down
48 changes: 28 additions & 20 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ type UpdateLoginFlowWithOidcMethod struct {
UpstreamParameters json.RawMessage `json:"upstream_parameters"`
}

func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login.Flow, token *oauth2.Token, claims *Claims, provider Provider, container *authCodeContainer) (*registration.Flow, error) {
func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlow *login.Flow, token *oauth2.Token, claims *Claims, provider Provider, container *authCodeContainer) (*registration.Flow, error) {
i, c, err := s.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(r.Context(), identity.CredentialsTypeOIDC, identity.OIDCUniqueID(provider.Config().ID, claims.Subject))
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
Expand All @@ -97,56 +97,64 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login
// not need additional consent/login.

// This is kinda hacky but the only way to ensure seamless login/registration flows when using OIDC.
s.d.Logger().WithField("provider", provider.Config().ID).WithField("subject", claims.Subject).Debug("Received successful OpenID Connect callback but user is not registered. Re-initializing registration flow now.")
s.d.
Logger().
WithField("provider", provider.Config().ID).
WithField("subject", claims.Subject).
Debug("Received successful OpenID Connect callback but user is not registered. Re-initializing registration flow now.")

// If return_to was set before, we need to preserve it.
var opts []registration.FlowOption
if len(a.ReturnTo) > 0 {
opts = append(opts, registration.WithFlowReturnTo(a.ReturnTo))
if len(loginFlow.ReturnTo) > 0 {
opts = append(opts, registration.WithFlowReturnTo(loginFlow.ReturnTo))
}

aa, err := s.d.RegistrationHandler().NewRegistrationFlow(w, r, a.Type, opts...)
if loginFlow.OAuth2LoginChallenge.String() != "" {
opts = append(opts, registration.WithFlowOAuth2LoginChallenge(loginFlow.ReturnTo))
}

registrationFlow, err := s.d.RegistrationHandler().NewRegistrationFlow(w, r, loginFlow.Type, opts...)
if err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err)
}

err = s.d.SessionTokenExchangePersister().MoveToNewFlow(r.Context(), a.ID, aa.ID)
err = s.d.SessionTokenExchangePersister().MoveToNewFlow(r.Context(), loginFlow.ID, registrationFlow.ID)
if err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err)
}

aa.RequestURL, err = x.TakeOverReturnToParameter(a.RequestURL, aa.RequestURL)
registrationFlow.RequestURL, err = x.TakeOverReturnToParameter(loginFlow.RequestURL, registrationFlow.RequestURL)
if err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err)
}

if _, err := s.processRegistration(w, r, aa, token, claims, provider, container); err != nil {
return aa, err
if _, err := s.processRegistration(w, r, registrationFlow, token, claims, provider, container); err != nil {
return registrationFlow, err
}

return nil, nil
}

return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err)
}

var o identity.CredentialsOIDC
if err := json.NewDecoder(bytes.NewBuffer(c.Config)).Decode(&o); err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The password credentials could not be decoded properly").WithDebug(err.Error())))
var oidcCredentials identity.CredentialsOIDC
if err := json.NewDecoder(bytes.NewBuffer(c.Config)).Decode(&oidcCredentials); err != nil {
return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("The password credentials could not be decoded properly").WithDebug(err.Error())))
}

sess := session.NewInactiveSession()
sess.CompletedLoginForWithProvider(s.ID(), identity.AuthenticatorAssuranceLevel1, provider.Config().ID)
for _, c := range o.Providers {
for _, c := range oidcCredentials.Providers {
if c.Subject == claims.Subject && c.Provider == provider.Config().ID {
if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, a, i, sess, provider.Config().ID); err != nil {
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID); err != nil {
return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err)
}
return nil, nil
}
}

return nil, s.handleError(w, r, a, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject)))
return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, errors.WithStack(herodot.ErrInternalServerError.WithReason("Unable to find matching OpenID Connect Credentials.").WithDebugf(`Unable to find credentials that match the given provider "%s" and subject "%s".`, provider.Config().ID, claims.Subject)))
}

func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, _ uuid.UUID) (i *identity.Identity, err error) {
Expand Down

0 comments on commit 7d6f143

Please sign in to comment.