Skip to content

Commit

Permalink
executor: introduce a new execution framework for aggregate functions (
Browse files Browse the repository at this point in the history
  • Loading branch information
zz-jason committed Jun 29, 2018
1 parent 5a5aeb8 commit 3c05d77
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 11 deletions.
82 changes: 82 additions & 0 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"unsafe"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

// All the AggFunc implementations are listed here for navigation.
var (
// All the AggFunc implementations for "COUNT" are listed here.
// All the AggFunc implementations for "SUM" are listed here.
// All the AggFunc implementations for "AVG" are listed here.
// All the AggFunc implementations for "FIRSTROW" are listed here.
// All the AggFunc implementations for "MAX" are listed here.
// All the AggFunc implementations for "MIN" are listed here.
// All the AggFunc implementations for "GROUP_CONCAT" are listed here.
// All the AggFunc implementations for "BIT_OR" are listed here.
// All the AggFunc implementations for "BIT_XOR" are listed here.
// All the AggFunc implementations for "BIT_AND" are listed here.
)

// PartialResult represents data structure to store the partial result for the
// aggregate functions. Here we use unsafe.Pointer to allow the partial result
// to be any type.
type PartialResult unsafe.Pointer

// AggFunc is the interface to evaluate the aggregate functions.
type AggFunc interface {
// AllocPartialResult allocates a specific data structure to store the
// partial result, initializes it, and converts it to PartialResult to
// return back. Aggregate operator implementation, no matter it's a hash
// or stream, should hold this allocated PartialResult for the further
// operations like: "ResetPartialResult", "UpdatePartialResult".
AllocPartialResult() PartialResult

// ResetPartialResult resets the partial result to the original state for a
// specific aggregate function. It converts the input PartialResult to the
// specific data structure which stores the partial result and then reset
// every field to the proper original state.
ResetPartialResult(pr PartialResult)

// UpdatePartialResult updates the specific partial result for an aggregate
// function using the input rows which all belonging to the same data group.
// It converts the PartialResult to the specific data structure which stores
// the partial result and then iterates on the input rows and update that
// partial result according to the functionality and the state of the
// aggregate function.
UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error

// AppendFinalResult2Chunk finalizes the partial result and append the
// final result to the input chunk. Like other operations, it converts the
// input PartialResult to the specific data structure which stores the
// partial result and then calculates the final result and append that
// final result to the chunk provided.
AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error
}

type baseAggFunc struct {
// args stores the input arguments for an aggregate function, we should
// call arg.EvalXXX to get the actual input data for this function.
args []expression.Expression

// ordinal stores the ordinal of the columns in the output chunk, which is
// used to append the final result of this function.
ordinal int
}
97 changes: 97 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/expression/aggregation"
)

// Build is used to build a specific AggFunc implementation according to the
// input aggFuncDesc.
func Build(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
switch aggFuncDesc.Name {
case ast.AggFuncCount:
return buildCount(aggFuncDesc, ordinal)
case ast.AggFuncSum:
return buildSum(aggFuncDesc, ordinal)
case ast.AggFuncAvg:
return buildAvg(aggFuncDesc, ordinal)
case ast.AggFuncFirstRow:
return buildFirstRow(aggFuncDesc, ordinal)
case ast.AggFuncMax:
return buildMax(aggFuncDesc, ordinal)
case ast.AggFuncMin:
return buildMin(aggFuncDesc, ordinal)
case ast.AggFuncGroupConcat:
return buildGroupConcat(aggFuncDesc, ordinal)
case ast.AggFuncBitOr:
return buildBitOr(aggFuncDesc, ordinal)
case ast.AggFuncBitXor:
return buildBitXor(aggFuncDesc, ordinal)
case ast.AggFuncBitAnd:
return buildBitAnd(aggFuncDesc, ordinal)
}
return nil
}

