Skip to content

Commit

Permalink
ddl: refine some context usage (#56243)
Browse files Browse the repository at this point in the history
ref #54436, ref #56017
  • Loading branch information
lance6716 authored Sep 27, 2024
1 parent 9feedd9 commit bad2ecd
Show file tree
Hide file tree
Showing 27 changed files with 184 additions and 114 deletions.
2 changes: 1 addition & 1 deletion br/pkg/backup/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ func WriteBackupDDLJobs(metaWriter *metautil.MetaWriter, g glue.Glue, store kv.S
newestMeta := meta.NewSnapshotMeta(store.GetSnapshot(kv.NewVersion(version.Ver)))
var allJobs []*model.Job
err = g.UseOneShotSession(store, !needDomain, func(se glue.Session) error {
allJobs, err = ddl.GetAllDDLJobs(se.GetSessionCtx())
allJobs, err = ddl.GetAllDDLJobs(context.Background(), se.GetSessionCtx())
if err != nil {
return errors.Trace(err)
}
Expand Down
14 changes: 8 additions & 6 deletions pkg/ddl/backfilling.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,9 @@ func (dc *ddlCtx) runAddIndexInLocalIngestMode(
return errors.Trace(err)
}
job := reorgInfo.Job
opCtx := NewLocalOperatorCtx(ctx, job.ID)
opCtx, cancel := NewLocalOperatorCtx(ctx, job.ID)
defer cancel()

idxCnt := len(reorgInfo.elements)
indexIDs := make([]int64, 0, idxCnt)
indexInfos := make([]*model.IndexInfo, 0, idxCnt)
Expand Down Expand Up @@ -705,11 +707,6 @@ func (dc *ddlCtx) runAddIndexInLocalIngestMode(
return errors.Trace(err)
}
defer ingest.LitBackCtxMgr.Unregister(job.ID)
sctx, err := sessPool.Get()
if err != nil {
return errors.Trace(err)
}
defer sessPool.Put(sctx)

cpMgr, err := ingest.NewCheckpointManager(
ctx,
Expand Down Expand Up @@ -737,6 +734,11 @@ func (dc *ddlCtx) runAddIndexInLocalIngestMode(
metrics.GenerateReorgLabel("add_idx_rate", job.SchemaName, job.TableName)),
}

sctx, err := sessPool.Get()
if err != nil {
return errors.Trace(err)
}
defer sessPool.Put(sctx)
avgRowSize := estimateTableRowSize(ctx, dc.store, sctx.GetRestrictedSQLExecutor(), t)

engines, err := bcCtx.Register(indexIDs, uniques, t)
Expand Down
24 changes: 12 additions & 12 deletions pkg/ddl/backfilling_operators.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,33 @@ type OperatorCtx struct {
}

// NewDistTaskOperatorCtx is used for adding index with dist framework.
func NewDistTaskOperatorCtx(ctx context.Context, taskID, subtaskID int64) *OperatorCtx {
func NewDistTaskOperatorCtx(
ctx context.Context,
taskID, subtaskID int64,
) (*OperatorCtx, context.CancelFunc) {
opCtx, cancel := context.WithCancel(ctx)
opCtx = logutil.WithFields(opCtx, zap.Int64("task-id", taskID), zap.Int64("subtask-id", subtaskID))
opCtx = logutil.WithFields(opCtx,
zap.Int64("task-id", taskID),
zap.Int64("subtask-id", subtaskID))
return &OperatorCtx{
Context: opCtx,
cancel: cancel,
}
}, cancel
}

// NewLocalOperatorCtx is used for adding index with local ingest mode.
func NewLocalOperatorCtx(ctx context.Context, jobID int64) *OperatorCtx {
func NewLocalOperatorCtx(ctx context.Context, jobID int64) (*OperatorCtx, context.CancelFunc) {
opCtx, cancel := context.WithCancel(ctx)
opCtx = logutil.WithFields(opCtx, zap.Int64("jobID", jobID))
return &OperatorCtx{
Context: opCtx,
cancel: cancel,
}
}, cancel
}

func (ctx *OperatorCtx) onError(err error) {
tracedErr := errors.Trace(err)
ctx.cancel()
ctx.err.CompareAndSwap(nil, &tracedErr)
}

// Cancel cancels the pipeline.
func (ctx *OperatorCtx) Cancel() {
ctx.cancel()
}

Expand Down Expand Up @@ -769,7 +769,7 @@ func (w *indexIngestLocalWorker) HandleTask(ck IndexRecordChunk, send func(Index
return
}
w.rowCntListener.Written(rs.Added)
flushed, imported, err := w.backendCtx.Flush(ingest.FlushModeAuto)
flushed, imported, err := w.backendCtx.Flush(w.ctx, ingest.FlushModeAuto)
if err != nil {
w.ctx.onError(err)
return
Expand Down Expand Up @@ -949,7 +949,7 @@ func (s *indexWriteResultSink) flush() error {
failpoint.Inject("mockFlushError", func(_ failpoint.Value) {
failpoint.Return(errors.New("mock flush error"))
})
flushed, imported, err := s.backendCtx.Flush(ingest.FlushModeForceFlushAndImport)
flushed, imported, err := s.backendCtx.Flush(s.ctx, ingest.FlushModeForceFlushAndImport)
if s.cpMgr != nil {
// Try to advance watermark even if there is an error.
s.cpMgr.AdvanceWatermark(flushed, imported)
Expand Down
4 changes: 2 additions & 2 deletions pkg/ddl/backfilling_read_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ func (r *readIndexExecutor) RunSubtask(ctx context.Context, subtask *proto.Subta
return err
}

opCtx := NewDistTaskOperatorCtx(ctx, subtask.TaskID, subtask.ID)
defer opCtx.Cancel()
opCtx, cancel := NewDistTaskOperatorCtx(ctx, subtask.TaskID, subtask.ID)
defer cancel()
r.curRowCount.Store(0)

if len(r.cloudStorageURI) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/ddl/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func checkAndSetFlashbackClusterInfo(ctx context.Context, se sessionctx.Context,
}
}

jobs, err := GetAllDDLJobs(se)
jobs, err := GetAllDDLJobs(ctx, se)
if err != nil {
return errors.Trace(err)
}
Expand Down
10 changes: 6 additions & 4 deletions pkg/ddl/db_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1297,11 +1297,12 @@ func prepareTestControlParallelExecSQL(t *testing.T, store kv.Storage) (*testkit
return
}
var qLen int
ctx := context.Background()
for {
sess := testkit.NewTestKit(t, store).Session()
err := sessiontxn.NewTxn(context.Background(), sess)
err := sessiontxn.NewTxn(ctx, sess)
require.NoError(t, err)
jobs, err := ddl.GetAllDDLJobs(sess)
jobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)
qLen = len(jobs)
if qLen == 2 {
Expand All @@ -1321,11 +1322,12 @@ func prepareTestControlParallelExecSQL(t *testing.T, store kv.Storage) (*testkit
// Make sure the sql1 is put into the DDLJobQueue.
go func() {
var qLen int
ctx := context.Background()
for {
sess := testkit.NewTestKit(t, store).Session()
err := sessiontxn.NewTxn(context.Background(), sess)
err := sessiontxn.NewTxn(ctx, sess)
require.NoError(t, err)
jobs, err := ddl.GetAllDDLJobs(sess)
jobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)
qLen = len(jobs)
if qLen == 1 {
Expand Down
61 changes: 35 additions & 26 deletions pkg/ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -1229,15 +1229,16 @@ func GetDDLInfo(s sessionctx.Context) (*Info, error) {

func get2JobsFromTable(sess *sess.Session) (*model.Job, *model.Job, error) {
var generalJob, reorgJob *model.Job
jobs, err := getJobsBySQL(sess, JobTable, "not reorg order by job_id limit 1")
ctx := context.Background()
jobs, err := getJobsBySQL(ctx, sess, JobTable, "not reorg order by job_id limit 1")
if err != nil {
return nil, nil, errors.Trace(err)
}

if len(jobs) != 0 {
generalJob = jobs[0]
}
jobs, err = getJobsBySQL(sess, JobTable, "reorg order by job_id limit 1")
jobs, err = getJobsBySQL(ctx, sess, JobTable, "reorg order by job_id limit 1")
if err != nil {
return nil, nil, errors.Trace(err)
}
Expand Down Expand Up @@ -1309,6 +1310,7 @@ func resumePausedJob(_ *sess.Session, job *model.Job,

// processJobs command on the Job according to the process
func processJobs(
ctx context.Context,
process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error),
sessCtx sessionctx.Context,
ids []int64,
Expand Down Expand Up @@ -1336,11 +1338,11 @@ func processJobs(
idsStr = append(idsStr, strconv.FormatInt(id, 10))
}

err = ns.Begin(context.Background())
err = ns.Begin(ctx)
if err != nil {
return nil, err
}
jobs, err := getJobsBySQL(ns, JobTable, fmt.Sprintf("job_id in (%s) order by job_id", strings.Join(idsStr, ", ")))
jobs, err := getJobsBySQL(ctx, ns, JobTable, fmt.Sprintf("job_id in (%s) order by job_id", strings.Join(idsStr, ", ")))
if err != nil {
ns.Rollback()
return nil, err
Expand All @@ -1362,7 +1364,7 @@ func processJobs(
continue
}

err = updateDDLJob2Table(ns, job, false)
err = updateDDLJob2Table(ctx, ns, job, false)
if err != nil {
jobErrs[i] = err
continue
Expand All @@ -1376,7 +1378,7 @@ func processJobs(
})

// There may be some conflict during the update, try it again
if err = ns.Commit(context.Background()); err != nil {
if err = ns.Commit(ctx); err != nil {
continue
}

Expand All @@ -1391,43 +1393,50 @@ func processJobs(
}

// CancelJobs cancels the DDL jobs according to user command.
func CancelJobs(se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(cancelRunningJob, se, ids, model.AdminCommandByEndUser)
func CancelJobs(ctx context.Context, se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(ctx, cancelRunningJob, se, ids, model.AdminCommandByEndUser)
}

// PauseJobs pause all the DDL jobs according to user command.
func PauseJobs(se sessionctx.Context, ids []int64) ([]error, error) {
return processJobs(pauseRunningJob, se, ids, model.AdminCommandByEndUser)
func PauseJobs(ctx context.Context, se sessionctx.Context, ids []int64) ([]error, error) {
return processJobs(ctx, pauseRunningJob, se, ids, model.AdminCommandByEndUser)
}

// ResumeJobs resume all the DDL jobs according to user command.
func ResumeJobs(se sessionctx.Context, ids []int64) ([]error, error) {
return processJobs(resumePausedJob, se, ids, model.AdminCommandByEndUser)
func ResumeJobs(ctx context.Context, se sessionctx.Context, ids []int64) ([]error, error) {
return processJobs(ctx, resumePausedJob, se, ids, model.AdminCommandByEndUser)
}

// CancelJobsBySystem cancels Jobs because of internal reasons.
func CancelJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(cancelRunningJob, se, ids, model.AdminCommandBySystem)
ctx := context.Background()
return processJobs(ctx, cancelRunningJob, se, ids, model.AdminCommandBySystem)
}

// PauseJobsBySystem pauses Jobs because of internal reasons.
func PauseJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(pauseRunningJob, se, ids, model.AdminCommandBySystem)
ctx := context.Background()
return processJobs(ctx, pauseRunningJob, se, ids, model.AdminCommandBySystem)
}

// ResumeJobsBySystem resumes Jobs that are paused by TiDB itself.
func ResumeJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) {
return processJobs(resumePausedJob, se, ids, model.AdminCommandBySystem)
ctx := context.Background()
return processJobs(ctx, resumePausedJob, se, ids, model.AdminCommandBySystem)
}

// pprocessAllJobs processes all the jobs in the job table, 100 jobs at a time in case of high memory usage.
func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error),
se sessionctx.Context, byWho model.AdminCommandOperator) (map[int64]error, error) {
func processAllJobs(
ctx context.Context,
process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error),
se sessionctx.Context,
byWho model.AdminCommandOperator,
) (map[int64]error, error) {
var err error
var jobErrs = make(map[int64]error)

ns := sess.NewSession(se)
err = ns.Begin(context.Background())
err = ns.Begin(ctx)
if err != nil {
return nil, err
}
Expand All @@ -1437,7 +1446,7 @@ func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOp
var limit = 100
for {
var jobs []*model.Job
jobs, err = getJobsBySQL(ns, JobTable,
jobs, err = getJobsBySQL(ctx, ns, JobTable,
fmt.Sprintf("job_id >= %s order by job_id asc limit %s",
strconv.FormatInt(jobID, 10),
strconv.FormatInt(int64(limit), 10)))
Expand All @@ -1453,7 +1462,7 @@ func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOp
continue
}

err = updateDDLJob2Table(ns, job, false)
err = updateDDLJob2Table(ctx, ns, job, false)
if err != nil {
jobErrs[job.ID] = err
continue
Expand All @@ -1473,7 +1482,7 @@ func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOp
jobID = jobIDMax + 1
}

err = ns.Commit(context.Background())
err = ns.Commit(ctx)
if err != nil {
return nil, err
}
Expand All @@ -1482,23 +1491,23 @@ func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOp

// PauseAllJobsBySystem pauses all running Jobs because of internal reasons.
func PauseAllJobsBySystem(se sessionctx.Context) (map[int64]error, error) {
return processAllJobs(pauseRunningJob, se, model.AdminCommandBySystem)
return processAllJobs(context.Background(), pauseRunningJob, se, model.AdminCommandBySystem)
}

// ResumeAllJobsBySystem resumes all paused Jobs because of internal reasons.
func ResumeAllJobsBySystem(se sessionctx.Context) (map[int64]error, error) {
return processAllJobs(resumePausedJob, se, model.AdminCommandBySystem)
return processAllJobs(context.Background(), resumePausedJob, se, model.AdminCommandBySystem)
}

// GetAllDDLJobs get all DDL jobs and sorts jobs by job.ID.
func GetAllDDLJobs(se sessionctx.Context) ([]*model.Job, error) {
return getJobsBySQL(sess.NewSession(se), JobTable, "1 order by job_id")
func GetAllDDLJobs(ctx context.Context, se sessionctx.Context) ([]*model.Job, error) {
return getJobsBySQL(ctx, sess.NewSession(se), JobTable, "1 order by job_id")
}

// IterAllDDLJobs will iterates running DDL jobs first, return directly if `finishFn` return true or error,
// then iterates history DDL jobs until the `finishFn` return true or error.
func IterAllDDLJobs(ctx sessionctx.Context, txn kv.Transaction, finishFn func([]*model.Job) (bool, error)) error {
jobs, err := GetAllDDLJobs(ctx)
jobs, err := GetAllDDLJobs(context.Background(), ctx)
if err != nil {
return err
}
Expand Down
8 changes: 5 additions & 3 deletions pkg/ddl/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func TestGetDDLJobs(t *testing.T) {

cnt := 10
jobs := make([]*model.Job, cnt)
ctx := context.Background()
var currJobs2 []*model.Job
for i := 0; i < cnt; i++ {
jobs[i] = &model.Job{
Expand All @@ -59,7 +60,7 @@ func TestGetDDLJobs(t *testing.T) {
err := addDDLJobs(sess, txn, jobs[i])
require.NoError(t, err)

currJobs, err := ddl.GetAllDDLJobs(sess)
currJobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)
require.Len(t, currJobs, i+1)

Expand All @@ -77,7 +78,7 @@ func TestGetDDLJobs(t *testing.T) {
require.Len(t, currJobs2, i+1)
}

currJobs, err := ddl.GetAllDDLJobs(sess)
currJobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)

for i, job := range jobs {
Expand All @@ -93,6 +94,7 @@ func TestGetDDLJobs(t *testing.T) {

func TestGetDDLJobsIsSort(t *testing.T) {
store := testkit.CreateMockStore(t)
ctx := context.Background()

sess := testkit.NewTestKit(t, store).Session()
_, err := sess.Execute(context.Background(), "begin")
Expand All @@ -110,7 +112,7 @@ func TestGetDDLJobsIsSort(t *testing.T) {
// insert add index jobs to AddIndexJobListKey queue
enQueueDDLJobs(t, sess, txn, model.ActionAddIndex, 5, 10)

currJobs, err := ddl.GetAllDDLJobs(sess)
currJobs, err := ddl.GetAllDDLJobs(ctx, sess)
require.NoError(t, err)
require.Len(t, currJobs, 15)

Expand Down
3 changes: 2 additions & 1 deletion pkg/ddl/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func FetchChunk4Test(copCtx copr.CopContext, tbl table.PhysicalTable, startKey,
for i := 0; i < 10; i++ {
srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, batchSize)
}
opCtx := ddl.NewLocalOperatorCtx(context.Background(), 1)
opCtx, cancel := ddl.NewLocalOperatorCtx(context.Background(), 1)
defer cancel()
src := testutil.NewOperatorTestSource(ddl.TableScanTask{ID: 1, Start: startKey, End: endKey})
scanOp := ddl.NewTableScanOperator(opCtx, sessPool, copCtx, srcChkPool, 1, nil, 0)
sink := testutil.NewOperatorTestSink[ddl.IndexRecordChunk]()
Expand Down
Loading

0 comments on commit bad2ecd

Please sign in to comment.