diff --git a/pkg/expression/aggregation/aggregation.go b/pkg/expression/aggregation/aggregation.go index 0455ba1b284f2..6db82f2761119 100644 --- a/pkg/expression/aggregation/aggregation.go +++ b/pkg/expression/aggregation/aggregation.go @@ -51,38 +51,58 @@ type Aggregation interface { } // NewDistAggFunc creates new Aggregate function for mock tikv. -func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx sessionctx.Context) (Aggregation, error) { +func NewDistAggFunc(expr *tipb.Expr, fieldTps []*types.FieldType, ctx sessionctx.Context) (Aggregation, *AggFuncDesc, error) { args := make([]expression.Expression, 0, len(expr.Children)) for _, child := range expr.Children { arg, err := expression.PBToExpr(ctx, child, fieldTps) if err != nil { - return nil, err + return nil, nil, err } args = append(args, arg) } switch expr.Tp { case tipb.ExprType_Sum: - return &sumFunction{aggFunction: newAggFunc(ast.AggFuncSum, args, false)}, nil + aggF := newAggFunc(ast.AggFuncSum, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &sumFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Count: - return &countFunction{aggFunction: newAggFunc(ast.AggFuncCount, args, false)}, nil + aggF := newAggFunc(ast.AggFuncCount, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &countFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Avg: - return &avgFunction{aggFunction: newAggFunc(ast.AggFuncAvg, args, false)}, nil + aggF := newAggFunc(ast.AggFuncAvg, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &avgFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_GroupConcat: - return &concatFunction{aggFunction: newAggFunc(ast.AggFuncGroupConcat, args, false)}, nil + aggF := newAggFunc(ast.AggFuncGroupConcat, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &concatFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Max: - return &maxMinFunction{aggFunction: newAggFunc(ast.AggFuncMax, args, false), isMax: true, ctor: collate.GetCollator(args[0].GetType().GetCollate())}, nil + aggF := newAggFunc(ast.AggFuncMax, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &maxMinFunction{aggFunction: aggF, isMax: true, ctor: collate.GetCollator(args[0].GetType().GetCollate())}, aggF.AggFuncDesc, nil case tipb.ExprType_Min: - return &maxMinFunction{aggFunction: newAggFunc(ast.AggFuncMin, args, false), ctor: collate.GetCollator(args[0].GetType().GetCollate())}, nil + aggF := newAggFunc(ast.AggFuncMin, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &maxMinFunction{aggFunction: aggF, ctor: collate.GetCollator(args[0].GetType().GetCollate())}, aggF.AggFuncDesc, nil case tipb.ExprType_First: - return &firstRowFunction{aggFunction: newAggFunc(ast.AggFuncFirstRow, args, false)}, nil + aggF := newAggFunc(ast.AggFuncFirstRow, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &firstRowFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Agg_BitOr: - return &bitOrFunction{aggFunction: newAggFunc(ast.AggFuncBitOr, args, false)}, nil + aggF := newAggFunc(ast.AggFuncBitOr, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &bitOrFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Agg_BitXor: - return &bitXorFunction{aggFunction: newAggFunc(ast.AggFuncBitXor, args, false)}, nil + aggF := newAggFunc(ast.AggFuncBitXor, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &bitXorFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil case tipb.ExprType_Agg_BitAnd: - return &bitAndFunction{aggFunction: newAggFunc(ast.AggFuncBitAnd, args, false)}, nil + aggF := newAggFunc(ast.AggFuncBitAnd, args, false) + aggF.Mode = AggFunctionMode(*expr.AggFuncMode) + return &bitAndFunction{aggFunction: aggF}, aggF.AggFuncDesc, nil } - return nil, errors.Errorf("Unknown aggregate function type %v", expr.Tp) + return nil, nil, errors.Errorf("Unknown aggregate function type %v", expr.Tp) } // AggEvaluateContext is used to store intermediate result when calculating aggregate functions. diff --git a/pkg/store/mockstore/mockcopr/cop_handler_dag.go b/pkg/store/mockstore/mockcopr/cop_handler_dag.go index 7d383a59c3745..8bdf3da053186 100644 --- a/pkg/store/mockstore/mockcopr/cop_handler_dag.go +++ b/pkg/store/mockstore/mockcopr/cop_handler_dag.go @@ -325,7 +325,7 @@ func (h coprHandler) getAggInfo(ctx *dagContext, executor *tipb.Executor) ([]agg var relatedColOffsets []int for _, expr := range executor.Aggregation.AggFunc { var aggExpr aggregation.Aggregation - aggExpr, err = aggregation.NewDistAggFunc(expr, ctx.evalCtx.fieldTps, ctx.evalCtx.sctx) + aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.evalCtx.fieldTps, ctx.evalCtx.sctx) if err != nil { return nil, nil, nil, errors.Trace(err) } diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go b/pkg/store/mockstore/unistore/cophandler/cop_handler.go index 67d707f7b6c4a..20b8555d3c2a4 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go @@ -331,7 +331,7 @@ func getAggInfo(ctx *dagContext, pbAgg *tipb.Aggregation) ([]aggregation.Aggrega var err error for _, expr := range pbAgg.AggFunc { var aggExpr aggregation.Aggregation - aggExpr, err = aggregation.NewDistAggFunc(expr, ctx.fieldTps, ctx.sctx) + aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.fieldTps, ctx.sctx) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/pkg/store/mockstore/unistore/cophandler/mpp.go b/pkg/store/mockstore/unistore/cophandler/mpp.go index 92f5507c3484e..cbf3799c6ef70 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp.go @@ -512,7 +512,7 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { for _, aggFunc := range agg.AggFunc { ft := expression.PbTypeToFieldType(aggFunc.FieldType) e.fieldTypes = append(e.fieldTypes, ft) - aggExpr, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) + aggExpr, _, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) if err != nil { return nil, errors.Trace(err) }