Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: store the hints of session variable (#45814) #46046

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions planner/core/plan_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ func getPointQueryPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmtC
}
sessVars.FoundInPlanCache = true
stmtCtx.PointExec = true
if pointGetPlan, ok := plan.(*PointGetPlan); ok && pointGetPlan != nil && pointGetPlan.stmtHints != nil {
sessVars.StmtCtx.StmtHints = *pointGetPlan.stmtHints
}
return plan, names, true, nil
}

Expand Down Expand Up @@ -251,6 +254,7 @@ func getGeneralPlan(sctx sessionctx.Context, isGeneralPlanCache bool, cacheKey k
planCacheCounter.Inc()
}
stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest)
stmtCtx.StmtHints = *cachedVal.stmtHints
return cachedVal.Plan, cachedVal.OutPutNames, true, nil
}

Expand Down Expand Up @@ -289,7 +293,7 @@ func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isGeneralPlan
}
sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{}
}
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, paramTypes)
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, paramTypes, &stmtCtx.StmtHints)
stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p)
stmtCtx.SetPlan(p)
stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest)
Expand Down Expand Up @@ -687,12 +691,15 @@ func tryCachePointPlan(_ context.Context, sctx sessionctx.Context,
names types.NameSlice
)

if _, _ok := p.(*PointGetPlan); _ok {
if plan, _ok := p.(*PointGetPlan); _ok {
ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p)
names = p.OutputNames()
if err != nil {
return err
}
if ok {
plan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone()
}
}

if ok {
Expand Down
6 changes: 5 additions & 1 deletion planner/core/plan_cache_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
driver "github.com/pingcap/tidb/types/parser_driver"
Expand Down Expand Up @@ -348,6 +349,8 @@ type PlanCacheValue struct {
TblInfo2UnionScan map[*model.TableInfo]bool
ParamTypes FieldSlice
memoryUsage int64
// stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan.
stmtHints *stmtctx.StmtHints
}

func (v *PlanCacheValue) varTypesUnchanged(txtVarTps []*types.FieldType) bool {
Expand Down Expand Up @@ -395,7 +398,7 @@ func (v *PlanCacheValue) MemoryUsage() (sum int64) {

// NewPlanCacheValue creates a SQLCacheValue.
func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool,
paramTypes []*types.FieldType) *PlanCacheValue {
paramTypes []*types.FieldType, stmtHints *stmtctx.StmtHints) *PlanCacheValue {
dstMap := make(map[*model.TableInfo]bool)
for k, v := range srcMap {
dstMap[k] = v
Expand All @@ -409,6 +412,7 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta
OutPutNames: names,
TblInfo2UnionScan: dstMap,
ParamTypes: userParamTypes,
stmtHints: stmtHints.Clone(),
}
}

Expand Down
2 changes: 2 additions & 0 deletions planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ type PointGetPlan struct {
// probeParents records the IndexJoins and Applys with this operator in their inner children.
// Please see comments in PhysicalPlan for details.
probeParents []PhysicalPlan
// stmtHints should restore in executing context.
stmtHints *stmtctx.StmtHints
}

func (p *PointGetPlan) getEstRowCountForDisplay() float64 {
Expand Down
1 change: 1 addition & 0 deletions session/session_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ go_test(
"//privilege/privileges",
"//session",
"//sessionctx",
"//sessionctx/stmtctx",
"//sessionctx/variable",
"//store/copr",
"//store/mockstore",
Expand Down
61 changes: 61 additions & 0 deletions session/session_test/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/pingcap/tidb/privilege/privileges"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/store/copr"
"github.com/pingcap/tidb/store/mockstore"
Expand Down Expand Up @@ -4139,3 +4140,63 @@ func TestSQLModeOp(t *testing.T) {
a = mysql.SetSQLMode(s, mysql.ModeAllowInvalidDates)
require.Equal(t, mysql.ModeNoBackslashEscapes|mysql.ModeOnlyFullGroupBy|mysql.ModeAllowInvalidDates, a)
}

func TestPrepareExecuteWithSQLHints(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
se := tk.Session()
se.SetConnectionID(1)
tk.MustExec("use test")
tk.MustExec("create table t(a int primary key)")

type hintCheck struct {
hint string
check func(*stmtctx.StmtHints)
}

hintChecks := []hintCheck{
{
hint: "MEMORY_QUOTA(1024 MB)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasMemQuotaHint)
require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery)
},
},
{
hint: "READ_CONSISTENT_REPLICA()",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasReplicaReadHint)
require.Equal(t, byte(kv.ReplicaReadFollower), stmtHint.ReplicaRead)
},
},
{
hint: "MAX_EXECUTION_TIME(1000)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasMaxExecutionTime)
require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime)
},
},
{
hint: "USE_TOJA(TRUE)",
check: func(stmtHint *stmtctx.StmtHints) {
require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint)
require.True(t, stmtHint.AllowInSubqToJoinAndAgg)
},
},
}

