Skip to content

Commit

Permalink
feat: hot-reload TLS certificate (#3265)
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Oct 4, 2022
1 parent 5842946 commit 1d13be6
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 45 deletions.
8 changes: 6 additions & 2 deletions cmd/server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,10 @@ func serve(
}

var tlsConfig *tls.Config
stopReload := make(chan struct{})
if tc := d.Config().TLS(ctx, iface); tc.Enabled() {
// #nosec G402 - This is a false positive because we use graceful.WithDefaults which sets the correct TLS settings.
tlsConfig = &tls.Config{Certificates: GetOrCreateTLSCertificate(ctx, cmd, d, iface)}
tlsConfig = &tls.Config{GetCertificate: GetOrCreateTLSCertificate(ctx, d, iface, stopReload)}
}

var srv = graceful.WithDefaults(&http.Server{
Expand Down Expand Up @@ -362,7 +363,10 @@ func serve(
}

return srv.Serve(listener)
}, srv.Shutdown); err != nil {
}, func(ctx context.Context) error {
close(stopReload)
return srv.Shutdown(ctx)
}); err != nil {
d.Logger().WithError(err).Fatal("Could not gracefully run server")
}
}
49 changes: 26 additions & 23 deletions cmd/server/helper_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"github.com/ory/hydra/driver/config"

"github.com/pkg/errors"
"github.com/spf13/cobra"

"github.com/ory/x/tlsx"

Expand All @@ -57,66 +56,70 @@ func AttachCertificate(priv *jose.JSONWebKey, cert *x509.Certificate) {
priv.CertificateThumbprintSHA1 = sig1[:]
}

var mapLock sync.RWMutex
var locks = map[string]*sync.RWMutex{}
var lock sync.Mutex

func getLock(set string) *sync.RWMutex {
mapLock.Lock()
defer mapLock.Unlock()
if _, ok := locks[set]; !ok {
locks[set] = new(sync.RWMutex)
}
return locks[set]
}

func GetOrCreateTLSCertificate(ctx context.Context, cmd *cobra.Command, d driver.Registry, iface config.ServeInterface) []tls.Certificate {
getLock(TlsKeyName).Lock()
defer getLock(TlsKeyName).Unlock()

cert, err := d.Config().TLS(ctx, iface).Certificate()
// GetOrCreateTLSCertificate returns a function for use with
// "net/tls".Config.GetCertificate. If the certificate and key are read from
// disk, they will be automatically reloaded until stopReload is close()'d.
func GetOrCreateTLSCertificate(ctx context.Context, d driver.Registry, iface config.ServeInterface, stopReload <-chan struct{}) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
lock.Lock()
defer lock.Unlock()

// check if certificates are configured
certFunc, err := d.Config().TLS(ctx, iface).GetCertificateFunc(stopReload, d.Logger())
if err == nil {
return cert
return certFunc
} else if !errors.Is(err, tlsx.ErrNoCertificatesConfigured) {
d.Logger().WithError(err).Fatalf("Unable to load HTTPS TLS Certificate")
d.Logger().WithError(err).Fatal("Unable to load HTTPS TLS Certificate")
return nil // in case Fatal is hooked
}

// no certificates configured: self-sign a new cert
priv, err := jwk.GetOrGenerateKeys(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256")
if err != nil {
d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair")
return nil // in case Fatal is hooked
}

if len(priv.Certificates) == 0 {
cert, err := tlsx.CreateSelfSignedCertificate(priv.Key)
if err != nil {
d.Logger().WithError(err).Fatalf(`Could not generate a self signed TLS certificate`)
d.Logger().WithError(err).Fatal(`Could not generate a self signed TLS certificate`)
return nil // in case Fatal is hooked
}

AttachCertificate(priv, cert)
if err := d.SoftwareKeyManager().DeleteKey(ctx, TlsKeyName, priv.KeyID); err != nil {
d.Logger().WithError(err).Fatal(`Could not update (delete) the self signed TLS certificate`)
return nil // in case Fatal is hooked
}

if err := d.SoftwareKeyManager().AddKey(ctx, TlsKeyName, priv); err != nil {
d.Logger().WithError(err).Fatalf(`Could not update (add) the self signed TLS certificate: %s %x %d`, cert.SignatureAlgorithm, cert.Signature, len(cert.Signature))
return nil // in case Fatalf is hooked
}
}

block, err := jwk.PEMBlockForKey(priv.Key)
if err != nil {
d.Logger().WithError(err).Fatalf("Could not encode key to PEM")
d.Logger().WithError(err).Fatal("Could not encode key to PEM")
return nil // in case Fatal is hooked
}

if len(priv.Certificates) == 0 {
d.Logger().Fatal("TLS certificate chain can not be empty")
return nil // in case Fatal is hooked
}

pemCert := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: priv.Certificates[0].Raw})
pemKey := pem.EncodeToMemory(block)
ct, err := tls.X509KeyPair(pemCert, pemKey)
if err != nil {
d.Logger().WithError(err).Fatalf("Could not decode certificate")
d.Logger().WithError(err).Fatal("Could not decode certificate")
return nil // in case Fatal is hooked
}

