diff --git a/server/conn.go b/server/conn.go index e2a45b5b225c5..62f9ed08eba43 100644 --- a/server/conn.go +++ b/server/conn.go @@ -107,6 +107,7 @@ type clientConn struct { mu struct { sync.RWMutex cancelFunc context.CancelFunc + resultSets []ResultSet } } @@ -883,6 +884,15 @@ func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) { metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc() return errors.Trace(err) } + cc.mu.Lock() + cc.mu.resultSets = rs + status := atomic.LoadInt32(&cc.status) + if status == connStatusShutdown || status == connStatusWaitShutdown { + cc.mu.Unlock() + killConn(cc) + return errors.New("killed by another connection") + } + cc.mu.Unlock() if rs != nil { if len(rs) == 1 { err = cc.writeResultset(ctx, rs[0], false, 0, 0) diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 20991846579e1..d5490b8780e38 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -16,7 +16,6 @@ package server import ( "crypto/tls" "fmt" - "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" @@ -31,6 +30,7 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" "golang.org/x/net/context" + "sync/atomic" ) // TiDBDriver implements IDriver. @@ -336,7 +336,7 @@ type tidbResultSet struct { recordSet sqlexec.RecordSet columns []*ColumnInfo rows []chunk.Row - closed bool + closed int32 } func (trs *tidbResultSet) NewChunk() *chunk.Chunk { @@ -359,10 +359,9 @@ func (trs *tidbResultSet) GetFetchedRows() []chunk.Row { } func (trs *tidbResultSet) Close() error { - if trs.closed { + if !atomic.CompareAndSwapInt32(&trs.closed, 0, 1) { return nil } - trs.closed = true return trs.recordSet.Close() } diff --git a/server/server.go b/server/server.go index 495c96390c0f2..48f9fff6c8a67 100644 --- a/server/server.go +++ b/server/server.go @@ -351,19 +351,25 @@ func (s *Server) Kill(connectionID uint64, query bool) { return } - killConn(conn, query) -} - -func killConn(conn *clientConn, query bool) { if !query { // Mark the client connection status as WaitShutdown, when the goroutine detect // this, it will end the dispatch loop and exit. atomic.StoreInt32(&conn.status, connStatusWaitShutdown) } + killConn(conn) +} +func killConn(conn *clientConn) { conn.mu.RLock() + resultSets := conn.mu.resultSets cancelFunc := conn.mu.cancelFunc conn.mu.RUnlock() + for _, resultSet := range resultSets { + // resultSet.Close() is reentrant so it's safe to kill a same connID multiple times + if err := resultSet.Close(); err != nil { + logutil.Logger(context.Background()).Error("close result set error", zap.Uint32("connID", conn.connectionID), zap.Error(err)) + } + } if cancelFunc != nil { cancelFunc() } @@ -378,12 +384,7 @@ func (s *Server) KillAllConnections() { for _, conn := range s.clients { atomic.StoreInt32(&conn.status, connStatusShutdown) terror.Log(errors.Trace(conn.closeWithoutLock())) - conn.mu.RLock() - cancelFunc := conn.mu.cancelFunc - conn.mu.RUnlock() - if cancelFunc != nil { - cancelFunc() - } + killConn(conn) } }