Skip to content

Commit

Permalink
consolidate algorithms cache
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Mar 24, 2022
1 parent 58bf0b5 commit d65eeb2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 29 deletions.
25 changes: 11 additions & 14 deletions cng/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/cipher"
"errors"
"runtime"
"sync"
"unsafe"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
Expand All @@ -19,20 +18,18 @@ import (

const aesBlockSize = 16

var aesCache sync.Map

type aesAlgorithm struct {
h bcrypt.ALG_HANDLE
allowedKeySized []int
}

type aesCacheEntry struct {
id string
mode string
}

func loadAes(id string, mode string) (h aesAlgorithm, err error) {
if v, ok := aesCache.Load(aesCacheEntry{id, mode}); ok {
func loadAes(mode string) (h aesAlgorithm, err error) {
const id = bcrypt.AES_ALGORITHM
type aesCacheEntry struct {
id string
mode string
}
if v, ok := algCache.Load(aesCacheEntry{id, mode}); ok {
return v.(aesAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG)
Expand All @@ -56,7 +53,7 @@ func loadAes(id string, mode string) (h aesAlgorithm, err error) {
for size := info.MinLength; size <= info.MaxLength; size += info.Increment {
h.allowedKeySized = append(h.allowedKeySized, int(size))
}
if existing, loaded := aesCache.LoadOrStore(aesCacheEntry{id, mode}, h); loaded {
if existing, loaded := algCache.LoadOrStore(aesCacheEntry{id, mode}, h); loaded {
// We can safely use a provider that has already been cached in another concurrent goroutine.
bcrypt.CloseAlgorithmProvider(h.h, 0)
h = existing.(aesAlgorithm)
Expand All @@ -70,7 +67,7 @@ type aesCipher struct {
}

func NewAESCipher(key []byte) (cipher.Block, error) {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_ECB)
h, err := loadAes(bcrypt.CHAIN_MODE_ECB)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -180,7 +177,7 @@ type aesCBC struct {
}

func newCBC(encrypt bool, key, iv []byte) *aesCBC {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_CBC)
h, err := loadAes(bcrypt.CHAIN_MODE_CBC)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -254,7 +251,7 @@ func (g *aesGCM) finalize() {
}

func newGCM(key []byte, tls bool) (*aesGCM, error) {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM)
h, err := loadAes(bcrypt.CHAIN_MODE_GCM)
if err != nil {
return nil, err
}
Expand Down
3 changes: 3 additions & 0 deletions cng/cng.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"math"
"reflect"
"runtime"
"sync"
"syscall"
"unsafe"

Expand All @@ -30,6 +31,8 @@ type algCacheEntry struct {
flags uint32
}

var algCache sync.Map

func utf16PtrFromString(s string) *uint16 {
str, err := syscall.UTF16PtrFromString(s)
if err != nil {
Expand Down
18 changes: 8 additions & 10 deletions cng/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@ import (
"hash"
"math/big"
"runtime"
"sync"
"unsafe"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)

var rsaCache sync.Map

type rsaAlgorithm struct {
h bcrypt.ALG_HANDLE
}

func loadRsa(id string, flags bcrypt.AlgorithmProviderFlags) (h rsaAlgorithm, err error) {
if v, ok := rsaCache.Load(algCacheEntry{id, uint32(flags)}); ok {
func loadRsa() (h rsaAlgorithm, err error) {
const id = bcrypt.RSA_ALGORITHM
if v, ok := algCache.Load(id); ok {
return v.(rsaAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags)
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG)
if err != nil {
return
}
rsaCache.Store(algCacheEntry{id, uint32(flags)}, h)
algCache.Store(id, h)
return
}

Expand All @@ -42,7 +40,7 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error)
bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) {
return nil, nil, nil, nil, nil, nil, nil, nil, e
}
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return bad(err)
}
Expand Down Expand Up @@ -95,7 +93,7 @@ type PublicKeyRSA struct {
}

func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) {
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return nil, err
}
Expand All @@ -122,7 +120,7 @@ func (k *PrivateKeyRSA) finalize() {
}

func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) {
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions cng/sha.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package cng
import (
"hash"
"runtime"
"sync"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)
Expand All @@ -34,16 +33,18 @@ func NewSHA512() hash.Hash {
return newSHAX(bcrypt.SHA512_ALGORITHM, nil)
}

var shaCache sync.Map

type shaAlgorithm struct {
h bcrypt.ALG_HANDLE
size uint32
blockSize uint32
}

func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, err error) {
if v, ok := shaCache.Load(algCacheEntry{id, uint32(flags)}); ok {
type entry struct {
id string
flags uint32
}
if v, ok := algCache.Load(entry{id, uint32(flags)}); ok {
return v.(shaAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags)
Expand All @@ -60,7 +61,7 @@ func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, er
bcrypt.CloseAlgorithmProvider(h.h, 0)
return
}
if existing, loaded := shaCache.LoadOrStore(algCacheEntry{id, uint32(flags)}, h); loaded {
if existing, loaded := algCache.LoadOrStore(algCacheEntry{id, uint32(flags)}, h); loaded {
// We can safely use a provider that has already been cached in another concurrent goroutine.
bcrypt.CloseAlgorithmProvider(h.h, 0)
h = existing.(shaAlgorithm)
Expand Down

0 comments on commit d65eeb2

Please sign in to comment.