diff --git a/executor/adapter.go b/executor/adapter.go index 9466ba4f80144..fc4b69c929be6 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -612,7 +612,7 @@ func (a *ExecStmt) handleForeignKeyTrigger(ctx context.Context, e Executor, dept } func (a *ExecStmt) handleForeignKeyCascade(ctx context.Context, fkc *FKCascadeExec, depth int) error { - if len(fkc.fkValues) == 0 { + if len(fkc.fkValues) == 0 && len(fkc.fkUpdatedValuesMap) == 0 { return nil } if depth > maxForeignKeyCascadeDepth { diff --git a/executor/builder.go b/executor/builder.go index b08bf123b3c52..4df7a5f123a1d 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2257,6 +2257,10 @@ func (b *executorBuilder) buildUpdate(v *plannercore.Update) Executor { if b.err != nil { return nil } + updateExec.fkCascades, b.err = b.buildTblID2FKCascadeExecs(tblID2table, v.FKCascades) + if b.err != nil { + return nil + } return updateExec } diff --git a/executor/fktest/foreign_key_test.go b/executor/fktest/foreign_key_test.go index fe977e2dd506b..ac20a0c43c9df 100644 --- a/executor/fktest/foreign_key_test.go +++ b/executor/fktest/foreign_key_test.go @@ -23,6 +23,7 @@ import ( "time" "github.com/pingcap/tidb/executor" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser/model" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/testkit" @@ -1264,11 +1265,26 @@ func TestForeignKeyGenerateCascadeSQL(t *testing.T) { sql, err = executor.GenCascadeSetNullSQL(model.NewCIStr("test"), model.NewCIStr("t"), model.NewCIStr(""), fk, fkValues) require.NoError(t, err) - require.Equal(t, "UPDATE `test`.`t` SET `c0`=NULL, `c1`=NULL WHERE (`c0`, `c1`) IN ((1,'a'), (2,'b'))", sql) + require.Equal(t, "UPDATE `test`.`t` SET `c0` = NULL, `c1` = NULL WHERE (`c0`, `c1`) IN ((1,'a'), (2,'b'))", sql) sql, err = executor.GenCascadeSetNullSQL(model.NewCIStr("test"), model.NewCIStr("t"), model.NewCIStr("idx"), fk, fkValues) require.NoError(t, err) - require.Equal(t, "UPDATE `test`.`t` USE INDEX(`idx`) SET `c0`=NULL, `c1`=NULL WHERE (`c0`, `c1`) IN ((1,'a'), (2,'b'))", sql) + require.Equal(t, "UPDATE `test`.`t` USE INDEX(`idx`) SET `c0` = NULL, `c1` = NULL WHERE (`c0`, `c1`) IN ((1,'a'), (2,'b'))", sql) + + newValue1 := []types.Datum{types.NewDatum(10), types.NewDatum("aa")} + couple := &executor.UpdatedValuesCouple{ + NewValues: newValue1, + OldValuesList: fkValues, + } + sql, err = executor.GenCascadeUpdateSQL(model.NewCIStr("test"), model.NewCIStr("t"), model.NewCIStr(""), fk, couple) + require.NoError(t, err) + require.Equal(t, "UPDATE `test`.`t` SET `c0` = 10, `c1` = 'aa' WHERE (`c0`, `c1`) IN ((1,'a'), (2,'b'))", sql) + + newValue2 := []types.Datum{types.NewDatum(nil), types.NewDatum(nil)} + couple.NewValues = newValue2 + sql, err = executor.GenCascadeUpdateSQL(model.NewCIStr("test"), model.NewCIStr("t"), model.NewCIStr("idx"), fk, couple) + require.NoError(t, err) + require.Equal(t, "UPDATE `test`.`t` USE INDEX(`idx`) SET `c0` = NULL, `c1` = NULL WHERE (`c0`, `c1`) IN ((1,'a'), (2,'b'))", sql) } func TestForeignKeyOnDeleteSetNull(t *testing.T) { @@ -1585,6 +1601,330 @@ func TestForeignKeyOnDeleteSetNull2(t *testing.T) { tk.MustQuery("select count(*) from t2 where id is null").Check(testkit.Rows("32768")) } +func TestForeignKeyOnUpdateCascade(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_enable_foreign_key=1") + tk.MustExec("set @@foreign_key_checks=1") + tk.MustExec("use test") + + cases := []struct { + prepareSQLs []string + }{ + // Case-1: test unique index only contain foreign key columns. + { + prepareSQLs: []string{ + "create table t1 (id int, a int, b int, unique index(a, b));", + "create table t2 (b int, name varchar(10), a int, id int, unique index (a,b), foreign key fk(a, b) references t1(a, b) ON UPDATE CASCADE);", + }, + }, + // Case-2: test unique index contain foreign key columns and other columns. + { + prepareSQLs: []string{ + "create table t1 (id int key, a int, b int, unique index(a, b, id));", + "create table t2 (b int, name varchar(10), a int, id int key, unique index (a,b, id), foreign key fk(a, b) references t1(a, b) ON UPDATE CASCADE);", + }, + }, + // Case-3: test non-unique index only contain foreign key columns. + { + prepareSQLs: []string{ + "create table t1 (id int key,a int, b int, index(a, b));", + "create table t2 (b int, a int, name varchar(10), id int key, index (a, b), foreign key fk(a, b) references t1(a, b) ON UPDATE CASCADE);", + }, + }, + // Case-4: test non-unique index contain foreign key columns and other columns. + { + prepareSQLs: []string{ + "create table t1 (id int key,a int, b int, index(a, b, id));", + "create table t2 (name varchar(10), b int, id int key, a int, index (a, b, id), foreign key fk(a, b) references t1(a, b) ON UPDATE CASCADE);", + }, + }, + } + + for idx, ca := range cases { + tk.MustExec("drop table if exists t2;") + tk.MustExec("drop table if exists t1;") + for _, sql := range ca.prepareSQLs { + tk.MustExec(sql) + } + tk.MustExec("insert into t1 (id, a, b) values (1, 11, 21),(2, 12, 22), (3, 13, 23), (4, 14, 24), (5, 15, null), (6, null, 26), (7, null, null);") + tk.MustExec("insert into t2 (id, a, b, name) values (1, 11, 21, 'a'),(2, 12, 22, 'b'), (3, 13, 23, 'c'), (4, 14, 24, 'd'), (5, 15, null, 'e'), (6, null, 26, 'f'), (7, null, null, 'g');") + tk.MustExec("update t1 set a=a+100, b = b+200 where id in (1, 2)") + tk.MustQuery("select id, a, b from t1 where id in (1,2) order by id").Check(testkit.Rows("1 111 221", "2 112 222")) + tk.MustQuery("select id, a, b, name from t2 where id in (1,2,3) order by id").Check(testkit.Rows("1 111 221 a", "2 112 222 b", "3 13 23 c")) + // Test update fk column to null + tk.MustExec("update t1 set a=101, b=null where id = 1 or b = 222") + tk.MustQuery("select id, a, b from t1 where id in (1,2) order by id").Check(testkit.Rows("1 101 ", "2 101 ")) + tk.MustQuery("select id, a, b, name from t2 where id in (1,2,3) order by id").Check(testkit.Rows("1 101 a", "2 101 b", "3 13 23 c")) + tk.MustExec("update t1 set a=null where b is null") + tk.MustQuery("select id, a, b from t1 where b is null order by id").Check(testkit.Rows("1 ", "2 ", "5 ", "7 ")) + tk.MustQuery("select id, a, b, name from t2 where b is null order by id").Check(testkit.Rows("1 101 a", "2 101 b", "5 15 e", "7 g")) + // Test update fk column from null to not-null value + tk.MustExec("update t1 set a=0, b = 0 where id = 7") + tk.MustQuery("select id, a, b from t1 where a=0 and b=0 order by id").Check(testkit.Rows("7 0 0")) + tk.MustQuery("select id, a, b from t2 where a=0 and b=0 order by id").Check(testkit.Rows()) + + // Test in transaction. + tk.MustExec("delete from t2") + tk.MustExec("delete from t1") + tk.MustExec("begin") + tk.MustExec("insert into t1 values (1, 1, 1),(2, 2, 2), (3, 3, 3), (4, 4, 4), (5, 5, null), (6, null, 6), (7, null, null);") + tk.MustExec("insert into t2 (id, a, b, name) values (1, 1, 1, 'a'),(2, 2, 2, 'b'), (3, 3, 3, 'c'), (4, 4, 4, 'd'), (5, 5, null, 'e'), (6, null, 6, 'f'), (7, null, null, 'g');") + tk.MustExec("update t1 set a=a+100, b = b+200 where id in (1, 2)") + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 101 201 a", "2 102 202 b", "3 3 3 c", "4 4 4 d", "5 5 e", "6 6 f", "7 g")) + tk.MustExec("rollback") + tk.MustQuery("select * from t1").Check(testkit.Rows()) + tk.MustQuery("select * from t2").Check(testkit.Rows()) + + tk.MustExec("insert into t1 values (1, 1, 1),(2, 2, 2);") + tk.MustExec("begin") + tk.MustExec("insert into t2 (id, a, b, name) values (1, 1, 1, 'a'),(2, 2, 2, 'b')") + tk.MustExec("update t1 set a=101 where a = 1") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 101 1", "2 2 2")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 101 1 a", "2 2 2 b")) + err := tk.ExecToErr("insert into t2 (id, a, b, name) values (3, 1, 1, 'c')") + require.Error(t, err) + require.True(t, plannercore.ErrNoReferencedRow2.Equal(err), err.Error()) + tk.MustExec("insert into t1 values (3, 1, 1);") + tk.MustExec("insert into t2 (id, a, b, name) values (3, 1, 1, 'c')") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 101 1", "2 2 2", "3 1 1")) + tk.MustQuery("select id, a, b, name from t2 order by id, a").Check(testkit.Rows("1 101 1 a", "2 2 2 b", "3 1 1 c")) + tk.MustExec("update t1 set a=null, b=2000 where id in (1, 2)") + tk.MustExec("commit") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 2000", "2 2000", "3 1 1")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 2000 a", "2 2000 b", "3 1 1 c")) + + // only test in non-unique index + if idx >= 2 { + tk.MustExec("delete from t2") + tk.MustExec("delete from t1") + tk.MustExec("insert into t1 values (1, 1, 1),(2, 1, 1);") + tk.MustExec("begin") + tk.MustExec("update t1 set a=101 where id = 1") + tk.MustExec("insert into t2 (id, a, b, name) values (1, 1, 1, 'a')") + tk.MustExec("update t1 set b=102 where id = 2") + tk.MustQuery("select * from t1").Check(testkit.Rows("1 101 1", "2 1 102")) + tk.MustQuery("select id, a, b, name from t2").Check(testkit.Rows("1 1 102 a")) + err := tk.ExecToErr("insert into t2 (id, a, b, name) values (3, 1, 1, 'e')") + require.Error(t, err) + require.True(t, plannercore.ErrNoReferencedRow2.Equal(err), err.Error()) + tk.MustExec("insert into t1 values (3, 1, 1);") + tk.MustExec("insert into t2 (id, a, b, name) values (3, 1, 1, 'e')") + tk.MustExec("commit") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 101 1", "2 1 102", "3 1 1")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 1 102 a", "3 1 1 e")) + + tk.MustExec("delete from t2") + tk.MustExec("delete from t1") + tk.MustExec("begin") + tk.MustExec("insert into t1 values (1, 1, 1),(2, 1, 1);") + tk.MustExec("insert into t2 (id, a, b, name) values (1, 1, 1, 'a'), (2, 1, 1, 'b')") + tk.MustExec("update t1 set a=101, b=102 where id = 1") + tk.MustExec("commit") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 101 102", "2 1 1")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 101 102 a", "2 101 102 b")) + } + } + + cases = []struct { + prepareSQLs []string + }{ + // Case-5: test primary key only contain foreign key columns, and disable tidb_enable_clustered_index. + { + prepareSQLs: []string{ + "set @@tidb_enable_clustered_index=0;", + "create table t1 (id int, a int, b int, primary key (a, b));", + "create table t2 (b int, a int, name varchar(10), id int, primary key (a, b), foreign key fk(a, b) references t1(a, b) ON UPDATE CASCADE);", + }, + }, + // Case-6: test primary key only contain foreign key columns, and enable tidb_enable_clustered_index. + { + prepareSQLs: []string{ + "set @@tidb_enable_clustered_index=1;", + "create table t1 (id int, a int, b int, primary key (a, b));", + "create table t2 (name varchar(10), b int, a int, id int, primary key (a, b), foreign key fk(a, b) references t1(a, b) ON UPDATE CASCADE);", + }, + }, + // Case-7: test primary key contain foreign key columns and other column, and disable tidb_enable_clustered_index. + { + prepareSQLs: []string{ + "set @@tidb_enable_clustered_index=0;", + "create table t1 (id int, a int, b int, primary key (a, b, id));", + "create table t2 (b int, name varchar(10), a int, id int, primary key (a, b, id), foreign key fk(a, b) references t1(a, b) ON UPDATE CASCADE);", + }, + }, + // Case-8: test primary key contain foreign key columns and other column, and enable tidb_enable_clustered_index. + { + prepareSQLs: []string{ + "set @@tidb_enable_clustered_index=1;", + "create table t1 (id int, a int, b int, primary key (a, b, id));", + "create table t2 (b int, a int, id int, name varchar(10), primary key (a, b, id), foreign key fk(a, b) references t1(a, b) ON UPDATE CASCADE);", + }, + }, + } + for idx, ca := range cases { + tk.MustExec("drop table if exists t2;") + tk.MustExec("drop table if exists t1;") + for _, sql := range ca.prepareSQLs { + tk.MustExec(sql) + } + tk.MustExec("insert into t1 (id, a, b) values (1, 11, 21),(2, 12, 22), (3, 13, 23), (4, 14, 24)") + tk.MustExec("insert into t2 (id, a, b, name) values (1, 11, 21, 'a'),(2, 12, 22, 'b'), (3, 13, 23, 'c'), (4, 14, 24, 'd')") + tk.MustExec("update t1 set a=a+100, b = b+200 where id in (1, 2)") + tk.MustQuery("select id, a, b from t1 where id in (1,2) order by id").Check(testkit.Rows("1 111 221", "2 112 222")) + tk.MustQuery("select id, a, b, name from t2 where id in (1,2,3) order by id").Check(testkit.Rows("1 111 221 a", "2 112 222 b", "3 13 23 c")) + tk.MustExec("update t1 set a=101 where id = 1 or b = 222") + tk.MustQuery("select id, a, b from t1 where id in (1,2) order by id").Check(testkit.Rows("1 101 221", "2 101 222")) + tk.MustQuery("select id, a, b, name from t2 where id in (1,2,3) order by id").Check(testkit.Rows("1 101 221 a", "2 101 222 b", "3 13 23 c")) + + if idx < 2 { + tk.MustGetDBError("update t1 set b=200 where id in (1,2);", kv.ErrKeyExists) + } + + // test in transaction. + tk.MustExec("delete from t2") + tk.MustExec("delete from t1") + tk.MustExec("begin") + tk.MustExec("insert into t1 values (1, 1, 1),(2, 2, 2), (3, 3, 3), (4, 4, 4);") + tk.MustExec("insert into t2 (id, a, b, name) values (1, 1, 1, 'a'),(2, 2, 2, 'b'), (3, 3, 3, 'c'), (4, 4, 4, 'd');") + tk.MustExec("update t1 set a=a+100, b=b+200 where id = 1 or a = 2") + tk.MustExec("update t1 set a=a+1000, b=b+2000 where a in (2,3,4) or b in (5,6,7) or id=2") + tk.MustQuery("select id, a, b from t2 order by id").Check(testkit.Rows("1 101 201", "2 1102 2202", "3 1003 2003", "4 1004 2004")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 101 201 a", "2 1102 2202 b", "3 1003 2003 c", "4 1004 2004 d")) + tk.MustExec("commit") + tk.MustQuery("select id, a, b from t2 order by id").Check(testkit.Rows("1 101 201", "2 1102 2202", "3 1003 2003", "4 1004 2004")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 101 201 a", "2 1102 2202 b", "3 1003 2003 c", "4 1004 2004 d")) + + tk.MustExec("delete from t2") + tk.MustExec("delete from t1") + tk.MustExec("insert into t1 values (1, 1, 1),(2, 2, 2);") + tk.MustExec("begin") + tk.MustExec("insert into t2 (id, a, b, name) values (1, 1, 1, 'a'),(2, 2, 2, 'b')") + tk.MustExec("update t1 set a=a+100, b=b+200 where id = 1") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 101 201", "2 2 2")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 101 201 a", "2 2 2 b")) + err := tk.ExecToErr("insert into t2 (id, a, b, name) values (3, 1, 1, 'e')") + require.Error(t, err) + require.True(t, plannercore.ErrNoReferencedRow2.Equal(err), err.Error()) + tk.MustExec("insert into t1 values (3, 1, 1);") + tk.MustExec("insert into t2 (id, a, b, name) values (3, 1, 1, 'c')") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 101 201", "2 2 2", "3 1 1")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 101 201 a", "2 2 2 b", "3 1 1 c")) + tk.MustExec("update t1 set a=a+1000, b=b+2000 where a>1") + tk.MustExec("commit") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("1 1101 2201", "2 1002 2002", "3 1 1")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("1 1101 2201 a", "2 1002 2002 b", "3 1 1 c")) + } + + // Case-9: test primary key is handle and contain foreign key column. + tk.MustExec("drop table if exists t2;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("set @@tidb_enable_clustered_index=0;") + tk.MustExec("create table t1 (id int, a int, b int, primary key (id));") + tk.MustExec("create table t2 (b int, a int, id int, name varchar(10), primary key (a), foreign key fk(a) references t1(id) ON UPDATE CASCADE);") + tk.MustExec("insert into t1 (id, a, b) values (1, 11, 21),(2, 12, 22), (3, 13, 23), (4, 14, 24)") + tk.MustExec("insert into t2 (id, a, b, name) values (11, 1, 21, 'a'),(12, 2, 22, 'b'), (13, 3, 23, 'c'), (14, 4, 24, 'd')") + tk.MustExec("update t1 set id = id + 100 where id in (1, 2, 3)") + tk.MustQuery("select id, a, b from t1 order by id").Check(testkit.Rows("4 14 24", "101 11 21", "102 12 22", "103 13 23")) + tk.MustQuery("select id, a, b, name from t2 order by id").Check(testkit.Rows("11 101 21 a", "12 102 22 b", "13 103 23 c", "14 4 24 d")) +} + +func TestForeignKeyOnUpdateCascade2(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_enable_foreign_key=1") + tk.MustExec("set @@foreign_key_checks=1") + tk.MustExec("use test") + + // Test update same old row in parent, but only the first old row do cascade update + tk.MustExec("create table t1 (id int key, a int, index (a));") + tk.MustExec("create table t2 (id int key, pid int, constraint fk_pid foreign key (pid) references t1(a) ON UPDATE CASCADE);") + tk.MustExec("insert into t1 (id, a) values (1,1), (2, 1)") + tk.MustExec("insert into t2 (id, pid) values (1,1), (2, 1)") + tk.MustExec("update t1 set a=id+1") + tk.MustQuery("select id, a from t1 order by id").Check(testkit.Rows("1 2", "2 3")) + tk.MustQuery("select id, pid from t2 order by id").Check(testkit.Rows("1 2", "2 2")) + + // Test cascade delete in self table. + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (id int key, name varchar(10), leader int, index(leader), foreign key (leader) references t1(id) ON UPDATE CASCADE);") + tk.MustExec("insert into t1 values (1, 'boss', null), (10, 'l1_a', 1), (11, 'l1_b', 1), (12, 'l1_c', 1)") + tk.MustExec("insert into t1 values (100, 'l2_a1', 10)") + tk.MustExec("insert into t1 values (110, 'l2_b1', 11)") + tk.MustExec("insert into t1 values (1000,'l3_a1', 100)") + tk.MustExec("update t1 set id=id+10000 where id=11") + tk.MustQuery("select id, name, leader from t1 order by id").Check(testkit.Rows("1 boss ", "10 l1_a 1", "12 l1_c 1", "100 l2_a1 10", "110 l2_b1 10011", "1000 l3_a1 100", "10011 l1_b 1")) + tk.MustExec("update t1 set id=0 where id=1") + tk.MustQuery("select id, name, leader from t1 order by id").Check(testkit.Rows("0 boss ", "10 l1_a 0", "12 l1_c 0", "100 l2_a1 10", "110 l2_b1 10011", "1000 l3_a1 100", "10011 l1_b 0")) + + // Test explain analyze with foreign key cascade. + tk.MustExec("explain analyze update t1 set id=1 where id=10") + tk.MustQuery("select id, name, leader from t1 order by id").Check(testkit.Rows("0 boss ", "1 l1_a 0", "12 l1_c 0", "100 l2_a1 1", "110 l2_b1 10011", "1000 l3_a1 100", "10011 l1_b 0")) + + // Test cascade delete in self table with string type foreign key. + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1 (id varchar(100) key, name varchar(10), leader varchar(100), index(leader), foreign key (leader) references t1(id) ON UPDATE CASCADE);") + tk.MustExec("insert into t1 values (1, 'boss', null), (10, 'l1_a', 1), (11, 'l1_b', 1), (12, 'l1_c', 1)") + tk.MustExec("insert into t1 values (100, 'l2_a1', 10)") + tk.MustExec("insert into t1 values (110, 'l2_b1', 11)") + tk.MustExec("insert into t1 values (1000,'l3_a1', 100)") + tk.MustExec("update t1 set id=id+10000 where id=11") + tk.MustQuery("select id, name, leader from t1 order by name").Check(testkit.Rows("1 boss ", "10 l1_a 1", "10011 l1_b 1", "12 l1_c 1", "100 l2_a1 10", "110 l2_b1 10011", "1000 l3_a1 100")) + tk.MustExec("update t1 set id=0 where id=1") + tk.MustQuery("select id, name, leader from t1 order by name").Check(testkit.Rows("0 boss ", "10 l1_a 0", "10011 l1_b 0", "12 l1_c 0", "100 l2_a1 10", "110 l2_b1 10011", "1000 l3_a1 100")) + + // Test cascade delete depth error. + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t0 (id int, unique index(id))") + tk.MustExec("insert into t0 values (1)") + for i := 1; i < 17; i++ { + tk.MustExec(fmt.Sprintf("create table t%v (id int, unique index(id), foreign key (id) references t%v(id) on update cascade)", i, i-1)) + tk.MustExec(fmt.Sprintf("insert into t%v values (1)", i)) + } + tk.MustGetDBError("update t0 set id=10 where id=1;", executor.ErrForeignKeyCascadeDepthExceeded) + tk.MustQuery("select id from t0").Check(testkit.Rows("1")) + tk.MustQuery("select id from t15").Check(testkit.Rows("1")) + tk.MustExec("drop table if exists t16") + tk.MustExec("update t0 set id=10 where id=1;") + tk.MustQuery("select id from t0").Check(testkit.Rows("10")) + tk.MustQuery("select id from t15").Check(testkit.Rows("10")) + for i := 16; i > -1; i-- { + tk.MustExec("drop table if exists t" + strconv.Itoa(i)) + } + + // Test handle many foreign key value in one cascade. + tk.MustExec("create table t1 (id int auto_increment key, b int, index(b));") + tk.MustExec("create table t2 (id int, b int, foreign key fk(b) references t1(b) on update cascade)") + tk.MustExec("insert into t1 (b) values (1),(2),(3),(4),(5),(6),(7),(8);") + for i := 0; i < 12; i++ { + tk.MustExec("insert into t1 (b) select id from t1") + } + tk.MustQuery("select count(*) from t1").Check(testkit.Rows("32768")) + tk.MustExec("insert into t2 select * from t1") + tk.MustExec("update t1 set b=2") + tk.MustQuery("select count(*) from t1 join t2 where t1.id=t2.id and t1.b=t2.b").Check(testkit.Rows("32768")) +} + +func TestForeignKeyOnUpdateSetNull(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_enable_foreign_key=1") + tk.MustExec("set @@foreign_key_checks=1") + tk.MustExec("use test") + + // Test handle many foreign key value in one cascade. + tk.MustExec("create table t1 (id int auto_increment key, b int, index(b));") + tk.MustExec("create table t2 (id int, b int, foreign key fk(b) references t1(b) on update set null)") + tk.MustExec("insert into t1 (b) values (1),(2),(3),(4),(5),(6),(7),(8);") + for i := 0; i < 12; i++ { + tk.MustExec("insert into t1 (b) select id from t1") + } + tk.MustQuery("select count(*) from t1").Check(testkit.Rows("32768")) + tk.MustExec("insert into t2 select * from t1") + tk.MustExec("update t1 set b=b+100000000") + tk.MustQuery("select count(*) from t2 where b is null").Check(testkit.Rows("32768")) +} + func TestShowCreateTableWithForeignKey(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) @@ -1611,3 +1951,19 @@ func TestShowCreateTableWithForeignKey(t *testing.T) { " CONSTRAINT `fk2` FOREIGN KEY (`leader2`) REFERENCES `test`.`t1` (`id`)\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) } + +func TestForeignKeyCascadeOnDiffColumnType(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_enable_foreign_key=1") + tk.MustExec("set @@foreign_key_checks=1") + tk.MustExec("use test") + tk.MustExec("create table t1 (id bit(10), index(id));") + tk.MustExec("create table t2 (id int key, b bit(10), constraint fk foreign key (b) references t1(id) ON DELETE CASCADE ON UPDATE CASCADE);") + tk.MustExec("insert into t1 values (b'01'), (b'10');") + tk.MustExec("insert into t2 values (1, b'01'), (2, b'10');") + tk.MustExec("delete from t1 where id = b'01';") + tk.MustExec("update t1 set id = b'110' where id = b'10';") + tk.MustQuery("select cast(id as unsigned) from t1;").Check(testkit.Rows("6")) + tk.MustQuery("select id, cast(b as unsigned) from t2;").Check(testkit.Rows("2 6")) +} diff --git a/executor/foreign_key.go b/executor/foreign_key.go index 03b724ada59cf..ea45e9bd394b5 100644 --- a/executor/foreign_key.go +++ b/executor/foreign_key.go @@ -65,7 +65,17 @@ type FKCascadeExec struct { referredFK *model.ReferredFKInfo childTable *model.TableInfo fk *model.FKInfo - fkValues [][]types.Datum + // On delete statement, fkValues stores the delete foreign key values. + // On update statement and the foreign key cascade is `SET NULL`, fkValues stores the old foreign key values. + fkValues [][]types.Datum + // new-value-key => UpdatedValuesCouple + fkUpdatedValuesMap map[string]*UpdatedValuesCouple +} + +// UpdatedValuesCouple contains the updated new row the old rows, exporting for test. +type UpdatedValuesCouple struct { + NewValues []types.Datum + OldValuesList [][]types.Datum } func buildTblID2FKCheckExecs(sctx sessionctx.Context, tblID2Table map[int64]table.Table, tblID2FKChecks map[int64][]*plannercore.FKCheck) (map[int64][]*FKCheckExec, error) { @@ -561,12 +571,13 @@ func (b *executorBuilder) buildFKCascadeExec(tbl table.Table, fkCascade *planner fkValuesSet: set.NewStringSet(), } return &FKCascadeExec{ - b: b, - fkValueHelper: helper, - tp: fkCascade.Tp, - referredFK: fkCascade.ReferredFK, - childTable: fkCascade.ChildTable.Meta(), - fk: fkCascade.FK, + b: b, + fkValueHelper: helper, + tp: fkCascade.Tp, + referredFK: fkCascade.ReferredFK, + childTable: fkCascade.ChildTable.Meta(), + fk: fkCascade.FK, + fkUpdatedValuesMap: make(map[string]*UpdatedValuesCouple), }, nil } @@ -579,6 +590,34 @@ func (fkc *FKCascadeExec) onDeleteRow(sc *stmtctx.StatementContext, row []types. return nil } +func (fkc *FKCascadeExec) onUpdateRow(sc *stmtctx.StatementContext, oldRow, newRow []types.Datum) error { + oldVals, err := fkc.fetchFKValuesWithCheck(sc, oldRow) + if err != nil || len(oldVals) == 0 { + return err + } + if model.ReferOptionType(fkc.fk.OnUpdate) == model.ReferOptionSetNull { + fkc.fkValues = append(fkc.fkValues, oldVals) + return nil + } + newVals, err := fkc.fetchFKValues(newRow) + if err != nil { + return err + } + newValsKey, err := codec.EncodeKey(sc, nil, newVals...) + if err != nil { + return err + } + couple := fkc.fkUpdatedValuesMap[string(newValsKey)] + if couple == nil { + couple = &UpdatedValuesCouple{ + NewValues: newVals, + } + } + couple.OldValuesList = append(couple.OldValuesList, oldVals) + fkc.fkUpdatedValuesMap[string(newValsKey)] = couple + return nil +} + func (fkc *FKCascadeExec) buildExecutor(ctx context.Context) (Executor, error) { p, err := fkc.buildFKCascadePlan(ctx) if err != nil || p == nil { @@ -591,17 +630,9 @@ func (fkc *FKCascadeExec) buildExecutor(ctx context.Context) (Executor, error) { var maxHandleFKValueInOneCascade = 1024 func (fkc *FKCascadeExec) buildFKCascadePlan(ctx context.Context) (plannercore.Plan, error) { - if len(fkc.fkValues) == 0 { + if len(fkc.fkValues) == 0 && len(fkc.fkUpdatedValuesMap) == 0 { return nil, nil } - var fkValues [][]types.Datum - if len(fkc.fkValues) <= maxHandleFKValueInOneCascade { - fkValues = fkc.fkValues - fkc.fkValues = nil - } else { - fkValues = fkc.fkValues[:maxHandleFKValueInOneCascade] - fkc.fkValues = fkc.fkValues[maxHandleFKValueInOneCascade:] - } var indexName model.CIStr indexForFK := model.FindIndexByColumns(fkc.childTable, fkc.fk.Cols...) if indexForFK != nil { @@ -611,12 +642,24 @@ func (fkc *FKCascadeExec) buildFKCascadePlan(ctx context.Context) (plannercore.P var err error switch fkc.tp { case plannercore.FKCascadeOnDelete: + fkValues := fkc.fetchOnDeleteOrUpdateFKValues() switch model.ReferOptionType(fkc.fk.OnDelete) { case model.ReferOptionCascade: sqlStr, err = GenCascadeDeleteSQL(fkc.referredFK.ChildSchema, fkc.childTable.Name, indexName, fkc.fk, fkValues) case model.ReferOptionSetNull: sqlStr, err = GenCascadeSetNullSQL(fkc.referredFK.ChildSchema, fkc.childTable.Name, indexName, fkc.fk, fkValues) } + case plannercore.FKCascadeOnUpdate: + switch model.ReferOptionType(fkc.fk.OnUpdate) { + case model.ReferOptionCascade: + couple := fkc.fetchUpdatedValuesCouple() + if couple != nil && len(couple.NewValues) != 0 { + sqlStr, err = GenCascadeUpdateSQL(fkc.referredFK.ChildSchema, fkc.childTable.Name, indexName, fkc.fk, couple) + } + case model.ReferOptionSetNull: + fkValues := fkc.fetchOnDeleteOrUpdateFKValues() + sqlStr, err = GenCascadeSetNullSQL(fkc.referredFK.ChildSchema, fkc.childTable.Name, indexName, fkc.fk, fkValues) + } } if err != nil { return nil, err @@ -646,6 +689,34 @@ func (fkc *FKCascadeExec) buildFKCascadePlan(ctx context.Context) (plannercore.P return finalPlan, err } +func (fkc *FKCascadeExec) fetchOnDeleteOrUpdateFKValues() [][]types.Datum { + var fkValues [][]types.Datum + if len(fkc.fkValues) <= maxHandleFKValueInOneCascade { + fkValues = fkc.fkValues + fkc.fkValues = nil + } else { + fkValues = fkc.fkValues[:maxHandleFKValueInOneCascade] + fkc.fkValues = fkc.fkValues[maxHandleFKValueInOneCascade:] + } + return fkValues +} + +func (fkc *FKCascadeExec) fetchUpdatedValuesCouple() *UpdatedValuesCouple { + for k, couple := range fkc.fkUpdatedValuesMap { + if len(couple.OldValuesList) <= maxHandleFKValueInOneCascade { + delete(fkc.fkUpdatedValuesMap, k) + return couple + } + result := &UpdatedValuesCouple{ + NewValues: couple.NewValues, + OldValuesList: couple.OldValuesList[:maxHandleFKValueInOneCascade], + } + couple.OldValuesList = couple.OldValuesList[maxHandleFKValueInOneCascade:] + return result + } + return nil +} + // GenCascadeDeleteSQL uses to generate cascade delete SQL, export for test. func GenCascadeDeleteSQL(schema, table, idx model.CIStr, fk *model.FKInfo, fkValues [][]types.Datum) (string, error) { buf := bytes.NewBuffer(make([]byte, 0, 48+8*len(fkValues))) @@ -669,6 +740,19 @@ func GenCascadeDeleteSQL(schema, table, idx model.CIStr, fk *model.FKInfo, fkVal // GenCascadeSetNullSQL uses to generate foreign key `SET NULL` SQL, export for test. func GenCascadeSetNullSQL(schema, table, idx model.CIStr, fk *model.FKInfo, fkValues [][]types.Datum) (string, error) { + newValues := make([]types.Datum, len(fk.Cols)) + for i := range fk.Cols { + newValues[i] = types.NewDatum(nil) + } + couple := &UpdatedValuesCouple{ + NewValues: newValues, + OldValuesList: fkValues, + } + return GenCascadeUpdateSQL(schema, table, idx, fk, couple) +} + +// GenCascadeUpdateSQL uses to generate cascade update SQL, export for test. +func GenCascadeUpdateSQL(schema, table, idx model.CIStr, fk *model.FKInfo, couple *UpdatedValuesCouple) (string, error) { buf := bytes.NewBuffer(nil) buf.WriteString("UPDATE `") buf.WriteString(schema.L) @@ -686,10 +770,15 @@ func GenCascadeSetNullSQL(schema, table, idx model.CIStr, fk *model.FKInfo, fkVa if i > 0 { buf.WriteString(", ") } - buf.WriteString("`" + col.L + "`") - buf.WriteString("=NULL") + buf.WriteString("`" + col.L) + buf.WriteString("` = ") + val, err := genFKValueString(couple.NewValues[i]) + if err != nil { + return "", err + } + buf.WriteString(val) } - err := genCascadeSQLWhereCondition(buf, fk, fkValues) + err := genCascadeSQLWhereCondition(buf, fk, couple.OldValuesList) if err != nil { return "", err } @@ -728,6 +817,12 @@ func genCascadeSQLWhereCondition(buf *bytes.Buffer, fk *model.FKInfo, fkValues [ } func genFKValueString(v types.Datum) (string, error) { + switch v.Kind() { + case types.KindNull: + return "NULL", nil + case types.KindMysqlBit: + return v.GetBinaryLiteral().ToBitLiteralString(true), nil + } val, err := v.ToString() if err != nil { return "", err diff --git a/executor/insert.go b/executor/insert.go index 34463c7d780b4..931b8bbf6b480 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -420,7 +420,7 @@ func (e *InsertExec) doDupRowUpdate(ctx context.Context, handle kv.Handle, oldRo } newData := e.row4Update[:len(oldRow)] - _, err := updateRecord(ctx, e.ctx, handle, oldRow, newData, assignFlag, e.Table, true, e.memTracker, e.fkChecks) + _, err := updateRecord(ctx, e.ctx, handle, oldRow, newData, assignFlag, e.Table, true, e.memTracker, e.fkChecks, e.fkCascades) if err != nil { return err } diff --git a/executor/insert_common.go b/executor/insert_common.go index d708a59770fdd..905a19b2ca015 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -98,7 +98,8 @@ type InsertValues struct { isLoadData bool txnInUse sync.Mutex // fkChecks contains the foreign key checkers. - fkChecks []*FKCheckExec + fkChecks []*FKCheckExec + fkCascades []*FKCascadeExec } type defaultVal struct { diff --git a/executor/update.go b/executor/update.go index acfd883a330cf..d9cffabd08355 100644 --- a/executor/update.go +++ b/executor/update.go @@ -68,6 +68,8 @@ type UpdateExec struct { matches []bool // fkChecks contains the foreign key checkers. the map is tableID -> []*FKCheckExec fkChecks map[int64][]*FKCheckExec + // fkCascades contains the foreign key cascade. the map is tableID -> []*FKCascadeExec + fkCascades map[int64][]*FKCascadeExec } // prepare `handles`, `tableUpdatable`, `changed` to avoid re-computations. @@ -194,7 +196,8 @@ func (e *UpdateExec) exec(ctx context.Context, schema *expression.Schema, row, n // Update row fkChecks := e.fkChecks[content.TblID] - changed, err1 := updateRecord(ctx, e.ctx, handle, oldData, newTableData, flags, tbl, false, e.memTracker, fkChecks) + fkCascades := e.fkCascades[content.TblID] + changed, err1 := updateRecord(ctx, e.ctx, handle, oldData, newTableData, flags, tbl, false, e.memTracker, fkChecks, fkCascades) if err1 == nil { _, exist := e.updatedRowKeys[content.Start].Get(handle) memDelta := e.updatedRowKeys[content.Start].Set(handle, changed) @@ -546,10 +549,14 @@ func (e *UpdateExec) GetFKChecks() []*FKCheckExec { // GetFKCascades implements WithForeignKeyTrigger interface. func (e *UpdateExec) GetFKCascades() []*FKCascadeExec { - return nil + fkCascades := make([]*FKCascadeExec, 0, len(e.fkChecks)) + for _, fkc := range e.fkCascades { + fkCascades = append(fkCascades, fkc...) + } + return fkCascades } // HasFKCascades implements WithForeignKeyTrigger interface. func (e *UpdateExec) HasFKCascades() bool { - return false + return len(e.fkCascades) > 0 } diff --git a/executor/write.go b/executor/write.go index 8d4336b07b960..01359b56a2571 100644 --- a/executor/write.go +++ b/executor/write.go @@ -51,7 +51,7 @@ var ( // 1. changed (bool) : does the update really change the row values. e.g. update set i = 1 where i = 1; // 2. err (error) : error in the update. func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, oldData, newData []types.Datum, modified []bool, t table.Table, - onDup bool, memTracker *memory.Tracker, fkChecks []*FKCheckExec) (bool, error) { + onDup bool, memTracker *memory.Tracker, fkChecks []*FKCheckExec, fkCascades []*FKCascadeExec) (bool, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("executor.updateRecord", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -220,6 +220,12 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h kv.Handle, old return false, err } } + for _, fkc := range fkCascades { + err := fkc.onUpdateRow(sc, oldData, newData) + if err != nil { + return false, err + } + } if onDup { sc.AddAffectedRows(2) } else { diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index e8fb5b0a486b8..7dd628b88718d 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -429,7 +429,8 @@ type Update struct { tblID2Table map[int64]table.Table - FKChecks map[int64][]*FKCheck + FKChecks map[int64][]*FKCheck + FKCascades map[int64][]*FKCascade } // MemoryUsage return the memory usage of Update diff --git a/planner/core/foreign_key.go b/planner/core/foreign_key.go index 86a49e3d2a0b8..8c5b03384ae6e 100644 --- a/planner/core/foreign_key.go +++ b/planner/core/foreign_key.go @@ -53,6 +53,8 @@ type FKCascadeType int8 const ( // FKCascadeOnDelete indicates in delete statement. FKCascadeOnDelete FKCascadeType = 1 + // FKCascadeOnUpdate indicates in update statement. + FKCascadeOnUpdate FKCascadeType = 2 emptyFkCheckSize = int64(unsafe.Sizeof(FKCheck{})) emptyFkCascadeSize = int64(unsafe.Sizeof(FKCascade{})) @@ -80,17 +82,17 @@ func (f *FKCascade) MemoryUsage() (sum int64) { return } -func (p *Insert) buildOnInsertFKChecks(ctx sessionctx.Context, is infoschema.InfoSchema, dbName string) ([]*FKCheck, error) { +func (p *Insert) buildOnInsertFKChecks(ctx sessionctx.Context, is infoschema.InfoSchema, dbName string) error { if !ctx.GetSessionVars().ForeignKeyChecks { - return nil, nil + return nil } tblInfo := p.Table.Meta() fkChecks := make([]*FKCheck, 0, len(tblInfo.ForeignKeys)) updateCols := p.buildOnDuplicateUpdateColumns() if len(updateCols) > 0 { - referredFKChecks, err := buildOnUpdateReferredFKChecks(is, dbName, tblInfo, updateCols) + referredFKChecks, _, err := buildOnUpdateReferredFKTriggers(is, dbName, tblInfo, updateCols) if err != nil { - return nil, err + return err } if len(referredFKChecks) > 0 { fkChecks = append(fkChecks, referredFKChecks...) @@ -103,13 +105,14 @@ func (p *Insert) buildOnInsertFKChecks(ctx sessionctx.Context, is infoschema.Inf failedErr := ErrNoReferencedRow2.FastGenByArgs(fk.String(dbName, tblInfo.Name.L)) fkCheck, err := buildFKCheckOnModifyChildTable(is, fk, failedErr) if err != nil { - return nil, err + return err } if fkCheck != nil { fkChecks = append(fkChecks, fkCheck) } } - return fkChecks, nil + p.FKChecks = fkChecks + return nil } func (p *Insert) buildOnDuplicateUpdateColumns() map[string]struct{} { @@ -120,12 +123,13 @@ func (p *Insert) buildOnDuplicateUpdateColumns() map[string]struct{} { return m } -func (updt *Update) buildOnUpdateFKChecks(ctx sessionctx.Context, is infoschema.InfoSchema, tblID2table map[int64]table.Table) error { +func (updt *Update) buildOnUpdateFKTriggers(ctx sessionctx.Context, is infoschema.InfoSchema, tblID2table map[int64]table.Table) error { if !ctx.GetSessionVars().ForeignKeyChecks { return nil } tblID2UpdateColumns := updt.buildTbl2UpdateColumns() fkChecks := make(map[int64][]*FKCheck) + fkCascades := make(map[int64][]*FKCascade) for tid, tbl := range tblID2table { tblInfo := tbl.Meta() dbInfo, exist := is.SchemaByTable(tblInfo) @@ -137,13 +141,16 @@ func (updt *Update) buildOnUpdateFKChecks(ctx sessionctx.Context, is infoschema. if len(updateCols) == 0 { continue } - referredFKChecks, err := buildOnUpdateReferredFKChecks(is, dbInfo.Name.L, tblInfo, updateCols) + referredFKChecks, referredFKCascades, err := buildOnUpdateReferredFKTriggers(is, dbInfo.Name.L, tblInfo, updateCols) if err != nil { return err } if len(referredFKChecks) > 0 { fkChecks[tid] = append(fkChecks[tid], referredFKChecks...) } + if len(referredFKCascades) > 0 { + fkCascades[tid] = append(fkCascades[tid], referredFKCascades...) + } childFKChecks, err := buildOnUpdateChildFKChecks(is, dbInfo.Name.L, tblInfo, updateCols) if err != nil { return err @@ -153,6 +160,7 @@ func (updt *Update) buildOnUpdateFKChecks(ctx sessionctx.Context, is infoschema. } } updt.FKChecks = fkChecks + updt.FKCascades = fkCascades return nil } @@ -170,7 +178,7 @@ func (del *Delete) buildOnDeleteFKTriggers(ctx sessionctx.Context, is infoschema } referredFKs := is.GetTableReferredForeignKeys(dbInfo.Name.L, tblInfo.Name.L) for _, referredFK := range referredFKs { - fkCheck, fkCascade, err := buildOnDeleteFKTrigger(is, referredFK) + fkCheck, fkCascade, err := buildOnDeleteOrUpdateFKTrigger(is, referredFK, FKCascadeOnDelete) if err != nil { return err } @@ -187,22 +195,26 @@ func (del *Delete) buildOnDeleteFKTriggers(ctx sessionctx.Context, is infoschema return nil } -func buildOnUpdateReferredFKChecks(is infoschema.InfoSchema, dbName string, tblInfo *model.TableInfo, updateCols map[string]struct{}) ([]*FKCheck, error) { +func buildOnUpdateReferredFKTriggers(is infoschema.InfoSchema, dbName string, tblInfo *model.TableInfo, updateCols map[string]struct{}) ([]*FKCheck, []*FKCascade, error) { referredFKs := is.GetTableReferredForeignKeys(dbName, tblInfo.Name.L) fkChecks := make([]*FKCheck, 0, len(referredFKs)) + fkCascades := make([]*FKCascade, 0, len(referredFKs)) for _, referredFK := range referredFKs { if !isMapContainAnyCols(updateCols, referredFK.Cols...) { continue } - fkCheck, err := buildFKCheckOnModifyReferTable(is, referredFK) + fkCheck, fkCascade, err := buildOnDeleteOrUpdateFKTrigger(is, referredFK, FKCascadeOnUpdate) if err != nil { - return nil, err + return nil, nil, err } if fkCheck != nil { fkChecks = append(fkChecks, fkCheck) } + if fkCascade != nil { + fkCascades = append(fkCascades, fkCascade) + } } - return fkChecks, nil + return fkChecks, fkCascades, nil } func buildOnUpdateChildFKChecks(is infoschema.InfoSchema, dbName string, tblInfo *model.TableInfo, updateCols map[string]struct{}) ([]*FKCheck, error) { @@ -260,7 +272,7 @@ func (updt *Update) buildTbl2UpdateColumns() map[int64]map[string]struct{} { return tblID2UpdateColumns } -func buildOnDeleteFKTrigger(is infoschema.InfoSchema, referredFK *model.ReferredFKInfo) (*FKCheck, *FKCascade, error) { +func buildOnDeleteOrUpdateFKTrigger(is infoschema.InfoSchema, referredFK *model.ReferredFKInfo, tp FKCascadeType) (*FKCheck, *FKCascade, error) { childTable, err := is.TableByName(referredFK.ChildSchema, referredFK.ChildTable) if err != nil { return nil, nil, nil @@ -273,12 +285,17 @@ func buildOnDeleteFKTrigger(is infoschema.InfoSchema, referredFK *model.Referred if fk.State != model.StatePublic { fkReferOption = model.ReferOptionRestrict } else { - fkReferOption = model.ReferOptionType(fk.OnDelete) + switch tp { + case FKCascadeOnDelete: + fkReferOption = model.ReferOptionType(fk.OnDelete) + case FKCascadeOnUpdate: + fkReferOption = model.ReferOptionType(fk.OnUpdate) + } } switch fkReferOption { case model.ReferOptionCascade, model.ReferOptionSetNull: fkCascade := &FKCascade{ - Tp: FKCascadeOnDelete, + Tp: tp, ReferredFK: referredFK, ChildTable: childTable, FK: fk, diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index e29e74919f59f..7e9d6c1a01cf8 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -5416,7 +5416,7 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( } updt.PartitionedTable = b.partitionedTable updt.tblID2Table = tblID2table - err = updt.buildOnUpdateFKChecks(b.ctx, b.is, tblID2table) + err = updt.buildOnUpdateFKTriggers(b.ctx, b.is, tblID2table) return updt, err } diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index cee0fc324d5b6..8886d5c4aab4e 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -3621,7 +3621,7 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( if err != nil { return nil, err } - insertPlan.FKChecks, err = insertPlan.buildOnInsertFKChecks(b.ctx, b.is, tn.DBInfo.Name.L) + err = insertPlan.buildOnInsertFKChecks(b.ctx, b.is, tn.DBInfo.Name.L) return insertPlan, err } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 125e716522947..de69438257f4e 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -1593,7 +1593,7 @@ func buildPointUpdatePlan(ctx sessionctx.Context, pointPlan PhysicalPlan, dbName updatePlan.PartitionedTable = append(updatePlan.PartitionedTable, pt) } } - err := updatePlan.buildOnUpdateFKChecks(ctx, is, updatePlan.tblID2Table) + err := updatePlan.buildOnUpdateFKTriggers(ctx, is, updatePlan.tblID2Table) if err != nil { return nil }