diff --git a/cng/rsa.go b/cng/rsa.go index 9f6e5df..fd14fcd 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -7,6 +7,7 @@ package cng import ( + "crypto" "errors" "hash" "math/big" @@ -205,7 +206,25 @@ func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) { return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_NONE, true) } -func rsaCrypt(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, in []byte, flags bcrypt.EncryptFlags, encrypt bool) ([]byte, error) { +func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) { + defer runtime.KeepAlive(priv) + info, err := newPKCS1_PADDING_INFO(h) + if err != nil { + return nil, err + } + return rsaSign(priv.pkey, unsafe.Pointer(&info), hashed, bcrypt.PAD_PKCS1) +} + +func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error { + defer runtime.KeepAlive(pub) + info, err := newPKCS1_PADDING_INFO(h) + if err != nil { + return err + } + return rsaVerify(pub.pkey, unsafe.Pointer(&info), hashed, sig, bcrypt.PAD_PKCS1) +} + +func rsaCrypt(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, in []byte, flags bcrypt.PadMode, encrypt bool) ([]byte, error) { var size uint32 var err error if encrypt { @@ -242,3 +261,49 @@ func rsaOAEP(h hash.Hash, pkey bcrypt.KEY_HANDLE, in, label []byte, encrypt bool } return rsaCrypt(pkey, unsafe.Pointer(&info), in, bcrypt.PAD_OAEP, encrypt) } + +func rsaSign(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, hashed []byte, flags bcrypt.PadMode) ([]byte, error) { + var size uint32 + err := bcrypt.SignHash(pkey, info, hashed, nil, &size, flags) + if err != nil { + return nil, err + } + out := make([]byte, size) + err = bcrypt.SignHash(pkey, info, hashed, out, &size, flags) + if err != nil { + return nil, err + } + return out[:size], nil +} + +func rsaVerify(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, hashed, sig []byte, flags bcrypt.PadMode) error { + return bcrypt.VerifySignature(pkey, info, hashed, sig, flags) +} + +func newPKCS1_PADDING_INFO(h crypto.Hash) (info bcrypt.PKCS1_PADDING_INFO, err error) { + if h != 0 { + hashID := cryptoHashToID(h) + if hashID == "" { + err = errors.New("crypto/rsa: unsupported hash function") + } else { + info.AlgId = utf16PtrFromString(hashID) + } + } + return +} + +func cryptoHashToID(ch crypto.Hash) string { + switch ch { + case crypto.MD5: + return bcrypt.MD5_ALGORITHM + case crypto.SHA1: + return bcrypt.SHA1_ALGORITHM + case crypto.SHA256: + return bcrypt.SHA256_ALGORITHM + case crypto.SHA384: + return bcrypt.SHA384_ALGORITHM + case crypto.SHA512: + return bcrypt.SHA512_ALGORITHM + } + return "" +} diff --git a/cng/rsa_test.go b/cng/rsa_test.go index de6432c..5c2c26a 100644 --- a/cng/rsa_test.go +++ b/cng/rsa_test.go @@ -8,6 +8,7 @@ package cng import ( "bytes" + "crypto" "strconv" "testing" ) @@ -103,3 +104,47 @@ func TestEncryptDecryptNoPadding(t *testing.T) { t.Errorf("got:%x want:%x", dec, msg) } } + +func TestSignVerifyPKCS1v15(t *testing.T) { + sha256 := NewSHA256() + priv, pub := newRSAKey(t, 2048) + sha256.Write([]byte("hi!")) + hashed := sha256.Sum(nil) + signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + signed, err := SignRSAPKCS1v15(priv, 0, msg) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, 0, msg, signed) + if err != nil { + t.Fatal(err) + } +} + +func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { + sha256 := NewSHA256() + msg := []byte("hi!") + priv, pub := newRSAKey(t, 2048) + sha256.Write(msg) + hashed := sha256.Sum(nil) + signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) + if err == nil { + t.Fatal("error expected") + } +} diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 73e9481..60bcfb6 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -17,6 +17,7 @@ const ( SHA512_ALGORITHM = "SHA512" AES_ALGORITHM = "AES" RSA_ALGORITHM = "RSA" + MD5_ALGORITHM = "MD5" ) const ( @@ -39,14 +40,14 @@ const ( USE_SYSTEM_PREFERRED_RNG = 0x00000002 ) -type EncryptFlags uint32 +type PadMode uint32 const ( - EncrytFlagsNone EncryptFlags = 0x0 - PAD_NONE EncryptFlags = 0x1 - PAD_PKCS1 EncryptFlags = 0x2 - PAD_OAEP EncryptFlags = 0x4 - PAD_PSS EncryptFlags = 0x8 + PAD_UNDEFINED PadMode = 0x0 + PAD_NONE PadMode = 0x1 + PAD_PKCS1 PadMode = 0x2 + PAD_OAEP PadMode = 0x4 + PAD_PSS PadMode = 0x8 ) type AlgorithmProviderFlags uint32 @@ -119,6 +120,11 @@ type OAEP_PADDING_INFO struct { LabelSize uint32 } +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_pkcs1_padding_info +type PKCS1_PADDING_INFO struct { + AlgId *uint16 +} + // https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob type RSAKEY_BLOB struct { Magic KeyBlobMagicNumber @@ -154,5 +160,7 @@ type RSAKEY_BLOB struct { //sys ImportKeyPair (hAlgorithm ALG_HANDLE, hImportKey KEY_HANDLE, pszBlobType *uint16, phKey *KEY_HANDLE, pbInput []byte, dwFlags uint32) (s error) = bcrypt.BCryptImportKeyPair //sys ExportKey(hKey KEY_HANDLE, hExportKey KEY_HANDLE, pszBlobType *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptExportKey //sys DestroyKey(hKey KEY_HANDLE) (s error) = bcrypt.BCryptDestroyKey -//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) = bcrypt.BCryptEncrypt -//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) = bcrypt.BCryptDecrypt +//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptEncrypt +//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptDecrypt +//sys SignHash (hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbInput []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) = bcrypt.BCryptSignHash +//sys VerifySignature(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbHash []byte, pbSignature []byte, dwFlags PadMode) (s error) = bcrypt.BCryptVerifySignature diff --git a/internal/bcrypt/zsyscall_windows.go b/internal/bcrypt/zsyscall_windows.go index 40ee49e..5fe28c5 100644 --- a/internal/bcrypt/zsyscall_windows.go +++ b/internal/bcrypt/zsyscall_windows.go @@ -57,6 +57,8 @@ var ( procBCryptImportKeyPair = modbcrypt.NewProc("BCryptImportKeyPair") procBCryptOpenAlgorithmProvider = modbcrypt.NewProc("BCryptOpenAlgorithmProvider") procBCryptSetProperty = modbcrypt.NewProc("BCryptSetProperty") + procBCryptSignHash = modbcrypt.NewProc("BCryptSignHash") + procBCryptVerifySignature = modbcrypt.NewProc("BCryptVerifySignature") ) func CloseAlgorithmProvider(hAlgorithm ALG_HANDLE, dwFlags uint32) (s error) { @@ -83,7 +85,7 @@ func CreateHash(hAlgorithm ALG_HANDLE, phHash *HASH_HANDLE, pbHashObject []byte, return } -func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) { +func Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { var _p0 *byte if len(pbInput) > 0 { _p0 = &pbInput[0] @@ -131,7 +133,7 @@ func DuplicateHash(hHash HASH_HANDLE, phNewHash *HASH_HANDLE, pbHashObject []byt return } -func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) { +func Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { var _p0 *byte if len(pbInput) > 0 { _p0 = &pbInput[0] @@ -274,3 +276,35 @@ func SetProperty(hObject HANDLE, pszProperty *uint16, pbInput []byte, dwFlags ui } return } + +func SignHash(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbInput []byte, pbOutput []byte, pcbResult *uint32, dwFlags PadMode) (s error) { + var _p0 *byte + if len(pbInput) > 0 { + _p0 = &pbInput[0] + } + var _p1 *byte + if len(pbOutput) > 0 { + _p1 = &pbOutput[0] + } + r0, _, _ := syscall.Syscall9(procBCryptSignHash.Addr(), 8, uintptr(hKey), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbInput)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbOutput)), uintptr(unsafe.Pointer(pcbResult)), uintptr(dwFlags), 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +} + +func VerifySignature(hKey KEY_HANDLE, pPaddingInfo unsafe.Pointer, pbHash []byte, pbSignature []byte, dwFlags PadMode) (s error) { + var _p0 *byte + if len(pbHash) > 0 { + _p0 = &pbHash[0] + } + var _p1 *byte + if len(pbSignature) > 0 { + _p1 = &pbSignature[0] + } + r0, _, _ := syscall.Syscall9(procBCryptVerifySignature.Addr(), 7, uintptr(hKey), uintptr(pPaddingInfo), uintptr(unsafe.Pointer(_p0)), uintptr(len(pbHash)), uintptr(unsafe.Pointer(_p1)), uintptr(len(pbSignature)), uintptr(dwFlags), 0, 0) + if r0 != 0 { + s = syscall.Errno(r0) + } + return +}