Skip to content

Commit

Permalink
remove math/big dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Jun 29, 2022
1 parent a2b9d65 commit 402be0e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 48 deletions.
31 changes: 31 additions & 0 deletions cng/bbig/big.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

package bbig

import (
"math/big"

"github.com/microsoft/go-crypto-winnative/cng"
)

func Enc(b *big.Int) cng.BigInt {
if b == nil {
return nil
}
x := b.Bytes()
if len(x) == 0 {
return cng.BigInt{}
}
return x
}

func Dec(b cng.BigInt) *big.Int {
if b == nil {
return nil
}
if len(b) == 0 {
return new(big.Int)
}
return new(big.Int).SetBytes(b)
}
13 changes: 13 additions & 0 deletions cng/big.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package cng

// This file does not have build constraints to
// facilitate using BigInt in Go crypto.
// Go crypto references BigInt unconditionally,
// even if it is not finally used.

// A BigInt is the raw big-endian bytes from a BigInt.
// This definition allows us to avoid importing math/big.
// Conversion between BigInt and *big.Int is in cng/bbig.
type BigInt []byte
34 changes: 15 additions & 19 deletions cng/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto"
"errors"
"hash"
"math/big"
"runtime"
"unsafe"

Expand All @@ -33,8 +32,8 @@ func loadRsa() (rsaAlgorithm, error) {

const sizeOfRSABlobHeader = uint32(unsafe.Sizeof(bcrypt.RSAKEY_BLOB{}))

func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) {
bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) {
func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) {
bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv BigInt, err error) {
return nil, nil, nil, nil, nil, nil, nil, nil, e
}
h, err := loadRsa()
Expand Down Expand Up @@ -73,8 +72,9 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error)
return bad(errors.New("crypto/rsa: exported key is corrupted"))
}
data := blob[sizeOfRSABlobHeader:]
consumeBigInt := func(size uint32) *big.Int {
b := new(big.Int).SetBytes(data[:size])
consumeBigInt := func(size uint32) BigInt {
b := make(BigInt, size)
copy(b, data)
data = data[size:]
return b
}
Expand All @@ -93,7 +93,7 @@ type PublicKeyRSA struct {
pkey bcrypt.KEY_HANDLE
}

func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) {
func NewPublicKeyRSA(N, E BigInt) (*PublicKeyRSA, error) {
h, err := loadRsa()
if err != nil {
return nil, err
Expand All @@ -120,7 +120,7 @@ func (k *PrivateKeyRSA) finalize() {
bcrypt.DestroyKey(k.pkey)
}

func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) {
func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv BigInt) (*PrivateKeyRSA, error) {
h, err := loadRsa()
if err != nil {
return nil, err
Expand All @@ -135,30 +135,26 @@ func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, err
return k, nil
}

func bigIntBytesLen(b *big.Int) uint32 {
return uint32(b.BitLen()+7) / 8
}

func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) []byte {
func encodeRSAKey(N, E, D, P, Q, Dp, Dq, Qinv BigInt) []byte {
hdr := bcrypt.RSAKEY_BLOB{
BitLength: uint32(N.BitLen()),
PublicExpSize: bigIntBytesLen(E),
ModulusSize: bigIntBytesLen(N),
BitLength: uint32(len(N) * 8),
PublicExpSize: uint32(len(E)),
ModulusSize: uint32(len(N)),
}
var blob []byte
if D == nil {
hdr.Magic = bcrypt.RSAPUBLIC_MAGIC
blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize)
} else {
hdr.Magic = bcrypt.RSAFULLPRIVATE_MAGIC
hdr.Prime1Size = bigIntBytesLen(P)
hdr.Prime2Size = bigIntBytesLen(Q)
hdr.Prime1Size = uint32(len(P))
hdr.Prime2Size = uint32(len(Q))
blob = make([]byte, sizeOfRSABlobHeader+hdr.PublicExpSize+hdr.ModulusSize*2+hdr.Prime1Size*3+hdr.Prime2Size*2)
}
copy(blob, (*(*[sizeOfRSABlobHeader]byte)(unsafe.Pointer(&hdr)))[:])
data := blob[sizeOfRSABlobHeader:]
encode := func(b *big.Int, size uint32) {
b.FillBytes(data[:size])
encode := func(b BigInt, size uint32) {
copy(data, b)
data = data[size:]
}
encode(E, hdr.PublicExpSize)
Expand Down
70 changes: 41 additions & 29 deletions cng/rsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,30 @@
//go:build windows
// +build windows

package cng
package cng_test

import (
"bytes"
"crypto"
"math/big"
"strconv"
"testing"

"github.com/microsoft/go-crypto-winnative/cng"
"github.com/microsoft/go-crypto-winnative/cng/bbig"
)

func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) {
func newRSAKey(t *testing.T, size int) (*cng.PrivateKeyRSA, *cng.PublicKeyRSA) {
t.Helper()
N, E, D, P, Q, Dp, Dq, Qinv, err := GenerateKeyRSA(size)
N, E, D, P, Q, Dp, Dq, Qinv, err := cng.GenerateKeyRSA(size)
if err != nil {
t.Fatalf("GenerateKeyRSA(%d): %v", size, err)
}
priv, err := NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv)
priv, err := cng.NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv)
if err != nil {
t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err)
}
pub, err := NewPublicKeyRSA(N, E)
pub, err := cng.NewPublicKeyRSA(N, E)
if err != nil {
t.Fatalf("NewPublicKeyRSA(%d): %v", size, err)
}
Expand All @@ -34,14 +37,13 @@ func newRSAKey(t *testing.T, size int) (*PrivateKeyRSA, *PublicKeyRSA) {
func TestRSAKeyGeneration(t *testing.T) {
for _, size := range []int{2048, 3072} {
t.Run(strconv.Itoa(size), func(t *testing.T) {
t.Parallel()
priv, pub := newRSAKey(t, size)
msg := []byte("hi!")
enc, err := EncryptRSAPKCS1(pub, msg)
enc, err := cng.EncryptRSAPKCS1(pub, msg)
if err != nil {
t.Fatalf("EncryptPKCS1v15: %v", err)
}
dec, err := DecryptRSAPKCS1(priv, enc)
dec, err := cng.DecryptRSAPKCS1(priv, enc)
if err != nil {
t.Fatalf("DecryptPKCS1v15: %v", err)
}
Expand All @@ -53,15 +55,15 @@ func TestRSAKeyGeneration(t *testing.T) {
}

func TestEncryptDecryptOAEP(t *testing.T) {
sha256 := NewSHA256()
sha256 := cng.NewSHA256()
msg := []byte("hi!")
label := []byte("ho!")
priv, pub := newRSAKey(t, 2048)
enc, err := EncryptRSAOAEP(sha256, pub, msg, label)
enc, err := cng.EncryptRSAOAEP(sha256, pub, msg, label)
if err != nil {
t.Fatal(err)
}
dec, err := DecryptRSAOAEP(sha256, priv, enc, label)
dec, err := cng.DecryptRSAOAEP(sha256, priv, enc, label)
if err != nil {
t.Fatal(err)
}
Expand All @@ -71,14 +73,14 @@ func TestEncryptDecryptOAEP(t *testing.T) {
}

func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) {
sha256 := NewSHA256()
sha256 := cng.NewSHA256()
msg := []byte("hi!")
priv, pub := newRSAKey(t, 2048)
enc, err := EncryptRSAOAEP(sha256, pub, msg, []byte("ho!"))
enc, err := cng.EncryptRSAOAEP(sha256, pub, msg, []byte("ho!"))
if err != nil {
t.Fatal(err)
}
dec, err := DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!"))
dec, err := cng.DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!"))
if err == nil {
t.Errorf("error expected")
}
Expand All @@ -93,11 +95,11 @@ func TestEncryptDecryptNoPadding(t *testing.T) {
msg[0] = 1
msg[255] = 1
priv, pub := newRSAKey(t, bits)
enc, err := EncryptRSANoPadding(pub, msg[:])
enc, err := cng.EncryptRSANoPadding(pub, msg[:])
if err != nil {
t.Fatal(err)
}
dec, err := DecryptRSANoPadding(priv, enc)
dec, err := cng.DecryptRSANoPadding(priv, enc)
if err != nil {
t.Fatal(err)
}
Expand All @@ -107,15 +109,15 @@ func TestEncryptDecryptNoPadding(t *testing.T) {
}

func TestSignVerifyPKCS1v15(t *testing.T) {
sha256 := NewSHA256()
sha256 := cng.NewSHA256()
priv, pub := newRSAKey(t, 2048)
sha256.Write([]byte("hi!"))
hashed := sha256.Sum(nil)
signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed)
signed, err := cng.SignRSAPKCS1v15(priv, crypto.SHA256, hashed)
if err != nil {
t.Fatal(err)
}
err = VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed)
err = cng.VerifyRSAPKCS1v15(pub, crypto.SHA256, hashed, signed)
if err != nil {
t.Fatal(err)
}
Expand All @@ -124,42 +126,42 @@ func TestSignVerifyPKCS1v15(t *testing.T) {
func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) {
msg := []byte("hi!")
priv, pub := newRSAKey(t, 2048)
signed, err := SignRSAPKCS1v15(priv, 0, msg)
signed, err := cng.SignRSAPKCS1v15(priv, 0, msg)
if err != nil {
t.Fatal(err)
}
err = VerifyRSAPKCS1v15(pub, 0, msg, signed)
err = cng.VerifyRSAPKCS1v15(pub, 0, msg, signed)
if err != nil {
t.Fatal(err)
}
}

func TestSignVerifyPKCS1v15_Invalid(t *testing.T) {
sha256 := NewSHA256()
sha256 := cng.NewSHA256()
msg := []byte("hi!")
priv, pub := newRSAKey(t, 2048)
sha256.Write(msg)
hashed := sha256.Sum(nil)
signed, err := SignRSAPKCS1v15(priv, crypto.SHA256, hashed)
signed, err := cng.SignRSAPKCS1v15(priv, crypto.SHA256, hashed)
if err != nil {
t.Fatal(err)
}
err = VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed)
err = cng.VerifyRSAPKCS1v15(pub, crypto.SHA256, msg, signed)
if err == nil {
t.Fatal("error expected")
}
}

func TestSignVerifyRSAPSS(t *testing.T) {
sha256 := NewSHA256()
sha256 := cng.NewSHA256()
priv, pub := newRSAKey(t, 2048)
sha256.Write([]byte("testing"))
hashed := sha256.Sum(nil)
signed, err := SignRSAPSS(priv, crypto.SHA256, hashed, 0)
signed, err := cng.SignRSAPSS(priv, crypto.SHA256, hashed, 0)
if err != nil {
t.Fatal(err)
}
err = VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0)
err = cng.VerifyRSAPSS(pub, crypto.SHA256, hashed, signed, 0)
if err != nil {
t.Fatal(err)
}
Expand All @@ -177,14 +179,24 @@ func BenchmarkEncryptRSAPKCS1(b *testing.B) {
b.StopTimer()
// Public key length should be at least of 2048 bits, else OpenSSL will report an error when running in FIPS mode.
n := fromBase36("14314132931241006650998084889274020608918049032671858325988396851334124245188214251956198731333464217832226406088020736932173064754214329009979944037640912127943488972644697423190955557435910767690712778463524983667852819010259499695177313115447116110358524558307947613422897787329221478860907963827160223559690523660574329011927531289655711860504630573766609239332569210831325633840174683944553667352219670930408593321661375473885147973879086994006440025257225431977751512374815915392249179976902953721486040787792801849818254465486633791826766873076617116727073077821584676715609985777563958286637185868165868520557")
test2048PubKey, err := NewPublicKeyRSA(n, big.NewInt(3))
test2048PubKey, err := cng.NewPublicKeyRSA(bbig.Enc(n), bbig.Enc(big.NewInt(3)))
if err != nil {
b.Fatal(err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil {
if _, err := cng.EncryptRSAPKCS1(test2048PubKey, []byte("testing")); err != nil {
b.Fatal(err)
}
}
}

func BenchmarkGenerateKeyRSA(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _, _, _, _, _, _, _, err := cng.GenerateKeyRSA(2048)
if err != nil {
b.Fatal(err)
}
}
Expand Down

0 comments on commit 402be0e

Please sign in to comment.