diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index 0ce745d172e4a..ce917f7a650ec 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -450,15 +450,16 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { if !e.txn.Valid() { return kv.ErrInvalidTxn } - membuf := e.txn.GetMemBuffer() - for _, idxKey := range indexKeys { - handleVal := handleVals[string(idxKey)] - if len(handleVal) == 0 { - continue - } - err = membuf.Set(idxKey, handleVal) - if err != nil { - return err + txn, ok := e.txn.(interface { + ChangeLockIntoPut(context.Context, kv.Key, []byte) bool + }) + if ok { + for _, idxKey := range indexKeys { + handleVal := handleVals[string(idxKey)] + if len(handleVal) == 0 { + continue + } + txn.ChangeLockIntoPut(ctx, idxKey, handleVal) } } } diff --git a/executor/insert_common.go b/executor/insert_common.go index 207b762400919..fbc801a2b57eb 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -1128,6 +1128,10 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D } } else { e.ctx.GetSessionVars().StmtCtx.AppendWarning(r.handleKey.dupErr) + if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic { + // lock duplicated row key on insert-ignore + txnCtx.AddUnchangedRowKey(r.handleKey.newKey) + } continue } } else if !kv.IsErrNotFound(err) { @@ -1139,6 +1143,10 @@ func (e *InsertValues) batchCheckAndInsert(ctx context.Context, rows [][]types.D if err == nil { // If duplicate keys were found in BatchGet, mark row = nil. e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) + if txnCtx := e.ctx.GetSessionVars().TxnCtx; txnCtx.IsPessimistic { + // lock duplicated unique key on insert-ignore + txnCtx.AddUnchangedRowKey(uk.newKey) + } skip = true break } @@ -1187,6 +1195,10 @@ func (e *InsertValues) removeRow(ctx context.Context, txn kv.Transaction, r toBe return err } if identical { + _, err := appendUnchangedRowForLock(e.ctx, r.t, handle, oldRow) + if err != nil { + return err + } return nil } diff --git a/executor/insert_test.go b/executor/insert_test.go index 8497b7ac88f43..204ec5f65c468 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -2032,3 +2032,70 @@ func TestIssue32213(t *testing.T) { tk.MustQuery("select cast(test.t1.c1 as decimal(5, 3)) from test.t1").Check(testkit.Rows("99.999")) tk.MustQuery("select cast(test.t1.c1 as decimal(6, 3)) from test.t1").Check(testkit.Rows("100.000")) } + +func TestInsertLock(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk1 := testkit.NewTestKit(t, store) + tk2 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk2.MustExec("use test") + + for _, tt := range []struct { + name string + ddl string + dml string + }{ + { + "replace-pk", + "create table t (c int primary key clustered)", + "replace into t values (1)", + }, + { + "replace-uk", + "create table t (c int unique key)", + "replace into t values (1)", + }, + { + "insert-ingore-pk", + "create table t (c int primary key clustered)", + "insert ignore into t values (1)", + }, + { + "insert-ingore-uk", + "create table t (c int unique key)", + "insert ignore into t values (1)", + }, + { + "insert-update-pk", + "create table t (c int primary key clustered)", + "insert into t values (1) on duplicate key update c = values(c)", + }, + { + "insert-update-uk", + "create table t (c int unique key)", + "insert into t values (1) on duplicate key update c = values(c)", + }, + } { + t.Run(tt.name, func(t *testing.T) { + tk1.MustExec("drop table if exists t") + tk1.MustExec(tt.ddl) + tk1.MustExec("insert into t values (1)") + tk1.MustExec("begin pessimistic") + tk1.MustExec(tt.dml) + done := make(chan struct{}) + go func() { + tk2.MustExec("delete from t") + done <- struct{}{} + }() + select { + case <-done: + require.Fail(t, fmt.Sprintf("txn2 is not blocked by %q", tt.dml)) + case <-time.After(100 * time.Millisecond): + } + tk1.MustExec("commit") + <-done + tk1.MustQuery("select * from t").Check([][]interface{}{}) + }) + } +} diff --git a/executor/point_get.go b/executor/point_get.go index f33ba20b5dd5a..a9cc5cb8fa3cd 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -270,10 +270,11 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { if !e.txn.Valid() { return kv.ErrInvalidTxn } - memBuffer := e.txn.GetMemBuffer() - err = memBuffer.Set(e.idxKey, e.handleVal) - if err != nil { - return err + txn, ok := e.txn.(interface { + ChangeLockIntoPut(context.Context, kv.Key, []byte) bool + }) + if ok { + txn.ChangeLockIntoPut(ctx, e.idxKey, e.handleVal) } } } diff --git a/executor/replace.go b/executor/replace.go index 221cbf87b2504..d9f64ee67bf3c 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -82,6 +82,10 @@ func (e *ReplaceExec) removeRow(ctx context.Context, txn kv.Transaction, handle } if rowUnchanged { e.ctx.GetSessionVars().StmtCtx.AddAffectedRows(1) + _, err := appendUnchangedRowForLock(e.ctx, r.t, handle, oldRow) + if err != nil { + return false, err + } return true, nil } diff --git a/executor/write.go b/executor/write.go index 7c9aba9331c4d..7caf2afbb37de 100644 --- a/executor/write.go +++ b/executor/write.go @@ -126,22 +126,8 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old if sctx.GetSessionVars().ClientCapability&mysql.ClientFoundRows > 0 { sc.AddAffectedRows(1) } - - physicalID := t.Meta().ID - if pt, ok := t.(table.PartitionedTable); ok { - p, err := pt.GetPartitionByRow(sctx, oldData) - if err != nil { - return false, err - } - physicalID = p.GetPhysicalID() - } - - unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h) - txnCtx := sctx.GetSessionVars().TxnCtx - if txnCtx.IsPessimistic { - txnCtx.AddUnchangedRowKey(unchangedRowKey) - } - return false, nil + _, err := appendUnchangedRowForLock(sctx, t, h, oldData) + return false, err } // Fill values into on-update-now fields, only if they are really changed. @@ -207,6 +193,24 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old return true, nil } +func appendUnchangedRowForLock(sctx sessionctx.Context, t table.Table, h kv.Handle, row []types.Datum) (bool, error) { + txnCtx := sctx.GetSessionVars().TxnCtx + if !txnCtx.IsPessimistic { + return false, nil + } + physicalID := t.Meta().ID + if pt, ok := t.(table.PartitionedTable); ok { + p, err := pt.GetPartitionByRow(sctx, row) + if err != nil { + return false, err + } + physicalID = p.GetPhysicalID() + } + unchangedRowKey := tablecodec.EncodeRowKeyWithHandle(physicalID, h) + txnCtx.AddUnchangedRowKey(unchangedRowKey) + return true, nil +} + func rebaseAutoRandomValue(ctx context.Context, sctx sessionctx.Context, t table.Table, newData *types.Datum, col *table.Column) error { tableInfo := t.Meta() if !tableInfo.ContainsAutoRandomBits() { diff --git a/go.mod b/go.mod index 95872b570f544..e964ac7e0aba3 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.2-0.20220504104629-106ec21d14df github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 - github.com/tikv/client-go/v2 v2.0.1-0.20230117081319-35a262e90d9b + github.com/tikv/client-go/v2 v2.0.1-0.20230329072435-fc18f677df02 github.com/tikv/pd/client v0.0.0-20220307081149-841fa61e9710 github.com/twmb/murmur3 v1.1.3 github.com/uber/jaeger-client-go v2.22.1+incompatible diff --git a/go.sum b/go.sum index 40babd153241d..79def501685e8 100644 --- a/go.sum +++ b/go.sum @@ -755,8 +755,8 @@ github.com/stretchr/testify v1.7.2-0.20220504104629-106ec21d14df/go.mod h1:6Fq8o github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJfDRtkanvQPiooDH8HvJ2FBh+iKT/OmiQQ= github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfKggNGDuadAa0LElHrByyrz4JPZ9fFx6Gs7nx7ZZU= -github.com/tikv/client-go/v2 v2.0.1-0.20230117081319-35a262e90d9b h1:IUH/4BrP9BJm7to+XUJslcwaZONuIEwwClBnlrO7zJM= -github.com/tikv/client-go/v2 v2.0.1-0.20230117081319-35a262e90d9b/go.mod h1:VTlli8fRRpcpISj9I2IqroQmcAFfaTyBquiRhofOcDs= +github.com/tikv/client-go/v2 v2.0.1-0.20230329072435-fc18f677df02 h1:5dBj57AfcdDSU6uV2RzGZcZDYOHIi+6aUbwREkVggxg= +github.com/tikv/client-go/v2 v2.0.1-0.20230329072435-fc18f677df02/go.mod h1:VTlli8fRRpcpISj9I2IqroQmcAFfaTyBquiRhofOcDs= github.com/tikv/pd/client v0.0.0-20220307081149-841fa61e9710 h1:jxgmKOscXSjaFEKQGRyY5qOpK8hLqxs2irb/uDJMtwk= github.com/tikv/pd/client v0.0.0-20220307081149-841fa61e9710/go.mod h1:AtvppPwkiyUgQlR1W9qSqfTB+OsOIu19jDCOxOsPkmU= github.com/tklauser/go-sysconf v0.3.9 h1:JeUVdAOWhhxVcU6Eqr/ATFHgXk/mmiItdKeJPev3vTo= diff --git a/session/txn.go b/session/txn.go index 8d359f94e60f4..9ee3da39bf3aa 100644 --- a/session/txn.go +++ b/session/txn.go @@ -408,6 +408,24 @@ func (txn *LazyTxn) LockKeysFunc(ctx context.Context, lockCtx *kv.LockCtx, fn fu return txn.Transaction.LockKeysFunc(ctx, lockCtx, lockFunc, keys...) } +// ChangeLockIntoPut tries to cache a locked key-value pair that might be converted to PUT on commit, returns true if +// the key-value pair has been cached. +func (txn *LazyTxn) ChangeLockIntoPut(ctx context.Context, key kv.Key, value []byte) bool { + if len(value) == 0 { + return false + } + cache, ok := txn.Transaction.(interface{ SetLockedKeyValue([]byte, []byte) }) + if !ok { + return false + } + _, err := txn.GetMemBuffer().Get(ctx, key) + if !kv.IsErrNotFound(err) { + return false + } + cache.SetLockedKeyValue(key, value) + return true +} + func (txn *LazyTxn) reset() { txn.cleanup() txn.changeToInvalid() diff --git a/tests/realtikvtest/pessimistictest/pessimistic_test.go b/tests/realtikvtest/pessimistictest/pessimistic_test.go index 262f827f49a97..29a42968e7a2c 100644 --- a/tests/realtikvtest/pessimistictest/pessimistic_test.go +++ b/tests/realtikvtest/pessimistictest/pessimistic_test.go @@ -543,7 +543,7 @@ func TestOptimisticConflicts(t *testing.T) { tk.MustExec("begin pessimistic") // This SQL use BatchGet and cache data in the txn snapshot. // It can be changed to other SQLs that use BatchGet. - tk.MustExec("insert ignore into conflict values (1, 2)") + tk.MustExec("select * from conflict where id in (1, 2, 3)") tk2.MustExec("update conflict set c = c - 1") @@ -2913,6 +2913,40 @@ func TestChangeLockToPut(t *testing.T) { tk.MustExec("admin check table t1") } +func TestIssue28011(t *testing.T) { + store, clean := realtikvtest.CreateMockStoreAndSetup(t) + defer clean() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + + for _, tt := range []struct { + name string + lockQuery string + finalRows [][]interface{} + }{ + {"Update", "update t set b = 'x' where a = 'a'", testkit.Rows("a x", "b y", "c z")}, + {"BatchUpdate", "update t set b = 'x' where a in ('a', 'b', 'c')", testkit.Rows("a x", "b y", "c x")}, + {"SelectForUpdate", "select a from t where a = 'a' for update", testkit.Rows("a x", "b y", "c z")}, + {"BatchSelectForUpdate", "select a from t where a in ('a', 'b', 'c') for update", testkit.Rows("a x", "b y", "c z")}, + } { + t.Run(tt.name, func(t *testing.T) { + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a varchar(10) primary key nonclustered, b varchar(10))") + tk.MustExec("insert into t values ('a', 'x'), ('b', 'x'), ('c', 'z')") + tk.MustExec("begin pessimistic") + tk.MustExec(tt.lockQuery) + tk.MustQuery("select a from t").Check(testkit.Rows("a", "b", "c")) + tk.MustExec("replace into t values ('b', 'y')") + tk.MustQuery("select a from t").Check(testkit.Rows("a", "b", "c")) + tk.MustQuery("select a, b from t order by a").Check(tt.finalRows) + tk.MustExec("commit") + tk.MustQuery("select a, b from t order by a").Check(tt.finalRows) + tk.MustExec("admin check table t") + }) + } +} + func createTable(part bool, columnNames []string, columnTypes []string) string { var str string str = "create table t("