Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

txn: manage the initialization of RCCheckTS by transaction context provider #35554

Merged
merged 24 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ type ExecStmt struct {
Ti *TelemetryInfo
}

// GetStmtNode returns the stmtNode inside Statement
func (a ExecStmt) GetStmtNode() ast.StmtNode {
return a.StmtNode
}

// PointGet short path for point exec directly from plan, keep only necessary steps
func (a *ExecStmt) PointGet(ctx context.Context, is infoschema.InfoSchema) (*recordSet, error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
Expand Down
15 changes: 0 additions & 15 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1904,11 +1904,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.NotFillCache = !opts.SQLCache
}
sc.WeakConsistency = isWeakConsistencyRead(ctx, stmt)
// Try to mark the `RCCheckTS` flag for the first time execution of in-transaction read requests
// using read-consistency isolation level.
if NeedSetRCCheckTSFlag(ctx, stmt) {
sc.RCCheckTS = true
}
case *ast.SetOprStmt:
sc.InSelectStmt = true
sc.OverflowAsWarning = true
Expand Down Expand Up @@ -2042,13 +2037,3 @@ func isWeakConsistencyRead(ctx sessionctx.Context, node ast.Node) bool {
return sessionVars.ConnectionID > 0 && sessionVars.ReadConsistency.IsWeak() &&
plannercore.IsAutoCommitTxn(ctx) && plannercore.IsReadOnly(node, sessionVars)
}

// NeedSetRCCheckTSFlag checks whether it's needed to set `RCCheckTS` flag in current stmtctx.
func NeedSetRCCheckTSFlag(ctx sessionctx.Context, node ast.Node) bool {
sessionVars := ctx.GetSessionVars()
if sessionVars.ConnectionID > 0 && sessionVars.RcReadCheckTS && sessionVars.InTxn() &&
sessionVars.IsPessimisticReadConsistency() && !sessionVars.RetryInfo.Retrying && plannercore.IsReadOnly(node, sessionVars) {
return true
}
return false
}
2 changes: 1 addition & 1 deletion executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3470,7 +3470,7 @@ func TestUnreasonablyClose(t *testing.T) {
err = sessiontxn.NewTxn(context.Background(), tk.Session())
require.NoError(t, err, comment)

err = sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO())
err = sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), stmt)
require.NoError(t, err, comment)

executorBuilder := executor.NewMockExecutorBuilderForTest(tk.Session(), is, nil, oracle.GlobalTxnScope)
Expand Down
5 changes: 1 addition & 4 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,7 @@ func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context,
defer func() {
sctx.GetSessionVars().DurationCompile = time.Since(startTime)
}()
execStmt := &ast.ExecuteStmt{ExecID: ID}
if err := ResetContextOfStmt(sctx, execStmt); err != nil {
return nil, false, false, err
}
execStmt := sessiontxn.GetTxnManager(sctx).GetCurrentStmt().(*ast.ExecuteStmt)
isStaleness := snapshotTS != 0
sctx.GetSessionVars().StmtCtx.IsStaleness = isStaleness
execStmt.BinaryArgs = args
Expand Down
2 changes: 1 addition & 1 deletion planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ func (p *preprocessor) updateStateFromStaleReadProcessor() error {
if err := txnManager.EnterNewTxn(context.TODO(), newTxnRequest); err != nil {
return err
}
if err := txnManager.OnStmtStart(context.TODO()); err != nil {
if err := txnManager.OnStmtStart(context.TODO(), txnManager.GetCurrentStmt()); err != nil {
return err
}
}
Expand Down
4 changes: 2 additions & 2 deletions planner/funcdep/extract_fd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func TestFDSet_ExtractFD(t *testing.T) {
for i, tt := range tests {
comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql)
require.NoError(t, tk.Session().PrepareTxnCtx(context.TODO()))
require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO()))
require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), nil))
stmt, err := par.ParseOneStmt(tt.sql, "", "")
require.NoError(t, err, comment)
tk.Session().GetSessionVars().PlanID = 0
Expand Down Expand Up @@ -312,7 +312,7 @@ func TestFDSet_ExtractFDForApply(t *testing.T) {
is := testGetIS(t, tk.Session())
for i, tt := range tests {
require.NoError(t, tk.Session().PrepareTxnCtx(context.TODO()))
require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO()))
require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), nil))
comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql)
stmt, err := par.ParseOneStmt(tt.sql, "", "")
require.NoError(t, err, comment)
Expand Down
23 changes: 13 additions & 10 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) {
}
_, digest := s.sessionVars.StmtCtx.SQLDigest()
s.txn.onStmtStart(digest.String())
if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx); err == nil {
if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx, st.GetStmtNode()); err == nil {
_, err = st.Exec(ctx)
}
s.txn.onStmtEnd()
Expand Down Expand Up @@ -1904,7 +1904,7 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex
s.txn.onStmtStart(digest.String())
defer s.txn.onStmtEnd()

