Skip to content

Commit

Permalink
fix: restore original structure
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Jan 4, 2023
1 parent 73f8313 commit 75387ec
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 171 deletions.
24 changes: 12 additions & 12 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,28 @@ type (
r managerDependencies
}

managerOptions struct {
ManagerOptions struct {
ExposeValidationErrors bool
AllowWriteProtectedTraits bool
}

ManagerOption func(*managerOptions)
ManagerOption func(*ManagerOptions)
)

func NewManager(r managerDependencies) *Manager {
return &Manager{r: r}
}

func ManagerExposeValidationErrorsForInternalTypeAssertion(options *managerOptions) {
func ManagerExposeValidationErrorsForInternalTypeAssertion(options *ManagerOptions) {
options.ExposeValidationErrors = true
}

func ManagerAllowWriteProtectedTraits(options *managerOptions) {
func ManagerAllowWriteProtectedTraits(options *ManagerOptions) {
options.AllowWriteProtectedTraits = true
}

func newManagerOptions(opts []ManagerOption) *managerOptions {
var o managerOptions
func newManagerOptions(opts []ManagerOption) *ManagerOptions {
var o ManagerOptions
for _, f := range opts {
f(&o)
}
Expand All @@ -80,14 +80,14 @@ func (m *Manager) Create(ctx context.Context, i *Identity, opts ...ManagerOption
}

o := newManagerOptions(opts)
if err := m.validate(ctx, i, o); err != nil {
if err := m.ValidateIdentity(ctx, i, o); err != nil {
return err
}

return m.r.IdentityPool().(PrivilegedPool).CreateIdentity(ctx, i)
}

func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *managerOptions) (err error) {
func (m *Manager) requiresPrivilegedAccess(ctx context.Context, original, updated *Identity, o *ManagerOptions) (err error) {
_, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.requiresPrivilegedAccess")
defer otelx.End(span, &err)

Expand All @@ -114,7 +114,7 @@ func (m *Manager) Update(ctx context.Context, updated *Identity, opts ...Manager
defer otelx.End(span, &err)

o := newManagerOptions(opts)
if err := m.validate(ctx, updated, o); err != nil {
if err := m.ValidateIdentity(ctx, updated, o); err != nil {
return err
}

Expand Down Expand Up @@ -145,7 +145,7 @@ func (m *Manager) UpdateSchemaID(ctx context.Context, id uuid.UUID, schemaID str
}

original.SchemaID = schemaID
if err := m.validate(ctx, original, o); err != nil {
if err := m.ValidateIdentity(ctx, original, o); err != nil {
return err
}

Expand All @@ -165,7 +165,7 @@ func (m *Manager) SetTraits(ctx context.Context, id uuid.UUID, traits Traits, op
// original is used to check whether protected traits were modified
updated := deepcopy.Copy(original).(*Identity)
updated.Traits = traits
if err := m.validate(ctx, updated, o); err != nil {
if err := m.ValidateIdentity(ctx, updated, o); err != nil {
return nil, err
}

Expand All @@ -188,7 +188,7 @@ func (m *Manager) UpdateTraits(ctx context.Context, id uuid.UUID, traits Traits,
return m.r.IdentityPool().(PrivilegedPool).UpdateIdentity(ctx, updated)
}

func (m *Manager) validate(ctx context.Context, i *Identity, o *managerOptions) (err error) {
func (m *Manager) ValidateIdentity(ctx context.Context, i *Identity, o *ManagerOptions) (err error) {
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.validate")
defer otelx.End(span, &err)

Expand Down
241 changes: 120 additions & 121 deletions persistence/sql/persister_identity_test.go → identity/test/pool.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package sql_test
package test

import (
"context"
Expand All @@ -12,9 +12,6 @@ import (
"testing"
"time"

"github.com/ory/kratos/driver"
"github.com/ory/kratos/identity"

"github.com/ory/x/randx"

"github.com/tidwall/gjson"
Expand All @@ -23,6 +20,9 @@ import (

"github.com/ory/kratos/internal/testhelpers"

"github.com/ory/kratos/identity"
"github.com/ory/kratos/persistence"

"github.com/bxcodec/faker/v3"

"github.com/ory/x/sqlxx"
Expand All @@ -41,159 +41,158 @@ import (
"github.com/ory/kratos/x"
)

func (suite *PersisterTestSuite) TestIdentityExpand() {
ctx := context.Background()
expandSchema := schema.Schema{
ID: "expandSchema",
URL: urlx.ParseOrPanic("file://./stub/expand.schema.json"),
RawURL: "file://./stub/expand.schema.json",
}
func TestPool(ctx context.Context, conf *config.Config, p interface {
persistence.Persister
}, m *identity.Manager) func(t *testing.T) {
return func(t *testing.T) {
nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p)

suite.forAllConnections(func(t *testing.T, reg *driver.RegistryDefault) {
reg.Config().MustSet(ctx, config.ViperKeyIdentitySchemas, []config.Schema{
{
ID: expandSchema.ID,
URL: expandSchema.RawURL,
},
})
t.Run("case=expand", func(t *testing.T) {
expandSchema := schema.Schema{
ID: "expandSchema",
URL: urlx.ParseOrPanic("file://./stub/expand.schema.json"),
RawURL: "file://./stub/expand.schema.json",
}

require.NoError(t, reg.Persister().GetConnection(ctx).RawQuery("DELETE FROM identities WHERE true").Exec())
conf.MustSet(ctx, config.ViperKeyIdentitySchemas, []config.Schema{
{
ID: expandSchema.ID,
URL: expandSchema.RawURL,
},
})

expected := identity.NewIdentity(expandSchema.ID)
expected.Traits = identity.Traits(`{"email":"` + uuid.Must(uuid.NewV4()).String() + "@ory.sh" + `","name":"john doe"}`)
require.NoError(t, reg.IdentityManager().Create(ctx, expected))
require.NoError(t, identity.UpgradeCredentials(expected))
require.NoError(t, p.GetConnection(ctx).RawQuery("DELETE FROM identities WHERE true").Exec())
t.Cleanup(func() {
require.NoError(t, p.GetConnection(ctx).RawQuery("DELETE FROM identities WHERE true").Exec())
})

assert.NotEmpty(t, expected.RecoveryAddresses)
assert.NotEmpty(t, expected.VerifiableAddresses)
assert.NotEmpty(t, expected.Credentials)
assert.NotEqual(t, uuid.Nil, expected.RecoveryAddresses[0].ID)
assert.NotEqual(t, uuid.Nil, expected.VerifiableAddresses[0].ID)
expected := identity.NewIdentity(expandSchema.ID)
expected.Traits = identity.Traits(`{"email":"` + uuid.Must(uuid.NewV4()).String() + "@ory.sh" + `","name":"john doe"}`)
require.NoError(t, m.ValidateIdentity(ctx, expected, new(identity.ManagerOptions)))
require.NoError(t, p.CreateIdentity(ctx, expected))
require.NoError(t, identity.UpgradeCredentials(expected))

runner := func(t *testing.T, expand sqlxx.Expandables, cb func(*testing.T, *identity.Identity)) {
assertion := func(t *testing.T, actual *identity.Identity) {
assertx.EqualAsJSONExcept(t, expected, actual, []string{
"verifiable_addresses", "recovery_addresses", "updated_at", "created_at", "credentials",
})
cb(t, actual)
}
assert.NotEmpty(t, expected.RecoveryAddresses)
assert.NotEmpty(t, expected.VerifiableAddresses)
assert.NotEmpty(t, expected.Credentials)
assert.NotEqual(t, uuid.Nil, expected.RecoveryAddresses[0].ID)
assert.NotEqual(t, uuid.Nil, expected.VerifiableAddresses[0].ID)

t.Run("find", func(t *testing.T) {
actual, err := reg.Persister().GetIdentity(ctx, expected.ID, expand)
require.NoError(t, err)
assertion(t, actual)
})
runner := func(t *testing.T, expand sqlxx.Expandables, cb func(*testing.T, *identity.Identity)) {
assertion := func(t *testing.T, actual *identity.Identity) {
assertx.EqualAsJSONExcept(t, expected, actual, []string{
"verifiable_addresses", "recovery_addresses", "updated_at", "created_at", "credentials",
})
cb(t, actual)
}

t.Run("list", func(t *testing.T) {
actual, err := reg.Persister().ListIdentities(ctx, expand, 0, 10)
require.NoError(t, err)
require.Len(t, actual, 1)
assertion(t, &actual[0])
})
}
t.Run("find", func(t *testing.T) {
actual, err := p.GetIdentity(ctx, expected.ID, expand)
require.NoError(t, err)
assertion(t, actual)
})

t.Run("list", func(t *testing.T) {
actual, err := p.ListIdentities(ctx, expand, 0, 10)
require.NoError(t, err)
require.Len(t, actual, 1)
assertion(t, &actual[0])
})
}

t.Run("expand=nothing", func(t *testing.T) {
runner(t, identity.ExpandNothing, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.RecoveryAddresses)
assert.Empty(t, actual.VerifiableAddresses)
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
t.Run("expand=nothing", func(t *testing.T) {
runner(t, identity.ExpandNothing, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.RecoveryAddresses)
assert.Empty(t, actual.VerifiableAddresses)
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
})
})
})

t.Run("expand=credentials", func(t *testing.T) {
runner(t, identity.ExpandCredentials, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.RecoveryAddresses)
assert.Empty(t, actual.VerifiableAddresses)
t.Run("expand=credentials", func(t *testing.T) {
runner(t, identity.ExpandCredentials, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.RecoveryAddresses)
assert.Empty(t, actual.VerifiableAddresses)

require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)
require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)

assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn])
})
})
})

t.Run("expand=recovery address", func(t *testing.T) {
runner(t, sqlxx.Expandables{identity.ExpandFieldRecoveryAddresses}, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
assert.Empty(t, actual.VerifiableAddresses)
t.Run("expand=recovery address", func(t *testing.T) {
runner(t, sqlxx.Expandables{identity.ExpandFieldRecoveryAddresses}, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
assert.Empty(t, actual.VerifiableAddresses)

require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSON(t, expected.RecoveryAddresses, actual.RecoveryAddresses)
require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSON(t, expected.RecoveryAddresses, actual.RecoveryAddresses)
})
})
})

t.Run("expand=verification address", func(t *testing.T) {
runner(t, sqlxx.Expandables{identity.ExpandFieldVerifiableAddresses}, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
assert.Empty(t, actual.RecoveryAddresses)
t.Run("expand=verification address", func(t *testing.T) {
runner(t, sqlxx.Expandables{identity.ExpandFieldVerifiableAddresses}, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
assert.Empty(t, actual.RecoveryAddresses)

require.Len(t, actual.VerifiableAddresses, 1)
assertx.EqualAsJSON(t, expected.VerifiableAddresses, actual.VerifiableAddresses)
require.Len(t, actual.VerifiableAddresses, 1)
assertx.EqualAsJSON(t, expected.VerifiableAddresses, actual.VerifiableAddresses)
})
})
})

t.Run("expand=default", func(t *testing.T) {
runner(t, identity.ExpandDefault, func(t *testing.T, actual *identity.Identity) {
t.Run("expand=default", func(t *testing.T) {
runner(t, identity.ExpandDefault, func(t *testing.T, actual *identity.Identity) {

assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)

require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSON(t, expected.RecoveryAddresses, actual.RecoveryAddresses)
require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSON(t, expected.RecoveryAddresses, actual.RecoveryAddresses)

require.Len(t, actual.VerifiableAddresses, 1)
assertx.EqualAsJSON(t, expected.VerifiableAddresses, actual.VerifiableAddresses)
require.Len(t, actual.VerifiableAddresses, 1)
assertx.EqualAsJSON(t, expected.VerifiableAddresses, actual.VerifiableAddresses)
})
})
})

t.Run("expand=everything", func(t *testing.T) {
runner(t, identity.ExpandEverything, func(t *testing.T, actual *identity.Identity) {
t.Run("expand=everything", func(t *testing.T) {
runner(t, identity.ExpandEverything, func(t *testing.T, actual *identity.Identity) {

require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)
require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)

assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn])

require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSON(t, expected.RecoveryAddresses, actual.RecoveryAddresses)
require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSON(t, expected.RecoveryAddresses, actual.RecoveryAddresses)

require.Len(t, actual.VerifiableAddresses, 1)
assertx.EqualAsJSON(t, expected.VerifiableAddresses, actual.VerifiableAddresses)
require.Len(t, actual.VerifiableAddresses, 1)
assertx.EqualAsJSON(t, expected.VerifiableAddresses, actual.VerifiableAddresses)
})
})
})

t.Run("expand=load", func(t *testing.T) {
runner(t, identity.ExpandNothing, func(t *testing.T, actual *identity.Identity) {
require.NoError(t, reg.Persister().HydrateIdentityAssociations(ctx, actual, identity.ExpandEverything))
t.Run("expand=load", func(t *testing.T) {
runner(t, identity.ExpandNothing, func(t *testing.T, actual *identity.Identity) {
require.NoError(t, p.HydrateIdentityAssociations(ctx, actual, identity.ExpandEverything))

require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)
require.Len(t, actual.InternalCredentials, 2)
require.Len(t, actual.Credentials, 2)

assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypePassword], actual.Credentials[identity.CredentialsTypePassword])
assertx.EqualAsJSON(t, expected.Credentials[identity.CredentialsTypeWebAuthn], actual.Credentials[identity.CredentialsTypeWebAuthn])

require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSON(t, expected.RecoveryAddresses, actual.RecoveryAddresses)
require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSON(t, expected.RecoveryAddresses, actual.RecoveryAddresses)

require.Len(t, actual.VerifiableAddresses, 1)
assertx.EqualAsJSON(t, expected.VerifiableAddresses, actual.VerifiableAddresses)
require.Len(t, actual.VerifiableAddresses, 1)
assertx.EqualAsJSON(t, expected.VerifiableAddresses, actual.VerifiableAddresses)
})
})
})
})
}

func (suite *PersisterTestSuite) TestIdentity() {
ctx := context.Background()
suite.forAllConnections(func(t *testing.T, reg *driver.RegistryDefault) {
_, p := testhelpers.NewNetwork(t, ctx, reg.Persister())
conf := reg.Config()

nid, p := testhelpers.NewNetworkUnlessExisting(t, ctx, p)

exampleServerURL := urlx.ParseOrPanic("http://example.com")
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, exampleServerURL.String())
Expand Down Expand Up @@ -1027,5 +1026,5 @@ func (suite *PersisterTestSuite) TestIdentity() {
require.Len(t, i.Credentials, 1)
assert.Equal(t, "nid1", i.Credentials[m[0].Name].Identifiers[0])
})
})
}
}
Loading

0 comments on commit 75387ec

Please sign in to comment.