return []tls.Certificate{ct}
return func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &ct, nil
}
}
118 changes: 118 additions & 0 deletions cmd/server/helper_cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,138 @@ package server_test
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"os"
"testing"
"time"

"github.com/google/uuid"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"

"github.com/ory/x/configx"
"github.com/ory/x/logrusx"
"github.com/ory/x/tlsx"

"github.com/ory/hydra/cmd/server"
"github.com/ory/hydra/driver"
"github.com/ory/hydra/driver/config"
"github.com/ory/hydra/internal/testhelpers"
"github.com/ory/hydra/jwk"
)

func TestGetOrCreateTLSCertificate(t *testing.T) {
certPath, keyPath, cert, priv := testhelpers.GenerateTLSCertificateFilesForTests(t)
logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { t.Fatalf("Logger called os.Exit(%v)", code) }
hook := test.NewLocal(logger.Logger)
cfg := config.MustNew(
context.Background(),
logger,
configx.WithValues(map[string]interface{}{
"dsn": config.DSNMemory,
"serve.tls.enabled": true,
"serve.tls.cert.path": certPath,
"serve.tls.key.path": keyPath,
}),
)
d, err := driver.NewRegistryWithoutInit(cfg, logger)
require.NoError(t, err)
getCert := server.GetOrCreateTLSCertificate(context.Background(), d, config.AdminInterface, nil)
require.NotNil(t, getCert)
tlsCert, err := getCert(nil)
require.NoError(t, err)
require.NotNil(t, tlsCert)
if tlsCert.Leaf == nil {
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
require.NoError(t, err)
}
require.True(t, tlsCert.Leaf.Equal(cert))
require.True(t, priv.Equal(tlsCert.PrivateKey))

// generate new cert+key
newCertPath, newKeyPath, newCert, newPriv := testhelpers.GenerateTLSCertificateFilesForTests(t)
require.False(t, cert.Equal(newCert))
require.False(t, priv.Equal(newPriv))
require.NotEqual(t, certPath, newCertPath)
require.NotEqual(t, keyPath, newKeyPath)

// move them into place
require.NoError(t, os.Rename(newKeyPath, keyPath))
require.NoError(t, os.Rename(newCertPath, certPath))

// give it some time and check we're reloaded
time.Sleep(150 * time.Millisecond)
require.Nil(t, hook.LastEntry())

// request another certificate: it should be the new one
tlsCert, err = getCert(nil)
require.NoError(t, err)
if tlsCert.Leaf == nil {
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
require.NoError(t, err)
}
require.True(t, tlsCert.Leaf.Equal(newCert))
require.True(t, newPriv.Equal(tlsCert.PrivateKey))

require.NoError(t, os.WriteFile(certPath, []byte{'j', 'u', 'n', 'k'}, 0))

timeout := time.After(500 * time.Millisecond)
for {
if hook.LastEntry() != nil {
break
}
select {
case <-timeout:
require.FailNow(t, "expected error log entry")
default:
}
}
require.Contains(t, hook.LastEntry().Message, "Failed to reload TLS certificates. Using the previously loaded certificates.")
}

func TestGetOrCreateTLSCertificateBase64(t *testing.T) {
certPath, keyPath, cert, priv := testhelpers.GenerateTLSCertificateFilesForTests(t)
certPEM, err := os.ReadFile(certPath)
require.NoError(t, err)
certBase64 := base64.StdEncoding.EncodeToString(certPEM)
keyPEM, err := os.ReadFile(keyPath)
require.NoError(t, err)
keyBase64 := base64.StdEncoding.EncodeToString(keyPEM)

logger := logrusx.New("", "")
logger.Logger.ExitFunc = func(code int) { t.Fatalf("Logger called os.Exit(%v)", code) }
hook := test.NewLocal(logger.Logger)
_ = hook
cfg := config.MustNew(
context.Background(),
logger,
configx.WithValues(map[string]interface{}{
"dsn": config.DSNMemory,
"serve.tls.enabled": true,
"serve.tls.cert.base64": certBase64,
"serve.tls.key.base64": keyBase64,
}),
)
d, err := driver.NewRegistryWithoutInit(cfg, logger)
require.NoError(t, err)
getCert := server.GetOrCreateTLSCertificate(context.Background(), d, config.AdminInterface, nil)
require.NotNil(t, getCert)
tlsCert, err := getCert(nil)
require.NoError(t, err)
require.NotNil(t, tlsCert)
if tlsCert.Leaf == nil {
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
require.NoError(t, err)
}
require.True(t, tlsCert.Leaf.Equal(cert))
require.True(t, priv.Equal(tlsCert.PrivateKey))
}

