Skip to content

Commit

Permalink
feat: link credentials when second login is OIDC (CORE-2041)
Browse files Browse the repository at this point in the history
  • Loading branch information
splaunov committed May 8, 2023
1 parent ef9918f commit 2a0b706
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 88 deletions.
1 change: 1 addition & 0 deletions cmd/clidoc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func init() {
"NewInfoSelfServiceSettingsUpdateLinkSAML": text.NewInfoSelfServiceSettingsUpdateLinkSAML(),
"NewInfoSelfServiceSettingsUpdateUnlinkSAML": text.NewInfoSelfServiceSettingsUpdateUnlinkSAML("{provider}"),
"NewInfoSelfServiceLoginLinkCredentials": text.NewInfoSelfServiceLoginLinkCredentials(),
"NewErrorValidationLoginLinkedCredentialsDoNotMatch": text.NewErrorValidationLoginLinkedCredentialsDoNotMatch(),
"NewErrorValidationSAMLProviderNotFound": text.NewErrorValidationSAMLProviderNotFound(),
}
}
Expand Down
10 changes: 10 additions & 0 deletions schema/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,13 @@ func NewNoWebAuthnCredentials() error {
Messages: new(text.Messages).Add(text.NewErrorValidationSuchNoWebAuthnUser()),
})
}

func NewLinkedCredentialsDoNotMatch() error {
return errors.WithStack(&ValidationError{
ValidationError: &jsonschema.ValidationError{
Message: `linked credentials do not match`,
InstancePtr: "#/",
},
Messages: new(text.Messages).Add(text.NewErrorValidationLoginLinkedCredentialsDoNotMatch()),
})
}
5 changes: 3 additions & 2 deletions selfservice/flow/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ const InternalContextDuplicateCredentialsPath = "registration_duplicate_credenti
const InternalContextLinkCredentialsPath = "link_credentials"

type RegistrationDuplicateCredentials struct {
CredentialsType identity.CredentialsType
CredentialsConfig sqlxx.JSONRawMessage
CredentialsType identity.CredentialsType
CredentialsConfig sqlxx.JSONRawMessage
DuplicateIdentifier string
}

func AppendFlowTo(src *url.URL, id uuid.UUID) *url.URL {
Expand Down
74 changes: 0 additions & 74 deletions selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package login
import (
_ "embed"
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"net/http"
Expand Down Expand Up @@ -753,11 +752,6 @@ continueLogin:
sess = session.NewInactiveSession()
}

if err := h.linkCredentials(r, sess, interim, f); err != nil {
h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, group, err)
return
}

method := ss.CompletedAuthenticationMethod(r.Context())
sess.CompletedLoginFor(method.Method, method.AAL)
i = interim
Expand All @@ -779,71 +773,3 @@ continueLogin:
return
}
}

