Skip to content

Commit

Permalink
pkg/security(ticdc): support online load tls config (#7927)
Browse files Browse the repository at this point in the history
close #7908
  • Loading branch information
CharlesCheung96 committed Dec 20, 2022
1 parent f2bee61 commit 3266836
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 140 deletions.
101 changes: 69 additions & 32 deletions cdc/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (
mock_etcd "github.com/pingcap/tiflow/pkg/etcd/mock"
"github.com/pingcap/tiflow/pkg/httputil"
"github.com/pingcap/tiflow/pkg/retry"
security2 "github.com/pingcap/tiflow/pkg/security"
"github.com/pingcap/tiflow/pkg/security"
"github.com/pingcap/tiflow/pkg/util"
"github.com/stretchr/testify/require"
"github.com/tikv/pd/pkg/tempurl"
Expand Down Expand Up @@ -174,12 +174,12 @@ const retryTime = 20
func TestServerTLSWithoutCommonName(t *testing.T) {
addr := tempurl.Alloc()[len("http://"):]
// Do not specify common name
security, err := security2.NewCredential4Test("")
_, securityCfg, err := security.NewServerCredential4Test("")
require.Nil(t, err)
conf := config.GetDefaultServerConfig()
conf.Addr = addr
conf.AdvertiseAddr = addr
conf.Security = &security
conf.Security = securityCfg
config.StoreGlobalServerConfig(conf)

server, err := New([]string{"https://127.0.0.1:2379"})
Expand All @@ -205,7 +205,7 @@ func TestServerTLSWithoutCommonName(t *testing.T) {
go func() {
defer wg.Done()
err := server.tcpServer.Run(ctx)
require.Contains(t, err.Error(), "ErrTCPServerClosed")
require.ErrorContains(t, err, "ErrTCPServerClosed")
}()

// test cli sends request without a cert will success
Expand All @@ -227,12 +227,13 @@ func TestServerTLSWithoutCommonName(t *testing.T) {
require.Nil(t, err)
require.Equal(t, info.ID, captureInfo.ID)
return nil
}, retry.WithMaxTries(retryTime), retry.WithBackoffBaseDelay(50), retry.WithIsRetryableErr(cerrors.IsRetryableError))
}, retry.WithMaxTries(retryTime), retry.WithBackoffBaseDelay(50),
retry.WithIsRetryableErr(cerrors.IsRetryableError))
require.Nil(t, err)

// test cli sends request with a cert will success
err = retry.Do(ctx, func() error {
cli, err := httputil.NewClient(&security)
cli, err := httputil.NewClient(securityCfg)
require.Nil(t, err)
resp, err := cli.Get(ctx, statusURL)
if err != nil {
Expand All @@ -247,22 +248,25 @@ func TestServerTLSWithoutCommonName(t *testing.T) {
require.Equal(t, info.ID, captureInfo.ID)
resp.Body.Close()
return nil
}, retry.WithMaxTries(retryTime), retry.WithBackoffBaseDelay(50), retry.WithIsRetryableErr(cerrors.IsRetryableError))
}, retry.WithMaxTries(retryTime), retry.WithBackoffBaseDelay(50),
retry.WithIsRetryableErr(cerrors.IsRetryableError))
require.Nil(t, err)

cancel()
wg.Wait()
}

func TestServerTLSWithCommonName(t *testing.T) {
func TestServerTLSWithCommonNameAndRotate(t *testing.T) {
addr := tempurl.Alloc()[len("http://"):]
// specify a common name
security, err := security2.NewCredential4Test("test")
ca, securityCfg, err := security.NewServerCredential4Test("server")
securityCfg.CertAllowedCN = append(securityCfg.CertAllowedCN, "client1")
require.Nil(t, err)

conf := config.GetDefaultServerConfig()
conf.Addr = addr
conf.AdvertiseAddr = addr
conf.Security = &security
conf.Security = securityCfg
config.StoreGlobalServerConfig(conf)

server, err := New([]string{"https://127.0.0.1:2379"})
Expand All @@ -280,15 +284,15 @@ func TestServerTLSWithCommonName(t *testing.T) {
}()

statusURL := fmt.Sprintf("https://%s/api/v1/status", addr)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err := server.tcpServer.Run(ctx)
require.Contains(t, err.Error(), "ErrTCPServerClosed")
require.ErrorContains(t, err, "ErrTCPServerClosed")
}()

// test cli sends request without a cert will fail
Expand All @@ -311,26 +315,59 @@ func TestServerTLSWithCommonName(t *testing.T) {
require.Equal(t, info.ID, captureInfo.ID)
resp.Body.Close()
return nil
}, retry.WithMaxTries(retryTime), retry.WithBackoffBaseDelay(50), retry.WithIsRetryableErr(cerrors.IsRetryableError))
require.Contains(t, err.Error(), "remote error: tls: bad certificate")
}, retry.WithMaxTries(retryTime), retry.WithBackoffBaseDelay(50),
retry.WithIsRetryableErr(cerrors.IsRetryableError))
require.ErrorContains(t, err, "remote error: tls: bad certificate")

