diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index 418669bbc0ff2..5da87ec057f28 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -191,6 +191,35 @@ HashAgg_18 24000.00 root group by:c1, funcs:firstrow(join_agg_0) └─IndexReader_62 8000.00 root index:StreamAgg_53 └─StreamAgg_53 8000.00 cop group by:test.t2.c1, funcs:firstrow(test.t2.c1), firstrow(test.t2.c1) └─IndexScan_60 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo +explain select count(1) from (select count(1) from (select * from t1 where c3 = 100) k) k2; +id count task operator info +StreamAgg_13 1.00 root funcs:count(1) +└─StreamAgg_28 1.00 root funcs:firstrow(col_0) + └─TableReader_29 1.00 root data:StreamAgg_17 + └─StreamAgg_17 1.00 cop funcs:firstrow(1) + └─Selection_27 10.00 cop eq(test.t1.c3, 100) + └─TableScan_26 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +explain select 1 from (select count(c2), count(c3) from t1) k; +id count task operator info +Projection_5 1.00 root 1 +└─StreamAgg_17 1.00 root funcs:firstrow(col_0) + └─TableReader_18 1.00 root data:StreamAgg_9 + └─StreamAgg_9 1.00 cop funcs:firstrow(1) + └─TableScan_16 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +explain select count(1) from (select max(c2), count(c3) as m from t1) k; +id count task operator info +StreamAgg_11 1.00 root funcs:count(1) +└─StreamAgg_23 1.00 root funcs:firstrow(col_0) + └─TableReader_24 1.00 root data:StreamAgg_15 + └─StreamAgg_15 1.00 cop funcs:firstrow(1) + └─TableScan_22 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +explain select count(1) from (select count(c2) from t1 group by c3) k; +id count task operator info +StreamAgg_11 1.00 root funcs:count(1) +└─HashAgg_23 8000.00 root group by:col_1, funcs:firstrow(col_0) + └─TableReader_24 8000.00 root data:HashAgg_20 + └─HashAgg_20 8000.00 cop group by:test.t1.c3, funcs:firstrow(1) + └─TableScan_15 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo set @@session.tidb_opt_insubq_to_join_and_agg=0; explain select sum(t1.c1 in (select c1 from t2)) from t1; id count task operator info diff --git a/cmd/explaintest/t/explain_easy.test b/cmd/explaintest/t/explain_easy.test index c5044c96069cd..fe056d84d0208 100644 --- a/cmd/explaintest/t/explain_easy.test +++ b/cmd/explaintest/t/explain_easy.test @@ -35,6 +35,12 @@ explain select if(10, t1.c1, t1.c2) from t1; explain select c1 from t2 union select c1 from t2 union all select c1 from t2; explain select c1 from t2 union all select c1 from t2 union select c1 from t2; +# https://github.com/pingcap/tidb/issues/9125 +explain select count(1) from (select count(1) from (select * from t1 where c3 = 100) k) k2; +explain select 1 from (select count(c2), count(c3) from t1) k; +explain select count(1) from (select max(c2), count(c3) as m from t1) k; +explain select count(1) from (select count(c2) from t1 group by c3) k; + set @@session.tidb_opt_insubq_to_join_and_agg=0; explain select sum(t1.c1 in (select c1 from t2)) from t1; diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 71598d87d2203..702a3f7d1a5d4 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1148,6 +1148,24 @@ func (s *testPlanSuite) TestColumnPruning(c *C) { 12: {"test.t4.a"}, }, }, + { + sql: "select 1 from (select count(b) as cnt from t) t1;", + ans: map[int][]string{ + 1: {"test.t.a"}, + }, + }, + { + sql: "select count(1) from (select count(b) as cnt from t) t1;", + ans: map[int][]string{ + 1: {"test.t.a"}, + }, + }, + { + sql: "select count(1) from (select count(b) as cnt from t group by c) t1;", + ans: map[int][]string{ + 1: {"test.t.c"}, + }, + }, } for _, tt := range tests { comment := Commentf("for %s", tt.sql) diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 1a78a8ecb4d90..6a5a278921388 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -20,7 +20,9 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/types" ) type columnPruner struct { @@ -117,6 +119,21 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) for _, aggrFunc := range la.AggFuncs { selfUsedCols = expression.ExtractColumnsFromExpressions(selfUsedCols, aggrFunc.Args, nil) } + if len(la.AggFuncs) == 0 { + // If all the aggregate functions are pruned, we should add an aggregate function to keep the correctness. + one, err := aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncFirstRow, []expression.Expression{expression.One}, false) + if err != nil { + return err + } + la.AggFuncs = []*aggregation.AggFuncDesc{one} + col := &expression.Column{ + ColName: model.NewCIStr("dummy_agg"), + UniqueID: la.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + la.schema.Columns = []*expression.Column{col} + } + if len(la.GroupByItems) > 0 { for i := len(la.GroupByItems) - 1; i >= 0; i-- { cols := expression.ExtractColumns(la.GroupByItems[i])