diff --git a/main.go b/main.go index cac594414..b917d6323 100644 --- a/main.go +++ b/main.go @@ -29,6 +29,7 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/ghodss/yaml" "github.com/golang/glog" @@ -46,6 +47,7 @@ import ( "github.com/brancz/kube-rbac-proxy/pkg/authn" "github.com/brancz/kube-rbac-proxy/pkg/authz" "github.com/brancz/kube-rbac-proxy/pkg/proxy" + rbac_proxy_tls "github.com/brancz/kube-rbac-proxy/pkg/tls" ) type config struct { @@ -60,10 +62,11 @@ type config struct { } type tlsConfig struct { - certFile string - keyFile string - minVersion string - cipherSuites []string + certFile string + keyFile string + minVersion string + cipherSuites []string + reloadInterval time.Duration } type configfile struct { @@ -113,6 +116,7 @@ func main() { flagset.StringVar(&cfg.tls.keyFile, "tls-private-key-file", "", "File containing the default x509 private key matching --tls-cert-file.") flagset.StringVar(&cfg.tls.minVersion, "tls-min-version", "VersionTLS12", "Minimum TLS version supported. Value must match version names from https://golang.org/pkg/crypto/tls/#pkg-constants.") flagset.StringSliceVar(&cfg.tls.cipherSuites, "tls-cipher-suites", nil, "Comma-separated list of cipher suites for the server. Values are from tls package constants (https://golang.org/pkg/crypto/tls/#pkg-constants). If omitted, the default Go cipher suites will be used") + flagset.DurationVar(&cfg.tls.reloadInterval, "tls-reload-interval", time.Minute, "The interval at which to watch for TLS certificate changes, by default set to 1 minute.") // Auth flags flagset.StringVar(&cfg.auth.Authentication.X509.ClientCAFile, "client-ca-file", "", "If set, any request presenting a client certificate signed by one of the authorities in the client-ca-file is authenticated with an identity corresponding to the CommonName of the client certificate.") @@ -233,6 +237,21 @@ func main() { } srv.TLSConfig.Certificates = []tls.Certificate{cert} + } else { + glog.Info("Reading certificate files") + ctx, cancel := context.WithCancel(context.Background()) + r, err := rbac_proxy_tls.NewCertReloader(cfg.tls.certFile, cfg.tls.keyFile, cfg.tls.reloadInterval) + if err != nil { + glog.Fatalf("Failed to initialize certificate reloader: %v", err) + } + + srv.TLSConfig.GetCertificate = r.GetCertificate + + gr.Add(func() error { + return r.Watch(ctx) + }, func(error) { + cancel() + }) } version, err := tlsVersion(cfg.tls.minVersion) @@ -260,7 +279,8 @@ func main() { gr.Add(func() error { glog.Infof("Listening securely on %v", cfg.secureListenAddress) - return srv.ServeTLS(l, cfg.tls.certFile, cfg.tls.keyFile) + tlsListener := tls.NewListener(l, srv.TLSConfig) + return srv.Serve(tlsListener) }, func(err error) { if err := srv.Shutdown(context.Background()); err != nil { glog.Errorf("failed to gracefully shutdown server: %v", err) diff --git a/pkg/tls/reloader.go b/pkg/tls/reloader.go new file mode 100644 index 000000000..6b335ca26 --- /dev/null +++ b/pkg/tls/reloader.go @@ -0,0 +1,125 @@ +/* +Copyright 2017 Frederic Branczyk All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tls + +import ( + "context" + "crypto/tls" + "fmt" + "hash/fnv" + "io/ioutil" + "sync" + "time" + + "github.com/golang/glog" +) + +// CertReloader is the struct that parses a certificate/key pair, +// providing a goroutine safe GetCertificate method to retrieve the parsed content. +// +// The GetCertificate signature is compatible with https://golang.org/pkg/crypto/tls/#Config.GetCertificate +// and can be used to hot-reload a certificate/key pair. +// +// For hot-reloading the Watch method must be started explicitly. +type CertReloader struct { + certPath, keyPath string + interval time.Duration + + mu sync.RWMutex // protects the fields below + cert *tls.Certificate + hash uint64 +} + +func NewCertReloader(certPath, keyPath string, interval time.Duration) (*CertReloader, error) { + r := &CertReloader{ + certPath: certPath, + keyPath: keyPath, + interval: interval, + } + + if err := r.reload(); err != nil { + return nil, fmt.Errorf("error loading certificates: %v", err) + } + + return r, nil +} + +// Watch watches the configured certificate and key path and blocks the current goroutine +// until the scenario context is done or an error occurred during reloading. +func (r *CertReloader) Watch(ctx context.Context) error { + t := time.NewTicker(r.interval) + + for { + select { + case <-t.C: + case <-ctx.Done(): + return nil + } + + if err := r.reload(); err != nil { + return fmt.Errorf("reloading failed: %v", err) + } + } +} + +func (r *CertReloader) reload() error { + r.mu.Lock() + defer r.mu.Unlock() + + var err error + + crt, err := ioutil.ReadFile(r.certPath) + key, err := ioutil.ReadFile(r.keyPath) + + if err != nil { + return fmt.Errorf("error loading certificate: %v", err) + } + + h := fnv.New64a() + _, err = h.Write(crt) + _, err = h.Write(key) + + if err != nil { + return fmt.Errorf("error hashing certificate: %v", err) + } + + newHash := h.Sum64() + if newHash == r.hash { + return nil + } + + cert, err := tls.X509KeyPair(crt, key) + if err != nil { + return fmt.Errorf("error parsing certificate: %v", err) + } + + glog.V(4).Info("reloading key ", r.keyPath, " certificate ", r.certPath) + + r.cert = &cert + r.hash = newHash + return nil +} + +// GetCertificate returns the current valid certificate. +// The ClientHello message is ignored +// and is just there to be compatible with https://golang.org/pkg/crypto/tls/#Config.GetCertificate. +func (r *CertReloader) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.cert, nil +} diff --git a/pkg/tls/reloader_test.go b/pkg/tls/reloader_test.go new file mode 100644 index 000000000..2961ad304 --- /dev/null +++ b/pkg/tls/reloader_test.go @@ -0,0 +1,345 @@ +/* +Copyright 2017 Frederic Branczyk All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tls + +import ( + "context" + "crypto/x509" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "path" + "strings" + "testing" + "time" + + "github.com/golang/glog" + "k8s.io/apimachinery/pkg/util/wait" + + certutil "k8s.io/client-go/util/cert" +) + +func TestReloader(t *testing.T) { + cases := []struct { + name string + given stepFunc + check checkFunc + }{ + { + name: "match cn", + given: steps( + newSelfSignedCert("foo"), + newCertReloader, + ), + check: commonNameIs("foo"), + }, + { + name: "change", + given: steps( + newSelfSignedCert("foo"), + newCertReloader, + startWatching, + newSelfSignedCert("baz"), + swapCert, + ), + check: commonNameIs("baz"), + }, + { + name: "double symlink", + given: steps( + newSelfSignedCert("foo"), + doubleSymlinkCert, + newCertReloader, + startWatching, + newSelfSignedCert("bar"), + swapSymlink, + ), + check: commonNameIs("bar"), + }, + { + name: "swap double symlink twice", + given: steps( + newSelfSignedCert("foo"), + doubleSymlinkCert, + newCertReloader, + startWatching, + newSelfSignedCert("bar"), + swapSymlink, + newSelfSignedCert("baz"), + swapSymlink, + ), + check: commonNameIs("baz"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s := &scenario{} + + tc.given(t, s) + + if err := tc.check(s); err != nil { + t.Error(err) + } + + for _, cleanup := range s.cleanups { + cleanup() + } + }) + } +} + +func TestMain(m *testing.M) { + var err error + err = flag.Set("alsologtostderr", "true") + err = flag.Set("v", "5") + if err != nil { + log.Fatal(err) + } + + flag.Parse() + os.Exit(m.Run()) +} + +type scenario struct { + certPath, keyPath string + reloader *CertReloader + cleanups []func() +} + +type stepFunc func(*testing.T, *scenario) + +type checkFunc func(*scenario) error + +func commonNameIs(want string) checkFunc { + return func(g *scenario) error { + return poll(10*time.Millisecond, 100*time.Millisecond, func() (err error) { + cert, err := g.reloader.GetCertificate(nil) + if err != nil { + return fmt.Errorf("error getting certificate: %v", err) + } + + first, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fmt.Errorf("error parsing certificate: %v", err) + } + + if !strings.HasPrefix(first.Subject.CommonName, want) { + return fmt.Errorf("want subject common name to start with %q, got %q", want, first.Subject.CommonName) + } + + return nil + }) + } +} + +func newCertReloader(t *testing.T, s *scenario) { + r, err := NewCertReloader(s.certPath, s.keyPath, 10*time.Millisecond) + if err != nil { + t.Fatalf("error creating cert reloader: %v", err) + } + s.reloader = r +} + +func startWatching(t *testing.T, s *scenario) { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error) + + go func() { + done <- s.reloader.Watch(ctx) + }() + + cleanup := func() { + cancel() + + if err := <-done; err != nil { + t.Fatal(err) + } + } + + s.cleanups = append([]func(){cleanup}, s.cleanups...) +} + +func newSelfSignedCert(hostname string) stepFunc { + return func(t *testing.T, s *scenario) { + var err error + certBytes, keyBytes, err := certutil.GenerateSelfSignedCertKey(hostname, nil, nil) + if err != nil { + t.Fatalf("generation of self signed cert and key failed: %v", err) + } + + certPath, err := writeTempFile("cert", certBytes) + keyPath, err := writeTempFile("key", keyBytes) + if err != nil { + t.Fatalf("error writing cert/key data: %v", err) + } + + s.certPath = certPath + s.keyPath = keyPath + + s.cleanups = append(s.cleanups, func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + } +} + +func doubleSymlinkCert(t *testing.T, s *scenario) { + name, err := ioutil.TempDir("", "keys") + if err != nil { + t.Fatal(err) + } + + keyPath := path.Join(name, "key") + if err := os.Rename(s.keyPath, keyPath); err != nil { + t.Fatal(err) + } + + certPath := path.Join(name, "cert") + if err := os.Rename(s.certPath, certPath); err != nil { + t.Fatal(err) + } + + keysdir := path.Join(os.TempDir(), "keys") + if err := os.Symlink(name, keysdir); err != nil { + t.Fatal(err) + } + + keyLink := path.Join(os.TempDir(), "key") + _ = os.Symlink(path.Join(keysdir, "key"), keyLink) + + certLink := path.Join(os.TempDir(), "cert") + _ = os.Symlink(path.Join(keysdir, "cert"), certLink) + + s.keyPath = keyLink + s.certPath = certLink + + s.cleanups = append(s.cleanups, func() { + _ = os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyLink) + _ = os.Remove(certLink) + _ = os.Remove(keysdir) + _ = os.RemoveAll(name) + }) +} + +func swapCert(t *testing.T, s *scenario) { + t.Log("renaming", s.keyPath, "to", s.reloader.keyPath) + if err := os.Rename(s.certPath, s.reloader.certPath); err != nil { + t.Fatal(err) + } + + if err := os.Rename(s.keyPath, s.reloader.keyPath); err != nil { + t.Fatal(err) + } + + s.certPath = s.reloader.certPath + s.keyPath = s.reloader.keyPath +} + +func swapSymlink(t *testing.T, s *scenario) { + name, err := ioutil.TempDir("", "keys") + if err != nil { + t.Fatal(err) + } + + keyPath := path.Join(name, "key") + if err := os.Rename(s.keyPath, keyPath); err != nil { + t.Fatal(err) + } + + certPath := path.Join(name, "cert") + if err := os.Rename(s.certPath, certPath); err != nil { + t.Fatal(err) + } + + tmp := path.Join(os.TempDir(), "keys.tmp") + if err := os.Symlink(name, tmp); err != nil { + t.Fatal(err) + } + + keysdir := path.Join(os.TempDir(), "keys") + if err := os.Rename(tmp, keysdir); err != nil { + t.Fatal(err) + } + + s.keyPath = path.Join(os.TempDir(), "key") + s.certPath = path.Join(os.TempDir(), "cert") + + s.cleanups = append(s.cleanups, func() { + _ = os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keysdir) + _ = os.RemoveAll(name) + }) +} + +func steps(gs ...stepFunc) stepFunc { + return func(t *testing.T, g *scenario) { + for _, gf := range gs { + gf(t, g) + } + } +} + +func writeTempFile(pattern string, data []byte) (string, error) { + f, err := ioutil.TempFile("", pattern) + if err != nil { + return "", fmt.Errorf("error creating temp file: %v", err) + } + defer f.Close() + + n, err := f.Write(data) + if err == nil && n < len(data) { + err = io.ErrShortWrite + } + + if err != nil { + return "", fmt.Errorf("error writing temporary file: %v", err) + } + + return f.Name(), nil +} + +// poll calls the scenario function f every scenario interval +// until it returns no error or the scenario timeout occurs. +// If a timeout occurs, the last observed error is returned +// or wait.ErrWaitTimeout if no error occurred. +func poll(interval, timeout time.Duration, f func() error) error { + var lastErr error + + err := wait.Poll(interval, timeout, func() (bool, error) { + lastErr = f() + + if lastErr != nil { + glog.V(4).Infof("error loading certificate: %v, retrying ...", lastErr) + return false, nil + } + + return true, nil + }) + + if err != nil && err == wait.ErrWaitTimeout && lastErr != nil { + err = fmt.Errorf("%v: %v", err, lastErr) + } + + return err +}