Skip to content

Commit

Permalink
Flakes: Setup new fake server if it has gone away (#17023)
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Lord <[email protected]>
  • Loading branch information
mattlord authored Oct 22, 2024
1 parent 0e22a3e commit ba129c7
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 253 deletions.
70 changes: 39 additions & 31 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"
"vitess.io/vitess/go/vt/tlstest"
"vitess.io/vitess/go/vt/vttls"
)

const clientCertUsername = "Client Cert"

// The listener's Accept() loop actually only ends on a connection
// error, which will occur when trying to connect after the listener
// has been closed. So this function closes the listener and then
// calls Connect to trigger the error which ends that work.
var cleanupListener = func(ctx context.Context, l *Listener, params *ConnParams) {
l.Close()
_, _ = Connect(ctx, params)
}

func TestValidCert(t *testing.T) {
ctx := utils.LeakCheckContext(t)
th := &testHandler{}

authServer := newAuthServerClientCert(string(MysqlClearPassword))
Expand All @@ -52,21 +63,6 @@ func TestValidCert(t *testing.T) {
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", clientCertUsername)
tlstest.CreateCRL(root, tlstest.CA)

// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
require.NoError(t, err, "TLSServerConfig failed: %v", err)

l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()

// Setup the right parameters.
params := &ConnParams{
Host: host,
Expand All @@ -81,7 +77,20 @@ func TestValidCert(t *testing.T) {
ServerName: "server.example.com",
}

ctx := context.Background()
// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
require.NoError(t, err, "TLSServerConfig failed: %v", err)

l.TLSConfig.Store(serverConfig)
go l.Accept()
defer cleanupListener(ctx, l, params)

conn, err := Connect(ctx, params)
require.NoError(t, err, "Connect failed: %v", err)

Expand All @@ -103,6 +112,7 @@ func TestValidCert(t *testing.T) {
}

func TestNoCert(t *testing.T) {
ctx := utils.LeakCheckContext(t)
th := &testHandler{}

authServer := newAuthServerClientCert(string(MysqlClearPassword))
Expand All @@ -120,6 +130,17 @@ func TestNoCert(t *testing.T) {
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateCRL(root, tlstest.CA)

// Setup the right parameters.
params := &ConnParams{
Host: host,
Port: port,
Uname: "user1",
Pass: "",
SslMode: vttls.VerifyIdentity,
SslCa: path.Join(root, "ca-cert.pem"),
ServerName: "server.example.com",
}

// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
Expand All @@ -131,22 +152,9 @@ func TestNoCert(t *testing.T) {
require.NoError(t, err, "TLSServerConfig failed: %v", err)

l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()

// Setup the right parameters.
params := &ConnParams{
Host: host,
Port: port,
Uname: "user1",
Pass: "",
SslMode: vttls.VerifyIdentity,
SslCa: path.Join(root, "ca-cert.pem"),
ServerName: "server.example.com",
}
go l.Accept()
defer cleanupListener(ctx, l, params)

ctx := context.Background()
conn, err := Connect(ctx, params)
assert.Error(t, err, "Connect() should have errored due to no client cert")

Expand Down
24 changes: 20 additions & 4 deletions go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ type AuthServerStatic struct {
// entries contains the users, passwords and user data.
entries map[string][]*AuthServerStaticEntry

// Signal handling related fields.
sigChan chan os.Signal
ticker *time.Ticker
done chan struct{} // Tell the signal related goroutines to stop
}

// AuthServerStaticEntry stores the values for a given user.
Expand Down Expand Up @@ -267,26 +269,40 @@ func (a *AuthServerStatic) installSignalHandlers() {
return
}

a.done = make(chan struct{})
a.sigChan = make(chan os.Signal, 1)
signal.Notify(a.sigChan, syscall.SIGHUP)
go func() {
for range a.sigChan {
a.reload()
for {
select {
case <-a.done:
return
case <-a.sigChan:
a.reload()
}
}
}()

// If duration is set, it will reload configuration every interval
if a.reloadInterval > 0 {
a.ticker = time.NewTicker(a.reloadInterval)
go func() {
for range a.ticker.C {
a.sigChan <- syscall.SIGHUP
for {
select {
case <-a.done:
return
case <-a.ticker.C:
a.sigChan <- syscall.SIGHUP
}
}
}()
}
}

func (a *AuthServerStatic) close() {
if a.done != nil {
close(a.done)
}
if a.ticker != nil {
a.ticker.Stop()
}
Expand Down
29 changes: 25 additions & 4 deletions go/mysql/auth_server_static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"time"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"
)

// getEntries is a test-only method for AuthServerStatic.
Expand All @@ -35,6 +37,7 @@ func (a *AuthServerStatic) getEntries() map[string][]*AuthServerStaticEntry {
}

func TestJsonConfigParser(t *testing.T) {
_ = utils.LeakCheckContext(t)
// works with legacy format
config := make(map[string][]*AuthServerStaticEntry)
jsonConfig := "{\"mysql_user\":{\"Password\":\"123\", \"UserData\":\"dummy\"}, \"mysql_user_2\": {\"Password\": \"123\", \"UserData\": \"mysql_user_2\"}}"
Expand Down Expand Up @@ -67,6 +70,7 @@ func TestJsonConfigParser(t *testing.T) {
}

func TestValidateHashGetter(t *testing.T) {
_ = utils.LeakCheckContext(t)
jsonConfig := `{"mysql_user": [{"Password": "password", "UserData": "user.name", "Groups": ["user_group"]}]}`

auth := NewAuthServerStatic("", jsonConfig, 0)
Expand All @@ -90,6 +94,7 @@ func TestValidateHashGetter(t *testing.T) {
}

func TestHostMatcher(t *testing.T) {
_ = utils.LeakCheckContext(t)
ip := net.ParseIP("192.168.0.1")
addr := &net.TCPAddr{IP: ip, Port: 9999}
match := MatchSourceHost(net.Addr(addr), "")
Expand All @@ -105,9 +110,9 @@ func TestHostMatcher(t *testing.T) {
}

func TestStaticConfigHUP(t *testing.T) {
_ = utils.LeakCheckContext(t)
tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json")
require.NoError(t, err, "couldn't create temp file: %v", err)

defer os.Remove(tmpFile.Name())

oldStr := "str5"
Expand All @@ -125,14 +130,19 @@ func TestStaticConfigHUP(t *testing.T) {

mu.Lock()
defer mu.Unlock()
// delete registered Auth server
clear(authServers)
// Delete registered Auth servers.
for k, v := range authServers {
if s, ok := v.(*AuthServerStatic); ok {
s.close()
}
delete(authServers, k)
}
}

func TestStaticConfigHUPWithRotation(t *testing.T) {
_ = utils.LeakCheckContext(t)
tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json")
require.NoError(t, err, "couldn't create temp file: %v", err)

defer os.Remove(tmpFile.Name())

oldStr := "str1"
Expand All @@ -147,6 +157,16 @@ func TestStaticConfigHUPWithRotation(t *testing.T) {

hupTestWithRotation(t, aStatic, tmpFile, oldStr, "str4")
hupTestWithRotation(t, aStatic, tmpFile, "str4", "str5")

mu.Lock()
defer mu.Unlock()
// Delete registered Auth servers.
for k, v := range authServers {
if s, ok := v.(*AuthServerStatic); ok {
s.close()
}
delete(authServers, k)
}
}

func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) {
Expand Down Expand Up @@ -178,6 +198,7 @@ func hupTestWithRotation(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.Fi
}

func TestStaticPasswords(t *testing.T) {
_ = utils.LeakCheckContext(t)
jsonConfig := `
{
"user01": [{ "Password": "user01" }],
Expand Down
29 changes: 11 additions & 18 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
// This file tests the handshake scenarios between our client and our server.

func TestClearTextClientAuth(t *testing.T) {
ctx := utils.LeakCheckContext(t)
th := &testHandler{}

authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword)
Expand All @@ -51,10 +52,6 @@ func TestClearTextClientAuth(t *testing.T) {
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
port := l.Addr().(*net.TCPAddr).Port
go func() {
l.Accept()
}()

// Setup the right parameters.
params := &ConnParams{
Host: host,
Expand All @@ -63,9 +60,10 @@ func TestClearTextClientAuth(t *testing.T) {
Pass: "password1",
SslMode: vttls.Disabled,
}
go l.Accept()
defer cleanupListener(ctx, l, params)

// Connection should fail, as server requires SSL for clear text auth.
ctx := context.Background()
_, err = Connect(ctx, params)
if err == nil || !strings.Contains(err.Error(), "Cannot use clear text authentication over non-SSL connections") {
t.Fatalf("unexpected connection error: %v", err)
Expand All @@ -92,6 +90,7 @@ func TestClearTextClientAuth(t *testing.T) {
// TestSSLConnection creates a server with TLS support, a client that
// also has SSL support, and connects them.
func TestSSLConnection(t *testing.T) {
ctx := utils.LeakCheckContext(t)
th := &testHandler{}

authServer := NewAuthServerStaticWithAuthMethodDescription("", "", 0, MysqlClearPassword)
Expand All @@ -103,7 +102,6 @@ func TestSSLConnection(t *testing.T) {
// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
port := l.Addr().(*net.TCPAddr).Port

Expand All @@ -122,12 +120,6 @@ func TestSSLConnection(t *testing.T) {
"",
tls.VersionTLS12)
require.NoError(t, err, "TLSServerConfig failed: %v", err)

l.TLSConfig.Store(serverConfig)
go func() {
l.Accept()
}()

// Setup the right parameters.
params := &ConnParams{
Host: host,
Expand All @@ -141,20 +133,22 @@ func TestSSLConnection(t *testing.T) {
SslKey: path.Join(root, "client-key.pem"),
ServerName: "server.example.com",
}
l.TLSConfig.Store(serverConfig)
go l.Accept()
defer cleanupListener(ctx, l, params)

t.Run("Basics", func(t *testing.T) {
testSSLConnectionBasics(t, params)
testSSLConnectionBasics(t, ctx, params)
})

// Make sure clear text auth works over SSL.
t.Run("ClearText", func(t *testing.T) {
testSSLConnectionClearText(t, params)
testSSLConnectionClearText(t, ctx, params)
})
}

func testSSLConnectionClearText(t *testing.T, params *ConnParams) {
func testSSLConnectionClearText(t *testing.T, ctx context.Context, params *ConnParams) {
// Create a client connection, connect.
ctx := context.Background()
conn, err := Connect(ctx, params)
require.NoError(t, err, "Connect failed: %v", err)

Expand All @@ -170,9 +164,8 @@ func testSSLConnectionClearText(t *testing.T, params *ConnParams) {
conn.writeComQuit()
}

func testSSLConnectionBasics(t *testing.T, params *ConnParams) {
func testSSLConnectionBasics(t *testing.T, ctx context.Context, params *ConnParams) {
// Create a client connection, connect.
ctx := context.Background()
conn, err := Connect(ctx, params)
require.NoError(t, err, "Connect failed: %v", err)

Expand Down
5 changes: 5 additions & 0 deletions go/mysql/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/utils"

binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
)

func TestComBinlogDump(t *testing.T) {
_ = utils.LeakCheckContext(t)
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
Expand Down Expand Up @@ -72,6 +75,7 @@ func TestComBinlogDump(t *testing.T) {
}

func TestComBinlogDumpGTID(t *testing.T) {
_ = utils.LeakCheckContext(t)
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
Expand Down Expand Up @@ -161,6 +165,7 @@ func TestComBinlogDumpGTID(t *testing.T) {
}

func TestSendSemiSyncAck(t *testing.T) {
_ = utils.LeakCheckContext(t)
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
Expand Down
Loading

0 comments on commit ba129c7

Please sign in to comment.