func (h *Handler) linkCredentials(r *http.Request, s *session.Session, i *identity.Identity, f *Flow) error {
var lc flow.RegistrationDuplicateCredentials

var p struct {
FlowID string `json:"linkCredentialsFlow" form:"linkCredentialsFlow"`
}

if err := h.hd.Decode(r, &p,
decoderx.HTTPDecoderSetValidatePayloads(true),
decoderx.MustHTTPRawJSONSchemaCompiler(linkCredentialsSchema),
decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil {
return err
}

if p.FlowID != "" {
linkCredentialsFlowID, innerErr := uuid.FromString(p.FlowID)
if innerErr != nil {
return innerErr
}
linkCredentialsFlow, innerErr := h.d.LoginFlowPersister().GetLoginFlow(r.Context(), linkCredentialsFlowID)
if innerErr != nil {
return innerErr
}
innerErr = h.getInternalContextLinkCredentials(linkCredentialsFlow, flow.InternalContextDuplicateCredentialsPath, &lc)
if innerErr != nil {
return innerErr
}
}

if lc.CredentialsType == "" {
err := h.getInternalContextLinkCredentials(f, flow.InternalContextLinkCredentialsPath, &lc)
if err != nil {
return err
}
}

if lc.CredentialsType != "" {
strategy, err := h.d.AllLoginStrategies().Strategy(lc.CredentialsType)
if err != nil {
return err
}

linkableStrategy, ok := strategy.(LinkableStrategy)
if !ok {
return errors.New(fmt.Sprintf("Strategy is not linkable: %T", linkableStrategy))
}

if err := linkableStrategy.Link(r.Context(), i, lc.CredentialsConfig); err != nil {
return err
}

method := strategy.CompletedAuthenticationMethod(r.Context())
s.CompletedLoginFor(method.Method, method.AAL)
}

return nil
}

func (h *Handler) getInternalContextLinkCredentials(f *Flow, internalContextPath string, lc *flow.RegistrationDuplicateCredentials) error {
internalContextLinkCredentials := gjson.GetBytes(f.InternalContext, internalContextPath)
if internalContextLinkCredentials.IsObject() {
if err := json.Unmarshal([]byte(internalContextLinkCredentials.Raw), lc); err != nil {
return err
}
}
return nil
}
100 changes: 100 additions & 0 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ package login

import (
"context"
"encoding/json"
"fmt"
"github.com/gofrs/uuid"
"github.com/ory/kratos/schema"
"github.com/ory/x/decoderx"
"github.com/tidwall/gjson"
"net/http"
"path"
"time"
Expand Down Expand Up @@ -44,13 +49,16 @@ type (
executorDependencies interface {
config.Provider
hydra.HydraProvider
identity.PrivilegedPoolProvider
session.ManagementProvider
session.PersistenceProvider
x.CSRFTokenGeneratorProvider
x.WriterProvider
x.LoggingProvider

FlowPersistenceProvider
HooksProvider
StrategyProvider
}
HookExecutor struct {
d executorDependencies
Expand Down Expand Up @@ -111,6 +119,10 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request,
}

func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, i *identity.Identity, s *session.Session) error {
if err := e.linkCredentials(r, s, i, a); err != nil {
return err
}

if err := s.Activate(r, i, e.d.Config(), time.Now().UTC()); err != nil {
return err
}
Expand Down Expand Up @@ -261,6 +273,21 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g n
return nil
}

func (e *HookExecutor) checkDuplecateCredentialsIdentifierMatch(ctx context.Context, identityID uuid.UUID, match string) error {
i, err := e.d.PrivilegedIdentityPool().GetIdentityConfidential(ctx, identityID)
if err != nil {
return err
}
for _, credentials := range i.Credentials {
for _, identifier := range credentials.Identifiers {
if identifier == match {
return nil
}
}
}
return schema.NewLinkedCredentialsDoNotMatch()
}

func (e *HookExecutor) PreLoginHook(w http.ResponseWriter, r *http.Request, a *Flow) error {
for _, executor := range e.d.PreLoginHooks(r.Context()) {
if err := executor.ExecuteLoginPreHook(w, r, a); err != nil {
Expand All @@ -270,3 +297,76 @@ func (e *HookExecutor) PreLoginHook(w http.ResponseWriter, r *http.Request, a *F

return nil
}

func (e *HookExecutor) linkCredentials(r *http.Request, s *session.Session, i *identity.Identity, f *Flow) error {
var lc flow.RegistrationDuplicateCredentials

if r.Method == "POST" {
var p struct {
FlowID string `json:"linkCredentialsFlow" form:"linkCredentialsFlow"`
}

if err := decoderx.NewHTTP().Decode(r, &p,
decoderx.HTTPDecoderSetValidatePayloads(true),
decoderx.MustHTTPRawJSONSchemaCompiler(linkCredentialsSchema),
decoderx.HTTPDecoderJSONFollowsFormFormat()); err != nil {
return err
}

if p.FlowID != "" {
linkCredentialsFlowID, innerErr := uuid.FromString(p.FlowID)
if innerErr != nil {
return innerErr
}
linkCredentialsFlow, innerErr := e.d.LoginFlowPersister().GetLoginFlow(r.Context(), linkCredentialsFlowID)
if innerErr != nil {
return innerErr
}
innerErr = e.getInternalContextLinkCredentials(linkCredentialsFlow, flow.InternalContextDuplicateCredentialsPath, &lc)
if innerErr != nil {
return innerErr
}
}
}

if lc.CredentialsType == "" {
err := e.getInternalContextLinkCredentials(f, flow.InternalContextLinkCredentialsPath, &lc)
if err != nil {
return err
}
}

if lc.CredentialsType != "" {
if err := e.checkDuplecateCredentialsIdentifierMatch(r.Context(), i.ID, lc.DuplicateIdentifier); err != nil {
return err
}
strategy, err := e.d.AllLoginStrategies().Strategy(lc.CredentialsType)
if err != nil {
return err
}

linkableStrategy, ok := strategy.(LinkableStrategy)
if !ok {
return errors.New(fmt.Sprintf("Strategy is not linkable: %T", linkableStrategy))
}

if err := linkableStrategy.Link(r.Context(), i, lc.CredentialsConfig); err != nil {
return err
}

method := strategy.CompletedAuthenticationMethod(r.Context())
s.CompletedLoginFor(method.Method, method.AAL)
}

return nil
}

func (e *HookExecutor) getInternalContextLinkCredentials(f *Flow, internalContextPath string, lc *flow.RegistrationDuplicateCredentials) error {
internalContextLinkCredentials := gjson.GetBytes(f.InternalContext, internalContextPath)
if internalContextLinkCredentials.IsObject() {
if err := json.Unmarshal([]byte(internalContextLinkCredentials.Raw), lc); err != nil {
return err
}
}
return nil
}
26 changes: 24 additions & 2 deletions selfservice/flow/registration/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ type (
executorDependencies interface {
config.Provider
identity.ManagementProvider
identity.PrivilegedPoolProvider
identity.ValidationProvider
login.FlowPersistenceProvider
login.StrategyProvider
Expand Down Expand Up @@ -151,9 +152,14 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
_, ok := strategy.(login.LinkableStrategy)

if ok {
duplicateIdentifier, err := e.getDuplicateIdentifier(r.Context(), i)
if err != nil {
return err
}
registrationDuplicateCredentials := flow.RegistrationDuplicateCredentials{
CredentialsType: ct,
CredentialsConfig: i.Credentials[ct].Config,
CredentialsType: ct,
CredentialsConfig: i.Credentials[ct].Config,
DuplicateIdentifier: duplicateIdentifier,
}
loginFlowID, err := a.GetOuterLoginFlowID()
if err != nil {
Expand Down Expand Up @@ -290,6 +296,22 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
return nil
}

func (e *HookExecutor) getDuplicateIdentifier(ctx context.Context, i *identity.Identity) (string, error) {
for ct, credentials := range i.Credentials {
for _, identifier := range credentials.Identifiers {
_, _, err := e.d.PrivilegedIdentityPool().FindByCredentialsIdentifier(ctx, ct, identifier)
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
continue
}
return "", err
}
return identifier, nil
}
}
return "", errors.New("Duplicate credential not found")
}

func (e *HookExecutor) PreRegistrationHook(w http.ResponseWriter, r *http.Request, a *Flow) error {
for _, executor := range e.d.PreRegistrationHooks(r.Context()) {
if err := executor.ExecuteRegistrationPreHook(w, r, a); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions selfservice/strategy/saml/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,11 @@ func TestStrategy(t *testing.T) {
"%s", string(identityConfidential.Credentials["saml"].Config[:]))
path := ""
if isAPI {
path = "session.authentication_methods.0.method"
path = "session.authentication_methods"
} else {
path = "authentication_methods.0.method"
path = "authentication_methods"
}
assert.Equal(t, "saml", gjson.Get(body, path).String(), "%s", body)
assert.Contains(t, gjson.Get(body, path).String(), "saml", "%s", body)
}

t.Run("case=browser", func(t *testing.T) {
Expand Down
15 changes: 8 additions & 7 deletions text/id.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,14 @@ const (
)

const (
ErrorValidationLogin ID = 4010000 + iota // 4010000
ErrorValidationLoginFlowExpired // 4010001
ErrorValidationLoginNoStrategyFound // 4010002
ErrorValidationRegistrationNoStrategyFound // 4010003
ErrorValidationSettingsNoStrategyFound // 4010004
ErrorValidationRecoveryNoStrategyFound // 4010005
ErrorValidationVerificationNoStrategyFound // 4010006
ErrorValidationLogin ID = 4010000 + iota // 4010000
ErrorValidationLoginFlowExpired // 4010001
ErrorValidationLoginNoStrategyFound // 4010002
ErrorValidationRegistrationNoStrategyFound // 4010003
ErrorValidationSettingsNoStrategyFound // 4010004
ErrorValidationRecoveryNoStrategyFound // 4010005
ErrorValidationVerificationNoStrategyFound // 4010006
ErrorValidationLoginLinkedCredentialsDoNotMatch // 4010007
)

const (
Expand Down
8 changes: 8 additions & 0 deletions text/message_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,11 @@ func NewInfoSelfServiceLoginLinkCredentials() *Message {
Context: context(nil),
}
}

func NewErrorValidationLoginLinkedCredentialsDoNotMatch() *Message {
return &Message{
ID: ErrorValidationLoginLinkedCredentialsDoNotMatch,
Text: "Linked credentials do not match.",
Type: Error,
}
}

0 comments on commit 2a0b706

Please sign in to comment.