diff --git a/executor/adapter_test.go b/executor/adapter_test.go index 012075c8e4a0c..38e7735e35292 100644 --- a/executor/adapter_test.go +++ b/executor/adapter_test.go @@ -15,13 +15,19 @@ package executor_test import ( + "context" + "sync" "testing" "time" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/executor" + "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/store/copr" "github.com/pingcap/tidb/testkit" "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/util" ) func TestQueryTime(t *testing.T) { @@ -52,3 +58,33 @@ func TestFormatSQL(t *testing.T) { val = executor.FormatSQL("aaaaaaaaaaaaaaaaaaaa") require.Equal(t, "aaaaa(len:20)", val.String()) } + +func TestContextCancelWhenReadFromCopIterator(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values(1)") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/store/copr/CtxCancelBeforeReceive", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/store/copr/CtxCancelBeforeReceive")) + }() + ctx := context.WithValue(context.Background(), "TestContextCancel", "test") + ctx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + ctx = util.WithInternalSourceType(ctx, "scheduler") + rs, err := tk.Session().ExecuteInternal(ctx, "select * from test.t") + require.NoError(t, err) + _, err2 := session.ResultSetToStringSlice(ctx, tk.Session(), rs) + require.ErrorIs(t, err2, context.Canceled) + }() + <-copr.GlobalSyncChForTest + cancelFunc() + copr.GlobalSyncChForTest <- struct{}{} + wg.Wait() +} diff --git a/store/copr/coprocessor.go b/store/copr/coprocessor.go index 074081c9c687d..f1df2a10c360d 100644 --- a/store/copr/coprocessor.go +++ b/store/copr/coprocessor.go @@ -899,7 +899,16 @@ func (sender *copIteratorTaskSender) run(connID uint64) { } } +// GlobalSyncChForTest is a global channel for test. +var GlobalSyncChForTest = make(chan struct{}) + func (it *copIterator) recvFromRespCh(ctx context.Context, respCh <-chan *copResponse) (resp *copResponse, ok bool, exit bool) { + failpoint.Inject("CtxCancelBeforeReceive", func(_ failpoint.Value) { + if ctx.Value("TestContextCancel") == "test" { + GlobalSyncChForTest <- struct{}{} + <-GlobalSyncChForTest + } + }) ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() for { @@ -1036,7 +1045,7 @@ func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { resp, ok, closed = it.recvFromRespCh(ctx, it.respChan) if !ok || closed { it.actionOnExceed.close() - return nil, nil + return nil, errors.Trace(ctx.Err()) } if resp == finCopResp { it.actionOnExceed.destroyTokenIfNeeded(func() { @@ -1054,8 +1063,8 @@ func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { task := it.tasks[it.curr] resp, ok, closed = it.recvFromRespCh(ctx, task.respChan) if closed { - // Close() is already called, so Next() is invalid. - return nil, nil + // Close() is called or context cancelled/timeout, so Next() is invalid. + return nil, errors.Trace(ctx.Err()) } if ok { break