Skip to content

Commit

Permalink
Merge 57c5a04 into f9338e4
Browse files Browse the repository at this point in the history
  • Loading branch information
StarAurryon authored Mar 17, 2022
2 parents f9338e4 + 57c5a04 commit 45a4254
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 9 deletions.
4 changes: 2 additions & 2 deletions cmd/server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,15 @@ func serve(
handler http.Handler,
address string,
permission *configx.UnixPermission,
cert []tls.Certificate,
getCert func(*tls.ClientHelloInfo) (*tls.Certificate, error),
) {
defer wg.Done()

var srv = graceful.WithDefaults(&http.Server{
Handler: handler,
// #nosec G402 - This is a false positive because we use graceful.WithDefaults which sets the correct TLS settings.
TLSConfig: &tls.Config{
Certificates: cert,
GetCertificate: getCert,
},
})

Expand Down
149 changes: 145 additions & 4 deletions cmd/server/helper_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"path"
"runtime"
"sync"
"time"

"gopkg.in/square/go-jose.v2"

Expand All @@ -39,6 +43,8 @@ import (
"github.com/ory/x/tlsx"

"github.com/ory/hydra/jwk"

"github.com/fsnotify/fsnotify"
)

const (
Expand All @@ -54,11 +60,11 @@ func AttachCertificate(priv *jose.JSONWebKey, cert *x509.Certificate) {
priv.CertificateThumbprintSHA1 = sig1[:]
}

func GetOrCreateTLSCertificate(cmd *cobra.Command, d driver.Registry, iface config.ServeInterface) []tls.Certificate {
cert, err := d.Config().TLS(iface).Certificate()
func GetOrCreateTLSCertificate(cmd *cobra.Command, d driver.Registry, iface config.ServeInterface) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, location, err := d.Config().TLS(iface).Certificate()

if err == nil {
return cert
return newCertificatesProvider(cert, location, d, iface).getCertificate
} else if !errors.Is(err, tlsx.ErrNoCertificatesConfigured) {
d.Logger().WithError(err).Fatalf("Unable to load HTTPS TLS Certificate")
}
Expand Down Expand Up @@ -100,5 +106,140 @@ func GetOrCreateTLSCertificate(cmd *cobra.Command, d driver.Registry, iface conf
d.Logger().WithError(err).Fatalf("Could not decode certificate")
}

return []tls.Certificate{ct}
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return &ct, nil
}
}

type certificatesProvider struct {
certs []tls.Certificate
mu sync.Mutex
iface config.ServeInterface
certLocation *config.CertLocation
d driver.Registry
watcher *fsnotify.Watcher
}

func newCertificatesProvider(certs []tls.Certificate, certLocation *config.CertLocation, d driver.Registry, iface config.ServeInterface) *certificatesProvider {
ret := &certificatesProvider{
certLocation: certLocation,
d: d,
iface: iface,
}
ret.load(certs)
if certLocation != nil {
ret.watchCertificatesChanges()
}

runtime.SetFinalizer(ret, func(ret *certificatesProvider) { ret.stop() })

return ret
}

func (p *certificatesProvider) load(certs []tls.Certificate) {
for i := range certs {
tlsCert := &certs[i]
if tlsCert.Leaf != nil {
continue
}
for _, bCert := range tlsCert.Certificate {
cert, _ := x509.ParseCertificate(bCert)
if !cert.IsCA {
tlsCert.Leaf = cert
}
}
}

p.mu.Lock()
defer p.mu.Unlock()
p.certs = certs
}

func (p *certificatesProvider) watchCertificatesChanges() {
var err error
p.watcher, err = fsnotify.NewWatcher()
if err != nil {
p.d.Logger().WithError(err).Fatalf("Could not activate certificate change watcher")
}

go func() {
p.d.Logger().Infof("Starting tls certificate auto-refresh")
for {
select {
case _, ok := <-p.watcher.Events:
if !ok {
return
}

p.waitForAllFilesChanges()

p.d.Logger().Infof("TLS certificates changed, updating")
certs, _, err := p.d.Config().TLS(p.iface).Certificate()
if err != nil {
p.d.Logger().WithError(err).Fatalf("Error in the new tls certificates")
return
}
p.load(certs)
case err, ok := <-p.watcher.Errors:
if !ok {
return
}
p.d.Logger().WithError(err).Fatalf("Error occured in the tls certificate change watcher")
}
}
}()

certPath := path.Dir(p.certLocation.CertPath)
keyPath := path.Dir(p.certLocation.KeyPath)

err = p.watcher.Add(certPath)
if err != nil {
p.d.Logger().WithError(err).Fatalf("Error watching the certFolder for tls certificate change")
}

if certPath != keyPath {
err = p.watcher.Add(keyPath)
if err != nil {
p.d.Logger().WithError(err).Fatalf("Error watching the keyFolder for tls certificate change")
}
}
}

func (p *certificatesProvider) waitForAllFilesChanges() {
flushUntil := time.After(2 * time.Second)
p.d.Logger().Infof("TLS certificates files changed, waiting for changes to finish")
stop := false
for {
select {
case <-flushUntil:
stop = true
case <-p.watcher.Events:
continue
}

if stop {
break
}
}
}

func (p *certificatesProvider) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
p.mu.Lock()
defer p.mu.Unlock()

if hello != nil {
for _, cert := range p.certs {
if cert.Leaf != nil && cert.Leaf.VerifyHostname(hello.ServerName) == nil {
return &cert, nil
}
}
}
return &p.certs[0], nil
}

func (p *certificatesProvider) stop() {
if p.watcher != nil {
p.watcher.Close()
p.watcher = nil
}
}
15 changes: 12 additions & 3 deletions driver/config/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
type TLSConfig interface {
Enabled() bool
AllowTerminationFrom() []string
Certificate() ([]tls.Certificate, error)
Certificate() ([]tls.Certificate, *CertLocation, error)
}

func (p *Provider) TLS(iface ServeInterface) TLSConfig {
Expand Down Expand Up @@ -65,8 +65,17 @@ func (c *tlsConfig) AllowTerminationFrom() []string {
return c.allowTerminationFrom
}

func (c *tlsConfig) Certificate() ([]tls.Certificate, error) {
return tlsx.Certificate(c.certString, c.keyString, c.certPath, c.keyPath)
type CertLocation struct {
CertPath string
KeyPath string
}

func (c *tlsConfig) Certificate() ([]tls.Certificate, *CertLocation, error) {
certs, err := tlsx.Certificate(c.certString, c.keyString, c.certPath, c.keyPath)
if c.certString != "" && c.keyString != "" {
return certs, nil, err
}
return certs, &CertLocation{CertPath: c.certPath, KeyPath: c.keyPath}, err
}

func (p *Provider) forcedHTTP() bool {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/ThalesIgnite/crypto11 v1.2.4
github.com/cenkalti/backoff/v3 v3.0.0
github.com/evanphx/json-patch v4.9.0+incompatible
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/go-bindata/go-bindata v3.1.2+incompatible
github.com/go-openapi/errors v0.20.1
github.com/go-openapi/runtime v0.20.0
Expand Down

0 comments on commit 45a4254

Please sign in to comment.