Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: add a reference count for StmtCtx #39368

Merged
merged 21 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
Info: sql,
CurTxnStartTS: curTxnStartTS,
StmtCtx: s.sessionVars.StmtCtx,
RefCountOfStmtCtx: &s.sessionVars.RefCountOfStmtCtx,
MemTracker: s.sessionVars.MemTracker,
DiskTracker: s.sessionVars.DiskTracker,
StatsInfo: plannercore.GetStatsInfo,
Expand Down
36 changes: 36 additions & 0 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,42 @@ func (warn *SQLWarn) UnmarshalJSON(data []byte) error {
return nil
}

// ReferenceCount indicates the reference count of StmtCtx.
type ReferenceCount int32

const (
// ReferenceCountIsFrozen indicates the current StmtCtx is resetting, it'll refuse all the access from other sessions.
ReferenceCountIsFrozen int32 = -1
// ReferenceCountNoReference indicates the current StmtCtx is not accessed by other sessions.
ReferenceCountNoReference int32 = 0
)

// TryIncrease tries to increase the reference count.
// There is a small chance that TryIncrease returns true while TryFreeze and
// UnFreeze are invoked successfully during the execution of TryIncrease.
func (rf *ReferenceCount) TryIncrease() bool {
refCnt := atomic.LoadInt32((*int32)(rf))
for ; refCnt != ReferenceCountIsFrozen && !atomic.CompareAndSwapInt32((*int32)(rf), refCnt, refCnt+1); refCnt = atomic.LoadInt32((*int32)(rf)) {
}
return refCnt != ReferenceCountIsFrozen
}

// Decrease decreases the reference count.
func (rf *ReferenceCount) Decrease() {
for refCnt := atomic.LoadInt32((*int32)(rf)); !atomic.CompareAndSwapInt32((*int32)(rf), refCnt, refCnt-1); refCnt = atomic.LoadInt32((*int32)(rf)) {
}
}

// TryFreeze tries to freeze the StmtCtx to frozen before resetting the old StmtCtx.
func (rf *ReferenceCount) TryFreeze() bool {
return atomic.LoadInt32((*int32)(rf)) == ReferenceCountNoReference && atomic.CompareAndSwapInt32((*int32)(rf), ReferenceCountNoReference, ReferenceCountIsFrozen)
}

// UnFreeze unfreeze the frozen StmtCtx thus the other session can access this StmtCtx.
func (rf *ReferenceCount) UnFreeze() {
atomic.StoreInt32((*int32)(rf), ReferenceCountNoReference)
}

