diff --git a/error/error.go b/error/error.go index 62a75ff65..14f29a440 100644 --- a/error/error.go +++ b/error/error.go @@ -64,13 +64,14 @@ var ( // ErrTiFlashServerTimeout is the error when tiflash server is timeout. ErrTiFlashServerTimeout = errors.New("tiflash server timeout") // ErrQueryInterrupted is the error when the query is interrupted. - ErrQueryInterrupted = errors.New("query interruppted") + // This is deprecated. Keep it only to pass CI :-(. We can remove this later. + ErrQueryInterrupted = errors.New("query interrupted") // ErrTiKVStaleCommand is the error that the command is stale in tikv. ErrTiKVStaleCommand = errors.New("tikv stale command") // ErrTiKVMaxTimestampNotSynced is the error that tikv's max timestamp is not synced. ErrTiKVMaxTimestampNotSynced = errors.New("tikv max timestamp not synced") // ErrLockAcquireFailAndNoWaitSet is the error that acquire the lock failed while no wait is setted. - ErrLockAcquireFailAndNoWaitSet = errors.New("lock acquired failed and no wait is setted") + ErrLockAcquireFailAndNoWaitSet = errors.New("lock acquired failed and no wait is set") // ErrResolveLockTimeout is the error that resolve lock timeout. ErrResolveLockTimeout = errors.New("resolve lock timeout") // ErrLockWaitTimeout is the error that wait for the lock is timeout. @@ -96,11 +97,19 @@ var ( // ErrIsWitness is the error when a request is send to a witness. ErrIsWitness = errors.New("peer is witness") // ErrUnknown is the unknow error. - ErrUnknown = errors.New("unknow") + ErrUnknown = errors.New("unknown") // ErrResultUndetermined is the error when execution result is unknown. ErrResultUndetermined = errors.New("execution result undetermined") ) +type ErrQueryInterruptedWithSignal struct { + Signal uint32 +} + +func (e ErrQueryInterruptedWithSignal) Error() string { + return fmt.Sprintf("query interrupted by signal %d", e.Signal) +} + // MismatchClusterID represents the message that the cluster ID of the PD client does not match the PD. const MismatchClusterID = "mismatch cluster id" diff --git a/integration_tests/2pc_test.go b/integration_tests/2pc_test.go index 355dafeae..9369bfb66 100644 --- a/integration_tests/2pc_test.go +++ b/integration_tests/2pc_test.go @@ -2502,3 +2502,13 @@ func (s *testCommitterSuite) TestExtractKeyExistsErr() { s.True(txn.GetMemBuffer().TryLock()) txn.GetMemBuffer().Unlock() } + +func (s *testCommitterSuite) TestKillSignal() { + txn := s.begin() + err := txn.Set([]byte("key"), []byte("value")) + s.Nil(err) + var killed uint32 = 2 + txn.SetVars(kv.NewVariables(&killed)) + err = txn.Commit(context.Background()) + s.ErrorContains(err, "query interrupted") +} diff --git a/internal/locate/region_request.go b/internal/locate/region_request.go index a83b93dcc..514c76e17 100644 --- a/internal/locate/region_request.go +++ b/internal/locate/region_request.go @@ -1474,9 +1474,8 @@ func (s *RegionRequestSender) SendReqCtx( } // recheck whether the session/query is killed during the Next() - boVars := bo.GetVars() - if boVars != nil && boVars.Killed != nil && atomic.LoadUint32(boVars.Killed) == 1 { - return nil, nil, retryTimes, errors.WithStack(tikverr.ErrQueryInterrupted) + if err2 := bo.CheckKilled(); err2 != nil { + return nil, nil, retryTimes, err2 } if val, err := util.EvalFailpoint("mockRetrySendReqToRegion"); err == nil { if val.(bool) { diff --git a/internal/retry/backoff.go b/internal/retry/backoff.go index a2723e05b..c18577ad0 100644 --- a/internal/retry/backoff.go +++ b/internal/retry/backoff.go @@ -217,10 +217,9 @@ func (b *Backoffer) BackoffWithCfgAndMaxSleep(cfg *Config, maxSleepMs int, err e atomic.AddInt64(&detail.BackoffCount, 1) } - if b.vars != nil && b.vars.Killed != nil { - if atomic.LoadUint32(b.vars.Killed) == 1 { - return errors.WithStack(tikverr.ErrQueryInterrupted) - } + err2 := b.CheckKilled() + if err2 != nil { + return err2 } var startTs interface{} @@ -382,3 +381,17 @@ func (b *Backoffer) longestSleepCfg() (*Config, int) { } return nil, 0 } + +func (b *Backoffer) CheckKilled() error { + if b.vars != nil && b.vars.Killed != nil { + killed := atomic.LoadUint32(b.vars.Killed) + if killed != 0 { + logutil.BgLogger().Info( + "backoff stops because a killed signal is received", + zap.Uint32("signal", killed), + ) + return errors.WithStack(tikverr.ErrQueryInterruptedWithSignal{Signal: killed}) + } + } + return nil +} diff --git a/kv/variables.go b/kv/variables.go index 581be54d0..cae78c9c5 100644 --- a/kv/variables.go +++ b/kv/variables.go @@ -44,6 +44,10 @@ type Variables struct { // Pointer to SessionVars.Killed // Killed is a flag to indicate that this query is killed. + // This is an enum value rather than a boolean. See sqlkiller.go + // in TiDB for its definition. + // When its value is 0, it's not killed + // When its value is not 0, it's killed, the value indicates concrete reason. Killed *uint32 } diff --git a/txnkv/transaction/2pc.go b/txnkv/transaction/2pc.go index c7d99333b..b7471f87a 100644 --- a/txnkv/transaction/2pc.go +++ b/txnkv/transaction/2pc.go @@ -1048,7 +1048,27 @@ func (c *twoPhaseCommitter) doActionOnGroupMutations(bo *retry.Backoffer, action } // doActionOnBatches does action to batches in parallel. -func (c *twoPhaseCommitter) doActionOnBatches(bo *retry.Backoffer, action twoPhaseCommitAction, batches []batchMutations) error { +func (c *twoPhaseCommitter) doActionOnBatches( + bo *retry.Backoffer, action twoPhaseCommitAction, + batches []batchMutations, +) error { + // killSignal should never be nil for TiDB + if c.txn != nil && c.txn.vars != nil && c.txn.vars.Killed != nil { + // Do not reset the killed flag here. Let the upper layer reset the flag. + // Before it resets, any request is considered valid to be killed. + status := atomic.LoadUint32(c.txn.vars.Killed) + if status != 0 { + logutil.BgLogger().Info( + "query is killed", zap.Uint32( + "signal", + status, + ), + ) + // TODO: There might be various signals besides a query interruption, + // but we are unable to differentiate them, because the definition is in TiDB. + return errors.WithStack(tikverr.ErrQueryInterruptedWithSignal{Signal: status}) + } + } if len(batches) == 0 { return nil } diff --git a/txnkv/transaction/pessimistic.go b/txnkv/transaction/pessimistic.go index 28835baeb..3db567028 100644 --- a/txnkv/transaction/pessimistic.go +++ b/txnkv/transaction/pessimistic.go @@ -214,18 +214,6 @@ func (action actionPessimisticLock) handleSingleBatch( return nil } } - - // Handle the killed flag when waiting for the pessimistic lock. - // When a txn runs into LockKeys() and backoff here, it has no chance to call - // executor.Next() and check the killed flag. - if action.Killed != nil { - // Do not reset the killed flag here! - // actionPessimisticLock runs on each region parallelly, we have to consider that - // the error may be dropped. - if atomic.LoadUint32(action.Killed) == 1 { - return errors.WithStack(tikverr.ErrQueryInterrupted) - } - } } } diff --git a/txnkv/transaction/txn.go b/txnkv/transaction/txn.go index 2d5407a39..94ef35964 100644 --- a/txnkv/transaction/txn.go +++ b/txnkv/transaction/txn.go @@ -1111,12 +1111,6 @@ func (txn *KVTxn) lockKeys(ctx context.Context, lockCtx *tikv.LockCtx, fn func() lockCtx.Stats.Mu.BackoffTypes = append(lockCtx.Stats.Mu.BackoffTypes, bo.GetTypes()...) lockCtx.Stats.Mu.Unlock() } - if lockCtx.Killed != nil { - // If the kill signal is received during waiting for pessimisticLock, - // pessimisticLockKeys would handle the error but it doesn't reset the flag. - // We need to reset the killed flag here. - atomic.CompareAndSwapUint32(lockCtx.Killed, 1, 0) - } if txn.IsInAggressiveLockingMode() { if txn.aggressiveLockingContext.maxLockedWithConflictTS < lockCtx.MaxLockedWithConflictTS { txn.aggressiveLockingContext.maxLockedWithConflictTS = lockCtx.MaxLockedWithConflictTS