Skip to content

Commit

Permalink
consolidate algorithms cache
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Mar 24, 2022
1 parent 58bf0b5 commit e1f71ab
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 82 deletions.
65 changes: 25 additions & 40 deletions cng/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/cipher"
"errors"
"runtime"
"sync"
"unsafe"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
Expand All @@ -19,49 +18,35 @@ import (

const aesBlockSize = 16

var aesCache sync.Map

type aesAlgorithm struct {
h bcrypt.ALG_HANDLE
allowedKeySized []int
}

type aesCacheEntry struct {
id string
mode string
}

func loadAes(id string, mode string) (h aesAlgorithm, err error) {
if v, ok := aesCache.Load(aesCacheEntry{id, mode}); ok {
return v.(aesAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG)
if err != nil {
return
}
// Windows 8 added support to set the CipherMode value on a key,
// but Windows 7 requires that it be set on the algorithm before key creation.
err = setString(bcrypt.HANDLE(h.h), bcrypt.CHAINING_MODE, mode)
if err != nil {
bcrypt.CloseAlgorithmProvider(h.h, 0)
return
}
var info bcrypt.KEY_LENGTHS_STRUCT
var discard uint32
err = bcrypt.GetProperty(bcrypt.HANDLE(h.h), utf16PtrFromString(bcrypt.KEY_LENGTHS), (*(*[1<<31 - 1]byte)(unsafe.Pointer(&info)))[:unsafe.Sizeof(info)], &discard, 0)
func loadAes(mode string) (aesAlgorithm, error) {
v, err := loadOrStoreAlg(bcrypt.AES_ALGORITHM, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (interface{}, error) {
// Windows 8 added support to set the CipherMode value on a key,
// but Windows 7 requires that it be set on the algorithm before key creation.
err := setString(bcrypt.HANDLE(h), bcrypt.CHAINING_MODE, mode)
if err != nil {
return nil, err
}
var info bcrypt.KEY_LENGTHS_STRUCT
var discard uint32
err = bcrypt.GetProperty(bcrypt.HANDLE(h), utf16PtrFromString(bcrypt.KEY_LENGTHS), (*(*[1<<31 - 1]byte)(unsafe.Pointer(&info)))[:unsafe.Sizeof(info)], &discard, 0)
if err != nil {
return nil, err
}
var allowedKeySized []int
for size := info.MinLength; size <= info.MaxLength; size += info.Increment {
allowedKeySized = append(allowedKeySized, int(size))
}
return aesAlgorithm{h, allowedKeySized}, nil
})
if err != nil {
bcrypt.CloseAlgorithmProvider(h.h, 0)
return
}
for size := info.MinLength; size <= info.MaxLength; size += info.Increment {
h.allowedKeySized = append(h.allowedKeySized, int(size))
return aesAlgorithm{}, nil
}
if existing, loaded := aesCache.LoadOrStore(aesCacheEntry{id, mode}, h); loaded {
// We can safely use a provider that has already been cached in another concurrent goroutine.
bcrypt.CloseAlgorithmProvider(h.h, 0)
h = existing.(aesAlgorithm)
}
return
return v.(aesAlgorithm), nil
}

type aesCipher struct {
Expand All @@ -70,7 +55,7 @@ type aesCipher struct {
}

func NewAESCipher(key []byte) (cipher.Block, error) {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_ECB)
h, err := loadAes(bcrypt.CHAIN_MODE_ECB)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -180,7 +165,7 @@ type aesCBC struct {
}

func newCBC(encrypt bool, key, iv []byte) *aesCBC {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_CBC)
h, err := loadAes(bcrypt.CHAIN_MODE_CBC)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -254,7 +239,7 @@ func (g *aesGCM) finalize() {
}

func newGCM(key []byte, tls bool) (*aesGCM, error) {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM)
h, err := loadAes(bcrypt.CHAIN_MODE_GCM)
if err != nil {
return nil, err
}
Expand Down
34 changes: 31 additions & 3 deletions cng/cng.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"math"
"reflect"
"runtime"
"sync"
"syscall"
"unsafe"

Expand All @@ -25,9 +26,36 @@ func lenU32(s []byte) int {
return len(s)
}

type algCacheEntry struct {
id string
flags uint32
var algCache sync.Map

type newAlgEntryFn func(h bcrypt.ALG_HANDLE) (interface{}, error)

func loadOrStoreAlg(id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn newAlgEntryFn) (interface{}, error) {
var entryKey = struct {
id string
flags bcrypt.AlgorithmProviderFlags
mode string
}{id, flags, mode}

if v, ok := algCache.Load(entryKey); ok {
return v, nil
}
var h bcrypt.ALG_HANDLE
err := bcrypt.OpenAlgorithmProvider(&h, utf16PtrFromString(id), nil, flags)
if err != nil {
return nil, err
}
v, err := fn(h)
if err != nil {
bcrypt.CloseAlgorithmProvider(h, 0)
return nil, err
}
if existing, loaded := algCache.LoadOrStore(entryKey, v); loaded {
// We can safely use a provider that has already been cached in another concurrent goroutine.
bcrypt.CloseAlgorithmProvider(h, 0)
v = existing
}
return v, nil
}

func utf16PtrFromString(s string) *uint16 {
Expand Down
23 changes: 9 additions & 14 deletions cng/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,23 @@ import (
"hash"
"math/big"
"runtime"
"sync"
"unsafe"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)

var rsaCache sync.Map

type rsaAlgorithm struct {
h bcrypt.ALG_HANDLE
}

func loadRsa(id string, flags bcrypt.AlgorithmProviderFlags) (h rsaAlgorithm, err error) {
if v, ok := rsaCache.Load(algCacheEntry{id, uint32(flags)}); ok {
return v.(rsaAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags)
func loadRsa() (rsaAlgorithm, error) {
v, err := loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
return rsaAlgorithm{h}, nil
})
if err != nil {
return
return rsaAlgorithm{}, nil
}
rsaCache.Store(algCacheEntry{id, uint32(flags)}, h)
return
return v.(rsaAlgorithm), nil
}

const sizeOfRSABlobHeader = uint32(unsafe.Sizeof(bcrypt.RSAKEY_BLOB{}))
Expand All @@ -42,7 +37,7 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error)
bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) {
return nil, nil, nil, nil, nil, nil, nil, nil, e
}
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return bad(err)
}
Expand Down Expand Up @@ -95,7 +90,7 @@ type PublicKeyRSA struct {
}

func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) {
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return nil, err
}
Expand All @@ -122,7 +117,7 @@ func (k *PrivateKeyRSA) finalize() {
}

func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) {
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return nil, err
}
Expand Down
39 changes: 14 additions & 25 deletions cng/sha.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package cng
import (
"hash"
"runtime"
"sync"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)
Expand All @@ -34,38 +33,28 @@ func NewSHA512() hash.Hash {
return newSHAX(bcrypt.SHA512_ALGORITHM, nil)
}

var shaCache sync.Map

type shaAlgorithm struct {
h bcrypt.ALG_HANDLE
size uint32
blockSize uint32
}

func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, err error) {
if v, ok := shaCache.Load(algCacheEntry{id, uint32(flags)}); ok {
return v.(shaAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags)
if err != nil {
return
}
h.size, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.HASH_LENGTH)
func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (shaAlgorithm, error) {
v, err := loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) {
size, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_LENGTH)
if err != nil {
return nil, err
}
blockSize, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_BLOCK_LENGTH)
if err != nil {
return nil, err
}
return shaAlgorithm{h, size, blockSize}, nil
})
if err != nil {
bcrypt.CloseAlgorithmProvider(h.h, 0)
return
}
h.blockSize, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.HASH_BLOCK_LENGTH)
if err != nil {
bcrypt.CloseAlgorithmProvider(h.h, 0)
return
}
if existing, loaded := shaCache.LoadOrStore(algCacheEntry{id, uint32(flags)}, h); loaded {
// We can safely use a provider that has already been cached in another concurrent goroutine.
bcrypt.CloseAlgorithmProvider(h.h, 0)
h = existing.(shaAlgorithm)
return shaAlgorithm{}, err
}
return
return v.(shaAlgorithm), nil
}

type shaXHash struct {
Expand Down

0 comments on commit e1f71ab

Please sign in to comment.