testTlSClient := func(securityCfg *security.Credential) error {
return retry.Do(ctx, func() error {
cli, err := httputil.NewClient(securityCfg)
require.Nil(t, err)
resp, err := cli.Get(ctx, statusURL)
if err != nil {
return err
}
decoder := json.NewDecoder(resp.Body)
captureInfo := &model.CaptureInfo{}
err = decoder.Decode(captureInfo)
require.Nil(t, err)
info, err := server.capture.Info()
require.Nil(t, err)
require.Equal(t, info.ID, captureInfo.ID)
resp.Body.Close()
return nil
}, retry.WithMaxTries(retryTime), retry.WithBackoffBaseDelay(50),
retry.WithIsRetryableErr(cerrors.IsRetryableError))
}

// test cli sends request with a cert will success
err = retry.Do(ctx, func() error {
cli, err := httputil.NewClient(&security)
require.Nil(t, err)
resp, err := cli.Get(ctx, statusURL)
if err != nil {
return err
}
decoder := json.NewDecoder(resp.Body)
captureInfo := &model.CaptureInfo{}
err = decoder.Decode(captureInfo)
require.Nil(t, err)
info, err := server.capture.Info()
require.Nil(t, err)
require.Equal(t, info.ID, captureInfo.ID)
resp.Body.Close()
return nil
}, retry.WithMaxTries(retryTime), retry.WithBackoffBaseDelay(50), retry.WithIsRetryableErr(cerrors.IsRetryableError))
require.Nil(t, err)

// test peer success
require.NoError(t, testTlSClient(securityCfg))

// test rotate
serverCert, serverkey, err := ca.GenerateCerts("rotate")
require.NoError(t, err)
err = os.WriteFile(securityCfg.CertPath, serverCert, 0o600)
require.NoError(t, err)
err = os.WriteFile(securityCfg.KeyPath, serverkey, 0o600)
require.NoError(t, err)
// peer fail due to invalid common name `rotate`
require.ErrorContains(t, testTlSClient(securityCfg), "client certificate authentication failed")

cert, key, err := ca.GenerateCerts("client1")
require.NoError(t, err)
certPath, err := security.WriteFile("ticdc-test-client-cert", cert)
require.NoError(t, err)
keyPath, err := security.WriteFile("ticdc-test-client-key", key)
require.NoError(t, err)
require.NoError(t, testTlSClient(&security.Credential{
CAPath: securityCfg.CAPath,
CertPath: certPath,
KeyPath: keyPath,
CertAllowedCN: []string{"rotate"},
}))

cancel()
wg.Wait()
}
10 changes: 8 additions & 2 deletions engine/pkg/client/executor_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package client

