From c0111648342b64d2a7c86185430da3be4d3b7287 Mon Sep 17 00:00:00 2001 From: Zhou Kunqin <25057648+time-and-fate@users.noreply.github.com> Date: Fri, 6 Sep 2024 12:31:49 +0800 Subject: [PATCH] planner: fix incorrect maintenance of `handleColHelper` for recursive CTE (#55732) close pingcap/tidb#55666 --- pkg/planner/core/logical_plan_builder.go | 34 ++++++++++-- pkg/planner/core/planbuilder.go | 3 +- tests/integrationtest/r/cte.result | 64 ++++++++++++++++++++++ tests/integrationtest/t/cte.test | 67 ++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 6 deletions(-) diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index ee3fbb427d537..22712cfe77090 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -4169,6 +4169,9 @@ func getLatestVersionFromStatsTable(ctx sessionctx.Context, tblInfo *model.Table return version } +// tryBuildCTE considers the input tn as a reference to a CTE and tries to build the logical plan for it like building +// DataSource for normal tables. +// tryBuildCTE will push an entry into handleHelper when successful. func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName *pmodel.CIStr) (base.LogicalPlan, error) { for i := len(b.outerCTEs) - 1; i >= 0; i-- { cte := b.outerCTEs[i] @@ -4193,6 +4196,7 @@ func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName p := logicalop.LogicalCTETable{Name: cte.def.Name.String(), IDForStorage: cte.storageID, SeedStat: cte.seedStat, SeedSchema: cte.seedLP.Schema()}.Init(b.ctx, b.getSelectOffset()) p.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) p.SetOutputNames(cte.seedLP.OutputNames()) + b.handleHelper.pushMap(nil) return p, nil } @@ -6899,6 +6903,8 @@ func isJoinHintSupportedInMPPMode(preferJoinType uint) bool { return onesCount < 1 } +// buildCte prepares for a CTE. It works together with buildWith(). +// It will push one entry into b.handleHelper. func (b *PlanBuilder) buildCte(ctx context.Context, cte *ast.CommonTableExpression, isRecursive bool) (p base.LogicalPlan, err error) { saveBuildingCTE := b.buildingCTE b.buildingCTE = true @@ -6934,6 +6940,7 @@ func (b *PlanBuilder) buildCte(ctx context.Context, cte *ast.CommonTableExpressi } // buildRecursiveCTE handles the with clause `with recursive xxx as xx`. +// It will push one entry into b.handleHelper. func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNode) error { b.isCTE = true cInfo := b.outerCTEs[len(b.outerCTEs)-1] @@ -6963,6 +6970,7 @@ func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNo for i := 0; i < len(x.SelectList.Selects); i++ { var p base.LogicalPlan var err error + originalLen := b.handleHelper.stackTail var afterOpr *ast.SetOprType switch y := x.SelectList.Selects[i].(type) { @@ -6974,6 +6982,22 @@ func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNo afterOpr = y.AfterSetOperator } + // This is for maintain b.handleHelper instead of normal error handling. Since one error is expected if + // expectSeed && cInfo.useRecursive, error handling is in the "if expectSeed" block below. + if err == nil { + b.handleHelper.popMap() + } else { + // Be careful with this tricky case. One error is expected here when building the first recursive + // part, however, the b.handleHelper won't be restored if error occurs, which means there could be + // more than one entry pushed into b.handleHelper without being poped. + // For example: with recursive cte1 as (select ... union all select ... from tbl join cte1 ...) ... + // This violates the semantic of buildSelect() and buildSetOpr(), which should only push exactly one + // entry into b.handleHelper. So we use a special logic to restore the b.handleHelper here. + for b.handleHelper.stackTail > originalLen { + b.handleHelper.popMap() + } + } + if expectSeed { if cInfo.useRecursive { // 3. If it fail to build a plan, it may be the recursive part. Then we build the seed part plan, and rebuild it. @@ -7004,14 +7028,11 @@ func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNo // Build seed part plan. saveSelect := x.SelectList.Selects x.SelectList.Selects = x.SelectList.Selects[:i] - // We're rebuilding the seed part, so we pop the result we built previously. - for _i := 0; _i < i; _i++ { - b.handleHelper.popMap() - } p, err = b.buildSetOpr(ctx, x) if err != nil { return err } + b.handleHelper.popMap() x.SelectList.Selects = saveSelect p, err = b.adjustCTEPlanOutputName(p, cInfo.def) if err != nil { @@ -7080,6 +7101,7 @@ func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNo limit.SetChildren(limit.Children()[:0]...) cInfo.limitLP = limit } + b.handleHelper.pushMap(nil) return nil default: p, err := b.buildResultSetNode(ctx, x, true) @@ -7175,7 +7197,9 @@ func (b *PlanBuilder) buildWith(ctx context.Context, w *ast.WithClause) ([]*cteI b.outerCTEs[len(b.outerCTEs)-1].optFlag = b.optFlag b.outerCTEs[len(b.outerCTEs)-1].isBuilding = false b.optFlag = saveFlag - // each cte (select statement) will generate a handle map, pop it out here. + // buildCte() will push one entry into handleHelper. As said in comments for b.handleHelper, building CTE + // should not affect the handleColHelper, so we pop it out here, then buildWith() as a whole will not modify + // the handleColHelper. b.handleHelper.popMap() ctes = append(ctes, b.outerCTEs[len(b.outerCTEs)-1]) } diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index b41065ba06ac2..18fceaf2602d9 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -235,7 +235,8 @@ type PlanBuilder struct { // If it's a aggregation, we pop the map and push a nil map since no handle information left. // If it's a union, we pop all children's and push a nil map. // If it's a join, we pop its children's out then merge them and push the new map to stack. - // If we meet a subquery, it's clearly that it's a independent problem so we just pop one map out when we finish building the subquery. + // If we meet a subquery or CTE, it's clearly that it's an independent problem so we just pop one map out when we + // finish building the subquery or CTE. handleHelper *handleColHelper hintProcessor *hint.QBHintHandler diff --git a/tests/integrationtest/r/cte.result b/tests/integrationtest/r/cte.result index de3721f2b739c..483fe0cdae93e 100644 --- a/tests/integrationtest/r/cte.result +++ b/tests/integrationtest/r/cte.result @@ -1271,3 +1271,67 @@ FROM product_detail col_4 Product A Product A +drop table if exists t1, t2; +create table t1(a int, b int); +create table t2(a int, b int); +insert into t1 value(5,5); +insert into t2 value(1,1); +with recursive cte1 as (select 1 as a union all select cte1.a+1 from t1 join cte1 on t1.a > cte1.a) select * from cte1; +a +1 +2 +3 +4 +5 +update t2 set b=2 where a in (with recursive cte1 as (select 1 as a union all select cte1.a+1 from t1 join cte1 on t1.a > cte1.a) select * from cte1); +select * from t2; +a b +1 2 +delete from t2 where a in (with recursive cte1 as (select 1 as a union all select cte1.a+1 from t1 join cte1 on t1.a > cte1.a) select * from cte1); +select * from t2; +a b +drop table if exists table_abc1; +drop table if exists table_abc2; +drop table if exists table_abc3; +drop table if exists table_abc4; +CREATE TABLE `table_abc1` ( +`column_abc1` varchar(10) DEFAULT NULL, +`column_abc2` varchar(10) DEFAULT NULL, +`column_abc3` varchar(10) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; +CREATE TABLE `table_abc3` ( +`column_abc5` varchar(10) DEFAULT NULL, +`column_abc6` varchar(10) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; +CREATE TABLE `table_abc4` ( +`column_abc3` varchar(10) DEFAULT NULL, +`column_abc7` varchar(10) DEFAULT NULL, +`column_abc5` varchar(10) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; +INSERT INTO `table_abc1` VALUES ('KTL157','KTL157','KTL157'); +INSERT INTO `table_abc3` VALUES ('1000','20240819'); +INSERT INTO `table_abc4` VALUES ('KTL157','test','1000'); +DELETE FROM table_abc3 t_abc3 +WHERE t_abc3.column_abc5 IN ( +SELECT a.column_abc5 +FROM ( +WITH tree_cte1 AS ( +WITH RECURSIVE tree_cte AS ( +SELECT t.column_abc1, t.column_abc2, t.column_abc3, 0 AS lv +FROM table_abc1 t +WHERE t.column_abc1 IN ('KTL157', 'KTL159') +UNION ALL +SELECT t.column_abc1, t.column_abc2, t.column_abc3, tcte.lv + 1 +FROM table_abc1 t +JOIN tree_cte tcte ON t.column_abc1 = tcte.column_abc2 +WHERE tcte.lv <= 1 +) +SELECT * FROM tree_cte +) +SELECT e.column_abc5 +FROM ( +SELECT DISTINCT * FROM tree_cte1 +) aa +LEFT JOIN table_abc4 e ON e.column_abc3 = aa.column_abc3 +) a +); diff --git a/tests/integrationtest/t/cte.test b/tests/integrationtest/t/cte.test index 5e8ab2e7c1f66..ae9fc79d53338 100644 --- a/tests/integrationtest/t/cte.test +++ b/tests/integrationtest/t/cte.test @@ -680,3 +680,70 @@ SELECT col_4 FROM product_detail ) a; +# Tests for PR #55732 +drop table if exists t1, t2; +create table t1(a int, b int); +create table t2(a int, b int); +insert into t1 value(5,5); +insert into t2 value(1,1); +with recursive cte1 as (select 1 as a union all select cte1.a+1 from t1 join cte1 on t1.a > cte1.a) select * from cte1; +# This UPDATE should update t2.b to 2 +update t2 set b=2 where a in (with recursive cte1 as (select 1 as a union all select cte1.a+1 from t1 join cte1 on t1.a > cte1.a) select * from cte1); +select * from t2; +# This DELETE should delete all rows in t2 +delete from t2 where a in (with recursive cte1 as (select 1 as a union all select cte1.a+1 from t1 join cte1 on t1.a > cte1.a) select * from cte1); +select * from t2; + +# Issue #55666 +drop table if exists table_abc1; +drop table if exists table_abc2; +drop table if exists table_abc3; +drop table if exists table_abc4; + +CREATE TABLE `table_abc1` ( + `column_abc1` varchar(10) DEFAULT NULL, + `column_abc2` varchar(10) DEFAULT NULL, + `column_abc3` varchar(10) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; + +CREATE TABLE `table_abc3` ( + `column_abc5` varchar(10) DEFAULT NULL, + `column_abc6` varchar(10) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; + +CREATE TABLE `table_abc4` ( + `column_abc3` varchar(10) DEFAULT NULL, + `column_abc7` varchar(10) DEFAULT NULL, + `column_abc5` varchar(10) DEFAULT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin; + +INSERT INTO `table_abc1` VALUES ('KTL157','KTL157','KTL157'); +INSERT INTO `table_abc3` VALUES ('1000','20240819'); +INSERT INTO `table_abc4` VALUES ('KTL157','test','1000'); + +DELETE FROM table_abc3 t_abc3 +WHERE t_abc3.column_abc5 IN ( + SELECT a.column_abc5 + FROM ( + WITH tree_cte1 AS ( + WITH RECURSIVE tree_cte AS ( + SELECT t.column_abc1, t.column_abc2, t.column_abc3, 0 AS lv + FROM table_abc1 t + WHERE t.column_abc1 IN ('KTL157', 'KTL159') + UNION ALL + SELECT t.column_abc1, t.column_abc2, t.column_abc3, tcte.lv + 1 + FROM table_abc1 t + JOIN tree_cte tcte ON t.column_abc1 = tcte.column_abc2 + WHERE tcte.lv <= 1 + ) + SELECT * FROM tree_cte + ) + SELECT e.column_abc5 + FROM ( + SELECT DISTINCT * FROM tree_cte1 + ) aa + LEFT JOIN table_abc4 e ON e.column_abc3 = aa.column_abc3 + ) a +); + +