for i, check := range hintChecks {
// common path
tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint))
for j := 0; j < 10; j++ {
tk.MustQuery(fmt.Sprintf("execute stmt%d", i))
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
}
// fast path
tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint))
for j := 0; j < 10; j++ {
tk.MustQuery(fmt.Sprintf("execute fast%d", i))
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
}
}
}
2 changes: 1 addition & 1 deletion sessionctx/stmtctx/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ go_test(
],
embed = [":stmtctx"],
flaky = True,
shard_count = 5,
shard_count = 6,
deps = [
"//kv",
"//sessionctx/variable",
Expand Down
36 changes: 35 additions & 1 deletion sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ type StatementContext struct {
type StmtHints struct {
// Hint Information
MemQuotaQuery int64
ApplyCacheCapacity int64
MaxExecutionTime uint64
ReplicaRead byte
AllowInSubqToJoinAndAgg bool
Expand Down Expand Up @@ -418,6 +417,41 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool {
return sh.ForceNthPlan != -1
}

// Clone the StmtHints struct and returns the pointer of the new one.
func (sh *StmtHints) Clone() *StmtHints {
var (
vars map[string]string
tableHints []*ast.TableOptimizerHint
)
if len(sh.SetVars) > 0 {
vars = make(map[string]string, len(sh.SetVars))
for k, v := range sh.SetVars {
vars[k] = v
}
}
if len(sh.OriginalTableHints) > 0 {
tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints))
copy(tableHints, sh.OriginalTableHints)
}
return &StmtHints{
MemQuotaQuery: sh.MemQuotaQuery,
MaxExecutionTime: sh.MaxExecutionTime,
ReplicaRead: sh.ReplicaRead,
AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg,
NoIndexMergeHint: sh.NoIndexMergeHint,
StraightJoinOrder: sh.StraightJoinOrder,
EnableCascadesPlanner: sh.EnableCascadesPlanner,
ForceNthPlan: sh.ForceNthPlan,
HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint,
HasMemQuotaHint: sh.HasMemQuotaHint,
HasReplicaReadHint: sh.HasReplicaReadHint,
HasMaxExecutionTime: sh.HasMaxExecutionTime,
HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint,
SetVars: vars,
OriginalTableHints: tableHints,
}
}

// StmtCacheKey represents the key type in the StmtCache.
type StmtCacheKey int

Expand Down
23 changes: 23 additions & 0 deletions sessionctx/stmtctx/stmtctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"fmt"
"math/rand"
"reflect"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -272,3 +273,25 @@ func TestApproxRuntimeInfo(t *testing.T) {
require.Equal(t, d.TotBackoffTime[backoff], timeSum)
}
}

func TestStmtHintsClone(t *testing.T) {
hints := stmtctx.StmtHints{}
value := reflect.ValueOf(&hints).Elem()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
switch field.Kind() {
case reflect.Int, reflect.Int32, reflect.Int64:
field.SetInt(1)
case reflect.Uint, reflect.Uint32, reflect.Uint64:
field.SetUint(1)
case reflect.Uint8: // byte
field.SetUint(1)
case reflect.Bool:
field.SetBool(true)
case reflect.String:
field.SetString("test")
default:
}
}
require.Equal(t, hints, *hints.Clone())
}