From 61a6b06d71ccb887e8859006c363a0dd260efc72 Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Sun, 4 Feb 2024 16:32:42 +0800 Subject: [PATCH 1/2] refactor NewDistAggFunc for mpp agg aware of aggDesc's agg execution mode Signed-off-by: AilinKid <314806019@qq.com> --- pkg/expression/aggregation/aggregation.go | 46 +++++++++++++------ .../mockstore/mockcopr/cop_handler_dag.go | 2 +- .../unistore/cophandler/cop_handler.go | 2 +- .../mockstore/unistore/cophandler/mpp.go | 26 ++++++++++- 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go index 0455ba1b284f2..6db82f2761119 100644 --- a/pkg/expression/aggregation/aggregation.go +++ b/pkg/expression/aggregation/aggregation.go @@ -51,38 +51,58 @@ type Aggregation interface { } // NewDistAggFunc creates new Aggregate function for mock tikv. -func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx sessionctx.Context) (Aggregation, error) { +func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx sessionctx.Context) (Aggregation, *AggFuncDesc, error) { args := make([]expression.Expression, 0, len(expr.Children)) for _, child := range expr.Children { arg, err := expression.PBToExpr(ctx, child, fieldTps) if err != nil { - return nil, err + return nil, nil, err } args = append(args, arg) } switch expr.Tp { case tipb.ExprType_Sum: - return &sumFunction{aggFunction: newAggFunc(ast.AggFuncSum, args, false)}, nil + aggF := newAggFunc(ast.AggFuncSum, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &sumFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Count: - return &countFunction{aggFunction: newAggFunc(ast.AggFuncCount, args, false)}, nil + aggF := newAggFunc(ast.AggFuncCount, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &countFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Avg: - return &avgFunction{aggFunction: newAggFunc(ast.AggFuncAvg, args, false)}, nil + aggF := newAggFunc(ast.AggFuncAvg, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &avgFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_GroupConcat: - return &concatFunction{aggFunction: newAggFunc(ast.AggFuncGroupConcat, args, false)}, nil + aggF := newAggFunc(ast.AggFuncGroupConcat, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &concatFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Max: - return &maxMinFunction{aggFunction: newAggFunc(ast.AggFuncMax, args, false), isMax: true, ctor: collate.GetCollator(args[0].GetType().GetCollate())}, nil + aggF := newAggFunc(ast.AggFuncMax, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &maxMinFunction{aggFunction: aggF, isMax: true, ctor: collate.GetCollator(args[0].GetType().GetCollate())}, aggF.AggFuncDesc, nil case tipb.ExprType_Min: - return &maxMinFunction{aggFunction: newAggFunc(ast.AggFuncMin, args, false), ctor: collate.GetCollator(args[0].GetType().GetCollate())}, nil + aggF := newAggFunc(ast.AggFuncMin, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &maxMinFunction{aggFunction: aggF, ctor: collate.GetCollator(args[0].GetType().GetCollate())}, aggF.AggFuncDesc, nil case tipb.ExprType_First: - return &firstRowFunction{aggFunction: newAggFunc(ast.AggFuncFirstRow, args, false)}, nil + aggF := newAggFunc(ast.AggFuncFirstRow, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &firstRowFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Agg_BitOr: - return &bitOrFunction{aggFunction: newAggFunc(ast.AggFuncBitOr, args, false)}, nil + aggF := newAggFunc(ast.AggFuncBitOr, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &bitOrFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Agg_BitXor: - return &bitXorFunction{aggFunction: newAggFunc(ast.AggFuncBitXor, args, false)}, nil + aggF := newAggFunc(ast.AggFuncBitXor, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &bitXorFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Agg_BitAnd: - return &bitAndFunction{aggFunction: newAggFunc(ast.AggFuncBitAnd, args, false)}, nil + aggF := newAggFunc(ast.AggFuncBitAnd, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &bitAndFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil } - return nil, errors.Errorf("Unknown aggregate function type %v", expr.Tp) + return nil, nil, errors.Errorf("Unknown aggregate function type %v", expr.Tp) } // AggEvaluateContext is used to store intermediate result when calculating aggregate functions. diff --git a/pkg/store/mockstore/mockcopr/cop_handler_dag.go b/pkg/store/mockstore/mockcopr/cop_handler_dag.go index 7d383a59c3745..8bdf3da053186 100644 --- a/pkg/store/mockstore/mockcopr/cop_handler_dag.go +++ b/pkg/store/mockstore/mockcopr/cop_handler_dag.go @@ -325,7 +325,7 @@ func (h coprHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]agg var relatedColOffsets []int for _, expr := range executor.Aggregation.AggFunc { var aggExpr aggregation.Aggregation - aggExpr, err = aggregation.NewDistAggFunc(expr, ctx.evalCtx.fieldTps, ctx.evalCtx.sctx) + aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.evalCtx.fieldTps, ctx.evalCtx.sctx) if err != nil { return nil, nil, nil, errors.Trace(err) } diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go b/pkg/store/mockstore/unistore/cophandler/cop_handler.go index 67d707f7b6c4a..20b8555d3c2a4 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go @@ -331,7 +331,7 @@ func getAggInfo(ctx *dagContext, pbAgg *tipb.Aggregation) ([]aggregation.Aggrega var err error for _, expr := range pbAgg.AggFunc { var aggExpr aggregation.Aggregation - aggExpr, err = aggregation.NewDistAggFunc(expr, ctx.fieldTps, ctx.sctx) + aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.fieldTps, ctx.sctx) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/pkg/store/mockstore/unistore/cophandler/mpp.go b/pkg/store/mockstore/unistore/cophandler/mpp.go index 92f5507c3484e..7ed0d5a61273c 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp.go @@ -79,6 +79,7 @@ func (b *mppExecBuilder) buildMPPTableScan(pb *tipb.TableScan) (*tableScanExec, ndvs: b.ndvs, desc: pb.Desc, paging: b.paging, + ID: pb.TableId, } if b.dagCtx != nil { ts.lockStore = b.dagCtx.lockStore @@ -106,6 +107,7 @@ func (b *mppExecBuilder) buildMPPPartitionTableScan(pb *tipb.PartitionTableScan) startTS: b.dagCtx.startTS, kvRanges: ranges, dbReader: b.dbReader, + ID: pb.TableId, } for i, col := range pb.Columns { if col.ColumnId == model.ExtraPhysTblID { @@ -509,13 +511,15 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { return nil, errors.Trace(err) } e.children = []mppExec{chExec} + tmpAggDescs := make([]*aggregation.AggFuncDesc, 0, len(agg.AggFunc)) for _, aggFunc := range agg.AggFunc { ft := expression.PbTypeToFieldType(aggFunc.FieldType) e.fieldTypes = append(e.fieldTypes, ft) - aggExpr, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) + aggExpr, aggDesc, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) if err != nil { return nil, errors.Trace(err) } + tmpAggDescs = append(tmpAggDescs, aggDesc) e.aggExprs = append(e.aggExprs, aggExpr) } e.sctx = b.sctx @@ -530,6 +534,26 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { } e.groupByExprs = append(e.groupByExprs, gbyExpr) } + + // fill the default value. logic copied from `func (b *executorBuilder) buildHashAgg` + if len(e.groupByExprs) != 0 || aggregation.IsAllFirstRow(tmpAggDescs) { + e.DefaultVal = nil + } else { + // Only do this for final agg, see issue #35295, #30923 + isFinalAgg := false + if len(tmpAggDescs) > 0 { + if tmpAggDescs[0].Mode == aggregation.FinalMode || tmpAggDescs[0].Mode == aggregation.CompleteMode { + isFinalAgg = true + } + } + if isFinalAgg { + e.DefaultVal = chunk.NewChunkWithCapacity(e.fieldTypes, 1) + for i, aggDesc := range tmpAggDescs { + result := aggDesc.GetDefaultValue() + e.DefaultVal.AppendDatum(i, &result) + } + } + } return e, nil } From d34744673f017c468381caf82632b98b3e31dc0c Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Sun, 4 Feb 2024 16:35:17 +0800 Subject: [PATCH 2/2] . Signed-off-by: AilinKid <314806019@qq.com> --- .../mockstore/unistore/cophandler/mpp.go | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/pkg/store/mockstore/unistore/cophandler/mpp.go b/pkg/store/mockstore/unistore/cophandler/mpp.go index 7ed0d5a61273c..cbf3799c6ef70 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp.go @@ -79,7 +79,6 @@ func (b *mppExecBuilder) buildMPPTableScan(pb *tipb.TableScan) (*tableScanExec, ndvs: b.ndvs, desc: pb.Desc, paging: b.paging, - ID: pb.TableId, } if b.dagCtx != nil { ts.lockStore = b.dagCtx.lockStore @@ -107,7 +106,6 @@ func (b *mppExecBuilder) buildMPPPartitionTableScan(pb *tipb.PartitionTableScan) startTS: b.dagCtx.startTS, kvRanges: ranges, dbReader: b.dbReader, - ID: pb.TableId, } for i, col := range pb.Columns { if col.ColumnId == model.ExtraPhysTblID { @@ -511,15 +509,13 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { return nil, errors.Trace(err) } e.children = []mppExec{chExec} - tmpAggDescs := make([]*aggregation.AggFuncDesc, 0, len(agg.AggFunc)) for _, aggFunc := range agg.AggFunc { ft := expression.PbTypeToFieldType(aggFunc.FieldType) e.fieldTypes = append(e.fieldTypes, ft) - aggExpr, aggDesc, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) + aggExpr, _, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) if err != nil { return nil, errors.Trace(err) } - tmpAggDescs = append(tmpAggDescs, aggDesc) e.aggExprs = append(e.aggExprs, aggExpr) } e.sctx = b.sctx @@ -534,26 +530,6 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { } e.groupByExprs = append(e.groupByExprs, gbyExpr) } - - // fill the default value. logic copied from `func (b *executorBuilder) buildHashAgg` - if len(e.groupByExprs) != 0 || aggregation.IsAllFirstRow(tmpAggDescs) { - e.DefaultVal = nil - } else { - // Only do this for final agg, see issue #35295, #30923 - isFinalAgg := false - if len(tmpAggDescs) > 0 { - if tmpAggDescs[0].Mode == aggregation.FinalMode || tmpAggDescs[0].Mode == aggregation.CompleteMode { - isFinalAgg = true - } - } - if isFinalAgg { - e.DefaultVal = chunk.NewChunkWithCapacity(e.fieldTypes, 1) - for i, aggDesc := range tmpAggDescs { - result := aggDesc.GetDefaultValue() - e.DefaultVal.AppendDatum(i, &result) - } - } - } return e, nil }