// StatementContext contains variables for a statement.
// It should be reset before executing a statement.
type StatementContext struct {
Expand Down
12 changes: 11 additions & 1 deletion sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ type SessionVars struct {
// StmtCtx holds variables for current executing statement.
StmtCtx *stmtctx.StatementContext

// RefCountOfStmtCtx indicates the reference count of StmtCtx. When the
// StmtCtx is accessed by other sessions, e.g. oom-alarm-handler/expensive-query-handler, add one first.
// Note: this variable should be accessed and updated by atomic operations.
RefCountOfStmtCtx stmtctx.ReferenceCount

// AllowAggPushDown can be set to false to forbid aggregation push down.
AllowAggPushDown bool

Expand Down Expand Up @@ -1389,7 +1394,12 @@ func (s *SessionVars) InitStatementContext() *stmtctx.StatementContext {
if sc == s.StmtCtx {
sc = &s.cachedStmtCtx[1]
}
*sc = stmtctx.StatementContext{}
if s.RefCountOfStmtCtx.TryFreeze() {
*sc = stmtctx.StatementContext{}
s.RefCountOfStmtCtx.UnFreeze()
} else {
sc = &stmtctx.StatementContext{}
}
return sc
}

Expand Down
6 changes: 5 additions & 1 deletion util/expensivequery/expensivequery.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,9 @@ func (eqh *Handle) LogOnQueryExceedMemQuota(connID uint64) {

// logExpensiveQuery logs the queries which exceed the time threshold or memory threshold.
func logExpensiveQuery(costTime time.Duration, info *util.ProcessInfo, msg string) {
logutil.BgLogger().Warn(msg, util.GenLogFields(costTime, info, true)...)
fields := util.GenLogFields(costTime, info, true)
if fields == nil {
return
}
logutil.BgLogger().Warn(msg, fields...)
}
10 changes: 6 additions & 4 deletions util/memoryusagealarm/memoryusagealarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,16 @@ func (record *memoryUsageAlarm) printTop10SqlInfo(pinfo []*util.ProcessInfo, f *
func (record *memoryUsageAlarm) getTop10SqlInfo(cmp func(i, j *util.ProcessInfo) bool, pinfo []*util.ProcessInfo) strings.Builder {
slices.SortFunc(pinfo, cmp)
list := pinfo
if len(list) > 10 {
list = list[:10]
}
var buf strings.Builder
oomAction := variable.OOMAction.Load()
serverMemoryLimit := memory.ServerMemoryLimit.Load()
for i, info := range list {
for i, totalCnt := 0, 10; i < len(list) && totalCnt > 0; i++ {
info := list[i]
buf.WriteString(fmt.Sprintf("SQL %v: \n", i))
fields := util.GenLogFields(record.lastCheckTime.Sub(info.Time), info, false)
if fields == nil {
continue
}
fields = append(fields, zap.String("tidb_mem_oom_action", oomAction))
fields = append(fields, zap.Uint64("tidb_server_memory_limit", serverMemoryLimit))
fields = append(fields, zap.Int64("tidb_mem_quota_query", info.OOMAlarmVariablesInfo.SessionMemQuotaQuery))
Expand All @@ -294,6 +295,7 @@ func (record *memoryUsageAlarm) getTop10SqlInfo(cmp func(i, j *util.ProcessInfo)
}
buf.WriteString("\n")
}
totalCnt--
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to update info here?

}
buf.WriteString("\n")
return buf
Expand Down
2 changes: 2 additions & 0 deletions util/memoryusagealarm/memoryusagealarm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ func genMockProcessInfoList(memConsumeList []int64, startTimeList []time.Time, s
for i := 0; i < size; i++ {
tracker := memory.NewTracker(0, 0)
tracker.Consume(memConsumeList[i])
var stmtCtxRefCount stmtctx.ReferenceCount = 0
processInfo := util.ProcessInfo{Time: startTimeList[i],
StmtCtx: &stmtctx.StatementContext{},
MemTracker: tracker,
StatsInfo: func(interface{}) map[string]uint64 {
return map[string]uint64{}
},
RefCountOfStmtCtx: &stmtCtxRefCount,
}
processInfoList = append(processInfoList, &processInfo)
}
Expand Down
1 change: 1 addition & 0 deletions util/processinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type ProcessInfo struct {
Time time.Time
Plan interface{}
StmtCtx *stmtctx.StatementContext
RefCountOfStmtCtx *stmtctx.ReferenceCount
MemTracker *memory.Tracker
DiskTracker *disk.Tracker
StatsInfo func(interface{}) map[string]uint64
Expand Down
5 changes: 5 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ func Str2Int64Map(str string) map[int64]struct{} {

// GenLogFields generate log fields.
func GenLogFields(costTime time.Duration, info *ProcessInfo, needTruncateSQL bool) []zap.Field {
if info.RefCountOfStmtCtx != nil && !info.RefCountOfStmtCtx.TryIncrease() {
return nil
}
defer info.RefCountOfStmtCtx.Decrease()

logFields := make([]zap.Field, 0, 20)
logFields = append(logFields, zap.String("cost_time", strconv.FormatFloat(costTime.Seconds(), 'f', -1, 64)+"s"))
execDetail := info.StmtCtx.GetExecDetails()
Expand Down
8 changes: 5 additions & 3 deletions util/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func TestLogFormat(t *testing.T) {
mem.Consume(1<<30 + 1<<29 + 1<<28 + 1<<27)
mockTooLongQuery := make([]byte, 1024*9)

var refCount stmtctx.ReferenceCount = 0
info := &ProcessInfo{
ID: 233,
User: "PingCAP",
Expand All @@ -38,9 +39,10 @@ func TestLogFormat(t *testing.T) {
StatsInfo: func(interface{}) map[string]uint64 {
return nil
},
StmtCtx: &stmtctx.StatementContext{},
MemTracker: mem,
RedactSQL: false,
StmtCtx: &stmtctx.StatementContext{},
RefCountOfStmtCtx: &refCount,
MemTracker: mem,
RedactSQL: false,
}
costTime := time.Second * 233
logSQLTruncateLen := 1024 * 8
Expand Down