import (
"io/ioutil"
"testing"
"time"

Expand All @@ -34,10 +35,15 @@ func TestNewExecutorClientNotBlocking(t *testing.T) {
func TestExecutorClientFactoryIllegalCredentials(t *testing.T) {
t.Parallel()

dir := t.TempDir()
caPath := dir + "/ca.pem"
err := ioutil.WriteFile(caPath, []byte("invalid ca pem"), 0o600)
require.NoError(t, err)

credentials := &security.Credential{
CAPath: "/dev/null", // illegal CA path to trigger an error
CAPath: caPath, // illegal CA path to trigger an error
}
factory := newExecutorClientFactory(credentials, nil)
_, err := factory.NewExecutorClient("127.0.0.1:1234")
_, err = factory.NewExecutorClient("127.0.0.1:1234")
require.Error(t, err)
}
97 changes: 88 additions & 9 deletions pkg/security/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"io/ioutil"
"os"
"strings"

"github.com/pingcap/tidb-tools/pkg/utils"
cerror "github.com/pingcap/tiflow/pkg/errors"
"github.com/pingcap/tiflow/pkg/errors"
pd "github.com/tikv/pd/client"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -64,15 +65,15 @@ func (s *Credential) ToGRPCDialOption() (grpc.DialOption, error) {

// ToTLSConfig generates tls's config from *Security
func (s *Credential) ToTLSConfig() (*tls.Config, error) {
cfg, err := utils.ToTLSConfig(s.CAPath, s.CertPath, s.KeyPath)
return cfg, cerror.WrapError(cerror.ErrToTLSConfigFailed, err)
cfg, err := ToTLSConfigWithVerify(s.CAPath, s.CertPath, s.KeyPath, nil)
return cfg, errors.WrapError(errors.ErrToTLSConfigFailed, err)
}

// ToTLSConfigWithVerify generates tls's config from *Security and requires
// the remote common name to be verified.
func (s *Credential) ToTLSConfigWithVerify() (*tls.Config, error) {
cfg, err := utils.ToTLSConfigWithVerify(s.CAPath, s.CertPath, s.KeyPath, s.CertAllowedCN)
return cfg, cerror.WrapError(cerror.ErrToTLSConfigFailed, err)
cfg, err := ToTLSConfigWithVerify(s.CAPath, s.CertPath, s.KeyPath, s.CertAllowedCN)
return cfg, errors.WrapError(errors.ErrToTLSConfigFailed, err)
}

func (s *Credential) getSelfCommonName() (string, error) {
Expand All @@ -81,15 +82,16 @@ func (s *Credential) getSelfCommonName() (string, error) {
}
data, err := os.ReadFile(s.CertPath)
if err != nil {
return "", cerror.WrapError(cerror.ErrToTLSConfigFailed, err)
return "", errors.WrapError(errors.ErrToTLSConfigFailed, err)
}
block, _ := pem.Decode(data)
if block == nil || block.Type != "CERTIFICATE" {
return "", cerror.ErrToTLSConfigFailed.GenWithStack("failed to decode PEM block to certificate")
return "", errors.ErrToTLSConfigFailed.
GenWithStack("failed to decode PEM block to certificate")
}
certificate, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return "", cerror.WrapError(cerror.ErrToTLSConfigFailed, err)
return "", errors.WrapError(errors.ErrToTLSConfigFailed, err)
}
return certificate.Subject.CommonName, nil
}
Expand All @@ -107,3 +109,80 @@ func (s *Credential) AddSelfCommonName() error {
s.CertAllowedCN = append(s.CertAllowedCN, cn)
return nil
}

// ToTLSConfigWithVerify constructs a `*tls.Config` from the CA, certification and key
// paths, and add verify for CN.
//
// If the CA path is empty, returns nil.
func ToTLSConfigWithVerify(
caPath, certPath, keyPath string, verifyCN []string,
) (*tls.Config, error) {
if len(caPath) == 0 {
return nil, nil
}

// Create a certificate pool from CA
certPool := x509.NewCertPool()
ca, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, errors.Annotate(err, "could not read ca certificate")
}

// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
return nil, errors.New("failed to append ca certs")
}

tlsCfg := &tls.Config{
RootCAs: certPool,
ClientCAs: certPool,
NextProtos: []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2.
MinVersion: tls.VersionTLS12,
}

if len(certPath) != 0 && len(keyPath) != 0 {
loadCert := func() (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, errors.Annotate(err, "could not load client key pair")
}
return &cert, nil
}
tlsCfg.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return loadCert()
}
tlsCfg.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return loadCert()
}
}

addVerifyPeerCertificate(tlsCfg, verifyCN)
return tlsCfg, nil
}

func addVerifyPeerCertificate(tlsCfg *tls.Config, verifyCN []string) {
if len(verifyCN) != 0 {
checkCN := make(map[string]struct{})
for _, cn := range verifyCN {
cn = strings.TrimSpace(cn)
checkCN[cn] = struct{}{}
}
tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert
tlsCfg.VerifyPeerCertificate = func(
rawCerts [][]byte, verifiedChains [][]*x509.Certificate,
) error {
cns := make([]string, 0, len(verifiedChains))
for _, chains := range verifiedChains {
for _, chain := range chains {
cns = append(cns, chain.Subject.CommonName)
if _, match := checkCN[chain.Subject.CommonName]; match {
return nil
}
}
}
return errors.Errorf("client certificate authentication failed. "+
"The Common Name from the client certificate %v was not found "+
"in the configuration cluster-verify-cn with value: %s", cns, verifyCN)
}
}
}
Loading

0 comments on commit 3266836

Please sign in to comment.