Skip to content

Commit

Permalink
refactor: distinguish between first and multi factor credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Mar 7, 2022
1 parent b0488ef commit 8de9d01
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 108 deletions.
2 changes: 1 addition & 1 deletion selfservice/strategy/lookup/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func (s *Strategy) identityHasLookup(ctx context.Context, id uuid.UUID) (bool, e
return false, err
}

count, err := s.CountActiveCredentials(confidential.Credentials)
count, err := s.CountActiveFirstFactorCredentials(confidential.Credentials)
if err != nil {
return false, err
}
Expand Down
6 changes: 5 additions & 1 deletion selfservice/strategy/lookup/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ func NewStrategy(d registrationStrategyDependencies) *Strategy {
}
}

func (s *Strategy) CountActiveCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
return 0, nil
}

func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
for _, c := range cc {
if c.Type == s.ID() && len(c.Config) > 0 {
var conf CredentialsConfig
Expand Down
106 changes: 57 additions & 49 deletions selfservice/strategy/lookup/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,65 @@ import (
"github.com/ory/kratos/internal"
)

func TestCountActiveCredentials(t *testing.T) {
func TestCountActiveFirstFactorCredentials(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
strategy := lookup.NewStrategy(reg)

for k, tc := range []struct {
in identity.CredentialsCollection
expected int
}{
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte{},
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte(`{"recovery_codes": []}`),
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Identifiers: []string{"foo"},
Config: []byte(`{"recovery_codes": [{}]}`),
}},
expected: 1,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte(`{}`),
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{}, {}},
expected: 0,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
cc := map[identity.CredentialsType]identity.Credentials{}
for _, c := range tc.in {
cc[c.Type] = c
}
t.Run("first factor", func(t *testing.T) {
actual, err := strategy.CountActiveFirstFactorCredentials(nil)
require.NoError(t, err)
assert.Equal(t, 0, actual)
})

actual, err := strategy.CountActiveCredentials(cc)
require.NoError(t, err)
assert.Equal(t, tc.expected, actual)
})
}
t.Run("multi factor", func(t *testing.T) {
for k, tc := range []struct {
in identity.CredentialsCollection
expected int
}{
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte{},
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte(`{"recovery_codes": []}`),
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Identifiers: []string{"foo"},
Config: []byte(`{"recovery_codes": [{}]}`),
}},
expected: 1,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte(`{}`),
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{}, {}},
expected: 0,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
cc := map[identity.CredentialsType]identity.Credentials{}
for _, c := range tc.in {
cc[c.Type] = c
}