if err := s.onTxnManagerStmtStartOrRetry(ctx); err != nil {
if err := s.onTxnManagerStmtStartOrRetry(ctx, stmtNode); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1967,11 +1967,11 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex
return recordSet, nil
}

func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context) error {
func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context, node ast.StmtNode) error {
if s.sessionVars.RetryInfo.Retrying {
return sessiontxn.GetTxnManager(s).OnStmtRetry(ctx)
}
return sessiontxn.GetTxnManager(s).OnStmtStart(ctx)
return sessiontxn.GetTxnManager(s).OnStmtStart(ctx, node)
}

func (s *session) validateStatementReadOnlyInStaleness(stmtNode ast.StmtNode) error {
Expand Down Expand Up @@ -2185,7 +2185,8 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields
return
}

if err = s.onTxnManagerStmtStartOrRetry(ctx); err != nil {
prepareStmt := &ast.PrepareStmt{SQLText: sql}
if err = s.onTxnManagerStmtStartOrRetry(ctx, prepareStmt); err != nil {
return
}

Expand Down Expand Up @@ -2240,10 +2241,7 @@ func (s *session) cachedPointPlanExec(ctx context.Context,

prepared := prepareStmt.PreparedAst
// compile ExecStmt
execAst := &ast.ExecuteStmt{ExecID: stmtID}
if err := executor.ResetContextOfStmt(s, execAst); err != nil {
return nil, false, err
}
execAst := sessiontxn.GetTxnManager(s).GetCurrentStmt().(*ast.ExecuteStmt)

failpoint.Inject("assertTxnManagerInCachedPlanExec", func() {
sessiontxn.RecordAssert(s, "assertTxnManagerInCachedPlanExec", true)
Expand Down Expand Up @@ -2406,7 +2404,12 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [
s.txn.onStmtStart(preparedStmt.SQLDigest.String())
defer s.txn.onStmtEnd()

if err = s.onTxnManagerStmtStartOrRetry(ctx); err != nil {
execStmt := &ast.ExecuteStmt{ExecID: stmtID}
if err := executor.ResetContextOfStmt(s, execStmt); err != nil {
return nil, err
}

if err = s.onTxnManagerStmtStartOrRetry(ctx, execStmt); err != nil {
return nil, err
}

Expand Down
12 changes: 10 additions & 2 deletions session/txnmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func getTxnManager(sctx sessionctx.Context) sessiontxn.TxnManager {
type txnManager struct {
sctx sessionctx.Context

stmtNode ast.StmtNode
ctxProvider sessiontxn.TxnContextProvider
}

Expand Down Expand Up @@ -116,14 +117,21 @@ func (m *txnManager) EnterNewTxn(ctx context.Context, r *sessiontxn.EnterNewTxnR

func (m *txnManager) OnTxnEnd() {
m.ctxProvider = nil
m.stmtNode = nil
}

func (m *txnManager) GetCurrentStmt() ast.StmtNode {
return m.stmtNode
}
SpadeA-Tang marked this conversation as resolved.
Show resolved Hide resolved

// OnStmtStart is the hook that should be called when a new statement started
func (m *txnManager) OnStmtStart(ctx context.Context) error {
func (m *txnManager) OnStmtStart(ctx context.Context, node ast.StmtNode) error {
m.stmtNode = node

if m.ctxProvider == nil {
return errors.New("context provider not set")
}
return m.ctxProvider.OnStmtStart(ctx)
return m.ctxProvider.OnStmtStart(ctx, m.stmtNode)
}

// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error
Expand Down
9 changes: 6 additions & 3 deletions sessiontxn/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ type TxnContextProvider interface {
// OnInitialize is the hook that should be called when enter a new txn with this provider
OnInitialize(ctx context.Context, enterNewTxnType EnterNewTxnType) error
// OnStmtStart is the hook that should be called when a new statement started
OnStmtStart(ctx context.Context) error
OnStmtStart(ctx context.Context, node ast.StmtNode) error
// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error
OnStmtErrorForNextAction(point StmtErrorHandlePoint, err error) (StmtErrorAction, error)
// OnStmtRetry is the hook that should be called when a statement is retried internally.
Expand All @@ -150,14 +150,16 @@ type TxnManager interface {
// OnTxnEnd is the hook that should be called after transaction commit or rollback
OnTxnEnd()
// OnStmtStart is the hook that should be called when a new statement started
OnStmtStart(ctx context.Context) error
OnStmtStart(ctx context.Context, node ast.StmtNode) error
// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error
// This method is not required to be called for every error in the statement,
// it is only required to be called for some errors handled in some specified points given by the parameter `point`.
// When the return error is not nil the return action is 'StmtActionError' and vice versa.
OnStmtErrorForNextAction(point StmtErrorHandlePoint, err error) (StmtErrorAction, error)
// OnStmtRetry is the hook that should be called when a statement retry
OnStmtRetry(ctx context.Context) error
// GetCurrentStmt returns the current statement node
GetCurrentStmt() ast.StmtNode
}

// NewTxn starts a new optimistic and active txn, it can be used for the below scenes:
Expand All @@ -178,7 +180,8 @@ func NewTxnInStmt(ctx context.Context, sctx sessionctx.Context) error {
if err := NewTxn(ctx, sctx); err != nil {
return err
}
return GetTxnManager(sctx).OnStmtStart(ctx)
txnManager := GetTxnManager(sctx)
return txnManager.OnStmtStart(ctx, txnManager.GetCurrentStmt())
}

// GetTxnManager returns the TxnManager object from session context
Expand Down
3 changes: 2 additions & 1 deletion sessiontxn/isolation/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
Expand Down Expand Up @@ -138,7 +139,7 @@ func (p *baseTxnContextProvider) GetStmtForUpdateTS() (uint64, error) {
return p.getStmtForUpdateTSFunc()
}

func (p *baseTxnContextProvider) OnStmtStart(ctx context.Context) error {
func (p *baseTxnContextProvider) OnStmtStart(ctx context.Context, _ ast.StmtNode) error {
p.ctx = ctx
return nil
}
Expand Down
28 changes: 14 additions & 14 deletions sessiontxn/isolation/optimistic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) {
se := tk.Session()
compareTS := getOracleTS(t, se)
provider := initializeOptimisticProvider(t, tk, true)
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
readTS, err := provider.GetStmtReadTS()
require.NoError(t, err)
updateTS, err := provider.GetStmtForUpdateTS()
Expand All @@ -59,7 +59,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) {
compareTS = readTS

// for optimistic mode ts, ts should be the same for all statements
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
readTS, err = provider.GetStmtReadTS()
require.NoError(t, err)
updateTS, err = provider.GetStmtForUpdateTS()
Expand All @@ -72,7 +72,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) {
require.NoError(t, err)
stmt := stmts[0]
provider = initializeOptimisticProvider(t, tk, false)
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), stmt))
plan, _, err := planner.Optimize(context.TODO(), tk.Session(), stmt, provider.GetTxnInfoSchema())
require.NoError(t, err)
require.NoError(t, provider.AdviseOptimizeWithPlan(plan))
Expand All @@ -85,7 +85,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) {

// if the oracle future is prepared fist, `math.MaxUint64` should still be used after plan
provider = initializeOptimisticProvider(t, tk, false)
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), stmt))
require.NoError(t, provider.AdviseWarmup())
plan, _, err = planner.Optimize(context.TODO(), tk.Session(), stmt, provider.GetTxnInfoSchema())
require.NoError(t, err)
Expand All @@ -100,7 +100,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) {
// when it is in explicit txn, we should not use `math.MaxUint64`
compareTS = getOracleTS(t, se)
provider = initializeOptimisticProvider(t, tk, true)
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), stmt))
plan, _, err = planner.Optimize(context.TODO(), tk.Session(), stmt, provider.GetTxnInfoSchema())
require.NoError(t, err)
require.NoError(t, provider.AdviseOptimizeWithPlan(plan))
Expand All @@ -115,7 +115,7 @@ func TestOptimisticTxnContextProviderTS(t *testing.T) {
tk.MustExec("set @@autocommit=0")
compareTS = getOracleTS(t, se)
provider = initializeOptimisticProvider(t, tk, false)
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), stmt))
plan, _, err = planner.Optimize(context.TODO(), tk.Session(), stmt, provider.GetTxnInfoSchema())
require.NoError(t, err)
require.NoError(t, provider.AdviseOptimizeWithPlan(plan))
Expand Down Expand Up @@ -175,15 +175,15 @@ func TestOptimisticHandleError(t *testing.T) {
}

