From eb1754058e93c6a6698ec30e4745319817e93219 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 17 Mar 2022 13:49:33 +0100 Subject: [PATCH 1/9] implement RSA public and private keys --- cng/rsa.go | 173 ++++++++++++++++++++++++++++ cng/rsa_test.go | 31 +++++ internal/bcrypt/bcrypt_windows.go | 27 +++++ internal/bcrypt/zsyscall_windows.go | 44 +++++++ 4 files changed, 275 insertions(+) create mode 100644 cng/rsa.go create mode 100644 cng/rsa_test.go diff --git a/cng/rsa.go b/cng/rsa.go new file mode 100644 index 0000000..163dfdb --- /dev/null +++ b/cng/rsa.go @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build windows +// +build windows + +package cng + +import ( + "math/big" + "runtime" + "sync" + "unsafe" + + "github.com/microsoft/go-crypto-winnative/internal/bcrypt" +) + +var rsaCache sync.Map + +type rsaAlgorithm struct { + h bcrypt.ALG_HANDLE +} + +func loadRsa(id string, flags bcrypt.AlgorithmProviderFlags) (h rsaAlgorithm, err error) { + if v, ok := rsaCache.Load(algCacheEntry{id, uint32(flags)}); ok { + return v.(rsaAlgorithm), nil + } + err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags) + if err != nil { + return + } + rsaCache.Store(algCacheEntry{id, uint32(flags)}, h) + return +} + +const sizeOfRSABlobHeader = uint32(unsafe.Sizeof(bcrypt.RSAKEY_BLOB{})) + +func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { + bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { + return nil, nil, nil, nil, nil, nil, nil, nil, e + } + h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + if err != nil { + return bad(err) + } + var hkey bcrypt.KEY_HANDLE + err = bcrypt.GenerateKeyPair(h.h, &hkey, uint32(bits), 0) + if err != nil { + return bad(err) + } + defer bcrypt.DestroyKey(hkey) + // The key cannot be used until BcryptFinalizeKeyPair has been called. + err = bcrypt.FinalizeKeyPair(hkey, 0) + if err != nil { + return bad(err) + } + + var size uint32 + err = bcrypt.ExportKey(hkey, 0, utf16PtrFromString(bcrypt.RSAFULLPRIVATE_BLOB), nil, &size, 0) + if err != nil { + return bad(err) + } + + blob := make([]byte, size) + err = bcrypt.ExportKey(hkey, 0, utf16PtrFromString(bcrypt.RSAFULLPRIVATE_BLOB), blob, &size, 0) + if err != nil { + return bad(err) + } + hdr := (*(*bcrypt.RSAKEY_BLOB)(unsafe.Pointer(&blob[0]))) + if hdr.Magic != bcrypt.RSAFULLPRIVATE_MAGIC || hdr.BitLength != uint32(bits) { + panic("crypto/rsa: exported key is corrupted") + } + data := blob[sizeOfRSABlobHeader:] + newInt := func(size uint32) *big.Int { + b := new(big.Int).SetBytes(data[:size]) + data = data[size:] + return b + } + E = newInt(hdr.PublicExpSize) + N = newInt(hdr.ModulusSize) + P = newInt(hdr.Prime1Size) + Q = newInt(hdr.Prime2Size) + Dp = newInt(hdr.Prime1Size) + Dq = newInt(hdr.Prime2Size) + Qinv = newInt(hdr.Prime1Size) + D = newInt(hdr.ModulusSize) + return +} + +type PublicKeyRSA struct { + pkey bcrypt.KEY_HANDLE +} + +func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) { + h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + if err != nil { + return nil, err + } + blob := encodeRSAKey(N, E, nil, nil, nil, nil, nil, nil) + k := new(PublicKeyRSA) + err = bcrypt.ImportKeyPair(h.h, 0, utf16PtrFromString(bcrypt.RSAPUBLIC_KEY_BLOB), &k.pkey, blob, 0) + if err != nil { + return nil, err + } + runtime.SetFinalizer(k, (*PublicKeyRSA).finalize) + return k, nil +} + +func (k *PublicKeyRSA) finalize() { + bcrypt.DestroyKey(k.pkey) +} + +type PrivateKeyRSA struct { + pkey bcrypt.KEY_HANDLE +} + +func (k *PrivateKeyRSA) finalize() { + bcrypt.DestroyKey(k.pkey) +} + +func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) { + h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + if err != nil { + return nil, err + } + blob := encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv) + k := new(PrivateKeyRSA) + err = bcrypt.ImportKeyPair(h.h, 0, utf16PtrFromString(bcrypt.RSAFULLPRIVATE_BLOB), &k.pkey, blob, 0) + if err != nil { + return nil, err + } + runtime.SetFinalizer(k, (*PrivateKeyRSA).finalize) + return k, nil +} + +func bigIntBytesLen(b *big.Int) uint32 { + return uint32(b.BitLen()+7) / 8 +} + +func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte { + hdr := bcrypt.RSAKEY_BLOB{ + BitLength: uint32(N.BitLen()), + PublicExpSize: bigIntBytesLen(E), + ModulusSize: bigIntBytesLen(N), + } + var blob []byte + if D == nil { + hdr.Magic = bcrypt.RSAPUBLIC_MAGIC + blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize) + } else { + hdr.Magic = bcrypt.RSAFULLPRIVATE_MAGIC + hdr.Prime1Size = bigIntBytesLen(P) + hdr.Prime2Size = bigIntBytesLen(Q) + blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize*2+hdr.Prime1Size*3+hdr.Prime2Size*2) + } + copy(blob[:sizeOfRSABlobHeader], (*(*[1<<31 - 1]byte)(unsafe.Pointer(&hdr)))[:sizeOfRSABlobHeader]) + data := blob[sizeOfRSABlobHeader:] + encode := func(b *big.Int, size uint32) { + b.FillBytes(data[:size]) + data = data[size:] + } + encode(E, hdr.PublicExpSize) + encode(N, hdr.ModulusSize) + if D != nil { + encode(P, hdr.Prime1Size) + encode(Q, hdr.Prime2Size) + encode(Dp, hdr.Prime1Size) + encode(Dq, hdr.Prime2Size) + encode(Qinv, hdr.Prime1Size) + encode(D, hdr.ModulusSize) + } + return blob +} diff --git a/cng/rsa_test.go b/cng/rsa_test.go new file mode 100644 index 0000000..3aaae30 --- /dev/null +++ b/cng/rsa_test.go @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build windows +// +build windows + +package cng + +import ( + "strconv" + "testing" +) + +func TestRSAKeyGeneration(t *testing.T) { + for _, size := range []int{2048, 3072} { + t.Run(strconv.Itoa(size), func(t *testing.T) { + N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size) + if err != nil { + t.Fatalf("GenerateKeyRSA(%d): %v", size, err) + } + _, err = NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) + if err != nil { + t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) + } + _, err = NewPublicKeyRSA(N, E) + if err != nil { + t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) + } + }) + } +} diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 938e988..87e3299 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -16,6 +16,7 @@ const ( SHA384_ALGORITHM = "SHA384" SHA512_ALGORITHM = "SHA512" AES_ALGORITHM = "AES" + RSA_ALGORITHM = "RSA" ) const ( @@ -29,6 +30,11 @@ const ( OBJECT_LENGTH = "ObjectLength" ) +const ( + RSAPUBLIC_KEY_BLOB = "RSAPUBLICBLOB" + RSAFULLPRIVATE_BLOB = "RSAFULLPRIVATEBLOB" +) + const ( USE_SYSTEM_PREFERRED_RNG = 0x00000002 ) @@ -40,6 +46,13 @@ const ( ALG_HANDLE_HMAC_FLAG AlgorithmProviderFlags = 0x00000008 ) +type KeyBlobMagicNumber uint32 + +const ( + RSAPUBLIC_MAGIC KeyBlobMagicNumber = 0x31415352 + RSAFULLPRIVATE_MAGIC KeyBlobMagicNumber = 0x33415352 +) + type ( HANDLE syscall.Handle ALG_HANDLE HANDLE @@ -89,6 +102,16 @@ func NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag []byte) *AUTHE return &info } +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob +type RSAKEY_BLOB struct { + Magic KeyBlobMagicNumber + BitLength uint32 + PublicExpSize uint32 + ModulusSize uint32 + Prime1Size uint32 + Prime2Size uint32 +} + //sys SetProperty(hObject HANDLE, pszProperty *uint16, pbInput []byte, dwFlags uint32) (s error) = bcrypt.BCryptSetProperty //sys GetProperty(hObject HANDLE, pszProperty *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptGetProperty //sys OpenAlgorithmProvider(phAlgorithm *ALG_HANDLE, pszAlgId *uint16, pszImplementation *uint16, dwFlags AlgorithmProviderFlags) (s error) = bcrypt.BCryptOpenAlgorithmProvider @@ -109,6 +132,10 @@ func NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag []byte) *AUTHE // Keys //sys GenerateSymmetricKey(hAlgorithm ALG_HANDLE, phKey *KEY_HANDLE, pbKeyObject []byte, pbSecret []byte, dwFlags uint32) (s error) = bcrypt.BCryptGenerateSymmetricKey +//sys GenerateKeyPair(hAlgorithm ALG_HANDLE, phKey *KEY_HANDLE, dwLength uint32, dwFlags uint32) (s error) = bcrypt.BCryptGenerateKeyPair +//sys FinalizeKeyPair(hKey KEY_HANDLE, dwFlags uint32) (s error) = bcrypt.BCryptFinalizeKeyPair +//sys ImportKeyPair (hAlgorithm ALG_HANDLE, hImportKey KEY_HANDLE, pszBlobType *uint16, phKey *KEY_HANDLE, pbInput []byte, dwFlags uint32) (s error) = bcrypt.BCryptImportKeyPair +//sys ExportKey(hKey KEY_HANDLE, hExportKey KEY_HANDLE, pszBlobType *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptExportKey //sys DestroyKey(hKey KEY_HANDLE) (s error) = bcrypt.BCryptDestroyKey //sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptEncrypt //sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptDecrypt diff --git a/internal/bcrypt/zsyscall_windows.go b/internal/bcrypt/zsyscall_windows.go index e8a0655..549339f 100644 --- a/internal/bcrypt/zsyscall_windows.go +++ b/internal/bcrypt/zsyscall_windows.go @@ -46,11 +46,15 @@ var ( procBCryptDestroyKey = modbcrypt.NewProc("BCryptDestroyKey") procBCryptDuplicateHash = modbcrypt.NewProc("BCryptDuplicateHash") procBCryptEncrypt = modbcrypt.NewProc("BCryptEncrypt") + procBCryptExportKey = modbcrypt.NewProc("BCryptExportKey") + procBCryptFinalizeKeyPair = modbcrypt.NewProc("BCryptFinalizeKeyPair") procBCryptFinishHash = modbcrypt.NewProc("BCryptFinishHash") procBCryptGenRandom = modbcrypt.NewProc("BCryptGenRandom") + procBCryptGenerateKeyPair = modbcrypt.NewProc("BCryptGenerateKeyPair") procBCryptGenerateSymmetricKey = modbcrypt.NewProc("BCryptGenerateSymmetricKey") procBCryptGetProperty = modbcrypt.NewProc("BCryptGetProperty") procBCryptHashData = modbcrypt.NewProc("BCryptHashData") + procBCryptImportKeyPair = modbcrypt.NewProc("BCryptImportKeyPair") procBCryptOpenAlgorithmProvider = modbcrypt.NewProc("BCryptOpenAlgorithmProvider") procBCryptSetProperty = modbcrypt.NewProc("BCryptSetProperty") ) @@ -147,6 +151,26 @@ func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER return } +func ExportKey(hKey KEY_HANDLE, hExportKey KEY_HANDLE, pszBlobType *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) { + var _p0 *byte + if len(pbOutput) > 0 { + _p0 = &pbOutput[0] + } + r0, _, _ := syscall.Syscall9(procBCryptExportKey.Addr(), 7, uintptr(hKey), uintptr(hExportKey), uintptr(unsafe.Pointer(pszBlobType)), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + +func FinalizeKeyPair(hKey KEY_HANDLE, dwFlags uint32) (s error) { + r0, _, _ := syscall.Syscall(procBCryptFinalizeKeyPair.Addr(), 2, uintptr(hKey), uintptr(dwFlags), 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + func FinishHash(hHash HASH_HANDLE, pbOutput []byte, dwFlags uint32) (s error) { var _p0 *byte if len(pbOutput) > 0 { @@ -171,6 +195,14 @@ func GenRandom(hAlgorithm ALG_HANDLE, pbBuffer []byte, dwFlags uint32) (s error) return } +func GenerateKeyPair(hAlgorithm ALG_HANDLE, phKey *KEY_HANDLE, dwLength uint32, dwFlags uint32) (s error) { + r0, _, _ := syscall.Syscall6(procBCryptGenerateKeyPair.Addr(), 4, uintptr(hAlgorithm), uintptr(unsafe.Pointer(phKey)), uintptr(dwLength), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + func GenerateSymmetricKey(hAlgorithm ALG_HANDLE, phKey *KEY_HANDLE, pbKeyObject []byte, pbSecret []byte, dwFlags uint32) (s error) { var _p0 *byte if len(pbKeyObject) > 0 { @@ -211,6 +243,18 @@ func HashData(hHash HASH_HANDLE, pbInput []byte, dwFlags uint32) (s error) { return } +func ImportKeyPair(hAlgorithm ALG_HANDLE, hImportKey KEY_HANDLE, pszBlobType *uint16, phKey *KEY_HANDLE, pbInput []byte, dwFlags uint32) (s error) { + var _p0 *byte + if len(pbInput) > 0 { + _p0 = &pbInput[0] + } + r0, _, _ := syscall.Syscall9(procBCryptImportKeyPair.Addr(), 7, uintptr(hAlgorithm), uintptr(hImportKey), uintptr(unsafe.Pointer(pszBlobType)), uintptr(unsafe.Pointer(phKey)), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + func OpenAlgorithmProvider(phAlgorithm *ALG_HANDLE, pszAlgId *uint16, pszImplementation *uint16, dwFlags AlgorithmProviderFlags) (s error) { r0, _, _ := syscall.Syscall6(procBCryptOpenAlgorithmProvider.Addr(), 4, uintptr(unsafe.Pointer(phAlgorithm)), uintptr(unsafe.Pointer(pszAlgId)), uintptr(unsafe.Pointer(pszImplementation)), uintptr(dwFlags), 0, 0) if r0 != 0 { From 9e12a3153b9cd6c685a91c69c1a30ed82be3d208 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 17 Mar 2022 15:18:44 +0100 Subject: [PATCH 2/9] implement RSA encrypt and decrypt --- cng/aes.go | 4 +- cng/rsa.go | 71 +++++++++++++++++++++++ cng/rsa_test.go | 88 ++++++++++++++++++++++++++--- internal/bcrypt/bcrypt_windows.go | 22 +++++++- internal/bcrypt/zsyscall_windows.go | 8 +-- 5 files changed, 178 insertions(+), 15 deletions(-) diff --git a/cng/aes.go b/cng/aes.go index 43bd772..62f8919 100644 --- a/cng/aes.go +++ b/cng/aes.go @@ -328,7 +328,7 @@ func (g *aesGCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte { info := bcrypt.NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, out[len(out)-gcmTagSize:]) var encSize uint32 - err := bcrypt.Encrypt(g.kh, plaintext, info, nil, out, &encSize, 0) + err := bcrypt.Encrypt(g.kh, plaintext, unsafe.Pointer(info), nil, out, &encSize, 0) if err != nil { panic(err) } @@ -365,7 +365,7 @@ func (g *aesGCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, er info := bcrypt.NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag) var decSize uint32 - err := bcrypt.Decrypt(g.kh, ciphertext, info, nil, out, &decSize, 0) + err := bcrypt.Decrypt(g.kh, ciphertext, unsafe.Pointer(info), nil, out, &decSize, 0) if err != nil || int(decSize) != len(ciphertext) { for i := range out { out[i] = 0 diff --git a/cng/rsa.go b/cng/rsa.go index 163dfdb..9f6e5df 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -7,6 +7,8 @@ package cng import ( + "errors" + "hash" "math/big" "runtime" "sync" @@ -171,3 +173,72 @@ func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte { } return blob } + +func DecryptRSAOAEP(h hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + return rsaOAEP(h, priv.pkey, ciphertext, label, false) +} + +func EncryptRSAOAEP(h hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) { + defer runtime.KeepAlive(pub) + return rsaOAEP(h, pub.pkey, msg, label, true) +} + +func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + return rsaCrypt(priv.pkey, nil, ciphertext, bcrypt.PAD_PKCS1, false) +} + +func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) { + defer runtime.KeepAlive(pub) + return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_PKCS1, true) +} + +func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + return rsaCrypt(priv.pkey, nil, ciphertext, bcrypt.PAD_NONE, false) + +} + +func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) { + defer runtime.KeepAlive(pub) + return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_NONE, true) +} + +func rsaCrypt(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, in []byte, flags bcrypt.EncryptFlags, encrypt bool) ([]byte, error) { + var size uint32 + var err error + if encrypt { + err = bcrypt.Encrypt(pkey, in, info, nil, nil, &size, flags) + } else { + err = bcrypt.Decrypt(pkey, in, info, nil, nil, &size, flags) + } + if err != nil { + return nil, err + } + out := make([]byte, size) + if encrypt { + err = bcrypt.Encrypt(pkey, in, info, nil, out, &size, flags) + } else { + err = bcrypt.Decrypt(pkey, in, info, nil, out, &size, flags) + } + if err != nil { + return nil, err + } + return out[:size], nil +} + +func rsaOAEP(h hash.Hash, pkey bcrypt.KEY_HANDLE, in, label []byte, encrypt bool) ([]byte, error) { + hashID := hashToID(h) + if hashID == "" { + return nil, errors.New("crypto/rsa: unsupported hash function") + } + info := bcrypt.OAEP_PADDING_INFO{ + AlgId: utf16PtrFromString(hashID), + LabelSize: uint32(len(label)), + } + if len(label) > 0 { + info.Label = &label[0] + } + return rsaCrypt(pkey, unsafe.Pointer(&info), in, bcrypt.PAD_OAEP, encrypt) +} diff --git a/cng/rsa_test.go b/cng/rsa_test.go index 3aaae30..de6432c 100644 --- a/cng/rsa_test.go +++ b/cng/rsa_test.go @@ -7,25 +7,99 @@ package cng import ( + "bytes" "strconv" "testing" ) +func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) { + t.Helper() + N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size) + if err != nil { + t.Fatalf("GenerateKeyRSA(%d): %v", size, err) + } + priv, err := NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) + if err != nil { + t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) + } + pub, err := NewPublicKeyRSA(N, E) + if err != nil { + t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) + } + return priv, pub +} + func TestRSAKeyGeneration(t *testing.T) { for _, size := range []int{2048, 3072} { t.Run(strconv.Itoa(size), func(t *testing.T) { - N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size) + t.Parallel() + priv, pub := newRSAKey(t, size) + msg := []byte("hi!") + enc, err := EncryptRSAPKCS1(pub, msg) if err != nil { - t.Fatalf("GenerateKeyRSA(%d): %v", size, err) + t.Fatalf("EncryptPKCS1v15: %v", err) } - _, err = NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) + dec, err := DecryptRSAPKCS1(priv, enc) if err != nil { - t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) + t.Fatalf("DecryptPKCS1v15: %v", err) } - _, err = NewPublicKeyRSA(N, E) - if err != nil { - t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) + if !bytes.Equal(dec, msg) { + t.Fatalf("got:%x want:%x", dec, msg) } }) } } + +func TestEncryptDecryptOAEP(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + label := []byte("ho!") + priv, pub := newRSAKey(t, 2048) + enc, err := EncryptRSAOAEP(sha256, pub, msg, label) + if err != nil { + t.Fatal(err) + } + dec, err := DecryptRSAOAEP(sha256, priv, enc, label) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(dec, msg) { + t.Errorf("got:%x want:%x", dec, msg) + } +} + +func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + enc, err := EncryptRSAOAEP(sha256, pub, msg, []byte("ho!")) + if err != nil { + t.Fatal(err) + } + dec, err := DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!")) + if err == nil { + t.Errorf("error expected") + } + if dec != nil { + t.Errorf("got:%x want: nil", dec) + } +} + +func TestEncryptDecryptNoPadding(t *testing.T) { + const bits = 2048 + var msg [bits / 8]byte + msg[0] = 1 + msg[255] = 1 + priv, pub := newRSAKey(t, bits) + enc, err := EncryptRSANoPadding(pub, msg[:]) + if err != nil { + t.Fatal(err) + } + dec, err := DecryptRSANoPadding(priv, enc) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(dec, msg[:]) { + t.Errorf("got:%x want:%x", dec, msg) + } +} diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 87e3299..554bcb3 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -28,6 +28,7 @@ const ( CHAIN_MODE_GCM = "ChainingModeGCM" KEY_LENGTHS = "KeyLengths" OBJECT_LENGTH = "ObjectLength" + BLOCK_LENGTH = "BlockLength" ) const ( @@ -39,6 +40,16 @@ const ( USE_SYSTEM_PREFERRED_RNG = 0x00000002 ) +type EncryptFlags uint32 + +const ( + EncrytFlagsNone EncryptFlags = 0x0 + PAD_NONE EncryptFlags = 0x1 + PAD_PKCS1 EncryptFlags = 0x2 + PAD_OAEP EncryptFlags = 0x4 + PAD_PSS EncryptFlags = 0x8 +) + type AlgorithmProviderFlags uint32 const ( @@ -102,6 +113,13 @@ func NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag []byte) *AUTHE return &info } +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_oaep_padding_info +type OAEP_PADDING_INFO struct { + AlgId *uint16 + Label *byte + LabelSize uint32 +} + // https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob type RSAKEY_BLOB struct { Magic KeyBlobMagicNumber @@ -137,5 +155,5 @@ type RSAKEY_BLOB struct { //sys ImportKeyPair (hAlgorithm ALG_HANDLE, hImportKey KEY_HANDLE, pszBlobType *uint16, phKey *KEY_HANDLE, pbInput []byte, dwFlags uint32) (s error) = bcrypt.BCryptImportKeyPair //sys ExportKey(hKey KEY_HANDLE, hExportKey KEY_HANDLE, pszBlobType *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptExportKey //sys DestroyKey(hKey KEY_HANDLE) (s error) = bcrypt.BCryptDestroyKey -//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptEncrypt -//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptDecrypt +//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) = bcrypt.BCryptEncrypt +//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) = bcrypt.BCryptDecrypt diff --git a/internal/bcrypt/zsyscall_windows.go b/internal/bcrypt/zsyscall_windows.go index 549339f..40ee49e 100644 --- a/internal/bcrypt/zsyscall_windows.go +++ b/internal/bcrypt/zsyscall_windows.go @@ -83,7 +83,7 @@ func CreateHash(hAlgorithm ALG_HANDLE, phHash *HASH_HANDLE, pbHashObject []byte, return } -func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) { +func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) { var _p0 *byte if len(pbInput) > 0 { _p0 = &pbInput[0] @@ -96,7 +96,7 @@ func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER if len(pbOutput) > 0 { _p2 = &pbOutput[0] } - r0, _, _ := syscall.Syscall12(procBCryptDecrypt.Addr(), 10, uintptr(hKey), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(unsafe.Pointer(pPaddingInfo)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbIV)), uintptr(unsafe.Pointer(_p2)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) + r0, _, _ := syscall.Syscall12(procBCryptDecrypt.Addr(), 10, uintptr(hKey), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbIV)), uintptr(unsafe.Pointer(_p2)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) if r0 != 0 { s = syscall.Errno(r0) } @@ -131,7 +131,7 @@ func DuplicateHash(hHash HASH_HANDLE, phNewHash *HASH_HANDLE, pbHashObject []byt return } -func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) { +func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) { var _p0 *byte if len(pbInput) > 0 { _p0 = &pbInput[0] @@ -144,7 +144,7 @@ func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER if len(pbOutput) > 0 { _p2 = &pbOutput[0] } - r0, _, _ := syscall.Syscall12(procBCryptEncrypt.Addr(), 10, uintptr(hKey), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(unsafe.Pointer(pPaddingInfo)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbIV)), uintptr(unsafe.Pointer(_p2)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) + r0, _, _ := syscall.Syscall12(procBCryptEncrypt.Addr(), 10, uintptr(hKey), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbIV)), uintptr(unsafe.Pointer(_p2)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0, 0) if r0 != 0 { s = syscall.Errno(r0) } From a677ed5e39ae55c2a140704c76e438d278a05a23 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 17 Mar 2022 15:51:53 +0100 Subject: [PATCH 3/9] implement RSA PKCS1v15 sign and verify --- cng/rsa.go | 67 ++++++++++++++++++++++++++++- cng/rsa_test.go | 45 +++++++++++++++++++ internal/bcrypt/bcrypt_windows.go | 24 +++++++---- internal/bcrypt/zsyscall_windows.go | 38 +++++++++++++++- 4 files changed, 163 insertions(+), 11 deletions(-) diff --git a/cng/rsa.go b/cng/rsa.go index 9f6e5df..fd14fcd 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -7,6 +7,7 @@ package cng import ( + "crypto" "errors" "hash" "math/big" @@ -205,7 +206,25 @@ func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) { return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_NONE, true) } -func rsaCrypt(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, in []byte, flags bcrypt.EncryptFlags, encrypt bool) ([]byte, error) { +func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + info, err := newPKCS1_PADDING_INFO(h) + if err != nil { + return nil, err + } + return rsaSign(priv.pkey, unsafe.Pointer(&info), hashed, bcrypt.PAD_PKCS1) +} + +func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error { + defer runtime.KeepAlive(pub) + info, err := newPKCS1_PADDING_INFO(h) + if err != nil { + return err + } + return rsaVerify(pub.pkey, unsafe.Pointer(&info), hashed, sig, bcrypt.PAD_PKCS1) +} + +func rsaCrypt(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, in []byte, flags bcrypt.PadMode, encrypt bool) ([]byte, error) { var size uint32 var err error if encrypt { @@ -242,3 +261,49 @@ func rsaOAEP(h hash.Hash, pkey bcrypt.KEY_HANDLE, in, label []byte, encrypt bool } return rsaCrypt(pkey, unsafe.Pointer(&info), in, bcrypt.PAD_OAEP, encrypt) } + +func rsaSign(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, hashed []byte, flags bcrypt.PadMode) ([]byte, error) { + var size uint32 + err := bcrypt.SignHash(pkey, info, hashed, nil, &size, flags) + if err != nil { + return nil, err + } + out := make([]byte, size) + err = bcrypt.SignHash(pkey, info, hashed, out, &size, flags) + if err != nil { + return nil, err + } + return out[:size], nil +} + +func rsaVerify(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, hashed, sig []byte, flags bcrypt.PadMode) error { + return bcrypt.VerifySignature(pkey, info, hashed, sig, flags) +} + +func newPKCS1_PADDING_INFO(h crypto.Hash) (info bcrypt.PKCS1_PADDING_INFO, err error) { + if h != 0 { + hashID := cryptoHashToID(h) + if hashID == "" { + err = errors.New("crypto/rsa: unsupported hash function") + } else { + info.AlgId = utf16PtrFromString(hashID) + } + } + return +} + +func cryptoHashToID(ch crypto.Hash) string { + switch ch { + case crypto.MD5: + return bcrypt.MD5_ALGORITHM + case crypto.SHA1: + return bcrypt.SHA1_ALGORITHM + case crypto.SHA256: + return bcrypt.SHA256_ALGORITHM + case crypto.SHA384: + return bcrypt.SHA384_ALGORITHM + case crypto.SHA512: + return bcrypt.SHA512_ALGORITHM + } + return "" +} diff --git a/cng/rsa_test.go b/cng/rsa_test.go index de6432c..5c2c26a 100644 --- a/cng/rsa_test.go +++ b/cng/rsa_test.go @@ -8,6 +8,7 @@ package cng import ( "bytes" + "crypto" "strconv" "testing" ) @@ -103,3 +104,47 @@ func TestEncryptDecryptNoPadding(t *testing.T) { t.Errorf("got:%x want:%x", dec, msg) } } + +func TestSignVerifyPKCS1v15(t *testing.T) { + sha256 := NewSHA256() + priv, pub := newRSAKey(t, 2048) + sha256.Write([]byte("hi!")) + hashed := sha256.Sum(nil) + signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + signed, err := SignRSAPKCS1v15(priv, 0, msg) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, 0, msg, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + sha256.Write(msg) + hashed := sha256.Sum(nil) + signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) + if err == nil { + t.Fatal("error expected") + } +} diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 554bcb3..6f462d8 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -17,6 +17,7 @@ const ( SHA512_ALGORITHM = "SHA512" AES_ALGORITHM = "AES" RSA_ALGORITHM = "RSA" + MD5_ALGORITHM = "MD5" ) const ( @@ -40,14 +41,14 @@ const ( USE_SYSTEM_PREFERRED_RNG = 0x00000002 ) -type EncryptFlags uint32 +type PadMode uint32 const ( - EncrytFlagsNone EncryptFlags = 0x0 - PAD_NONE EncryptFlags = 0x1 - PAD_PKCS1 EncryptFlags = 0x2 - PAD_OAEP EncryptFlags = 0x4 - PAD_PSS EncryptFlags = 0x8 + PAD_UNDEFINED PadMode = 0x0 + PAD_NONE PadMode = 0x1 + PAD_PKCS1 PadMode = 0x2 + PAD_OAEP PadMode = 0x4 + PAD_PSS PadMode = 0x8 ) type AlgorithmProviderFlags uint32 @@ -120,6 +121,11 @@ type OAEP_PADDING_INFO struct { LabelSize uint32 } +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_pkcs1_padding_info +type PKCS1_PADDING_INFO struct { + AlgId *uint16 +} + // https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob type RSAKEY_BLOB struct { Magic KeyBlobMagicNumber @@ -155,5 +161,7 @@ type RSAKEY_BLOB struct { //sys ImportKeyPair (hAlgorithm ALG_HANDLE, hImportKey KEY_HANDLE, pszBlobType *uint16, phKey *KEY_HANDLE, pbInput []byte, dwFlags uint32) (s error) = bcrypt.BCryptImportKeyPair //sys ExportKey(hKey KEY_HANDLE, hExportKey KEY_HANDLE, pszBlobType *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptExportKey //sys DestroyKey(hKey KEY_HANDLE) (s error) = bcrypt.BCryptDestroyKey -//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) = bcrypt.BCryptEncrypt -//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) = bcrypt.BCryptDecrypt +//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptEncrypt +//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptDecrypt +//sys SignHash (hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbInput []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptSignHash +//sys VerifySignature(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbHash []byte, pbSignature []byte, dwFlags PadMode) (s error) = bcrypt.BCryptVerifySignature diff --git a/internal/bcrypt/zsyscall_windows.go b/internal/bcrypt/zsyscall_windows.go index 40ee49e..5fe28c5 100644 --- a/internal/bcrypt/zsyscall_windows.go +++ b/internal/bcrypt/zsyscall_windows.go @@ -57,6 +57,8 @@ var ( procBCryptImportKeyPair = modbcrypt.NewProc("BCryptImportKeyPair") procBCryptOpenAlgorithmProvider = modbcrypt.NewProc("BCryptOpenAlgorithmProvider") procBCryptSetProperty = modbcrypt.NewProc("BCryptSetProperty") + procBCryptSignHash = modbcrypt.NewProc("BCryptSignHash") + procBCryptVerifySignature = modbcrypt.NewProc("BCryptVerifySignature") ) func CloseAlgorithmProvider(hAlgorithm ALG_HANDLE, dwFlags uint32) (s error) { @@ -83,7 +85,7 @@ func CreateHash(hAlgorithm ALG_HANDLE, phHash *HASH_HANDLE, pbHashObject []byte, return } -func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) { +func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { var _p0 *byte if len(pbInput) > 0 { _p0 = &pbInput[0] @@ -131,7 +133,7 @@ func DuplicateHash(hHash HASH_HANDLE, phNewHash *HASH_HANDLE, pbHashObject []byt return } -func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) { +func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { var _p0 *byte if len(pbInput) > 0 { _p0 = &pbInput[0] @@ -274,3 +276,35 @@ func SetProperty(hObject HANDLE, pszProperty *uint16, pbInput []byte, dwFlags ui } return } + +func SignHash(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbInput []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { + var _p0 *byte + if len(pbInput) > 0 { + _p0 = &pbInput[0] + } + var _p1 *byte + if len(pbOutput) > 0 { + _p1 = &pbOutput[0] + } + r0, _, _ := syscall.Syscall9(procBCryptSignHash.Addr(), 8, uintptr(hKey), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + +func VerifySignature(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbHash []byte, pbSignature []byte, dwFlags PadMode) (s error) { + var _p0 *byte + if len(pbHash) > 0 { + _p0 = &pbHash[0] + } + var _p1 *byte + if len(pbSignature) > 0 { + _p1 = &pbSignature[0] + } + r0, _, _ := syscall.Syscall9(procBCryptVerifySignature.Addr(), 7, uintptr(hKey), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbHash)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbSignature)), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} From 4cc8ca89a72ee21b40e83b08df69315088fb7ab9 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 17 Mar 2022 16:11:14 +0100 Subject: [PATCH 4/9] implement RSA PSS sign and verify --- cng/rsa.go | 28 ++++++++++++++++++++++++++++ cng/rsa_test.go | 15 +++++++++++++++ internal/bcrypt/bcrypt_windows.go | 6 ++++++ 3 files changed, 49 insertions(+) diff --git a/cng/rsa.go b/cng/rsa.go index fd14fcd..b958b0c 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -206,6 +206,24 @@ func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) { return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_NONE, true) } +func SignRSAPSS(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte, saltLen int) ([]byte, error) { + defer runtime.KeepAlive(priv) + info, err := newPSS_PADDING_INFO(h, saltLen) + if err != nil { + return nil, err + } + return rsaSign(priv.pkey, unsafe.Pointer(&info), hashed, bcrypt.PAD_PSS) +} + +func VerifyRSAPSS(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte, saltLen int) error { + defer runtime.KeepAlive(pub) + info, err := newPSS_PADDING_INFO(h, saltLen) + if err != nil { + return err + } + return rsaVerify(pub.pkey, unsafe.Pointer(&info), hashed, sig, bcrypt.PAD_PSS) +} + func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) { defer runtime.KeepAlive(priv) info, err := newPKCS1_PADDING_INFO(h) @@ -280,6 +298,16 @@ func rsaVerify(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, hashed, sig []byte, return bcrypt.VerifySignature(pkey, info, hashed, sig, flags) } +func newPSS_PADDING_INFO(h crypto.Hash, saltLen int) (info bcrypt.PSS_PADDING_INFO, err error) { + hashID := cryptoHashToID(h) + if hashID == "" { + return info, errors.New("crypto/rsa: unsupported hash function") + } + info.AlgId = utf16PtrFromString(hashID) + info.Salt = uint32(saltLen) + return +} + func newPKCS1_PADDING_INFO(h crypto.Hash) (info bcrypt.PKCS1_PADDING_INFO, err error) { if h != 0 { hashID := cryptoHashToID(h) diff --git a/cng/rsa_test.go b/cng/rsa_test.go index 5c2c26a..888dc5b 100644 --- a/cng/rsa_test.go +++ b/cng/rsa_test.go @@ -148,3 +148,18 @@ func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { t.Fatal("error expected") } } + +func TestSignVerifyRSAPSS(t *testing.T) { + sha256 := NewSHA256() + priv, pub := newRSAKey(t, 2048) + sha256.Write([]byte("testing")) + hashed := sha256.Sum(nil) + signed, err := SignRSAPSS(priv, crypto.SHA256, hashed, 0) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0) + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 6f462d8..8ace412 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -126,6 +126,12 @@ type PKCS1_PADDING_INFO struct { AlgId *uint16 } +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_pss_padding_info +type PSS_PADDING_INFO struct { + AlgId *uint16 + Salt uint32 +} + // https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob type RSAKEY_BLOB struct { Magic KeyBlobMagicNumber From 0b373add74394980b9f00da5b8c4b3748823e434 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 17 Mar 2022 18:21:26 +0100 Subject: [PATCH 5/9] add rsa benchmark --- cng/rsa_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/cng/rsa_test.go b/cng/rsa_test.go index 888dc5b..dea2a14 100644 --- a/cng/rsa_test.go +++ b/cng/rsa_test.go @@ -9,6 +9,7 @@ package cng import ( "bytes" "crypto" + "math/big" "strconv" "testing" ) @@ -163,3 +164,28 @@ func TestSignVerifyRSAPSS(t *testing.T) { t.Fatal(err) } } + +func fromBase36(base36 string) *big.Int { + i, ok := new(big.Int).SetString(base36, 36) + if !ok { + panic("bad number: " + base36) + } + return i +} + +func BenchmarkEncryptRSAPKCS1(b *testing.B) { + b.StopTimer() + // Public key length should be at least of 2048 bits, else OpenSSL will report an error when running in FIPS mode. + n := fromBase36("14314132931241006650998084889274020608918049032671858325988396851334124245188214251956198731333464217832226406088020736932173064754214329009979944037640912127943488972644697423190955557435910767690712778463524983667852819010259499695177313115447116110358524558307947613422897787329221478860907963827160223559690523660574329011927531289655711860504630573766609239332569210831325633840174683944553667352219670930408593321661375473885147973879086994006440025257225431977751512374815915392249179976902953721486040787792801849818254465486633791826766873076617116727073077821584676715609985777563958286637185868165868520557") + test2048PubKey, err := NewPublicKeyRSA(n, big.NewInt(3)) + if err != nil { + b.Fatal(err) + } + b.StartTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil { + b.Fatal(err) + } + } +} From 677a924d90d2bd1ac54350da2a01aab80fdbfca2 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 17 Mar 2022 18:40:22 +0100 Subject: [PATCH 6/9] consolidate algorithms cache --- cng/aes.go | 72 ++++++++++++++++++++---------------------------------- cng/cng.go | 34 +++++++++++++++++++++++--- cng/rsa.go | 23 +++++++---------- cng/sha.go | 48 ++++++++++++------------------------ 4 files changed, 82 insertions(+), 95 deletions(-) diff --git a/cng/aes.go b/cng/aes.go index 62f8919..73b12f6 100644 --- a/cng/aes.go +++ b/cng/aes.go @@ -10,7 +10,6 @@ import ( "crypto/cipher" "errors" "runtime" - "sync" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" @@ -19,57 +18,38 @@ import ( const aesBlockSize = 16 -var aesCache sync.Map - type aesAlgorithm struct { h bcrypt.ALG_HANDLE allowedKeySizes []int } -type aesCacheEntry struct { - id string - mode string -} - -func loadAes(id string, mode string) (h aesAlgorithm, err error) { - if v, ok := aesCache.Load(aesCacheEntry{id, mode}); ok { - return v.(aesAlgorithm), nil - } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG) - if err != nil { - return - } - defer func() { +func loadAes(mode string) (aesAlgorithm, error) { + v, err := loadOrStoreAlg(bcrypt.AES_ALGORITHM, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (interface{}, error) { + // Windows 8 added support to set the CipherMode value on a key, + // but Windows 7 requires that it be set on the algorithm before key creation. + err := setString(bcrypt.HANDLE(h), bcrypt.CHAINING_MODE, mode) if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - h.h = 0 + return nil, err } - }() - // Windows 8 added support to set the CipherMode value on a key, - // but Windows 7 requires that it be set on the algorithm before key creation. - err = setString(bcrypt.HANDLE(h.h), bcrypt.CHAINING_MODE, mode) - if err != nil { - return - } - var info bcrypt.KEY_LENGTHS_STRUCT - var discard uint32 - err = bcrypt.GetProperty(bcrypt.HANDLE(h.h), utf16PtrFromString(bcrypt.KEY_LENGTHS), (*[unsafe.Sizeof(info)]byte)(unsafe.Pointer(&info))[:], &discard, 0) + var info bcrypt.KEY_LENGTHS_STRUCT + var discard uint32 + err = bcrypt.GetProperty(bcrypt.HANDLE(h), utf16PtrFromString(bcrypt.KEY_LENGTHS), (*[unsafe.Sizeof(info)]byte)(unsafe.Pointer(&info))[:], &discard, 0) + if err != nil { + return nil, err + } + if info.Increment == 0 || info.MinLength > info.MaxLength { + return nil, errors.New("invalid BCRYPT_KEY_LENGTHS_STRUCT") + } + var allowedKeySized []int + for size := info.MinLength; size <= info.MaxLength; size += info.Increment { + allowedKeySized = append(allowedKeySized, int(size)) + } + return aesAlgorithm{h, allowedKeySized}, nil + }) if err != nil { - return + return aesAlgorithm{}, nil } - if info.Increment == 0 || info.MinLength > info.MaxLength { - err = errors.New("invalid BCRYPT_KEY_LENGTHS_STRUCT") - return - } - for size := info.MinLength; size <= info.MaxLength; size += info.Increment { - h.allowedKeySizes = append(h.allowedKeySizes, int(size)) - } - if existing, loaded := aesCache.LoadOrStore(aesCacheEntry{id, mode}, h); loaded { - // We can safely use a provider that has already been cached in another concurrent goroutine. - bcrypt.CloseAlgorithmProvider(h.h, 0) - h = existing.(aesAlgorithm) - } - return + return v.(aesAlgorithm), nil } type aesCipher struct { @@ -78,7 +58,7 @@ type aesCipher struct { } func NewAESCipher(key []byte) (cipher.Block, error) { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_ECB) + h, err := loadAes(bcrypt.CHAIN_MODE_ECB) if err != nil { return nil, err } @@ -194,7 +174,7 @@ type aesCBC struct { } func newCBC(encrypt bool, key, iv []byte) *aesCBC { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_CBC) + h, err := loadAes(bcrypt.CHAIN_MODE_CBC) if err != nil { panic(err) } @@ -268,7 +248,7 @@ func (g *aesGCM) finalize() { } func newGCM(key []byte, tls bool) (*aesGCM, error) { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM) + h, err := loadAes(bcrypt.CHAIN_MODE_GCM) if err != nil { return nil, err } diff --git a/cng/cng.go b/cng/cng.go index 7a54ab0..a30fdc9 100644 --- a/cng/cng.go +++ b/cng/cng.go @@ -10,6 +10,7 @@ import ( "math" "reflect" "runtime" + "sync" "syscall" "unsafe" @@ -25,9 +26,36 @@ func lenU32(s []byte) int { return len(s) } -type algCacheEntry struct { - id string - flags uint32 +var algCache sync.Map + +type newAlgEntryFn func(h bcrypt.ALG_HANDLE) (interface{}, error) + +func loadOrStoreAlg(id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn newAlgEntryFn) (interface{}, error) { + var entryKey = struct { + id string + flags bcrypt.AlgorithmProviderFlags + mode string + }{id, flags, mode} + + if v, ok := algCache.Load(entryKey); ok { + return v, nil + } + var h bcrypt.ALG_HANDLE + err := bcrypt.OpenAlgorithmProvider(&h, utf16PtrFromString(id), nil, flags) + if err != nil { + return nil, err + } + v, err := fn(h) + if err != nil { + bcrypt.CloseAlgorithmProvider(h, 0) + return nil, err + } + if existing, loaded := algCache.LoadOrStore(entryKey, v); loaded { + // We can safely use a provider that has already been cached in another concurrent goroutine. + bcrypt.CloseAlgorithmProvider(h, 0) + v = existing + } + return v, nil } func utf16PtrFromString(s string) *uint16 { diff --git a/cng/rsa.go b/cng/rsa.go index b958b0c..15fe68a 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -12,28 +12,23 @@ import ( "hash" "math/big" "runtime" - "sync" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) -var rsaCache sync.Map - type rsaAlgorithm struct { h bcrypt.ALG_HANDLE } -func loadRsa(id string, flags bcrypt.AlgorithmProviderFlags) (h rsaAlgorithm, err error) { - if v, ok := rsaCache.Load(algCacheEntry{id, uint32(flags)}); ok { - return v.(rsaAlgorithm), nil - } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags) +func loadRsa() (rsaAlgorithm, error) { + v, err := loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return rsaAlgorithm{h}, nil + }) if err != nil { - return + return rsaAlgorithm{}, nil } - rsaCache.Store(algCacheEntry{id, uint32(flags)}, h) - return + return v.(rsaAlgorithm), nil } const sizeOfRSABlobHeader = uint32(unsafe.Sizeof(bcrypt.RSAKEY_BLOB{})) @@ -42,7 +37,7 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { return nil, nil, nil, nil, nil, nil, nil, nil, e } - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return bad(err) } @@ -95,7 +90,7 @@ type PublicKeyRSA struct { } func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) { - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return nil, err } @@ -122,7 +117,7 @@ func (k *PrivateKeyRSA) finalize() { } func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) { - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return nil, err } diff --git a/cng/sha.go b/cng/sha.go index dac8696..0cd8bee 100644 --- a/cng/sha.go +++ b/cng/sha.go @@ -9,7 +9,6 @@ package cng import ( "hash" "runtime" - "sync" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) @@ -105,8 +104,6 @@ func NewSHA512() hash.Hash { return newSHAX(bcrypt.SHA512_ALGORITHM, nil) } -var shaCache sync.Map - type shaAlgorithm struct { h bcrypt.ALG_HANDLE size uint32 @@ -114,39 +111,26 @@ type shaAlgorithm struct { objectLength uint32 } -func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, err error) { - if v, ok := shaCache.Load(algCacheEntry{id, uint32(flags)}); ok { - return v.(shaAlgorithm), nil - } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags) - if err != nil { - return - } - defer func() { +func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (shaAlgorithm, error) { + v, err := loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + size, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_LENGTH) if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - h.h = 0 + return nil, err } - }() - h.size, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.HASH_LENGTH) - if err != nil { - return - } - h.blockSize, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.HASH_BLOCK_LENGTH) - if err != nil { - return - } - h.objectLength, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.OBJECT_LENGTH) + blockSize, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_BLOCK_LENGTH) + if err != nil { + return nil, err + } + objectLength, err := getUint32(bcrypt.HANDLE(h), bcrypt.OBJECT_LENGTH) + if err != nil { + return nil, err + } + return shaAlgorithm{h, size, blockSize, objectLength}, nil + }) if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - return - } - if existing, loaded := shaCache.LoadOrStore(algCacheEntry{id, uint32(flags)}, h); loaded { - // We can safely use a provider that has already been cached in another concurrent goroutine. - bcrypt.CloseAlgorithmProvider(h.h, 0) - h = existing.(shaAlgorithm) + return shaAlgorithm{}, err } - return + return v.(shaAlgorithm), nil } type shaXHash struct { From f18c222480f8c3260393776cc60ccb86d7e6e743 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Tue, 28 Jun 2022 17:51:46 +0200 Subject: [PATCH 7/9] pr feedback --- cng/aes.go | 6 +++--- cng/rsa.go | 26 +++++++++++++++----------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/cng/aes.go b/cng/aes.go index 73b12f6..1f0f57a 100644 --- a/cng/aes.go +++ b/cng/aes.go @@ -40,11 +40,11 @@ func loadAes(mode string) (aesAlgorithm, error) { if info.Increment == 0 || info.MinLength > info.MaxLength { return nil, errors.New("invalid BCRYPT_KEY_LENGTHS_STRUCT") } - var allowedKeySized []int + var allowedKeySizes []int for size := info.MinLength; size <= info.MaxLength; size += info.Increment { - allowedKeySized = append(allowedKeySized, int(size)) + allowedKeySizes = append(allowedKeySizes, int(size)) } - return aesAlgorithm{h, allowedKeySized}, nil + return aesAlgorithm{h, allowedKeySizes}, nil }) if err != nil { return aesAlgorithm{}, nil diff --git a/cng/rsa.go b/cng/rsa.go index 15fe68a..87ffb6d 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -59,6 +59,10 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) return bad(err) } + if size < sizeOfRSABlobHeader { + return bad(errors.New("crypto/rsa: exported key is corrupted")) + } + blob := make([]byte, size) err = bcrypt.ExportKey(hkey, 0, utf16PtrFromString(bcrypt.RSAFULLPRIVATE_BLOB), blob, &size, 0) if err != nil { @@ -66,22 +70,22 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) } hdr := (*(*bcrypt.RSAKEY_BLOB)(unsafe.Pointer(&blob[0]))) if hdr.Magic != bcrypt.RSAFULLPRIVATE_MAGIC || hdr.BitLength != uint32(bits) { - panic("crypto/rsa: exported key is corrupted") + return bad(errors.New("crypto/rsa: exported key is corrupted")) } data := blob[sizeOfRSABlobHeader:] - newInt := func(size uint32) *big.Int { + consumeBigInt := func(size uint32) *big.Int { b := new(big.Int).SetBytes(data[:size]) data = data[size:] return b } - E = newInt(hdr.PublicExpSize) - N = newInt(hdr.ModulusSize) - P = newInt(hdr.Prime1Size) - Q = newInt(hdr.Prime2Size) - Dp = newInt(hdr.Prime1Size) - Dq = newInt(hdr.Prime2Size) - Qinv = newInt(hdr.Prime1Size) - D = newInt(hdr.ModulusSize) + E = consumeBigInt(hdr.PublicExpSize) + N = consumeBigInt(hdr.ModulusSize) + P = consumeBigInt(hdr.Prime1Size) + Q = consumeBigInt(hdr.Prime2Size) + Dp = consumeBigInt(hdr.Prime1Size) + Dq = consumeBigInt(hdr.Prime2Size) + Qinv = consumeBigInt(hdr.Prime1Size) + D = consumeBigInt(hdr.ModulusSize) return } @@ -151,7 +155,7 @@ func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte { hdr.Prime2Size = bigIntBytesLen(Q) blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize*2+hdr.Prime1Size*3+hdr.Prime2Size*2) } - copy(blob[:sizeOfRSABlobHeader], (*(*[1<<31 - 1]byte)(unsafe.Pointer(&hdr)))[:sizeOfRSABlobHeader]) + copy(blob, (*(*[sizeOfRSABlobHeader]byte)(unsafe.Pointer(&hdr)))[:]) data := blob[sizeOfRSABlobHeader:] encode := func(b *big.Int, size uint32) { b.FillBytes(data[:size]) From 800268c517c940103a71d57d53cce3abd3e72cd1 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 29 Jun 2022 10:09:14 +0200 Subject: [PATCH 8/9] remove math/big dependency --- cng/bbig/big.go | 31 ++++++++++++++++++++++ cng/big.go | 13 +++++++++ cng/rsa.go | 34 +++++++++++------------- cng/rsa_test.go | 70 +++++++++++++++++++++++++++++-------------------- 4 files changed, 100 insertions(+), 48 deletions(-) create mode 100644 cng/bbig/big.go create mode 100644 cng/big.go diff --git a/cng/bbig/big.go b/cng/bbig/big.go new file mode 100644 index 0000000..584f206 --- /dev/null +++ b/cng/bbig/big.go @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package bbig + +import ( + "math/big" + + "github.com/microsoft/go-crypto-winnative/cng" +) + +func Enc(b *big.Int) cng.BigInt { + if b == nil { + return nil + } + x := b.Bytes() + if len(x) == 0 { + return cng.BigInt{} + } + return x +} + +func Dec(b cng.BigInt) *big.Int { + if b == nil { + return nil + } + if len(b) == 0 { + return new(big.Int) + } + return new(big.Int).SetBytes(b) +} diff --git a/cng/big.go b/cng/big.go new file mode 100644 index 0000000..dfecd03 --- /dev/null +++ b/cng/big.go @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package cng + +// This file does not have build constraints to +// facilitate using BigInt in Go crypto. +// Go crypto references BigInt unconditionally, +// even if it is not finally used. + +// A BigInt is the raw big-endian bytes from a BigInt. +// This definition allows us to avoid importing math/big. +// Conversion between BigInt and *big.Int is in cng/bbig. +type BigInt []byte diff --git a/cng/rsa.go b/cng/rsa.go index 87ffb6d..49c1f8d 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -10,7 +10,6 @@ import ( "crypto" "errors" "hash" - "math/big" "runtime" "unsafe" @@ -33,8 +32,8 @@ func loadRsa() (rsaAlgorithm, error) { const sizeOfRSABlobHeader = uint32(unsafe.Sizeof(bcrypt.RSAKEY_BLOB{})) -func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { - bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { +func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { + bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { return nil, nil, nil, nil, nil, nil, nil, nil, e } h, err := loadRsa() @@ -73,8 +72,9 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) return bad(errors.New("crypto/rsa: exported key is corrupted")) } data := blob[sizeOfRSABlobHeader:] - consumeBigInt := func(size uint32) *big.Int { - b := new(big.Int).SetBytes(data[:size]) + consumeBigInt := func(size uint32) BigInt { + b := make(BigInt, size) + copy(b, data) data = data[size:] return b } @@ -93,7 +93,7 @@ type PublicKeyRSA struct { pkey bcrypt.KEY_HANDLE } -func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) { +func NewPublicKeyRSA(N, E BigInt) (*PublicKeyRSA, error) { h, err := loadRsa() if err != nil { return nil, err @@ -120,7 +120,7 @@ func (k *PrivateKeyRSA) finalize() { bcrypt.DestroyKey(k.pkey) } -func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) { +func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv BigInt) (*PrivateKeyRSA, error) { h, err := loadRsa() if err != nil { return nil, err @@ -135,15 +135,11 @@ func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, err return k, nil } -func bigIntBytesLen(b *big.Int) uint32 { - return uint32(b.BitLen()+7) / 8 -} - -func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte { +func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv BigInt) []byte { hdr := bcrypt.RSAKEY_BLOB{ - BitLength: uint32(N.BitLen()), - PublicExpSize: bigIntBytesLen(E), - ModulusSize: bigIntBytesLen(N), + BitLength: uint32(len(N) * 8), + PublicExpSize: uint32(len(E)), + ModulusSize: uint32(len(N)), } var blob []byte if D == nil { @@ -151,14 +147,14 @@ func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte { blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize) } else { hdr.Magic = bcrypt.RSAFULLPRIVATE_MAGIC - hdr.Prime1Size = bigIntBytesLen(P) - hdr.Prime2Size = bigIntBytesLen(Q) + hdr.Prime1Size = uint32(len(P)) + hdr.Prime2Size = uint32(len(Q)) blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize*2+hdr.Prime1Size*3+hdr.Prime2Size*2) } copy(blob, (*(*[sizeOfRSABlobHeader]byte)(unsafe.Pointer(&hdr)))[:]) data := blob[sizeOfRSABlobHeader:] - encode := func(b *big.Int, size uint32) { - b.FillBytes(data[:size]) + encode := func(b BigInt, size uint32) { + copy(data, b) data = data[size:] } encode(E, hdr.PublicExpSize) diff --git a/cng/rsa_test.go b/cng/rsa_test.go index dea2a14..c384fac 100644 --- a/cng/rsa_test.go +++ b/cng/rsa_test.go @@ -4,7 +4,7 @@ //go:build windows // +build windows -package cng +package cng_test import ( "bytes" @@ -12,19 +12,22 @@ import ( "math/big" "strconv" "testing" + + "github.com/microsoft/go-crypto-winnative/cng" + "github.com/microsoft/go-crypto-winnative/cng/bbig" ) -func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) { +func newRSAKey(t *testing.T, size int) (*cng.PrivateKeyRSA, *cng.PublicKeyRSA) { t.Helper() - N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size) + N, E, D, P, Q, Dp, Dq, Qinv, err := cng.GenerateKeyRSA(size) if err != nil { t.Fatalf("GenerateKeyRSA(%d): %v", size, err) } - priv, err := NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) + priv, err := cng.NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) if err != nil { t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) } - pub, err := NewPublicKeyRSA(N, E) + pub, err := cng.NewPublicKeyRSA(N, E) if err != nil { t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) } @@ -34,14 +37,13 @@ func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) { func TestRSAKeyGeneration(t *testing.T) { for _, size := range []int{2048, 3072} { t.Run(strconv.Itoa(size), func(t *testing.T) { - t.Parallel() priv, pub := newRSAKey(t, size) msg := []byte("hi!") - enc, err := EncryptRSAPKCS1(pub, msg) + enc, err := cng.EncryptRSAPKCS1(pub, msg) if err != nil { t.Fatalf("EncryptPKCS1v15: %v", err) } - dec, err := DecryptRSAPKCS1(priv, enc) + dec, err := cng.DecryptRSAPKCS1(priv, enc) if err != nil { t.Fatalf("DecryptPKCS1v15: %v", err) } @@ -53,15 +55,15 @@ func TestRSAKeyGeneration(t *testing.T) { } func TestEncryptDecryptOAEP(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() msg := []byte("hi!") label := []byte("ho!") priv, pub := newRSAKey(t, 2048) - enc, err := EncryptRSAOAEP(sha256, pub, msg, label) + enc, err := cng.EncryptRSAOAEP(sha256, pub, msg, label) if err != nil { t.Fatal(err) } - dec, err := DecryptRSAOAEP(sha256, priv, enc, label) + dec, err := cng.DecryptRSAOAEP(sha256, priv, enc, label) if err != nil { t.Fatal(err) } @@ -71,14 +73,14 @@ func TestEncryptDecryptOAEP(t *testing.T) { } func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) - enc, err := EncryptRSAOAEP(sha256, pub, msg, []byte("ho!")) + enc, err := cng.EncryptRSAOAEP(sha256, pub, msg, []byte("ho!")) if err != nil { t.Fatal(err) } - dec, err := DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!")) + dec, err := cng.DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!")) if err == nil { t.Errorf("error expected") } @@ -93,11 +95,11 @@ func TestEncryptDecryptNoPadding(t *testing.T) { msg[0] = 1 msg[255] = 1 priv, pub := newRSAKey(t, bits) - enc, err := EncryptRSANoPadding(pub, msg[:]) + enc, err := cng.EncryptRSANoPadding(pub, msg[:]) if err != nil { t.Fatal(err) } - dec, err := DecryptRSANoPadding(priv, enc) + dec, err := cng.DecryptRSANoPadding(priv, enc) if err != nil { t.Fatal(err) } @@ -107,15 +109,15 @@ func TestEncryptDecryptNoPadding(t *testing.T) { } func TestSignVerifyPKCS1v15(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() priv, pub := newRSAKey(t, 2048) sha256.Write([]byte("hi!")) hashed := sha256.Sum(nil) - signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + signed, err := cng.SignRSAPKCS1v15(priv, crypto.SHA256, hashed) if err != nil { t.Fatal(err) } - err = VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) + err = cng.VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) if err != nil { t.Fatal(err) } @@ -124,42 +126,42 @@ func TestSignVerifyPKCS1v15(t *testing.T) { func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) - signed, err := SignRSAPKCS1v15(priv, 0, msg) + signed, err := cng.SignRSAPKCS1v15(priv, 0, msg) if err != nil { t.Fatal(err) } - err = VerifyRSAPKCS1v15(pub, 0, msg, signed) + err = cng.VerifyRSAPKCS1v15(pub, 0, msg, signed) if err != nil { t.Fatal(err) } } func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) sha256.Write(msg) hashed := sha256.Sum(nil) - signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + signed, err := cng.SignRSAPKCS1v15(priv, crypto.SHA256, hashed) if err != nil { t.Fatal(err) } - err = VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) + err = cng.VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) if err == nil { t.Fatal("error expected") } } func TestSignVerifyRSAPSS(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() priv, pub := newRSAKey(t, 2048) sha256.Write([]byte("testing")) hashed := sha256.Sum(nil) - signed, err := SignRSAPSS(priv, crypto.SHA256, hashed, 0) + signed, err := cng.SignRSAPSS(priv, crypto.SHA256, hashed, 0) if err != nil { t.Fatal(err) } - err = VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0) + err = cng.VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0) if err != nil { t.Fatal(err) } @@ -177,14 +179,24 @@ func BenchmarkEncryptRSAPKCS1(b *testing.B) { b.StopTimer() // Public key length should be at least of 2048 bits, else OpenSSL will report an error when running in FIPS mode. n := fromBase36("14314132931241006650998084889274020608918049032671858325988396851334124245188214251956198731333464217832226406088020736932173064754214329009979944037640912127943488972644697423190955557435910767690712778463524983667852819010259499695177313115447116110358524558307947613422897787329221478860907963827160223559690523660574329011927531289655711860504630573766609239332569210831325633840174683944553667352219670930408593321661375473885147973879086994006440025257225431977751512374815915392249179976902953721486040787792801849818254465486633791826766873076617116727073077821584676715609985777563958286637185868165868520557") - test2048PubKey, err := NewPublicKeyRSA(n, big.NewInt(3)) + test2048PubKey, err := cng.NewPublicKeyRSA(bbig.Enc(n), bbig.Enc(big.NewInt(3))) if err != nil { b.Fatal(err) } b.StartTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - if _, err := EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil { + if _, err := cng.EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkGenerateKeyRSA(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, _, _, _, _, _, _, err := cng.GenerateKeyRSA(2048) + if err != nil { b.Fatal(err) } } From ce4cf37c4de956cf1424cd64ccba365b4576359d Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Thu, 30 Jun 2022 09:06:36 +0200 Subject: [PATCH 9/9] Update cng/big.go Co-authored-by: Davis Goodin --- cng/big.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cng/big.go b/cng/big.go index dfecd03..0069d31 100644 --- a/cng/big.go +++ b/cng/big.go @@ -7,7 +7,8 @@ package cng // Go crypto references BigInt unconditionally, // even if it is not finally used. -// A BigInt is the raw big-endian bytes from a BigInt. +// A BigInt is the big-endian bytes from a math/big BigInt. +// Windows BCrypt accepts this specific data format. // This definition allows us to avoid importing math/big. // Conversion between BigInt and *big.Int is in cng/bbig. type BigInt []byte