diff --git a/executor/adapter.go b/executor/adapter.go index 31a8c8d20150f..3a4ace6ba582b 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -223,6 +223,11 @@ type ExecStmt struct { Ti *TelemetryInfo } +// GetStmtNode returns the stmtNode inside Statement +func (a ExecStmt) GetStmtNode() ast.StmtNode { + return a.StmtNode +} + // PointGet short path for point exec directly from plan, keep only necessary steps func (a *ExecStmt) PointGet(ctx context.Context, is infoschema.InfoSchema) (*recordSet, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { diff --git a/executor/executor.go b/executor/executor.go index 9bb2cd5789d65..7b370615582cc 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1922,11 +1922,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.NotFillCache = !opts.SQLCache } sc.WeakConsistency = isWeakConsistencyRead(ctx, stmt) - // Try to mark the `RCCheckTS` flag for the first time execution of in-transaction read requests - // using read-consistency isolation level. - if NeedSetRCCheckTSFlag(ctx, stmt) { - sc.RCCheckTS = true - } case *ast.SetOprStmt: sc.InSelectStmt = true sc.OverflowAsWarning = true @@ -2065,13 +2060,3 @@ func isWeakConsistencyRead(ctx sessionctx.Context, node ast.Node) bool { return sessionVars.ConnectionID > 0 && sessionVars.ReadConsistency.IsWeak() && plannercore.IsAutoCommitTxn(ctx) && plannercore.IsReadOnly(node, sessionVars) } - -// NeedSetRCCheckTSFlag checks whether it's needed to set `RCCheckTS` flag in current stmtctx. -func NeedSetRCCheckTSFlag(ctx sessionctx.Context, node ast.Node) bool { - sessionVars := ctx.GetSessionVars() - if sessionVars.ConnectionID > 0 && sessionVars.RcReadCheckTS && sessionVars.InTxn() && - sessionVars.IsPessimisticReadConsistency() && !sessionVars.RetryInfo.Retrying && plannercore.IsReadOnly(node, sessionVars) { - return true - } - return false -} diff --git a/executor/executor_test.go b/executor/executor_test.go index c8e4304f1c22c..fb0b9816ccbf7 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -3470,7 +3470,7 @@ func TestUnreasonablyClose(t *testing.T) { err = sessiontxn.NewTxn(context.Background(), tk.Session()) require.NoError(t, err, comment) - err = sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO()) + err = sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), stmt) require.NoError(t, err, comment) executorBuilder := executor.NewMockExecutorBuilderForTest(tk.Session(), is, nil, oracle.GlobalTxnScope) diff --git a/executor/prepared.go b/executor/prepared.go index e6395371c95d5..abe974e899310 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -333,15 +333,11 @@ func (e *DeallocateExec) Next(ctx context.Context, req *chunk.Chunk) error { // CompileExecutePreparedStmt compiles a session Execute command to a stmt.Statement. func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context, - ID uint32, is infoschema.InfoSchema, snapshotTS uint64, replicaReadScope string, args []types.Datum) (*ExecStmt, bool, bool, error) { + execStmt *ast.ExecuteStmt, is infoschema.InfoSchema, snapshotTS uint64, replicaReadScope string, args []types.Datum) (*ExecStmt, bool, bool, error) { startTime := time.Now() defer func() { sctx.GetSessionVars().DurationCompile = time.Since(startTime) }() - execStmt := &ast.ExecuteStmt{ExecID: ID} - if err := ResetContextOfStmt(sctx, execStmt); err != nil { - return nil, false, false, err - } isStaleness := snapshotTS != 0 sctx.GetSessionVars().StmtCtx.IsStaleness = isStaleness execStmt.BinaryArgs = args @@ -369,7 +365,7 @@ func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context, Ti: &TelemetryInfo{}, ReplicaReadScope: replicaReadScope, } - if preparedPointer, ok := sctx.GetSessionVars().PreparedStmts[ID]; ok { + if preparedPointer, ok := sctx.GetSessionVars().PreparedStmts[execStmt.ExecID]; ok { preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt) if !ok { return nil, false, false, errors.Errorf("invalid CachedPrepareStmt type") diff --git a/executor/seqtest/prepared_test.go b/executor/seqtest/prepared_test.go index b39f66d3030ee..b4b018e8ac3ee 100644 --- a/executor/seqtest/prepared_test.go +++ b/executor/seqtest/prepared_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/mysql" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" @@ -157,8 +158,9 @@ func TestPrepared(t *testing.T) { require.NoError(t, err) tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)).Check(testkit.Rows()) + execStmt := &ast.ExecuteStmt{ExecID: stmtID} // Check that ast.Statement created by executor.CompileExecutePreparedStmt has query text. - stmt, _, _, err := executor.CompileExecutePreparedStmt(context.TODO(), tk.Session(), stmtID, + stmt, _, _, err := executor.CompileExecutePreparedStmt(context.TODO(), tk.Session(), execStmt, tk.Session().GetInfoSchema().(infoschema.InfoSchema), 0, kv.GlobalReplicaScope, []types.Datum{types.NewDatum(1)}) require.NoError(t, err) require.Equal(t, query, stmt.OriginText()) diff --git a/go.mod b/go.mod index 9ad9135be4085..8943a2403025a 100644 --- a/go.mod +++ b/go.mod @@ -117,7 +117,7 @@ require ( github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect - github.com/cockroachdb/errors v1.8.1 // indirect + github.com/cockroachdb/errors v1.8.1 github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f // indirect github.com/cockroachdb/redact v1.0.8 // indirect github.com/cockroachdb/sentry-go v0.6.1-cockroachdb.2 // indirect diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 3d55d3f6c14d2..df47167354c2a 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -1653,7 +1653,7 @@ func (p *preprocessor) updateStateFromStaleReadProcessor() error { if err := txnManager.EnterNewTxn(context.TODO(), newTxnRequest); err != nil { return err } - if err := txnManager.OnStmtStart(context.TODO()); err != nil { + if err := txnManager.OnStmtStart(context.TODO(), txnManager.GetCurrentStmt()); err != nil { return err } } diff --git a/planner/funcdep/extract_fd_test.go b/planner/funcdep/extract_fd_test.go index b5bb646cba073..aed58f6dd2957 100644 --- a/planner/funcdep/extract_fd_test.go +++ b/planner/funcdep/extract_fd_test.go @@ -214,7 +214,7 @@ func TestFDSet_ExtractFD(t *testing.T) { for i, tt := range tests { comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql) require.NoError(t, tk.Session().PrepareTxnCtx(context.TODO())) - require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO())) + require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), nil)) stmt, err := par.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) tk.Session().GetSessionVars().PlanID = 0 @@ -312,7 +312,7 @@ func TestFDSet_ExtractFDForApply(t *testing.T) { is := testGetIS(t, tk.Session()) for i, tt := range tests { require.NoError(t, tk.Session().PrepareTxnCtx(context.TODO())) - require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO())) + require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), nil)) comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql) stmt, err := par.ParseOneStmt(tt.sql, "", "") require.NoError(t, err, comment) diff --git a/session/bench_test.go b/session/bench_test.go index c1164ec32de40..75be9443cf7e6 100644 --- a/session/bench_test.go +++ b/session/bench_test.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/benchdaily" @@ -1812,8 +1813,9 @@ func BenchmarkCompileExecutePreparedStmt(b *testing.B) { is := se.GetInfoSchema() b.ResetTimer() + stmtExec := &ast.ExecuteStmt{ExecID: stmtID} for i := 0; i < b.N; i++ { - _, _, _, err := executor.CompileExecutePreparedStmt(context.Background(), se, stmtID, is.(infoschema.InfoSchema), 0, kv.GlobalTxnScope, args) + _, _, _, err := executor.CompileExecutePreparedStmt(context.Background(), se, stmtExec, is.(infoschema.InfoSchema), 0, kv.GlobalTxnScope, args) if err != nil { b.Fatal(err) } diff --git a/session/session.go b/session/session.go index 05b36262e147e..5d6a50246bce1 100644 --- a/session/session.go +++ b/session/session.go @@ -1111,7 +1111,7 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { } _, digest := s.sessionVars.StmtCtx.SQLDigest() s.txn.onStmtStart(digest.String()) - if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx); err == nil { + if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx, st.GetStmtNode()); err == nil { _, err = st.Exec(ctx) } s.txn.onStmtEnd() @@ -1917,7 +1917,7 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex s.txn.onStmtStart(digest.String()) defer s.txn.onStmtEnd() - if err := s.onTxnManagerStmtStartOrRetry(ctx); err != nil { + if err := s.onTxnManagerStmtStartOrRetry(ctx, stmtNode); err != nil { return nil, err } @@ -1980,11 +1980,11 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex return recordSet, nil } -func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context) error { +func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context, node ast.StmtNode) error { if s.sessionVars.RetryInfo.Retrying { return sessiontxn.GetTxnManager(s).OnStmtRetry(ctx) } - return sessiontxn.GetTxnManager(s).OnStmtStart(ctx) + return sessiontxn.GetTxnManager(s).OnStmtStart(ctx, node) } func (s *session) validateStatementReadOnlyInStaleness(stmtNode ast.StmtNode) error { @@ -2201,7 +2201,8 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields return } - if err = s.onTxnManagerStmtStartOrRetry(ctx); err != nil { + prepareStmt := &ast.PrepareStmt{SQLText: sql} + if err = s.onTxnManagerStmtStartOrRetry(ctx, prepareStmt); err != nil { return } @@ -2222,7 +2223,7 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields func (s *session) preparedStmtExec(ctx context.Context, is infoschema.InfoSchema, snapshotTS uint64, - stmtID uint32, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, error) { + execStmt *ast.ExecuteStmt, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, error) { failpoint.Inject("assertTxnManagerInPreparedStmtExec", func() { sessiontxn.RecordAssert(s, "assertTxnManagerInPreparedStmtExec", true) @@ -2232,7 +2233,7 @@ func (s *session) preparedStmtExec(ctx context.Context, } }) - st, tiFlashPushDown, tiFlashExchangePushDown, err := executor.CompileExecutePreparedStmt(ctx, s, stmtID, is, snapshotTS, replicaReadScope, args) + st, tiFlashPushDown, tiFlashExchangePushDown, err := executor.CompileExecutePreparedStmt(ctx, s, execStmt, is, snapshotTS, replicaReadScope, args) if err != nil { return nil, err } @@ -2252,14 +2253,9 @@ func (s *session) preparedStmtExec(ctx context.Context, // cachedPointPlanExec is a short path currently ONLY for cached "point select plan" execution func (s *session) cachedPointPlanExec(ctx context.Context, - is infoschema.InfoSchema, stmtID uint32, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, bool, error) { + is infoschema.InfoSchema, execAst *ast.ExecuteStmt, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, bool, error) { prepared := prepareStmt.PreparedAst - // compile ExecStmt - execAst := &ast.ExecuteStmt{ExecID: stmtID} - if err := executor.ResetContextOfStmt(s, execAst); err != nil { - return nil, false, err - } failpoint.Inject("assertTxnManagerInCachedPlanExec", func() { sessiontxn.RecordAssert(s, "assertTxnManagerInCachedPlanExec", true) @@ -2419,12 +2415,17 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [ s.txn.onStmtStart(preparedStmt.SQLDigest.String()) defer s.txn.onStmtEnd() - if err = s.onTxnManagerStmtStartOrRetry(ctx); err != nil { + execStmt := &ast.ExecuteStmt{ExecID: stmtID} + if err := executor.ResetContextOfStmt(s, execStmt); err != nil { + return nil, err + } + + if err = s.onTxnManagerStmtStartOrRetry(ctx, execStmt); err != nil { return nil, err } if ok { - rs, ok, err := s.cachedPointPlanExec(ctx, txnManager.GetTxnInfoSchema(), stmtID, preparedStmt, replicaReadScope, args) + rs, ok, err := s.cachedPointPlanExec(ctx, txnManager.GetTxnInfoSchema(), execStmt, preparedStmt, replicaReadScope, args) if err != nil { return nil, err } @@ -2432,7 +2433,7 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [ return rs, nil } } - return s.preparedStmtExec(ctx, txnManager.GetTxnInfoSchema(), snapshotTS, stmtID, preparedStmt, replicaReadScope, args) + return s.preparedStmtExec(ctx, txnManager.GetTxnInfoSchema(), snapshotTS, execStmt, preparedStmt, replicaReadScope, args) } func (s *session) DropPreparedStmt(stmtID uint32) error { diff --git a/session/txnmanager.go b/session/txnmanager.go index 19d28ae014f70..3d5a049307eb9 100644 --- a/session/txnmanager.go +++ b/session/txnmanager.go @@ -46,6 +46,7 @@ func getTxnManager(sctx sessionctx.Context) sessiontxn.TxnManager { type txnManager struct { sctx sessionctx.Context + stmtNode ast.StmtNode ctxProvider sessiontxn.TxnContextProvider } @@ -116,14 +117,21 @@ func (m *txnManager) EnterNewTxn(ctx context.Context, r *sessiontxn.EnterNewTxnR func (m *txnManager) OnTxnEnd() { m.ctxProvider = nil + m.stmtNode = nil +} + +func (m *txnManager) GetCurrentStmt() ast.StmtNode { + return m.stmtNode } // OnStmtStart is the hook that should be called when a new statement started -func (m *txnManager) OnStmtStart(ctx context.Context) error { +func (m *txnManager) OnStmtStart(ctx context.Context, node ast.StmtNode) error { + m.stmtNode = node + if m.ctxProvider == nil { return errors.New("context provider not set") } - return m.ctxProvider.OnStmtStart(ctx) + return m.ctxProvider.OnStmtStart(ctx, m.stmtNode) } // OnStmtErrorForNextAction is the hook that should be called when a new statement get an error diff --git a/sessiontxn/interface.go b/sessiontxn/interface.go index d1febc88c8a48..ad41877e7439d 100644 --- a/sessiontxn/interface.go +++ b/sessiontxn/interface.go @@ -124,7 +124,7 @@ type TxnContextProvider interface { // OnInitialize is the hook that should be called when enter a new txn with this provider OnInitialize(ctx context.Context, enterNewTxnType EnterNewTxnType) error // OnStmtStart is the hook that should be called when a new statement started - OnStmtStart(ctx context.Context) error + OnStmtStart(ctx context.Context, node ast.StmtNode) error // OnStmtErrorForNextAction is the hook that should be called when a new statement get an error OnStmtErrorForNextAction(point StmtErrorHandlePoint, err error) (StmtErrorAction, error) // OnStmtRetry is the hook that should be called when a statement is retried internally. @@ -150,7 +150,7 @@ type TxnManager interface { // OnTxnEnd is the hook that should be called after transaction commit or rollback OnTxnEnd() // OnStmtStart is the hook that should be called when a new statement started - OnStmtStart(ctx context.Context) error + OnStmtStart(ctx context.Context, node ast.StmtNode) error // OnStmtErrorForNextAction is the hook that should be called when a new statement get an error // This method is not required to be called for every error in the statement, // it is only required to be called for some errors handled in some specified points given by the parameter `point`. @@ -158,6 +158,8 @@ type TxnManager interface { OnStmtErrorForNextAction(point StmtErrorHandlePoint, err error) (StmtErrorAction, error) // OnStmtRetry is the hook that should be called when a statement retry OnStmtRetry(ctx context.Context) error + // GetCurrentStmt returns the current statement node + GetCurrentStmt() ast.StmtNode } // NewTxn starts a new optimistic and active txn, it can be used for the below scenes: @@ -178,7 +180,8 @@ func NewTxnInStmt(ctx context.Context, sctx sessionctx.Context) error { if err := NewTxn(ctx, sctx); err != nil { return err } - return GetTxnManager(sctx).OnStmtStart(ctx) + txnManager := GetTxnManager(sctx) + return txnManager.OnStmtStart(ctx, txnManager.GetCurrentStmt()) } // GetTxnManager returns the TxnManager object from session context diff --git a/sessiontxn/isolation/base.go b/sessiontxn/isolation/base.go index d5c3bcbca7ab1..92e691d5a94df 100644 --- a/sessiontxn/isolation/base.go +++ b/sessiontxn/isolation/base.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" @@ -138,7 +139,7 @@ func (p *baseTxnContextProvider) GetStmtForUpdateTS() (uint64, error) { return p.getStmtForUpdateTSFunc() } -func (p *baseTxnContextProvider) OnStmtStart(ctx context.Context) error { +func (p *baseTxnContextProvider) OnStmtStart(ctx context.Context, _ ast.StmtNode) error { p.ctx = ctx return nil } diff --git a/sessiontxn/isolation/optimistic_test.go b/sessiontxn/isolation/optimistic_test.go index 25390dca1aa19..d85a53149a752 100644 --- a/sessiontxn/isolation/optimistic_test.go +++ b/sessiontxn/isolation/optimistic_test.go @@ -49,7 +49,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) { se := tk.Session() compareTS := getOracleTS(t, se) provider := initializeOptimisticProvider(t, tk, true) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) readTS, err := provider.GetStmtReadTS() require.NoError(t, err) updateTS, err := provider.GetStmtForUpdateTS() @@ -59,7 +59,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) { compareTS = readTS // for optimistic mode ts, ts should be the same for all statements - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) readTS, err = provider.GetStmtReadTS() require.NoError(t, err) updateTS, err = provider.GetStmtForUpdateTS() @@ -72,7 +72,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) { require.NoError(t, err) stmt := stmts[0] provider = initializeOptimisticProvider(t, tk, false) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), stmt)) plan, _, err := planner.Optimize(context.TODO(), tk.Session(), stmt, provider.GetTxnInfoSchema()) require.NoError(t, err) require.NoError(t, provider.AdviseOptimizeWithPlan(plan)) @@ -85,7 +85,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) { // if the oracle future is prepared fist, `math.MaxUint64` should still be used after plan provider = initializeOptimisticProvider(t, tk, false) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), stmt)) require.NoError(t, provider.AdviseWarmup()) plan, _, err = planner.Optimize(context.TODO(), tk.Session(), stmt, provider.GetTxnInfoSchema()) require.NoError(t, err) @@ -100,7 +100,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) { // when it is in explicit txn, we should not use `math.MaxUint64` compareTS = getOracleTS(t, se) provider = initializeOptimisticProvider(t, tk, true) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), stmt)) plan, _, err = planner.Optimize(context.TODO(), tk.Session(), stmt, provider.GetTxnInfoSchema()) require.NoError(t, err) require.NoError(t, provider.AdviseOptimizeWithPlan(plan)) @@ -115,7 +115,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) { tk.MustExec("set @@autocommit=0") compareTS = getOracleTS(t, se) provider = initializeOptimisticProvider(t, tk, false) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), stmt)) plan, _, err = planner.Optimize(context.TODO(), tk.Session(), stmt, provider.GetTxnInfoSchema()) require.NoError(t, err) require.NoError(t, provider.AdviseOptimizeWithPlan(plan)) @@ -175,7 +175,7 @@ func TestOptimisticHandleError(t *testing.T) { } for _, c := range cases { - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) action, err := provider.OnStmtErrorForNextAction(c.point, c.err) if c.point == sessiontxn.StmtErrAfterPessimisticLock { require.Error(t, err) @@ -183,7 +183,7 @@ func TestOptimisticHandleError(t *testing.T) { require.Equal(t, sessiontxn.StmtActionError, action) // next statement should not update ts - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkTS() } else { require.NoError(t, err) @@ -194,13 +194,13 @@ func TestOptimisticHandleError(t *testing.T) { checkTS() // OnStmtErrorForNextAction again - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) action, err = provider.OnStmtErrorForNextAction(c.point, c.err) require.NoError(t, err) require.Equal(t, sessiontxn.StmtActionNoIdea, action) // next statement should not update ts - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkTS() } } @@ -280,7 +280,7 @@ func TestOptimisticProviderInitialize(t *testing.T) { assertAfterActive.couldRetry = c.autocommit || !c.disableTxnAutoRetry require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider := assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err := provider.GetStmtReadTS() require.NoError(t, err) assertAfterActive.Check(t) @@ -343,12 +343,12 @@ func TestTidbSnapshotVarInOptimisticTxn(t *testing.T) { } // information schema and ts should equal to snapshot when tidb_snapshot is set - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseSnapshot() // information schema and ts will restore when set tidb_snapshot to empty tk.MustExec("set @@tidb_snapshot=''") - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseTxn() // txn will not be active after `GetStmtReadTS` or `GetStmtForUpdateTS` when `tidb_snapshot` is set @@ -368,7 +368,7 @@ func TestTidbSnapshotVarInOptimisticTxn(t *testing.T) { assertAfterUseSnapshot := activeSnapshotTxnAssert(se, se.GetSessionVars().SnapshotTS, "") require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider = assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseSnapshot() assertAfterUseSnapshot.Check(t) }() diff --git a/sessiontxn/isolation/readcommitted.go b/sessiontxn/isolation/readcommitted.go index 985b6a04800f3..9918f8b226e5b 100644 --- a/sessiontxn/isolation/readcommitted.go +++ b/sessiontxn/isolation/readcommitted.go @@ -77,13 +77,30 @@ func NewPessimisticRCTxnContextProvider(sctx sessionctx.Context, causalConsisten } // OnStmtStart is the hook that should be called when a new statement started -func (p *PessimisticRCTxnContextProvider) OnStmtStart(ctx context.Context) error { - if err := p.baseTxnContextProvider.OnStmtStart(ctx); err != nil { +func (p *PessimisticRCTxnContextProvider) OnStmtStart(ctx context.Context, node ast.StmtNode) error { + if err := p.baseTxnContextProvider.OnStmtStart(ctx, node); err != nil { return err } + + // Try to mark the `RCCheckTS` flag for the first time execution of in-transaction read requests + // using read-consistency isolation level. + if node != nil && NeedSetRCCheckTSFlag(p.sctx, node) { + p.sctx.GetSessionVars().StmtCtx.RCCheckTS = true + } + return p.prepareStmt(!p.isTxnPrepared) } +// NeedSetRCCheckTSFlag checks whether it's needed to set `RCCheckTS` flag in current stmtctx. +func NeedSetRCCheckTSFlag(ctx sessionctx.Context, node ast.Node) bool { + sessionVars := ctx.GetSessionVars() + if sessionVars.ConnectionID > 0 && sessionVars.RcReadCheckTS && sessionVars.InTxn() && + !sessionVars.RetryInfo.Retrying && plannercore.IsReadOnly(node, sessionVars) { + return true + } + return false +} + // OnStmtErrorForNextAction is the hook that should be called when a new statement get an error func (p *PessimisticRCTxnContextProvider) OnStmtErrorForNextAction(point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { switch point { diff --git a/sessiontxn/isolation/readcommitted_test.go b/sessiontxn/isolation/readcommitted_test.go index 5c747eba4fa2c..a01066f3588b1 100644 --- a/sessiontxn/isolation/readcommitted_test.go +++ b/sessiontxn/isolation/readcommitted_test.go @@ -29,10 +29,12 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessiontxn" "github.com/pingcap/tidb/sessiontxn/isolation" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/types" "github.com/stretchr/testify/require" tikverr "github.com/tikv/client-go/v2/error" ) @@ -55,26 +57,26 @@ func TestPessimisticRCTxnContextProviderRCCheck(t *testing.T) { forUpdateStmt := stmts[0] compareTS := se.GetSessionVars().TxnCtx.StartTS - // first ts should request from tso + // first ts should use the txn startTS require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), readOnlyStmt)) ts, err := provider.GetStmtReadTS() require.NoError(t, err) require.Equal(t, ts, compareTS) rcCheckTS := ts - // second ts should reuse first ts + // second ts should reuse the txn startTS require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), readOnlyStmt)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) require.Equal(t, rcCheckTS, ts) // when one statement did not getStmtReadTS, the next one should still reuse the first ts require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), readOnlyStmt)) require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), readOnlyStmt)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) require.Equal(t, rcCheckTS, ts) @@ -93,7 +95,7 @@ func TestPessimisticRCTxnContextProviderRCCheck(t *testing.T) { // if retry succeed next statement will still use rc check require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), readOnlyStmt)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) require.Equal(t, rcCheckTS, ts) @@ -103,14 +105,14 @@ func TestPessimisticRCTxnContextProviderRCCheck(t *testing.T) { require.NoError(t, err) require.Equal(t, sessiontxn.StmtActionNoIdea, nextAction) require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), readOnlyStmt)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) require.Equal(t, rcCheckTS, ts) // `StmtErrAfterPessimisticLock` will still disable rc check require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), readOnlyStmt)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) require.Equal(t, rcCheckTS, ts) @@ -128,7 +130,7 @@ func TestPessimisticRCTxnContextProviderRCCheck(t *testing.T) { // only read-only stmt can retry for rc check require.NoError(t, executor.ResetContextOfStmt(se, forUpdateStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), forUpdateStmt)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) require.Greater(t, ts, compareTS) @@ -137,6 +139,60 @@ func TestPessimisticRCTxnContextProviderRCCheck(t *testing.T) { require.Equal(t, sessiontxn.StmtActionNoIdea, nextAction) } +func TestPessimisticRCTxnContextProviderRCCheckForPrepareExecute(t *testing.T) { + store, _, clean := testkit.CreateMockStoreAndDomain(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk2.MustExec("use test") + tk.MustExec("create table t (id int primary key, v int)") + tk2.MustExec("insert into t values(1, 1)") + + tk.MustExec("set @@tidb_rc_read_check_ts=1") + se := tk.Session() + ctx := context.Background() + provider := initializePessimisticRCProvider(t, tk) + txnStartTS := se.GetSessionVars().TxnCtx.StartTS + + // first ts should use the txn startTS + stmt, _, _, err := tk.Session().PrepareStmt("select * from t") + require.NoError(t, err) + rs, err := tk.Session().ExecutePreparedStmt(ctx, stmt, []types.Datum{}) + tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)).Check(testkit.Rows("1 1")) + require.NoError(t, err) + ts, err := provider.GetStmtForUpdateTS() + require.NoError(t, err) + require.Equal(t, txnStartTS, ts) + + // second ts should reuse the txn startTS + rs, err = tk.Session().ExecutePreparedStmt(ctx, stmt, []types.Datum{}) + tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)).Check(testkit.Rows("1 1")) + require.NoError(t, err) + ts, err = provider.GetStmtForUpdateTS() + require.NoError(t, err) + require.Equal(t, txnStartTS, ts) + + tk2.MustExec("update t set v = v + 10 where id = 1") + compareTS := getOracleTS(t, se) + rs, err = tk.Session().ExecutePreparedStmt(ctx, stmt, []types.Datum{}) + require.NoError(t, err) + _, err = session.ResultSetToStringSlice(ctx, tk.Session(), rs) + require.Error(t, err) + ts, err = provider.GetStmtForUpdateTS() + require.NoError(t, err) + require.Greater(t, compareTS, ts) + // retry + tk.Session().GetSessionVars().RetryInfo.Retrying = true + rs, err = tk.Session().ExecutePreparedStmt(ctx, stmt, []types.Datum{}) + require.NoError(t, err) + tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)).Check(testkit.Rows("1 11")) + ts, err = provider.GetStmtForUpdateTS() + require.NoError(t, err) + require.Greater(t, ts, compareTS) +} + func TestPessimisticRCTxnContextProviderLockError(t *testing.T) { store, _, clean := testkit.CreateMockStoreAndDomain(t) defer clean() @@ -155,7 +211,7 @@ func TestPessimisticRCTxnContextProviderLockError(t *testing.T) { &tikverr.ErrDeadlock{Deadlock: &kvrpcpb.Deadlock{}, IsRetryable: true}, } { require.NoError(t, executor.ResetContextOfStmt(se, stmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), stmt)) nextAction, err := provider.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, lockErr) require.NoError(t, err) require.Equal(t, sessiontxn.StmtActionRetryReady, nextAction) @@ -167,7 +223,7 @@ func TestPessimisticRCTxnContextProviderLockError(t *testing.T) { errors.New("err"), } { require.NoError(t, executor.ResetContextOfStmt(se, stmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), stmt)) nextAction, err := provider.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, lockErr) require.Same(t, lockErr, err) require.Equal(t, sessiontxn.StmtActionError, nextAction) @@ -189,7 +245,7 @@ func TestPessimisticRCTxnContextProviderTS(t *testing.T) { // first read require.NoError(t, executor.ResetContextOfStmt(se, stmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), stmt)) readTS, err := provider.GetStmtReadTS() require.NoError(t, err) forUpdateTS, err := provider.GetStmtForUpdateTS() @@ -202,7 +258,7 @@ func TestPessimisticRCTxnContextProviderTS(t *testing.T) { compareTS = getOracleTS(t, se) require.Greater(t, compareTS, readTS) require.NoError(t, executor.ResetContextOfStmt(se, stmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), stmt)) readTS, err = provider.GetStmtReadTS() require.NoError(t, err) forUpdateTS, err = provider.GetStmtForUpdateTS() @@ -267,7 +323,7 @@ func TestRCProviderInitialize(t *testing.T) { assertAfterActive := activeRCTxnAssert(t, se, true) require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider := assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err := provider.GetStmtReadTS() require.NoError(t, err) assertAfterActive.Check(t) @@ -280,7 +336,7 @@ func TestRCProviderInitialize(t *testing.T) { assertAfterActive = activeRCTxnAssert(t, se, true) require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider = assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) assertAfterActive.Check(t) @@ -346,12 +402,12 @@ func TestTidbSnapshotVarInRC(t *testing.T) { } // information schema and ts should equal to snapshot when tidb_snapshot is set - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseSnapshot() // information schema and ts will restore when set tidb_snapshot to empty tk.MustExec("set @@tidb_snapshot=''") - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseTxn(false) // txn will not be active after `GetStmtReadTS` or `GetStmtForUpdateTS` when `tidb_snapshot` is set @@ -372,7 +428,7 @@ func TestTidbSnapshotVarInRC(t *testing.T) { assertAfterUseSnapshot := activeSnapshotTxnAssert(se, se.GetSessionVars().SnapshotTS, "READ-COMMITTED") require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider = assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseSnapshot() assertAfterUseSnapshot.Check(t) }() diff --git a/sessiontxn/isolation/repeatable_read.go b/sessiontxn/isolation/repeatable_read.go index 571d2754be9a3..4827446d27a6c 100644 --- a/sessiontxn/isolation/repeatable_read.go +++ b/sessiontxn/isolation/repeatable_read.go @@ -122,8 +122,8 @@ func (p *PessimisticRRTxnContextProvider) updateForUpdateTS() (err error) { } // OnStmtStart is the hook that should be called when a new statement started -func (p *PessimisticRRTxnContextProvider) OnStmtStart(ctx context.Context) error { - if err := p.baseTxnContextProvider.OnStmtStart(ctx); err != nil { +func (p *PessimisticRRTxnContextProvider) OnStmtStart(ctx context.Context, node ast.StmtNode) error { + if err := p.baseTxnContextProvider.OnStmtStart(ctx, node); err != nil { return err } diff --git a/sessiontxn/isolation/repeatable_read_test.go b/sessiontxn/isolation/repeatable_read_test.go index c60c1c3da560d..c1487a1bb0ae7 100644 --- a/sessiontxn/isolation/repeatable_read_test.go +++ b/sessiontxn/isolation/repeatable_read_test.go @@ -74,7 +74,7 @@ func TestPessimisticRRErrorHandle(t *testing.T) { nextAction, err = provider.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, lockErr) require.NoError(t, err) require.Equal(t, sessiontxn.StmtActionRetryReady, nextAction) - err = provider.OnStmtStart(context.TODO()) + err = provider.OnStmtStart(context.TODO(), nil) // Unlike StmtRetry which uses forUpdateTS got in OnStmtErrorForNextAction, OnStmtStart will reset provider's forUpdateTS, // which leads GetStmtForUpdateTS to acquire the latest ts. compareTS2 = getOracleTS(t, se) @@ -111,7 +111,7 @@ func TestPessimisticRRErrorHandle(t *testing.T) { nextAction, err = provider.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, lockErr) require.NoError(t, err) require.Equal(t, sessiontxn.StmtActionRetryReady, nextAction) - err = provider.OnStmtStart(context.TODO()) + err = provider.OnStmtStart(context.TODO(), nil) require.NoError(t, err) // Unlike StmtRetry which uses forUpdateTS got in OnStmtErrorForNextAction, OnStmtStart will reset provider's forUpdateTS, // which leads GetStmtForUpdateTS to acquire the latest ts. @@ -153,7 +153,7 @@ func TestRepeatableReadProviderTS(t *testing.T) { compareTS := getOracleTS(t, se) // The read ts should be less than the compareTS require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) CurrentTS, err = provider.GetStmtReadTS() require.NoError(t, err) require.Greater(t, compareTS, CurrentTS) @@ -161,7 +161,7 @@ func TestRepeatableReadProviderTS(t *testing.T) { // The read ts should also be less than the compareTS in a new statement (after calling OnStmtStart) require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) CurrentTS, err = provider.GetStmtReadTS() require.NoError(t, err) require.Equal(t, CurrentTS, prevTS) @@ -175,14 +175,14 @@ func TestRepeatableReadProviderTS(t *testing.T) { // The for update read ts should be larger than the compareTS require.NoError(t, executor.ResetContextOfStmt(se, forUpdateStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) forUpdateTS, err := provider.GetStmtForUpdateTS() require.NoError(t, err) require.Greater(t, forUpdateTS, compareTS) // But the read ts is still less than the compareTS require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) CurrentTS, err = provider.GetStmtReadTS() require.NoError(t, err) require.Equal(t, CurrentTS, prevTS) @@ -228,7 +228,7 @@ func TestRepeatableReadProviderInitialize(t *testing.T) { assertAfterActive := activePessimisticRRAssert(t, se, true) require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider := assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err := provider.GetStmtReadTS() require.NoError(t, err) assertAfterActive.Check(t) @@ -241,7 +241,7 @@ func TestRepeatableReadProviderInitialize(t *testing.T) { assertAfterActive = activePessimisticRRAssert(t, se, true) require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider = assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) assertAfterActive.Check(t) @@ -303,12 +303,12 @@ func TestTidbSnapshotVarInPessimisticRepeatableRead(t *testing.T) { } // information schema and ts should equal to snapshot when tidb_snapshot is set - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseSnapshot() // information schema and ts will restore when set tidb_snapshot to empty tk.MustExec("set @@tidb_snapshot=''") - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseTxn() // txn will not be active after `GetStmtReadTS` or `GetStmtForUpdateTS` when `tidb_snapshot` is set @@ -329,7 +329,7 @@ func TestTidbSnapshotVarInPessimisticRepeatableRead(t *testing.T) { assertAfterUseSnapshot := activeSnapshotTxnAssert(se, se.GetSessionVars().SnapshotTS, "REPEATABLE-READ") require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider = assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseSnapshot() assertAfterUseSnapshot.Check(t) }() @@ -390,11 +390,11 @@ func TestOptimizeWithPlanInPessimisticRR(t *testing.T) { for _, c := range cases { compareTS = getOracleTS(t, se) - require.NoError(t, txnManager.OnStmtStart(context.TODO())) + require.NoError(t, txnManager.OnStmtStart(context.TODO(), nil)) stmt, err = parser.New().ParseOneStmt(c.sql, "", "") require.NoError(t, err) - err = provider.OnStmtStart(context.TODO()) + err = provider.OnStmtStart(context.TODO(), nil) require.NoError(t, err) compiler = executor.Compiler{Ctx: se} @@ -432,9 +432,9 @@ func TestOptimizeWithPlanInPessimisticRR(t *testing.T) { // Test use startTS after optimize when autocommit=0 activeAssert := activePessimisticRRAssert(t, tk.Session(), true) provider = initializeRepeatableReadProvider(t, tk, false) - require.NoError(t, txnManager.OnStmtStart(context.TODO())) stmt, err = parser.New().ParseOneStmt("update t set v = v + 10 where id = 1", "", "") require.NoError(t, err) + require.NoError(t, txnManager.OnStmtStart(context.TODO(), stmt)) execStmt, err = compiler.Compile(context.TODO(), stmt) require.NoError(t, err) err = txnManager.AdviseOptimizeWithPlan(execStmt.Plan) @@ -448,9 +448,9 @@ func TestOptimizeWithPlanInPessimisticRR(t *testing.T) { compareTS = getOracleTS(t, se) activeAssert = activePessimisticRRAssert(t, tk.Session(), true) provider = initializeRepeatableReadProvider(t, tk, false) - require.NoError(t, txnManager.OnStmtStart(context.TODO())) stmt, err = parser.New().ParseOneStmt("select * from t", "", "") require.NoError(t, err) + require.NoError(t, txnManager.OnStmtStart(context.TODO(), stmt)) execStmt, err = compiler.Compile(context.TODO(), stmt) require.NoError(t, err) err = txnManager.AdviseOptimizeWithPlan(execStmt.Plan) diff --git a/sessiontxn/isolation/serializable_test.go b/sessiontxn/isolation/serializable_test.go index a28e455195cbf..f192adf909369 100644 --- a/sessiontxn/isolation/serializable_test.go +++ b/sessiontxn/isolation/serializable_test.go @@ -54,7 +54,7 @@ func TestPessimisticSerializableTxnProviderTS(t *testing.T) { compareTS := getOracleTS(t, se) require.NoError(t, executor.ResetContextOfStmt(se, readOnlyStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err := provider.GetStmtReadTS() require.NoError(t, err) require.Greater(t, compareTS, ts) @@ -62,7 +62,7 @@ func TestPessimisticSerializableTxnProviderTS(t *testing.T) { // In Oracle-like serializable isolation, readTS equals to the for update ts require.NoError(t, executor.ResetContextOfStmt(se, forUpdateStmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err = provider.GetStmtForUpdateTS() require.NoError(t, err) require.Greater(t, compareTS, ts) @@ -87,7 +87,7 @@ func TestPessimisticSerializableTxnContextProviderLockError(t *testing.T) { &tikverr.ErrDeadlock{Deadlock: &kvrpcpb.Deadlock{}, IsRetryable: true}, } { require.NoError(t, executor.ResetContextOfStmt(se, stmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) nextAction, err := provider.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, lockErr) require.Same(t, lockErr, err) require.Equal(t, sessiontxn.StmtActionError, nextAction) @@ -99,7 +99,7 @@ func TestPessimisticSerializableTxnContextProviderLockError(t *testing.T) { errors.New("err"), } { require.NoError(t, executor.ResetContextOfStmt(se, stmt)) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) nextAction, err := provider.OnStmtErrorForNextAction(sessiontxn.StmtErrAfterPessimisticLock, lockErr) require.Same(t, lockErr, err) require.Equal(t, sessiontxn.StmtActionError, nextAction) @@ -147,7 +147,7 @@ func TestSerializableInitialize(t *testing.T) { assertAfterActive := activeSerializableAssert(t, se, true) require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider := assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err := provider.GetStmtReadTS() require.NoError(t, err) assertAfterActive.Check(t) @@ -160,7 +160,7 @@ func TestSerializableInitialize(t *testing.T) { assertAfterActive = activeSerializableAssert(t, se, true) require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider = assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) ts, err = provider.GetStmtReadTS() require.NoError(t, err) assertAfterActive.Check(t) @@ -223,12 +223,12 @@ func TestTidbSnapshotVarInSerialize(t *testing.T) { } // information schema and ts should equal to snapshot when tidb_snapshot is set - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseSnapshot() // information schema and ts will restore when set tidb_snapshot to empty tk.MustExec("set @@tidb_snapshot=''") - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseTxn() // txn will not be active after `GetStmtReadTS` or `GetStmtForUpdateTS` when `tidb_snapshot` is set @@ -250,7 +250,7 @@ func TestTidbSnapshotVarInSerialize(t *testing.T) { assertAfterUseSnapshot := activeSnapshotTxnAssert(se, se.GetSessionVars().SnapshotTS, "SERIALIZABLE") require.NoError(t, se.PrepareTxnCtx(context.TODO())) provider = assert.CheckAndGetProvider(t) - require.NoError(t, provider.OnStmtStart(context.TODO())) + require.NoError(t, provider.OnStmtStart(context.TODO(), nil)) checkUseSnapshot() assertAfterUseSnapshot.Check(t) }() diff --git a/sessiontxn/staleread/provider.go b/sessiontxn/staleread/provider.go index cc77cdd214b37..f76f500ed31c7 100644 --- a/sessiontxn/staleread/provider.go +++ b/sessiontxn/staleread/provider.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/sessiontxn" @@ -86,7 +87,7 @@ func (p *StalenessTxnContextProvider) OnInitialize(ctx context.Context, tp sessi } // OnStmtStart is the hook that should be called when a new statement started -func (p *StalenessTxnContextProvider) OnStmtStart(_ context.Context) error { +func (p *StalenessTxnContextProvider) OnStmtStart(_ context.Context, _ ast.StmtNode) error { return nil } diff --git a/sessiontxn/txn_manager_test.go b/sessiontxn/txn_manager_test.go index e32f8bc2b3784..7fe8024b69de5 100644 --- a/sessiontxn/txn_manager_test.go +++ b/sessiontxn/txn_manager_test.go @@ -135,7 +135,7 @@ func TestEnterNewTxn(t *testing.T) { Type: sessiontxn.EnterNewTxnBeforeStmt, }) require.NoError(t, err) - require.NoError(t, mgr.OnStmtStart(context.TODO())) + require.NoError(t, mgr.OnStmtStart(context.TODO(), nil)) require.NoError(t, mgr.AdviseWarmup()) }, request: &sessiontxn.EnterNewTxnRequest{ diff --git a/tests/realtikvtest/sessiontest/temporary_table_test.go b/tests/realtikvtest/sessiontest/temporary_table_test.go index 6eb2ceddb5d3c..796f51cd68287 100644 --- a/tests/realtikvtest/sessiontest/temporary_table_test.go +++ b/tests/realtikvtest/sessiontest/temporary_table_test.go @@ -321,7 +321,7 @@ func TestTemporaryTableInterceptor(t *testing.T) { for _, initFunc := range initTxnFuncs { require.NoError(t, initFunc()) - require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO())) + require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), nil)) txn, err := tk.Session().Txn(true) require.NoError(t, err) diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 5a0d39361f1ab..2db32927705f9 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -173,6 +173,9 @@ type Statement interface { // RebuildPlan rebuilds the plan of the statement. RebuildPlan(ctx context.Context) (schemaVersion int64, err error) + + // GetStmtNode returns the stmtNode inside Statement + GetStmtNode() ast.StmtNode } // RecordSet is an abstract result set interface to help get data from Plan.