Skip to content

Commit

Permalink
Merge pull request #5423 from lyonlai/ylai/2019-11-11/one-copy-of-tls…
Browse files Browse the repository at this point in the history
…-certificate-per-file

Load tls certificate and x509 cert pool once per file to reduce memory usage
  • Loading branch information
Daniel Kozlowski authored Nov 20, 2019
2 parents 0b3de7c + 6888606 commit 83bfb04
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 40 deletions.
156 changes: 138 additions & 18 deletions go/vt/tlstest/tlstest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package tlstest

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
Expand All @@ -29,6 +30,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/vt/vttls"
)

Expand All @@ -45,26 +47,20 @@ func TestClientServer(t *testing.T) {
}
defer os.RemoveAll(root)

// Create the certs and configs.
CreateCA(root)

CreateSignedCert(root, CA, "01", "servers", "Servers CA")
CreateSignedCert(root, "servers", "01", "server-instance", "server.example.com")
clientServerKeyPairs := createClientServerCertPairs(root)

CreateSignedCert(root, CA, "02", "clients", "Clients CA")
CreateSignedCert(root, "clients", "01", "client-instance", "Client Instance")
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-instance-cert.pem"),
path.Join(root, "server-instance-key.pem"),
path.Join(root, "clients-cert.pem"))
clientServerKeyPairs.serverCert,
clientServerKeyPairs.serverKey,
clientServerKeyPairs.clientCA)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
clientConfig, err := vttls.ClientConfig(
path.Join(root, "client-instance-cert.pem"),
path.Join(root, "client-instance-key.pem"),
path.Join(root, "servers-cert.pem"),
"server.example.com")
clientServerKeyPairs.clientCert,
clientServerKeyPairs.clientKey,
clientServerKeyPairs.serverCA,
clientServerKeyPairs.serverName)
if err != nil {
t.Fatalf("TLSClientConfig failed: %v", err)
}
Expand Down Expand Up @@ -121,10 +117,10 @@ func TestClientServer(t *testing.T) {
//

badClientConfig, err := vttls.ClientConfig(
path.Join(root, "server-instance-cert.pem"),
path.Join(root, "server-instance-key.pem"),
path.Join(root, "servers-cert.pem"),
"server.example.com")
clientServerKeyPairs.serverCert,
clientServerKeyPairs.serverKey,
clientServerKeyPairs.serverCA,
clientServerKeyPairs.serverName)
if err != nil {
t.Fatalf("TLSClientConfig failed: %v", err)
}
Expand Down Expand Up @@ -168,3 +164,127 @@ func TestClientServer(t *testing.T) {
t.Errorf("Wrong error returned: %v", err)
}
}

var serialCounter = 0

type clientServerKeyPairs struct {
serverCert string
serverKey string
serverCA string
serverName string
clientCert string
clientKey string
clientCA string
}

func createClientServerCertPairs(root string) clientServerKeyPairs {

// Create the certs and configs.
CreateCA(root)

serverSerial := fmt.Sprintf("%03d", serialCounter*2+1)
clientSerial := fmt.Sprintf("%03d", serialCounter*2+2)

serialCounter = serialCounter + 1

serverName := fmt.Sprintf("server-%s", serverSerial)
serverCACommonName := fmt.Sprintf("Server %s CA", serverSerial)
serverCertName := fmt.Sprintf("server-instance-%s", serverSerial)
serverCertCommonName := fmt.Sprintf("server%s.example.com", serverSerial)

clientName := fmt.Sprintf("clients-%s", serverSerial)
clientCACommonName := fmt.Sprintf("Clients %s CA", serverSerial)
clientCertName := fmt.Sprintf("client-instance-%s", serverSerial)
clientCertCommonName := fmt.Sprintf("Client Instance %s", serverSerial)

CreateSignedCert(root, CA, serverSerial, serverName, serverCACommonName)
CreateSignedCert(root, serverName, serverSerial, serverCertName, serverCertCommonName)

CreateSignedCert(root, CA, clientSerial, clientName, clientCACommonName)
CreateSignedCert(root, clientName, serverSerial, clientCertName, clientCertCommonName)

return clientServerKeyPairs{
serverCert: path.Join(root, fmt.Sprintf("%s-cert.pem", serverCertName)),
serverKey: path.Join(root, fmt.Sprintf("%s-key.pem", serverCertName)),
serverCA: path.Join(root, fmt.Sprintf("%s-cert.pem", serverName)),
clientCert: path.Join(root, fmt.Sprintf("%s-cert.pem", clientCertName)),
clientKey: path.Join(root, fmt.Sprintf("%s-key.pem", clientCertName)),
clientCA: path.Join(root, fmt.Sprintf("%s-cert.pem", clientName)),
serverName: serverCertCommonName,
}

}

