diff --git a/planner/cascades/testdata/integration_suite_in.json b/planner/cascades/testdata/integration_suite_in.json index b4f563265813c..7db86db1f37fa 100644 --- a/planner/cascades/testdata/integration_suite_in.json +++ b/planner/cascades/testdata/integration_suite_in.json @@ -60,7 +60,8 @@ "select /*+ STREAM_AGG() */ count(distinct c) from t;", // should push down after stream agg implemented "select /*+ HASH_AGG() */ count(distinct c) from t;", "select count(distinct c) from t group by c;", - "select count(distinct c) from t;" + "select count(distinct c) from t;", + "select count(*) from t group by a having avg(distinct a)>1;" // #24449 Projection should be add between HashAgg and TableReader ] }, { diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json index 3b2719c02d293..f5f134212ac8e 100644 --- a/planner/cascades/testdata/integration_suite_out.json +++ b/planner/cascades/testdata/integration_suite_out.json @@ -601,6 +601,21 @@ "Result": [ "2" ] + }, + { + "SQL": "select count(*) from t group by a having avg(distinct a)>1;", + "Plan": [ + "Projection_14 6400.00 root Column#5", + "└─Selection_15 6400.00 root gt(Column#6, 1)", + " └─HashAgg_20 8000.00 root group by:test.t.a, funcs:count(Column#8)->Column#5, funcs:avg(distinct Column#10)->Column#6", + " └─Projection_21 8000.00 root Column#8, cast(test.t.a, decimal(15,4) BINARY)->Column#10, test.t.a", + " └─TableReader_22 8000.00 root data:HashAgg_23", + " └─HashAgg_23 8000.00 cop[tikv] group by:test.t.a, funcs:count(1)->Column#8", + " └─TableFullScan_19 10000.00 cop[tikv] table:t keep order:false, stats:pseudo" + ], + "Result": [ + "1" + ] } ] }, diff --git a/planner/cascades/transformation_rules.go b/planner/cascades/transformation_rules.go index 9961509299a52..28adca873a974 100644 --- a/planner/cascades/transformation_rules.go +++ b/planner/cascades/transformation_rules.go @@ -1882,7 +1882,7 @@ func (*outerJoinEliminator) prepareForEliminateOuterJoin(joinExpr *memo.GroupExp return } -// check whether one of unique keys sets is contained by inner join keys. +// isInnerJoinKeysContainUniqueKey check whether one of unique keys sets is contained by inner join keys. func (*outerJoinEliminator) isInnerJoinKeysContainUniqueKey(innerGroup *memo.Group, joinKeys *expression.Schema) (bool, error) { // builds UniqueKey info of innerGroup. innerGroup.BuildKeyInfo() @@ -2129,7 +2129,7 @@ func (r *TransformAggregateCaseToSelection) isOnlyOneNotNull(ctx sessionctx.Cont return !args[outputIdx].Equal(ctx, expression.NewNull()) && (argsNum == 2 || args[3-outputIdx].Equal(ctx, expression.NewNull())) } -// TransformAggregateCaseToSelection only support `case when cond then var end` and `case when cond then var1 else var2 end`. +// isTwoOrThreeArgCase represents that TransformAggregateCaseToSelection only support `case when cond then var end` and `case when cond then var1 else var2 end`. func (r *TransformAggregateCaseToSelection) isTwoOrThreeArgCase(expr expression.Expression) bool { scalarFunc, ok := expr.(*expression.ScalarFunction) if !ok { @@ -2315,7 +2315,7 @@ func NewRuleInjectProjectionBelowAgg() Transformation { // Match implements Transformation interface. func (r *InjectProjectionBelowAgg) Match(expr *memo.ExprIter) bool { agg := expr.GetExpr().ExprNode.(*plannercore.LogicalAggregation) - return agg.IsCompleteModeAgg() + return agg.HasCompleteModeAgg() } // OnTransform implements Transformation interface. @@ -2326,9 +2326,15 @@ func (r *InjectProjectionBelowAgg) OnTransform(old *memo.ExprIter) (newExprs []* hasScalarFunc := false copyFuncs := make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)) for _, aggFunc := range agg.AggFuncs { - copyFunc := aggFunc.Clone() // WrapCastForAggArgs will modify AggFunc, so we should clone AggFunc. - copyFunc.WrapCastForAggArgs(agg.SCtx()) + copyFunc := aggFunc.Clone() + + // if aggFunc input is from 'partial data', no need to wrap cast for agg args + copyFunc.WrapCastAsDecimalForAggArgs(agg.SCtx()) + if copyFunc.Mode != aggregation.FinalMode && copyFunc.Mode != aggregation.Partial2Mode { + copyFunc.WrapCastForAggArgs(agg.SCtx()) + } + copyFuncs = append(copyFuncs, copyFunc) for _, arg := range copyFunc.Args { _, isScalarFunc := arg.(*expression.ScalarFunction) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index dedd379f1b91d..dc483197f5079 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -326,6 +326,19 @@ func (la *LogicalAggregation) HasDistinct() bool { return false } +// HasCompleteModeAgg shows whether LogicalAggregation has functions with CompleteMode. +func (la *LogicalAggregation) HasCompleteModeAgg() bool { + // not all of the AggFunctions has the same AggMode + // for example: when cascades planner on, after PushAggDownGather transformed, + // some aggFunctions are CompleteMode, and the others are FinalMode + for _, aggFunc := range la.AggFuncs { + if aggFunc.Mode == aggregation.CompleteMode { + return true + } + } + return false +} + // CopyAggHints copies the aggHints from another LogicalAggregation. func (la *LogicalAggregation) CopyAggHints(agg *LogicalAggregation) { // TODO: Copy the hint may make the un-applicable hint throw the @@ -391,6 +404,7 @@ func (la *LogicalAggregation) GetUsedCols() (usedCols []*expression.Column) { type LogicalSelection struct { baseLogicalPlan + // Conditions represents a list of AND conditions. // Originally the WHERE or ON condition is parsed into a single expression, // but after we converted to CNF(Conjunctive normal form), it can be // split into a list of AND conditions. @@ -495,12 +509,13 @@ type DataSource struct { // possibleAccessPaths stores all the possible access path for physical plan, including table scan. possibleAccessPaths []*util.AccessPath + // isPartition represents whether the data source is a partition. // The data source may be a partition, rather than a real table. isPartition bool physicalTableID int64 partitionNames []model.CIStr - // handleCol represents the handle column for the datasource, either the + // handleCols represents the handle column for the datasource, either the // int primary key column or extra handle column. // handleCol *expression.Column handleCols HandleCols @@ -558,7 +573,7 @@ type LogicalTableScan struct { // LogicalIndexScan is the logical index scan operator for TiKV. type LogicalIndexScan struct { logicalSchemaProducer - // DataSource should be read-only here. + // Source should be read-only here. Source *DataSource IsDoubleRead bool @@ -1191,7 +1206,7 @@ type LogicalShowDDLJobs struct { // CTEClass holds the information and plan for a CTE. Most of the fields in this struct are the same as cteInfo. // But the cteInfo is used when building the plan, and CTEClass is used also for building the executor. type CTEClass struct { - // The union between seed part and recursive part is DISTINCT or DISTINCT ALL. + // IsDistinct represents the union between seed part and recursive part is DISTINCT or DISTINCT ALL. IsDistinct bool // seedPartLogicalPlan and recursivePartLogicalPlan are the logical plans for the seed part and recursive part of this CTE. seedPartLogicalPlan LogicalPlan @@ -1201,7 +1216,7 @@ type CTEClass struct { recursivePartPhysicalPlan PhysicalPlan // cteTask is the physical plan for this CTE, is a wrapper of the PhysicalCTE. cteTask task - // storageID for this CTE. + // IDForStorage represents the storageID for this CTE. IDForStorage int // optFlag is the optFlag for the whole CTE. optFlag uint64 diff --git a/planner/core/logical_plans_test.go b/planner/core/logical_plans_test.go index 08617ce93d480..ac85b3b8db0c1 100644 --- a/planner/core/logical_plans_test.go +++ b/planner/core/logical_plans_test.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" @@ -207,3 +208,28 @@ func (s *testUnitTestSuit) TestIndexPathSplitCorColCond(c *C) { } collate.SetNewCollationEnabledForTest(false) } + +func (s *testUnitTestSuit) TestHasCompleteModeAgg(c *C) { + defer testleak.AfterTest(c)() + + aggFuncs := make([]*aggregation.AggFuncDesc, 2) + aggFuncs[0] = &aggregation.AggFuncDesc{ + Mode: aggregation.FinalMode, + HasDistinct: true, + } + aggFuncs[1] = &aggregation.AggFuncDesc{ + Mode: aggregation.CompleteMode, + HasDistinct: true, + } + + newAgg := &LogicalAggregation{ + AggFuncs: aggFuncs, + } + c.Assert(newAgg.HasCompleteModeAgg(), Equals, true) + + aggFuncs[1] = &aggregation.AggFuncDesc{ + Mode: aggregation.FinalMode, + HasDistinct: true, + } + c.Assert(newAgg.HasCompleteModeAgg(), Equals, false) +}