actual, err := strategy.CountActiveMultiFactorCredentials(cc)
require.NoError(t, err)
assert.Equal(t, tc.expected, actual)
})
}
})
}
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type dependencies interface {
identity.ValidationProvider
identity.PrivilegedPoolProvider
identity.ActiveCredentialsCounterStrategyProvider
identity.ManagementProvider

session.ManagementProvider
session.HandlerProvider
Expand Down Expand Up @@ -112,7 +113,7 @@ type authCodeContainer struct {
Traits json.RawMessage `json:"traits"`
}

func (s *Strategy) CountActiveCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
for _, c := range cc {
if c.Type == s.ID() && gjson.ValidBytes(c.Config) {
var conf identity.CredentialsOIDC
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/oidc/strategy_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (s *Strategy) linkedProviders(ctx context.Context, r *http.Request, conf *C
return nil, errors.WithStack(err)
}

count, err := s.d.CountActiveCredentials(ctx, confidential)
count, err := s.d.IdentityManager().CountActiveFirstFactorCredentials(ctx, confidential)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ func TestStrategy(t *testing.T) {
})
}

func TestCountActiveCredentials(t *testing.T) {
func TestCountActiveFirstFactorCredentials(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
strategy := oidc.NewStrategy(reg)

Expand Down Expand Up @@ -586,7 +586,7 @@ func TestCountActiveCredentials(t *testing.T) {
for _, v := range tc.in {
in[v.Type] = v
}
actual, err := strategy.CountActiveCredentials(in)
actual, err := strategy.CountActiveFirstFactorCredentials(in)
require.NoError(t, err)
assert.Equal(t, tc.expected, actual)
})
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/password/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewStrategy(d registrationStrategyDependencies) *Strategy {
}
}

func (s *Strategy) CountActiveCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
for _, c := range cc {
if c.Type == s.ID() && len(c.Config) > 0 {
var conf identity.CredentialsPassword
Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/password/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/ory/kratos/selfservice/strategy/password"
)

func TestCountActiveCredentials(t *testing.T) {
func TestCountActiveFirstFactorCredentials(t *testing.T) {
_, reg := internal.NewFastRegistryWithMocks(t)
strategy := password.NewStrategy(reg)

Expand Down Expand Up @@ -91,7 +91,7 @@ func TestCountActiveCredentials(t *testing.T) {
cc[c.Type] = c
}

actual, err := strategy.CountActiveCredentials(cc)
actual, err := strategy.CountActiveFirstFactorCredentials(cc)
require.NoError(t, err)
assert.Equal(t, tc.expected, actual)
})
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/totp/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (s *Strategy) identityHasTOTP(ctx context.Context, id uuid.UUID) (bool, err
return false, err
}

count, err := s.CountActiveCredentials(confidential.Credentials)
count, err := s.CountActiveFirstFactorCredentials(confidential.Credentials)
if err != nil {
return false, err
}
Expand Down
6 changes: 5 additions & 1 deletion selfservice/strategy/totp/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ func NewStrategy(d registrationStrategyDependencies) *Strategy {
}
}

func (s *Strategy) CountActiveCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
func (s *Strategy) CountActiveFirstFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
return 0, nil
}

func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
for _, c := range cc {
if c.Type == s.ID() && len(c.Config) > 0 {
var conf CredentialsConfig
Expand Down
104 changes: 56 additions & 48 deletions selfservice/strategy/totp/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,53 +21,61 @@ func TestCountActiveCredentials(t *testing.T) {
key, err := totp.NewKey(context.Background(), "foo", reg)
require.NoError(t, err)

for k, tc := range []struct {
in identity.CredentialsCollection
expected int
}{
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte{},
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte(`{"totp_url": ""}`),
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Identifiers: []string{"foo"},
Config: []byte(`{"totp_url": "` + key.URL() + `"}`),
}},
expected: 1,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte(`{}`),
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{}, {}},
expected: 0,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
cc := map[identity.CredentialsType]identity.Credentials{}
for _, c := range tc.in {
cc[c.Type] = c
}
t.Run("first factor", func(t *testing.T) {
actual, err := strategy.CountActiveFirstFactorCredentials(nil)
require.NoError(t, err)
assert.Equal(t, 0, actual)
})

actual, err := strategy.CountActiveCredentials(cc)
require.NoError(t, err)
assert.Equal(t, tc.expected, actual)
})
}
t.Run("multi factor", func(t *testing.T) {
for k, tc := range []struct {
in identity.CredentialsCollection
expected int
}{
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte{},
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte(`{"totp_url": ""}`),
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Identifiers: []string{"foo"},
Config: []byte(`{"totp_url": "` + key.URL() + `"}`),
}},
expected: 1,
},
{
in: identity.CredentialsCollection{{
Type: strategy.ID(),
Config: []byte(`{}`),
}},
expected: 0,
},
{
in: identity.CredentialsCollection{{}, {}},
expected: 0,
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
cc := map[identity.CredentialsType]identity.Credentials{}
for _, c := range tc.in {
cc[c.Type] = c
}

actual, err := strategy.CountActiveMultiFactorCredentials(cc)
require.NoError(t, err)
assert.Equal(t, tc.expected, actual)
})
}
})
}

0 comments on commit 8de9d01

Please sign in to comment.