Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

txn: manage the initialization of RCCheckTS by transaction context provider #35554

Merged
merged 24 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions executor/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ type ExecStmt struct {
Ti *TelemetryInfo
}

// GetStmtNode returns the stmtNode inside Statement
func (a ExecStmt) GetStmtNode() ast.StmtNode {
return a.StmtNode
}

// PointGet short path for point exec directly from plan, keep only necessary steps
func (a *ExecStmt) PointGet(ctx context.Context, is infoschema.InfoSchema) (*recordSet, error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
Expand Down
15 changes: 0 additions & 15 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1904,11 +1904,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.NotFillCache = !opts.SQLCache
}
sc.WeakConsistency = isWeakConsistencyRead(ctx, stmt)
// Try to mark the `RCCheckTS` flag for the first time execution of in-transaction read requests
// using read-consistency isolation level.
if NeedSetRCCheckTSFlag(ctx, stmt) {
sc.RCCheckTS = true
}
case *ast.SetOprStmt:
sc.InSelectStmt = true
sc.OverflowAsWarning = true
Expand Down Expand Up @@ -2042,13 +2037,3 @@ func isWeakConsistencyRead(ctx sessionctx.Context, node ast.Node) bool {
return sessionVars.ConnectionID > 0 && sessionVars.ReadConsistency.IsWeak() &&
plannercore.IsAutoCommitTxn(ctx) && plannercore.IsReadOnly(node, sessionVars)
}

// NeedSetRCCheckTSFlag checks whether it's needed to set `RCCheckTS` flag in current stmtctx.
func NeedSetRCCheckTSFlag(ctx sessionctx.Context, node ast.Node) bool {
sessionVars := ctx.GetSessionVars()
if sessionVars.ConnectionID > 0 && sessionVars.RcReadCheckTS && sessionVars.InTxn() &&
sessionVars.IsPessimisticReadConsistency() && !sessionVars.RetryInfo.Retrying && plannercore.IsReadOnly(node, sessionVars) {
return true
}
return false
}
2 changes: 1 addition & 1 deletion executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3470,7 +3470,7 @@ func TestUnreasonablyClose(t *testing.T) {
err = sessiontxn.NewTxn(context.Background(), tk.Session())
require.NoError(t, err, comment)

err = sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO())
err = sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), stmt)
require.NoError(t, err, comment)

