Skip to content
This repository has been archived by the owner on Aug 25, 2023. It is now read-only.

Commit

Permalink
chore: add health check key
Browse files Browse the repository at this point in the history
Signed-off-by: Firas Qutishat <[email protected]>
  • Loading branch information
fqutishat committed May 26, 2022
1 parent 52184bd commit 4bc3ce7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 55 deletions.
19 changes: 7 additions & 12 deletions pkg/aws/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type awsClient interface {
Sign(input *kms.SignInput) (*kms.SignOutput, error)
GetPublicKey(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
Verify(input *kms.VerifyInput) (*kms.VerifyOutput, error)
ListKeys(input *kms.ListKeysInput) (*kms.ListKeysOutput, error)
DescribeKey(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error)
}

type metricsProvider interface {
Expand All @@ -36,8 +36,9 @@ type metricsProvider interface {

// Service aws kms.
type Service struct {
client awsClient
metrics metricsProvider
client awsClient
metrics metricsProvider
healthCheckKeyID string
}

// nolint: gochecknoglobals
Expand All @@ -48,8 +49,8 @@ var kmsKeyTypes = map[string]arieskms.KeyType{
}

// New return aws service.
func New(awsSession *session.Session, metrics metricsProvider) *Service {
return &Service{client: kms.New(awsSession), metrics: metrics}
func New(awsSession *session.Session, metrics metricsProvider, healthCheckKeyID string) *Service {
return &Service{client: kms.New(awsSession), metrics: metrics, healthCheckKeyID: healthCheckKeyID}
}

// Sign data.
Expand Down Expand Up @@ -89,17 +90,11 @@ func (s *Service) Get(keyID string) (interface{}, error) {

// HealthCheck check kms.
func (s *Service) HealthCheck() error {
var limit int64 = 1

result, err := s.client.ListKeys(&kms.ListKeysInput{Limit: &limit})
_, err := s.client.DescribeKey(&kms.DescribeKeyInput{KeyId: &s.healthCheckKeyID})
if err != nil {
return err
}

if len(result.Keys) == 0 {
return fmt.Errorf("list of keys are empty")
}

return nil
}

Expand Down
62 changes: 19 additions & 43 deletions pkg/aws/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestSign(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{signFunc: func(input *kms.SignInput) (*kms.SignOutput, error) {
return &kms.SignOutput{
Expand All @@ -54,7 +54,7 @@ func TestSign(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{signFunc: func(input *kms.SignInput) (*kms.SignOutput, error) {
return nil, fmt.Errorf("failed to sign")
Expand All @@ -75,7 +75,7 @@ func TestSign(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

_, err = svc.Sign([]byte("msg"), "key1")
require.Error(t, err)
Expand All @@ -93,14 +93,10 @@ func TestHealthCheck(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

keyID := "key1"

svc.client = &mockAWSClient{listKeysFunc: func(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) {
return &kms.ListKeysOutput{
Keys: []*kms.KeyListEntry{{KeyId: &keyID}},
}, nil
svc.client = &mockAWSClient{describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
return &kms.DescribeKeyOutput{}, nil
}}

err = svc.HealthCheck()
Expand All @@ -116,36 +112,16 @@ func TestHealthCheck(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{listKeysFunc: func(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) {
svc.client = &mockAWSClient{describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
return nil, fmt.Errorf("failed to list keys")
}}

err = svc.HealthCheck()
require.Error(t, err)
require.Contains(t, err.Error(), "failed to list keys")
})

t.Run("empty keys", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})

svc.client = &mockAWSClient{listKeysFunc: func(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) {
return &kms.ListKeysOutput{}, nil
}}

err = svc.HealthCheck()
require.Error(t, err)
require.Contains(t, err.Error(), "list of keys are empty")
})
}

func TestGet(t *testing.T) {
Expand All @@ -158,7 +134,7 @@ func TestGet(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

keyID, err := svc.Get("key1")
require.NoError(t, err)
Expand All @@ -176,7 +152,7 @@ func TestPubKeyBytes(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{getPublicKeyFunc: func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) {
signingAlgo := "ECDSA_SHA_256"
Expand All @@ -203,7 +179,7 @@ func TestPubKeyBytes(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{getPublicKeyFunc: func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) {
return nil, fmt.Errorf("failed to export public key")
Expand All @@ -224,7 +200,7 @@ func TestPubKeyBytes(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

_, _, err = svc.ExportPubKeyBytes("key1")
require.Error(t, err)
Expand All @@ -242,7 +218,7 @@ func TestVerify(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{verifyFunc: func(input *kms.VerifyInput) (*kms.VerifyOutput, error) {
return &kms.VerifyOutput{}, nil
Expand All @@ -262,7 +238,7 @@ func TestVerify(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{verifyFunc: func(input *kms.VerifyInput) (*kms.VerifyOutput, error) {
return nil, fmt.Errorf("failed to verify")
Expand All @@ -283,7 +259,7 @@ func TestVerify(t *testing.T) {
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{})
svc := New(awsSession, &mockMetrics{}, "")

err = svc.Verify([]byte("sign"), []byte("msg"), "key1")
require.Error(t, err)
Expand All @@ -295,7 +271,7 @@ type mockAWSClient struct {
signFunc func(input *kms.SignInput) (*kms.SignOutput, error)
getPublicKeyFunc func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
verifyFunc func(input *kms.VerifyInput) (*kms.VerifyOutput, error)
listKeysFunc func(input *kms.ListKeysInput) (*kms.ListKeysOutput, error)
describeKeyFunc func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error)
}

func (m *mockAWSClient) Sign(input *kms.SignInput) (*kms.SignOutput, error) {
Expand All @@ -322,9 +298,9 @@ func (m *mockAWSClient) Verify(input *kms.VerifyInput) (*kms.VerifyOutput, error
return nil, nil
}

func (m *mockAWSClient) ListKeys(input *kms.ListKeysInput) (*kms.ListKeysOutput, error) {
if m.listKeysFunc != nil {
return m.listKeysFunc(input)
func (m *mockAWSClient) DescribeKey(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
if m.describeKeyFunc != nil {
return m.describeKeyFunc(input)
}

return nil, nil
Expand Down

0 comments on commit 4bc3ce7

Please sign in to comment.