diff --git a/cng/aes.go b/cng/aes.go index 621a513..8fc62c0 100644 --- a/cng/aes.go +++ b/cng/aes.go @@ -10,7 +10,6 @@ import ( "crypto/cipher" "errors" "runtime" - "sync" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" @@ -19,49 +18,35 @@ import ( const aesBlockSize = 16 -var aesCache sync.Map - type aesAlgorithm struct { h bcrypt.ALG_HANDLE allowedKeySized []int } -type aesCacheEntry struct { - id string - mode string -} - -func loadAes(id string, mode string) (h aesAlgorithm, err error) { - if v, ok := aesCache.Load(aesCacheEntry{id, mode}); ok { - return v.(aesAlgorithm), nil - } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG) - if err != nil { - return - } - // Windows 8 added support to set the CipherMode value on a key, - // but Windows 7 requires that it be set on the algorithm before key creation. - err = setString(bcrypt.HANDLE(h.h), bcrypt.CHAINING_MODE, mode) - if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - return - } - var info bcrypt.KEY_LENGTHS_STRUCT - var discard uint32 - err = bcrypt.GetProperty(bcrypt.HANDLE(h.h), utf16PtrFromString(bcrypt.KEY_LENGTHS), (*(*[1<<31 - 1]byte)(unsafe.Pointer(&info)))[:unsafe.Sizeof(info)], &discard, 0) +func loadAes(mode string) (aesAlgorithm, error) { + v, err := loadOrStoreAlg(bcrypt.AES_ALGORITHM, bcrypt.ALG_NONE_FLAG, mode, func(h bcrypt.ALG_HANDLE) (interface{}, error) { + // Windows 8 added support to set the CipherMode value on a key, + // but Windows 7 requires that it be set on the algorithm before key creation. + err := setString(bcrypt.HANDLE(h), bcrypt.CHAINING_MODE, mode) + if err != nil { + return nil, err + } + var info bcrypt.KEY_LENGTHS_STRUCT + var discard uint32 + err = bcrypt.GetProperty(bcrypt.HANDLE(h), utf16PtrFromString(bcrypt.KEY_LENGTHS), (*(*[1<<31 - 1]byte)(unsafe.Pointer(&info)))[:unsafe.Sizeof(info)], &discard, 0) + if err != nil { + return nil, err + } + var allowedKeySized []int + for size := info.MinLength; size <= info.MaxLength; size += info.Increment { + allowedKeySized = append(allowedKeySized, int(size)) + } + return aesAlgorithm{h, allowedKeySized}, nil + }) if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - return - } - for size := info.MinLength; size <= info.MaxLength; size += info.Increment { - h.allowedKeySized = append(h.allowedKeySized, int(size)) + return aesAlgorithm{}, nil } - if existing, loaded := aesCache.LoadOrStore(aesCacheEntry{id, mode}, h); loaded { - // We can safely use a provider that has already been cached in another concurrent goroutine. - bcrypt.CloseAlgorithmProvider(h.h, 0) - h = existing.(aesAlgorithm) - } - return + return v.(aesAlgorithm), nil } type aesCipher struct { @@ -70,7 +55,7 @@ type aesCipher struct { } func NewAESCipher(key []byte) (cipher.Block, error) { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_ECB) + h, err := loadAes(bcrypt.CHAIN_MODE_ECB) if err != nil { return nil, err } @@ -180,7 +165,7 @@ type aesCBC struct { } func newCBC(encrypt bool, key, iv []byte) *aesCBC { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_CBC) + h, err := loadAes(bcrypt.CHAIN_MODE_CBC) if err != nil { panic(err) } @@ -254,7 +239,7 @@ func (g *aesGCM) finalize() { } func newGCM(key []byte, tls bool) (*aesGCM, error) { - h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM) + h, err := loadAes(bcrypt.CHAIN_MODE_GCM) if err != nil { return nil, err } diff --git a/cng/cng.go b/cng/cng.go index 7a54ab0..a30fdc9 100644 --- a/cng/cng.go +++ b/cng/cng.go @@ -10,6 +10,7 @@ import ( "math" "reflect" "runtime" + "sync" "syscall" "unsafe" @@ -25,9 +26,36 @@ func lenU32(s []byte) int { return len(s) } -type algCacheEntry struct { - id string - flags uint32 +var algCache sync.Map + +type newAlgEntryFn func(h bcrypt.ALG_HANDLE) (interface{}, error) + +func loadOrStoreAlg(id string, flags bcrypt.AlgorithmProviderFlags, mode string, fn newAlgEntryFn) (interface{}, error) { + var entryKey = struct { + id string + flags bcrypt.AlgorithmProviderFlags + mode string + }{id, flags, mode} + + if v, ok := algCache.Load(entryKey); ok { + return v, nil + } + var h bcrypt.ALG_HANDLE + err := bcrypt.OpenAlgorithmProvider(&h, utf16PtrFromString(id), nil, flags) + if err != nil { + return nil, err + } + v, err := fn(h) + if err != nil { + bcrypt.CloseAlgorithmProvider(h, 0) + return nil, err + } + if existing, loaded := algCache.LoadOrStore(entryKey, v); loaded { + // We can safely use a provider that has already been cached in another concurrent goroutine. + bcrypt.CloseAlgorithmProvider(h, 0) + v = existing + } + return v, nil } func utf16PtrFromString(s string) *uint16 { diff --git a/cng/rsa.go b/cng/rsa.go index b958b0c..15fe68a 100644 --- a/cng/rsa.go +++ b/cng/rsa.go @@ -12,28 +12,23 @@ import ( "hash" "math/big" "runtime" - "sync" "unsafe" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) -var rsaCache sync.Map - type rsaAlgorithm struct { h bcrypt.ALG_HANDLE } -func loadRsa(id string, flags bcrypt.AlgorithmProviderFlags) (h rsaAlgorithm, err error) { - if v, ok := rsaCache.Load(algCacheEntry{id, uint32(flags)}); ok { - return v.(rsaAlgorithm), nil - } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags) +func loadRsa() (rsaAlgorithm, error) { + v, err := loadOrStoreAlg(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + return rsaAlgorithm{h}, nil + }) if err != nil { - return + return rsaAlgorithm{}, nil } - rsaCache.Store(algCacheEntry{id, uint32(flags)}, h) - return + return v.(rsaAlgorithm), nil } const sizeOfRSABlobHeader = uint32(unsafe.Sizeof(bcrypt.RSAKEY_BLOB{})) @@ -42,7 +37,7 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) { return nil, nil, nil, nil, nil, nil, nil, nil, e } - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return bad(err) } @@ -95,7 +90,7 @@ type PublicKeyRSA struct { } func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) { - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return nil, err } @@ -122,7 +117,7 @@ func (k *PrivateKeyRSA) finalize() { } func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) { - h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG) + h, err := loadRsa() if err != nil { return nil, err } diff --git a/cng/sha.go b/cng/sha.go index cc93f20..8413e6b 100644 --- a/cng/sha.go +++ b/cng/sha.go @@ -9,7 +9,6 @@ package cng import ( "hash" "runtime" - "sync" "github.com/microsoft/go-crypto-winnative/internal/bcrypt" ) @@ -34,38 +33,28 @@ func NewSHA512() hash.Hash { return newSHAX(bcrypt.SHA512_ALGORITHM, nil) } -var shaCache sync.Map - type shaAlgorithm struct { h bcrypt.ALG_HANDLE size uint32 blockSize uint32 } -func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, err error) { - if v, ok := shaCache.Load(algCacheEntry{id, uint32(flags)}); ok { - return v.(shaAlgorithm), nil - } - err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags) - if err != nil { - return - } - h.size, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.HASH_LENGTH) +func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (shaAlgorithm, error) { + v, err := loadOrStoreAlg(id, flags, "", func(h bcrypt.ALG_HANDLE) (interface{}, error) { + size, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_LENGTH) + if err != nil { + return nil, err + } + blockSize, err := getUint32(bcrypt.HANDLE(h), bcrypt.HASH_BLOCK_LENGTH) + if err != nil { + return nil, err + } + return shaAlgorithm{h, size, blockSize}, nil + }) if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - return - } - h.blockSize, err = getUint32(bcrypt.HANDLE(h.h), bcrypt.HASH_BLOCK_LENGTH) - if err != nil { - bcrypt.CloseAlgorithmProvider(h.h, 0) - return - } - if existing, loaded := shaCache.LoadOrStore(algCacheEntry{id, uint32(flags)}, h); loaded { - // We can safely use a provider that has already been cached in another concurrent goroutine. - bcrypt.CloseAlgorithmProvider(h.h, 0) - h = existing.(shaAlgorithm) + return shaAlgorithm{}, err } - return + return v.(shaAlgorithm), nil } type shaXHash struct {