Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executer: fix the last_insert_id in insert on duplicate key update #7534

Merged
merged 4 commits into from
Aug 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 4 additions & 20 deletions executor/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@ type InsertExec struct {
finished bool
}

func (e *InsertExec) insertOneRow(row []types.Datum) (int64, error) {
e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil)
h, err := e.Table.AddRecord(e.ctx, row, false)
e.ctx.Txn().DelOption(kv.PresumeKeyNotExists)
if err != nil {
return 0, errors.Trace(err)
}
return h, nil
}

func (e *InsertExec) exec(rows [][]types.Datum) error {
// If tidb_batch_insert is ON and not in a transaction, we could use BatchInsert mode.
sessVars := e.ctx.GetSessionVars()
Expand All @@ -67,20 +57,17 @@ func (e *InsertExec) exec(rows [][]types.Datum) error {
return errors.Trace(err)
}
} else if ignoreErr {
err := e.batchCheckAndInsert(rows, e.insertOneRow)
err := e.batchCheckAndInsert(rows, e.addRecord)
if err != nil {
return errors.Trace(err)
}
} else {
for _, row := range rows {
if _, err := e.insertOneRow(row); err != nil {
if _, err := e.addRecord(row); err != nil {
return errors.Trace(err)
}
}
}
if e.lastInsertID != 0 {
sessVars.SetLastInsertID(e.lastInsertID)
}
e.finished = true
return nil
}
Expand Down Expand Up @@ -131,7 +118,7 @@ func (e *InsertExec) batchUpdateDupRows(newRows [][]types.Datum) error {
// and key-values should be filled back to dupOldRowValues for the further row check,
// due to there may be duplicate keys inside the insert statement.
if newRows[i] != nil {
newHandle, err := e.insertOneRow(newRows[i])
newHandle, err := e.addRecord(newRows[i])
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -220,13 +207,10 @@ func (e *InsertExec) doDupRowUpdate(handle int64, oldRow []types.Datum, newRow [
}

newData := row4Update[:len(oldRow)]
_, handleChanged, newHandle, lastInsertID, err := updateRecord(e.ctx, handle, oldRow, newData, assignFlag, e.Table, true)
_, handleChanged, newHandle, err := updateRecord(e.ctx, handle, oldRow, newData, assignFlag, e.Table, true)
if err != nil {
return nil, false, 0, errors.Trace(err)
}
if lastInsertID != 0 {
e.lastInsertID = lastInsertID
}
return newData, handleChanged, newHandle, nil
}

Expand Down
21 changes: 17 additions & 4 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (e *InsertValues) insertRows(cols []*table.Column, exec func(rows [][]types

rows := make([][]types.Datum, len(e.Lists))
for i, list := range e.Lists {
e.rowCount = uint64(i)
e.rowCount++
rows[i], err = e.getRow(cols, list, i)
if err != nil {
return errors.Trace(err)
Expand Down Expand Up @@ -445,7 +445,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(row []types.Datum, i int, c *tab
return errors.Trace(err)
}
// It's compatible with mysql. So it sets last insert id to the first row.
if e.rowCount == 0 {
if e.rowCount == 1 {
e.lastInsertID = uint64(recordID)
}
}
Expand Down Expand Up @@ -474,7 +474,7 @@ func (e *InsertValues) handleWarning(err error, logInfo string) {

// batchCheckAndInsert checks rows with duplicate errors.
// All duplicate rows will be ignored and appended as duplicate warnings.
func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow func(row []types.Datum) (int64, error)) error {
func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, addRecord func(row []types.Datum) (int64, error)) error {
// all the rows will be checked, so it is safe to set BatchCheck = true
e.ctx.GetSessionVars().StmtCtx.BatchCheck = true
err := e.batchGetInsertKeys(e.ctx, e.Table, rows)
Expand Down Expand Up @@ -502,7 +502,7 @@ func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow fu
// it should be add to values map for the further row check.
// There may be duplicate keys inside the insert statement.
if rows[i] != nil {
_, err = insertOneRow(rows[i])
_, err = addRecord(rows[i])
if err != nil {
return errors.Trace(err)
}
Expand All @@ -516,3 +516,16 @@ func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, insertOneRow fu
}
return nil
}

func (e *InsertValues) addRecord(row []types.Datum) (int64, error) {
e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil)
h, err := e.Table.AddRecord(e.ctx, row, false)
e.ctx.Txn().DelOption(kv.PresumeKeyNotExists)
if err != nil {
return 0, errors.Trace(err)
}
if e.lastInsertID != 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the e.lastInsertID is not set anywhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set it at line 449.

e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID)
}
return h, nil
}
10 changes: 3 additions & 7 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,10 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error
break
}
}
err := e.batchCheckAndInsert(rows, e.insertData)
err := e.batchCheckAndInsert(rows, e.addRecordLD)
if err != nil {
return nil, reachLimit, errors.Trace(err)
}
if e.lastInsertID != 0 {
e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID)
}

