Skip to content

Commit

Permalink
fix(identity): slow query performance on MySQL
Browse files Browse the repository at this point in the history
Closes #2278
  • Loading branch information
aeneasr committed Mar 7, 2022
1 parent 0833321 commit 731b3c7
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
35 changes: 26 additions & 9 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"fmt"
"github.com/ory/x/randx"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -567,10 +568,25 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
require.Equal(t, sqlcon.ErrNoRows, errorsx.Cause(err))
})

transform := func(k int, value string) string {
switch k % 5 {
case 0:
value = strings.ToLower(value)
case 1:
value = strings.ToUpper(value)
case 2:
value = " " + value
case 3:
value = value + " "
}
return value
}

t.Run("case=create and find", func(t *testing.T) {
addresses := make([]identity.VerifiableAddress, 15)
for k := range addresses {
addresses[k] = createIdentityWithAddresses(t, "recovery.TestPersister.Create"+strconv.Itoa(k)+"@ory.sh")
value := randx.MustString(16, randx.AlphaLowerNum) + "@ory.sh"
addresses[k] = createIdentityWithAddresses(t, transform(k, value))
require.NotEmpty(t, addresses[k].ID)
}

Expand All @@ -579,19 +595,20 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
actual.UpdatedAt = actual.UpdatedAt.UTC().Truncate(time.Hour * 24)
expected.CreatedAt = expected.CreatedAt.UTC().Truncate(time.Hour * 24)
expected.UpdatedAt = expected.UpdatedAt.UTC().Truncate(time.Hour * 24)
expected.Value = strings.TrimSpace(strings.ToLower(expected.Value))
assert.EqualValues(t, expected, actual)
}

for k, expected := range addresses {
t.Run("method=FindVerifiableAddressByValue", func(t *testing.T) {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
actual, err := p.FindVerifiableAddressByValue(ctx, expected.Via, expected.Value)
actual, err := p.FindVerifiableAddressByValue(ctx, expected.Via, transform(k+1, expected.Value))
require.NoError(t, err)
compare(t, expected, *actual)

t.Run("not if on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
_, err := p.FindVerifiableAddressByValue(ctx, expected.Via, expected.Value)
_, err := p.FindVerifiableAddressByValue(ctx, expected.Via, transform(k+1, expected.Value))
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})
})
Expand All @@ -600,9 +617,9 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
})

