diff --git a/pkg/parser/ast/dml.go b/pkg/parser/ast/dml.go index 898ba923e0fba..8ba60c26f91a7 100644 --- a/pkg/parser/ast/dml.go +++ b/pkg/parser/ast/dml.go @@ -282,6 +282,11 @@ type TableName struct { TableSample *TableSample // AS OF is used to see the data as it was at a specific point in time. AsOf *AsOfClause + // IsAlias is true if this table name is an alias. + // sometime, we need to distinguish the table name is an alias or not. + // for example ```delete tt1 from t1 tt1,(select max(id) id from t2)tt2 where tt1.id<=tt2.id``` + // ```tt1``` is a alias name. so we need to set IsAlias to true and restore the table name without database name. + IsAlias bool } func (*TableName) resultSet() {} @@ -293,7 +298,7 @@ func (n *TableName) restoreName(ctx *format.RestoreCtx) { if n.Schema.String() != "" { ctx.WriteName(n.Schema.String()) ctx.WritePlain(".") - } else if ctx.DefaultDB != "" { + } else if ctx.DefaultDB != "" && !n.IsAlias { // Try CTE, for a CTE table name, we shouldn't write the database name. if !ctx.IsCTETableName(n.Name.L) { ctx.WriteName(ctx.DefaultDB) diff --git a/pkg/planner/core/preprocess.go b/pkg/planner/core/preprocess.go index 209f540a7b25f..34408a5c5d9ab 100644 --- a/pkg/planner/core/preprocess.go +++ b/pkg/planner/core/preprocess.go @@ -571,7 +571,9 @@ func (p *preprocessor) checkBindGrammar(originNode, hintedNode ast.StmtNode, def TableInfo: tableInfo, }) } - + aliasChecker := &aliasChecker{} + originNode.Accept(aliasChecker) + hintedNode.Accept(aliasChecker) originSQL := parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(originNode, defaultDB, originNode.Text()), false) hintedSQL := parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(hintedNode, defaultDB, hintedNode.Text()), false) if originSQL != hintedSQL { @@ -1991,3 +1993,59 @@ func (p *preprocessor) skipLockMDL() bool { // skip lock mdl for ANALYZE statement. return p.flag&inImportInto > 0 || p.flag&inAnalyze > 0 } + +// aliasChecker is used to check the alias of the table in delete statement. +// +// for example: delete tt1 from t1 tt1,(select max(id) id from t2)tt2 where tt1.id<=tt2.id +// `delete tt1` will be transformed to `delete current_database.t1` by default. +// because `tt1` cannot be used as alias in delete statement. +// so we have to set `tt1` as alias by aliasChecker. +type aliasChecker struct{} + +func (*aliasChecker) Enter(in ast.Node) (ast.Node, bool) { + if deleteStmt, ok := in.(*ast.DeleteStmt); ok { + // 1. check the tableRefs of deleteStmt to find the alias + var aliases []*pmodel.CIStr + if deleteStmt.TableRefs != nil && deleteStmt.TableRefs.TableRefs != nil { + tableRefs := deleteStmt.TableRefs.TableRefs + if val := getTableRefsAlias(tableRefs.Left); val != nil { + aliases = append(aliases, val) + } + if val := getTableRefsAlias(tableRefs.Right); val != nil { + aliases = append(aliases, val) + } + } + // 2. check the Tables to tag the alias + if deleteStmt.Tables != nil && deleteStmt.Tables.Tables != nil { + for _, table := range deleteStmt.Tables.Tables { + if table.Schema.String() != "" { + continue + } + for _, alias := range aliases { + if table.Name.L == alias.L { + table.IsAlias = true + break + } + } + } + } + return in, true + } + return in, false +} + +func getTableRefsAlias(tableRefs ast.ResultSetNode) *pmodel.CIStr { + switch v := tableRefs.(type) { + case *ast.Join: + if v.Left != nil { + return getTableRefsAlias(v.Left) + } + case *ast.TableSource: + return &v.AsName + } + return nil +} + +func (*aliasChecker) Leave(in ast.Node) (ast.Node, bool) { + return in, true +} diff --git a/pkg/planner/core/preprocess_test.go b/pkg/planner/core/preprocess_test.go index f3eea5b6f54b9..9a8653e8ad409 100644 --- a/pkg/planner/core/preprocess_test.go +++ b/pkg/planner/core/preprocess_test.go @@ -414,3 +414,14 @@ func TestPreprocessCTE(t *testing.T) { require.Equal(t, tc.after, rs.String()) } } + +func TestPreprocessDeleteFromWithAlias(t *testing.T) { + // https://github.com/pingcap/tidb/issues/56726 + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("create table t1(id int);") + tk.MustExec(" create table t2(id int);") + tk.MustExec("delete tt1 from t1 tt1,(select max(id) id from t2)tt2 where tt1.id<=tt2.id;") + tk.MustExec("create global binding for delete tt1 from t1 tt1,(select max(id) id from t2)tt2 where tt1.id<=tt2.id using delete /*+ MAX_EXECUTION_TIME(10)*/ tt1 from t1 tt1,(select max(id) id from t2)tt2 where tt1.id<=tt2.id;") +}