Skip to content

Commit

Permalink
session: optimize the sharding algorithm for non-transactional DMLs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ekexium authored Apr 18, 2022
1 parent 58e2a55 commit e10ad28
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 36 deletions.
48 changes: 29 additions & 19 deletions session/nontransactional.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/types"
driver "github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/sqlexec"
Expand Down Expand Up @@ -304,29 +305,38 @@ func buildShardJobs(ctx context.Context, stmt *ast.NonTransactionalDeleteStmt, s
break
}

newStart := chk.GetRow(0).GetDatum(0, &rs.Fields()[0].Column.FieldType)
if len(jobs) > 0 && chk.NumRows()+currentSize < batchSize {
// not enough data for a batch
currentSize += chk.NumRows()
newEnd := chk.GetRow(chk.NumRows()-1).GetDatum(0, &rs.Fields()[0].Column.FieldType)
currentEnd = *newEnd.Clone()
continue
}

// end last batch if: (1) current start != last end (2) current size >= batch size
if currentSize >= batchSize {
cmp, err := newStart.Compare(se.GetSessionVars().StmtCtx, &currentEnd, collate.GetCollator(shardColumnCollate))
if err != nil {
return nil, err
iter := chunk.NewIterator4Chunk(chk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
if currentSize == 0 {
newStart := row.GetDatum(0, &rs.Fields()[0].Column.FieldType)
currentStart = *newStart.Clone()
}
if cmp != 0 {
jobs = append(jobs, job{jobID: jobCount, start: currentStart, end: currentEnd, jobSize: currentSize})
jobCount++
currentSize = 0
newEnd := row.GetDatum(0, &rs.Fields()[0].Column.FieldType)
if currentSize >= batchSize {
cmp, err := newEnd.Compare(se.GetSessionVars().StmtCtx, &currentEnd, collate.GetCollator(shardColumnCollate))
if err != nil {
return nil, err
}
if cmp != 0 {
jobs = append(jobs, job{jobID: jobCount, start: *currentStart.Clone(), end: *currentEnd.Clone(), jobSize: currentSize})
jobCount++
currentSize = 0
currentStart = newEnd
}
}
currentEnd = newEnd
currentSize++
}

// a new batch
if currentSize == 0 {
currentStart = *newStart.Clone()
}

currentSize += chk.NumRows()
currentEndPointer := chk.GetRow(chk.NumRows()-1).GetDatum(0, &rs.Fields()[0].Column.FieldType)
currentEnd = *currentEndPointer.Clone()
currentEnd = *currentEnd.Clone()
currentStart = *currentStart.Clone()
}

return jobs, nil
Expand Down
46 changes: 29 additions & 17 deletions session/nontransactional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestNonTransactionalDelete(t *testing.T) {
func TestNonTransactionalDeleteSharding(t *testing.T) {
store, clean := createStorage(t)
defer clean()
tk := testkit.NewTestKit(t, store)
Expand All @@ -49,30 +49,42 @@ func TestNonTransactionalDelete(t *testing.T) {
"create table t(a varchar(30), b int, unique key(a, b))",
"create table t(a varchar(30), b int, unique key(a))",
}
tableSizes := []int{0, 1, 10, 35, 40, 100}
batchSizes := []int{1, 10, 25, 35, 50, 80, 120}
for _, table := range tables {
tk.MustExec("drop table if exists t")
tk.MustExec(table)
for i := 0; i < 100; i++ {
tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2))
}
tk.MustExec("split on a limit 3 delete from t")
tk.MustQuery("select count(*) from t").Check(testkit.Rows("0"))

for i := 0; i < 100; i++ {
tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2))
}
if strings.Contains(table, "a int") {
rows := tk.MustQuery("split on a limit 3 dry run delete from t").Rows()
for _, row := range rows {
require.True(t, strings.HasPrefix(row[0].(string), "DELETE FROM `test`.`t` WHERE `a` BETWEEN"))
for _, tableSize := range tableSizes {
for _, batchSize := range batchSizes {
for i := 0; i < tableSize; i++ {
tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2))
}
tk.MustQuery(fmt.Sprintf("split on a limit %d delete from t", batchSize)).Check(testkit.Rows(fmt.Sprintf("%d all succeeded", (tableSize+batchSize-1)/batchSize)))
tk.MustQuery("select count(*) from t").Check(testkit.Rows("0"))
}
}
tk.MustQuery("split on a limit 3 dry run query delete from t").Check(testkit.Rows(
"SELECT `a` FROM `test`.`t` WHERE TRUE ORDER BY IF(ISNULL(`a`),0,1),`a`"))
tk.MustQuery("select count(*) from t").Check(testkit.Rows("100"))
}
}

func TestNonTransactionalDeleteDryRun(t *testing.T) {
store, clean := createStorage(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("set @@tidb_max_chunk_size=35")
tk.MustExec("use test")
tk.MustExec("create table t(a int, b int, primary key(a, b) clustered)")
for i := 0; i < 100; i++ {
tk.MustExec(fmt.Sprintf("insert into t values ('%d', %d)", i, i*2))
}
rows := tk.MustQuery("split on a limit 3 dry run delete from t").Rows()
for _, row := range rows {
require.True(t, strings.HasPrefix(row[0].(string), "DELETE FROM `test`.`t` WHERE `a` BETWEEN"))
}
tk.MustQuery("split on a limit 3 dry run query delete from t").Check(testkit.Rows(
"SELECT `a` FROM `test`.`t` WHERE TRUE ORDER BY IF(ISNULL(`a`),0,1),`a`"))
tk.MustQuery("select count(*) from t").Check(testkit.Rows("100"))
}

func TestNonTransactionalDeleteErrorMessage(t *testing.T) {
store, clean := createStorage(t)
defer clean()
Expand Down

0 comments on commit e10ad28

Please sign in to comment.