From b3741f535132bc25bf4a3c8a6b87254ed4a259bf Mon Sep 17 00:00:00 2001 From: AilinKid <314806019@qq.com> Date: Mon, 5 Feb 2024 15:27:56 +0800 Subject: [PATCH] add default value logic for scalar final agg in mock mpp executor Signed-off-by: AilinKid <314806019@qq.com> --- .../mockstore/unistore/cophandler/mpp.go | 25 ++++++- .../mockstore/unistore/cophandler/mpp_exec.go | 66 ++++++++++++++----- 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/pkg/store/mockstore/unistore/cophandler/mpp.go b/pkg/store/mockstore/unistore/cophandler/mpp.go index cbf3799c6ef70..50679ec99a736 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp.go @@ -509,13 +509,16 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { return nil, errors.Trace(err) } e.children = []mppExec{chExec} + // restore the aggDesc from tipb-agg. + tmpAggDescs := make([]*aggregation.AggFuncDesc, 0, len(agg.AggFunc)) for _, aggFunc := range agg.AggFunc { ft := expression.PbTypeToFieldType(aggFunc.FieldType) e.fieldTypes = append(e.fieldTypes, ft) - aggExpr, _, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) + aggExpr, aggDesc, err := aggregation.NewDistAggFunc(aggFunc, chExec.getFieldTypes(), b.sctx) if err != nil { return nil, errors.Trace(err) } + tmpAggDescs = append(tmpAggDescs, aggDesc) e.aggExprs = append(e.aggExprs, aggExpr) } e.sctx = b.sctx @@ -530,6 +533,26 @@ func (b *mppExecBuilder) buildMPPAgg(agg *tipb.Aggregation) (*aggExec, error) { } e.groupByExprs = append(e.groupByExprs, gbyExpr) } + + // fill the default value. logic copied from `func (b *executorBuilder) buildHashAgg` + if len(e.groupByExprs) != 0 || aggregation.IsAllFirstRow(tmpAggDescs) { + e.DefaultVal = nil + } else { + // Only do this for final agg, see issue #35295, #30923 + isFinalAgg := false + if len(tmpAggDescs) > 0 { + if tmpAggDescs[0].Mode == aggregation.FinalMode || tmpAggDescs[0].Mode == aggregation.CompleteMode { + isFinalAgg = true + } + } + if isFinalAgg { + e.DefaultVal = chunk.NewChunkWithCapacity(e.fieldTypes, 1) + for i, aggDesc := range tmpAggDescs { + result := aggDesc.GetDefaultValue() + e.DefaultVal.AppendDatum(i, &result) + } + } + } return e, nil } diff --git a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go index ca2c7d600c424..80034b839eaf3 100644 --- a/pkg/store/mockstore/unistore/cophandler/mpp_exec.go +++ b/pkg/store/mockstore/unistore/cophandler/mpp_exec.go @@ -995,6 +995,7 @@ type aggExec struct { groupByRows []chunk.Row groupByTypes []*types.FieldType + DefaultVal *chunk.Chunk processed bool } @@ -1039,6 +1040,29 @@ func (e *aggExec) getContexts(groupKey []byte) []*aggregation.AggEvaluateContext return aggCtxs } +// processAllRows handles the aggregation logic inside. +// Special case for first_row in scalar agg case: +// 1: If all the aggregation functions are first_row/any_value, we should fill nothing, and let it empty. +// 2: If there exists some other non-first-value aggregations, we should fill the default value for both of them. +// +// mysql> select any_value(a) from (select * from t3) s; +// Empty set (0.01 sec) +// +// mysql> select count(a) from (select * from t3) s; +// +----------+ +// | count(a) | +// +----------+ +// | 0 | +// +----------+ +// 1 row in set (0.01 sec) +// +// mysql> select count(a), any_value(a) from (select * from t3) s; +// +----------+--------------+ +// | count(a) | any_value(a) | +// +----------+--------------+ +// | 0 | NULL | +// +----------+--------------+ +// 1 row in set (0.01 sec) func (e *aggExec) processAllRows() (*chunk.Chunk, error) { for { chk, err := e.children[0].next() @@ -1075,24 +1099,36 @@ func (e *aggExec) processAllRows() (*chunk.Chunk, error) { chk := chunk.NewChunkWithCapacity(e.fieldTypes, 0) - for i, gk := range e.groupKeys { - newRow := chunk.MutRowFromTypes(e.fieldTypes) - aggCtxs := e.getContexts(gk) - for i, agg := range e.aggExprs { - result := agg.GetResult(aggCtxs[i]) - if e.fieldTypes[i].GetType() == mysql.TypeLonglong && result.Kind() == types.KindMysqlDecimal { - var err error - result, err = result.ConvertTo(e.sctx.GetSessionVars().StmtCtx.TypeCtx(), e.fieldTypes[i]) - if err != nil { - return nil, errors.Trace(err) + // where len(e.groupKeys) equals to 0, that means there is no data in the below child source. + // And when e.DefaultVal is not nil, it means it's a scalar agg. Some default value should be + // filled. + // In the contrary with those even without group by items, the whole data(not empty) will be + // seen as one group, and classified into that group with key built as "". So its groupKeys here + // is not the size of 0. + // 1: len(e.groupKeys) == 0 indicates whether the source data is equal to empty-set. + // 2: e.DefaultVal != nil indicates whether this aggregate is a scalar aggregation. + if len(e.groupKeys) == 0 && e.DefaultVal != nil { + chk.Append(e.DefaultVal, 0, 1) + } else { + for i, gk := range e.groupKeys { + newRow := chunk.MutRowFromTypes(e.fieldTypes) + aggCtxs := e.getContexts(gk) + for i, agg := range e.aggExprs { + result := agg.GetResult(aggCtxs[i]) + if e.fieldTypes[i].GetType() == mysql.TypeLonglong && result.Kind() == types.KindMysqlDecimal { + var err error + result, err = result.ConvertTo(e.sctx.GetSessionVars().StmtCtx.TypeCtx(), e.fieldTypes[i]) + if err != nil { + return nil, errors.Trace(err) + } } + newRow.SetDatum(i, result) } - newRow.SetDatum(i, result) - } - if len(e.groupByRows) > 0 { - newRow.ShallowCopyPartialRow(len(e.aggExprs), e.groupByRows[i]) + if len(e.groupByRows) > 0 { + newRow.ShallowCopyPartialRow(len(e.aggExprs), e.groupByRows[i]) + } + chk.AppendRow(newRow.ToRow()) } - chk.AppendRow(newRow.ToRow()) } e.execSummary.updateOnlyRows(chk.NumRows()) return chk, nil