t.Run("case=update", func(t *testing.T) {
address := createIdentityWithAddresses(t, "[email protected]")
address := createIdentityWithAddresses(t, "[email protected] ")

address.Value = "new-code"
address.Value = "new-codE "
require.NoError(t, p.UpdateVerifiableAddress(ctx, &address))

t.Run("not if on another network", func(t *testing.T) {
Expand Down Expand Up @@ -648,7 +665,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
actual, err := p.FindVerifiableAddressByValue(ctx, identity.VerifiableAddressTypeEmail, "[email protected]")
require.NoError(t, err)
assert.Equal(t, identity.VerifiableAddressTypeEmail, actual.Via)
assert.Equal(t, "verification.TestPersister.Update-Identity[email protected]", actual.Value)
assert.Equal(t, "verification.testpersister.update-identity[email protected]", actual.Value)

t.Run("can not find if on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
Expand Down Expand Up @@ -690,7 +707,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
actual, err := p.FindVerifiableAddressByValue(ctx, identity.VerifiableAddressTypeEmail, strings.ToUpper("verification.TestPersister.Update-Identity-case-insensitive-next@ory.sh"))
require.NoError(t, err)
assert.Equal(t, identity.VerifiableAddressTypeEmail, actual.Via)
assert.Equal(t, "verification.TestPersister.Update-Identity[email protected]", actual.Value)
assert.Equal(t, "verification.testpersister.update-identity[email protected]", actual.Value)

t.Run("can not find if on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
Expand Down Expand Up @@ -775,7 +792,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
actual, err := p.FindRecoveryAddressByValue(ctx, identity.RecoveryAddressTypeEmail, "[email protected]")
require.NoError(t, err)
assert.Equal(t, identity.RecoveryAddressTypeEmail, actual.Via)
assert.Equal(t, "recovery.TestPersister.Update[email protected]", actual.Value)
assert.Equal(t, "recovery.testpersister.update[email protected]", actual.Value)

t.Run("can not find if on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
Expand Down Expand Up @@ -811,7 +828,7 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
actual, err := p.FindRecoveryAddressByValue(ctx, identity.RecoveryAddressTypeEmail, strings.ToUpper("[email protected]"))
require.NoError(t, err)
assert.Equal(t, identity.RecoveryAddressTypeEmail, actual.Via)
assert.Equal(t, "recovery.TestPersister.Update[email protected]", actual.Value)
assert.Equal(t, "recovery.testpersister.update[email protected]", actual.Value)

t.Run("can not find if on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
UPDATE identity_recovery_addresses SET value = LOWER(value) WHERE TRUE;
UPDATE identity_verification_addresses SET value = LOWER(value) WHERE TRUE;
13 changes: 10 additions & 3 deletions persistence/sql/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func (p *Persister) ListRecoveryAddresses(ctx context.Context, page, itemsPerPag
return a, err
}

func stringToLowerTrim(match string) string {
return strings.ToLower(strings.TrimSpace(match))
}

func (p *Persister) normalizeIdentifier(ct identity.CredentialsType, match string) string {
switch ct {
case identity.CredentialsTypeLookup:
Expand All @@ -61,7 +65,7 @@ func (p *Persister) normalizeIdentifier(ct identity.CredentialsType, match strin
case identity.CredentialsTypePassword:
fallthrough
case identity.CredentialsTypeWebAuthn:
return strings.ToLower(strings.TrimSpace(match))
return stringToLowerTrim(match)
}
return match
}
Expand Down Expand Up @@ -179,6 +183,7 @@ func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.I
for k := range i.VerifiableAddresses {
i.VerifiableAddresses[k].IdentityID = i.ID
i.VerifiableAddresses[k].NID = corp.ContextualizeNID(ctx, p.nid)
i.VerifiableAddresses[k].Value = stringToLowerTrim(i.VerifiableAddresses[k].Value)
if err := p.GetConnection(ctx).Create(&i.VerifiableAddresses[k]); err != nil {
return err
}
Expand All @@ -190,6 +195,7 @@ func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Ide
for k := range i.RecoveryAddresses {
i.RecoveryAddresses[k].IdentityID = i.ID
i.RecoveryAddresses[k].NID = corp.ContextualizeNID(ctx, p.nid)
i.RecoveryAddresses[k].Value = stringToLowerTrim(i.RecoveryAddresses[k].Value)
if err := p.GetConnection(ctx).Create(&i.RecoveryAddresses[k]); err != nil {
return err
}
Expand Down Expand Up @@ -423,7 +429,7 @@ func (p *Persister) GetIdentityConfidential(ctx context.Context, id uuid.UUID) (

func (p *Persister) FindVerifiableAddressByValue(ctx context.Context, via identity.VerifiableAddressType, value string) (*identity.VerifiableAddress, error) {
var address identity.VerifiableAddress
if err := p.GetConnection(ctx).Where("nid = ? AND via = ? AND LOWER(value) = ?", corp.ContextualizeNID(ctx, p.nid), via, strings.ToLower(value)).First(&address); err != nil {
if err := p.GetConnection(ctx).Where("nid = ? AND via = ? AND value = ?", corp.ContextualizeNID(ctx, p.nid), via, stringToLowerTrim(value)).First(&address); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand All @@ -432,7 +438,7 @@ func (p *Persister) FindVerifiableAddressByValue(ctx context.Context, via identi

func (p *Persister) FindRecoveryAddressByValue(ctx context.Context, via identity.RecoveryAddressType, value string) (*identity.RecoveryAddress, error) {
var address identity.RecoveryAddress
if err := p.GetConnection(ctx).Where("nid = ? AND via = ? AND LOWER(value) = ?", corp.ContextualizeNID(ctx, p.nid), via, strings.ToLower(value)).First(&address); err != nil {
if err := p.GetConnection(ctx).Where("nid = ? AND via = ? AND value = ?", corp.ContextualizeNID(ctx, p.nid), via, stringToLowerTrim(value)).First(&address); err != nil {
return nil, sqlcon.HandleError(err)
}

Expand Down Expand Up @@ -471,6 +477,7 @@ func (p *Persister) VerifyAddress(ctx context.Context, code string) error {

func (p *Persister) UpdateVerifiableAddress(ctx context.Context, address *identity.VerifiableAddress) error {
address.NID = corp.ContextualizeNID(ctx, p.nid)
address.Value = stringToLowerTrim(address.Value)
return p.update(ctx, address)
}

Expand Down

0 comments on commit 731b3c7

Please sign in to comment.