Skip to content

Commit

Permalink
Refactor cert provider interfaces to simplify them (#1379)
Browse files Browse the repository at this point in the history
* Renamed fields with cached values for clarity

* Remove GetSettings() method from CertProvider interface

* Remove ServerName() and DisableHostVerification() methods from ClientCertProvider interface
  • Loading branch information
sergeybykov authored Mar 16, 2021
1 parent af60aea commit 65a948a
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 66 deletions.
12 changes: 0 additions & 12 deletions common/rpc/encryption/localStoreCertProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ type x509CertFetcher func() (*x509.Certificate, error)
type x509CertPoolFetcher func() (*x509.CertPool, error)
type tlsCertFetcher func() (*tls.Certificate, error)

func (s *localStoreCertProvider) GetSettings() *config.GroupTLS {
return s.tlsSettings
}

func (s *localStoreCertProvider) FetchServerCertificate() (*tls.Certificate, error) {

if s.tlsSettings == nil {
Expand Down Expand Up @@ -258,14 +254,6 @@ func (s *localStoreCertProvider) FetchCertificate(cachedCert **tls.Certificate,
return *cachedCert, nil
}

func (s *localStoreCertProvider) ServerName(isWorker bool) string {
return s.getClientTLSSettings(isWorker).ServerName
}

func (s *localStoreCertProvider) DisableHostVerification(isWorker bool) bool {
return s.getClientTLSSettings(isWorker).DisableHostVerification
}

func (s *localStoreCertProvider) getClientTLSSettings(isWorker bool) *config.ClientTLS {
if isWorker && s.workerTLSSettings != nil {
return &s.workerTLSSettings.Client // explicit system worker case
Expand Down
21 changes: 15 additions & 6 deletions common/rpc/encryption/localStorePerHostCertProviderMap.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var _ CertExpirationChecker = (*localStorePerHostCertProviderMap)(nil)

type localStorePerHostCertProviderMap struct {
certProviderCache map[string]*localStoreCertProvider
clientAuthCache map[string]bool
}

func newLocalStorePerHostCertProviderMap(overrides map[string]config.ServerTLS) *localStorePerHostCertProviderMap {
Expand All @@ -46,28 +47,36 @@ func newLocalStorePerHostCertProviderMap(overrides map[string]config.ServerTLS)
}

factory.certProviderCache = make(map[string]*localStoreCertProvider, len(overrides))
factory.clientAuthCache = make(map[string]bool, len(overrides))

for host, settings := range overrides {
factory.certProviderCache[strings.ToLower(host)] = &localStoreCertProvider{
lcHost := strings.ToLower(host)
factory.certProviderCache[lcHost] = &localStoreCertProvider{
tlsSettings: &config.GroupTLS{
Server: settings,
},
}
factory.clientAuthCache[lcHost] = settings.RequireClientAuth
}

return factory
}

func (f *localStorePerHostCertProviderMap) GetCertProvider(hostName string) (CertProvider, error) {
// GetCertProvider for a given host name returns a cert provider (nil if not found) and if client authentication is required
func (f *localStorePerHostCertProviderMap) GetCertProvider(hostName string) (CertProvider, bool, error) {

clientAuthRequired := true
lcHostName := strings.ToLower(hostName)

if f.certProviderCache == nil {
return nil, nil
return nil, clientAuthRequired, nil
}
cachedCertProvider, ok := f.certProviderCache[strings.ToLower(hostName)]
cachedCertProvider, ok := f.certProviderCache[lcHostName]
if !ok {
return nil, nil
return nil, clientAuthRequired, nil
}
return cachedCertProvider, nil
clientAuthRequired = f.clientAuthCache[lcHostName]
return cachedCertProvider, clientAuthRequired, nil
}

func (f *localStorePerHostCertProviderMap) GetExpiringCerts(timeWindow time.Duration,
Expand Down
65 changes: 36 additions & 29 deletions common/rpc/encryption/localStoreTlsProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ type localStoreTlsProvider struct {

frontendPerHostCertProviderMap *localStorePerHostCertProviderMap

internodeServerConfig *tls.Config
internodeClientConfig *tls.Config
frontendServerConfig *tls.Config
frontendClientConfig *tls.Config
cachedInternodeServerConfig *tls.Config
cachedInternodeClientConfig *tls.Config
cachedFrontendServerConfig *tls.Config
cachedFrontendClientConfig *tls.Config

ticker *time.Ticker
logger log.Logger
Expand Down Expand Up @@ -117,43 +117,47 @@ func (s *localStoreTlsProvider) Close() {
}

func (s *localStoreTlsProvider) GetInternodeClientConfig() (*tls.Config, error) {

client := &s.settings.Internode.Client
return s.getOrCreateConfig(
&s.internodeClientConfig,
&s.cachedInternodeClientConfig,
func() (*tls.Config, error) {
return newClientTLSConfig(s.internodeClientCertProvider,
s.internodeCertProvider.GetSettings().Server.RequireClientAuth, false)
return newClientTLSConfig(s.internodeClientCertProvider, client.ServerName,
s.settings.Internode.Server.RequireClientAuth, false, !client.DisableHostVerification)
},
s.internodeCertProvider.GetSettings().IsEnabled(),
s.settings.Internode.IsEnabled(),
)
}

func (s *localStoreTlsProvider) GetFrontendClientConfig() (*tls.Config, error) {

client := &s.settings.Frontend.Client
return s.getOrCreateConfig(
&s.frontendClientConfig,
&s.cachedFrontendClientConfig,
func() (*tls.Config, error) {
return newClientTLSConfig(s.workerCertProvider,
s.frontendCertProvider.GetSettings().Server.RequireClientAuth, true)
return newClientTLSConfig(s.workerCertProvider, client.ServerName,
s.settings.Frontend.Server.RequireClientAuth, true, !client.DisableHostVerification)
},
s.internodeCertProvider.GetSettings().IsEnabled(),
s.settings.Internode.IsEnabled(),
)
}

func (s *localStoreTlsProvider) GetFrontendServerConfig() (*tls.Config, error) {
return s.getOrCreateConfig(
&s.frontendServerConfig,
&s.cachedFrontendServerConfig,
func() (*tls.Config, error) {
return newServerTLSConfig(s.frontendCertProvider, s.frontendPerHostCertProviderMap)
return newServerTLSConfig(s.frontendCertProvider, s.frontendPerHostCertProviderMap, &s.settings.Frontend)
},
s.frontendCertProvider.GetSettings().IsEnabled())
s.settings.Frontend.IsEnabled())
}

func (s *localStoreTlsProvider) GetInternodeServerConfig() (*tls.Config, error) {
return s.getOrCreateConfig(
&s.internodeServerConfig,
&s.cachedInternodeServerConfig,
func() (*tls.Config, error) {
return newServerTLSConfig(s.internodeCertProvider, nil)
return newServerTLSConfig(s.internodeCertProvider, nil, &s.settings.Internode)
},
s.internodeCertProvider.GetSettings().IsEnabled())
s.settings.Internode.IsEnabled())
}

func (s *localStoreTlsProvider) GetExpiringCerts(timeWindow time.Duration,
Expand Down Expand Up @@ -225,31 +229,33 @@ func (s *localStoreTlsProvider) getOrCreateConfig(
func newServerTLSConfig(
certProvider CertProvider,
perHostCertProviderMap PerHostCertProviderMap,
config *config.GroupTLS,
) (*tls.Config, error) {

tlsConfig, err := getServerTLSConfigFromCertProvider(certProvider)
clientAuthRequired := config.Server.RequireClientAuth
tlsConfig, err := getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired)
if err != nil {
return nil, err
}

tlsConfig.GetConfigForClient = func(c *tls.ClientHelloInfo) (*tls.Config, error) {
if perHostCertProviderMap != nil {
perHostCertProvider, err := perHostCertProviderMap.GetCertProvider(c.ServerName)
perHostCertProvider, hostClientAuthRequired, err := perHostCertProviderMap.GetCertProvider(c.ServerName)
if err != nil {
return nil, err
}

if perHostCertProvider == nil {
return getServerTLSConfigFromCertProvider(certProvider)
if perHostCertProvider != nil {
return getServerTLSConfigFromCertProvider(perHostCertProvider, hostClientAuthRequired)
}
return getServerTLSConfigFromCertProvider(perHostCertProvider)
return getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired)
}
return getServerTLSConfigFromCertProvider(certProvider)
return getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired)
}
return tlsConfig, nil
}

func getServerTLSConfigFromCertProvider(certProvider CertProvider) (*tls.Config, error) {
func getServerTLSConfigFromCertProvider(certProvider CertProvider, requireClientAuth bool) (*tls.Config, error) {
// Get serverCert from disk
serverCert, err := certProvider.FetchServerCertificate()
if err != nil {
Expand All @@ -266,7 +272,7 @@ func getServerTLSConfigFromCertProvider(certProvider CertProvider) (*tls.Config,
var clientCaPool *x509.CertPool

// If mTLS enabled
if certProvider.GetSettings().Server.RequireClientAuth {
if requireClientAuth {
clientAuthType = tls.RequireAndVerifyClientCert

ca, err := certProvider.FetchClientCAs()
Expand All @@ -282,7 +288,8 @@ func getServerTLSConfigFromCertProvider(certProvider CertProvider) (*tls.Config,
clientCaPool), nil
}

func newClientTLSConfig(clientProvider ClientCertProvider, isAuthRequired bool, isWorker bool) (*tls.Config, error) {
func newClientTLSConfig(clientProvider ClientCertProvider, serverName string, isAuthRequired bool,
isWorker bool, enableHostVerification bool) (*tls.Config, error) {
// Optional ServerCA for client if not already trusted by host
serverCa, err := clientProvider.FetchServerRootCAsForClient(isWorker)
if err != nil {
Expand All @@ -309,8 +316,8 @@ func newClientTLSConfig(clientProvider ClientCertProvider, isAuthRequired bool,
return auth.NewDynamicTLSClientConfig(
getCert,
serverCa,
clientProvider.ServerName(isWorker),
!clientProvider.DisableHostVerification(isWorker),
serverName,
enableHostVerification,
), nil
}

Expand Down
14 changes: 3 additions & 11 deletions common/rpc/encryption/testDynamicCertProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,11 @@ func (t *TestDynamicCertProvider) FetchServerRootCAsForClient(_ bool) (*x509.Cer
return t.caCerts, nil
}

func (t *TestDynamicCertProvider) ServerName(_ bool) string {
return t.serverName
}

func (t *TestDynamicCertProvider) DisableHostVerification(_ bool) bool {
return false
}

func (t *TestDynamicCertProvider) GetCertProvider(hostName string) (CertProvider, error) {
func (t *TestDynamicCertProvider) GetCertProvider(hostName string) (CertProvider, bool, error) {
if hostName == "localhost" {
return t, nil
return t, false, nil
}
return nil, nil
return nil, false, nil
}

func (t *TestDynamicCertProvider) SwitchToWrongServerRootCACerts() {
Expand Down
8 changes: 4 additions & 4 deletions common/rpc/encryption/testDynamicTLSConfigProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,19 @@ type TestDynamicTLSConfigProvider struct {
}

func (t *TestDynamicTLSConfigProvider) GetInternodeServerConfig() (*tls.Config, error) {
return newServerTLSConfig(t.InternodeCertProvider, nil)
return newServerTLSConfig(t.InternodeCertProvider, nil, &t.settings.Internode)
}

func (t *TestDynamicTLSConfigProvider) GetInternodeClientConfig() (*tls.Config, error) {
return newClientTLSConfig(t.InternodeClientCertProvider, true, false)
return newClientTLSConfig(t.InternodeClientCertProvider, t.settings.Internode.Client.ServerName, true, false, true)
}

func (t *TestDynamicTLSConfigProvider) GetFrontendServerConfig() (*tls.Config, error) {
return newServerTLSConfig(t.FrontendCertProvider, t.FrontendPerHostCertProviderMap)
return newServerTLSConfig(t.FrontendCertProvider, t.FrontendPerHostCertProviderMap, &t.settings.Frontend)
}

func (t *TestDynamicTLSConfigProvider) GetFrontendClientConfig() (*tls.Config, error) {
return newClientTLSConfig(t.WorkerCertProvider, true, false)
return newClientTLSConfig(t.WorkerCertProvider, t.settings.Frontend.Client.ServerName, true, false, true)
}

var _ TLSConfigProvider = (*TestDynamicTLSConfigProvider)(nil)
Expand Down
5 changes: 1 addition & 4 deletions common/rpc/encryption/tlsFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,17 @@ type (
CertProvider interface {
FetchServerCertificate() (*tls.Certificate, error)
FetchClientCAs() (*x509.CertPool, error)
GetSettings() *config.GroupTLS
}

// ClientCertProvider is an interface to load raw TLS/X509 primitives for configuring clients.
ClientCertProvider interface {
FetchClientCertificate(isWorker bool) (*tls.Certificate, error)
FetchServerRootCAsForClient(isWorker bool) (*x509.CertPool, error)
ServerName(isWorker bool) string
DisableHostVerification(isWorker bool) bool
}

// PerHostCertProviderMap returns a CertProvider for a given host name.
PerHostCertProviderMap interface {
GetCertProvider(hostName string) (CertProvider, error)
GetCertProvider(hostName string) (provider CertProvider, clientAuthRequired bool, err error)
}

CertThumbprint [16]byte
Expand Down

0 comments on commit 65a948a

Please sign in to comment.