// buildCount builds the AggFunc implementation for function "COUNT".
func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "SUM".
func buildSum(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "AVG".
func buildAvg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "FIRST_ROW".
func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "MAX".
func buildMax(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "MIN".
func buildMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "GROUP_CONCAT".
func buildGroupConcat(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "BIT_OR".
func buildBitOr(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "BIT_XOR".
func buildBitXor(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}

// buildCount builds the AggFunc implementation for function "BIT_AND".
func buildBitAnd(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return nil
}
81 changes: 71 additions & 10 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package executor

import (
"github.com/juju/errors"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/mysql"
Expand Down Expand Up @@ -189,11 +190,15 @@ type StreamAggExec struct {
curGroupKey []types.Datum
tmpGroupKey []types.Datum

// for chunk execution.
inputIter *chunk.Iterator4Chunk
inputRow chunk.Row
mutableRow chunk.MutRow
rowBuffer []types.Datum

// for the new execution framework of aggregate functions
newAggFuncs []aggfuncs.AggFunc
partialResults []aggfuncs.PartialResult
groupRows []chunk.Row
}

// Open implements the Executor Open interface.
Expand All @@ -209,9 +214,16 @@ func (e *StreamAggExec) Open(ctx context.Context) error {
e.mutableRow = chunk.MutRowFromTypes(e.retTypes())
e.rowBuffer = make([]types.Datum, 0, e.Schema().Len())

e.aggCtxs = make([]*aggregation.AggEvaluateContext, 0, len(e.AggFuncs))
for _, agg := range e.AggFuncs {
e.aggCtxs = append(e.aggCtxs, agg.CreateContext(e.ctx.GetSessionVars().StmtCtx))
if e.newAggFuncs != nil {
e.partialResults = make([]aggfuncs.PartialResult, 0, len(e.newAggFuncs))
for _, newAggFunc := range e.newAggFuncs {
e.partialResults = append(e.partialResults, newAggFunc.AllocPartialResult())
}
} else {
e.aggCtxs = make([]*aggregation.AggEvaluateContext, 0, len(e.AggFuncs))
for _, agg := range e.AggFuncs {
e.aggCtxs = append(e.aggCtxs, agg.CreateContext(e.ctx.GetSessionVars().StmtCtx))
}
}

return nil
Expand Down Expand Up @@ -242,23 +254,55 @@ func (e *StreamAggExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) e
return errors.Trace(err)
}
if meetNewGroup {
e.appendResult2Chunk(chk)
}
for i, af := range e.AggFuncs {
err := af.Update(e.aggCtxs[i], e.StmtCtx, e.inputRow)
err := e.consumeGroupRows()
if err != nil {
return errors.Trace(err)
}
err = e.appendResult2Chunk(chk)
if err != nil {
return errors.Trace(err)
}
}
if e.newAggFuncs != nil {
e.groupRows = append(e.groupRows, e.inputRow)
} else {
for i, af := range e.AggFuncs {
err := af.Update(e.aggCtxs[i], e.StmtCtx, e.inputRow)
if err != nil {
return errors.Trace(err)
}
}
}
if meetNewGroup {
e.inputRow = e.inputIter.Next()
return nil
}
}
if e.newAggFuncs != nil {
err := e.consumeGroupRows()
if err != nil {
return errors.Trace(err)
}
}
}
return nil
}

func (e *StreamAggExec) consumeGroupRows() error {
if len(e.groupRows) == 0 {
return nil
}

for i, newAggFunc := range e.newAggFuncs {
err := newAggFunc.UpdatePartialResult(e.ctx, e.groupRows, e.partialResults[i])
if err != nil {
return errors.Trace(err)
}
}
e.groupRows = e.groupRows[:0]
return nil
}

func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk) error {
if e.inputRow != e.inputIter.End() {
return nil
Expand All @@ -271,7 +315,10 @@ func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Ch
// No more data.
if e.childrenResults[0].NumRows() == 0 {
if e.hasData || len(e.GroupByItems) == 0 {
e.appendResult2Chunk(chk)
err := e.appendResult2Chunk(chk)
if err != nil {
return errors.Trace(err)
}
}
e.executed = true
return nil
Expand All @@ -285,14 +332,28 @@ func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Ch

// appendResult2Chunk appends result of all the aggregation functions to the
// result chunk, and reset the evaluation context for each aggregation.
func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) {
func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error {
if e.newAggFuncs != nil {
for i, newAggFunc := range e.newAggFuncs {
err := newAggFunc.AppendFinalResult2Chunk(e.ctx, e.partialResults[i], chk)
if err != nil {
return errors.Trace(err)
}
newAggFunc.ResetPartialResult(e.partialResults[i])
}
if len(e.newAggFuncs) == 0 {
chk.SetNumVirtualRows(chk.NumRows() + 1)
}
return nil
}
e.rowBuffer = e.rowBuffer[:0]
for i, af := range e.AggFuncs {
e.rowBuffer = append(e.rowBuffer, af.GetResult(e.aggCtxs[i]))
af.ResetContext(e.ctx.GetSessionVars().StmtCtx, e.aggCtxs[i])
}
e.mutableRow.SetDatums(e.rowBuffer...)
chk.AppendRow(e.mutableRow.ToRow())
return nil
}

// meetNewGroup returns a value that represents if the new group is different from last group.
Expand Down
15 changes: 14 additions & 1 deletion executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/distsql"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/infoschema"
Expand Down Expand Up @@ -919,8 +920,20 @@ func (b *executorBuilder) buildStreamAgg(v *plan.PhysicalStreamAgg) Executor {
AggFuncs: make([]aggregation.Aggregation, 0, len(v.AggFuncs)),
GroupByItems: v.GroupByItems,
}
for _, aggDesc := range v.AggFuncs {
newAggFuncs := make([]aggfuncs.AggFunc, 0, len(v.AggFuncs))
for i, aggDesc := range v.AggFuncs {
e.AggFuncs = append(e.AggFuncs, aggDesc.GetAggFunc())
newAggFunc := aggfuncs.Build(aggDesc, i)
if newAggFunc != nil {
newAggFuncs = append(newAggFuncs, newAggFunc)
}
}

// Once we have successfully build all the aggregate functions to the new
// aggregate function execution framework, we can store them to the stream
// aggregate operator to indicate it using the new execution framework.
if len(newAggFuncs) == len(v.AggFuncs) {
e.newAggFuncs = newAggFuncs
}
metrics.ExecutorCounter.WithLabelValues("StreamAggExec").Inc()
return e
Expand Down

0 comments on commit 3c05d77

Please sign in to comment.