From 458718f83d6c53ef612d6718dd6f3a6e4d3e0baa Mon Sep 17 00:00:00 2001 From: Jawad Zaheer Date: Wed, 14 Jun 2023 10:08:49 +0000 Subject: [PATCH] Added tls configuration flags for CAPM3. --- main.go | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++ main_test.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 main_test.go diff --git a/main.go b/main.go index e7cb1d29d5..2603c87a24 100644 --- a/main.go +++ b/main.go @@ -18,10 +18,12 @@ package main import ( "context" + "crypto/tls" "flag" "fmt" "math/rand" "os" + "strings" "time" bmov1alpha1 "github.com/metal3-io/baremetal-operator/apis/metal3.io/v1alpha1" @@ -39,6 +41,7 @@ import ( _ "k8s.io/client-go/plugin/pkg/client/auth/gcp" "k8s.io/client-go/rest" "k8s.io/client-go/tools/leaderelection/resourcelock" + cliflag "k8s.io/component-base/cli/flag" "k8s.io/component-base/logs" logsv1 "k8s.io/component-base/logs/api/v1" _ "k8s.io/component-base/logs/json/register" @@ -50,6 +53,20 @@ import ( // +kubebuilder:scaffold:imports ) +type TLSVersion string + +// Constants for TLS versions. +const ( + TLSVersion12 TLSVersion = "TLS12" + TLSVersion13 TLSVersion = "TLS13" +) + +type TLSOptions struct { + TLSMaxVersion string + TLSMinVersion string + TLSCipherSuites string +} + var ( myscheme = runtime.NewScheme() setupLog = ctrl.Log.WithName("setup") @@ -76,6 +93,8 @@ var ( watchFilterValue string logOptions = logs.NewOptions() enableBMHNameBasedPreallocation bool + tlsOptions = TLSOptions{} + tlsSupportedVersions = []string{"TLS12", "TLS13"} ) func init() { @@ -105,6 +124,12 @@ func main() { restConfig.QPS = restConfigQPS restConfig.Burst = restConfigBurst restConfig.UserAgent = "cluster-api-provider-metal3-manager" + + tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(tlsOptions) + if err != nil { + setupLog.Error(err, "unable to add TLS settings to the webhook server") + os.Exit(1) + } mgr, err := ctrl.NewManager(restConfig, ctrl.Options{ Scheme: myscheme, MetricsBindAddress: metricsBindAddr, @@ -119,6 +144,7 @@ func main() { CertDir: webhookCertDir, HealthProbeBindAddress: healthAddr, Namespace: watchNamespace, + TLSOpts: tlsOptionOverrides, }) if err != nil { setupLog.Error(err, "unable to start manager") @@ -266,6 +292,23 @@ func initFlags(fs *pflag.FlagSet) { fs.IntVar(&restConfigBurst, "kube-api-burst", 30, "Maximum number of queries that should be allowed in one burst from the controller client to the Kubernetes API server. Default 30") + flag.StringVar(&tlsOptions.TLSMinVersion, "tls-min-version", "TLS12", + "The minimum TLS version in use by the webhook server.\n"+ + fmt.Sprintf("Possible values are %s.", strings.Join(tlsSupportedVersions, ", ")), + ) + + fs.StringVar(&tlsOptions.TLSMaxVersion, "tls-max-version", "TLS13", + "The maximum TLS version in use by the webhook server.\n"+ + fmt.Sprintf("Possible values are %s.", strings.Join(tlsSupportedVersions, ", ")), + ) + + tlsCipherPreferredValues := cliflag.PreferredTLSCipherNames() + tlsCipherInsecureValues := cliflag.InsecureTLSCipherNames() + fs.StringVar(&tlsOptions.TLSCipherSuites, "tls-cipher-suites", "", + "Comma-separated list of cipher suites for the webhook server. "+ + "If omitted, the default Go cipher suites will be used. \n"+ + "Preferred values: "+strings.Join(tlsCipherPreferredValues, ", ")+". \n"+ + "Insecure values: "+strings.Join(tlsCipherInsecureValues, ", ")+".") } func waitForAPIs(cfg *rest.Config) error { @@ -421,3 +464,76 @@ func setupWebhooks(mgr ctrl.Manager) { func concurrency(c int) controller.Options { return controller.Options{MaxConcurrentReconciles: c} } + +// GetTLSOptionOverrideFuncs returns a list of TLS configuration overrides to be used +// by the webhook server. +func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) { + var tlsOptions []func(config *tls.Config) + + tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion) + if err != nil { + return nil, err + } + + tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion) + if err != nil { + return nil, err + } + + if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion { + return nil, fmt.Errorf("TLS version flag min version (%s) is greater than max version (%s)", + options.TLSMinVersion, options.TLSMaxVersion) + } + + tlsOptions = append(tlsOptions, func(cfg *tls.Config) { + cfg.MinVersion = tlsMinVersion + }) + + tlsOptions = append(tlsOptions, func(cfg *tls.Config) { + cfg.MaxVersion = tlsMaxVersion + }) + // Cipher suites should not be set if empty. + if options.TLSMinVersion == string(TLSVersion13) && + options.TLSMaxVersion == string(TLSVersion13) && + options.TLSCipherSuites != "" { + setupLog.Info("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers") + options.TLSCipherSuites = "" + } + + if options.TLSCipherSuites != "" { + tlsCipherSuites := strings.Split(options.TLSCipherSuites, ",") + suites, err := cliflag.TLSCipherSuites(tlsCipherSuites) + if err != nil { + return nil, err + } + + insecureCipherValues := cliflag.InsecureTLSCipherNames() + for _, cipher := range tlsCipherSuites { + for _, insecureCipherName := range insecureCipherValues { + if insecureCipherName == cipher { + setupLog.Info(fmt.Sprintf("warning: use of insecure cipher '%s' detected.", cipher)) + } + } + } + tlsOptions = append(tlsOptions, func(cfg *tls.Config) { + cfg.CipherSuites = suites + }) + } + + return tlsOptions, nil +} + +// GetTLSVersion returns the corresponding tls.Version or error. +func GetTLSVersion(version string) (uint16, error) { + var v uint16 + + switch version { + case string(TLSVersion12): + v = tls.VersionTLS12 + case string(TLSVersion13): + v = tls.VersionTLS13 + default: + return 0, fmt.Errorf("unexpected TLS version %q (must be one of: TLS12, TLS13)", version) + } + return v, nil +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000000..f3a1131112 --- /dev/null +++ b/main_test.go @@ -0,0 +1,111 @@ +/* +Copyright 2023 The Metal3 Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "bytes" + "testing" + + . "github.com/onsi/gomega" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" +) + +func TestTLSInsecureCiperSuite(t *testing.T) { + t.Run("test insecure cipher suite passed as TLS flag", func(t *testing.T) { + g := NewWithT(t) + tlsMockOptions := TLSOptions{ + TLSMaxVersion: "TLS13", + TLSMinVersion: "TLS12", + TLSCipherSuites: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", + } + ctrl.Log.WithName("setup") + ctrl.SetLogger(klog.Background()) + + bufWriter := bytes.NewBuffer(nil) + klog.SetOutput(bufWriter) + klog.LogToStderr(false) // this is important, because klog by default logs to stderr only + _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) + g.Expect(err).Should(BeNil()) + g.Expect(bufWriter.String()).Should(ContainSubstring("use of insecure cipher 'TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256' detected.")) + }) +} + +func TestTLSMinAndMaxVersion(t *testing.T) { + t.Run("should fail if TLS min version is greater than max version.", func(t *testing.T) { + g := NewWithT(t) + tlsMockOptions := TLSOptions{ + TLSMaxVersion: "TLS12", + TLSMinVersion: "TLS13", + } + _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) + g.Expect(err.Error()).To(Equal("TLS version flag min version (TLS13) is greater than max version (TLS12)")) + }) +} + +func Test13CipherSuite(t *testing.T) { + t.Run("should reset ciphersuite flag if TLS min and max version are set to 1.3", func(t *testing.T) { + g := NewWithT(t) + + // Here TLS_RSA_WITH_AES_128_CBC_SHA is a tls12 cipher suite. + tlsMockOptions := TLSOptions{ + TLSMaxVersion: "TLS13", + TLSMinVersion: "TLS13", + TLSCipherSuites: "TLS_RSA_WITH_AES_128_CBC_SHA,TLS_AES_256_GCM_SHA384", + } + + ctrl.Log.WithName("setup") + ctrl.SetLogger(klog.Background()) + + bufWriter := bytes.NewBuffer(nil) + klog.SetOutput(bufWriter) + klog.LogToStderr(false) // this is important, because klog by default logs to stderr only + _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) + g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers")) + g.Expect(err).Should(BeNil()) + }) +} + +func TestGetTLSVersion(t *testing.T) { + t.Run("should error out when incorrect tls version passed", func(t *testing.T) { + g := NewWithT(t) + tlsVersion := "TLS11" + _, err := GetTLSVersion(tlsVersion) + g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)")) + }) + t.Run("should pass and output correct tls version", func(t *testing.T) { + const VersionTLS12 uint16 = 771 + g := NewWithT(t) + tlsVersion := "TLS12" + version, err := GetTLSVersion(tlsVersion) + g.Expect(version).To(Equal(VersionTLS12)) + g.Expect(err).Should(BeNil()) + }) +} + +func TestTLSOptions(t *testing.T) { + t.Run("should pass with all the correct options below with no error.", func(t *testing.T) { + g := NewWithT(t) + tlsMockOptions := TLSOptions{ + TLSMinVersion: "TLS12", + TLSMaxVersion: "TLS13", + TLSCipherSuites: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + } + _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) + g.Expect(err).Should(BeNil()) + }) +}