func getServerConfig(keypairs clientServerKeyPairs) (*tls.Config, error) {
return vttls.ServerConfig(
keypairs.clientCert,
keypairs.clientKey,
keypairs.serverCA)
}

func getClientConfig(keypairs clientServerKeyPairs) (*tls.Config, error) {
return vttls.ClientConfig(
keypairs.clientCert,
keypairs.clientKey,
keypairs.serverCA,
keypairs.serverName)
}

func TestServerTLSConfigCaching(t *testing.T) {
testConfigGeneration(t, "servertlstest", getServerConfig, func(config *tls.Config) *x509.CertPool {
return config.ClientCAs
})
}

func TestClientTLSConfigCaching(t *testing.T) {
testConfigGeneration(t, "clienttlstest", getClientConfig, func(config *tls.Config) *x509.CertPool {
return config.RootCAs
})
}

func testConfigGeneration(t *testing.T, rootPrefix string, generateConfig func(clientServerKeyPairs) (*tls.Config, error), getCertPool func(tlsConfig *tls.Config) *x509.CertPool) {
// Our test root.
root, err := ioutil.TempDir("", rootPrefix)
if err != nil {
t.Fatalf("TempDir failed: %v", err)
}
defer os.RemoveAll(root)

const configsToGenerate = 1

firstClientServerKeyPairs := createClientServerCertPairs(root)
secondClientServerKeyPairs := createClientServerCertPairs(root)

firstExpectedConfig, _ := generateConfig(firstClientServerKeyPairs)
secondExpectedConfig, _ := generateConfig(secondClientServerKeyPairs)
firstConfigChannel := make(chan *tls.Config, configsToGenerate)
secondConfigChannel := make(chan *tls.Config, configsToGenerate)

var configCounter = 0

for i := 1; i <= configsToGenerate; i++ {
go func() {
firstConfig, _ := generateConfig(firstClientServerKeyPairs)
firstConfigChannel <- firstConfig
secondConfig, _ := generateConfig(secondClientServerKeyPairs)
secondConfigChannel <- secondConfig
}()
}

for {
select {
case firstConfig := <-firstConfigChannel:
assert.Equal(t, &firstExpectedConfig.Certificates, &firstConfig.Certificates)
assert.Equal(t, getCertPool(firstExpectedConfig), getCertPool(firstConfig))
case secondConfig := <-secondConfigChannel:
assert.Equal(t, &secondExpectedConfig.Certificates, &secondConfig.Certificates)
assert.Equal(t, getCertPool(secondExpectedConfig), getCertPool(secondConfig))
}
configCounter = configCounter + 1

if configCounter >= 2*configsToGenerate {
break
}
}

}
132 changes: 110 additions & 22 deletions go/vt/vttls/vttls.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ package vttls
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"strings"
"sync"

"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
)

// Updated list of acceptable cipher suits to address
Expand Down Expand Up @@ -51,31 +55,33 @@ func newTLSConfig() *tls.Config {
}
}

var onceByKeys = sync.Map{}

