From ba4fcc1ea65b40a9e93c7bc8bd900e06b0aa1ca2 Mon Sep 17 00:00:00 2001 From: Firas Qutishat Date: Mon, 26 Sep 2022 10:35:19 -0400 Subject: [PATCH] fix: aws kms create method Signed-off-by: Firas Qutishat --- pkg/aws/service.go | 36 +++++++++++++++++++++++++++--------- pkg/aws/service_test.go | 23 ++++++++++------------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/pkg/aws/service.go b/pkg/aws/service.go index e9fd5580..0b4dea0d 100644 --- a/pkg/aws/service.go +++ b/pkg/aws/service.go @@ -13,7 +13,6 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" arieskms "github.com/hyperledger/aries-framework-go/pkg/kms" @@ -24,7 +23,7 @@ type awsClient interface { GetPublicKey(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) Verify(input *kms.VerifyInput) (*kms.VerifyOutput, error) DescribeKey(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) - CreateKeyRequest(input *kms.CreateKeyInput) (req *request.Request, output *kms.CreateKeyOutput) + CreateKey(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) } type metricsProvider interface { @@ -60,10 +59,14 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) { startTime := time.Now() defer func() { - s.metrics.SignTime(time.Since(startTime)) + if s.metrics != nil { + s.metrics.SignTime(time.Since(startTime)) + } }() - s.metrics.SignCount() + if s.metrics != nil { + s.metrics.SignCount() + } keyID, err := getKeyID(kh.(string)) if err != nil { @@ -110,10 +113,14 @@ func (s *Service) ExportPubKeyBytes(keyURI string) ([]byte, arieskms.KeyType, er startTime := time.Now() defer func() { - s.metrics.ExportPublicKeyTime(time.Since(startTime)) + if s.metrics != nil { + s.metrics.ExportPublicKeyTime(time.Since(startTime)) + } }() - s.metrics.ExportPublicKeyCount() + if s.metrics != nil { + s.metrics.ExportPublicKeyCount() + } keyID, err := getKeyID(keyURI) if err != nil { @@ -137,10 +144,14 @@ func (s *Service) Verify(signature, msg []byte, kh interface{}) error { startTime := time.Now() defer func() { - s.metrics.VerifyTime(time.Since(startTime)) + if s.metrics != nil { + s.metrics.VerifyTime(time.Since(startTime)) + } }() - s.metrics.VerifyCount() + if s.metrics != nil { + s.metrics.VerifyCount() + } keyID, err := getKeyID(kh.(string)) if err != nil { @@ -177,7 +188,10 @@ func (s *Service) Create(kt arieskms.KeyType) (string, interface{}, error) { return "", nil, fmt.Errorf("key not supported %s", kt) } - _, result := s.client.CreateKeyRequest(&kms.CreateKeyInput{KeySpec: &keySpec, KeyUsage: &keyUsage}) + result, err := s.client.CreateKey(&kms.CreateKeyInput{KeySpec: &keySpec, KeyUsage: &keyUsage}) + if err != nil { + return "", nil, err + } return *result.KeyMetadata.KeyId, *result.KeyMetadata.KeyId, nil } @@ -209,6 +223,10 @@ func (s *Service) SignMulti(messages [][]byte, kh interface{}) ([]byte, error) { } func getKeyID(keyURI string) (string, error) { + if !strings.Contains(keyURI, "aws-kms") { + return keyURI, nil + } + // keyURI must have the following format: 'aws-kms://arn::kms::[:path]'. // See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html. re1 := regexp.MustCompile(`aws-kms://arn:(aws[a-zA-Z0-9-_]*):kms:([a-z0-9-]+):([a-z0-9-]+):key/(.+)`) diff --git a/pkg/aws/service_test.go b/pkg/aws/service_test.go index 731ed621..e5642e18 100644 --- a/pkg/aws/service_test.go +++ b/pkg/aws/service_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" arieskms "github.com/hyperledger/aries-framework-go/pkg/kms" @@ -79,7 +78,7 @@ func TestSign(t *testing.T) { svc := New(awsSession, &mockMetrics{}, "") - _, err = svc.Sign([]byte("msg"), "key1") + _, err = svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:key1") require.Error(t, err) require.Contains(t, err.Error(), "extracting key id from URI failed") }) @@ -142,9 +141,8 @@ func TestCreate(t *testing.T) { keyID := "key1" - svc.client = &mockAWSClient{createKeyFunc: func(input *kms.CreateKeyInput) (req *request.Request, - output *kms.CreateKeyOutput) { - return nil, &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}} + svc.client = &mockAWSClient{createKeyFunc: func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) { + return &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}, nil }} result, _, err := svc.Create(arieskms.ECDSAP256DER) @@ -210,8 +208,8 @@ func TestCreateAndPubKeyBytes(t *testing.T) { SigningAlgorithms: []*string{&signingAlgo}, }, nil }, - createKeyFunc: func(input *kms.CreateKeyInput) (req *request.Request, output *kms.CreateKeyOutput) { - return nil, &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}} + createKeyFunc: func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) { + return &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}, nil }, } @@ -298,7 +296,7 @@ func TestPubKeyBytes(t *testing.T) { svc := New(awsSession, &mockMetrics{}, "") - _, _, err = svc.ExportPubKeyBytes("key1") + _, _, err = svc.ExportPubKeyBytes("aws-kms://arn:aws:kms:key1") require.Error(t, err) require.Contains(t, err.Error(), "extracting key id from URI failed") }) @@ -357,7 +355,7 @@ func TestVerify(t *testing.T) { svc := New(awsSession, &mockMetrics{}, "") - err = svc.Verify([]byte("sign"), []byte("msg"), "key1") + err = svc.Verify([]byte("sign"), []byte("msg"), "aws-kms://arn:aws:kms:key1") require.Error(t, err) require.Contains(t, err.Error(), "extracting key id from URI failed") }) @@ -368,7 +366,7 @@ type mockAWSClient struct { getPublicKeyFunc func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error) verifyFunc func(input *kms.VerifyInput) (*kms.VerifyOutput, error) describeKeyFunc func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) - createKeyFunc func(input *kms.CreateKeyInput) (req *request.Request, output *kms.CreateKeyOutput) + createKeyFunc func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) } func (m *mockAWSClient) Sign(input *kms.SignInput) (*kms.SignOutput, error) { @@ -403,13 +401,12 @@ func (m *mockAWSClient) DescribeKey(input *kms.DescribeKeyInput) (*kms.DescribeK return nil, nil //nolint:nilnil } -func (m *mockAWSClient) CreateKeyRequest(input *kms.CreateKeyInput) (req *request.Request, - output *kms.CreateKeyOutput) { +func (m *mockAWSClient) CreateKey(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) { if m.createKeyFunc != nil { return m.createKeyFunc(input) } - return nil, nil + return nil, nil //nolint:nilnil } type mockMetrics struct{}