diff --git a/driver/registry_sql.go b/driver/registry_sql.go index 4d4de5c86fe..ec6c6c354b2 100644 --- a/driver/registry_sql.go +++ b/driver/registry_sql.go @@ -85,7 +85,8 @@ func (m *RegistrySQL) Init(ctx context.Context) error { } if m.C.HsmEnabled() { - m.defaultKeyManager = hsm.NewKeyManager(m.HsmContext()) + hardwareKeyManager := hsm.NewKeyManager(m.HsmContext()) + m.defaultKeyManager = jwk.NewManagerStrategy(hardwareKeyManager, m.persister) } else { m.defaultKeyManager = m.persister } diff --git a/hsm/manager_hsm.go b/hsm/manager_hsm.go index db85f10d6a0..928a61ce960 100644 --- a/hsm/manager_hsm.go +++ b/hsm/manager_hsm.go @@ -29,12 +29,11 @@ import ( ) type KeyManager struct { + jwk.Manager sync.RWMutex Context } -var _ jwk.Manager = &KeyManager{} - var ErrPreGeneratedKeys = &fosite.RFC6749Error{ CodeField: http.StatusBadRequest, ErrorField: http.StatusText(http.StatusBadRequest), diff --git a/hsm/manager_hsm_test.go b/hsm/manager_hsm_test.go index d5c42acfe00..a60b80d54d6 100644 --- a/hsm/manager_hsm_test.go +++ b/hsm/manager_hsm_test.go @@ -49,7 +49,7 @@ func TestDefaultKeyManager_HsmEnabled(t *testing.T) { reg.WithHsmContext(mockHsmContext) err := reg.Init(context.Background()) assert.NoError(t, err) - assert.IsType(t, &hsm.KeyManager{}, reg.KeyManager()) + assert.IsType(t, &jwk.ManagerStrategy{}, reg.KeyManager()) assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager()) } diff --git a/hsm/manager_nohsm.go b/hsm/manager_nohsm.go index f67af1537b0..2e20d9a7891 100644 --- a/hsm/manager_nohsm.go +++ b/hsm/manager_nohsm.go @@ -21,12 +21,11 @@ type Context interface { } type KeyManager struct { + jwk.Manager sync.RWMutex Context } -var _ jwk.Manager = &KeyManager{} - var ErrOpSysNotSupported = errors.New("Hardware Security Module is not supported on this platform.") func NewContext(c *config.Provider, l *logrusx.Logger) Context { diff --git a/jwk/manager_strategy.go b/jwk/manager_strategy.go new file mode 100644 index 00000000000..2e7160181c7 --- /dev/null +++ b/jwk/manager_strategy.go @@ -0,0 +1,86 @@ +package jwk + +import ( + "context" + + "github.com/pkg/errors" + "gopkg.in/square/go-jose.v2" + + "github.com/ory/hydra/x" +) + +type ManagerStrategy struct { + hardwareKeyManager Manager + softwareKeyManager Manager +} + +func NewManagerStrategy(hardwareKeyManager Manager, softwareKeyManager Manager) *ManagerStrategy { + return &ManagerStrategy{ + hardwareKeyManager: hardwareKeyManager, + softwareKeyManager: softwareKeyManager, + } +} + +func (m ManagerStrategy) GenerateAndPersistKeySet(ctx context.Context, set, kid, alg, use string) (*jose.JSONWebKeySet, error) { + return m.hardwareKeyManager.GenerateAndPersistKeySet(ctx, set, kid, alg, use) +} + +func (m ManagerStrategy) AddKey(ctx context.Context, set string, key *jose.JSONWebKey) error { + return m.softwareKeyManager.AddKey(ctx, set, key) +} + +func (m ManagerStrategy) AddKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error { + return m.softwareKeyManager.AddKeySet(ctx, set, keys) +} + +func (m ManagerStrategy) UpdateKey(ctx context.Context, set string, key *jose.JSONWebKey) error { + return m.softwareKeyManager.UpdateKey(ctx, set, key) +} + +func (m ManagerStrategy) UpdateKeySet(ctx context.Context, set string, keys *jose.JSONWebKeySet) error { + return m.softwareKeyManager.UpdateKeySet(ctx, set, keys) +} + +func (m ManagerStrategy) GetKey(ctx context.Context, set, kid string) (*jose.JSONWebKeySet, error) { + keySet, err := m.hardwareKeyManager.GetKey(ctx, set, kid) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return nil, err + } else if keySet != nil { + return keySet, nil + } else { + return m.softwareKeyManager.GetKey(ctx, set, kid) + } +} + +func (m ManagerStrategy) GetKeySet(ctx context.Context, set string) (*jose.JSONWebKeySet, error) { + keySet, err := m.hardwareKeyManager.GetKeySet(ctx, set) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return nil, err + } else if keySet != nil { + return keySet, nil + } else { + return m.softwareKeyManager.GetKeySet(ctx, set) + } +} + +func (m ManagerStrategy) DeleteKey(ctx context.Context, set, kid string) error { + err := m.hardwareKeyManager.DeleteKey(ctx, set, kid) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return err + } else if errors.Is(err, x.ErrNotFound) { + return m.softwareKeyManager.DeleteKey(ctx, set, kid) + } else { + return nil + } +} + +func (m ManagerStrategy) DeleteKeySet(ctx context.Context, set string) error { + err := m.hardwareKeyManager.DeleteKeySet(ctx, set) + if err != nil && !errors.Is(err, x.ErrNotFound) { + return err + } else if errors.Is(err, x.ErrNotFound) { + return m.softwareKeyManager.DeleteKeySet(ctx, set) + } else { + return nil + } +} diff --git a/jwk/manager_strategy_test.go b/jwk/manager_strategy_test.go new file mode 100644 index 00000000000..ce255e9fab3 --- /dev/null +++ b/jwk/manager_strategy_test.go @@ -0,0 +1,205 @@ +package jwk_test + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" + + "github.com/ory/hydra/jwk" + "github.com/ory/hydra/x" +) + +func TestKeyManagerStrategy(t *testing.T) { + ctrl := gomock.NewController(t) + softwareKeyManager := NewMockManager(ctrl) + hardwareKeyManager := NewMockManager(ctrl) + keyManager := jwk.NewManagerStrategy(hardwareKeyManager, softwareKeyManager) + defer ctrl.Finish() + hwKeySet := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{{ + KeyID: "hwKeyID", + }}, + } + swKeySet := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{{ + KeyID: "swKeyID", + }}, + } + + t.Run("GenerateAndPersistKeySet_WithResult", func(t *testing.T) { + hardwareKeyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1"), gomock.Any(), gomock.Any()).Return(hwKeySet, nil) + resultKeySet, err := keyManager.GenerateAndPersistKeySet(nil, "set1", "kid1", "RS256", "sig") + assert.NoError(t, err) + assert.Equal(t, hwKeySet, resultKeySet) + }) + + t.Run("GenerateAndPersistKeySet_WithError", func(t *testing.T) { + hardwareKeyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1"), gomock.Any(), gomock.Any()).Return(nil, errors.New("test")) + resultKeySet, err := keyManager.GenerateAndPersistKeySet(nil, "set1", "kid1", "RS256", "sig") + assert.Error(t, err, "test") + assert.Nil(t, resultKeySet) + }) + + t.Run("AddKey", func(t *testing.T) { + softwareKeyManager.EXPECT().AddKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil) + err := keyManager.AddKey(nil, "set1", nil) + assert.NoError(t, err) + }) + + t.Run("AddKey_WithError", func(t *testing.T) { + softwareKeyManager.EXPECT().AddKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test")) + err := keyManager.AddKey(nil, "set1", nil) + assert.Error(t, err, "test") + }) + + t.Run("AddKeySet", func(t *testing.T) { + softwareKeyManager.EXPECT().AddKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil) + err := keyManager.AddKeySet(nil, "set1", nil) + assert.NoError(t, err) + }) + + t.Run("AddKeySet_WithError", func(t *testing.T) { + softwareKeyManager.EXPECT().AddKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test")) + err := keyManager.AddKeySet(nil, "set1", nil) + assert.Error(t, err, "test") + }) + + t.Run("UpdateKey", func(t *testing.T) { + softwareKeyManager.EXPECT().UpdateKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil) + err := keyManager.UpdateKey(nil, "set1", nil) + assert.NoError(t, err) + }) + + t.Run("UpdateKey_WithError", func(t *testing.T) { + softwareKeyManager.EXPECT().UpdateKey(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test")) + err := keyManager.UpdateKey(nil, "set1", nil) + assert.Error(t, err, "test") + }) + + t.Run("UpdateKeySet", func(t *testing.T) { + softwareKeyManager.EXPECT().UpdateKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(nil) + err := keyManager.UpdateKeySet(nil, "set1", nil) + assert.NoError(t, err) + }) + + t.Run("UpdateKeySet_WithError", func(t *testing.T) { + softwareKeyManager.EXPECT().UpdateKeySet(gomock.Any(), gomock.Eq("set1"), gomock.Any()).Return(errors.New("test")) + err := keyManager.UpdateKeySet(nil, "set1", nil) + assert.Error(t, err, "test") + }) + + t.Run("GetKey_WithResultFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(hwKeySet, nil) + resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1") + assert.NoError(t, err) + assert.Equal(t, hwKeySet, resultKeySet) + }) + + t.Run("GetKey_WithErrorFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.New("test")) + resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1") + assert.Error(t, err, "test") + assert.Nil(t, resultKeySet) + }) + + t.Run("GetKey_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(swKeySet, nil) + resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1") + assert.NoError(t, err) + assert.Equal(t, swKeySet, resultKeySet) + }) + + t.Run("GetKey_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().GetKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil, errors.WithStack(x.ErrNotFound)) + resultKeySet, err := keyManager.GetKey(nil, "set1", "kid1") + assert.Error(t, err, "Not Found") + assert.Nil(t, resultKeySet) + }) + + t.Run("GetKeySet_WithResultFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(hwKeySet, nil) + resultKeySet, err := keyManager.GetKeySet(nil, "set1") + assert.NoError(t, err) + assert.Equal(t, hwKeySet, resultKeySet) + }) + + t.Run("GetKeySet_WithErrorFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.New("test")) + resultKeySet, err := keyManager.GetKeySet(nil, "set1") + assert.Error(t, err, "test") + assert.Nil(t, resultKeySet) + }) + + t.Run("GetKeySet_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(swKeySet, nil) + resultKeySet, err := keyManager.GetKeySet(nil, "set1") + assert.NoError(t, err) + assert.Equal(t, swKeySet, resultKeySet) + }) + + t.Run("GetKeySet_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil, errors.WithStack(x.ErrNotFound)) + resultKeySet, err := keyManager.GetKeySet(nil, "set1") + assert.Error(t, err, "Not Found") + assert.Nil(t, resultKeySet) + }) + + t.Run("DeleteKey_FromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil) + err := keyManager.DeleteKey(nil, "set1", "kid1") + assert.NoError(t, err) + }) + + t.Run("DeleteKey_WithErrorFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.New("test")) + err := keyManager.DeleteKey(nil, "set1", "kid1") + assert.Error(t, err, "test") + }) + + t.Run("DeleteKey_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(nil) + err := keyManager.DeleteKey(nil, "set1", "kid1") + assert.NoError(t, err) + }) + + t.Run("DeleteKey_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().DeleteKey(gomock.Any(), gomock.Eq("set1"), gomock.Eq("kid1")).Return(errors.WithStack(x.ErrNotFound)) + err := keyManager.DeleteKey(nil, "set1", "kid1") + assert.Error(t, err, "Not Found") + }) + + t.Run("DeleteKeySet_FromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil) + err := keyManager.DeleteKeySet(nil, "set1") + assert.NoError(t, err) + }) + + t.Run("DeleteKeySet_WithErrorFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.New("test")) + err := keyManager.DeleteKeySet(nil, "set1") + assert.Error(t, err, "test") + }) + + t.Run("DeleteKeySet_WithErrNotFoundFromHardwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(nil) + err := keyManager.DeleteKeySet(nil, "set1") + assert.NoError(t, err) + }) + + t.Run("DeleteKeySet_WithErrNotFoundFromSoftwareKeyManager", func(t *testing.T) { + hardwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound)) + softwareKeyManager.EXPECT().DeleteKeySet(gomock.Any(), gomock.Eq("set1")).Return(errors.WithStack(x.ErrNotFound)) + err := keyManager.DeleteKeySet(nil, "set1") + assert.Error(t, err, "Not Found") + }) +}