Skip to content

Commit

Permalink
feat: add key manager strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmam committed Jan 7, 2022
1 parent 6965b87 commit ff11892
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 6 deletions.
3 changes: 2 additions & 1 deletion driver/registry_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ func (m *RegistrySQL) Init(ctx context.Context) error {
}

if m.C.HsmEnabled() {
m.defaultKeyManager = hsm.NewKeyManager(m.HsmContext())
hardwareKeyManager := hsm.NewKeyManager(m.HsmContext())
m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister)
} else {
m.defaultKeyManager = m.persister
}
Expand Down
3 changes: 1 addition & 2 deletions hsm/manager_hsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ import (
)

type KeyManager struct {
jwk.Manager
sync.RWMutex
Context
}

var _ jwk.Manager = &KeyManager{}

var ErrPreGeneratedKeys = &fosite.RFC6749Error{
CodeField: http.StatusBadRequest,
ErrorField: http.StatusText(http.StatusBadRequest),
Expand Down
2 changes: 1 addition & 1 deletion hsm/manager_hsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestDefaultKeyManager_HsmEnabled(t *testing.T) {
reg.WithHsmContext(mockHsmContext)
err := reg.Init(context.Background())
assert.NoError(t, err)
assert.IsType(t, &hsm.KeyManager{}, reg.KeyManager())
assert.IsType(t, &jwk.ManagerStrategy{}, reg.KeyManager())
assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager())
}

Expand Down
3 changes: 1 addition & 2 deletions hsm/manager_nohsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ type Context interface {
}

type KeyManager struct {
jwk.Manager
sync.RWMutex
Context
}

var _ jwk.Manager = &KeyManager{}

var ErrOpSysNotSupported = errors.New("Hardware Security Module is not supported on this platform.")

func NewContext(c *config.Provider, l *logrusx.Logger) Context {
Expand Down
86 changes: 86 additions & 0 deletions jwk/manager_strategy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package jwk

import (
"context"

"github.com/pkg/errors"
"gopkg.in/square/go-jose.v2"

"github.com/ory/hydra/x"
)

type ManagerStrategy struct {
hardwareKeyManager Manager
softwareKeyManager Manager
}

func NewManagerStrategy(hardwareKeyManager Manager, softwareKeyManager Manager) *ManagerStrategy {
return &ManagerStrategy{
hardwareKeyManager: hardwareKeyManager,
softwareKeyManager: softwareKeyManager,
}
}

func (m ManagerStrategy) GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) {
return m.hardwareKeyManager.GenerateAndPersistKeySet(ctx, set, kid, alg, use)
}

func (m ManagerStrategy) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error {
return m.softwareKeyManager.AddKey(ctx, set, key)
}

func (m ManagerStrategy) AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error {
return m.softwareKeyManager.AddKeySet(ctx, set, keys)
}

func (m ManagerStrategy) UpdateKey(ctx context.Context, set string, key *jose.JSONWebKey) error {
return m.softwareKeyManager.UpdateKey(ctx, set, key)
}

func (m ManagerStrategy) UpdateKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error {
return m.softwareKeyManager.UpdateKeySet(ctx, set, keys)
}

func (m ManagerStrategy) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) {
keySet, err := m.hardwareKeyManager.GetKey(ctx, set, kid)
if err != nil && !errors.Is(err, x.ErrNotFound) {
return nil, err
} else if keySet != nil {
return keySet, nil
} else {
return m.softwareKeyManager.GetKey(ctx, set, kid)
}
}

func (m ManagerStrategy) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) {
keySet, err := m.hardwareKeyManager.GetKeySet(ctx, set)
if err != nil && !errors.Is(err, x.ErrNotFound) {
return nil, err
} else if keySet != nil {
return keySet, nil
} else {
return m.softwareKeyManager.GetKeySet(ctx, set)
}
}

func (m ManagerStrategy) DeleteKey(ctx context.Context, set, kid string) error {
err := m.hardwareKeyManager.DeleteKey(ctx, set, kid)
if err != nil && !errors.Is(err, x.ErrNotFound) {
return err
} else if errors.Is(err, x.ErrNotFound) {
return m.softwareKeyManager.DeleteKey(ctx, set, kid)
} else {
return nil
}
}

func (m ManagerStrategy) DeleteKeySet(ctx context.Context, set string) error {
err := m.hardwareKeyManager.DeleteKeySet(ctx, set)
if err != nil && !errors.Is(err, x.ErrNotFound) {
return err
} else if errors.Is(err, x.ErrNotFound) {
return m.softwareKeyManager.DeleteKeySet(ctx, set)
} else {
return nil
}
}
205 changes: 205 additions & 0 deletions jwk/manager_strategy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package jwk_test

import (
"testing"

"github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"

"github.com/ory/hydra/jwk"
"github.com/ory/hydra/x"
)

func TestKeyManagerStrategy(t *testing.T) {
ctrl := gomock.NewController(t)
softwareKeyManager := NewMockManager(ctrl)
hardwareKeyManager := NewMockManager(ctrl)
keyManager := jwk.NewManagerStrategy(hardwareKeyManager, softwareKeyManager)
defer ctrl.Finish()
hwKeySet := &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{{
KeyID: "hwKeyID",
}},
}
swKeySet := &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{{
KeyID: "swKeyID",
}},
}

