diff --git a/executor/insert_test.go b/executor/insert_test.go index b6dd964ccf57c..df243fe853cbb 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -15,7 +15,9 @@ package executor_test import ( "fmt" + "strconv" "strings" + "sync" . "github.com/pingcap/check" "github.com/pingcap/parser/terror" @@ -763,6 +765,43 @@ func (s *testSuite3) TestBit(c *C) { } +func (s *testSuite3) TestAllocateContinuousRowID(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int,b int, key I_a(a));`) + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + tk := testkit.NewTestKitWithInit(c, s.store) + for j := 0; j < 10; j++ { + k := strconv.Itoa(idx*100 + j) + sql := "insert into t1(a,b) values (" + k + ", 2)" + for t := 0; t < 20; t++ { + sql += ",(" + k + ",2)" + } + tk.MustExec(sql) + q := "select _tidb_rowid from t1 where a=" + k + fmt.Printf("query: %v\n", q) + rows := tk.MustQuery(q).Rows() + c.Assert(len(rows), Equals, 21) + last := 0 + for _, r := range rows { + c.Assert(len(r), Equals, 1) + v, err := strconv.Atoi(r[0].(string)) + c.Assert(err, Equals, nil) + if last > 0 { + c.Assert(last+1, Equals, v) + } + last = v + } + } + }(i) + } + wg.Wait() +} + func (s *testSuite3) TestJiraIssue5366(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) diff --git a/infoschema/tables.go b/infoschema/tables.go index 681a03817968a..3a033f79d74cf 100644 --- a/infoschema/tables.go +++ b/infoschema/tables.go @@ -2424,6 +2424,11 @@ func (it *infoschemaTable) AllocHandle(ctx sessionctx.Context) (int64, error) { return 0, table.ErrUnsupportedOp } +// AllocHandleIDs implements table.Table AllocHandleIDs interface. +func (it *infoschemaTable) AllocHandleIDs(ctx sessionctx.Context, n uint64) (int64, int64, error) { + return 0, 0, table.ErrUnsupportedOp +} + // Allocator implements table.Table Allocator interface. func (it *infoschemaTable) Allocator(ctx sessionctx.Context) autoid.Allocator { return nil @@ -2541,6 +2546,11 @@ func (vt *VirtualTable) AllocHandle(ctx sessionctx.Context) (int64, error) { return 0, table.ErrUnsupportedOp } +// AllocHandleIDs implements table.Table AllocHandleIDs interface. +func (vt *VirtualTable) AllocHandleIDs(ctx sessionctx.Context, n uint64) (int64, int64, error) { + return 0, 0, table.ErrUnsupportedOp +} + // Allocator implements table.Table Allocator interface. func (vt *VirtualTable) Allocator(ctx sessionctx.Context) autoid.Allocator { return nil diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index e04d961b1ff3c..6a0f382189d66 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -118,6 +118,9 @@ type StatementContext struct { // InsertID is the given insert ID of an auto_increment column. InsertID uint64 + BaseRowID int64 + MaxRowID int64 + // Copied from SessionVars.TimeZone. TimeZone *time.Location Priority mysql.PriorityEnum @@ -428,6 +431,8 @@ func (sc *StatementContext) ResetForRetry() { sc.mu.execDetails = execdetails.ExecDetails{} sc.mu.allExecDetails = make([]*execdetails.ExecDetails, 0, 4) sc.mu.Unlock() + sc.MaxRowID = 0 + sc.BaseRowID = 0 sc.TableIDs = sc.TableIDs[:0] sc.IndexNames = sc.IndexNames[:0] } diff --git a/table/table.go b/table/table.go index 41c77bf9a4ac6..acde4846fd94b 100644 --- a/table/table.go +++ b/table/table.go @@ -168,6 +168,9 @@ type Table interface { // AllocHandle allocates a handle for a new row. AllocHandle(ctx sessionctx.Context) (int64, error) + // AllocHandleIds allocates multiple handle for rows. + AllocHandleIDs(ctx sessionctx.Context, n uint64) (int64, int64, error) + // Allocator returns Allocator. Allocator(ctx sessionctx.Context) autoid.Allocator diff --git a/table/tables/tables.go b/table/tables/tables.go index 4e8620332ed0f..38d2b3e36eefa 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -455,9 +455,22 @@ func (t *tableCommon) AddRecord(ctx sessionctx.Context, r []types.Datum, opts .. } } if !hasRecordID { - recordID, err = t.AllocHandle(ctx) - if err != nil { - return 0, err + stmtCtx := ctx.GetSessionVars().StmtCtx + rows := stmtCtx.RecordRows() + if rows > 1 { + if stmtCtx.BaseRowID >= stmtCtx.MaxRowID { + stmtCtx.BaseRowID, stmtCtx.MaxRowID, err = t.AllocHandleIDs(ctx, rows) + if err != nil { + return 0, err + } + } + stmtCtx.BaseRowID += 1 + recordID = stmtCtx.BaseRowID + } else { + recordID, err = t.AllocHandle(ctx) + if err != nil { + return 0, err + } } } @@ -947,13 +960,19 @@ func GetColDefaultValue(ctx sessionctx.Context, col *table.Column, defaultVals [ // AllocHandle implements table.Table AllocHandle interface. func (t *tableCommon) AllocHandle(ctx sessionctx.Context) (int64, error) { - _, rowID, err := t.Allocator(ctx).Alloc(t.tableID, 1) + _, rowID, err := t.AllocHandleIDs(ctx, 1) + return rowID, err +} + +// AllocHandle implements table.Table AllocHandle interface. +func (t *tableCommon) AllocHandleIDs(ctx sessionctx.Context, n uint64) (int64, int64, error) { + base, maxID, err := t.Allocator(ctx).Alloc(t.tableID, n) if err != nil { - return 0, err + return 0, 0, err } if t.meta.ShardRowIDBits > 0 { // Use max record ShardRowIDBits to check overflow. - if OverflowShardBits(rowID, t.meta.MaxShardRowIDBits) { + if OverflowShardBits(maxID, t.meta.MaxShardRowIDBits) { // If overflow, the rowID may be duplicated. For examples, // t.meta.ShardRowIDBits = 4 // rowID = 0010111111111111111111111111111111111111111111111111111111111111 @@ -961,16 +980,17 @@ func (t *tableCommon) AllocHandle(ctx sessionctx.Context) (int64, error) { // will be duplicated with: // rowID = 0100111111111111111111111111111111111111111111111111111111111111 // shard = 0010000000000000000000000000000000000000000000000000000000000000 - return 0, autoid.ErrAutoincReadFailed + return 0, 0, autoid.ErrAutoincReadFailed } txnCtx := ctx.GetSessionVars().TxnCtx if txnCtx.Shard == nil { shard := t.calcShard(txnCtx.StartTS) txnCtx.Shard = &shard } - rowID |= *txnCtx.Shard + base |= *txnCtx.Shard + maxID |= *txnCtx.Shard } - return rowID, nil + return base, maxID, nil } // OverflowShardBits checks whether the rowID overflow `1<<(64-shardRowIDBits-1) -1`.