From 40efd0ef83becf5eb8c98d739676fa7b3245099a Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Tue, 28 Jun 2022 12:08:39 +0800 Subject: [PATCH] cherry pick #35443 to release-4.0 Signed-off-by: ti-srebot --- executor/aggregate_test.go | 176 +++++++++++++++++++++ executor/builder.go | 10 +- expression/aggregation/descriptor.go | 2 - planner/core/physical_plans.go | 30 ++++ planner/core/rule_aggregation_push_down.go | 16 ++ planner/core/rule_eliminate_projection.go | 24 +++ planner/core/task.go | 100 ++++++++++++ 7 files changed, 354 insertions(+), 4 deletions(-) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 1b7e6a9b566dd..6ee37cf27983e 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -1212,3 +1212,179 @@ func (s *testSuiteAgg) TestIssue23277(c *C) { tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000")) tk.MustExec("drop table t;") } +<<<<<<< HEAD +======= + +func TestAvgDecimal(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists td;") + tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));") + tk.MustExec("insert into td values (null, 22876);") + tk.MustExec("insert into td values (9220557287087669248, 32767);") + tk.MustExec("insert into td values (28030, 32767);") + tk.MustExec("insert into td values (-3309864251140603904,32767);") + tk.MustExec("insert into td values (4,0);") + tk.MustExec("insert into td values (null,0);") + tk.MustExec("insert into td values (4,-23828);") + tk.MustExec("insert into td values (54720,32767);") + tk.MustExec("insert into td values (0,29815);") + tk.MustExec("insert into td values (10017,-32661);") + tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260")) + tk.MustQuery(" SELECT AVG(col_bigint) OVER (PARTITION BY col_smallint) as field2 FROM td where col_smallint = -23828;").Sort().Check(testkit.Rows("4.0000")) + tk.MustExec("drop table td;") +} + +// https://github.com/pingcap/tidb/issues/23314 +func TestIssue23314(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(col1 time(2) NOT NULL)") + tk.MustExec("insert into t1 values(\"16:40:20.01\")") + res := tk.MustQuery("select col1 from t1 group by col1") + res.Check(testkit.Rows("16:40:20.01")) +} + +func TestAggInDisk(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set tidb_hashagg_final_concurrency = 1;") + tk.MustExec("set tidb_hashagg_partial_concurrency = 1;") + tk.MustExec("set tidb_mem_quota_query = 4194304") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(a int)") + sql := "insert into t values (0)" + for i := 1; i <= 200; i++ { + sql += fmt.Sprintf(",(%v)", i) + } + sql += ";" + tk.MustExec(sql) + rows := tk.MustQuery("desc analyze select /*+ HASH_AGG() */ avg(t1.a) from t t1 join t t2 group by t1.a, t2.a;").Rows() + for _, row := range rows { + length := len(row) + line := fmt.Sprintf("%v", row) + disk := fmt.Sprintf("%v", row[length-1]) + if strings.Contains(line, "HashAgg") { + require.False(t, strings.Contains(disk, "0 Bytes")) + require.True(t, strings.Contains(disk, "MB") || + strings.Contains(disk, "KB") || + strings.Contains(disk, "Bytes")) + } + } + + // Add code cover + // Test spill chunk. Add a line to avoid tmp spill chunk is always full. + tk.MustExec("insert into t values(0)") + tk.MustQuery("select sum(tt.b) from ( select /*+ HASH_AGG() */ avg(t1.a) as b from t t1 join t t2 group by t1.a, t2.a) as tt").Check( + testkit.Rows("4040100.0000")) + // Test no groupby and no data. + tk.MustExec("drop table t;") + tk.MustExec("create table t(c int, c1 int);") + tk.MustQuery("select /*+ HASH_AGG() */ count(c) from t;").Check(testkit.Rows("0")) + tk.MustQuery("select /*+ HASH_AGG() */ count(c) from t group by c1;").Check(testkit.Rows()) +} + +func TestRandomPanicAggConsume(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("set @@tidb_max_chunk_size=32") + tk.MustExec("set @@tidb_init_chunk_size=1") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + for i := 0; i <= 1000; i++ { + tk.MustExec(fmt.Sprintf("insert into t values(%v),(%v),(%v)", i, i, i)) + } + + fpName := "github.com/pingcap/tidb/executor/ConsumeRandomPanic" + require.NoError(t, failpoint.Enable(fpName, "5%panic(\"ERROR 1105 (HY000): Out Of Memory Quota![conn_id=1]\")")) + defer func() { + require.NoError(t, failpoint.Disable(fpName)) + }() + + // Test 10 times panic for each AggExec. + var res sqlexec.RecordSet + for i := 1; i <= 10; i++ { + var err error + for err == nil { + // Test paralleled hash agg. + res, err = tk.Exec("select /*+ HASH_AGG() */ count(a) from t group by a") + if err == nil { + _, err = session.GetRows4Test(context.Background(), tk.Session(), res) + require.NoError(t, res.Close()) + } + } + require.EqualError(t, err, "failpoint panic: ERROR 1105 (HY000): Out Of Memory Quota![conn_id=1]") + + err = nil + for err == nil { + // Test unparalleled hash agg. + res, err = tk.Exec("select /*+ HASH_AGG() */ count(distinct a) from t") + if err == nil { + _, err = session.GetRows4Test(context.Background(), tk.Session(), res) + require.NoError(t, res.Close()) + } + } + require.EqualError(t, err, "failpoint panic: ERROR 1105 (HY000): Out Of Memory Quota![conn_id=1]") + + err = nil + for err == nil { + // Test stream agg. + res, err = tk.Exec("select /*+ STREAM_AGG() */ count(a) from t") + if err == nil { + _, err = session.GetRows4Test(context.Background(), tk.Session(), res) + require.NoError(t, res.Close()) + } + } + require.EqualError(t, err, "failpoint panic: ERROR 1105 (HY000): Out Of Memory Quota![conn_id=1]") + } +} + +func TestIssue35295(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t100") + // This bug only happens on partition prune mode = 'static' + tk.MustExec("set @@tidb_partition_prune_mode = 'static'") + tk.MustExec(`CREATE TABLE t100 ( +ID bigint(20) unsigned NOT NULL AUTO_INCREMENT, +col1 int(10) NOT NULL DEFAULT '0' COMMENT 'test', +money bigint(20) NOT NULL COMMENT 'test', +logtime datetime NOT NULL COMMENT '记录时间', +PRIMARY KEY (ID,logtime) +) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin AUTO_INCREMENT=1 COMMENT='test' +PARTITION BY RANGE COLUMNS(logtime) ( +PARTITION p20220608 VALUES LESS THAN ("20220609"), +PARTITION p20220609 VALUES LESS THAN ("20220610"), +PARTITION p20220610 VALUES LESS THAN ("20220611"), +PARTITION p20220611 VALUES LESS THAN ("20220612"), +PARTITION p20220612 VALUES LESS THAN ("20220613"), +PARTITION p20220613 VALUES LESS THAN ("20220614"), +PARTITION p20220614 VALUES LESS THAN ("20220615"), +PARTITION p20220615 VALUES LESS THAN ("20220616"), +PARTITION p20220616 VALUES LESS THAN ("20220617"), +PARTITION p20220617 VALUES LESS THAN ("20220618"), +PARTITION p20220618 VALUES LESS THAN ("20220619"), +PARTITION p20220619 VALUES LESS THAN ("20220620"), +PARTITION p20220620 VALUES LESS THAN ("20220621"), +PARTITION p20220621 VALUES LESS THAN ("20220622"), +PARTITION p20220622 VALUES LESS THAN ("20220623"), +PARTITION p20220623 VALUES LESS THAN ("20220624"), +PARTITION p20220624 VALUES LESS THAN ("20220625") + );`) + tk.MustExec("insert into t100(col1,money,logtime) values (100,10,'2022-06-09 00:00:00');") + tk.MustExec("insert into t100(col1,money,logtime) values (100,10,'2022-06-10 00:00:00');") + tk.MustQuery("SELECT /*+STREAM_AGG()*/ col1,sum(money) FROM t100 WHERE logtime>='2022-06-09 00:00:00' AND col1=100 ;").Check(testkit.Rows("100 20")) + tk.MustQuery("SELECT /*+HASH_AGG()*/ col1,sum(money) FROM t100 WHERE logtime>='2022-06-09 00:00:00' AND col1=100 ;").Check(testkit.Rows("100 20")) +} +>>>>>>> d99b35822... *: only add default value for final aggregation to fix the aggregate push down (partition) union case (#35443) diff --git a/executor/builder.go b/executor/builder.go index 906a518e59c5f..b6a200915d3eb 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1182,7 +1182,9 @@ func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) Executor if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { e.defaultVal = nil } else { - e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) + if v.IsFinalAgg() { + e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) + } } for _, aggDesc := range v.AggFuncs { if aggDesc.HasDistinct || len(aggDesc.OrderByItems) > 0 { @@ -1238,10 +1240,14 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) Execu groupChecker: newVecGroupChecker(b.ctx, v.GroupByItems), aggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)), } + if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { e.defaultVal = nil } else { - e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) + // Only do this for final agg, see issue #35295, #30923 + if v.IsFinalAgg() { + e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) + } } for i, aggDesc := range v.AggFuncs { aggFunc := aggfuncs.Build(b.ctx, aggDesc, i) diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index af16d26b1f81a..c1366a7fc6222 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -103,8 +103,6 @@ func (a *AggFuncDesc) Split(ordinal []int) (partialAggDesc, finalAggDesc *AggFun partialAggDesc.Mode = Partial1Mode } else if a.Mode == FinalMode { partialAggDesc.Mode = Partial2Mode - } else { - panic("Error happened during AggFuncDesc.Split, the AggFunctionMode is not CompleteMode or FinalMode.") } finalAggDesc = &AggFuncDesc{ Mode: FinalMode, // We only support FinalMode now in final phase. diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index c0f8820d5f03f..f708e1344e973 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -500,8 +500,38 @@ type PhysicalUnionAll struct { type basePhysicalAgg struct { physicalSchemaProducer +<<<<<<< HEAD AggFuncs []*aggregation.AggFuncDesc GroupByItems []expression.Expression +======= + AggFuncs []*aggregation.AggFuncDesc + GroupByItems []expression.Expression + MppRunMode AggMppRunMode + MppPartitionCols []*property.MPPPartitionColumn +} + +func (p *basePhysicalAgg) IsFinalAgg() bool { + if len(p.AggFuncs) > 0 { + if p.AggFuncs[0].Mode == aggregation.FinalMode || p.AggFuncs[0].Mode == aggregation.CompleteMode { + return true + } + } + return false +} + +func (p *basePhysicalAgg) cloneWithSelf(newSelf PhysicalPlan) (*basePhysicalAgg, error) { + cloned := new(basePhysicalAgg) + base, err := p.physicalSchemaProducer.cloneWithSelf(newSelf) + if err != nil { + return nil, err + } + cloned.physicalSchemaProducer = *base + for _, aggDesc := range p.AggFuncs { + cloned.AggFuncs = append(cloned.AggFuncs, aggDesc.Clone()) + } + cloned.GroupByItems = cloneExprs(p.GroupByItems) + return cloned, nil +>>>>>>> d99b35822... *: only add default value for final aggregation to fix the aggregate push down (partition) union case (#35443) } func (p *basePhysicalAgg) numDistinctFunc() (num int) { diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 8668efc2138dd..a9056f3a11c07 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -376,6 +376,22 @@ func (a *aggregationPushDownSolver) tryAggPushDownForUnion(union *LogicalUnionAl } } pushedAgg := a.splitPartialAgg(agg) +<<<<<<< HEAD +======= + if pushedAgg == nil { + return nil + } + + // Update the agg mode for the pushed down aggregation. + for _, aggFunc := range pushedAgg.AggFuncs { + if aggFunc.Mode == aggregation.CompleteMode { + aggFunc.Mode = aggregation.Partial1Mode + } else if aggFunc.Mode == aggregation.FinalMode { + aggFunc.Mode = aggregation.Partial2Mode + } + } + +>>>>>>> d99b35822... *: only add default value for final aggregation to fix the aggregate push down (partition) union case (#35443) newChildren := make([]LogicalPlan, 0, len(union.Children())) for _, child := range union.Children() { newChild, err := a.pushAggCrossUnion(pushedAgg, union.Schema(), child) diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 51ab76a34dc7c..67bf7a9fe6a83 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -34,6 +34,30 @@ func canProjectionBeEliminatedLoose(p *LogicalProjection) bool { // canProjectionBeEliminatedStrict checks whether a projection can be // eliminated, returns true if the projection just copy its child's output. func canProjectionBeEliminatedStrict(p *PhysicalProjection) bool { +<<<<<<< HEAD +======= + // This is due to the in-compatibility between TiFlash and TiDB: + // For TiDB, the output schema of final agg is all the aggregated functions and for + // TiFlash, the output schema of agg(TiFlash not aware of the aggregation mode) is + // aggregated functions + group by columns, so to make the things work, for final + // mode aggregation that need to be running in TiFlash, always add an extra Project + // the align the output schema. In the future, we can solve this in-compatibility by + // passing down the aggregation mode to TiFlash. + if physicalAgg, ok := p.Children()[0].(*PhysicalHashAgg); ok { + if physicalAgg.MppRunMode == Mpp1Phase || physicalAgg.MppRunMode == Mpp2Phase || physicalAgg.MppRunMode == MppScalar { + if physicalAgg.IsFinalAgg() { + return false + } + } + } + if physicalAgg, ok := p.Children()[0].(*PhysicalStreamAgg); ok { + if physicalAgg.MppRunMode == Mpp1Phase || physicalAgg.MppRunMode == Mpp2Phase || physicalAgg.MppRunMode == MppScalar { + if physicalAgg.IsFinalAgg() { + return false + } + } + } +>>>>>>> d99b35822... *: only add default value for final aggregation to fix the aggregate push down (partition) union case (#35443) // If this projection is specially added for `DO`, we keep it. if p.CalculateNoDelay { return false diff --git a/planner/core/task.go b/planner/core/task.go index 0386c253d765d..c6a7c7cfa2203 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1186,8 +1186,22 @@ func BuildFinalModeAggregation( } } +<<<<<<< HEAD finalAggFunc.HasDistinct = true finalAggFunc.Mode = aggregation.CompleteMode +======= + finalAggFunc.OrderByItems = byItems + finalAggFunc.HasDistinct = aggFunc.HasDistinct + // In logical optimize phase, the Agg->PartitionUnion->TableReader may become + // Agg1->PartitionUnion->Agg2->TableReader, and the Agg2 is a partial aggregation. + // So in the push down here, we need to add a new if-condition check: + // If the original agg mode is partial already, the finalAggFunc's mode become Partial2. + if aggFunc.Mode == aggregation.CompleteMode { + finalAggFunc.Mode = aggregation.CompleteMode + } else if aggFunc.Mode == aggregation.Partial1Mode || aggFunc.Mode == aggregation.Partial2Mode { + finalAggFunc.Mode = aggregation.Partial2Mode + } +>>>>>>> d99b35822... *: only add default value for final aggregation to fix the aggregate push down (partition) union case (#35443) } else { if aggregation.NeedCount(finalAggFunc.Name) { ft := types.NewFieldType(mysql.TypeLonglong) @@ -1236,8 +1250,20 @@ func BuildFinalModeAggregation( partial.AggFuncs = append(partial.AggFuncs, aggFunc) } +<<<<<<< HEAD finalAggFunc.Mode = aggregation.FinalMode funcMap[aggFunc] = finalAggFunc +======= + // In logical optimize phase, the Agg->PartitionUnion->TableReader may become + // Agg1->PartitionUnion->Agg2->TableReader, and the Agg2 is a partial aggregation. + // So in the push down here, we need to add a new if-condition check: + // If the original agg mode is partial already, the finalAggFunc's mode become Partial2. + if aggFunc.Mode == aggregation.CompleteMode { + finalAggFunc.Mode = aggregation.FinalMode + } else if aggFunc.Mode == aggregation.Partial1Mode || aggFunc.Mode == aggregation.Partial2Mode { + finalAggFunc.Mode = aggregation.Partial2Mode + } +>>>>>>> d99b35822... *: only add default value for final aggregation to fix the aggregate push down (partition) union case (#35443) } finalAggFunc.Args = args @@ -1248,7 +1274,81 @@ func BuildFinalModeAggregation( return } +<<<<<<< HEAD func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType) (partial, final PhysicalPlan) { +======= +// convertAvgForMPP converts avg(arg) to sum(arg)/(case when count(arg)=0 then 1 else count(arg) end), in detail: +// 1.rewrite avg() in the final aggregation to count() and sum(), and reconstruct its schema. +// 2.replace avg() with sum(arg)/(case when count(arg)=0 then 1 else count(arg) end) and reuse the original schema of the final aggregation. +// If there is no avg, nothing is changed and return nil. +func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { + newSchema := expression.NewSchema() + newSchema.Keys = p.schema.Keys + newSchema.UniqueKeys = p.schema.UniqueKeys + newAggFuncs := make([]*aggregation.AggFuncDesc, 0, 2*len(p.AggFuncs)) + exprs := make([]expression.Expression, 0, 2*len(p.schema.Columns)) + // add agg functions schema + for i, aggFunc := range p.AggFuncs { + if aggFunc.Name == ast.AggFuncAvg { + // inset a count(column) + avgCount := aggFunc.Clone() + avgCount.Name = ast.AggFuncCount + err := avgCount.TypeInfer(p.ctx) + if err != nil { // must not happen + return nil + } + newAggFuncs = append(newAggFuncs, avgCount) + avgCountCol := &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: avgCount.RetTp, + } + newSchema.Append(avgCountCol) + // insert a sum(column) + avgSum := aggFunc.Clone() + avgSum.Name = ast.AggFuncSum + avgSum.TypeInfer4AvgSum(avgSum.RetTp) + newAggFuncs = append(newAggFuncs, avgSum) + avgSumCol := &expression.Column{ + UniqueID: p.schema.Columns[i].UniqueID, + RetType: avgSum.RetTp, + } + newSchema.Append(avgSumCol) + // avgSumCol/(case when avgCountCol=0 then 1 else avgCountCol end) + eq := expression.NewFunctionInternal(p.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), avgCountCol, expression.NewZero()) + caseWhen := expression.NewFunctionInternal(p.ctx, ast.Case, avgCountCol.RetType, eq, expression.NewOne(), avgCountCol) + divide := expression.NewFunctionInternal(p.ctx, ast.Div, avgSumCol.RetType, avgSumCol, caseWhen) + divide.(*expression.ScalarFunction).RetType = p.schema.Columns[i].RetType + exprs = append(exprs, divide) + } else { + newAggFuncs = append(newAggFuncs, aggFunc) + newSchema.Append(p.schema.Columns[i]) + exprs = append(exprs, p.schema.Columns[i]) + } + } + // no avgs + // for final agg, always add project due to in-compatibility between TiDB and TiFlash + if len(p.schema.Columns) == len(newSchema.Columns) && !p.IsFinalAgg() { + return nil + } + // add remaining columns to exprs + for i := len(p.AggFuncs); i < len(p.schema.Columns); i++ { + exprs = append(exprs, p.schema.Columns[i]) + } + proj := PhysicalProjection{ + Exprs: exprs, + CalculateNoDelay: false, + AvoidColumnEvaluator: false, + }.Init(p.SCtx(), p.stats, p.SelectBlockOffset(), p.GetChildReqProps(0).CloneEssentialFields()) + proj.SetSchema(p.schema) + + p.AggFuncs = newAggFuncs + p.schema = newSchema + + return proj +} + +func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType, isMPPTask bool) (partial, final PhysicalPlan) { +>>>>>>> d99b35822... *: only add default value for final aggregation to fix the aggregate push down (partition) union case (#35443) // Check if this aggregation can push down. if !CheckAggCanPushCop(p.ctx, p.AggFuncs, p.GroupByItems, copTaskType) { return nil, p.self