Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MAX_EXECUTION_TIME. #10541

Merged
merged 21 commits into from
Jun 24, 2019
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import (

// processinfoSetter is the interface use to set current running process info.
type processinfoSetter interface {
SetProcessInfo(string, time.Time, byte)
SetProcessInfo(string, time.Time, byte, uint64)
}

// recordSet wraps an executor, implements sqlexec.RecordSet interface
Expand Down Expand Up @@ -240,8 +240,9 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {
sql = ss.SecureText()
}
}
maxExecutionTime := getMaxExecutionTime(sctx, a.StmtNode)
// Update processinfo, ShowProcess() will use it.
pi.SetProcessInfo(sql, time.Now(), cmd)
pi.SetProcessInfo(sql, time.Now(), cmd, maxExecutionTime)
a.Ctx.GetSessionVars().StmtCtx.StmtType = GetStmtLabel(a.StmtNode)
}

Expand Down Expand Up @@ -280,6 +281,20 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) {
}, nil
}

// getMaxExecutionTime get the max execution timeout value.
func getMaxExecutionTime(sctx sessionctx.Context, stmtNode ast.StmtNode) uint64 {
ret := sctx.GetSessionVars().MaxExecutionTime
if sel, ok := stmtNode.(*ast.SelectStmt); ok {
for _, hint := range sel.TableHints {
if hint.HintName.L == variable.MaxExecutionTime {
ret = hint.MaxExecutionTime
break
}
}
}
return ret
}

type chunkRowRecordSet struct {
rows []chunk.Row
idx int
Expand Down
19 changes: 16 additions & 3 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (

"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
Expand Down Expand Up @@ -262,6 +263,11 @@ func (cc *clientConn) readPacket() ([]byte, error) {
}

func (cc *clientConn) writePacket(data []byte) error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.writePacket(data)
}

Expand Down Expand Up @@ -847,8 +853,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
defer func() {
// if handleChangeUser failed, cc.ctx may be nil
if cc.ctx != nil {
cc.ctx.SetProcessInfo("", t, mysql.ComSleep)
cc.ctx.SetProcessInfo("", t, mysql.ComSleep, 0)
}

cc.server.releaseToken(token)
span.Finish()
}()
Expand All @@ -863,9 +870,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
switch cmd {
case mysql.ComPing, mysql.ComStmtClose, mysql.ComStmtSendLongData, mysql.ComStmtReset,
mysql.ComSetOption, mysql.ComChangeUser:
cc.ctx.SetProcessInfo("", t, cmd)
cc.ctx.SetProcessInfo("", t, cmd, 0)
case mysql.ComInitDB:
cc.ctx.SetProcessInfo("use "+dataStr, t, cmd)
cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0)
}

switch cmd {
Expand Down Expand Up @@ -928,6 +935,11 @@ func (cc *clientConn) useDB(ctx context.Context, db string) (err error) {
}

func (cc *clientConn) flush() error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.flush()
}

Expand Down Expand Up @@ -1258,6 +1270,7 @@ func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary b
if err != nil {
return err
}

return cc.flush()
}