// ClientConfig returns the TLS config to use for a client to
// connect to a server with the provided parameters.
func ClientConfig(cert, key, ca, name string) (*tls.Config, error) {
config := newTLSConfig()

// Load the client-side cert & key if any.
if cert != "" && key != "" {
crt, err := tls.LoadX509KeyPair(cert, key)
certificates, err := loadTLSCertificate(cert, key)

if err != nil {
return nil, fmt.Errorf("failed to load cert/key: %v", err)
return nil, err
}
config.Certificates = []tls.Certificate{crt}

config.Certificates = *certificates
}

// Load the server CA if any.
if ca != "" {
b, err := ioutil.ReadFile(ca)
certificatePool, err := loadx509CertPool(ca)

if err != nil {
return nil, fmt.Errorf("failed to read ca file: %v", err)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(b) {
return nil, fmt.Errorf("failed to append certificates")
return nil, err
}
config.RootCAs = cp

config.RootCAs = certificatePool
}

// Set the server name if any.
Expand All @@ -91,27 +97,109 @@ func ClientConfig(cert, key, ca, name string) (*tls.Config, error) {
func ServerConfig(cert, key, ca string) (*tls.Config, error) {
config := newTLSConfig()

// Load the server cert and key.
crt, err := tls.LoadX509KeyPair(cert, key)
certificates, err := loadTLSCertificate(cert, key)

if err != nil {
return nil, fmt.Errorf("failed to load cert/key: %v", err)
return nil, err
}
config.Certificates = []tls.Certificate{crt}

config.Certificates = *certificates

// if specified, load ca to validate client,
// and enforce clients present valid certs.
if ca != "" {
b, err := ioutil.ReadFile(ca)
certificatePool, err := loadx509CertPool(ca)

if err != nil {
return nil, fmt.Errorf("failed to read ca file: %v", err)
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(b) {
return nil, fmt.Errorf("failed to append certificates")
return nil, err
}
config.ClientCAs = cp

config.ClientCAs = certificatePool
config.ClientAuth = tls.RequireAndVerifyClientCert
}

return config, nil
}

var certPools = sync.Map{}

func loadx509CertPool(ca string) (*x509.CertPool, error) {
once, _ := onceByKeys.LoadOrStore(ca, &sync.Once{})

var err error
once.(*sync.Once).Do(func() {
err = doLoadx509CertPool(ca)
})
if err != nil {
return nil, err
}

result, ok := certPools.Load(ca)

if !ok {
return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded x509 cert pool for ca: %s", ca)
}

return result.(*x509.CertPool), nil
}

func doLoadx509CertPool(ca string) error {
b, err := ioutil.ReadFile(ca)
if err != nil {
return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read ca file: %s", ca)
}

cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(b) {
return vterrors.Errorf(vtrpc.Code_UNKNOWN, "failed to append certificates")
}

certPools.Store(ca, cp)

return nil
}

var tlsCertificates = sync.Map{}

func tlsCertificatesIdentifier(cert, key string) string {
return strings.Join([]string{cert, key}, ";")
}

func loadTLSCertificate(cert, key string) (*[]tls.Certificate, error) {
tlsIdentifier := tlsCertificatesIdentifier(cert, key)
once, _ := onceByKeys.LoadOrStore(tlsIdentifier, &sync.Once{})

var err error
once.(*sync.Once).Do(func() {
err = doLoadTLSCertificate(cert, key)
})

if err != nil {
return nil, err
}

result, ok := tlsCertificates.Load(tlsIdentifier)

if !ok {
return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded tls certificate with cert: %s, key%s", cert, key)
}

return result.(*[]tls.Certificate), nil
}

func doLoadTLSCertificate(cert, key string) error {
tlsIdentifier := tlsCertificatesIdentifier(cert, key)

var certificate []tls.Certificate
// Load the server cert and key.
crt, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to load tls certificate, cert %s, key: %s", cert, key)
}

certificate = []tls.Certificate{crt}

tlsCertificates.Store(tlsIdentifier, &certificate)

return nil
}

0 comments on commit 83bfb04

Please sign in to comment.