Skip to content
This repository has been archived by the owner on Aug 25, 2023. It is now read-only.

fix: send digest msg instead of raw to aws kms #341

Merged
merged 1 commit into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 82 additions & 42 deletions pkg/aws/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@ SPDX-License-Identifier: Apache-2.0
package aws

import (
"crypto/elliptic"
"crypto/sha512"
"encoding/asn1"
"fmt"
"hash"
"math/big"
"regexp"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/btcsuite/btcd/btcec"
arieskms "github.com/hyperledger/aries-framework-go/pkg/kms"
"github.com/minio/sha256-simd"
)

type awsClient interface {
Expand All @@ -35,18 +42,34 @@ type metricsProvider interface {
VerifyTime(value time.Duration)
}

type ecdsaSignature struct {
R, S *big.Int
}

// Service aws kms.
type Service struct {
client awsClient
metrics metricsProvider
healthCheckKeyID string
}

const (
signingAlgorithmEcdsaSha256 = "ECDSA_SHA_256"
signingAlgorithmEcdsaSha384 = "ECDSA_SHA_384"
signingAlgorithmEcdsaSha512 = "ECDSA_SHA_512"
bitSize = 8
)

// nolint: gochecknoglobals
var kmsKeyTypes = map[string]arieskms.KeyType{
"ECDSA_SHA_256": arieskms.ECDSAP256DER,
"ECDSA_SHA_384": arieskms.ECDSAP384DER,
"ECDSA_SHA_521": arieskms.ECDSAP521DER,
signingAlgorithmEcdsaSha256: arieskms.ECDSAP256DER,
signingAlgorithmEcdsaSha384: arieskms.ECDSAP384DER,
signingAlgorithmEcdsaSha512: arieskms.ECDSAP521DER,
}

// nolint: gochecknoglobals
var keySpecToCurve = map[string]elliptic.Curve{
kms.KeySpecEccSecgP256k1: btcec.S256(),
}

// New return aws service.
Expand All @@ -55,7 +78,7 @@ func New(awsSession *session.Session, metrics metricsProvider, healthCheckKeyID
}

// Sign data.
func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) { //nolint: funlen
startTime := time.Now()

defer func() {
Expand All @@ -78,10 +101,15 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
return nil, err
}

digest, err := hashMessage(msg, *describeKey.KeyMetadata.SigningAlgorithms[0])
if err != nil {
return nil, err
}

input := &kms.SignInput{
KeyId: aws.String(keyID),
Message: msg,
MessageType: aws.String("RAW"),
Message: digest,
MessageType: aws.String("DIGEST"),
SigningAlgorithm: describeKey.KeyMetadata.SigningAlgorithms[0],
}

Expand All @@ -90,6 +118,31 @@ func (s *Service) Sign(msg []byte, kh interface{}) ([]byte, error) {
return nil, err
}

if *describeKey.KeyMetadata.KeySpec == kms.KeySpecEccSecgP256k1 {
signature := ecdsaSignature{}

_, err = asn1.Unmarshal(result.Signature, &signature)
if err != nil {
return nil, err
}

curveBits := keySpecToCurve[*describeKey.KeyMetadata.KeySpec].Params().BitSize

keyBytes := curveBits / bitSize
if curveBits%bitSize > 0 {
keyBytes++
}

copyPadded := func(source []byte, size int) []byte {
dest := make([]byte, size)
copy(dest[size-len(source):], source)

return dest
}

return append(copyPadded(signature.R.Bytes(), keyBytes), copyPadded(signature.S.Bytes(), keyBytes)...), nil
}

return result.Signature, nil
}

