diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index 933212d10557b..f6c6ab30ae80e 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -394,6 +394,20 @@ func (mr *MockTaskManagerMockRecorder) PauseTask(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PauseTask", reflect.TypeOf((*MockTaskManager)(nil).PauseTask), arg0, arg1) } +// PausedTask mocks base method. +func (m *MockTaskManager) PausedTask(arg0 context.Context, arg1 int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PausedTask", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// PausedTask indicates an expected call of PausedTask. +func (mr *MockTaskManagerMockRecorder) PausedTask(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PausedTask", reflect.TypeOf((*MockTaskManager)(nil).PausedTask), arg0, arg1) +} + // ResumeSubtasks mocks base method. func (m *MockTaskManager) ResumeSubtasks(arg0 context.Context, arg1 int64) error { m.ctrl.T.Helper() @@ -408,6 +422,20 @@ func (mr *MockTaskManagerMockRecorder) ResumeSubtasks(arg0, arg1 any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResumeSubtasks", reflect.TypeOf((*MockTaskManager)(nil).ResumeSubtasks), arg0, arg1) } +// RevertedTask mocks base method. +func (m *MockTaskManager) RevertedTask(arg0 context.Context, arg1 int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevertedTask", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevertedTask indicates an expected call of RevertedTask. +func (mr *MockTaskManagerMockRecorder) RevertedTask(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevertedTask", reflect.TypeOf((*MockTaskManager)(nil).RevertedTask), arg0, arg1) +} + // SucceedTask mocks base method. func (m *MockTaskManager) SucceedTask(arg0 context.Context, arg1 int64) error { m.ctrl.T.Helper() diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index f2508db179fca..77c949dd2c661 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -36,10 +36,18 @@ type TaskManager interface { GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) DeleteDeadNodes(ctx context.Context, nodes []string) error TransferTasks2History(ctx context.Context, tasks []*proto.Task) error + // CancelTask updated task state to canceling. CancelTask(ctx context.Context, taskID int64) error // FailTask updates task state to Failed and updates task error. FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error + // RevertedTask updates task state to reverted. + RevertedTask(ctx context.Context, taskID int64) error + // PauseTask updated task state to pausing. PauseTask(ctx context.Context, taskKey string) (bool, error) + // PausedTask updated task state to paused. + PausedTask(ctx context.Context, taskID int64) error + // SucceedTask updates a task to success state. + SucceedTask(ctx context.Context, taskID int64) error // SwitchTaskStep switches the task to the next step and add subtasks in one // transaction. It will change task state too if we're switch from InitStep to // next step. @@ -51,8 +59,6 @@ type TaskManager interface { // And each subtask of this step must be different, to handle the network // partition or owner change. SwitchTaskStepInBatch(ctx context.Context, task *proto.Task, nextState proto.TaskState, nextStep proto.Step, subtasks []*proto.Subtask) error - // SucceedTask updates a task to success state. - SucceedTask(ctx context.Context, taskID int64) error // GetUsedSlotsOnNodes returns the used slots on nodes that have subtask scheduled. // subtasks of each task on one node is only accounted once as we don't support // running them concurrently. diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 92d170f69eaa7..a034b6a733746 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -245,7 +245,7 @@ func (s *BaseScheduler) onPausing() error { runningPendingCnt := cntByStates[proto.SubtaskStateRunning] + cntByStates[proto.SubtaskStatePending] if runningPendingCnt == 0 { logutil.Logger(s.logCtx).Info("all running subtasks paused, update the task to paused state") - return s.updateTask(proto.TaskStatePaused, nil, RetrySQLTimes) + return s.taskMgr.PausedTask(s.ctx, s.Task.ID) } logutil.Logger(s.logCtx).Debug("on pausing state, this task keeps current state", zap.Stringer("state", s.Task.State)) return nil @@ -302,7 +302,7 @@ func (s *BaseScheduler) onReverting() error { if err = s.OnDone(s.ctx, s, s.Task); err != nil { return errors.Trace(err) } - return s.updateTask(proto.TaskStateReverted, nil, RetrySQLTimes) + return s.taskMgr.RevertedTask(s.ctx, s.Task.ID) } // Wait all subtasks in this step finishes. s.OnTick(s.ctx, s.Task) @@ -641,11 +641,11 @@ func (s *BaseScheduler) handlePlanErr(err error) error { return err } s.Task.Error = err - if err = s.OnDone(s.ctx, s, s.Task); err != nil { return errors.Trace(err) } - return s.updateTask(proto.TaskStateFailed, nil, RetrySQLTimes) + + return s.taskMgr.FailTask(s.ctx, s.Task.ID, s.Task.State, s.Task.Error) } // MockServerInfo exported for scheduler_test.go diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index 154082d24a25a..ff150e67c3092 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "storage", srcs = [ + "task_state.go", "task_table.go", "util.go", ], diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 114c6927b9b7f..99ac86e6bc30a 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -133,11 +133,16 @@ func TestTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, proto.TaskStatePending, task.State) require.Nil(t, task.Error) + curTime := time.Unix(time.Now().Unix(), 0) require.NoError(t, gm.FailTask(ctx, id, proto.TaskStatePending, errors.New("test error"))) task, err = gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, proto.TaskStateFailed, task.State) require.ErrorContains(t, task.Error, "test error") + endTime, err := storage.GetTaskEndTimeForTest(ctx, gm, id) + require.NoError(t, err) + require.LessOrEqual(t, endTime.Sub(curTime), time.Since(curTime)) + require.GreaterOrEqual(t, endTime, curTime) // succeed a pending task, no effect id, err = gm.CreateTask(ctx, "key-success", "test", 4, []byte("test")) @@ -157,6 +162,44 @@ func TestTaskTable(t *testing.T) { require.NoError(t, err) checkTaskStateStep(t, task, proto.TaskStateSucceed, proto.StepDone) require.GreaterOrEqual(t, task.StateUpdateTime, startTime) + + // reverted a pending task, no effect + id, err = gm.CreateTask(ctx, "key-reverted", "test", 4, []byte("test")) + require.NoError(t, err) + require.NoError(t, gm.RevertedTask(ctx, id)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) + // reverted a reverting task + task.State = proto.TaskStateReverting + _, err = gm.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStatePending) + require.NoError(t, err) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateReverting, task.State) + require.NoError(t, gm.RevertedTask(ctx, task.ID)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateReverted, task.State) + // paused + + id, err = gm.CreateTask(ctx, "key-paused", "test", 4, []byte("test")) + require.NoError(t, err) + require.NoError(t, gm.PausedTask(ctx, id)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + checkTaskStateStep(t, task, proto.TaskStatePending, proto.StepInit) + // reverted a reverting task + task.State = proto.TaskStatePausing + _, err = gm.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStatePending) + require.NoError(t, err) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStatePausing, task.State) + require.NoError(t, gm.PausedTask(ctx, task.ID)) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStatePaused, task.State) } func checkAfterSwitchStep(t *testing.T, startTime time.Time, task *proto.Task, subtasks []*proto.Subtask, step proto.Step) { @@ -543,6 +586,10 @@ func TestSubTaskTable(t *testing.T) { require.Greater(t, subtask.StartTime, ts) require.Greater(t, subtask.UpdateTime, ts) + endTime, err := storage.GetSubtaskEndTimeForTest(ctx, sm, subtask.ID) + require.NoError(t, err) + require.Greater(t, endTime, ts) + // test FinishSubtask do update update time testutil.CreateSubTask(t, sm, 4, proto.StepInit, "for_test1", []byte("test"), proto.TaskTypeExample, 11, false) subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test1", 4, proto.StepInit, proto.SubtaskStatePending) @@ -554,12 +601,16 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.Greater(t, subtask.StartTime, ts) require.Greater(t, subtask.UpdateTime, ts) + ts = time.Now() time.Sleep(time.Second) require.NoError(t, sm.FinishSubtask(ctx, "for_test1", subtask.ID, []byte{})) subtask2, err = sm.GetFirstSubtaskInStates(ctx, "for_test1", 4, proto.StepInit, proto.SubtaskStateSucceed) require.NoError(t, err) require.Equal(t, subtask2.StartTime, subtask.StartTime) require.Greater(t, subtask2.UpdateTime, subtask.UpdateTime) + endTime, err = storage.GetSubtaskEndTimeForTest(ctx, sm, subtask.ID) + require.NoError(t, err) + require.Greater(t, endTime, ts) // test UpdateFailedTaskExecutorIDs and IsTaskExecutorCanceled canceled, err := sm.IsTaskExecutorCanceled(ctx, "for_test999", 4) diff --git a/pkg/disttask/framework/storage/task_state.go b/pkg/disttask/framework/storage/task_state.go new file mode 100644 index 0000000000000..d0723964e48d3 --- /dev/null +++ b/pkg/disttask/framework/storage/task_state.go @@ -0,0 +1,152 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/sqlexec" +) + +// CancelTask cancels task. +func (stm *TaskManager) CancelTask(ctx context.Context, taskID int64) error { + _, err := stm.executeSQLWithNewSession(ctx, + `update mysql.tidb_global_task + set state = %?, + state_update_time = CURRENT_TIMESTAMP() + where id = %? and state in (%?, %?)`, + proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning, + ) + return err +} + +// CancelTaskByKeySession cancels task by key using input session. +func (*TaskManager) CancelTaskByKeySession(ctx context.Context, se sessionctx.Context, taskKey string) error { + _, err := sqlexec.ExecSQL(ctx, se, + `update mysql.tidb_global_task + set state = %?, + state_update_time = CURRENT_TIMESTAMP() + where task_key = %? and state in (%?, %?)`, + proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning) + return err +} + +// FailTask implements the scheduler.TaskManager interface. +func (stm *TaskManager) FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error { + _, err := stm.executeSQLWithNewSession(ctx, + `update mysql.tidb_global_task + set state = %?, + error = %?, + state_update_time = CURRENT_TIMESTAMP(), + end_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, + proto.TaskStateFailed, serializeErr(taskErr), taskID, currentState, + ) + return err +} + +// RevertedTask implements the scheduler.TaskManager interface. +func (stm *TaskManager) RevertedTask(ctx context.Context, taskID int64) error { + _, err := stm.executeSQLWithNewSession(ctx, + `update mysql.tidb_global_task + set state = %?, + state_update_time = CURRENT_TIMESTAMP(), + end_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, + proto.TaskStateReverted, taskID, proto.TaskStateReverting, + ) + return err +} + +// PauseTask pauses the task. +func (stm *TaskManager) PauseTask(ctx context.Context, taskKey string) (bool, error) { + found := false + err := stm.WithNewSession(func(se sessionctx.Context) error { + _, err := sqlexec.ExecSQL(ctx, se, + `update mysql.tidb_global_task + set state = %?, + state_update_time = CURRENT_TIMESTAMP() + where task_key = %? and state in (%?, %?)`, + proto.TaskStatePausing, taskKey, proto.TaskStatePending, proto.TaskStateRunning, + ) + if err != nil { + return err + } + if se.GetSessionVars().StmtCtx.AffectedRows() != 0 { + found = true + } + return err + }) + if err != nil { + return found, err + } + return found, nil +} + +// PausedTask update the task state from pausing to paused. +func (stm *TaskManager) PausedTask(ctx context.Context, taskID int64) error { + _, err := stm.executeSQLWithNewSession(ctx, + `update mysql.tidb_global_task + set state = %?, + state_update_time = CURRENT_TIMESTAMP(), + end_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, + proto.TaskStatePaused, taskID, proto.TaskStatePausing, + ) + return err +} + +// ResumeTask resumes the task. +func (stm *TaskManager) ResumeTask(ctx context.Context, taskKey string) (bool, error) { + found := false + err := stm.WithNewSession(func(se sessionctx.Context) error { + _, err := sqlexec.ExecSQL(ctx, se, + `update mysql.tidb_global_task + set state = %?, + state_update_time = CURRENT_TIMESTAMP() + where task_key = %? and state = %?`, + proto.TaskStateResuming, taskKey, proto.TaskStatePaused, + ) + if err != nil { + return err + } + if se.GetSessionVars().StmtCtx.AffectedRows() != 0 { + found = true + } + return err + }) + if err != nil { + return found, err + } + return found, nil +} + +// SucceedTask update task state from running to succeed. +func (stm *TaskManager) SucceedTask(ctx context.Context, taskID int64) error { + return stm.WithNewSession(func(se sessionctx.Context) error { + _, err := sqlexec.ExecSQL(ctx, se, ` + update mysql.tidb_global_task + set state = %?, + step = %?, + state_update_time = CURRENT_TIMESTAMP(), + end_time = CURRENT_TIMESTAMP() + where id = %? and state = %?`, + proto.TaskStateSucceed, proto.StepDone, taskID, proto.TaskStateRunning, + ) + return err + }) +} diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 1972d3418245c..31c64968a9595 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -257,22 +257,6 @@ func (stm *TaskManager) CreateTaskWithSession(ctx context.Context, se sessionctx return taskID, nil } -// SucceedTask implements dispatcher.TaskManager interface. -func (stm *TaskManager) SucceedTask(ctx context.Context, taskID int64) error { - return stm.WithNewSession(func(se sessionctx.Context) error { - _, err := sqlexec.ExecSQL(ctx, se, ` - update mysql.tidb_global_task - set state = %?, - step = %?, - state_update_time = CURRENT_TIMESTAMP(), - end_time = CURRENT_TIMESTAMP() - where id = %? and state = %?`, - proto.TaskStateSucceed, proto.StepDone, taskID, proto.TaskStateRunning, - ) - return err - }) -} - // GetOneTask get a task from task table, it's used by scheduler only. func (stm *TaskManager) GetOneTask(ctx context.Context) (task *proto.Task, err error) { rs, err := stm.executeSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending) @@ -403,19 +387,6 @@ func (stm *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string) return row2Task(rs[0]), nil } -// FailTask implements the scheduler.TaskManager interface. -func (stm *TaskManager) FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error { - _, err := stm.executeSQLWithNewSession(ctx, - `update mysql.tidb_global_task - set state=%?, - error = %?, - state_update_time = CURRENT_TIMESTAMP() - where id=%? and state=%?`, - proto.TaskStateFailed, serializeErr(taskErr), taskID, currentState, - ) - return err -} - // GetUsedSlotsOnNodes implements the scheduler.TaskManager interface. func (stm *TaskManager) GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) { // concurrency of subtasks of some step is the same, we use max(concurrency) @@ -545,10 +516,12 @@ func (stm *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID stri } // UpdateSubtaskExecID updates the subtask's exec_id, used for testing now. -func (stm *TaskManager) UpdateSubtaskExecID(ctx context.Context, execID string, subtaskID int64) error { - _, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask - set exec_id = %?, state_update_time = unix_timestamp() where id = %?`, - execID, subtaskID) +func (stm *TaskManager) UpdateSubtaskExecID(ctx context.Context, tidbID string, subtaskID int64) error { + _, err := stm.executeSQLWithNewSession(ctx, + `update mysql.tidb_background_subtask + set exec_id = %?, state_update_time = unix_timestamp() + where id = %?`, + tidbID, subtaskID) return err } @@ -559,7 +532,11 @@ func (stm *TaskManager) UpdateErrorToSubtask(ctx context.Context, execID string, } _, err1 := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask - set state = %?, error = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp() + set state = %?, + error = %?, + start_time = unix_timestamp(), + state_update_time = unix_timestamp(), + end_time = CURRENT_TIMESTAMP() where exec_id = %? and task_key = %? and state in (%?, %?) @@ -608,7 +585,8 @@ func (stm *TaskManager) GetSubtaskRowCount(ctx context.Context, taskID int64, st // UpdateSubtaskRowCount updates the subtask row count. func (stm *TaskManager) UpdateSubtaskRowCount(ctx context.Context, subtaskID int64, rowCount int64) error { - _, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask + _, err := stm.executeSQLWithNewSession(ctx, + `update mysql.tidb_background_subtask set summary = json_set(summary, '$.row_count', %?) where id = %?`, rowCount, subtaskID) return err @@ -737,7 +715,8 @@ func (stm *TaskManager) UpdateSubtaskStateAndError( // FinishSubtask updates the subtask meta and mark state to succeed. func (stm *TaskManager) FinishSubtask(ctx context.Context, execID string, id int64, meta []byte) error { _, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask - set meta = %?, state = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`, + set meta = %?, state = %?, state_update_time = unix_timestamp(), end_time = CURRENT_TIMESTAMP() + where id = %? and exec_id = %?`, meta, proto.TaskStateSucceed, id, execID) return err } @@ -1087,25 +1066,6 @@ func serializeErr(err error) []byte { return errBytes } -// CancelTask cancels task. -func (stm *TaskManager) CancelTask(ctx context.Context, taskID int64) error { - _, err := stm.executeSQLWithNewSession(ctx, - "update mysql.tidb_global_task set state=%?, state_update_time = CURRENT_TIMESTAMP() "+ - "where id=%? and state in (%?, %?)", - proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning, - ) - return err -} - -// CancelTaskByKeySession cancels task by key using input session. -func (*TaskManager) CancelTaskByKeySession(ctx context.Context, se sessionctx.Context, taskKey string) error { - _, err := sqlexec.ExecSQL(ctx, se, - "update mysql.tidb_global_task set state=%?, state_update_time = CURRENT_TIMESTAMP() "+ - "where task_key=%? and state in (%?, %?)", - proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning) - return err -} - // IsTaskCancelling checks whether the task state is cancelling. func (stm *TaskManager) IsTaskCancelling(ctx context.Context, taskID int64) (bool, error) { rs, err := stm.executeSQLWithNewSession(ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?", @@ -1119,52 +1079,6 @@ func (stm *TaskManager) IsTaskCancelling(ctx context.Context, taskID int64) (boo return len(rs) > 0, nil } -// PauseTask pauses the task. -func (stm *TaskManager) PauseTask(ctx context.Context, taskKey string) (bool, error) { - found := false - err := stm.WithNewSession(func(se sessionctx.Context) error { - _, err := sqlexec.ExecSQL(ctx, se, - "update mysql.tidb_global_task set state=%?, state_update_time = CURRENT_TIMESTAMP() "+ - "where task_key = %? and state in (%?, %?)", - proto.TaskStatePausing, taskKey, proto.TaskStatePending, proto.TaskStateRunning, - ) - if err != nil { - return err - } - if se.GetSessionVars().StmtCtx.AffectedRows() != 0 { - found = true - } - return err - }) - if err != nil { - return found, err - } - return found, nil -} - -// ResumeTask resumes the task. -func (stm *TaskManager) ResumeTask(ctx context.Context, taskKey string) (bool, error) { - found := false - err := stm.WithNewSession(func(se sessionctx.Context) error { - _, err := sqlexec.ExecSQL(ctx, se, - "update mysql.tidb_global_task set state=%?, state_update_time = CURRENT_TIMESTAMP() "+ - "where task_key = %? and state = %?", - proto.TaskStateResuming, taskKey, proto.TaskStatePaused, - ) - if err != nil { - return err - } - if se.GetSessionVars().StmtCtx.AffectedRows() != 0 { - found = true - } - return err - }) - if err != nil { - return found, err - } - return found, nil -} - // GetSubtasksForImportInto gets the subtasks for import into(show import jobs). func (stm *TaskManager) GetSubtasksForImportInto(ctx context.Context, taskID int64, step proto.Step) ([]*proto.Subtask, error) { var ( diff --git a/pkg/disttask/framework/storage/util.go b/pkg/disttask/framework/storage/util.go index 6a1922c7624bb..e70e03d496c0b 100644 --- a/pkg/disttask/framework/storage/util.go +++ b/pkg/disttask/framework/storage/util.go @@ -17,6 +17,7 @@ package storage import ( "context" "fmt" + "time" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/disttask/framework/proto" @@ -71,6 +72,38 @@ func GetTasksFromHistoryForTest(ctx context.Context, stm *TaskManager) (int, err return len(rs), nil } +// GetTaskEndTimeForTest gets task's endTime for test. +func GetTaskEndTimeForTest(ctx context.Context, stm *TaskManager, taskID int64) (time.Time, error) { + rs, err := stm.executeSQLWithNewSession(ctx, + `select end_time + from mysql.tidb_global_task + where id = %?`, taskID) + + if err != nil { + return time.Time{}, nil + } + if !rs[0].IsNull(0) { + return rs[0].GetTime(0).GoTime(time.Local) + } + return time.Time{}, nil +} + +// GetSubtaskEndTimeForTest gets subtask's endTime for test. +func GetSubtaskEndTimeForTest(ctx context.Context, stm *TaskManager, subtaskID int64) (time.Time, error) { + rs, err := stm.executeSQLWithNewSession(ctx, + `select end_time + from mysql.tidb_background_subtask + where id = %?`, subtaskID) + + if err != nil { + return time.Time{}, nil + } + if !rs[0].IsNull(0) { + return rs[0].GetTime(0).GoTime(time.Local) + } + return time.Time{}, nil +} + // PrintSubtaskInfo log the subtask info by taskKey. Only used for UT. func (stm *TaskManager) PrintSubtaskInfo(ctx context.Context, taskID int64) { rs, _ := stm.executeSQLWithNewSession(ctx,