diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 6f3643ebc7f..5263d643df5 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -1707,7 +1707,6 @@ func (c *Conn) IsMarkedForClose() bool { return c.closing } -// GetTestConn returns a conn for testing purpose only. -func GetTestConn() *Conn { - return newConn(testConn{}) +func (c *Conn) IsShuttingDown() bool { + return c.listener.shutdown.Load() } diff --git a/go/mysql/conn_fake.go b/go/mysql/conn_fake.go index 72d944c2f3b..e61f90d33f1 100644 --- a/go/mysql/conn_fake.go +++ b/go/mysql/conn_fake.go @@ -81,3 +81,14 @@ func (m mockAddress) String() string { } var _ net.Addr = (*mockAddress)(nil) + +// GetTestConn returns a conn for testing purpose only. +func GetTestConn() *Conn { + return newConn(testConn{}) +} + +// GetTestServerConn is only meant to be used for testing. +// It creates a server connection using a testConn and the provided listener. +func GetTestServerConn(listener *Listener) *Conn { + return newServerConn(testConn{}, listener) +} diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index bfbb7b105f8..273592b5bf7 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -201,6 +201,12 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co } func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + session := vh.session(c) + if c.IsShuttingDown() && !session.InTransaction { + c.MarkForClose() + return sqlerror.NewSQLError(sqlerror.ERServerShutdown, sqlerror.SSNetError, "Server shutdown in progress") + } + ctx, cancel := context.WithCancel(context.Background()) c.UpdateCancelCtx(cancel) @@ -229,7 +235,6 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq "VTGate MySQL Connector" /* subcomponent: part of the client */) ctx = callerid.NewContext(ctx, ef, im) - session := vh.session(c) if !session.InTransaction { vh.busyConnections.Add(1) } @@ -614,11 +619,11 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys func (srv *mysqlServer) shutdownMysqlProtocolAndDrain() { if srv.tcpListener != nil { - srv.tcpListener.Close() + srv.tcpListener.Shutdown() srv.tcpListener = nil } if srv.unixListener != nil { - srv.unixListener.Close() + srv.unixListener.Shutdown() srv.unixListener = nil } if srv.sigChan != nil { diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 1b161dfb171..1aa201b5d4c 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -342,3 +342,80 @@ func TestKillMethods(t *testing.T) { require.EqualError(t, cancelCtx.Err(), "context canceled") require.True(t, mysqlConn.IsMarkedForClose()) } + +func TestGracefulShutdown(t *testing.T) { + executor, _, _, _, _ := createExecutorEnv(t) + + vh := newVtgateHandler(&VTGate{executor: executor, timings: timings, rowsReturned: rowsReturned, rowsAffected: rowsAffected}) + th := &testHandler{} + listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0) + require.NoError(t, err) + defer listener.Close() + + // add a connection + mysqlConn := mysql.GetTestServerConn(listener) + mysqlConn.ConnectionID = 1 + mysqlConn.UserData = &mysql.StaticUserData{} + vh.connections[1] = mysqlConn + + err = vh.ComQuery(mysqlConn, "select 1", func(result *sqltypes.Result) error { + return nil + }) + assert.NoError(t, err) + + listener.Shutdown() + + err = vh.ComQuery(mysqlConn, "select 1", func(result *sqltypes.Result) error { + return nil + }) + require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)") + + require.True(t, mysqlConn.IsMarkedForClose()) +} + +func TestGracefulShutdownWithTransaction(t *testing.T) { + executor, _, _, _, _ := createExecutorEnv(t) + + vh := newVtgateHandler(&VTGate{executor: executor, timings: timings, rowsReturned: rowsReturned, rowsAffected: rowsAffected}) + th := &testHandler{} + listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0) + require.NoError(t, err) + defer listener.Close() + + // add a connection + mysqlConn := mysql.GetTestServerConn(listener) + mysqlConn.ConnectionID = 1 + mysqlConn.UserData = &mysql.StaticUserData{} + vh.connections[1] = mysqlConn + + err = vh.ComQuery(mysqlConn, "BEGIN", func(result *sqltypes.Result) error { + return nil + }) + assert.NoError(t, err) + + err = vh.ComQuery(mysqlConn, "select 1", func(result *sqltypes.Result) error { + return nil + }) + assert.NoError(t, err) + + listener.Shutdown() + + err = vh.ComQuery(mysqlConn, "select 1", func(result *sqltypes.Result) error { + return nil + }) + assert.NoError(t, err) + + err = vh.ComQuery(mysqlConn, "COMMIT", func(result *sqltypes.Result) error { + return nil + }) + assert.NoError(t, err) + + require.False(t, mysqlConn.IsMarkedForClose()) + + err = vh.ComQuery(mysqlConn, "select 1", func(result *sqltypes.Result) error { + return nil + }) + require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)") + + require.True(t, mysqlConn.IsMarkedForClose()) +}