executorBuilder := executor.NewMockExecutorBuilderForTest(tk.Session(), is, nil, oracle.GlobalTxnScope)
Expand Down
8 changes: 2 additions & 6 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,11 @@ func (e *DeallocateExec) Next(ctx context.Context, req *chunk.Chunk) error {

// CompileExecutePreparedStmt compiles a session Execute command to a stmt.Statement.
func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context,
ID uint32, is infoschema.InfoSchema, snapshotTS uint64, replicaReadScope string, args []types.Datum) (*ExecStmt, bool, bool, error) {
execStmt *ast.ExecuteStmt, is infoschema.InfoSchema, snapshotTS uint64, replicaReadScope string, args []types.Datum) (*ExecStmt, bool, bool, error) {
startTime := time.Now()
defer func() {
sctx.GetSessionVars().DurationCompile = time.Since(startTime)
}()
execStmt := &ast.ExecuteStmt{ExecID: ID}
if err := ResetContextOfStmt(sctx, execStmt); err != nil {
return nil, false, false, err
}
isStaleness := snapshotTS != 0
sctx.GetSessionVars().StmtCtx.IsStaleness = isStaleness
execStmt.BinaryArgs = args
Expand Down Expand Up @@ -369,7 +365,7 @@ func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context,
Ti: &TelemetryInfo{},
ReplicaReadScope: replicaReadScope,
}
if preparedPointer, ok := sctx.GetSessionVars().PreparedStmts[ID]; ok {
if preparedPointer, ok := sctx.GetSessionVars().PreparedStmts[execStmt.ExecID]; ok {
preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt)
if !ok {
return nil, false, false, errors.Errorf("invalid CachedPrepareStmt type")
Expand Down
4 changes: 3 additions & 1 deletion executor/seqtest/prepared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
Expand Down Expand Up @@ -157,8 +158,9 @@ func TestPrepared(t *testing.T) {
require.NoError(t, err)
tk.ResultSetToResult(rs, fmt.Sprintf("%v", rs)).Check(testkit.Rows())

execStmt := &ast.ExecuteStmt{ExecID: stmtID}
// Check that ast.Statement created by executor.CompileExecutePreparedStmt has query text.
stmt, _, _, err := executor.CompileExecutePreparedStmt(context.TODO(), tk.Session(), stmtID,
stmt, _, _, err := executor.CompileExecutePreparedStmt(context.TODO(), tk.Session(), execStmt,
tk.Session().GetInfoSchema().(infoschema.InfoSchema), 0, kv.GlobalReplicaScope, []types.Datum{types.NewDatum(1)})
require.NoError(t, err)
require.Equal(t, query, stmt.OriginText())
Expand Down
2 changes: 1 addition & 1 deletion planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ func (p *preprocessor) updateStateFromStaleReadProcessor() error {
if err := txnManager.EnterNewTxn(context.TODO(), newTxnRequest); err != nil {
return err
}
if err := txnManager.OnStmtStart(context.TODO()); err != nil {
if err := txnManager.OnStmtStart(context.TODO(), txnManager.GetCurrentStmt()); err != nil {
return err
}
}
Expand Down
4 changes: 2 additions & 2 deletions planner/funcdep/extract_fd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func TestFDSet_ExtractFD(t *testing.T) {
for i, tt := range tests {
comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql)
require.NoError(t, tk.Session().PrepareTxnCtx(context.TODO()))
require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO()))
require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), nil))
stmt, err := par.ParseOneStmt(tt.sql, "", "")
require.NoError(t, err, comment)
tk.Session().GetSessionVars().PlanID = 0
Expand Down Expand Up @@ -312,7 +312,7 @@ func TestFDSet_ExtractFDForApply(t *testing.T) {
is := testGetIS(t, tk.Session())
for i, tt := range tests {
require.NoError(t, tk.Session().PrepareTxnCtx(context.TODO()))
require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO()))
require.NoError(t, sessiontxn.GetTxnManager(tk.Session()).OnStmtStart(context.TODO(), nil))
comment := fmt.Sprintf("case:%v sql:%s", i, tt.sql)
stmt, err := par.ParseOneStmt(tt.sql, "", "")
require.NoError(t, err, comment)
Expand Down
4 changes: 3 additions & 1 deletion session/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/benchdaily"
Expand Down Expand Up @@ -1812,8 +1813,9 @@ func BenchmarkCompileExecutePreparedStmt(b *testing.B) {
is := se.GetInfoSchema()

b.ResetTimer()
stmtExec := &ast.ExecuteStmt{ExecID: stmtID}
for i := 0; i < b.N; i++ {
_, _, _, err := executor.CompileExecutePreparedStmt(context.Background(), se, stmtID, is.(infoschema.InfoSchema), 0, kv.GlobalTxnScope, args)
_, _, _, err := executor.CompileExecutePreparedStmt(context.Background(), se, stmtExec, is.(infoschema.InfoSchema), 0, kv.GlobalTxnScope, args)
if err != nil {
b.Fatal(err)
}
Expand Down
33 changes: 17 additions & 16 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) {
}
_, digest := s.sessionVars.StmtCtx.SQLDigest()
s.txn.onStmtStart(digest.String())
if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx); err == nil {
if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx, st.GetStmtNode()); err == nil {
_, err = st.Exec(ctx)
}
s.txn.onStmtEnd()
Expand Down Expand Up @@ -1918,7 +1918,7 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex
s.txn.onStmtStart(digest.String())
defer s.txn.onStmtEnd()

if err := s.onTxnManagerStmtStartOrRetry(ctx); err != nil {
if err := s.onTxnManagerStmtStartOrRetry(ctx, stmtNode); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1981,11 +1981,11 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex
return recordSet, nil
}

func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context) error {
func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context, node ast.StmtNode) error {
if s.sessionVars.RetryInfo.Retrying {
return sessiontxn.GetTxnManager(s).OnStmtRetry(ctx)
}
return sessiontxn.GetTxnManager(s).OnStmtStart(ctx)
return sessiontxn.GetTxnManager(s).OnStmtStart(ctx, node)
}

