diff --git a/cng/rsa.go b/cng/rsa.go index fd14fcd..b958b0c 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -206,6 +206,24 @@ func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) { return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_NONE, true) } +func SignRSAPSS(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte, saltLen int) ([]byte, error) { + defer runtime.KeepAlive(priv) + info, err := newPSS_PADDING_INFO(h, saltLen) + if err != nil { + return nil, err + } + return rsaSign(priv.pkey, unsafe.Pointer(&info), hashed, bcrypt.PAD_PSS) +} + +func VerifyRSAPSS(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte, saltLen int) error { + defer runtime.KeepAlive(pub) + info, err := newPSS_PADDING_INFO(h, saltLen) + if err != nil { + return err + } + return rsaVerify(pub.pkey, unsafe.Pointer(&info), hashed, sig, bcrypt.PAD_PSS) +} + func SignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, hashed []byte) ([]byte, error) { defer runtime.KeepAlive(priv) info, err := newPKCS1_PADDING_INFO(h) @@ -280,6 +298,16 @@ func rsaVerify(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, hashed, sig []byte, return bcrypt.VerifySignature(pkey, info, hashed, sig, flags) } +func newPSS_PADDING_INFO(h crypto.Hash, saltLen int) (info bcrypt.PSS_PADDING_INFO, err error) { + hashID := cryptoHashToID(h) + if hashID == "" { + return info, errors.New("crypto/rsa: unsupported hash function") + } + info.AlgId = utf16PtrFromString(hashID) + info.Salt = uint32(saltLen) + return +} + func newPKCS1_PADDING_INFO(h crypto.Hash) (info bcrypt.PKCS1_PADDING_INFO, err error) { if h != 0 { hashID := cryptoHashToID(h) diff --git a/cng/rsa_test.go b/cng/rsa_test.go index 5c2c26a..888dc5b 100644 --- a/cng/rsa_test.go +++ b/cng/rsa_test.go @@ -148,3 +148,18 @@ func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { t.Fatal("error expected") } } + +func TestSignVerifyRSAPSS(t *testing.T) { + sha256 := NewSHA256() + priv, pub := newRSAKey(t, 2048) + sha256.Write([]byte("testing")) + hashed := sha256.Sum(nil) + signed, err := SignRSAPSS(priv, crypto.SHA256, hashed, 0) + if err != nil { + t.Fatal(err) + } + err = VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0) + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/bcrypt/bcrypt_windows.go b/internal/bcrypt/bcrypt_windows.go index 60bcfb6..bfd9a69 100644 --- a/internal/bcrypt/bcrypt_windows.go +++ b/internal/bcrypt/bcrypt_windows.go @@ -125,6 +125,12 @@ type PKCS1_PADDING_INFO struct { AlgId *uint16 } +// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_pss_padding_info +type PSS_PADDING_INFO struct { + AlgId *uint16 + Salt uint32 +} + // https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob type RSAKEY_BLOB struct { Magic KeyBlobMagicNumber