Skip to content

Commit

Permalink
disttask: check affacted row before start subtask (#49810)
Browse files Browse the repository at this point in the history
ref #49008
  • Loading branch information
ywqzzy authored Dec 29, 2023
1 parent 8a79c0d commit fbe232e
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 98 deletions.
8 changes: 4 additions & 4 deletions pkg/disttask/framework/mock/task_executor_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/disttask/framework/scheduler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type TaskManager interface {
GetManagedNodes(ctx context.Context) ([]string, error)
GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error)
GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetSubtasksByExecIdsAndStepAndState(ctx context.Context, execIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetTaskExecutorIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error)

WithNewSession(fn func(se sessionctx.Context) error) error
Expand Down
10 changes: 8 additions & 2 deletions pkg/disttask/framework/storage/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,11 @@ func TestSubTaskTable(t *testing.T) {

ts := time.Now()
time.Sleep(time.Second)
require.NoError(t, sm.StartSubtask(ctx, 1))
err = sm.StartSubtask(ctx, 1, "tidb1")
require.NoError(t, err)

err = sm.StartSubtask(ctx, 1, "tidb2")
require.Error(t, storage.ErrSubtaskNotFound, err)

subtask, err = sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.TaskStatePending)
require.NoError(t, err)
Expand Down Expand Up @@ -565,7 +569,9 @@ func TestSubTaskTable(t *testing.T) {
testutil.CreateSubTask(t, sm, 4, proto.StepInit, "for_test1", []byte("test"), proto.TaskTypeExample, 11, false)
subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test1", 4, proto.StepInit, proto.TaskStatePending)
require.NoError(t, err)
require.NoError(t, sm.StartSubtask(ctx, subtask.ID))
err = sm.StartSubtask(ctx, subtask.ID, "for_test1")
require.NoError(t, err)

subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test1", 4, proto.StepInit, proto.TaskStateRunning)
require.NoError(t, err)
require.Greater(t, subtask.StartTime, ts)
Expand Down
111 changes: 56 additions & 55 deletions pkg/disttask/framework/storage/task_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ var (
// unstable, i.e. count, order and content of the subtasks are changed on
// different call.
ErrUnstableSubtasks = errors.New("unstable subtasks")
// ErrSubtaskNotFound is the error when can't find subtask by subtask_id and execId,
// i.e. scheduler change the subtask's execId when subtask need to balance to other nodes.
ErrSubtaskNotFound = errors.New("subtask not found")
)

// SessionExecutor defines the interface for executing SQLs in a session.
Expand Down Expand Up @@ -114,7 +117,6 @@ func SetTaskManager(is *TaskManager) {
}

// ExecSQL executes the sql and returns the result.
// TODO: consider retry.
func ExecSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) {
rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...)
if err != nil {
Expand Down Expand Up @@ -490,8 +492,8 @@ func row2SubTask(r chunk.Row) *proto.Subtask {
}

// GetSubtasksByStepAndStates gets all subtasks by given states.
func (stm *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) {
args := []interface{}{tidbID, taskID, step}
func (stm *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) {
args := []interface{}{execID, taskID, step}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask
where exec_id = %? and task_key = %? and step = %?
Expand All @@ -508,14 +510,14 @@ func (stm *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, tidbID s
}

// GetSubtasksByExecIdsAndStepAndState gets all subtasks by given taskID, exec_id, step and state.
func (stm *TaskManager) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) {
func (stm *TaskManager) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, execIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) {
args := []interface{}{taskID, step, state}
for _, tidbID := range tidbIDs {
args = append(args, tidbID)
for _, execID := range execIDs {
args = append(args, execID)
}
rs, err := stm.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask
where task_key = %? and step = %? and state = %?
and exec_id in (`+strings.Repeat("%?,", len(tidbIDs)-1)+"%?)", args...)
and exec_id in (`+strings.Repeat("%?,", len(execIDs)-1)+"%?)", args...)
if err != nil {
return nil, err
}
Expand All @@ -528,8 +530,8 @@ func (stm *TaskManager) GetSubtasksByExecIdsAndStepAndState(ctx context.Context,
}

// GetFirstSubtaskInStates gets the first subtask by given states.
func (stm *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) (*proto.Subtask, error) {
args := []interface{}{tidbID, taskID, step}
func (stm *TaskManager) GetFirstSubtaskInStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) (*proto.Subtask, error) {
args := []interface{}{execID, taskID, step}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask
where exec_id = %? and task_key = %? and step = %?
Expand All @@ -545,49 +547,34 @@ func (stm *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID stri
}

// UpdateSubtaskExecID updates the subtask's exec_id, used for testing now.
func (stm *TaskManager) UpdateSubtaskExecID(ctx context.Context, tidbID string, subtaskID int64) error {
func (stm *TaskManager) UpdateSubtaskExecID(ctx context.Context, execID string, subtaskID int64) error {
_, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
set exec_id = %?, state_update_time = unix_timestamp() where id = %?`,
tidbID, subtaskID)
execID, subtaskID)
return err
}

// UpdateErrorToSubtask updates the error to subtask.
func (stm *TaskManager) UpdateErrorToSubtask(ctx context.Context, tidbID string, taskID int64, err error) error {
func (stm *TaskManager) UpdateErrorToSubtask(ctx context.Context, execID string, taskID int64, err error) error {
if err == nil {
return nil
}
_, err1 := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
_, err1 := stm.executeSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask
set state = %?, error = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp()
where exec_id = %? and task_key = %? and state in (%?, %?) limit 1;`,
proto.TaskStateFailed, serializeErr(err), tidbID, taskID, proto.TaskStatePending, proto.TaskStateRunning)
where exec_id = %? and
task_key = %? and
state in (%?, %?)
limit 1;`,
proto.TaskStateFailed,
serializeErr(err),
execID,
taskID,
proto.TaskStatePending,
proto.TaskStateRunning)
return err1
}

// PrintSubtaskInfo log the subtask info by taskKey. Only used for UT.
func (stm *TaskManager) PrintSubtaskInfo(ctx context.Context, taskID int64) {
rs, _ := stm.executeSQLWithNewSession(ctx,
`select `+subtaskColumns+` from mysql.tidb_background_subtask_history where task_key = %?`, taskID)
rs2, _ := stm.executeSQLWithNewSession(ctx,
`select `+subtaskColumns+` from mysql.tidb_background_subtask where task_key = %?`, taskID)
rs = append(rs, rs2...)

for _, r := range rs {
errBytes := r.GetBytes(13)
var err error
if len(errBytes) > 0 {
stdErr := errors.Normalize("")
err1 := stdErr.UnmarshalJSON(errBytes)
if err1 != nil {
err = err1
} else {
err = stdErr
}
}
logutil.BgLogger().Info(fmt.Sprintf("subTask: %v\n", row2SubTask(r)), zap.Error(err))
}
}

// GetSubtasksByStepAndState gets the subtask by step and state.
func (stm *TaskManager) GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) {
rs, err := stm.executeSQLWithNewSession(ctx, `select `+subtaskColumns+` from mysql.tidb_background_subtask
Expand Down Expand Up @@ -673,8 +660,8 @@ func (stm *TaskManager) CollectSubTaskError(ctx context.Context, taskID int64) (
}

// HasSubtasksInStates checks if there are subtasks in the states.
func (stm *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) (bool, error) {
args := []interface{}{tidbID, taskID, step}
func (stm *TaskManager) HasSubtasksInStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) (bool, error) {
args := []interface{}{execID, taskID, step}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(ctx, `select 1 from mysql.tidb_background_subtask
where exec_id = %? and task_key = %? and step = %?
Expand All @@ -687,35 +674,49 @@ func (stm *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string,
}

// StartSubtask updates the subtask state to running.
func (stm *TaskManager) StartSubtask(ctx context.Context, subtaskID int64) error {
_, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp()
where id = %?`,
proto.TaskStateRunning, subtaskID)
func (stm *TaskManager) StartSubtask(ctx context.Context, subtaskID int64, execID string) error {
err := stm.WithNewTxn(ctx, func(se sessionctx.Context) error {
vars := se.GetSessionVars()
_, err := ExecSQL(ctx,
se,
`update mysql.tidb_background_subtask
set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp()
where id = %? and exec_id = %?`,
proto.TaskStateRunning,
subtaskID,
execID)
if err != nil {
return err
}
if vars.StmtCtx.AffectedRows() == 0 {
return ErrSubtaskNotFound
}
return nil
})
return err
}

// StartManager insert the manager information into dist_framework_meta.
func (stm *TaskManager) StartManager(ctx context.Context, tidbID string, role string) error {
func (stm *TaskManager) StartManager(ctx context.Context, execID string, role string) error {
_, err := stm.executeSQLWithNewSession(ctx, `insert into mysql.dist_framework_meta(host, role, keyspace_id)
SELECT %?, %?,-1
WHERE NOT EXISTS (SELECT 1 FROM mysql.dist_framework_meta WHERE host = %?)`, tidbID, role, tidbID)
WHERE NOT EXISTS (SELECT 1 FROM mysql.dist_framework_meta WHERE host = %?)`, execID, role, execID)
return err
}

// UpdateSubtaskStateAndError updates the subtask state.
func (stm *TaskManager) UpdateSubtaskStateAndError(ctx context.Context, tidbID string, id int64, state proto.TaskState, subTaskErr error) error {
func (stm *TaskManager) UpdateSubtaskStateAndError(ctx context.Context, execID string, id int64, state proto.TaskState, subTaskErr error) error {
_, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
set state = %?, error = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`,
state, serializeErr(subTaskErr), id, tidbID)
state, serializeErr(subTaskErr), id, execID)
return err
}

// FinishSubtask updates the subtask meta and mark state to succeed.
func (stm *TaskManager) FinishSubtask(ctx context.Context, tidbID string, id int64, meta []byte) error {
func (stm *TaskManager) FinishSubtask(ctx context.Context, execID string, id int64, meta []byte) error {
_, err := stm.executeSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
set meta = %?, state = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`,
meta, proto.TaskStateSucceed, id, tidbID)
meta, proto.TaskStateSucceed, id, execID)
return err
}

Expand Down Expand Up @@ -825,9 +826,9 @@ func (stm *TaskManager) DeleteDeadNodes(ctx context.Context, nodes []string) err
}

// PauseSubtasks update all running/pending subtasks to pasued state.
func (stm *TaskManager) PauseSubtasks(ctx context.Context, tidbID string, taskID int64) error {
func (stm *TaskManager) PauseSubtasks(ctx context.Context, execID string, taskID int64) error {
_, err := stm.executeSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask set state = "paused" where task_key = %? and state in ("running", "pending") and exec_id = %?`, taskID, tidbID)
`update mysql.tidb_background_subtask set state = "paused" where task_key = %? and state in ("running", "pending") and exec_id = %?`, taskID, execID)
return err
}

Expand Down Expand Up @@ -864,7 +865,7 @@ func (stm *TaskManager) SwitchTaskStep(
}
if vars.StmtCtx.AffectedRows() == 0 {
// on network partition or owner change, there might be multiple
// dispatchers for the same task, if other dispatcher has switched
// schedulers for the same task, if other scheduler has switched
// the task to next step, skip the update process.
// Or when there is no such task.
return nil
Expand Down
28 changes: 28 additions & 0 deletions pkg/disttask/framework/storage/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ package storage

import (
"context"
"fmt"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)

// GetSubtasksFromHistoryForTest gets subtasks from history table for test.
Expand Down Expand Up @@ -66,3 +70,27 @@ func GetTasksFromHistoryForTest(ctx context.Context, stm *TaskManager) (int, err
}
return len(rs), nil
}

// PrintSubtaskInfo log the subtask info by taskKey. Only used for UT.
func (stm *TaskManager) PrintSubtaskInfo(ctx context.Context, taskID int64) {
rs, _ := stm.executeSQLWithNewSession(ctx,
`select `+subtaskColumns+` from mysql.tidb_background_subtask_history where task_key = %?`, taskID)
rs2, _ := stm.executeSQLWithNewSession(ctx,
`select `+subtaskColumns+` from mysql.tidb_background_subtask where task_key = %?`, taskID)
rs = append(rs, rs2...)

for _, r := range rs {
errBytes := r.GetBytes(13)
var err error
if len(errBytes) > 0 {
stdErr := errors.Normalize("")
err1 := stdErr.UnmarshalJSON(errBytes)
if err1 != nil {
err = err1
} else {
err = stdErr
}
}
logutil.BgLogger().Info(fmt.Sprintf("subTask: %v\n", row2SubTask(r)), zap.Error(err))
}
}
22 changes: 12 additions & 10 deletions pkg/disttask/framework/taskexecutor/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@ type TaskTable interface {
GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error)
GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error)

GetSubtasksByStepAndStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error)
GetSubtasksByStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error)
GetFirstSubtaskInStates(ctx context.Context, instanceID string, taskID int64, step proto.Step, states ...interface{}) (*proto.Subtask, error)
StartManager(ctx context.Context, tidbID string, role string) error
StartSubtask(ctx context.Context, subtaskID int64) error
UpdateSubtaskStateAndError(ctx context.Context, tidbID string, subtaskID int64, state proto.TaskState, err error) error
FinishSubtask(ctx context.Context, tidbID string, subtaskID int64, meta []byte) error

HasSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) (bool, error)
UpdateErrorToSubtask(ctx context.Context, tidbID string, taskID int64, err error) error
IsTaskExecutorCanceled(ctx context.Context, tidbID string, taskID int64) (bool, error)
PauseSubtasks(ctx context.Context, tidbID string, taskID int64) error
StartManager(ctx context.Context, execID string, role string) error
// StartSubtask try to update the subtask's state to running if the subtask is owned by execID.
// If the update success, it means the execID's related task executor own the subtask.
StartSubtask(ctx context.Context, subtaskID int64, execID string) error
UpdateSubtaskStateAndError(ctx context.Context, execID string, subtaskID int64, state proto.TaskState, err error) error
FinishSubtask(ctx context.Context, execID string, subtaskID int64, meta []byte) error

HasSubtasksInStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...interface{}) (bool, error)
UpdateErrorToSubtask(ctx context.Context, execID string, taskID int64, err error) error
IsTaskExecutorCanceled(ctx context.Context, execID string, taskID int64) (bool, error)
PauseSubtasks(ctx context.Context, execID string, taskID int64) error
}

// Pool defines the interface of a pool.
Expand Down
Loading

0 comments on commit fbe232e

Please sign in to comment.