t.Run("GenerateAndPersistKeySet_WithResult", func(t *testing.T) {
hardwareKeyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1"), gomock.Any(), gomock.Any()).Return(hwKeySet, nil)
resultKeySet, err := keyManager.GenerateAndPersistKeySet(nil, "set1", "kid1", "RS256", "sig")
assert.NoError(t, err)
assert.Equal(t, hwKeySet, resultKeySet)
})

t.Run("GenerateAndPersistKeySet_WithError", func(t *testing.T) {
hardwareKeyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1"), gomock.Any(), gomock.Any()).Return(nil, errors.New("test"))
resultKeySet, err := keyManager.GenerateAndPersistKeySet(nil, "set1", "kid1", "RS256", "sig")
assert.Error(t, err, "test")
assert.Nil(t, resultKeySet)
})

t.Run("AddKey", func(t *testing.T) {
softwareKeyManager.EXPECT().AddKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil)
err := keyManager.AddKey(nil, "set1", nil)
assert.NoError(t, err)
})

t.Run("AddKey_WithError", func(t *testing.T) {
softwareKeyManager.EXPECT().AddKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test"))
err := keyManager.AddKey(nil, "set1", nil)
assert.Error(t, err, "test")
})

t.Run("AddKeySet", func(t *testing.T) {
softwareKeyManager.EXPECT().AddKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil)
err := keyManager.AddKeySet(nil, "set1", nil)
assert.NoError(t, err)
})

t.Run("AddKeySet_WithError", func(t *testing.T) {
softwareKeyManager.EXPECT().AddKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test"))
err := keyManager.AddKeySet(nil, "set1", nil)
assert.Error(t, err, "test")
})

t.Run("UpdateKey", func(t *testing.T) {
softwareKeyManager.EXPECT().UpdateKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil)
err := keyManager.UpdateKey(nil, "set1", nil)
assert.NoError(t, err)
})

t.Run("UpdateKey_WithError", func(t *testing.T) {
softwareKeyManager.EXPECT().UpdateKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test"))
err := keyManager.UpdateKey(nil, "set1", nil)
assert.Error(t, err, "test")
})

t.Run("UpdateKeySet", func(t *testing.T) {
softwareKeyManager.EXPECT().UpdateKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil)
err := keyManager.UpdateKeySet(nil, "set1", nil)
assert.NoError(t, err)
})

t.Run("UpdateKeySet_WithError", func(t *testing.T) {
softwareKeyManager.EXPECT().UpdateKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test"))
err := keyManager.UpdateKeySet(nil, "set1", nil)
assert.Error(t, err, "test")
})

t.Run("GetKey_WithResultFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(hwKeySet, nil)
resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1")
assert.NoError(t, err)
assert.Equal(t, hwKeySet, resultKeySet)
})

t.Run("GetKey_WithErrorFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.New("test"))
resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1")
assert.Error(t, err, "test")
assert.Nil(t, resultKeySet)
})

t.Run("GetKey_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound))
softwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(swKeySet, nil)
resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1")
assert.NoError(t, err)
assert.Equal(t, swKeySet, resultKeySet)
})

t.Run("GetKey_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound))
softwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound))
resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1")
assert.Error(t, err, "Not Found")
assert.Nil(t, resultKeySet)
})

t.Run("GetKeySet_WithResultFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(hwKeySet, nil)
resultKeySet, err := keyManager.GetKeySet(nil, "set1")
assert.NoError(t, err)
assert.Equal(t, hwKeySet, resultKeySet)
})

t.Run("GetKeySet_WithErrorFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.New("test"))
resultKeySet, err := keyManager.GetKeySet(nil, "set1")
assert.Error(t, err, "test")
assert.Nil(t, resultKeySet)
})

t.Run("GetKeySet_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound))
softwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(swKeySet, nil)
resultKeySet, err := keyManager.GetKeySet(nil, "set1")
assert.NoError(t, err)
assert.Equal(t, swKeySet, resultKeySet)
})

t.Run("GetKeySet_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound))
softwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound))
resultKeySet, err := keyManager.GetKeySet(nil, "set1")
assert.Error(t, err, "Not Found")
assert.Nil(t, resultKeySet)
})

t.Run("DeleteKey_FromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil)
err := keyManager.DeleteKey(nil, "set1", "kid1")
assert.NoError(t, err)
})

t.Run("DeleteKey_WithErrorFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.New("test"))
err := keyManager.DeleteKey(nil, "set1", "kid1")
assert.Error(t, err, "test")
})

t.Run("DeleteKey_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound))
softwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil)
err := keyManager.DeleteKey(nil, "set1", "kid1")
assert.NoError(t, err)
})

t.Run("DeleteKey_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound))
softwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound))
err := keyManager.DeleteKey(nil, "set1", "kid1")
assert.Error(t, err, "Not Found")
})

t.Run("DeleteKeySet_FromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil)
err := keyManager.DeleteKeySet(nil, "set1")
assert.NoError(t, err)
})

t.Run("DeleteKeySet_WithErrorFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.New("test"))
err := keyManager.DeleteKeySet(nil, "set1")
assert.Error(t, err, "test")
})

t.Run("DeleteKeySet_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound))
softwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil)
err := keyManager.DeleteKeySet(nil, "set1")
assert.NoError(t, err)
})

t.Run("DeleteKeySet_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) {
hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound))
softwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound))
err := keyManager.DeleteKeySet(nil, "set1")
assert.Error(t, err, "Not Found")
})
}

0 comments on commit ff11892

Please sign in to comment.