Skip to content

Commit

Permalink
implement RSA encrypt and decrypt
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Mar 18, 2022
1 parent 17459e4 commit 90e2201
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 15 deletions.
4 changes: 2 additions & 2 deletions cng/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ func (g *aesGCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte {

info := bcrypt.NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, out[len(out)-gcmTagSize:])
var encSize uint32
err := bcrypt.Encrypt(g.kh, plaintext, info, nil, out, &encSize, 0)
err := bcrypt.Encrypt(g.kh, plaintext, unsafe.Pointer(info), nil, out, &encSize, 0)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -345,7 +345,7 @@ func (g *aesGCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, er

info := bcrypt.NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag)
var decSize uint32
err := bcrypt.Decrypt(g.kh, ciphertext, info, nil, out, &decSize, 0)
err := bcrypt.Decrypt(g.kh, ciphertext, unsafe.Pointer(info), nil, out, &decSize, 0)
if err != nil || int(decSize) != len(ciphertext) {
for i := range out {
out[i] = 0
Expand Down
71 changes: 71 additions & 0 deletions cng/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package cng

import (
"errors"
"hash"
"math/big"
"runtime"
"sync"
Expand Down Expand Up @@ -171,3 +173,72 @@ func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte {
}
return blob
}

func DecryptRSAOAEP(h hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) {
defer runtime.KeepAlive(priv)
return rsaOAEP(h, priv.pkey, ciphertext, label, false)
}

func EncryptRSAOAEP(h hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) {
defer runtime.KeepAlive(pub)
return rsaOAEP(h, pub.pkey, msg, label, true)
}

func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
defer runtime.KeepAlive(priv)
return rsaCrypt(priv.pkey, nil, ciphertext, bcrypt.PAD_PKCS1, false)
}

func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
defer runtime.KeepAlive(pub)
return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_PKCS1, true)
}

func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
defer runtime.KeepAlive(priv)
return rsaCrypt(priv.pkey, nil, ciphertext, bcrypt.PAD_NONE, false)

This comment has been minimized.

Copy link
@jaredpar

jaredpar Jun 30, 2022

Member

Nit: extra blank line

}

func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
defer runtime.KeepAlive(pub)
return rsaCrypt(pub.pkey, nil, msg, bcrypt.PAD_NONE, true)
}

func rsaCrypt(pkey bcrypt.KEY_HANDLE, info unsafe.Pointer, in []byte, flags bcrypt.EncryptFlags, encrypt bool) ([]byte, error) {
var size uint32
var err error
if encrypt {
err = bcrypt.Encrypt(pkey, in, info, nil, nil, &size, flags)
} else {
err = bcrypt.Decrypt(pkey, in, info, nil, nil, &size, flags)
}
if err != nil {
return nil, err
}
out := make([]byte, size)
if encrypt {
err = bcrypt.Encrypt(pkey, in, info, nil, out, &size, flags)
} else {
err = bcrypt.Decrypt(pkey, in, info, nil, out, &size, flags)
}
if err != nil {
return nil, err
}
return out[:size], nil
}

func rsaOAEP(h hash.Hash, pkey bcrypt.KEY_HANDLE, in, label []byte, encrypt bool) ([]byte, error) {
hashID := hashToID(h)
if hashID == "" {
return nil, errors.New("crypto/rsa: unsupported hash function")
}
info := bcrypt.OAEP_PADDING_INFO{
AlgId: utf16PtrFromString(hashID),
LabelSize: uint32(len(label)),
}
if len(label) > 0 {
info.Label = &label[0]
}
return rsaCrypt(pkey, unsafe.Pointer(&info), in, bcrypt.PAD_OAEP, encrypt)
}
88 changes: 81 additions & 7 deletions cng/rsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,99 @@
package cng

import (
"bytes"
"strconv"
"testing"
)

func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) {
t.Helper()
N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size)
if err != nil {
t.Fatalf("GenerateKeyRSA(%d): %v", size, err)
}
priv, err := NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv)
if err != nil {
t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err)
}
pub, err := NewPublicKeyRSA(N, E)
if err != nil {
t.Fatalf("NewPublicKeyRSA(%d): %v", size, err)
}
return priv, pub
}

