From a059ef079ceffdec6364e174e782b171cbbdc25b Mon Sep 17 00:00:00 2001 From: hperl <34397+hperl@users.noreply.github.com> Date: Fri, 26 Aug 2022 20:59:04 +0200 Subject: [PATCH 1/3] fix: do not invalidate recovery addr on update --- identity/identity_recovery.go | 6 +++ persistence/sql/persister_identity.go | 46 +++++++++++++++++-- selfservice/strategy/link/test/persistence.go | 27 ++++++++++- 3 files changed, 72 insertions(+), 7 deletions(-) diff --git a/identity/identity_recovery.go b/identity/identity_recovery.go index 0a3a2b662086..f354a8b13304 100644 --- a/identity/identity_recovery.go +++ b/identity/identity_recovery.go @@ -2,6 +2,7 @@ package identity import ( "context" + "fmt" "time" "github.com/gofrs/uuid" @@ -55,6 +56,11 @@ func (a RecoveryAddress) ValidateNID() error { return nil } +// Hash returns a unique string representation for the recovery address. +func (a RecoveryAddress) Hash() string { + return fmt.Sprintf("%s|%s|%s|%s", a.Value, a.Via, a.IdentityID, a.NID) +} + func NewRecoveryEmailAddress( value string, identity uuid.UUID, diff --git a/persistence/sql/persister_identity.go b/persistence/sql/persister_identity.go index 22e13f1b5e62..eaa8a2f95449 100644 --- a/persistence/sql/persister_identity.go +++ b/persistence/sql/persister_identity.go @@ -202,6 +202,43 @@ func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.I return nil } +func (p *Persister) updateRecoveryAddresses(ctx context.Context, i *identity.Identity) error { + var addressesInDb []identity.RecoveryAddress + if err := p.GetConnection(ctx).Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).Order("id ASC").All(&addressesInDb); err != nil { + return err + } + + newAddresses := make(map[string]*identity.RecoveryAddress) + oldAddresses := make(map[string]*identity.RecoveryAddress) + for j, a := range i.RecoveryAddresses { + i.RecoveryAddresses[j].IdentityID = i.ID + i.RecoveryAddresses[j].NID = p.NetworkID(ctx) + i.RecoveryAddresses[j].Value = stringToLowerTrim(i.RecoveryAddresses[j].Value) + newAddresses[a.Hash()] = &i.RecoveryAddresses[j] + } + for j, a := range addressesInDb { + oldAddresses[a.Hash()] = &addressesInDb[j] + } + for h, a := range newAddresses { + if _, found := oldAddresses[h]; found { + // Ignore addresses that are already in the db + oldAddresses[h] = nil + } else { + if err := p.GetConnection(ctx).Create(a); err != nil { + return err + } + } + } + for _, a := range oldAddresses { + if a != nil { + if err := p.GetConnection(ctx).Destroy(a); err != nil { + return err + } + } + } + return nil +} + func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Identity) error { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createRecoveryAddresses") defer span.End() @@ -350,10 +387,13 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er return sql.ErrNoRows } + if err := p.updateRecoveryAddresses(ctx, i); err != nil { + return err + } + for _, tn := range []string{ new(identity.Credentials).TableName(ctx), new(identity.VerifiableAddress).TableName(ctx), - new(identity.RecoveryAddress).TableName(ctx), } { /* #nosec G201 TableName is static */ if err := tx.RawQuery(fmt.Sprintf( @@ -370,10 +410,6 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er return err } - if err := p.createRecoveryAddresses(ctx, i); err != nil { - return err - } - return p.createIdentityCredentials(ctx, i) })) } diff --git a/selfservice/strategy/link/test/persistence.go b/selfservice/strategy/link/test/persistence.go index b6ed26f00c61..f8648bf82676 100644 --- a/selfservice/strategy/link/test/persistence.go +++ b/selfservice/strategy/link/test/persistence.go @@ -93,8 +93,31 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface { assert.NotEqual(t, expected.Token, actual.Token) assert.EqualValues(t, expected.FlowID, actual.FlowID) - _, err = p.UseRecoveryToken(ctx, f.ID, expected.Token) - require.Error(t, err) + t.Run("double spend", func(t *testing.T) { + _, err = p.UseRecoveryToken(ctx, f.ID, expected.Token) + require.Error(t, err) + }) + }) + + t.Run("case=update to identity should not invalidate token", func(t *testing.T) { + expected, f := newRecoveryToken(t, "some-user@ory.sh") + + require.NoError(t, p.CreateRecoveryToken(ctx, expected)) + id, err := p.GetIdentity(ctx, expected.IdentityID) + require.NoError(t, err) + require.NoError(t, p.UpdateIdentity(ctx, id)) + + actual, err := p.UseRecoveryToken(ctx, f.ID, expected.Token) + require.NoError(t, err) + assert.Equal(t, nid, actual.NID) + assert.Equal(t, expected.IdentityID, actual.IdentityID) + assert.NotEqual(t, expected.Token, actual.Token) + assert.EqualValues(t, expected.FlowID, actual.FlowID) + + t.Run("double spend", func(t *testing.T) { + _, err = p.UseRecoveryToken(ctx, f.ID, expected.Token) + require.Error(t, err) + }) }) }) From ee917e7431727cb2e06f912a409cf912412c7e35 Mon Sep 17 00:00:00 2001 From: hperl <34397+hperl@users.noreply.github.com> Date: Mon, 29 Aug 2022 11:15:22 +0200 Subject: [PATCH 2/3] feat: add diffable update for verifiable addresses --- go.mod | 2 +- identity/identity_verification.go | 5 ++ package-lock.json | 5 +- persistence/sql/persister_identity.go | 99 +++++++++++++++++---------- persistence/sql/persister_test.go | 2 +- 5 files changed, 71 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index c2517839d518..6d9399cec73d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ory/kratos -go 1.17 +go 1.18 replace ( github.com/bradleyjkemp/cupaloy/v2 => github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3 diff --git a/identity/identity_verification.go b/identity/identity_verification.go index 0fc757d1418c..83bce7366ac5 100644 --- a/identity/identity_verification.go +++ b/identity/identity_verification.go @@ -2,6 +2,7 @@ package identity import ( "context" + "fmt" "time" "github.com/gofrs/uuid" @@ -129,3 +130,7 @@ func (a VerifiableAddress) GetNID() uuid.UUID { func (a VerifiableAddress) ValidateNID() error { return nil } + +func (a VerifiableAddress) Hash() string { + return fmt.Sprintf("%s|%v|%s|%s|%s|%s", a.Value, a.Verified, a.Via, a.Status, a.IdentityID, a.NID) +} diff --git a/package-lock.json b/package-lock.json index cb26077ca427..6429c6d8531f 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,5 +1,5 @@ { - "name": "tmp.T01PPIJfY2", + "name": "kratos", "lockfileVersion": 2, "requires": true, "packages": { @@ -5088,7 +5088,8 @@ "version": "7.5.7", "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.7.tgz", "integrity": "sha512-KMvVuFzpKBuiIXW3E4u3mySRO2/mCHSyZDJQM5NQ9Q9KHWHWh0NHgfbRMLLrceUK5qAL4ytALJbpRMjixFZh8A==", - "dev": true + "dev": true, + "requires": {} }, "y18n": { "version": "5.0.8", diff --git a/persistence/sql/persister_identity.go b/persistence/sql/persister_identity.go index eaa8a2f95449..c4230fd2054d 100644 --- a/persistence/sql/persister_identity.go +++ b/persistence/sql/persister_identity.go @@ -192,9 +192,6 @@ func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.I defer span.End() for k := range i.VerifiableAddresses { - i.VerifiableAddresses[k].IdentityID = i.ID - i.VerifiableAddresses[k].NID = p.NetworkID(ctx) - i.VerifiableAddresses[k].Value = stringToLowerTrim(i.VerifiableAddresses[k].Value) if err := p.GetConnection(ctx).Create(&i.VerifiableAddresses[k]); err != nil { return err } @@ -202,51 +199,76 @@ func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.I return nil } -func (p *Persister) updateRecoveryAddresses(ctx context.Context, i *identity.Identity) error { - var addressesInDb []identity.RecoveryAddress - if err := p.GetConnection(ctx).Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).Order("id ASC").All(&addressesInDb); err != nil { +func updateAssociation[T interface { + Hash() string +}](ctx context.Context, p *Persister, i *identity.Identity, inID []T) error { + var inDB []T + if err := p.GetConnection(ctx). + Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)). + Order("id ASC"). + All(&inDB); err != nil { + return err } - newAddresses := make(map[string]*identity.RecoveryAddress) - oldAddresses := make(map[string]*identity.RecoveryAddress) - for j, a := range i.RecoveryAddresses { - i.RecoveryAddresses[j].IdentityID = i.ID - i.RecoveryAddresses[j].NID = p.NetworkID(ctx) - i.RecoveryAddresses[j].Value = stringToLowerTrim(i.RecoveryAddresses[j].Value) - newAddresses[a.Hash()] = &i.RecoveryAddresses[j] + newAssocs := make(map[string]*T) + oldAssocs := make(map[string]*T) + for i, a := range inID { + newAssocs[a.Hash()] = &inID[i] } - for j, a := range addressesInDb { - oldAddresses[a.Hash()] = &addressesInDb[j] + for i, a := range inDB { + oldAssocs[a.Hash()] = &inDB[i] } - for h, a := range newAddresses { - if _, found := oldAddresses[h]; found { - // Ignore addresses that are already in the db - oldAddresses[h] = nil + + // Subtle: we delete the old associations from the DB first, because else + // they could cause UNIQUE constraints to fail on insert. + for h, a := range oldAssocs { + if _, found := newAssocs[h]; found { + newAssocs[h] = nil // Ignore associations that are already in the db. } else { - if err := p.GetConnection(ctx).Create(a); err != nil { + if err := p.GetConnection(ctx).Destroy(a); err != nil { return err } } } - for _, a := range oldAddresses { + + for _, a := range newAssocs { if a != nil { - if err := p.GetConnection(ctx).Destroy(a); err != nil { + if err := p.GetConnection(ctx).Create(a); err != nil { return err } } } + return nil } +func (p *Persister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) { + p.normalizeRecoveryAddresses(ctx, id) + p.normalizeVerifiableAddresses(ctx, id) +} + +func (p *Persister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) { + for k := range id.VerifiableAddresses { + id.VerifiableAddresses[k].IdentityID = id.ID + id.VerifiableAddresses[k].NID = p.NetworkID(ctx) + id.VerifiableAddresses[k].Value = stringToLowerTrim(id.VerifiableAddresses[k].Value) + } +} + +func (p *Persister) normalizeRecoveryAddresses(ctx context.Context, id *identity.Identity) { + for k := range id.RecoveryAddresses { + id.RecoveryAddresses[k].IdentityID = id.ID + id.RecoveryAddresses[k].NID = p.NetworkID(ctx) + id.RecoveryAddresses[k].Value = stringToLowerTrim(id.RecoveryAddresses[k].Value) + } +} + func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Identity) error { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createRecoveryAddresses") defer span.End() for k := range i.RecoveryAddresses { - i.RecoveryAddresses[k].IdentityID = i.ID - i.RecoveryAddresses[k].NID = p.NetworkID(ctx) - i.RecoveryAddresses[k].Value = stringToLowerTrim(i.RecoveryAddresses[k].Value) if err := p.GetConnection(ctx).Create(&i.RecoveryAddresses[k]); err != nil { return err } @@ -322,6 +344,8 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er return sqlcon.HandleError(err) } + p.normalizeAllAddressess(ctx, i) + if err := p.createVerifiableAddresses(ctx, i); err != nil { return sqlcon.HandleError(err) } @@ -387,26 +411,25 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er return sql.ErrNoRows } - if err := p.updateRecoveryAddresses(ctx, i); err != nil { + p.normalizeAllAddressess(ctx, i) + if err := updateAssociation(ctx, p, i, i.RecoveryAddresses); err != nil { return err } - - for _, tn := range []string{ - new(identity.Credentials).TableName(ctx), - new(identity.VerifiableAddress).TableName(ctx), - } { - /* #nosec G201 TableName is static */ - if err := tx.RawQuery(fmt.Sprintf( - `DELETE FROM %s WHERE identity_id = ? AND nid = ?`, tn), i.ID, p.NetworkID(ctx)).Exec(); err != nil { - return err - } + if err := updateAssociation(ctx, p, i, i.VerifiableAddresses); err != nil { + return err } - if err := p.update(WithTransaction(ctx, tx), i); err != nil { + /* #nosec G201 TableName is static */ + if err := tx.RawQuery( + fmt.Sprintf( + `DELETE FROM %s WHERE identity_id = ? AND nid = ?`, + new(identity.Credentials).TableName(ctx)), + i.ID, p.NetworkID(ctx)).Exec(); err != nil { + return err } - if err := p.createVerifiableAddresses(ctx, i); err != nil { + if err := p.update(WithTransaction(ctx, tx), i); err != nil { return err } diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 0474a6b7324f..ae3eb1f8e001 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -285,7 +285,7 @@ func TestPersister_Transaction(t *testing.T) { Traits: ri.Traits(`{}`), } errMessage := "failing because why not" - err := p.Transaction(context.Background(), func(ctx context.Context, connection *pop.Connection) error { + err := p.Transaction(context.Background(), func(_ context.Context, connection *pop.Connection) error { require.NoError(t, connection.Create(i)) return errors.Errorf(errMessage) }) From bece07609f1d46f8e24da4f51468edbb4a8847b3 Mon Sep 17 00:00:00 2001 From: hperl <34397+hperl@users.noreply.github.com> Date: Fri, 2 Sep 2022 12:42:26 +0200 Subject: [PATCH 3/3] test: add tests for Hash() --- identity/identity_recovery.go | 2 +- identity/identity_recovery_test.go | 41 +++++++++++++- identity/identity_verification.go | 3 +- identity/identity_verification_test.go | 76 ++++++++++++++++++++++++++ persistence/sql/persister_identity.go | 8 +-- 5 files changed, 123 insertions(+), 7 deletions(-) diff --git a/identity/identity_recovery.go b/identity/identity_recovery.go index f354a8b13304..234b0341d460 100644 --- a/identity/identity_recovery.go +++ b/identity/identity_recovery.go @@ -58,7 +58,7 @@ func (a RecoveryAddress) ValidateNID() error { // Hash returns a unique string representation for the recovery address. func (a RecoveryAddress) Hash() string { - return fmt.Sprintf("%s|%s|%s|%s", a.Value, a.Via, a.IdentityID, a.NID) + return fmt.Sprintf("%v|%v|%v|%v", a.Value, a.Via, a.IdentityID, a.NID) } func NewRecoveryEmailAddress( diff --git a/identity/identity_recovery_test.go b/identity/identity_recovery_test.go index f3f2dd71de32..b96fac0f9ed9 100644 --- a/identity/identity_recovery_test.go +++ b/identity/identity_recovery_test.go @@ -2,9 +2,9 @@ package identity import ( "testing" + "time" "github.com/gofrs/uuid" - "github.com/stretchr/testify/assert" "github.com/ory/kratos/x" @@ -19,3 +19,42 @@ func TestNewRecoveryEmailAddress(t *testing.T) { assert.Equal(t, iid, a.IdentityID) assert.Equal(t, uuid.Nil, a.ID) } + +// TestRecoveryAddress_Hash tests that the hash considers all fields that are +// written to the database (ignoring some well-known fields like the ID or +// timestamps). +func TestRecoveryAddress_Hash(t *testing.T) { + cases := []struct { + name string + a RecoveryAddress + }{ + { + name: "full fields", + a: RecoveryAddress{ + ID: x.NewUUID(), + Value: "foo@bar.me", + Via: AddressTypeEmail, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + IdentityID: x.NewUUID(), + NID: x.NewUUID(), + }, + }, { + name: "empty fields", + a: RecoveryAddress{}, + }, { + name: "constructor", + a: *NewRecoveryEmailAddress("foo@ory.sh", x.NewUUID()), + }, + } + + for _, tc := range cases { + t.Run("case="+tc.name, func(t *testing.T) { + assert.Equal(t, + reflectiveHash(tc.a), + tc.a.Hash(), + ) + }) + } + +} diff --git a/identity/identity_verification.go b/identity/identity_verification.go index 83bce7366ac5..d1e6e55c3a43 100644 --- a/identity/identity_verification.go +++ b/identity/identity_verification.go @@ -131,6 +131,7 @@ func (a VerifiableAddress) ValidateNID() error { return nil } +// Hash returns a unique string representation for the recovery address. func (a VerifiableAddress) Hash() string { - return fmt.Sprintf("%s|%v|%s|%s|%s|%s", a.Value, a.Verified, a.Via, a.Status, a.IdentityID, a.NID) + return fmt.Sprintf("%v|%v|%v|%v|%v|%v", a.Value, a.Verified, a.Via, a.Status, a.IdentityID, a.NID) } diff --git a/identity/identity_verification_test.go b/identity/identity_verification_test.go index e81901d88211..5750c225d5be 100644 --- a/identity/identity_verification_test.go +++ b/identity/identity_verification_test.go @@ -1,7 +1,11 @@ package identity import ( + "fmt" + "reflect" + "strings" "testing" + "time" "github.com/gofrs/uuid" @@ -25,3 +29,75 @@ func TestNewVerifiableEmailAddress(t *testing.T) { assert.Equal(t, iid, a.IdentityID) assert.Equal(t, uuid.Nil, a.ID) } + +var tagsIgnoredForHashing = map[string]struct{}{ + "id": {}, + "created_at": {}, + "updated_at": {}, + "verified_at": {}, +} + +func reflectiveHash(record any) string { + var ( + val = reflect.ValueOf(record) + typ = reflect.TypeOf(record) + values = []string{} + ) + for i := 0; i < val.NumField(); i++ { + dbTag, ok := typ.Field(i).Tag.Lookup("db") + if !ok { + continue + } + if _, ignore := tagsIgnoredForHashing[dbTag]; ignore { + continue + } + if !val.Field(i).CanInterface() { + continue + } + values = append(values, fmt.Sprintf("%v", val.Field(i).Interface())) + } + return strings.Join(values, "|") +} + +// TestVerifiableAddress_Hash tests that the hash considers all fields that are +// written to the database (ignoring some well-known fields like the ID or +// timestamps). +func TestVerifiableAddress_Hash(t *testing.T) { + now := sqlxx.NullTime(time.Now()) + cases := []struct { + name string + a VerifiableAddress + }{ + { + name: "full fields", + a: VerifiableAddress{ + ID: x.NewUUID(), + Value: "foo@bar.me", + Verified: false, + Via: AddressTypeEmail, + Status: VerifiableAddressStatusPending, + VerifiedAt: &now, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + IdentityID: x.NewUUID(), + NID: x.NewUUID(), + }, + }, { + name: "empty fields", + a: VerifiableAddress{}, + }, { + name: "constructor", + a: *NewVerifiableEmailAddress("foo@ory.sh", x.NewUUID()), + }, + } + + for _, tc := range cases { + t.Run("case="+tc.name, func(t *testing.T) { + assert.Equal(t, + reflectiveHash(tc.a), + tc.a.Hash(), + ) + }) + } + +} diff --git a/persistence/sql/persister_identity.go b/persistence/sql/persister_identity.go index c4230fd2054d..d6bdba318cca 100644 --- a/persistence/sql/persister_identity.go +++ b/persistence/sql/persister_identity.go @@ -208,7 +208,7 @@ func updateAssociation[T interface { Order("id ASC"). All(&inDB); err != nil { - return err + return sqlcon.HandleError(err) } newAssocs := make(map[string]*T) @@ -227,7 +227,7 @@ func updateAssociation[T interface { newAssocs[h] = nil // Ignore associations that are already in the db. } else { if err := p.GetConnection(ctx).Destroy(a); err != nil { - return err + return sqlcon.HandleError(err) } } } @@ -235,7 +235,7 @@ func updateAssociation[T interface { for _, a := range newAssocs { if a != nil { if err := p.GetConnection(ctx).Create(a); err != nil { - return err + return sqlcon.HandleError(err) } } } @@ -426,7 +426,7 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er new(identity.Credentials).TableName(ctx)), i.ID, p.NetworkID(ctx)).Exec(); err != nil { - return err + return sqlcon.HandleError(err) } if err := p.update(WithTransaction(ctx, tx), i); err != nil {