Skip to content
This repository has been archived by the owner on Aug 25, 2023. It is now read-only.

Commit

Permalink
Merge pull request #338 from fqutishat/update
Browse files Browse the repository at this point in the history
fix: aws kms create method
  • Loading branch information
fqutishat authored Sep 26, 2022
2 parents 25df4d4 + ba4fcc1 commit 06e5709
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
36 changes: 27 additions & 9 deletions pkg/aws/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
arieskms "github.com/hyperledger/aries-framework-go/pkg/kms"
Expand All @@ -24,7 +23,7 @@ type awsClient interface {
GetPublicKey(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
Verify(input *kms.VerifyInput) (*kms.VerifyOutput, error)
DescribeKey(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error)
CreateKeyRequest(input *kms.CreateKeyInput) (req *request.Request, output *kms.CreateKeyOutput)
CreateKey(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error)
}

type metricsProvider interface {
Expand Down Expand Up @@ -60,10 +59,14 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
startTime := time.Now()

defer func() {
s.metrics.SignTime(time.Since(startTime))
if s.metrics != nil {
s.metrics.SignTime(time.Since(startTime))
}
}()

s.metrics.SignCount()
if s.metrics != nil {
s.metrics.SignCount()
}

keyID, err := getKeyID(kh.(string))
if err != nil {
Expand Down Expand Up @@ -110,10 +113,14 @@ func (s *Service) ExportPubKeyBytes(keyURI string) ([]byte, arieskms.KeyType, er
startTime := time.Now()

defer func() {
s.metrics.ExportPublicKeyTime(time.Since(startTime))
if s.metrics != nil {
s.metrics.ExportPublicKeyTime(time.Since(startTime))
}
}()

s.metrics.ExportPublicKeyCount()
if s.metrics != nil {
s.metrics.ExportPublicKeyCount()
}

keyID, err := getKeyID(keyURI)
if err != nil {
Expand All @@ -137,10 +144,14 @@ func (s *Service) Verify(signature, msg []byte, kh interface{}) error {
startTime := time.Now()

defer func() {
s.metrics.VerifyTime(time.Since(startTime))
if s.metrics != nil {
s.metrics.VerifyTime(time.Since(startTime))
}
}()

s.metrics.VerifyCount()
if s.metrics != nil {
s.metrics.VerifyCount()
}

keyID, err := getKeyID(kh.(string))
if err != nil {
Expand Down Expand Up @@ -177,7 +188,10 @@ func (s *Service) Create(kt arieskms.KeyType) (string, interface{}, error) {
return "", nil, fmt.Errorf("key not supported %s", kt)
}

_, result := s.client.CreateKeyRequest(&kms.CreateKeyInput{KeySpec: &keySpec, KeyUsage: &keyUsage})
result, err := s.client.CreateKey(&kms.CreateKeyInput{KeySpec: &keySpec, KeyUsage: &keyUsage})
if err != nil {
return "", nil, err
}

return *result.KeyMetadata.KeyId, *result.KeyMetadata.KeyId, nil
}
Expand Down Expand Up @@ -209,6 +223,10 @@ func (s *Service) SignMulti(messages [][]byte, kh interface{}) ([]byte, error) {
}

func getKeyID(keyURI string) (string, error) {
if !strings.Contains(keyURI, "aws-kms") {
return keyURI, nil
}

// keyURI must have the following format: 'aws-kms://arn:<partition>:kms:<region>:[:path]'.
// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html.
re1 := regexp.MustCompile(`aws-kms://arn:(aws[a-zA-Z0-9-_]*):kms:([a-z0-9-]+):([a-z0-9-]+):key/(.+)`)
Expand Down
23 changes: 10 additions & 13 deletions pkg/aws/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
arieskms "github.com/hyperledger/aries-framework-go/pkg/kms"
Expand Down Expand Up @@ -79,7 +78,7 @@ func TestSign(t *testing.T) {

svc := New(awsSession, &mockMetrics{}, "")

_, err = svc.Sign([]byte("msg"), "key1")
_, err = svc.Sign([]byte("msg"), "aws-kms://arn:aws:kms:key1")
require.Error(t, err)
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
Expand Down Expand Up @@ -142,9 +141,8 @@ func TestCreate(t *testing.T) {

keyID := "key1"

svc.client = &mockAWSClient{createKeyFunc: func(input *kms.CreateKeyInput) (req *request.Request,
output *kms.CreateKeyOutput) {
return nil, &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}
svc.client = &mockAWSClient{createKeyFunc: func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) {
return &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}, nil
}}

result, _, err := svc.Create(arieskms.ECDSAP256DER)
Expand Down Expand Up @@ -210,8 +208,8 @@ func TestCreateAndPubKeyBytes(t *testing.T) {
SigningAlgorithms: []*string{&signingAlgo},
}, nil
},
createKeyFunc: func(input *kms.CreateKeyInput) (req *request.Request, output *kms.CreateKeyOutput) {
return nil, &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}
createKeyFunc: func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) {
return &kms.CreateKeyOutput{KeyMetadata: &kms.KeyMetadata{KeyId: &keyID}}, nil
},
}

Expand Down Expand Up @@ -298,7 +296,7 @@ func TestPubKeyBytes(t *testing.T) {

svc := New(awsSession, &mockMetrics{}, "")

_, _, err = svc.ExportPubKeyBytes("key1")
_, _, err = svc.ExportPubKeyBytes("aws-kms://arn:aws:kms:key1")
require.Error(t, err)
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
Expand Down Expand Up @@ -357,7 +355,7 @@ func TestVerify(t *testing.T) {

svc := New(awsSession, &mockMetrics{}, "")

err = svc.Verify([]byte("sign"), []byte("msg"), "key1")
err = svc.Verify([]byte("sign"), []byte("msg"), "aws-kms://arn:aws:kms:key1")
require.Error(t, err)
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
Expand All @@ -368,7 +366,7 @@ type mockAWSClient struct {
getPublicKeyFunc func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
verifyFunc func(input *kms.VerifyInput) (*kms.VerifyOutput, error)
describeKeyFunc func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error)
createKeyFunc func(input *kms.CreateKeyInput) (req *request.Request, output *kms.CreateKeyOutput)
createKeyFunc func(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error)
}

func (m *mockAWSClient) Sign(input *kms.SignInput) (*kms.SignOutput, error) {
Expand Down Expand Up @@ -403,13 +401,12 @@ func (m *mockAWSClient) DescribeKey(input *kms.DescribeKeyInput) (*kms.DescribeK
return nil, nil //nolint:nilnil
}

func (m *mockAWSClient) CreateKeyRequest(input *kms.CreateKeyInput) (req *request.Request,
output *kms.CreateKeyOutput) {
func (m *mockAWSClient) CreateKey(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) {
if m.createKeyFunc != nil {
return m.createKeyFunc(input)
}

return nil, nil
return nil, nil //nolint:nilnil
}

type mockMetrics struct{}
Expand Down

0 comments on commit 06e5709

Please sign in to comment.