From 65a948a9d4001f864ccda74211b53e7bab435dff Mon Sep 17 00:00:00 2001 From: Sergey Bykov <8248806+sergeybykov@users.noreply.github.com> Date: Tue, 16 Mar 2021 09:12:16 -0700 Subject: [PATCH] Refactor cert provider interfaces to simplify them (#1379) * Renamed fields with cached values for clarity * Remove GetSettings() method from CertProvider interface * Remove ServerName() and DisableHostVerification() methods from ClientCertProvider interface --- .../rpc/encryption/localStoreCertProvider.go | 12 ---- .../localStorePerHostCertProviderMap.go | 21 ++++-- .../rpc/encryption/localStoreTlsProvider.go | 65 ++++++++++--------- .../rpc/encryption/testDynamicCertProvider.go | 14 +--- .../testDynamicTLSConfigProvider.go | 8 +-- common/rpc/encryption/tlsFactory.go | 5 +- 6 files changed, 59 insertions(+), 66 deletions(-) diff --git a/common/rpc/encryption/localStoreCertProvider.go b/common/rpc/encryption/localStoreCertProvider.go index 5281505be7e..40e17e1ab12 100644 --- a/common/rpc/encryption/localStoreCertProvider.go +++ b/common/rpc/encryption/localStoreCertProvider.go @@ -70,10 +70,6 @@ type x509CertFetcher func() (*x509.Certificate, error) type x509CertPoolFetcher func() (*x509.CertPool, error) type tlsCertFetcher func() (*tls.Certificate, error) -func (s *localStoreCertProvider) GetSettings() *config.GroupTLS { - return s.tlsSettings -} - func (s *localStoreCertProvider) FetchServerCertificate() (*tls.Certificate, error) { if s.tlsSettings == nil { @@ -258,14 +254,6 @@ func (s *localStoreCertProvider) FetchCertificate(cachedCert **tls.Certificate, return *cachedCert, nil } -func (s *localStoreCertProvider) ServerName(isWorker bool) string { - return s.getClientTLSSettings(isWorker).ServerName -} - -func (s *localStoreCertProvider) DisableHostVerification(isWorker bool) bool { - return s.getClientTLSSettings(isWorker).DisableHostVerification -} - func (s *localStoreCertProvider) getClientTLSSettings(isWorker bool) *config.ClientTLS { if isWorker && s.workerTLSSettings != nil { return &s.workerTLSSettings.Client // explicit system worker case diff --git a/common/rpc/encryption/localStorePerHostCertProviderMap.go b/common/rpc/encryption/localStorePerHostCertProviderMap.go index f586ffe6a4f..a0ca60aa01a 100644 --- a/common/rpc/encryption/localStorePerHostCertProviderMap.go +++ b/common/rpc/encryption/localStorePerHostCertProviderMap.go @@ -36,6 +36,7 @@ var _ CertExpirationChecker = (*localStorePerHostCertProviderMap)(nil) type localStorePerHostCertProviderMap struct { certProviderCache map[string]*localStoreCertProvider + clientAuthCache map[string]bool } func newLocalStorePerHostCertProviderMap(overrides map[string]config.ServerTLS) *localStorePerHostCertProviderMap { @@ -46,28 +47,36 @@ func newLocalStorePerHostCertProviderMap(overrides map[string]config.ServerTLS) } factory.certProviderCache = make(map[string]*localStoreCertProvider, len(overrides)) + factory.clientAuthCache = make(map[string]bool, len(overrides)) for host, settings := range overrides { - factory.certProviderCache[strings.ToLower(host)] = &localStoreCertProvider{ + lcHost := strings.ToLower(host) + factory.certProviderCache[lcHost] = &localStoreCertProvider{ tlsSettings: &config.GroupTLS{ Server: settings, }, } + factory.clientAuthCache[lcHost] = settings.RequireClientAuth } return factory } -func (f *localStorePerHostCertProviderMap) GetCertProvider(hostName string) (CertProvider, error) { +// GetCertProvider for a given host name returns a cert provider (nil if not found) and if client authentication is required +func (f *localStorePerHostCertProviderMap) GetCertProvider(hostName string) (CertProvider, bool, error) { + + clientAuthRequired := true + lcHostName := strings.ToLower(hostName) if f.certProviderCache == nil { - return nil, nil + return nil, clientAuthRequired, nil } - cachedCertProvider, ok := f.certProviderCache[strings.ToLower(hostName)] + cachedCertProvider, ok := f.certProviderCache[lcHostName] if !ok { - return nil, nil + return nil, clientAuthRequired, nil } - return cachedCertProvider, nil + clientAuthRequired = f.clientAuthCache[lcHostName] + return cachedCertProvider, clientAuthRequired, nil } func (f *localStorePerHostCertProviderMap) GetExpiringCerts(timeWindow time.Duration, diff --git a/common/rpc/encryption/localStoreTlsProvider.go b/common/rpc/encryption/localStoreTlsProvider.go index 2aaabb15119..ba5c96e0753 100644 --- a/common/rpc/encryption/localStoreTlsProvider.go +++ b/common/rpc/encryption/localStoreTlsProvider.go @@ -55,10 +55,10 @@ type localStoreTlsProvider struct { frontendPerHostCertProviderMap *localStorePerHostCertProviderMap - internodeServerConfig *tls.Config - internodeClientConfig *tls.Config - frontendServerConfig *tls.Config - frontendClientConfig *tls.Config + cachedInternodeServerConfig *tls.Config + cachedInternodeClientConfig *tls.Config + cachedFrontendServerConfig *tls.Config + cachedFrontendClientConfig *tls.Config ticker *time.Ticker logger log.Logger @@ -117,43 +117,47 @@ func (s *localStoreTlsProvider) Close() { } func (s *localStoreTlsProvider) GetInternodeClientConfig() (*tls.Config, error) { + + client := &s.settings.Internode.Client return s.getOrCreateConfig( - &s.internodeClientConfig, + &s.cachedInternodeClientConfig, func() (*tls.Config, error) { - return newClientTLSConfig(s.internodeClientCertProvider, - s.internodeCertProvider.GetSettings().Server.RequireClientAuth, false) + return newClientTLSConfig(s.internodeClientCertProvider, client.ServerName, + s.settings.Internode.Server.RequireClientAuth, false, !client.DisableHostVerification) }, - s.internodeCertProvider.GetSettings().IsEnabled(), + s.settings.Internode.IsEnabled(), ) } func (s *localStoreTlsProvider) GetFrontendClientConfig() (*tls.Config, error) { + + client := &s.settings.Frontend.Client return s.getOrCreateConfig( - &s.frontendClientConfig, + &s.cachedFrontendClientConfig, func() (*tls.Config, error) { - return newClientTLSConfig(s.workerCertProvider, - s.frontendCertProvider.GetSettings().Server.RequireClientAuth, true) + return newClientTLSConfig(s.workerCertProvider, client.ServerName, + s.settings.Frontend.Server.RequireClientAuth, true, !client.DisableHostVerification) }, - s.internodeCertProvider.GetSettings().IsEnabled(), + s.settings.Internode.IsEnabled(), ) } func (s *localStoreTlsProvider) GetFrontendServerConfig() (*tls.Config, error) { return s.getOrCreateConfig( - &s.frontendServerConfig, + &s.cachedFrontendServerConfig, func() (*tls.Config, error) { - return newServerTLSConfig(s.frontendCertProvider, s.frontendPerHostCertProviderMap) + return newServerTLSConfig(s.frontendCertProvider, s.frontendPerHostCertProviderMap, &s.settings.Frontend) }, - s.frontendCertProvider.GetSettings().IsEnabled()) + s.settings.Frontend.IsEnabled()) } func (s *localStoreTlsProvider) GetInternodeServerConfig() (*tls.Config, error) { return s.getOrCreateConfig( - &s.internodeServerConfig, + &s.cachedInternodeServerConfig, func() (*tls.Config, error) { - return newServerTLSConfig(s.internodeCertProvider, nil) + return newServerTLSConfig(s.internodeCertProvider, nil, &s.settings.Internode) }, - s.internodeCertProvider.GetSettings().IsEnabled()) + s.settings.Internode.IsEnabled()) } func (s *localStoreTlsProvider) GetExpiringCerts(timeWindow time.Duration, @@ -225,31 +229,33 @@ func (s *localStoreTlsProvider) getOrCreateConfig( func newServerTLSConfig( certProvider CertProvider, perHostCertProviderMap PerHostCertProviderMap, + config *config.GroupTLS, ) (*tls.Config, error) { - tlsConfig, err := getServerTLSConfigFromCertProvider(certProvider) + clientAuthRequired := config.Server.RequireClientAuth + tlsConfig, err := getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired) if err != nil { return nil, err } tlsConfig.GetConfigForClient = func(c *tls.ClientHelloInfo) (*tls.Config, error) { if perHostCertProviderMap != nil { - perHostCertProvider, err := perHostCertProviderMap.GetCertProvider(c.ServerName) + perHostCertProvider, hostClientAuthRequired, err := perHostCertProviderMap.GetCertProvider(c.ServerName) if err != nil { return nil, err } - if perHostCertProvider == nil { - return getServerTLSConfigFromCertProvider(certProvider) + if perHostCertProvider != nil { + return getServerTLSConfigFromCertProvider(perHostCertProvider, hostClientAuthRequired) } - return getServerTLSConfigFromCertProvider(perHostCertProvider) + return getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired) } - return getServerTLSConfigFromCertProvider(certProvider) + return getServerTLSConfigFromCertProvider(certProvider, clientAuthRequired) } return tlsConfig, nil } -func getServerTLSConfigFromCertProvider(certProvider CertProvider) (*tls.Config, error) { +func getServerTLSConfigFromCertProvider(certProvider CertProvider, requireClientAuth bool) (*tls.Config, error) { // Get serverCert from disk serverCert, err := certProvider.FetchServerCertificate() if err != nil { @@ -266,7 +272,7 @@ func getServerTLSConfigFromCertProvider(certProvider CertProvider) (*tls.Config, var clientCaPool *x509.CertPool // If mTLS enabled - if certProvider.GetSettings().Server.RequireClientAuth { + if requireClientAuth { clientAuthType = tls.RequireAndVerifyClientCert ca, err := certProvider.FetchClientCAs() @@ -282,7 +288,8 @@ func getServerTLSConfigFromCertProvider(certProvider CertProvider) (*tls.Config, clientCaPool), nil } -func newClientTLSConfig(clientProvider ClientCertProvider, isAuthRequired bool, isWorker bool) (*tls.Config, error) { +func newClientTLSConfig(clientProvider ClientCertProvider, serverName string, isAuthRequired bool, + isWorker bool, enableHostVerification bool) (*tls.Config, error) { // Optional ServerCA for client if not already trusted by host serverCa, err := clientProvider.FetchServerRootCAsForClient(isWorker) if err != nil { @@ -309,8 +316,8 @@ func newClientTLSConfig(clientProvider ClientCertProvider, isAuthRequired bool, return auth.NewDynamicTLSClientConfig( getCert, serverCa, - clientProvider.ServerName(isWorker), - !clientProvider.DisableHostVerification(isWorker), + serverName, + enableHostVerification, ), nil } diff --git a/common/rpc/encryption/testDynamicCertProvider.go b/common/rpc/encryption/testDynamicCertProvider.go index 1aadcf28770..ab250039cf2 100644 --- a/common/rpc/encryption/testDynamicCertProvider.go +++ b/common/rpc/encryption/testDynamicCertProvider.go @@ -82,19 +82,11 @@ func (t *TestDynamicCertProvider) FetchServerRootCAsForClient(_ bool) (*x509.Cer return t.caCerts, nil } -func (t *TestDynamicCertProvider) ServerName(_ bool) string { - return t.serverName -} - -func (t *TestDynamicCertProvider) DisableHostVerification(_ bool) bool { - return false -} - -func (t *TestDynamicCertProvider) GetCertProvider(hostName string) (CertProvider, error) { +func (t *TestDynamicCertProvider) GetCertProvider(hostName string) (CertProvider, bool, error) { if hostName == "localhost" { - return t, nil + return t, false, nil } - return nil, nil + return nil, false, nil } func (t *TestDynamicCertProvider) SwitchToWrongServerRootCACerts() { diff --git a/common/rpc/encryption/testDynamicTLSConfigProvider.go b/common/rpc/encryption/testDynamicTLSConfigProvider.go index f1533c79eae..2563b3adf7d 100644 --- a/common/rpc/encryption/testDynamicTLSConfigProvider.go +++ b/common/rpc/encryption/testDynamicTLSConfigProvider.go @@ -49,19 +49,19 @@ type TestDynamicTLSConfigProvider struct { } func (t *TestDynamicTLSConfigProvider) GetInternodeServerConfig() (*tls.Config, error) { - return newServerTLSConfig(t.InternodeCertProvider, nil) + return newServerTLSConfig(t.InternodeCertProvider, nil, &t.settings.Internode) } func (t *TestDynamicTLSConfigProvider) GetInternodeClientConfig() (*tls.Config, error) { - return newClientTLSConfig(t.InternodeClientCertProvider, true, false) + return newClientTLSConfig(t.InternodeClientCertProvider, t.settings.Internode.Client.ServerName, true, false, true) } func (t *TestDynamicTLSConfigProvider) GetFrontendServerConfig() (*tls.Config, error) { - return newServerTLSConfig(t.FrontendCertProvider, t.FrontendPerHostCertProviderMap) + return newServerTLSConfig(t.FrontendCertProvider, t.FrontendPerHostCertProviderMap, &t.settings.Frontend) } func (t *TestDynamicTLSConfigProvider) GetFrontendClientConfig() (*tls.Config, error) { - return newClientTLSConfig(t.WorkerCertProvider, true, false) + return newClientTLSConfig(t.WorkerCertProvider, t.settings.Frontend.Client.ServerName, true, false, true) } var _ TLSConfigProvider = (*TestDynamicTLSConfigProvider)(nil) diff --git a/common/rpc/encryption/tlsFactory.go b/common/rpc/encryption/tlsFactory.go index 0283e70e698..5fc2b511820 100644 --- a/common/rpc/encryption/tlsFactory.go +++ b/common/rpc/encryption/tlsFactory.go @@ -47,20 +47,17 @@ type ( CertProvider interface { FetchServerCertificate() (*tls.Certificate, error) FetchClientCAs() (*x509.CertPool, error) - GetSettings() *config.GroupTLS } // ClientCertProvider is an interface to load raw TLS/X509 primitives for configuring clients. ClientCertProvider interface { FetchClientCertificate(isWorker bool) (*tls.Certificate, error) FetchServerRootCAsForClient(isWorker bool) (*x509.CertPool, error) - ServerName(isWorker bool) string - DisableHostVerification(isWorker bool) bool } // PerHostCertProviderMap returns a CertProvider for a given host name. PerHostCertProviderMap interface { - GetCertProvider(hostName string) (CertProvider, error) + GetCertProvider(hostName string) (provider CertProvider, clientAuthRequired bool, err error) } CertThumbprint [16]byte