diff --git a/server/conn.go b/server/conn.go index 561dae86c7b80..c3c6631e62250 100644 --- a/server/conn.go +++ b/server/conn.go @@ -2184,8 +2184,13 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { if err != nil { logutil.Logger(ctx).Debug("close old context failed", zap.Error(err)) } +<<<<<<< HEAD err = cc.openSessionAndDoAuth(pass) if err != nil { +======= + cc.ctx = nil + if err := cc.openSessionAndDoAuth(pass, ""); err != nil { +>>>>>>> 9fecc8a9b... server: set 'clientConn.ctx = nil' to clean the context when changeUser. (#33703) return err } return cc.handleCommonConnectionReset(ctx) diff --git a/server/conn_test.go b/server/conn_test.go index 0c99937eac746..d4a6a7caaa21b 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -41,6 +41,7 @@ import ( "github.com/tikv/client-go/v2/testutils" ) +<<<<<<< HEAD type ConnTestSuite struct { dom *domain.Domain store kv.Storage @@ -75,6 +76,128 @@ func (ts *ConnTestSuite) TearDownSuite(c *C) { func (ts *ConnTestSuite) TestMalformHandshakeHeader(c *C) { c.Parallel() +======= +type Issue33699CheckType struct { + name string + defVal string + setVal string + isSessionVariable bool +} + +func (c *Issue33699CheckType) toSetSessionVar() string { + if c.isSessionVariable { + return fmt.Sprintf("set session %s=%s", c.name, c.setVal) + } + return fmt.Sprintf("set @%s=%s", c.name, c.setVal) +} + +func (c *Issue33699CheckType) toGetSessionVar() string { + if c.isSessionVariable { + return fmt.Sprintf("select @@session.%s", c.name) + } + return fmt.Sprintf("select @%s", c.name) +} + +func TestIssue33699(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + var outBuffer bytes.Buffer + tidbdrv := NewTiDBDriver(store) + cfg := newTestConfig() + cfg.Port, cfg.Status.StatusPort = 0, 0 + cfg.Status.ReportStatus = false + server, err := NewServer(cfg, tidbdrv) + require.NoError(t, err) + defer server.Close() + + cc := &clientConn{ + connectionID: 1, + salt: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14}, + server: server, + pkt: &packetIO{ + bufWriter: bufio.NewWriter(&outBuffer), + }, + collation: mysql.DefaultCollationID, + peerHost: "localhost", + alloc: arena.NewAllocator(512), + chunkAlloc: chunk.NewAllocator(), + capability: mysql.ClientProtocol41, + } + + tk := testkit.NewTestKit(t, store) + ctx := &TiDBContext{Session: tk.Session()} + cc.ctx = ctx + + // change user. + doChangeUser := func() { + userData := append([]byte("root"), 0x0, 0x0) + userData = append(userData, []byte("test")...) + userData = append(userData, 0x0) + changeUserReq := dispatchInput{ + com: mysql.ComChangeUser, + in: userData, + err: nil, + out: []byte{0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0}, + } + inBytes := append([]byte{changeUserReq.com}, changeUserReq.in...) + err = cc.dispatch(context.Background(), inBytes) + require.Equal(t, changeUserReq.err, err) + if err == nil { + err = cc.flush(context.TODO()) + require.NoError(t, err) + require.Equal(t, changeUserReq.out, outBuffer.Bytes()) + } else { + _ = cc.flush(context.TODO()) + } + outBuffer.Reset() + } + // check variable. + checks := []Issue33699CheckType{ + { // self define. + "a", + "", + "1", + false, + }, + { // session variable + "net_read_timeout", + "30", + "1234", + true, + }, + { + "net_write_timeout", + "60", + "1234", + true, + }, + } + + // default; + for _, ck := range checks { + tk.MustQuery(ck.toGetSessionVar()).Check(testkit.Rows(ck.defVal)) + } + // set; + for _, ck := range checks { + tk.MustExec(ck.toSetSessionVar()) + } + // check after set. + for _, ck := range checks { + tk.MustQuery(ck.toGetSessionVar()).Check(testkit.Rows(ck.setVal)) + } + doChangeUser() + require.NotEqual(t, ctx, cc.ctx) + require.NotEqual(t, ctx.Session, cc.ctx.Session) + // new session,so values is defaults; + tk.SetSession(cc.ctx.Session) // set new session. + for _, ck := range checks { + tk.MustQuery(ck.toGetSessionVar()).Check(testkit.Rows(ck.defVal)) + } +} + +func TestMalformHandshakeHeader(t *testing.T) { +>>>>>>> 9fecc8a9b... server: set 'clientConn.ctx = nil' to clean the context when changeUser. (#33703) data := []byte{0x00} var p handshakeResponse41 _, err := parseHandshakeResponseHeader(context.Background(), &p, data)