diff --git a/cng/bbig/big.go b/cng/bbig/big.go new file mode 100644 index 0000000..584f206 --- /dev/null +++ b/cng/bbig/big.go @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package bbig + +import ( + "math/big" + + "github.com/microsoft/go-crypto-winnative/cng" +) + +func Enc(b *big.Int) cng.BigInt { + if b == nil { + return nil + } + x := b.Bytes() + if len(x) == 0 { + return cng.BigInt{} + } + return x +} + +func Dec(b cng.BigInt) *big.Int { + if b == nil { + return nil + } + if len(b) == 0 { + return new(big.Int) + } + return new(big.Int).SetBytes(b) +} diff --git a/cng/big.go b/cng/big.go new file mode 100644 index 0000000..dfecd03 --- /dev/null +++ b/cng/big.go @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package cng + +// This file does not have build constraints to +// facilitate using BigInt in Go crypto. +// Go crypto references BigInt unconditionally, +// even if it is not finally used. + +// A BigInt is the raw big-endian bytes from a BigInt. +// This definition allows us to avoid importing math/big. +// Conversion between BigInt and *big.Int is in cng/bbig. +type BigInt []byte diff --git a/cng/rsa.go b/cng/rsa.go index 87ffb6d..49c1f8d 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -10,7 +10,6 @@ import ( "crypto" "errors" "hash" - "math/big" "runtime" "unsafe" @@ -33,8 +32,8 @@ func loadRsa() (rsaAlgorithm, error) { const sizeOfRSABlobHeader = uint32(unsafe.Sizeof(bcrypt.RSAKEY_BLOB{})) -func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { - bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { +func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { + bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) { return nil, nil, nil, nil, nil, nil, nil, nil, e } h, err := loadRsa() @@ -73,8 +72,9 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) return bad(errors.New("crypto/rsa: exported key is corrupted")) } data := blob[sizeOfRSABlobHeader:] - consumeBigInt := func(size uint32) *big.Int { - b := new(big.Int).SetBytes(data[:size]) + consumeBigInt := func(size uint32) BigInt { + b := make(BigInt, size) + copy(b, data) data = data[size:] return b } @@ -93,7 +93,7 @@ type PublicKeyRSA struct { pkey bcrypt.KEY_HANDLE } -func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) { +func NewPublicKeyRSA(N, E BigInt) (*PublicKeyRSA, error) { h, err := loadRsa() if err != nil { return nil, err @@ -120,7 +120,7 @@ func (k *PrivateKeyRSA) finalize() { bcrypt.DestroyKey(k.pkey) } -func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) { +func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv BigInt) (*PrivateKeyRSA, error) { h, err := loadRsa() if err != nil { return nil, err @@ -135,15 +135,11 @@ func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, err return k, nil } -func bigIntBytesLen(b *big.Int) uint32 { - return uint32(b.BitLen()+7) / 8 -} - -func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte { +func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv BigInt) []byte { hdr := bcrypt.RSAKEY_BLOB{ - BitLength: uint32(N.BitLen()), - PublicExpSize: bigIntBytesLen(E), - ModulusSize: bigIntBytesLen(N), + BitLength: uint32(len(N) * 8), + PublicExpSize: uint32(len(E)), + ModulusSize: uint32(len(N)), } var blob []byte if D == nil { @@ -151,14 +147,14 @@ func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte { blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize) } else { hdr.Magic = bcrypt.RSAFULLPRIVATE_MAGIC - hdr.Prime1Size = bigIntBytesLen(P) - hdr.Prime2Size = bigIntBytesLen(Q) + hdr.Prime1Size = uint32(len(P)) + hdr.Prime2Size = uint32(len(Q)) blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize*2+hdr.Prime1Size*3+hdr.Prime2Size*2) } copy(blob, (*(*[sizeOfRSABlobHeader]byte)(unsafe.Pointer(&hdr)))[:]) data := blob[sizeOfRSABlobHeader:] - encode := func(b *big.Int, size uint32) { - b.FillBytes(data[:size]) + encode := func(b BigInt, size uint32) { + copy(data, b) data = data[size:] } encode(E, hdr.PublicExpSize) diff --git a/cng/rsa_test.go b/cng/rsa_test.go index dea2a14..c384fac 100644 --- a/cng/rsa_test.go +++ b/cng/rsa_test.go @@ -4,7 +4,7 @@ //go:build windows // +build windows -package cng +package cng_test import ( "bytes" @@ -12,19 +12,22 @@ import ( "math/big" "strconv" "testing" + + "github.com/microsoft/go-crypto-winnative/cng" + "github.com/microsoft/go-crypto-winnative/cng/bbig" ) -func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) { +func newRSAKey(t *testing.T, size int) (*cng.PrivateKeyRSA, *cng.PublicKeyRSA) { t.Helper() - N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size) + N, E, D, P, Q, Dp, Dq, Qinv, err := cng.GenerateKeyRSA(size) if err != nil { t.Fatalf("GenerateKeyRSA(%d): %v", size, err) } - priv, err := NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) + priv, err := cng.NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv) if err != nil { t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err) } - pub, err := NewPublicKeyRSA(N, E) + pub, err := cng.NewPublicKeyRSA(N, E) if err != nil { t.Fatalf("NewPublicKeyRSA(%d): %v", size, err) } @@ -34,14 +37,13 @@ func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) { func TestRSAKeyGeneration(t *testing.T) { for _, size := range []int{2048, 3072} { t.Run(strconv.Itoa(size), func(t *testing.T) { - t.Parallel() priv, pub := newRSAKey(t, size) msg := []byte("hi!") - enc, err := EncryptRSAPKCS1(pub, msg) + enc, err := cng.EncryptRSAPKCS1(pub, msg) if err != nil { t.Fatalf("EncryptPKCS1v15: %v", err) } - dec, err := DecryptRSAPKCS1(priv, enc) + dec, err := cng.DecryptRSAPKCS1(priv, enc) if err != nil { t.Fatalf("DecryptPKCS1v15: %v", err) } @@ -53,15 +55,15 @@ func TestRSAKeyGeneration(t *testing.T) { } func TestEncryptDecryptOAEP(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() msg := []byte("hi!") label := []byte("ho!") priv, pub := newRSAKey(t, 2048) - enc, err := EncryptRSAOAEP(sha256, pub, msg, label) + enc, err := cng.EncryptRSAOAEP(sha256, pub, msg, label) if err != nil { t.Fatal(err) } - dec, err := DecryptRSAOAEP(sha256, priv, enc, label) + dec, err := cng.DecryptRSAOAEP(sha256, priv, enc, label) if err != nil { t.Fatal(err) } @@ -71,14 +73,14 @@ func TestEncryptDecryptOAEP(t *testing.T) { } func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) - enc, err := EncryptRSAOAEP(sha256, pub, msg, []byte("ho!")) + enc, err := cng.EncryptRSAOAEP(sha256, pub, msg, []byte("ho!")) if err != nil { t.Fatal(err) } - dec, err := DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!")) + dec, err := cng.DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!")) if err == nil { t.Errorf("error expected") } @@ -93,11 +95,11 @@ func TestEncryptDecryptNoPadding(t *testing.T) { msg[0] = 1 msg[255] = 1 priv, pub := newRSAKey(t, bits) - enc, err := EncryptRSANoPadding(pub, msg[:]) + enc, err := cng.EncryptRSANoPadding(pub, msg[:]) if err != nil { t.Fatal(err) } - dec, err := DecryptRSANoPadding(priv, enc) + dec, err := cng.DecryptRSANoPadding(priv, enc) if err != nil { t.Fatal(err) } @@ -107,15 +109,15 @@ func TestEncryptDecryptNoPadding(t *testing.T) { } func TestSignVerifyPKCS1v15(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() priv, pub := newRSAKey(t, 2048) sha256.Write([]byte("hi!")) hashed := sha256.Sum(nil) - signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + signed, err := cng.SignRSAPKCS1v15(priv, crypto.SHA256, hashed) if err != nil { t.Fatal(err) } - err = VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) + err = cng.VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed) if err != nil { t.Fatal(err) } @@ -124,42 +126,42 @@ func TestSignVerifyPKCS1v15(t *testing.T) { func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) { msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) - signed, err := SignRSAPKCS1v15(priv, 0, msg) + signed, err := cng.SignRSAPKCS1v15(priv, 0, msg) if err != nil { t.Fatal(err) } - err = VerifyRSAPKCS1v15(pub, 0, msg, signed) + err = cng.VerifyRSAPKCS1v15(pub, 0, msg, signed) if err != nil { t.Fatal(err) } } func TestSignVerifyPKCS1v15_Invalid(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() msg := []byte("hi!") priv, pub := newRSAKey(t, 2048) sha256.Write(msg) hashed := sha256.Sum(nil) - signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed) + signed, err := cng.SignRSAPKCS1v15(priv, crypto.SHA256, hashed) if err != nil { t.Fatal(err) } - err = VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) + err = cng.VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed) if err == nil { t.Fatal("error expected") } } func TestSignVerifyRSAPSS(t *testing.T) { - sha256 := NewSHA256() + sha256 := cng.NewSHA256() priv, pub := newRSAKey(t, 2048) sha256.Write([]byte("testing")) hashed := sha256.Sum(nil) - signed, err := SignRSAPSS(priv, crypto.SHA256, hashed, 0) + signed, err := cng.SignRSAPSS(priv, crypto.SHA256, hashed, 0) if err != nil { t.Fatal(err) } - err = VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0) + err = cng.VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0) if err != nil { t.Fatal(err) } @@ -177,14 +179,24 @@ func BenchmarkEncryptRSAPKCS1(b *testing.B) { b.StopTimer() // Public key length should be at least of 2048 bits, else OpenSSL will report an error when running in FIPS mode. n := fromBase36("14314132931241006650998084889274020608918049032671858325988396851334124245188214251956198731333464217832226406088020736932173064754214329009979944037640912127943488972644697423190955557435910767690712778463524983667852819010259499695177313115447116110358524558307947613422897787329221478860907963827160223559690523660574329011927531289655711860504630573766609239332569210831325633840174683944553667352219670930408593321661375473885147973879086994006440025257225431977751512374815915392249179976902953721486040787792801849818254465486633791826766873076617116727073077821584676715609985777563958286637185868165868520557") - test2048PubKey, err := NewPublicKeyRSA(n, big.NewInt(3)) + test2048PubKey, err := cng.NewPublicKeyRSA(bbig.Enc(n), bbig.Enc(big.NewInt(3))) if err != nil { b.Fatal(err) } b.StartTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - if _, err := EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil { + if _, err := cng.EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkGenerateKeyRSA(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, _, _, _, _, _, _, err := cng.GenerateKeyRSA(2048) + if err != nil { b.Fatal(err) } }