diff --git a/cng/aes.go b/cng/aes.go index 43bd772..1f0f57a 100644 --- a/cng/aes.go +++ b/cng/aes.go @@ -10,7 +10,6 @@ import ( "crypto/cipher" "errors" "runtime" - "sync" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" @@ -19,57 +18,38 @@ import ( const aesBlockSize = 16 -var aesCache sync.Map - type aesAlgorithm struct { h bcrypt.ALG_HANDLE allowedKeySizes []int } -type aesCacheEntry struct { - id string - mode string -} - -func loadAes(id string, mode string) (h aesAlgorithm, err error) { - if v, ok := aesCache.Load(aesCacheEntry{id, mode}); ok { - return v.(aesAlgorithm), nil - } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG) - if err != nil { - return - } - defer func() { +func loadAes(mode string) (aesAlgorithm, error) { + v, err := loadOrStoreAlg(bcrypt.AES_ALGORITHM, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (interface{}, error) { + // Windows 8 added support to set the CipherMode value on a key, + // but Windows 7 requires that it be set on the algorithm before key creation. + err := setString(bcrypt.HANDLE(h), bcrypt.CHAINING_MODE, mode) if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - h.h = 0 + return nil, err } - }() - // Windows 8 added support to set the CipherMode value on a key, - // but Windows 7 requires that it be set on the algorithm before key creation. - err = setString(bcrypt.HANDLE(h.h), bcrypt.CHAINING_MODE, mode) - if err != nil { - return - } - var info bcrypt.KEY_LENGTHS_STRUCT - var discard uint32 - err = bcrypt.GetProperty(bcrypt.HANDLE(h.h), utf16PtrFromString(bcrypt.KEY_LENGTHS), (*[unsafe.Sizeof(info)]byte)(unsafe.Pointer(&info))[:], &discard, 0) + var info bcrypt.KEY_LENGTHS_STRUCT + var discard uint32 + err = bcrypt.GetProperty(bcrypt.HANDLE(h), utf16PtrFromString(bcrypt.KEY_LENGTHS), (*[unsafe.Sizeof(info)]byte)(unsafe.Pointer(&info))[:], &discard, 0) + if err != nil { + return nil, err + } + if info.Increment == 0 || info.MinLength > info.MaxLength { + return nil, errors.New("invalid BCRYPT_KEY_LENGTHS_STRUCT") + } + var allowedKeySizes []int + for size := info.MinLength; size <= info.MaxLength; size += info.Increment { + allowedKeySizes = append(allowedKeySizes, int(size)) + } + return aesAlgorithm{h, allowedKeySizes}, nil + }) if err != nil { - return + return aesAlgorithm{}, nil } - if info.Increment == 0 || info.MinLength > info.MaxLength { - err = errors.New("invalid BCRYPT_KEY_LENGTHS_STRUCT") - return - } - for size := info.MinLength; size <= info.MaxLength; size += info.Increment { - h.allowedKeySizes = append(h.allowedKeySizes, int(size)) - } - if existing, loaded := aesCache.LoadOrStore(aesCacheEntry{id, mode}, h); loaded { - // We can safely use a provider that has already been cached in another concurrent goroutine. - bcrypt.CloseAlgorithmProvider(h.h, 0) - h = existing.(aesAlgorithm) - } - return + return v.(aesAlgorithm), nil } type aesCipher struct { @@ -78,7 +58,7 @@ type aesCipher struct { } func NewAESCipher(key []byte) (cipher.Block, error) { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_ECB) + h, err := loadAes(bcrypt.CHAIN_MODE_ECB) if err != nil { return nil, err } @@ -194,7 +174,7 @@ type aesCBC struct { } func newCBC(encrypt bool, key, iv []byte) *aesCBC { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_CBC) + h, err := loadAes(bcrypt.CHAIN_MODE_CBC) if err != nil { panic(err) } @@ -268,7 +248,7 @@ func (g *aesGCM) finalize() { } func newGCM(key []byte, tls bool) (*aesGCM, error) { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM) + h, err := loadAes(bcrypt.CHAIN_MODE_GCM) if err != nil { return nil, err } @@ -328,7 +308,7 @@ func (g *aesGCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte { info := bcrypt.NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, out[len(out)-gcmTagSize:]) var encSize uint32 - err := bcrypt.Encrypt(g.kh, plaintext, info, nil, out, &encSize, 0) + err := bcrypt.Encrypt(g.kh, plaintext, unsafe.Pointer(info), nil, out, &encSize, 0) if err != nil { panic(err) } @@ -365,7 +345,7 @@ func (g *aesGCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, er info := bcrypt.NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag) var decSize uint32 - err := bcrypt.Decrypt(g.kh, ciphertext, info, nil, out, &decSize, 0) + err := bcrypt.Decrypt(g.kh, ciphertext, unsafe.Pointer(info), nil, out, &decSize, 0) if err != nil || int(decSize) != len(ciphertext) { for i := range out { out[i] = 0 diff --git a/cng/bbig/big.go b/cng/bbig/big.go new file mode 100644 index 0000000..584f206 --- /dev/null +++ b/cng/bbig/big.go @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package bbig + +import ( + "math/big" + + "github.com/microsoft/go-crypto-winnative/cng" +) + +func Enc(b *big.Int) cng.BigInt { + if b == nil { + return nil + } + x := b.Bytes() + if len(x) == 0 { + return cng.BigInt{} + } + return x +} + +func Dec(b cng.BigInt) *big.Int { + if b == nil { + return nil + } + if len(b) == 0 { + return new(big.Int) + } + return new(big.Int).SetBytes(b) +} diff --git a/cng/big.go b/cng/big.go new file mode 100644 index 0000000..0069d31 --- /dev/null +++ b/cng/big.go @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package cng + +// This file does not have build constraints to +// facilitate using BigInt in Go crypto. +// Go crypto references BigInt unconditionally, +// even if it is not finally used. + +// A BigInt is the big-endian bytes from a math/big BigInt. +// Windows BCrypt accepts this specific data format. +// This definition allows us to avoid importing math/big. +// Conversion between BigInt and *big.Int is in cng/bbig. +type BigInt []byte diff --git a/cng/cng.go b/cng/cng.go index 7a54ab0..a30fdc9 100644 --- a/cng/cng.go +++ b/cng/cng.go @@ -10,6 +10,7 @@ import ( "math" "reflect" "runtime" + "sync" "syscall" "unsafe" @@ -25,9 +26,36 @@ func lenU32(s []byte) int { return len(s) } -type algCacheEntry struct { - id string - flags uint32 +var algCache sync.Map + +type newAlgEntryFn func(h bcrypt.ALG_HANDLE) (interface{}, error) + +func loadOrStoreAlg(id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn newAlgEntryFn) (interface{}, error) { + var entryKey = struct { + id string + flags bcrypt.AlgorithmProviderFlags + mode string + }{id, flags, mode} + + if v, ok := algCache.Load(entryKey); ok { + return v, nil + } + var h bcrypt.ALG_HANDLE + err := bcrypt.OpenAlgorithmProvider(&h, utf16PtrFromString(id), nil, flags) + if err != nil { + return nil, err + } + v, err := fn(h) + if err != nil { + bcrypt.CloseAlgorithmProvider(h, 0) + return nil, err + } + if existing, loaded := algCache.LoadOrStore(entryKey, v); loaded { + // We can safely use a provider that has already been cached in another concurrent goroutine. + bcrypt.CloseAlgorithmProvider(h, 0) + v = existing + } + return v, nil } func utf16PtrFromString(s string) *uint16 { diff --git a/cng/rsa.go b/cng/rsa.go new file mode 100644 index 0000000..49c1f8d --- /dev/null +++ b/cng/rsa.go @@ -0,0 +1,332 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build windows +// +build windows + +package cng + +import ( + "crypto" + "errors" + "hash" + "runtime" + "unsafe" + + "github.com/microsoft/go-crypto-winnative/internal/bcrypt" +) + +type rsaAlgorithm struct { + h bcrypt.ALG_HANDLE +} + +func loadRsa() (rsaAlgorithm, error) { + v, err := loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return rsaAlgorithm{h}, nil + }) + if err != nil { + return rsaAlgorithm{}, nil + } + return v.(rsaAlgorithm), nil +} + +const sizeOfRSABlobHeader = uint32(unsafe.Sizeof(bcrypt.RSAKEY_BLOB{})) + +func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { + bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { + return nil, nil, nil, nil, nil, nil, nil, nil, e + } + h, err := loadRsa() + if err != nil { + return bad(err) + } + var hkey bcrypt.KEY_HANDLE + err = bcrypt.GenerateKeyPair(h.h, &hkey, uint32(bits), 0) + if err != nil { + return bad(err) + } + defer bcrypt.DestroyKey(hkey) + // The key cannot be used until BcryptFinalizeKeyPair has been called. + err = bcrypt.FinalizeKeyPair(hkey, 0) + if err != nil { + return bad(err) + } + + var size uint32 + err = bcrypt.ExportKey(hkey, 0, utf16PtrFromString(bcrypt.RSAFULLPRIVATE_BLOB), nil, &size, 0) + if err != nil { + return bad(err) + } + + if size < sizeOfRSABlobHeader { + return bad(errors.New("crypto/rsa: exported key is corrupted")) + } + + blob := make([]byte, size) + err = bcrypt.ExportKey(hkey, 0, utf16PtrFromString(bcrypt.RSAFULLPRIVATE_BLOB), blob, &size, 0) + if err != nil { + return bad(err) + } + hdr := (*(*bcrypt.RSAKEY_BLOB)(unsafe.Pointer(&blob[0]))) + if hdr.Magic != bcrypt.RSAFULLPRIVATE_MAGIC || hdr.BitLength != uint32(bits) { + return bad(errors.New("crypto/rsa: exported key is corrupted")) + } + data := blob[sizeOfRSABlobHeader:] + consumeBigInt := func(size uint32) BigInt { + b := make(BigInt, size) + copy(b, data) + data = data[size:] + return b + } + E = consumeBigInt(hdr.PublicExpSize) + N = consumeBigInt(hdr.ModulusSize) + P = consumeBigInt(hdr.Prime1Size) + Q = consumeBigInt(hdr.Prime2Size) + Dp = consumeBigInt(hdr.Prime1Size) + Dq = consumeBigInt(hdr.Prime2Size) + Qinv = consumeBigInt(hdr.Prime1Size) + D = consumeBigInt(hdr.ModulusSize) + return +} + +type PublicKeyRSA struct { + pkey bcrypt.KEY_HANDLE +} + +func NewPublicKeyRSA(N, E BigInt) (*PublicKeyRSA, error) { + h, err := loadRsa() + if err != nil { + return nil, err + } + blob := encodeRSAKey(N, E, nil, nil, nil, nil, nil, nil) + k := new(PublicKeyRSA) + err = bcrypt.ImportKeyPair(h.h, 0, utf16PtrFromString(bcrypt.RSAPUBLIC_KEY_BLOB), &k.pkey, blob, 0) + if err != nil { + return nil, err + } + runtime.SetFinalizer(k, (*PublicKeyRSA).finalize) + return k, nil +} + +func (k *PublicKeyRSA) finalize() { + bcrypt.DestroyKey(k.pkey) +} + +type PrivateKeyRSA struct { + pkey bcrypt.KEY_HANDLE +} + +func (k *PrivateKeyRSA) finalize() { + bcrypt.DestroyKey(k.pkey) +} + +func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv BigInt) (*PrivateKeyRSA, error) { + h, err := loadRsa() + if err != nil { + return nil, err + } + blob := encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv) + k := new(PrivateKeyRSA) + err = bcrypt.ImportKeyPair(h.h, 0, utf16PtrFromString(bcrypt.RSAFULLPRIVATE_BLOB), &k.pkey, blob, 0) + if err != nil { + return nil, err + } + runtime.SetFinalizer(k, (*PrivateKeyRSA).finalize) + return k, nil +} + +func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv BigInt) []byte { + hdr := bcrypt.RSAKEY_BLOB{ + BitLength: uint32(len(N) * 8), + PublicExpSize: uint32(len(E)), + ModulusSize: uint32(len(N)), + } + var blob []byte + if D == nil { + hdr.Magic = bcrypt.RSAPUBLIC_MAGIC + blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize) + } else { + hdr.Magic = bcrypt.RSAFULLPRIVATE_MAGIC + hdr.Prime1Size = uint32(len(P)) + hdr.Prime2Size = uint32(len(Q)) + blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize*2+hdr.Prime1Size*3+hdr.Prime2Size*2) + } + copy(blob, (*(*[sizeOfRSABlobHeader]byte)(unsafe.Pointer(&hdr)))[:]) + data := blob[sizeOfRSABlobHeader:] + encode := func(b BigInt, size uint32) { + copy(data, b) + data = data[size:] + } + encode(E, hdr.PublicExpSize) + encode(N, hdr.ModulusSize) + if D != nil { + encode(P, hdr.Prime1Size) + encode(Q, hdr.Prime2Size) + encode(Dp, hdr.Prime1Size) + encode(Dq, hdr.Prime2Size) + encode(Qinv, hdr.Prime1Size) + encode(D, hdr.ModulusSize) + } + return blob +} + +func DecryptRSAOAEP(h hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + return rsaOAEP(h, priv.pkey, ciphertext, label, false) +} + +func EncryptRSAOAEP(h hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) { + defer runtime.KeepAlive(pub) + return rsaOAEP(h, pub.pkey, msg, label, true) +} + +func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + return rsaCrypt(priv.pkey, nil, ciphertext, bcrypt.PAD_PKCS1, false) +} + +func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) { + defer runtime.KeepAlive(pub) + return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_PKCS1, true) +} + +func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + return rsaCrypt(priv.pkey, nil, ciphertext, bcrypt.PAD_NONE, false) + +} + +func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) { + defer runtime.KeepAlive(pub) + return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_NONE, true) +} + +func SignRSAPSS(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte, saltLen int) ([]byte, error) { + defer runtime.KeepAlive(priv) + info, err := newPSS_PADDING_INFO(h, saltLen) + if err != nil { + return nil, err + } + return rsaSign(priv.pkey, unsafe.Pointer(&info), hashed, bcrypt.PAD_PSS) +} + +func VerifyRSAPSS(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte, saltLen int) error { + defer runtime.KeepAlive(pub) + info, err := newPSS_PADDING_INFO(h, saltLen) + if err != nil { + return err + } + return rsaVerify(pub.pkey, unsafe.Pointer(&info), hashed, sig, bcrypt.PAD_PSS) +} + +func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + info, err := newPKCS1_PADDING_INFO(h) + if err != nil { + return nil, err + } + return rsaSign(priv.pkey, unsafe.Pointer(&info), hashed, bcrypt.PAD_PKCS1) +} + +func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error { + defer runtime.KeepAlive(pub) + info, err := newPKCS1_PADDING_INFO(h) + if err != nil { + return err + } + return rsaVerify(pub.pkey, unsafe.Pointer(&info), hashed, sig, bcrypt.PAD_PKCS1) +} + +func rsaCrypt(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, in []byte, flags bcrypt.PadMode, encrypt bool) ([]byte, error) { + var size uint32 + var err error + if encrypt { + err = bcrypt.Encrypt(pkey, in, info, nil, nil, &size, flags) + } else { + err = bcrypt.Decrypt(pkey, in, info, nil, nil, &size, flags) + } + if err != nil { + return nil, err + } + out := make([]byte, size) + if encrypt { + err = bcrypt.Encrypt(pkey, in, info, nil, out, &size, flags) + } else { + err = bcrypt.Decrypt(pkey, in, info, nil, out, &size, flags) + } + if err != nil { + return nil, err + } + return out[:size], nil +} + +func rsaOAEP(h hash.Hash, pkey bcrypt.KEY_HANDLE, in, label []byte, encrypt bool) ([]byte, error) { + hashID := hashToID(h) + if hashID == "" { + return nil, errors.New("crypto/rsa: unsupported hash function") + } + info := bcrypt.OAEP_PADDING_INFO{ + AlgId: utf16PtrFromString(hashID), + LabelSize: uint32(len(label)), + } + if len(label) > 0 { + info.Label = &label[0] + } + return rsaCrypt(pkey, unsafe.Pointer(&info), in, bcrypt.PAD_OAEP, encrypt) +} + +func rsaSign(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, hashed []byte, flags bcrypt.PadMode) ([]byte, error) { + var size uint32 + err := bcrypt.SignHash(pkey, info, hashed, nil, &size, flags) + if err != nil { + return nil, err + } + out := make([]byte, size) + err = bcrypt.SignHash(pkey, info, hashed, out, &size, flags) + if err != nil { + return nil, err + } + return out[:size], nil +} + +func rsaVerify(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, hashed, sig []byte, flags bcrypt.PadMode) error { + return bcrypt.VerifySignature(pkey, info, hashed, sig, flags) +} + +func newPSS_PADDING_INFO(h crypto.Hash, saltLen int) (info bcrypt.PSS_PADDING_INFO, err error) { + hashID := cryptoHashToID(h) + if hashID == "" { + return info, errors.New("crypto/rsa: unsupported hash function") + } + info.AlgId = utf16PtrFromString(hashID) + info.Salt = uint32(saltLen) + return +} + +func newPKCS1_PADDING_INFO(h crypto.Hash) (info bcrypt.PKCS1_PADDING_INFO, err error) { + if h != 0 { + hashID := cryptoHashToID(h) + if hashID == "" { + err = errors.New("crypto/rsa: unsupported hash function") + } else { + info.AlgId = utf16PtrFromString(hashID) + } + } + return +} + +func cryptoHashToID(ch crypto.Hash) string { + switch ch { + case crypto.MD5: + return bcrypt.MD5_ALGORITHM + case crypto.SHA1: + return bcrypt.SHA1_ALGORITHM + case crypto.SHA256: + return bcrypt.SHA256_ALGORITHM + case crypto.SHA384: + return bcrypt.SHA384_ALGORITHM + case crypto.SHA512: + return bcrypt.SHA512_ALGORITHM + } + return "" +} diff --git a/cng/rsa_test.go b/cng/rsa_test.go new file mode 100644 index 0000000..c384fac --- /dev/null +++ b/cng/rsa_test.go @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build windows +// +build windows + +package cng_test + +import ( + "bytes" + "crypto" + "math/big" + "strconv" + "testing" + + "github.com/microsoft/go-crypto-winnative/cng" + "github.com/microsoft/go-crypto-winnative/cng/bbig" +) + +func newRSAKey(t *testing.T, size int) (*cng.PrivateKeyRSA, *cng.PublicKeyRSA) { + t.Helper() + N, E, D, P, Q, Dp, Dq, Qinv, err := cng.GenerateKeyRSA(size) + if err != nil { + t.Fatalf("GenerateKeyRSA(%d): %v", size, err) + } + priv, err := cng.NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) + if err != nil { + t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) + } + pub, err := cng.NewPublicKeyRSA(N, E) + if err != nil { + t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) + } + return priv, pub +} + +func TestRSAKeyGeneration(t *testing.T) { + for _, size := range []int{2048, 3072} { + t.Run(strconv.Itoa(size), func(t *testing.T) { + priv, pub := newRSAKey(t, size) + msg := []byte("hi!") + enc, err := cng.EncryptRSAPKCS1(pub, msg) + if err != nil { + t.Fatalf("EncryptPKCS1v15: %v", err) + } + dec, err := cng.DecryptRSAPKCS1(priv, enc) + if err != nil { + t.Fatalf("DecryptPKCS1v15: %v", err) + } + if !bytes.Equal(dec, msg) { + t.Fatalf("got:%x want:%x", dec, msg) + } + }) + } +} + +func TestEncryptDecryptOAEP(t *testing.T) { + sha256 := cng.NewSHA256() + msg := []byte("hi!") + label := []byte("ho!") + priv, pub := newRSAKey(t, 2048) + enc, err := cng.EncryptRSAOAEP(sha256, pub, msg, label) + if err != nil { + t.Fatal(err) + } + dec, err := cng.DecryptRSAOAEP(sha256, priv, enc, label) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(dec, msg) { + t.Errorf("got:%x want:%x", dec, msg) + } +} + +func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { + sha256 := cng.NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + enc, err := cng.EncryptRSAOAEP(sha256, pub, msg, []byte("ho!")) + if err != nil { + t.Fatal(err) + } + dec, err := cng.DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!")) + if err == nil { + t.Errorf("error expected") + } + if dec != nil { + t.Errorf("got:%x want: nil", dec) + } +} + +func TestEncryptDecryptNoPadding(t *testing.T) { + const bits = 2048 + var msg [bits / 8]byte + msg[0] = 1 + msg[255] = 1 + priv, pub := newRSAKey(t, bits) + enc, err := cng.EncryptRSANoPadding(pub, msg[:]) + if err != nil { + t.Fatal(err) + } + dec, err := cng.DecryptRSANoPadding(priv, enc) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(dec, msg[:]) { + t.Errorf("got:%x want:%x", dec, msg) + } +} + +func TestSignVerifyPKCS1v15(t *testing.T) { + sha256 := cng.NewSHA256() + priv, pub := newRSAKey(t, 2048) + sha256.Write([]byte("hi!")) + hashed := sha256.Sum(nil) + signed, err := cng.SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = cng.VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + signed, err := cng.SignRSAPKCS1v15(priv, 0, msg) + if err != nil { + t.Fatal(err) + } + err = cng.VerifyRSAPKCS1v15(pub, 0, msg, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { + sha256 := cng.NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + sha256.Write(msg) + hashed := sha256.Sum(nil) + signed, err := cng.SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = cng.VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) + if err == nil { + t.Fatal("error expected") + } +} + +func TestSignVerifyRSAPSS(t *testing.T) { + sha256 := cng.NewSHA256() + priv, pub := newRSAKey(t, 2048) + sha256.Write([]byte("testing")) + hashed := sha256.Sum(nil) + signed, err := cng.SignRSAPSS(priv, crypto.SHA256, hashed, 0) + if err != nil { + t.Fatal(err) + } + err = cng.VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0) + if err != nil { + t.Fatal(err) + } +} + +func fromBase36(base36 string) *big.Int { + i, ok := new(big.Int).SetString(base36, 36) + if !ok { + panic("bad number: " + base36) + } + return i +} + +func BenchmarkEncryptRSAPKCS1(b *testing.B) { + b.StopTimer() + // Public key length should be at least of 2048 bits, else OpenSSL will report an error when running in FIPS mode. + n := fromBase36("14314132931241006650998084889274020608918049032671858325988396851334124245188214251956198731333464217832226406088020736932173064754214329009979944037640912127943488972644697423190955557435910767690712778463524983667852819010259499695177313115447116110358524558307947613422897787329221478860907963827160223559690523660574329011927531289655711860504630573766609239332569210831325633840174683944553667352219670930408593321661375473885147973879086994006440025257225431977751512374815915392249179976902953721486040787792801849818254465486633791826766873076617116727073077821584676715609985777563958286637185868165868520557") + test2048PubKey, err := cng.NewPublicKeyRSA(bbig.Enc(n), bbig.Enc(big.NewInt(3))) + if err != nil { + b.Fatal(err) + } + b.StartTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := cng.EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkGenerateKeyRSA(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, _, _, _, _, _, _, err := cng.GenerateKeyRSA(2048) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/cng/sha.go b/cng/sha.go index dac8696..0cd8bee 100644 --- a/cng/sha.go +++ b/cng/sha.go @@ -9,7 +9,6 @@ package cng import ( "hash" "runtime" - "sync" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) @@ -105,8 +104,6 @@ func NewSHA512() hash.Hash { return newSHAX(bcrypt.SHA512_ALGORITHM, nil) } -var shaCache sync.Map - type shaAlgorithm struct { h bcrypt.ALG_HANDLE size uint32 @@ -114,39 +111,26 @@ type shaAlgorithm struct { objectLength uint32 } -func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, err error) { - if v, ok := shaCache.Load(algCacheEntry{id, uint32(flags)}); ok { - return v.(shaAlgorithm), nil - } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags) - if err != nil { - return - } - defer func() { +func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (shaAlgorithm, error) { + v, err := loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + size, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_LENGTH) if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - h.h = 0 + return nil, err } - }() - h.size, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.HASH_LENGTH) - if err != nil { - return - } - h.blockSize, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.HASH_BLOCK_LENGTH) - if err != nil { - return - } - h.objectLength, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.OBJECT_LENGTH) + blockSize, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_BLOCK_LENGTH) + if err != nil { + return nil, err + } + objectLength, err := getUint32(bcrypt.HANDLE(h), bcrypt.OBJECT_LENGTH) + if err != nil { + return nil, err + } + return shaAlgorithm{h, size, blockSize, objectLength}, nil + }) if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - return - } - if existing, loaded := shaCache.LoadOrStore(algCacheEntry{id, uint32(flags)}, h); loaded { - // We can safely use a provider that has already been cached in another concurrent goroutine. - bcrypt.CloseAlgorithmProvider(h.h, 0) - h = existing.(shaAlgorithm) + return shaAlgorithm{}, err } - return + return v.(shaAlgorithm), nil } type shaXHash struct { diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 938e988..8ace412 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -16,6 +16,8 @@ const ( SHA384_ALGORITHM = "SHA384" SHA512_ALGORITHM = "SHA512" AES_ALGORITHM = "AES" + RSA_ALGORITHM = "RSA" + MD5_ALGORITHM = "MD5" ) const ( @@ -27,12 +29,28 @@ const ( CHAIN_MODE_GCM = "ChainingModeGCM" KEY_LENGTHS = "KeyLengths" OBJECT_LENGTH = "ObjectLength" + BLOCK_LENGTH = "BlockLength" +) + +const ( + RSAPUBLIC_KEY_BLOB = "RSAPUBLICBLOB" + RSAFULLPRIVATE_BLOB = "RSAFULLPRIVATEBLOB" ) const ( USE_SYSTEM_PREFERRED_RNG = 0x00000002 ) +type PadMode uint32 + +const ( + PAD_UNDEFINED PadMode = 0x0 + PAD_NONE PadMode = 0x1 + PAD_PKCS1 PadMode = 0x2 + PAD_OAEP PadMode = 0x4 + PAD_PSS PadMode = 0x8 +) + type AlgorithmProviderFlags uint32 const ( @@ -40,6 +58,13 @@ const ( ALG_HANDLE_HMAC_FLAG AlgorithmProviderFlags = 0x00000008 ) +type KeyBlobMagicNumber uint32 + +const ( + RSAPUBLIC_MAGIC KeyBlobMagicNumber = 0x31415352 + RSAFULLPRIVATE_MAGIC KeyBlobMagicNumber = 0x33415352 +) + type ( HANDLE syscall.Handle ALG_HANDLE HANDLE @@ -89,6 +114,34 @@ func NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag []byte) *AUTHE return &info } +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_oaep_padding_info +type OAEP_PADDING_INFO struct { + AlgId *uint16 + Label *byte + LabelSize uint32 +} + +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_pkcs1_padding_info +type PKCS1_PADDING_INFO struct { + AlgId *uint16 +} + +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_pss_padding_info +type PSS_PADDING_INFO struct { + AlgId *uint16 + Salt uint32 +} + +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob +type RSAKEY_BLOB struct { + Magic KeyBlobMagicNumber + BitLength uint32 + PublicExpSize uint32 + ModulusSize uint32 + Prime1Size uint32 + Prime2Size uint32 +} + //sys SetProperty(hObject HANDLE, pszProperty *uint16, pbInput []byte, dwFlags uint32) (s error) = bcrypt.BCryptSetProperty //sys GetProperty(hObject HANDLE, pszProperty *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptGetProperty //sys OpenAlgorithmProvider(phAlgorithm *ALG_HANDLE, pszAlgId *uint16, pszImplementation *uint16, dwFlags AlgorithmProviderFlags) (s error) = bcrypt.BCryptOpenAlgorithmProvider @@ -109,6 +162,12 @@ func NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag []byte) *AUTHE // Keys //sys GenerateSymmetricKey(hAlgorithm ALG_HANDLE, phKey *KEY_HANDLE, pbKeyObject []byte, pbSecret []byte, dwFlags uint32) (s error) = bcrypt.BCryptGenerateSymmetricKey +//sys GenerateKeyPair(hAlgorithm ALG_HANDLE, phKey *KEY_HANDLE, dwLength uint32, dwFlags uint32) (s error) = bcrypt.BCryptGenerateKeyPair +//sys FinalizeKeyPair(hKey KEY_HANDLE, dwFlags uint32) (s error) = bcrypt.BCryptFinalizeKeyPair +//sys ImportKeyPair (hAlgorithm ALG_HANDLE, hImportKey KEY_HANDLE, pszBlobType *uint16, phKey *KEY_HANDLE, pbInput []byte, dwFlags uint32) (s error) = bcrypt.BCryptImportKeyPair +//sys ExportKey(hKey KEY_HANDLE, hExportKey KEY_HANDLE, pszBlobType *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptExportKey //sys DestroyKey(hKey KEY_HANDLE) (s error) = bcrypt.BCryptDestroyKey -//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptEncrypt -//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptDecrypt +//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptEncrypt +//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptDecrypt +//sys SignHash (hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbInput []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptSignHash +//sys VerifySignature(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbHash []byte, pbSignature []byte, dwFlags PadMode) (s error) = bcrypt.BCryptVerifySignature diff --git a/internal/bcrypt/zsyscall_windows.go b/internal/bcrypt/zsyscall_windows.go index e8a0655..5fe28c5 100644 --- a/internal/bcrypt/zsyscall_windows.go +++ b/internal/bcrypt/zsyscall_windows.go @@ -46,13 +46,19 @@ var ( procBCryptDestroyKey = modbcrypt.NewProc("BCryptDestroyKey") procBCryptDuplicateHash = modbcrypt.NewProc("BCryptDuplicateHash") procBCryptEncrypt = modbcrypt.NewProc("BCryptEncrypt") + procBCryptExportKey = modbcrypt.NewProc("BCryptExportKey") + procBCryptFinalizeKeyPair = modbcrypt.NewProc("BCryptFinalizeKeyPair") procBCryptFinishHash = modbcrypt.NewProc("BCryptFinishHash") procBCryptGenRandom = modbcrypt.NewProc("BCryptGenRandom") + procBCryptGenerateKeyPair = modbcrypt.NewProc("BCryptGenerateKeyPair") procBCryptGenerateSymmetricKey = modbcrypt.NewProc("BCryptGenerateSymmetricKey") procBCryptGetProperty = modbcrypt.NewProc("BCryptGetProperty") procBCryptHashData = modbcrypt.NewProc("BCryptHashData") + procBCryptImportKeyPair = modbcrypt.NewProc("BCryptImportKeyPair") procBCryptOpenAlgorithmProvider = modbcrypt.NewProc("BCryptOpenAlgorithmProvider") procBCryptSetProperty = modbcrypt.NewProc("BCryptSetProperty") + procBCryptSignHash = modbcrypt.NewProc("BCryptSignHash") + procBCryptVerifySignature = modbcrypt.NewProc("BCryptVerifySignature") ) func CloseAlgorithmProvider(hAlgorithm ALG_HANDLE, dwFlags uint32) (s error) { @@ -79,7 +85,7 @@ func CreateHash(hAlgorithm ALG_HANDLE, phHash *HASH_HANDLE, pbHashObject []byte, return } -func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) { +func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { var _p0 *byte if len(pbInput) > 0 { _p0 = &pbInput[0] @@ -92,7 +98,7 @@ func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER if len(pbOutput) > 0 { _p2 = &pbOutput[0] } - r0, _, _ := syscall.Syscall12(procBCryptDecrypt.Addr(), 10, uintptr(hKey), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(unsafe.Pointer(pPaddingInfo)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbIV)), uintptr(unsafe.Pointer(_p2)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) + r0, _, _ := syscall.Syscall12(procBCryptDecrypt.Addr(), 10, uintptr(hKey), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbIV)), uintptr(unsafe.Pointer(_p2)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) if r0 != 0 { s = syscall.Errno(r0) } @@ -127,7 +133,7 @@ func DuplicateHash(hHash HASH_HANDLE, phNewHash *HASH_HANDLE, pbHashObject []byt return } -func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) { +func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { var _p0 *byte if len(pbInput) > 0 { _p0 = &pbInput[0] @@ -140,7 +146,27 @@ func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER if len(pbOutput) > 0 { _p2 = &pbOutput[0] } - r0, _, _ := syscall.Syscall12(procBCryptEncrypt.Addr(), 10, uintptr(hKey), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(unsafe.Pointer(pPaddingInfo)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbIV)), uintptr(unsafe.Pointer(_p2)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) + r0, _, _ := syscall.Syscall12(procBCryptEncrypt.Addr(), 10, uintptr(hKey), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbIV)), uintptr(unsafe.Pointer(_p2)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + +func ExportKey(hKey KEY_HANDLE, hExportKey KEY_HANDLE, pszBlobType *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) { + var _p0 *byte + if len(pbOutput) > 0 { + _p0 = &pbOutput[0] + } + r0, _, _ := syscall.Syscall9(procBCryptExportKey.Addr(), 7, uintptr(hKey), uintptr(hExportKey), uintptr(unsafe.Pointer(pszBlobType)), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + +func FinalizeKeyPair(hKey KEY_HANDLE, dwFlags uint32) (s error) { + r0, _, _ := syscall.Syscall(procBCryptFinalizeKeyPair.Addr(), 2, uintptr(hKey), uintptr(dwFlags), 0) if r0 != 0 { s = syscall.Errno(r0) } @@ -171,6 +197,14 @@ func GenRandom(hAlgorithm ALG_HANDLE, pbBuffer []byte, dwFlags uint32) (s error) return } +func GenerateKeyPair(hAlgorithm ALG_HANDLE, phKey *KEY_HANDLE, dwLength uint32, dwFlags uint32) (s error) { + r0, _, _ := syscall.Syscall6(procBCryptGenerateKeyPair.Addr(), 4, uintptr(hAlgorithm), uintptr(unsafe.Pointer(phKey)), uintptr(dwLength), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + func GenerateSymmetricKey(hAlgorithm ALG_HANDLE, phKey *KEY_HANDLE, pbKeyObject []byte, pbSecret []byte, dwFlags uint32) (s error) { var _p0 *byte if len(pbKeyObject) > 0 { @@ -211,6 +245,18 @@ func HashData(hHash HASH_HANDLE, pbInput []byte, dwFlags uint32) (s error) { return } +func ImportKeyPair(hAlgorithm ALG_HANDLE, hImportKey KEY_HANDLE, pszBlobType *uint16, phKey *KEY_HANDLE, pbInput []byte, dwFlags uint32) (s error) { + var _p0 *byte + if len(pbInput) > 0 { + _p0 = &pbInput[0] + } + r0, _, _ := syscall.Syscall9(procBCryptImportKeyPair.Addr(), 7, uintptr(hAlgorithm), uintptr(hImportKey), uintptr(unsafe.Pointer(pszBlobType)), uintptr(unsafe.Pointer(phKey)), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + func OpenAlgorithmProvider(phAlgorithm *ALG_HANDLE, pszAlgId *uint16, pszImplementation *uint16, dwFlags AlgorithmProviderFlags) (s error) { r0, _, _ := syscall.Syscall6(procBCryptOpenAlgorithmProvider.Addr(), 4, uintptr(unsafe.Pointer(phAlgorithm)), uintptr(unsafe.Pointer(pszAlgId)), uintptr(unsafe.Pointer(pszImplementation)), uintptr(dwFlags), 0, 0) if r0 != 0 { @@ -230,3 +276,35 @@ func SetProperty(hObject HANDLE, pszProperty *uint16, pbInput []byte, dwFlags ui } return } + +func SignHash(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbInput []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { + var _p0 *byte + if len(pbInput) > 0 { + _p0 = &pbInput[0] + } + var _p1 *byte + if len(pbOutput) > 0 { + _p1 = &pbOutput[0] + } + r0, _, _ := syscall.Syscall9(procBCryptSignHash.Addr(), 8, uintptr(hKey), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + +func VerifySignature(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbHash []byte, pbSignature []byte, dwFlags PadMode) (s error) { + var _p0 *byte + if len(pbHash) > 0 { + _p0 = &pbHash[0] + } + var _p1 *byte + if len(pbSignature) > 0 { + _p1 = &pbSignature[0] + } + r0, _, _ := syscall.Syscall9(procBCryptVerifySignature.Addr(), 7, uintptr(hKey), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbHash)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbSignature)), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +}