From d65eeb2e8e1ac9eeebda9eae3a4cb75ecb5e490a Mon Sep 17 00:00:00 2001 From: qmuntal Date: Thu, 17 Mar 2022 18:40:22 +0100 Subject: [PATCH] consolidate algorithms cache --- cng/aes.go | 25 +++++++++++-------------- cng/cng.go | 3 +++ cng/rsa.go | 18 ++++++++---------- cng/sha.go | 11 ++++++----- 4 files changed, 28 insertions(+), 29 deletions(-) diff --git a/cng/aes.go b/cng/aes.go index 621a513..7b87f46 100644 --- a/cng/aes.go +++ b/cng/aes.go @@ -10,7 +10,6 @@ import ( "crypto/cipher" "errors" "runtime" - "sync" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" @@ -19,20 +18,18 @@ import ( const aesBlockSize = 16 -var aesCache sync.Map - type aesAlgorithm struct { h bcrypt.ALG_HANDLE allowedKeySized []int } -type aesCacheEntry struct { - id string - mode string -} - -func loadAes(id string, mode string) (h aesAlgorithm, err error) { - if v, ok := aesCache.Load(aesCacheEntry{id, mode}); ok { +func loadAes(mode string) (h aesAlgorithm, err error) { + const id = bcrypt.AES_ALGORITHM + type aesCacheEntry struct { + id string + mode string + } + if v, ok := algCache.Load(aesCacheEntry{id, mode}); ok { return v.(aesAlgorithm), nil } err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG) @@ -56,7 +53,7 @@ func loadAes(id string, mode string) (h aesAlgorithm, err error) { for size := info.MinLength; size <= info.MaxLength; size += info.Increment { h.allowedKeySized = append(h.allowedKeySized, int(size)) } - if existing, loaded := aesCache.LoadOrStore(aesCacheEntry{id, mode}, h); loaded { + if existing, loaded := algCache.LoadOrStore(aesCacheEntry{id, mode}, h); loaded { // We can safely use a provider that has already been cached in another concurrent goroutine. bcrypt.CloseAlgorithmProvider(h.h, 0) h = existing.(aesAlgorithm) @@ -70,7 +67,7 @@ type aesCipher struct { } func NewAESCipher(key []byte) (cipher.Block, error) { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_ECB) + h, err := loadAes(bcrypt.CHAIN_MODE_ECB) if err != nil { return nil, err } @@ -180,7 +177,7 @@ type aesCBC struct { } func newCBC(encrypt bool, key, iv []byte) *aesCBC { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_CBC) + h, err := loadAes(bcrypt.CHAIN_MODE_CBC) if err != nil { panic(err) } @@ -254,7 +251,7 @@ func (g *aesGCM) finalize() { } func newGCM(key []byte, tls bool) (*aesGCM, error) { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM) + h, err := loadAes(bcrypt.CHAIN_MODE_GCM) if err != nil { return nil, err } diff --git a/cng/cng.go b/cng/cng.go index 7a54ab0..7b898fa 100644 --- a/cng/cng.go +++ b/cng/cng.go @@ -10,6 +10,7 @@ import ( "math" "reflect" "runtime" + "sync" "syscall" "unsafe" @@ -30,6 +31,8 @@ type algCacheEntry struct { flags uint32 } +var algCache sync.Map + func utf16PtrFromString(s string) *uint16 { str, err := syscall.UTF16PtrFromString(s) if err != nil { diff --git a/cng/rsa.go b/cng/rsa.go index b958b0c..a49cf2a 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -12,27 +12,25 @@ import ( "hash" "math/big" "runtime" - "sync" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) -var rsaCache sync.Map - type rsaAlgorithm struct { h bcrypt.ALG_HANDLE } -func loadRsa(id string, flags bcrypt.AlgorithmProviderFlags) (h rsaAlgorithm, err error) { - if v, ok := rsaCache.Load(algCacheEntry{id, uint32(flags)}); ok { +func loadRsa() (h rsaAlgorithm, err error) { + const id = bcrypt.RSA_ALGORITHM + if v, ok := algCache.Load(id); ok { return v.(rsaAlgorithm), nil } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags) + err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG) if err != nil { return } - rsaCache.Store(algCacheEntry{id, uint32(flags)}, h) + algCache.Store(id, h) return } @@ -42,7 +40,7 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { return nil, nil, nil, nil, nil, nil, nil, nil, e } - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return bad(err) } @@ -95,7 +93,7 @@ type PublicKeyRSA struct { } func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) { - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return nil, err } @@ -122,7 +120,7 @@ func (k *PrivateKeyRSA) finalize() { } func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) { - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return nil, err } diff --git a/cng/sha.go b/cng/sha.go index cc93f20..31c11a8 100644 --- a/cng/sha.go +++ b/cng/sha.go @@ -9,7 +9,6 @@ package cng import ( "hash" "runtime" - "sync" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) @@ -34,8 +33,6 @@ func NewSHA512() hash.Hash { return newSHAX(bcrypt.SHA512_ALGORITHM, nil) } -var shaCache sync.Map - type shaAlgorithm struct { h bcrypt.ALG_HANDLE size uint32 @@ -43,7 +40,11 @@ type shaAlgorithm struct { } func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, err error) { - if v, ok := shaCache.Load(algCacheEntry{id, uint32(flags)}); ok { + type entry struct { + id string + flags uint32 + } + if v, ok := algCache.Load(entry{id, uint32(flags)}); ok { return v.(shaAlgorithm), nil } err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags) @@ -60,7 +61,7 @@ func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, er bcrypt.CloseAlgorithmProvider(h.h, 0) return } - if existing, loaded := shaCache.LoadOrStore(algCacheEntry{id, uint32(flags)}, h); loaded { + if existing, loaded := algCache.LoadOrStore(algCacheEntry{id, uint32(flags)}, h); loaded { // We can safely use a provider that has already been cached in another concurrent goroutine. bcrypt.CloseAlgorithmProvider(h.h, 0) h = existing.(shaAlgorithm)