func TestRSAKeyGeneration(t *testing.T) {
for _, size := range []int{2048, 3072} {
t.Run(strconv.Itoa(size), func(t *testing.T) {
N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size)
t.Parallel()
priv, pub := newRSAKey(t, size)
msg := []byte("hi!")
enc, err := EncryptRSAPKCS1(pub, msg)
if err != nil {
t.Fatalf("GenerateKeyRSA(%d): %v", size, err)
t.Fatalf("EncryptPKCS1v15: %v", err)
}
_, err = NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv)
dec, err := DecryptRSAPKCS1(priv, enc)
if err != nil {
t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err)
t.Fatalf("DecryptPKCS1v15: %v", err)
}
_, err = NewPublicKeyRSA(N, E)
if err != nil {
t.Fatalf("NewPublicKeyRSA(%d): %v", size, err)
if !bytes.Equal(dec, msg) {
t.Fatalf("got:%x want:%x", dec, msg)
}
})
}
}

func TestEncryptDecryptOAEP(t *testing.T) {
sha256 := NewSHA256()
msg := []byte("hi!")
label := []byte("ho!")
priv, pub := newRSAKey(t, 2048)
enc, err := EncryptRSAOAEP(sha256, pub, msg, label)
if err != nil {
t.Fatal(err)
}
dec, err := DecryptRSAOAEP(sha256, priv, enc, label)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(dec, msg) {
t.Errorf("got:%x want:%x", dec, msg)
}
}

func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) {
sha256 := NewSHA256()
msg := []byte("hi!")
priv, pub := newRSAKey(t, 2048)
enc, err := EncryptRSAOAEP(sha256, pub, msg, []byte("ho!"))
if err != nil {
t.Fatal(err)
}
dec, err := DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!"))
if err == nil {
t.Errorf("error expected")
}
if dec != nil {
t.Errorf("got:%x want: nil", dec)
}
}

func TestEncryptDecryptNoPadding(t *testing.T) {
const bits = 2048
var msg [bits / 8]byte
msg[0] = 1
msg[255] = 1
priv, pub := newRSAKey(t, bits)
enc, err := EncryptRSANoPadding(pub, msg[:])
if err != nil {
t.Fatal(err)
}
dec, err := DecryptRSANoPadding(priv, enc)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(dec, msg[:]) {
t.Errorf("got:%x want:%x", dec, msg)
}
}
22 changes: 20 additions & 2 deletions internal/bcrypt/bcrypt_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const (
CHAIN_MODE_CBC = "ChainingModeCBC"
CHAIN_MODE_GCM = "ChainingModeGCM"
KEY_LENGTHS = "KeyLengths"
BLOCK_LENGTH = "BlockLength"
)

const (
Expand All @@ -38,6 +39,16 @@ const (
USE_SYSTEM_PREFERRED_RNG = 0x00000002
)

type EncryptFlags uint32

const (
EncrytFlagsNone EncryptFlags = 0x0
PAD_NONE EncryptFlags = 0x1
PAD_PKCS1 EncryptFlags = 0x2
PAD_OAEP EncryptFlags = 0x4
PAD_PSS EncryptFlags = 0x8
)

type AlgorithmProviderFlags uint32

const (
Expand Down Expand Up @@ -101,6 +112,13 @@ func NewAUTHENTICATED_CIPHER_MODE_INFO(nonce, additionalData, tag []byte) *AUTHE
return &info
}

// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_oaep_padding_info
type OAEP_PADDING_INFO struct {
AlgId *uint16
Label *byte
LabelSize uint32
}

// https://docs.microsoft.com/en-us/windows/win32/api/bcrypt/ns-bcrypt-bcrypt_rsakey_blob
type RSAKEY_BLOB struct {
Magic KeyBlobMagicNumber
Expand Down Expand Up @@ -136,5 +154,5 @@ type RSAKEY_BLOB struct {
//sys ImportKeyPair (hAlgorithm ALG_HANDLE, hImportKey KEY_HANDLE, pszBlobType *uint16, phKey *KEY_HANDLE, pbInput []byte, dwFlags uint32) (s error) = bcrypt.BCryptImportKeyPair
//sys ExportKey(hKey KEY_HANDLE, hExportKey KEY_HANDLE, pszBlobType *uint16, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptExportKey
//sys DestroyKey(hKey KEY_HANDLE) (s error) = bcrypt.BCryptDestroyKey
//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptEncrypt
//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo *AUTHENTICATED_CIPHER_MODE_INFO, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags uint32) (s error) = bcrypt.BCryptDecrypt
//sys Encrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) = bcrypt.BCryptEncrypt
//sys Decrypt(hKey KEY_HANDLE, pbInput []byte, pPaddingInfo unsafe.Pointer, pbIV []byte, pbOutput []byte, pcbResult *uint32, dwFlags EncryptFlags) (s error) = bcrypt.BCryptDecrypt
8 changes: 4 additions & 4 deletions internal/bcrypt/zsyscall_windows.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 90e2201

Please sign in to comment.