Skip to content

Commit

Permalink
consolidate algorithms cache
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Mar 17, 2022
1 parent 1c22c15 commit 844a921
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 33 deletions.
25 changes: 11 additions & 14 deletions cng/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/cipher"
"errors"
"runtime"
"sync"
"unsafe"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
Expand All @@ -19,20 +18,18 @@ import (

const aesBlockSize = 16

var aesCache sync.Map

type aesAlgorithm struct {
h bcrypt.ALG_HANDLE
allowedKeySized []int
}

type aesCacheEntry struct {
id string
mode string
}

func loadAes(id string, mode string) (h aesAlgorithm, err error) {
if v, ok := aesCache.Load(aesCacheEntry{id, mode}); ok {
func loadAes(mode string) (h aesAlgorithm, err error) {
const id = bcrypt.AES_ALGORITHM
type aesCacheEntry struct {
id string
mode string
}
if v, ok := algCache.Load(aesCacheEntry{id, mode}); ok {
return v.(aesAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG)
Expand All @@ -54,7 +51,7 @@ func loadAes(id string, mode string) (h aesAlgorithm, err error) {
for size := info.MinLength; size <= info.MaxLength; size += info.Increment {
h.allowedKeySized = append(h.allowedKeySized, int(size))
}
aesCache.Store(aesCacheEntry{id, mode}, h)
algCache.Store(aesCacheEntry{id, mode}, h)
return
}

Expand All @@ -64,7 +61,7 @@ type aesCipher struct {
}

func NewAESCipher(key []byte) (cipher.Block, error) {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_ECB)
h, err := loadAes(bcrypt.CHAIN_MODE_ECB)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -174,7 +171,7 @@ type aesCBC struct {
}

func newCBC(encrypt bool, key, iv []byte) *aesCBC {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_CBC)
h, err := loadAes(bcrypt.CHAIN_MODE_CBC)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -248,7 +245,7 @@ func (g *aesGCM) finalize() {
}

func newGCM(key []byte, tls bool) (*aesGCM, error) {
h, err := loadAes(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM)
h, err := loadAes(bcrypt.CHAIN_MODE_GCM)
if err != nil {
return nil, err
}
Expand Down
6 changes: 2 additions & 4 deletions cng/cng.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@
package cng

import (
"sync"
"syscall"
"unsafe"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)

type algCacheEntry struct {
id string
flags uint32
}
var algCache sync.Map

func utf16PtrFromString(s string) *uint16 {
str, err := syscall.UTF16PtrFromString(s)
Expand Down
18 changes: 8 additions & 10 deletions cng/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@ import (
"hash"
"math/big"
"runtime"
"sync"
"unsafe"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)

var rsaCache sync.Map

type rsaAlgorithm struct {
h bcrypt.ALG_HANDLE
}

func loadRsa(id string, flags bcrypt.AlgorithmProviderFlags) (h rsaAlgorithm, err error) {
if v, ok := rsaCache.Load(algCacheEntry{id, uint32(flags)}); ok {
func loadRsa() (h rsaAlgorithm, err error) {
const id = bcrypt.RSA_ALGORITHM
if v, ok := algCache.Load(id); ok {
return v.(rsaAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags)
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, bcrypt.ALG_NONE_FLAG)
if err != nil {
return
}
rsaCache.Store(algCacheEntry{id, uint32(flags)}, h)
algCache.Store(id, h)
return
}

Expand All @@ -42,7 +40,7 @@ func GenerateKeyRSA(bits int) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error)
bad := func(e error) (N, E, D, P, Q, Dp, Dq, Qinv *big.Int, err error) {
return nil, nil, nil, nil, nil, nil, nil, nil, e
}
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return bad(err)
}
Expand Down Expand Up @@ -95,7 +93,7 @@ type PublicKeyRSA struct {
}

func NewPublicKeyRSA(N, E *big.Int) (*PublicKeyRSA, error) {
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return nil, err
}
Expand All @@ -122,7 +120,7 @@ func (k *PrivateKeyRSA) finalize() {
}

func NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv *big.Int) (*PrivateKeyRSA, error) {
h, err := loadRsa(bcrypt.RSA_ALGORITHM, bcrypt.ALG_NONE_FLAG)
h, err := loadRsa()
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions cng/sha.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package cng
import (
"hash"
"runtime"
"sync"

"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)
Expand All @@ -34,16 +33,18 @@ func NewSHA512() hash.Hash {
return newSHAX(bcrypt.SHA512_ALGORITHM, nil)
}

var shaCache sync.Map

type shaAlgorithm struct {
h bcrypt.ALG_HANDLE
size uint32
blockSize uint32
}

func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, err error) {
if v, ok := shaCache.Load(algCacheEntry{id, uint32(flags)}); ok {
type entry struct {
id string
flags uint32
}
if v, ok := algCache.Load(entry{id, uint32(flags)}); ok {
return v.(shaAlgorithm), nil
}
err = bcrypt.OpenAlgorithmProvider(&h.h, utf16PtrFromString(id), nil, flags)
Expand All @@ -58,7 +59,7 @@ func loadSha(id string, flags bcrypt.AlgorithmProviderFlags) (h shaAlgorithm, er
if err != nil {
return
}
shaCache.Store(algCacheEntry{id, uint32(flags)}, h)
algCache.Store(entry{id, uint32(flags)}, h)
return
}

Expand Down

0 comments on commit 844a921

Please sign in to comment.