diff --git a/executor/grant.go b/executor/grant.go index cd3d8065b6591..18ae258dae0d7 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -64,7 +64,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.RecordBatch) error { dbName = e.ctx.GetSessionVars().CurrentDB } // Grant for each user - for _, user := range e.Users { + for idx, user := range e.Users { // Check if user exists. exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname) if err != nil { @@ -105,6 +105,15 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.WithGrant { privs = append(privs, &ast.PrivElem{Priv: mysql.GrantPriv}) } + + if idx == 0 { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(ctx); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } + // Grant each priv to the user. for _, priv := range privs { if len(priv.Cols) > 0 { diff --git a/executor/revoke.go b/executor/revoke.go index ada1a06f3dbeb..bba8101518cf3 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -58,7 +58,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { e.done = true // Revoke for each user. - for _, user := range e.Users { + for idx, user := range e.Users { // Check if user exists. exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname) if err != nil { @@ -68,6 +68,13 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { return errors.Errorf("Unknown user: %s", user.User) } + if idx == 0 { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(ctx); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } err = e.revokeOneUser(user.User.Username, user.User.Hostname) if err != nil { return err diff --git a/executor/simple.go b/executor/simple.go index 87d76ba8bdf0b..e050998f73f06 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -56,6 +56,15 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro if e.done { return nil } + + if e.autoNewTxn() { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(ctx); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } + switch x := e.Statement.(type) { case *ast.GrantRoleStmt: err = e.executeGrantRole(x) @@ -70,7 +79,7 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro case *ast.RollbackStmt: err = e.executeRollback(x) case *ast.CreateUserStmt: - err = e.executeCreateUser(x) + err = e.executeCreateUser(ctx, x) case *ast.AlterUserStmt: err = e.executeAlterUser(x) case *ast.DropUserStmt: @@ -490,7 +499,7 @@ func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { return nil } -func (e *SimpleExec) executeCreateUser(s *ast.CreateUserStmt) error { +func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStmt) error { users := make([]string, 0, len(s.Specs)) for _, spec := range s.Specs { exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) @@ -516,6 +525,7 @@ func (e *SimpleExec) executeCreateUser(s *ast.CreateUserStmt) error { if len(users) == 0 { return nil } + sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) if s.IsCreateRole { sql = fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password, Account_locked) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) @@ -831,3 +841,11 @@ func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error { } return h.Update(GetInfoSchema(e.ctx)) } + +func (e *SimpleExec) autoNewTxn() bool { + switch e.Statement.(type) { + case *ast.CreateUserStmt, *ast.AlterUserStmt, *ast.DropUserStmt: + return true + } + return false +} diff --git a/executor/simple_test.go b/executor/simple_test.go index 4ffd55a97a0e9..267b2763c2ab3 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -439,3 +439,33 @@ func (s *testSuite3) TestUseDB(c *C) { _, err = tk.Exec("USE ``") c.Assert(terror.ErrorEqual(core.ErrNoDB, err), IsTrue, Commentf("err %v", err)) } + +func (s *testSuite3) TestStmtAutoNewTxn(c *C) { + // Some statements are like DDL, they commit the previous txn automically. + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + // Fix issue https://github.com/pingcap/tidb/issues/10705 + tk.MustExec("begin") + tk.MustExec("create user 'xxx'@'%';") + tk.MustExec("grant all privileges on *.* to 'xxx'@'%';") + + tk.MustExec("create table auto_new (id int)") + tk.MustExec("begin") + tk.MustExec("insert into auto_new values (1)") + tk.MustExec("revoke all privileges on *.* from 'xxx'@'%'") + tk.MustExec("rollback") // insert statement has already committed + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1")) + + // Test the behavior when autocommit is false. + tk.MustExec("set autocommit = 0") + tk.MustExec("insert into auto_new values (2)") + tk.MustExec("create user 'yyy'@'%'") + tk.MustExec("rollback") + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1", "2")) + + tk.MustExec("drop user 'yyy'@'%'") + tk.MustExec("insert into auto_new values (3)") + tk.MustExec("rollback") + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1", "2")) +}