for _, c := range cases {
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
action, err := provider.OnStmtErrorForNextAction(c.point, c.err)
if c.point == sessiontxn.StmtErrAfterPessimisticLock {
require.Error(t, err)
require.Same(t, c.err, err)
require.Equal(t, sessiontxn.StmtActionError, action)

// next statement should not update ts
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
checkTS()
} else {
require.NoError(t, err)
Expand All @@ -194,13 +194,13 @@ func TestOptimisticHandleError(t *testing.T) {
checkTS()

// OnStmtErrorForNextAction again
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
action, err = provider.OnStmtErrorForNextAction(c.point, c.err)
require.NoError(t, err)
require.Equal(t, sessiontxn.StmtActionNoIdea, action)

// next statement should not update ts
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
checkTS()
}
}
Expand Down Expand Up @@ -280,7 +280,7 @@ func TestOptimisticProviderInitialize(t *testing.T) {
assertAfterActive.couldRetry = c.autocommit || !c.disableTxnAutoRetry
require.NoError(t, se.PrepareTxnCtx(context.TODO()))
provider := assert.CheckAndGetProvider(t)
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
ts, err := provider.GetStmtReadTS()
require.NoError(t, err)
assertAfterActive.Check(t)
Expand Down Expand Up @@ -343,12 +343,12 @@ func TestTidbSnapshotVarInOptimisticTxn(t *testing.T) {
}

// information schema and ts should equal to snapshot when tidb_snapshot is set
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
checkUseSnapshot()

// information schema and ts will restore when set tidb_snapshot to empty
tk.MustExec("set @@tidb_snapshot=''")
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
checkUseTxn()

// txn will not be active after `GetStmtReadTS` or `GetStmtForUpdateTS` when `tidb_snapshot` is set
Expand All @@ -368,7 +368,7 @@ func TestTidbSnapshotVarInOptimisticTxn(t *testing.T) {
assertAfterUseSnapshot := activeSnapshotTxnAssert(se, se.GetSessionVars().SnapshotTS, "")
require.NoError(t, se.PrepareTxnCtx(context.TODO()))
provider = assert.CheckAndGetProvider(t)
require.NoError(t, provider.OnStmtStart(context.TODO()))
require.NoError(t, provider.OnStmtStart(context.TODO(), nil))
checkUseSnapshot()
assertAfterUseSnapshot.Check(t)
}()
Expand Down
Loading