return curData, reachLimit, nil
}

Expand All @@ -282,11 +278,11 @@ func (e *LoadDataInfo) colsToRow(cols []field) []types.Datum {
return row
}

func (e *LoadDataInfo) insertData(row []types.Datum) (int64, error) {
func (e *LoadDataInfo) addRecordLD(row []types.Datum) (int64, error) {
if row == nil {
return 0, nil
}
h, err := e.Table.AddRecord(e.ctx, row, false)
h, err := e.addRecord(row)
if err != nil {
e.handleWarning(err,
fmt.Sprintf("Load Data: insert data:%v failed:%v", e.row, errors.ErrorStack(err)))
Expand Down
18 changes: 1 addition & 17 deletions executor/replace.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ package executor

import (
"github.com/juju/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -83,18 +82,6 @@ func (e *ReplaceExec) removeRow(handle int64, r toBeCheckedRow) (bool, error) {
return false, nil
}

// addRow adds a row when all the duplicate key were checked.
func (e *ReplaceExec) addRow(row []types.Datum) (int64, error) {
// Set kv.PresumeKeyNotExists is safe here, because we've already removed all duplicated rows.
e.ctx.Txn().SetOption(kv.PresumeKeyNotExists, nil)
h, err := e.Table.AddRecord(e.ctx, row, false)
e.ctx.Txn().DelOption(kv.PresumeKeyNotExists)
if err != nil {
return 0, errors.Trace(err)
}
return h, nil
}

// replaceRow removes all duplicate rows for one row, then inserts it.
func (e *ReplaceExec) replaceRow(r toBeCheckedRow) error {
if r.handleKey != nil {
Expand Down Expand Up @@ -129,7 +116,7 @@ func (e *ReplaceExec) replaceRow(r toBeCheckedRow) error {
}

// No duplicated rows now, insert the row.
newHandle, err := e.addRow(r.row)
newHandle, err := e.addRecord(r.row)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -190,9 +177,6 @@ func (e *ReplaceExec) exec(newRows [][]types.Datum) error {
return errors.Trace(err)
}
}
if e.lastInsertID != 0 {
e.ctx.GetSessionVars().SetLastInsertID(e.lastInsertID)
}
e.finished = true
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion executor/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (e *UpdateExec) exec(schema *expression.Schema) ([]types.Datum, error) {
}

// Update row
changed, _, _, _, err1 := updateRecord(e.ctx, handle, oldData, newTableData, flags, tbl, false)
changed, _, _, err1 := updateRecord(e.ctx, handle, oldData, newTableData, flags, tbl, false)
if err1 == nil {
if changed {
e.updatedRowKeys[id][handle] = struct{}{}
Expand Down
33 changes: 15 additions & 18 deletions executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,15 @@ var (
// 1. changed (bool) : does the update really change the row values. e.g. update set i = 1 where i = 1;
// 2. handleChanged (bool) : is the handle changed after the update.
// 3. newHandle (int64) : if handleChanged == true, the newHandle means the new handle after update.
// 4. lastInsertID (uint64) : the lastInsertID should be set by the newData.
// 5. err (error) : error in the update.
// 4. err (error) : error in the update.
func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, modified []bool, t table.Table,
onDup bool) (bool, bool, int64, uint64, error) {
onDup bool) (bool, bool, int64, error) {
var sc = ctx.GetSessionVars().StmtCtx
var changed, handleChanged = false, false
// onUpdateSpecified is for "UPDATE SET ts_field = old_value", the
// timestamp field is explicitly set, but not changed in fact.
var onUpdateSpecified = make(map[int]bool)
var newHandle int64
var lastInsertID uint64

// We can iterate on public columns not writable columns,
// because all of them are sorted by their `Offset`, which
Expand All @@ -61,7 +59,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
// Cast changed fields with respective columns.
v, err := table.CastValue(ctx, newData[i], col.ToInfo())
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
newData[i] = v
}
Expand All @@ -70,27 +68,26 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
var err error
newData[i], err = table.GetColDefaultValue(ctx, col.ToInfo())
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
}
// Rebase auto increment id if the field is changed.
if mysql.HasAutoIncrementFlag(col.Flag) {
if newData[i].IsNull() {
return false, handleChanged, newHandle, 0, table.ErrColumnCantNull.GenByArgs(col.Name)
return false, handleChanged, newHandle, table.ErrColumnCantNull.GenByArgs(col.Name)
}
val, errTI := newData[i].ToInt64(sc)
if errTI != nil {
return false, handleChanged, newHandle, 0, errors.Trace(errTI)
return false, handleChanged, newHandle, errors.Trace(errTI)
}
lastInsertID = uint64(val)
err := t.RebaseAutoID(ctx, val, true)
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
}
cmp, err := newData[i].CompareDatum(sc, &oldData[i])
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
if cmp != 0 {
changed = true
Expand All @@ -111,23 +108,23 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
// Check the not-null constraints.
err := table.CheckNotNull(t.Cols(), newData)
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}

if !changed {
// See https://dev.mysql.com/doc/refman/5.7/en/mysql-real-connect.html CLIENT_FOUND_ROWS
if ctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 {
sc.AddAffectedRows(1)
}
return false, handleChanged, newHandle, lastInsertID, nil
return false, handleChanged, newHandle, nil
}

// Fill values into on-update-now fields, only if they are really changed.
for i, col := range t.Cols() {
if mysql.HasOnUpdateNowFlag(col.Flag) && !modified[i] && !onUpdateSpecified[i] {
v, errGT := expression.GetTimeValue(ctx, strings.ToUpper(ast.CurrentTimestamp), col.Tp, col.Decimal)
if errGT != nil {
return false, handleChanged, newHandle, 0, errors.Trace(errGT)
return false, handleChanged, newHandle, errors.Trace(errGT)
}
newData[i] = v
modified[i] = true
Expand All @@ -140,21 +137,21 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
// if the new handle exists. `UPDATE IGNORE` will avoid removing record, and do nothing.
err = tables.CheckHandleExists(ctx, t, newHandle, newData)
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
skipHandleCheck = true
}
err = t.RemoveRecord(ctx, h, oldData)
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}
newHandle, err = t.AddRecord(ctx, newData, skipHandleCheck)
} else {
// Update record to new value and update index.
err = t.UpdateRecord(ctx, h, oldData, newData, modified)
}
if err != nil {
return false, handleChanged, newHandle, 0, errors.Trace(err)
return false, handleChanged, newHandle, errors.Trace(err)
}

if onDup {
Expand All @@ -173,7 +170,7 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu
}
}
ctx.GetSessionVars().TxnCtx.UpdateDeltaForTable(t.Meta().ID, 0, 1, colSize)
return true, handleChanged, newHandle, lastInsertID, nil
return true, handleChanged, newHandle, nil
}

// resetErrDataTooLong reset ErrDataTooLong error msg.
Expand Down
5 changes: 5 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,11 @@ commit;`
testSQL = `SELECT LAST_INSERT_ID();`
r = tk.MustQuery(testSQL)
r.Check(testkit.Rows("1"))
testSQL = `INSERT t1 (f2) VALUES ('test') ON DUPLICATE KEY UPDATE f1 = 2;`
tk.MustExec(testSQL)
testSQL = `SELECT LAST_INSERT_ID();`
r = tk.MustQuery(testSQL)
r.Check(testkit.Rows("1"))

testSQL = `DROP TABLE IF EXISTS t1;
CREATE TABLE t1 (f1 INT);
Expand Down
4 changes: 2 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func runTestLoadData(c *C, server *Server) {
dbt.Assert(err, IsNil)
lastID, err = rs.LastInsertId()
dbt.Assert(err, IsNil)
dbt.Assert(lastID, Equals, int64(6))
dbt.Assert(lastID, Equals, int64(7))
affectedRows, err = rs.RowsAffected()
dbt.Assert(err, IsNil)
dbt.Assert(affectedRows, Equals, int64(4))
Expand Down Expand Up @@ -464,7 +464,7 @@ func runTestLoadData(c *C, server *Server) {
dbt.Assert(err, IsNil)
lastID, err = rs.LastInsertId()
dbt.Assert(err, IsNil)
dbt.Assert(lastID, Equals, int64(10))
dbt.Assert(lastID, Equals, int64(11))
affectedRows, err = rs.RowsAffected()
dbt.Assert(err, IsNil)
dbt.Assert(affectedRows, Equals, int64(799))
Expand Down