Expand Down
2 changes: 1 addition & 1 deletion server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err
if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*TiDBStatement); ok {
sql = prepared.sql
}
cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute)
cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute, 0)
rs := stmt.GetResultSet()
if rs == nil {
return mysql.NewErr(mysql.ErrUnknownStmtHandler,
Expand Down
70 changes: 70 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/domain"
Expand Down Expand Up @@ -363,3 +366,70 @@ func mapBelong(m1, m2 map[string]string) bool {
}
return true
}

func (ts ConnTestSuite) TestConnExecutionTimeout(c *C) {
//There is no underlying netCon, use failpoint to avoid panic
c.Assert(failpoint.Enable("github.com/pingcap/tidb/server/FakeClientConn", "return(1)"), IsNil)

c.Parallel()
var err error
ts.store, err = mockstore.NewMockTikvStore()
c.Assert(err, IsNil)
ts.dom, err = session.BootstrapSession(ts.store)
c.Assert(err, IsNil)
se, err := session.CreateSession4Test(ts.store)
c.Assert(err, IsNil)

connID := 1
se.SetConnectionID(uint64(connID))
tc := &TiDBContext{
session: se,
stmts: make(map[int]*TiDBStatement),
}
cc := &clientConn{
connectionID: uint32(connID),
server: &Server{
capability: defaultCapability,
},
ctx: tc,
alloc: arena.NewAllocator(32 * 1024),
}
srv := &Server{
clients: map[uint32]*clientConn{
uint32(connID): cc,
},
}
handle := ts.dom.ExpensiveQueryHandle().SetSessionManager(srv)
go handle.Run(time.Millisecond)
defer handle.Close()

_, err = se.Execute(context.Background(), "use test;")
c.Assert(err, IsNil)
_, err = se.Execute(context.Background(), "CREATE TABLE testTable2 (id bigint PRIMARY KEY, age int)")
c.Assert(err, IsNil)
for i := 0; i < 10; i++ {
str := fmt.Sprintf("insert into testTable2 values(%d, %d)", i, i%80)
_, err = se.Execute(context.Background(), str)
c.Assert(err, IsNil)
}

_, err = se.Execute(context.Background(), "select SLEEP(1);")
c.Assert(err, IsNil)

_, err = se.Execute(context.Background(), "set @@max_execution_time = 500;")
c.Assert(err, IsNil)

err = cc.handleQuery(context.Background(), "select * FROM testTable2 WHERE SLEEP(1);")
c.Assert(err, NotNil)

_, err = se.Execute(context.Background(), "set @@max_execution_time = 0;")
c.Assert(err, IsNil)

err = cc.handleQuery(context.Background(), "select * FROM testTable2 WHERE SLEEP(1);")
c.Assert(err, IsNil)

err = cc.handleQuery(context.Background(), "select /*+ MAX_EXECUTION_TIME(100)*/ * FROM testTable2 WHERE SLEEP(1);")
c.Assert(err, NotNil)

c.Assert(failpoint.Disable("github.com/pingcap/tidb/server/FakeClientConn"), IsNil)
}
2 changes: 1 addition & 1 deletion server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type QueryCtx interface {
// SetValue saves a value associated with this context for key.
SetValue(key fmt.Stringer, value interface{})

SetProcessInfo(sql string, t time.Time, command byte)
SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64)

// CommitTxn commits the transaction operations.
CommitTxn(ctx context.Context) error
Expand Down
4 changes: 2 additions & 2 deletions server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ func (tc *TiDBContext) CommitTxn(ctx context.Context) error {
}

// SetProcessInfo implements QueryCtx SetProcessInfo method.
func (tc *TiDBContext) SetProcessInfo(sql string, t time.Time, command byte) {
tc.session.SetProcessInfo(sql, t, command)
func (tc *TiDBContext) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) {
tc.session.SetProcessInfo(sql, t, command, maxExecutionTime)
}

// RollbackTxn implements QueryCtx RollbackTxn method.
Expand Down
26 changes: 14 additions & 12 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ func init() {
}

var (
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version")
errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied])
errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded])
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand All @@ -107,7 +108,7 @@ type Server struct {
driver IDriver
listener net.Listener
socket net.Listener
rwlock *sync.RWMutex
rwlock sync.RWMutex
concurrentLimiter *TokenLimiter
clients map[uint32]*clientConn
capability uint32
Expand Down Expand Up @@ -199,7 +200,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
cfg: cfg,
driver: driver,
concurrentLimiter: NewTokenLimiter(cfg.TokenLimit),
rwlock: &sync.RWMutex{},
clients: make(map[uint32]*clientConn),
stopListenerCh: make(chan struct{}, 1),
}
Expand Down Expand Up @@ -618,14 +618,16 @@ const (
codeInvalidSequence = 3
codeInvalidType = 4

codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeNotAllowedCommand = 1148
codeAccessDenied = mysql.ErrAccessDenied
codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded
)

