Skip to content
This repository has been archived by the owner on Aug 25, 2023. It is now read-only.

Commit

Permalink
fix: send digest msg instead of raw to aws kms
Browse files Browse the repository at this point in the history
Signed-off-by: Firas Qutishat <[email protected]>
  • Loading branch information
fqutishat committed Oct 23, 2022
1 parent 412f152 commit 7e11a5d
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 110 deletions.
118 changes: 79 additions & 39 deletions pkg/aws/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@ SPDX-License-Identifier: Apache-2.0
package aws

import (
"crypto/elliptic"
"crypto/sha512"
"encoding/asn1"
"fmt"
"hash"
"math/big"
"regexp"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/btcsuite/btcd/btcec"
arieskms "github.com/hyperledger/aries-framework-go/pkg/kms"
"github.com/minio/sha256-simd"
)

type awsClient interface {
Expand All @@ -35,18 +42,34 @@ type metricsProvider interface {
VerifyTime(value time.Duration)
}

type ecdsaSignature struct {
R, S *big.Int
}

// Service aws kms.
type Service struct {
client awsClient
metrics metricsProvider
healthCheckKeyID string
}

const (
signingAlgorithmEcdsaSha256 = "ECDSA_SHA_256"
signingAlgorithmEcdsaSha384 = "ECDSA_SHA_384"
signingAlgorithmEcdsaSha512 = "ECDSA_SHA_512"
bitSize = 8
)

// nolint: gochecknoglobals
var kmsKeyTypes = map[string]arieskms.KeyType{
"ECDSA_SHA_256": arieskms.ECDSAP256DER,
"ECDSA_SHA_384": arieskms.ECDSAP384DER,
"ECDSA_SHA_521": arieskms.ECDSAP521DER,
signingAlgorithmEcdsaSha256: arieskms.ECDSAP256DER,
signingAlgorithmEcdsaSha384: arieskms.ECDSAP384DER,
signingAlgorithmEcdsaSha512: arieskms.ECDSAP521DER,
}

// nolint: gochecknoglobals
var keySpecToCurve = map[string]elliptic.Curve{
kms.KeySpecEccSecgP256k1: btcec.S256(),
}

// New return aws service.
Expand All @@ -55,7 +78,7 @@ func New(awsSession *session.Session, metrics metricsProvider, healthCheckKeyID
}

// Sign data.
func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) { //nolint: funlen
startTime := time.Now()

defer func() {
Expand All @@ -78,10 +101,15 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
return nil, err
}

digest, err := hashMessage(msg, *describeKey.KeyMetadata.SigningAlgorithms[0])
if err != nil {
return nil, err
}

input := &kms.SignInput{
KeyId: aws.String(keyID),
Message: msg,
MessageType: aws.String("RAW"),
Message: digest,
MessageType: aws.String("DIGEST"),
SigningAlgorithm: describeKey.KeyMetadata.SigningAlgorithms[0],
}

Expand All @@ -90,6 +118,31 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
return nil, err
}

if *describeKey.KeyMetadata.KeySpec == kms.KeySpecEccSecgP256k1 {
signature := ecdsaSignature{}

_, err = asn1.Unmarshal(result.Signature, &signature)
if err != nil {
return nil, err
}

curveBits := keySpecToCurve[*describeKey.KeyMetadata.KeySpec].Params().BitSize

keyBytes := curveBits / bitSize
if curveBits%bitSize > 0 {
keyBytes++
}

copyPadded := func(source []byte, size int) []byte {
dest := make([]byte, size)
copy(dest[size-len(source):], source)

return dest
}

return append(copyPadded(signature.R.Bytes(), keyBytes), copyPadded(signature.S.Bytes(), keyBytes)...), nil
}

return result.Signature, nil
}

Expand Down Expand Up @@ -146,39 +199,7 @@ func (s *Service) ExportPubKeyBytes(keyURI string) ([]byte, arieskms.KeyType, er

// Verify signature.
func (s *Service) Verify(signature, msg []byte, kh interface{}) error {
startTime := time.Now()

defer func() {
if s.metrics != nil {
s.metrics.VerifyTime(time.Since(startTime))
}
}()

if s.metrics != nil {
s.metrics.VerifyCount()
}

keyID, err := getKeyID(kh.(string))
if err != nil {
return err
}

describeKey, err := s.client.DescribeKey(&kms.DescribeKeyInput{KeyId: &keyID})
if err != nil {
return err
}

input := &kms.VerifyInput{
KeyId: aws.String(keyID),
Message: msg,
MessageType: aws.String("RAW"),
Signature: signature,
SigningAlgorithm: describeKey.KeyMetadata.SigningAlgorithms[0],
}

_, err = s.client.Verify(input)

return err
return fmt.Errorf("not implemented")
}

// Create key.
Expand Down Expand Up @@ -257,3 +278,22 @@ func getKeyID(keyURI string) (string, error) {

return r[4], nil
}

func hashMessage(message []byte, algorithm string) ([]byte, error) {
var digest hash.Hash

switch algorithm {
case signingAlgorithmEcdsaSha256:
digest = sha256.New()
case signingAlgorithmEcdsaSha384:
digest = sha512.New384()
case signingAlgorithmEcdsaSha512:
digest = sha512.New()
default:
return []byte{}, fmt.Errorf("unknown signing algorithm")
}

digest.Write(message)

return digest.Sum(nil), nil
}
72 changes: 1 addition & 71 deletions pkg/aws/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func TestSign(t *testing.T) {
return &kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")},
KeySpec: aws.String(kms.KeySpecEccNistP256),
},
}, nil
}}
Expand Down Expand Up @@ -314,77 +315,6 @@ func TestPubKeyBytes(t *testing.T) {
})
}

func TestVerify(t *testing.T) {
t.Run("success", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{verifyFunc: func(input *kms.VerifyInput) (*kms.VerifyOutput, error) {
return &kms.VerifyOutput{}, nil
}, describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
return &kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")},
},
}, nil
}}

err = svc.Verify([]byte("sign"), []byte("data"),
"aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147")
require.NoError(t, err)
})

t.Run("failed to verify", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{verifyFunc: func(input *kms.VerifyInput) (*kms.VerifyOutput, error) {
return nil, fmt.Errorf("failed to verify")
}, describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
return &kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")},
},
}, nil
}}

err = svc.Verify([]byte("data"), []byte("msg"),
"aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147")
require.Error(t, err)
require.Contains(t, err.Error(), "failed to verify")
})

t.Run("failed to parse key id", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

err = svc.Verify([]byte("sign"), []byte("msg"), "aws-kms://arn:aws:kms:key1")
require.Error(t, err)
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
}

type mockAWSClient struct {
signFunc func(input *kms.SignInput) (*kms.SignOutput, error)
getPublicKeyFunc func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
Expand Down

0 comments on commit 7e11a5d

Please sign in to comment.