func (s *session) validateStatementReadOnlyInStaleness(stmtNode ast.StmtNode) error {
Expand Down Expand Up @@ -2202,7 +2202,8 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields
return
}

if err = s.onTxnManagerStmtStartOrRetry(ctx); err != nil {
prepareStmt := &ast.PrepareStmt{SQLText: sql}
if err = s.onTxnManagerStmtStartOrRetry(ctx, prepareStmt); err != nil {
return
}

Expand All @@ -2223,7 +2224,7 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields

func (s *session) preparedStmtExec(ctx context.Context,
is infoschema.InfoSchema, snapshotTS uint64,
stmtID uint32, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, error) {
execStmt *ast.ExecuteStmt, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, error) {

failpoint.Inject("assertTxnManagerInPreparedStmtExec", func() {
sessiontxn.RecordAssert(s, "assertTxnManagerInPreparedStmtExec", true)
Expand All @@ -2233,7 +2234,7 @@ func (s *session) preparedStmtExec(ctx context.Context,
}
})

st, tiFlashPushDown, tiFlashExchangePushDown, err := executor.CompileExecutePreparedStmt(ctx, s, stmtID, is, snapshotTS, replicaReadScope, args)
st, tiFlashPushDown, tiFlashExchangePushDown, err := executor.CompileExecutePreparedStmt(ctx, s, execStmt, is, snapshotTS, replicaReadScope, args)
if err != nil {
return nil, err
}
Expand All @@ -2253,14 +2254,9 @@ func (s *session) preparedStmtExec(ctx context.Context,

// cachedPointPlanExec is a short path currently ONLY for cached "point select plan" execution
func (s *session) cachedPointPlanExec(ctx context.Context,
is infoschema.InfoSchema, stmtID uint32, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, bool, error) {
is infoschema.InfoSchema, execAst *ast.ExecuteStmt, prepareStmt *plannercore.CachedPrepareStmt, replicaReadScope string, args []types.Datum) (sqlexec.RecordSet, bool, error) {

prepared := prepareStmt.PreparedAst
// compile ExecStmt
execAst := &ast.ExecuteStmt{ExecID: stmtID}
if err := executor.ResetContextOfStmt(s, execAst); err != nil {
return nil, false, err
}

failpoint.Inject("assertTxnManagerInCachedPlanExec", func() {
sessiontxn.RecordAssert(s, "assertTxnManagerInCachedPlanExec", true)
Expand Down Expand Up @@ -2423,7 +2419,12 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [
s.txn.onStmtStart(preparedStmt.SQLDigest.String())
defer s.txn.onStmtEnd()

if err = s.onTxnManagerStmtStartOrRetry(ctx); err != nil {
execStmt := &ast.ExecuteStmt{ExecID: stmtID}
if err := executor.ResetContextOfStmt(s, execStmt); err != nil {
return nil, err
}

if err = s.onTxnManagerStmtStartOrRetry(ctx, execStmt); err != nil {
return nil, err
}

Expand All @@ -2432,15 +2433,15 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args [
}

if ok {
rs, ok, err := s.cachedPointPlanExec(ctx, txnManager.GetTxnInfoSchema(), stmtID, preparedStmt, replicaReadScope, args)
rs, ok, err := s.cachedPointPlanExec(ctx, txnManager.GetTxnInfoSchema(), execStmt, preparedStmt, replicaReadScope, args)
if err != nil {
return nil, err
}
if ok { // fallback to preparedStmtExec if we cannot get a valid point select plan in cachedPointPlanExec
return rs, nil
}
}
return s.preparedStmtExec(ctx, txnManager.GetTxnInfoSchema(), snapshotTS, stmtID, preparedStmt, replicaReadScope, args)
return s.preparedStmtExec(ctx, txnManager.GetTxnInfoSchema(), snapshotTS, execStmt, preparedStmt, replicaReadScope, args)
}

func (s *session) DropPreparedStmt(stmtID uint32) error {
Expand Down
12 changes: 10 additions & 2 deletions session/txnmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func getTxnManager(sctx sessionctx.Context) sessiontxn.TxnManager {
type txnManager struct {
sctx sessionctx.Context

stmtNode ast.StmtNode
ctxProvider sessiontxn.TxnContextProvider
}

Expand Down Expand Up @@ -116,14 +117,21 @@ func (m *txnManager) EnterNewTxn(ctx context.Context, r *sessiontxn.EnterNewTxnR

func (m *txnManager) OnTxnEnd() {
m.ctxProvider = nil
m.stmtNode = nil
}

func (m *txnManager) GetCurrentStmt() ast.StmtNode {
return m.stmtNode
}
SpadeA-Tang marked this conversation as resolved.
Show resolved Hide resolved

// OnStmtStart is the hook that should be called when a new statement started
func (m *txnManager) OnStmtStart(ctx context.Context) error {
func (m *txnManager) OnStmtStart(ctx context.Context, node ast.StmtNode) error {
m.stmtNode = node

if m.ctxProvider == nil {
return errors.New("context provider not set")
}
return m.ctxProvider.OnStmtStart(ctx)
return m.ctxProvider.OnStmtStart(ctx, m.stmtNode)
}

// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error
Expand Down
9 changes: 6 additions & 3 deletions sessiontxn/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ type TxnContextProvider interface {
// OnInitialize is the hook that should be called when enter a new txn with this provider
OnInitialize(ctx context.Context, enterNewTxnType EnterNewTxnType) error
// OnStmtStart is the hook that should be called when a new statement started
OnStmtStart(ctx context.Context) error
OnStmtStart(ctx context.Context, node ast.StmtNode) error
// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error
OnStmtErrorForNextAction(point StmtErrorHandlePoint, err error) (StmtErrorAction, error)
// OnStmtRetry is the hook that should be called when a statement is retried internally.
Expand All @@ -150,14 +150,16 @@ type TxnManager interface {
// OnTxnEnd is the hook that should be called after transaction commit or rollback
OnTxnEnd()
// OnStmtStart is the hook that should be called when a new statement started
OnStmtStart(ctx context.Context) error
OnStmtStart(ctx context.Context, node ast.StmtNode) error
// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error
// This method is not required to be called for every error in the statement,
// it is only required to be called for some errors handled in some specified points given by the parameter `point`.
// When the return error is not nil the return action is 'StmtActionError' and vice versa.
OnStmtErrorForNextAction(point StmtErrorHandlePoint, err error) (StmtErrorAction, error)
// OnStmtRetry is the hook that should be called when a statement retry
OnStmtRetry(ctx context.Context) error
// GetCurrentStmt returns the current statement node
GetCurrentStmt() ast.StmtNode
}

// NewTxn starts a new optimistic and active txn, it can be used for the below scenes:
Expand All @@ -178,7 +180,8 @@ func NewTxnInStmt(ctx context.Context, sctx sessionctx.Context) error {
if err := NewTxn(ctx, sctx); err != nil {
return err
}
return GetTxnManager(sctx).OnStmtStart(ctx)
txnManager := GetTxnManager(sctx)
return txnManager.OnStmtStart(ctx, txnManager.GetCurrentStmt())
}

// GetTxnManager returns the TxnManager object from session context
Expand Down
3 changes: 2 additions & 1 deletion sessiontxn/isolation/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/sessiontxn"
Expand Down Expand Up @@ -138,7 +139,7 @@ func (p *baseTxnContextProvider) GetStmtForUpdateTS() (uint64, error) {
return p.getStmtForUpdateTSFunc()
}

func (p *baseTxnContextProvider) OnStmtStart(ctx context.Context) error {
func (p *baseTxnContextProvider) OnStmtStart(ctx context.Context, _ ast.StmtNode) error {
p.ctx = ctx
return nil
}
Expand Down
Loading