diff --git a/executor/adapter.go b/executor/adapter.go index 7f3c7417d91d4..e6c0d2af64d2f 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -106,7 +106,7 @@ func (a *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { defer span1.Finish() } - err := a.executor.Next(ctx, req) + err := Next(ctx, a.executor, req) if err != nil { a.lastErr = err return err @@ -385,7 +385,7 @@ func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, e Executor) (sqlex a.logAudit() }() - err = e.Next(ctx, chunk.NewRecordBatch(newFirstChunk(e))) + err = Next(ctx, e, chunk.NewRecordBatch(newFirstChunk(e))) if err != nil { return nil, err } diff --git a/executor/aggregate.go b/executor/aggregate.go index dda7d395599b3..ad3d224b7c6a9 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -555,7 +555,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context) { } chk = input.chk } - err = e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err = Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { e.finalOutputCh <- &AfFinalResult{err: err} return @@ -681,7 +681,7 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro func (e *HashAggExec) execute(ctx context.Context) (err error) { inputIter := chunk.NewIterator4Chunk(e.childResult) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } @@ -870,7 +870,7 @@ func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Ch return err } - err = e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err = Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } diff --git a/executor/delete.go b/executor/delete.go index ae1620dfe14e0..f198861dfe99d 100644 --- a/executor/delete.go +++ b/executor/delete.go @@ -105,7 +105,7 @@ func (e *DeleteExec) deleteSingleTableByChunk(ctx context.Context) error { for { iter := chunk.NewIterator4Chunk(chk) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { return err } @@ -187,7 +187,7 @@ func (e *DeleteExec) deleteMultiTablesByChunk(ctx context.Context) error { chk := newFirstChunk(e.children[0]) for { iter := chunk.NewIterator4Chunk(chk) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { return err } diff --git a/executor/errors.go b/executor/errors.go index c1d426ef71535..a48152f0acdfe 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -52,6 +52,7 @@ var ( ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject]) ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted]) ErrDeadlock = terror.ClassExecutor.New(mysql.ErrLockDeadlock, mysql.MySQLErrName[mysql.ErrLockDeadlock]) + ErrQueryInterrupted = terror.ClassExecutor.New(mysql.ErrQueryInterrupted, mysql.MySQLErrName[mysql.ErrQueryInterrupted]) ) func init() { @@ -69,6 +70,7 @@ func init() { mysql.ErrBadDB: mysql.ErrBadDB, mysql.ErrWrongObject: mysql.ErrWrongObject, mysql.ErrLockDeadlock: mysql.ErrLockDeadlock, + mysql.ErrQueryInterrupted: mysql.ErrQueryInterrupted, } terror.ErrClassToMySQLCodes[terror.ClassExecutor] = tableMySQLErrCodes } diff --git a/executor/executor.go b/executor/executor.go index 9991a2d7343a3..4481657de1443 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -180,6 +180,16 @@ type Executor interface { Schema() *expression.Schema } +// Next is a wrapper function on e.Next(), it handles some common codes. +func Next(ctx context.Context, e Executor, req *chunk.RecordBatch) error { + sessVars := e.base().ctx.GetSessionVars() + if atomic.CompareAndSwapUint32(&sessVars.Killed, 1, 0) { + return ErrQueryInterrupted + } + + return e.Next(ctx, req) +} + // CancelDDLJobsExec represents a cancel DDL jobs executor. type CancelDDLJobsExec struct { baseExecutor @@ -559,7 +569,7 @@ func (e *CheckIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error } chk := newFirstChunk(e.src) for { - err := e.src.Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.src, chunk.NewRecordBatch(chk)) if err != nil { return err } @@ -668,7 +678,7 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.RecordBatch) error } req.GrowAndReset(e.maxChunkSize) - err := e.children[0].Next(ctx, req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -728,7 +738,7 @@ func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error { for !e.meetFirstBatch { // transfer req's requiredRows to childResult and then adjust it in childResult e.childResult = e.childResult.SetRequiredRows(req.RequiredRows(), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.adjustRequiredRows(e.childResult))) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.adjustRequiredRows(e.childResult))) if err != nil { return err } @@ -753,7 +763,7 @@ func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error { e.cursor += batchSize } e.adjustRequiredRows(req.Chunk) - err := e.children[0].Next(ctx, req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -823,7 +833,7 @@ func init() { } chk := newFirstChunk(exec) for { - err = exec.Next(ctx, chunk.NewRecordBatch(chk)) + err = Next(ctx, exec, chunk.NewRecordBatch(chk)) if err != nil { return rows, err } @@ -940,7 +950,7 @@ func (e *SelectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error } req.AppendRow(e.inputRow) } - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } @@ -972,7 +982,7 @@ func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) err return nil } } - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } @@ -1120,7 +1130,7 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error return nil } e.evaluated = true - err := e.children[0].Next(ctx, req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -1135,7 +1145,7 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error } childChunk := newFirstChunk(e.children[0]) - err = e.children[0].Next(ctx, chunk.NewRecordBatch(childChunk)) + err = Next(ctx, e.children[0], chunk.NewRecordBatch(childChunk)) if err != nil { return err } @@ -1246,7 +1256,7 @@ func (e *UnionExec) resultPuller(ctx context.Context, childID int) { return case result.chk = <-e.resourcePools[childID]: } - result.err = e.children[childID].Next(ctx, chunk.NewRecordBatch(result.chk)) + result.err = Next(ctx, e.children[childID], chunk.NewRecordBatch(result.chk)) if result.err == nil && result.chk.NumRows() == 0 { return } diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index ac8d82c4b42d6..f658522dccb0f 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -386,7 +386,7 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { task.memTracker.Consume(task.outerResult.MemoryUsage()) for !task.outerResult.IsFull() { - err := ow.executor.Next(ctx, chunk.NewRecordBatch(ow.executorChk)) + err := Next(ctx, ow.executor, chunk.NewRecordBatch(ow.executorChk)) if err != nil { return task, err } @@ -586,7 +586,7 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa innerResult.GetMemTracker().SetLabel(innerResultLabel) innerResult.GetMemTracker().AttachTo(task.memTracker) for { - err := innerExec.Next(ctx, chunk.NewRecordBatch(iw.executorChk)) + err := Next(ctx, innerExec, chunk.NewRecordBatch(iw.executorChk)) if err != nil { return err } diff --git a/executor/join.go b/executor/join.go index 37b1b6991c39f..47148c833acc1 100644 --- a/executor/join.go +++ b/executor/join.go @@ -202,6 +202,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { if e.finished.Load().(bool) { return } + var outerResource *outerChkResource var ok bool select { @@ -217,7 +218,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { required := int(atomic.LoadInt64(&e.requiredRows)) outerResult.SetRequiredRows(required, e.maxChunkSize) } - err := e.outerExec.Next(ctx, chunk.NewRecordBatch(outerResult)) + err := Next(ctx, e.outerExec, chunk.NewRecordBatch(outerResult)) if err != nil { e.joinResultCh <- &hashjoinWorkerResult{ err: err, @@ -244,6 +245,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { if outerResult.NumRows() == 0 { return } + outerResource.dest <- outerResult } } @@ -276,8 +278,9 @@ func (e *HashJoinExec) fetchInnerRows(ctx context.Context) error { if e.finished.Load().(bool) { return nil } + chk := newFirstChunk(e.children[e.innerIdx]) - err = e.innerExec.Next(ctx, chunk.NewRecordBatch(chk)) + err = Next(ctx, e.innerExec, chunk.NewRecordBatch(chk)) if err != nil || chk.NumRows() == 0 { return err } @@ -512,6 +515,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) (err er if e.joinResultCh == nil { return nil } + result, ok := <-e.joinResultCh if !ok { return nil @@ -642,7 +646,7 @@ func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *ch outerIter := chunk.NewIterator4Chunk(e.outerChunk) for { if e.outerChunkCursor >= e.outerChunk.NumRows() { - err := e.outerExec.Next(ctx, chunk.NewRecordBatch(e.outerChunk)) + err := Next(ctx, e.outerExec, chunk.NewRecordBatch(e.outerChunk)) if err != nil { return nil, err } @@ -679,7 +683,7 @@ func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { e.innerList.Reset() innerIter := chunk.NewIterator4Chunk(e.innerChunk) for { - err := e.innerExec.Next(ctx, chunk.NewRecordBatch(e.innerChunk)) + err := Next(ctx, e.innerExec, chunk.NewRecordBatch(e.innerChunk)) if err != nil { return err } diff --git a/executor/merge_join.go b/executor/merge_join.go index b972c607ef1de..bc6f597a9325d 100644 --- a/executor/merge_join.go +++ b/executor/merge_join.go @@ -142,7 +142,7 @@ func (t *mergeJoinInnerTable) nextRow() (chunk.Row, error) { if t.curRow == t.curIter.End() { t.reallocReaderResult() oldMemUsage := t.curResult.MemoryUsage() - err := t.reader.Next(t.ctx, chunk.NewRecordBatch(t.curResult)) + err := Next(t.ctx, t.reader, chunk.NewRecordBatch(t.curResult)) // error happens or no more data. if err != nil || t.curResult.NumRows() == 0 { t.curRow = t.curIter.End() @@ -389,7 +389,7 @@ func (e *MergeJoinExec) fetchNextOuterRows(ctx context.Context, requiredRows int e.outerTable.chk.SetRequiredRows(requiredRows, e.maxChunkSize) } - err = e.outerTable.reader.Next(ctx, chunk.NewRecordBatch(e.outerTable.chk)) + err = Next(ctx, e.outerTable.reader, chunk.NewRecordBatch(e.outerTable.chk)) if err != nil { return err } diff --git a/executor/projection.go b/executor/projection.go index e22b22d3e4fc3..f316558524044 100644 --- a/executor/projection.go +++ b/executor/projection.go @@ -179,7 +179,7 @@ func (e *ProjectionExec) isUnparallelExec() bool { func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk) error { // transmit the requiredRows e.childResult.SetRequiredRows(chk.RequiredRows(), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(e.childResult)) if err != nil { return err } @@ -312,7 +312,7 @@ func (f *projectionInputFetcher) run(ctx context.Context) { requiredRows := atomic.LoadInt64(&f.proj.parentReqRows) input.chk.SetRequiredRows(int(requiredRows), f.proj.maxChunkSize) - err := f.child.Next(ctx, chunk.NewRecordBatch(input.chk)) + err := Next(ctx, f.child, chunk.NewRecordBatch(input.chk)) if err != nil || input.chk.NumRows() == 0 { output.done <- err return diff --git a/executor/sort.go b/executor/sort.go index 3e41a1b1b4aac..4d4ce8a3d22d6 100644 --- a/executor/sort.go +++ b/executor/sort.go @@ -111,7 +111,7 @@ func (e *SortExec) fetchRowChunks(ctx context.Context) error { e.rowChunks.GetMemTracker().SetLabel(rowChunksLabel) for { chk := newFirstChunk(e.children[0]) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { return err } @@ -282,7 +282,7 @@ func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { srcChk := newFirstChunk(e.children[0]) // adjust required rows by total limit srcChk.SetRequiredRows(int(e.totalLimit-uint64(e.rowChunks.Len())), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(srcChk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(srcChk)) if err != nil { return err } @@ -307,7 +307,7 @@ func (e *TopNExec) executeTopN(ctx context.Context) error { } childRowChk := newFirstChunk(e.children[0]) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(childRowChk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(childRowChk)) if err != nil { return err } diff --git a/executor/union_scan.go b/executor/union_scan.go index 5a8de698f49ac..2967953149f53 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -199,7 +199,7 @@ func (us *UnionScanExec) getSnapshotRow(ctx context.Context) ([]types.Datum, err us.cursor4SnapshotRows = 0 us.snapshotRows = us.snapshotRows[:0] for len(us.snapshotRows) == 0 { - err = us.children[0].Next(ctx, chunk.NewRecordBatch(us.snapshotChunkBuffer)) + err = Next(ctx, us.children[0], chunk.NewRecordBatch(us.snapshotChunkBuffer)) if err != nil || us.snapshotChunkBuffer.NumRows() == 0 { return nil, err } diff --git a/executor/update.go b/executor/update.go index f6840f68ff197..9b478cc80d968 100644 --- a/executor/update.go +++ b/executor/update.go @@ -181,7 +181,7 @@ func (e *UpdateExec) fetchChunkRows(ctx context.Context) error { chk := newFirstChunk(e.children[0]) e.evalBuffer = chunk.MutRowFromTypes(fields) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chunk.NewRecordBatch(chk)) if err != nil { return err } diff --git a/executor/window.go b/executor/window.go index 0b51691f139e0..2ba119736e564 100644 --- a/executor/window.go +++ b/executor/window.go @@ -131,7 +131,7 @@ func (e *WindowExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk } childResult := newFirstChunk(e.children[0]) - err = e.children[0].Next(ctx, &chunk.RecordBatch{Chunk: childResult}) + err = Next(ctx, e.children[0], &chunk.RecordBatch{Chunk: childResult}) if err != nil { return errors.Trace(err) } diff --git a/server/conn.go b/server/conn.go index 24d16c0af9b50..290821f758c84 100644 --- a/server/conn.go +++ b/server/conn.go @@ -45,7 +45,6 @@ import ( "runtime" "strconv" "strings" - "sync" "sync/atomic" "time" @@ -150,13 +149,6 @@ type clientConn struct { peerHost string // peer host peerPort string // peer port lastCode uint16 // last error code - - // mu is used for cancelling the execution of current transaction. - mu struct { - sync.RWMutex - cancelFunc context.CancelFunc - resultSets []ResultSet - } } func (cc *clientConn) String() string { @@ -847,11 +839,6 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { span := opentracing.StartSpan("server.dispatch") - ctx1, cancelFunc := context.WithCancel(ctx) - cc.mu.Lock() - cc.mu.cancelFunc = cancelFunc - cc.mu.Unlock() - t := time.Now() cmd := data[0] data = data[1:] @@ -863,6 +850,8 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { span.Finish() }() + vars := cc.ctx.GetSessionVars() + atomic.StoreUint32(&vars.Killed, 0) if cmd < mysql.ComEnd { cc.ctx.SetCommandValue(cmd) } @@ -893,11 +882,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { data = data[:len(data)-1] dataStr = string(hack.String(data)) } - return cc.handleQuery(ctx1, dataStr) + return cc.handleQuery(ctx, dataStr) case mysql.ComPing: return cc.writeOK() case mysql.ComInitDB: - if err := cc.useDB(ctx1, dataStr); err != nil { + if err := cc.useDB(ctx, dataStr); err != nil { return err } return cc.writeOK() @@ -906,9 +895,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { case mysql.ComStmtPrepare: return cc.handleStmtPrepare(dataStr) case mysql.ComStmtExecute: - return cc.handleStmtExecute(ctx1, data) + return cc.handleStmtExecute(ctx, data) case mysql.ComStmtFetch: - return cc.handleStmtFetch(ctx1, data) + return cc.handleStmtFetch(ctx, data) case mysql.ComStmtClose: return cc.handleStmtClose(data) case mysql.ComStmtSendLongData: @@ -918,7 +907,7 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { case mysql.ComSetOption: return cc.handleSetOption(data) case mysql.ComChangeUser: - return cc.handleChangeUser(ctx1, data) + return cc.handleChangeUser(ctx, data) default: return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd) } @@ -1171,15 +1160,11 @@ func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) { metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc() return err } - cc.mu.Lock() - cc.mu.resultSets = rs status := atomic.LoadInt32(&cc.status) if status == connStatusShutdown || status == connStatusWaitShutdown { - cc.mu.Unlock() killConn(cc) return errors.New("killed by another connection") } - cc.mu.Unlock() if rs != nil { if len(rs) == 1 { err = cc.writeResultset(ctx, rs[0], false, 0, 0) diff --git a/server/server.go b/server/server.go index f45cb65e6c7a7..c258f074d0ed0 100644 --- a/server/server.go +++ b/server/server.go @@ -530,19 +530,8 @@ func (s *Server) Kill(connectionID uint64, query bool) { } func killConn(conn *clientConn) { - conn.mu.RLock() - resultSets := conn.mu.resultSets - cancelFunc := conn.mu.cancelFunc - conn.mu.RUnlock() - for _, resultSet := range resultSets { - // resultSet.Close() is reentrant so it's safe to kill a same connID multiple times - if err := resultSet.Close(); err != nil { - logutil.Logger(context.Background()).Error("close result set error", zap.Uint32("connID", conn.connectionID), zap.Error(err)) - } - } - if cancelFunc != nil { - cancelFunc() - } + sessVars := conn.ctx.GetSessionVars() + atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1) } // KillAllConnections kills all connections when server is not gracefully shutdown. diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 8a651b0958129..20906cd009cdc 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -379,6 +379,9 @@ type SessionVars struct { // LowResolutionTSO is used for reading data with low resolution TSO which is updated once every two seconds. LowResolutionTSO bool + + // Killed is a flag to indicate that this query is killed. + Killed uint32 } // ConnectionInfo present connection used by audit.