diff --git a/executor/prepared_test.go b/executor/prepared_test.go index f428d2b9a891c..69c54b26af783 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -18,11 +18,13 @@ import ( "fmt" "strconv" "strings" + "sync/atomic" "testing" "github.com/pingcap/tidb/parser/auth" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/server" "github.com/pingcap/tidb/sessionctx/variable" @@ -1251,3 +1253,18 @@ func TestIssue31141(t *testing.T) { tk.MustExec("set @@tidb_txn_mode = 'optimistic'") tk.MustExec("prepare stmt1 from 'do 1'") } + +func TestMaxPreparedStmtCount(t *testing.T) { + oldVal := atomic.LoadInt64(&variable.PreparedStmtCount) + atomic.StoreInt64(&variable.PreparedStmtCount, 0) + defer func() { + atomic.StoreInt64(&variable.PreparedStmtCount, oldVal) + }() + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.max_prepared_stmt_count = 2") + tk.MustExec("prepare stmt1 from 'select ? as num from dual'") + tk.MustExec("prepare stmt2 from 'select ? as num from dual'") + err := tk.ExecToErr("prepare stmt3 from 'select ? as num from dual'") + require.True(t, terror.ErrorEqual(err, variable.ErrMaxPreparedStmtCountReached)) +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index cacd814528362..cbed5b3863eaf 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -2064,11 +2064,7 @@ func (s *SessionVars) GetGeneralPlanCacheStmt(sql string) interface{} { // AddPreparedStmt adds prepareStmt to current session and count in global. func (s *SessionVars) AddPreparedStmt(stmtID uint32, stmt interface{}) error { if _, exists := s.PreparedStmts[stmtID]; !exists { - valStr, _ := s.GetSystemVar(MaxPreparedStmtCount) - maxPreparedStmtCount, err := strconv.ParseInt(valStr, 10, 64) - if err != nil { - maxPreparedStmtCount = DefMaxPreparedStmtCount - } + maxPreparedStmtCount := MaxPreparedStmtCountValue.Load() newPreparedStmtCount := atomic.AddInt64(&PreparedStmtCount, 1) if maxPreparedStmtCount >= 0 && newPreparedStmtCount > maxPreparedStmtCount { atomic.AddInt64(&PreparedStmtCount, -1) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index ad2ba72cdbbb2..dc95e507a7302 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -479,7 +479,15 @@ var defaultSysVars = []*SysVar{ }}, /* The system variables below have GLOBAL scope */ - {Scope: ScopeGlobal, Name: MaxPreparedStmtCount, Value: strconv.FormatInt(DefMaxPreparedStmtCount, 10), Type: TypeInt, MinValue: -1, MaxValue: 1048576}, + {Scope: ScopeGlobal, Name: MaxPreparedStmtCount, Value: strconv.FormatInt(DefMaxPreparedStmtCount, 10), Type: TypeInt, MinValue: -1, MaxValue: 1048576, + SetGlobal: func(_ context.Context, s *SessionVars, val string) error { + num, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return errors.Trace(err) + } + MaxPreparedStmtCountValue.Store(num) + return nil + }}, {Scope: ScopeGlobal, Name: InitConnect, Value: "", Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) { p := parser.New() p.SetSQLMode(vars.SQLMode) diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 48f557a20f1c8..c1f82a93cd86f 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -1234,6 +1234,7 @@ var ( PasswordHistory = atomic.NewInt64(DefPasswordReuseHistory) PasswordReuseInterval = atomic.NewInt64(DefPasswordReuseTime) IsSandBoxModeEnabled = atomic.NewBool(false) + MaxPreparedStmtCountValue = atomic.NewInt64(DefMaxPreparedStmtCount) ) var (