Expand Down Expand Up @@ -146,39 +199,7 @@ func (s *Service) ExportPubKeyBytes(keyURI string) ([]byte, arieskms.KeyType, er

// Verify signature.
func (s *Service) Verify(signature, msg []byte, kh interface{}) error {
startTime := time.Now()

defer func() {
if s.metrics != nil {
s.metrics.VerifyTime(time.Since(startTime))
}
}()

if s.metrics != nil {
s.metrics.VerifyCount()
}

keyID, err := getKeyID(kh.(string))
if err != nil {
return err
}

describeKey, err := s.client.DescribeKey(&kms.DescribeKeyInput{KeyId: &keyID})
if err != nil {
return err
}

input := &kms.VerifyInput{
KeyId: aws.String(keyID),
Message: msg,
MessageType: aws.String("RAW"),
Signature: signature,
SigningAlgorithm: describeKey.KeyMetadata.SigningAlgorithms[0],
}

_, err = s.client.Verify(input)

return err
return fmt.Errorf("not implemented")
}

// Create key.
Expand All @@ -188,11 +209,11 @@ func (s *Service) Create(kt arieskms.KeyType) (string, interface{}, error) {
keySpec := ""

switch string(kt) {
case arieskms.ECDSAP256DER, arieskms.NISTP256ECDHKW:
case arieskms.ECDSAP256DER:
keySpec = kms.KeySpecEccNistP256
case arieskms.ECDSAP384DER, arieskms.NISTP384ECDHKW:
case arieskms.ECDSAP384DER:
keySpec = kms.KeySpecEccNistP384
case arieskms.ECDSAP521DER, arieskms.NISTP521ECDHKW:
case arieskms.ECDSAP521DER:
keySpec = kms.KeySpecEccNistP521
case arieskms.ECDSASecp256k1IEEEP1363:
keySpec = kms.KeySpecEccSecgP256k1
Expand Down Expand Up @@ -257,3 +278,22 @@ func getKeyID(keyURI string) (string, error) {

return r[4], nil
}

func hashMessage(message []byte, algorithm string) ([]byte, error) {
var digest hash.Hash

switch algorithm {
case signingAlgorithmEcdsaSha256:
digest = sha256.New()
case signingAlgorithmEcdsaSha384:
digest = sha512.New384()
case signingAlgorithmEcdsaSha512:
digest = sha512.New()
default:
return []byte{}, fmt.Errorf("unknown signing algorithm")
}

digest.Write(message)

return digest.Sum(nil), nil
}
72 changes: 1 addition & 71 deletions pkg/aws/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func TestSign(t *testing.T) {
return &kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")},
KeySpec: aws.String(kms.KeySpecEccNistP256),
},
}, nil
}}
Expand Down Expand Up @@ -314,77 +315,6 @@ func TestPubKeyBytes(t *testing.T) {
})
}

func TestVerify(t *testing.T) {
t.Run("success", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{verifyFunc: func(input *kms.VerifyInput) (*kms.VerifyOutput, error) {
return &kms.VerifyOutput{}, nil
}, describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
return &kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")},
},
}, nil
}}

err = svc.Verify([]byte("sign"), []byte("data"),
"aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147")
require.NoError(t, err)
})

t.Run("failed to verify", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

svc.client = &mockAWSClient{verifyFunc: func(input *kms.VerifyInput) (*kms.VerifyOutput, error) {
return nil, fmt.Errorf("failed to verify")
}, describeKeyFunc: func(input *kms.DescribeKeyInput) (*kms.DescribeKeyOutput, error) {
return &kms.DescribeKeyOutput{
KeyMetadata: &kms.KeyMetadata{
SigningAlgorithms: []*string{aws.String("ECDSA_SHA_256")},
},
}, nil
}}

err = svc.Verify([]byte("data"), []byte("msg"),
"aws-kms://arn:aws:kms:ca-central-1:111122223333:key/800d5768-3fd7-4edd-a4b8-4c81c3e4c147")
require.Error(t, err)
require.Contains(t, err.Error(), "failed to verify")
})

t.Run("failed to parse key id", func(t *testing.T) {
endpoint := localhost
awsSession, err := session.NewSession(&aws.Config{
Endpoint: &endpoint,
Region: aws.String("ca"),
CredentialsChainVerboseErrors: aws.Bool(true),
})
require.NoError(t, err)

svc := New(awsSession, &mockMetrics{}, "")

err = svc.Verify([]byte("sign"), []byte("msg"), "aws-kms://arn:aws:kms:key1")
require.Error(t, err)
require.Contains(t, err.Error(), "extracting key id from URI failed")
})
}

type mockAWSClient struct {
signFunc func(input *kms.SignInput) (*kms.SignOutput, error)
getPublicKeyFunc func(input *kms.GetPublicKeyInput) (*kms.GetPublicKeyOutput, error)
Expand Down