Skip to content

Commit

Permalink
Guard secrets.Store middleware with a mutex (#625)
Browse files Browse the repository at this point in the history
Sometimes we have to call AddMiddleware in a goroutine which can lead to race conditions with other calls to AddMiddleware. Adding a mutex around updating and calling the middleware functions protects us from this.
  • Loading branch information
pacejackson authored Jul 27, 2023
1 parent cb54626 commit 8b80dad
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
38 changes: 31 additions & 7 deletions secrets/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"io/fs"
"os"
"sync"
"time"

"github.com/reddit/baseplate.go/filewatcher"
Expand All @@ -32,7 +33,11 @@ func nopSecretHandlerFunc(sec *Secrets) {}
type Store struct {
watcher filewatcher.FileWatcher

secretHandlerFunc SecretHandlerFunc
// mutex to guard unsafeSecretHandlerFunc
// call handler function using the secretHandlerFunc function rather than
// calling unsafeSecretHandlerFunc directly
mu sync.Mutex
unsafeSecretHandlerFunc SecretHandlerFunc
}

// NewStore returns a new instance of Store by configuring it
Expand All @@ -48,7 +53,7 @@ func NewStore(ctx context.Context, path string, logger log.Wrapper, middlewares
// Used in tests to override FSEventsDelay
func newStore(ctx context.Context, fsEventsDelay time.Duration, path string, logger log.Wrapper, middlewares ...SecretMiddleware) (*Store, error) {
store := &Store{
secretHandlerFunc: nopSecretHandlerFunc,
unsafeSecretHandlerFunc: nopSecretHandlerFunc,
}
store.secretHandler(middlewares...)
fileInfo, err := os.Stat(path)
Expand Down Expand Up @@ -103,11 +108,30 @@ func (s *Store) dirParser(dir fs.FS) (any, error) {

// secretHandler creates the middleware chain.
func (s *Store) secretHandler(middlewares ...SecretMiddleware) {
s.mu.Lock()
defer s.mu.Unlock()

for _, m := range middlewares {
s.secretHandlerFunc = m(s.secretHandlerFunc)
s.unsafeSecretHandlerFunc = m(s.unsafeSecretHandlerFunc)
}
}

// secretHandlerFunc guards calling s.unsafeSecretHandlerFunc with a mutex.
func (s *Store) secretHandlerFunc(sec *Secrets) {
// grab current secret handler func while holding the lock to guard against
// updates to the handler func.
currentSecretHandlerFunc := func() SecretHandlerFunc {
s.mu.Lock()
defer s.mu.Unlock()

return s.unsafeSecretHandlerFunc
}()

// execute the secret handler func outside the lock to not tie it up once
// we safely have a handler func.
currentSecretHandlerFunc(sec)
}

func (s *Store) getSecrets() *Secrets {
return s.watcher.Get().(*Secrets)
}
Expand Down Expand Up @@ -137,17 +161,17 @@ func (s *Store) AddMiddlewares(middlewares ...SecretMiddleware) {
}

// GetSimpleSecret loads secrets from watcher, and fetches a simple secret from secrets
func (s Store) GetSimpleSecret(path string) (SimpleSecret, error) {
func (s *Store) GetSimpleSecret(path string) (SimpleSecret, error) {
return s.getSecrets().GetSimpleSecret(path)
}

// GetVersionedSecret loads secrets from watcher, and fetches a versioned secret from secrets
func (s Store) GetVersionedSecret(path string) (VersionedSecret, error) {
func (s *Store) GetVersionedSecret(path string) (VersionedSecret, error) {
return s.getSecrets().GetVersionedSecret(path)
}

// GetCredentialSecret loads secrets from watcher, and fetches a credential secret from secrets
func (s Store) GetCredentialSecret(path string) (CredentialSecret, error) {
func (s *Store) GetCredentialSecret(path string) (CredentialSecret, error) {
return s.getSecrets().GetCredentialSecret(path)
}

Expand All @@ -156,6 +180,6 @@ func (s Store) GetCredentialSecret(path string) (CredentialSecret, error) {
// role. This is only necessary if talking directly to Vault.
//
// This function always returns nil error.
func (s Store) GetVault() (Vault, error) {
func (s *Store) GetVault() (Vault, error) {
return s.getSecrets().vault, nil
}
2 changes: 1 addition & 1 deletion secrets/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func NewTestSecrets(ctx context.Context, raw map[string]GenericSecret, middlewar
}

store := &Store{
secretHandlerFunc: nopSecretHandlerFunc,
unsafeSecretHandlerFunc: nopSecretHandlerFunc,
}
store.secretHandler(middlewares...)

Expand Down

0 comments on commit 8b80dad

Please sign in to comment.