func TestCreateSelfSignedCertificate(t *testing.T) {
keys, err := jwk.GenerateJWK(context.Background(), jose.RS256, uuid.New().String(), "sig")
require.NoError(t, err)

Expand Down
3 changes: 3 additions & 0 deletions driver/config/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ func TestViperProviderValidates(t *testing.T) {
ServerURL: "http://sampling",
},
},
Zipkin: otelx.ZipkinConfig{
ServerURL: "http://zipkin/api/v2/spans",
},
},
}, c.Tracing())
}
Expand Down
38 changes: 35 additions & 3 deletions driver/config/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"crypto/tls"

"github.com/pkg/errors"

"github.com/ory/x/logrusx"
"github.com/ory/x/tlsx"
)

Expand All @@ -26,9 +29,11 @@ const (
type TLSConfig interface {
Enabled() bool
AllowTerminationFrom() []string
Certificate() ([]tls.Certificate, error)
GetCertificateFunc(stopReload <-chan struct{}, _ *logrusx.Logger) (func(*tls.ClientHelloInfo) (*tls.Certificate, error), error)
}

var _ TLSConfig = (*tlsConfig)(nil)

type tlsConfig struct {
enabled bool
allowTerminationFrom []string
Expand Down Expand Up @@ -58,6 +63,33 @@ func (p *DefaultProvider) TLS(ctx context.Context, iface ServeInterface) TLSConf
}
}

func (c *tlsConfig) Certificate() ([]tls.Certificate, error) {
return tlsx.Certificate(c.certString, c.keyString, c.certPath, c.keyPath)
func (c *tlsConfig) GetCertificateFunc(stopReload <-chan struct{}, log *logrusx.Logger) (func(*tls.ClientHelloInfo) (*tls.Certificate, error), error) {
if c.certPath != "" && c.keyPath != "" { // attempt to load from disk first (enables hot-reloading)
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-stopReload
cancel()
}()
errs := make(chan error, 1)
getCert, err := tlsx.GetCertificate(ctx, c.certPath, c.keyPath, errs)
if err != nil {
return nil, errors.WithStack(err)
}
go func() {
for err := range errs {
log.WithError(err).Error("Failed to reload TLS certificates. Using the previously loaded certificates.")
}
}()
return getCert, nil
}
if c.certString != "" && c.keyString != "" { // base64-encoded directly in config
cert, err := tlsx.CertificateFromBase64(c.certString, c.keyString)
if err != nil {
return nil, errors.WithStack(err)
}
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return &cert, nil
}, nil
}
return nil, tlsx.ErrNoCertificatesConfigured
}
12 changes: 6 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ require (
github.com/ory/herodot v0.9.13
github.com/ory/hydra-client-go v1.11.8
github.com/ory/jsonschema/v3 v3.0.7
github.com/ory/x v0.0.463
github.com/ory/x v0.0.474
github.com/pborman/uuid v1.2.1
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.11.0
Expand Down Expand Up @@ -150,7 +150,7 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/huandu/xstrings v1.3.2 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.0.1 // indirect
github.com/inhies/go-bytesize v0.0.0-20210819104631-275770b98743 // indirect
github.com/instana/go-sensor v1.41.1 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
Expand Down Expand Up @@ -213,11 +213,11 @@ require (
github.com/shopspring/decimal v1.3.1 // indirect
github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d // indirect
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e // indirect
github.com/spf13/afero v1.8.2 // indirect
github.com/spf13/afero v1.9.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/viper v1.12.0 // indirect
github.com/subosito/gotenv v1.3.0 // indirect
github.com/subosito/gotenv v1.4.1 // indirect
github.com/thales-e-security/pool v0.0.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
Expand Down Expand Up @@ -256,15 +256,15 @@ require (
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2 // indirect
golang.org/x/sys v0.0.0-20220919091848-fb04ddd9f9c8 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect
golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20220525015930-6ca3db687a9d // indirect
google.golang.org/grpc v1.46.2 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/ini.v1 v1.66.4 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/op/go-logging.v1 v1.0.0-20160211212156-b2cb9fa56473 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
Loading

0 comments on commit 1d13be6

Please sign in to comment.