func init() {
serverMySQLErrCodes := map[terror.ErrCode]uint16{
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeNotAllowedCommand: mysql.ErrNotAllowedCommand,
codeAccessDenied: mysql.ErrAccessDenied,
codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded,
}
terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes
}
33 changes: 21 additions & 12 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ type Session interface {
SetClientCapability(uint32) // Set client capability flags.
SetConnectionID(uint64)
SetCommandValue(byte)
SetProcessInfo(string, time.Time, byte)
SetProcessInfo(string, time.Time, byte, uint64)
SetTLSState(*tls.ConnectionState)
SetCollation(coID int) error
SetSessionManager(util.SessionManager)
Expand Down Expand Up @@ -829,6 +829,10 @@ func createSessionFunc(store kv.Storage) pools.Factory {
if err != nil {
return nil, err
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxExecutionTime, types.NewUintDatum(0))
if err != nil {
return nil, errors.Trace(err)
}
se.sessionVars.CommonGlobalLoaded = true
se.sessionVars.InRestrictedSQL = true
return se, nil
Expand All @@ -845,6 +849,10 @@ func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.R
if err != nil {
return nil, err
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxExecutionTime, types.NewUintDatum(0))
if err != nil {
return nil, errors.Trace(err)
}
se.sessionVars.CommonGlobalLoaded = true
se.sessionVars.InRestrictedSQL = true
return se, nil
Expand Down Expand Up @@ -956,18 +964,19 @@ func (s *session) ParseSQL(ctx context.Context, sql, charset, collation string)
return s.parser.Parse(sql, charset, collation)
}

func (s *session) SetProcessInfo(sql string, t time.Time, command byte) {
func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) {
pi := util.ProcessInfo{
ID: s.sessionVars.ConnectionID,
DB: s.sessionVars.CurrentDB,
Command: command,
Plan: s.currentPlan,
Time: t,
State: s.Status(),
Info: sql,
CurTxnStartTS: s.sessionVars.TxnCtx.StartTS,
StmtCtx: s.sessionVars.StmtCtx,
StatsInfo: plannercore.GetStatsInfo,
ID: s.sessionVars.ConnectionID,
DB: s.sessionVars.CurrentDB,
Command: command,
Plan: s.currentPlan,
Time: t,
State: s.Status(),
Info: sql,
CurTxnStartTS: s.sessionVars.TxnCtx.StartTS,
StmtCtx: s.sessionVars.StmtCtx,
StatsInfo: plannercore.GetStatsInfo,
MaxExecutionTime: maxExecutionTime,
}
if s.sessionVars.User != nil {
pi.User = s.sessionVars.User.Username
Expand Down
26 changes: 26 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2640,6 +2640,32 @@ func (s *testSessionSuite) TestTxnGoString(c *C) {
c.Assert(fmt.Sprintf("%#v", txn), Equals, "Txn{state=invalid}")
}

func (s *testSessionSuite) TestMaxExeucteTime(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)

tk.MustExec("create table MaxExecTime( id int,name varchar(128),age int);")
tk.MustExec("begin")
tk.MustExec("insert into MaxExecTime (id,name,age) values (1,'john',18),(2,'lary',19),(3,'lily',18);")

tk.MustQuery("select @@MAX_EXECUTION_TIME;").Check(testkit.Rows("0"))
tk.MustQuery("select @@global.MAX_EXECUTION_TIME;").Check(testkit.Rows("0"))
tk.MustQuery("select /*+ MAX_EXECUTION_TIME(1000) */ * FROM MaxExecTime;")

tk.MustExec("set @@global.MAX_EXECUTION_TIME = 300;")
tk.MustQuery("select * FROM MaxExecTime;")

tk.MustExec("set @@MAX_EXECUTION_TIME = 150;")
tk.MustQuery("select * FROM MaxExecTime;")

tk.MustQuery("select @@global.MAX_EXECUTION_TIME;").Check(testkit.Rows("300"))
tk.MustQuery("select @@MAX_EXECUTION_TIME;").Check(testkit.Rows("150"))

tk.MustExec("set @@global.MAX_EXECUTION_TIME = 0;")
tk.MustExec("set @@MAX_EXECUTION_TIME = 0;")
tk.MustExec("commit")
tk.MustExec("drop table if exists MaxExecTime;")
}

func (s *testSessionSuite) TestGrantViewRelated(c *C) {
tkRoot := testkit.NewTestKitWithInit(c, s.store)
tkUser := testkit.NewTestKitWithInit(c, s.store)
Expand Down
Loading