Skip to content

Commit

Permalink
*: watch tls certificate changes
Browse files Browse the repository at this point in the history
  • Loading branch information
s-urbaniak committed Jun 20, 2019
1 parent 03224ae commit c1502bb
Show file tree
Hide file tree
Showing 3 changed files with 495 additions and 5 deletions.
30 changes: 25 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"os"
"os/signal"
"syscall"
"time"

"github.com/ghodss/yaml"
"github.com/golang/glog"
Expand All @@ -46,6 +47,7 @@ import (
"github.com/brancz/kube-rbac-proxy/pkg/authn"
"github.com/brancz/kube-rbac-proxy/pkg/authz"
"github.com/brancz/kube-rbac-proxy/pkg/proxy"
rbac_proxy_tls "github.com/brancz/kube-rbac-proxy/pkg/tls"
)

type config struct {
Expand All @@ -60,10 +62,11 @@ type config struct {
}

type tlsConfig struct {
certFile string
keyFile string
minVersion string
cipherSuites []string
certFile string
keyFile string
minVersion string
cipherSuites []string
reloadInterval time.Duration
}

type configfile struct {
Expand Down Expand Up @@ -113,6 +116,7 @@ func main() {
flagset.StringVar(&cfg.tls.keyFile, "tls-private-key-file", "", "File containing the default x509 private key matching --tls-cert-file.")
flagset.StringVar(&cfg.tls.minVersion, "tls-min-version", "VersionTLS12", "Minimum TLS version supported. Value must match version names from https://golang.org/pkg/crypto/tls/#pkg-constants.")
flagset.StringSliceVar(&cfg.tls.cipherSuites, "tls-cipher-suites", nil, "Comma-separated list of cipher suites for the server. Values are from tls package constants (https://golang.org/pkg/crypto/tls/#pkg-constants). If omitted, the default Go cipher suites will be used")
flagset.DurationVar(&cfg.tls.reloadInterval, "tls-reload-interval", time.Minute, "The interval at which to watch for TLS certificate changes, by default set to 1 minute.")

// Auth flags
flagset.StringVar(&cfg.auth.Authentication.X509.ClientCAFile, "client-ca-file", "", "If set, any request presenting a client certificate signed by one of the authorities in the client-ca-file is authenticated with an identity corresponding to the CommonName of the client certificate.")
Expand Down Expand Up @@ -233,6 +237,21 @@ func main() {
}

srv.TLSConfig.Certificates = []tls.Certificate{cert}
} else {
glog.Info("Reading certificate files")
ctx, cancel := context.WithCancel(context.Background())
r, err := rbac_proxy_tls.NewCertReloader(cfg.tls.certFile, cfg.tls.keyFile, cfg.tls.reloadInterval)
if err != nil {
glog.Fatalf("Failed to initialize certificate reloader: %v", err)
}

srv.TLSConfig.GetCertificate = r.GetCertificate

gr.Add(func() error {
return r.Watch(ctx)
}, func(error) {
cancel()
})
}

version, err := tlsVersion(cfg.tls.minVersion)
Expand Down Expand Up @@ -260,7 +279,8 @@ func main() {

gr.Add(func() error {
glog.Infof("Listening securely on %v", cfg.secureListenAddress)
return srv.ServeTLS(l, cfg.tls.certFile, cfg.tls.keyFile)
tlsListener := tls.NewListener(l, srv.TLSConfig)
return srv.Serve(tlsListener)
}, func(err error) {
if err := srv.Shutdown(context.Background()); err != nil {
glog.Errorf("failed to gracefully shutdown server: %v", err)
Expand Down
125 changes: 125 additions & 0 deletions pkg/tls/reloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
Copyright 2017 Frederic Branczyk All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package tls

import (
"context"
"crypto/tls"
"fmt"
"hash/fnv"
"io/ioutil"
"sync"
"time"

"github.com/golang/glog"
)

// CertReloader is the struct that parses a certificate/key pair,
// providing a goroutine safe GetCertificate method to retrieve the parsed content.
//
// The GetCertificate signature is compatible with https://golang.org/pkg/crypto/tls/#Config.GetCertificate
// and can be used to hot-reload a certificate/key pair.
//
// For hot-reloading the Watch method must be started explicitly.
type CertReloader struct {
certPath, keyPath string
interval time.Duration

mu sync.RWMutex // protects the fields below
cert *tls.Certificate
hash uint64
}

func NewCertReloader(certPath, keyPath string, interval time.Duration) (*CertReloader, error) {
r := &CertReloader{
certPath: certPath,
keyPath: keyPath,
interval: interval,
}

if err := r.reload(); err != nil {
return nil, fmt.Errorf("error loading certificates: %v", err)
}

return r, nil
}

// Watch watches the configured certificate and key path and blocks the current goroutine
// until the scenario context is done or an error occurred during reloading.
func (r *CertReloader) Watch(ctx context.Context) error {
t := time.NewTicker(r.interval)

for {
select {
case <-t.C:
case <-ctx.Done():
return nil
}

if err := r.reload(); err != nil {
return fmt.Errorf("reloading failed: %v", err)
}
}
}

func (r *CertReloader) reload() error {
r.mu.Lock()
defer r.mu.Unlock()

var err error

crt, err := ioutil.ReadFile(r.certPath)
key, err := ioutil.ReadFile(r.keyPath)

if err != nil {
return fmt.Errorf("error loading certificate: %v", err)
}

h := fnv.New64a()
_, err = h.Write(crt)
_, err = h.Write(key)

if err != nil {
return fmt.Errorf("error hashing certificate: %v", err)
}

newHash := h.Sum64()
if newHash == r.hash {
return nil
}

cert, err := tls.X509KeyPair(crt, key)
if err != nil {
return fmt.Errorf("error parsing certificate: %v", err)
}

glog.V(4).Info("reloading key ", r.keyPath, " certificate ", r.certPath)

r.cert = &cert
r.hash = newHash
return nil
}

// GetCertificate returns the current valid certificate.
// The ClientHello message is ignored
// and is just there to be compatible with https://golang.org/pkg/crypto/tls/#Config.GetCertificate.
func (r *CertReloader) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
r.mu.RLock()
defer r.mu.RUnlock()

return r.cert, nil
}
Loading

0 comments on commit c1502bb

Please sign in to comment.