Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: support window func for aggregate without frame clause #8899

Merged
merged 18 commits into from
Jan 15, 2019
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,7 @@ type StreamAggExec struct {
// isChildReturnEmpty indicates whether the child executor only returns an empty input.
isChildReturnEmpty bool
defaultVal *chunk.Chunk
StmtCtx *stmtctx.StatementContext
GroupByItems []expression.Expression
curGroupKey []types.Datum
tmpGroupKey []types.Datum
groupChecker *groupChecker
inputIter *chunk.Iterator4Chunk
inputRow chunk.Row
aggFuncs []aggfuncs.AggFunc
Expand Down Expand Up @@ -824,7 +821,7 @@ func (e *StreamAggExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) e
return errors.Trace(err)
}
for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() {
meetNewGroup, err := e.meetNewGroup(e.inputRow)
meetNewGroup, err := e.groupChecker.meetNewGroup(e.inputRow)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -911,8 +908,23 @@ func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error {
return nil
}

type groupChecker struct {
StmtCtx *stmtctx.StatementContext
GroupByItems []expression.Expression
curGroupKey []types.Datum
tmpGroupKey []types.Datum
}

func newGroupChecker(stmtCtx *stmtctx.StatementContext, items []expression.Expression) *groupChecker {
return &groupChecker{
StmtCtx: stmtCtx,
GroupByItems: items,
}
}

// meetNewGroup returns a value that represents if the new group is different from last group.
func (e *StreamAggExec) meetNewGroup(row chunk.Row) (bool, error) {
// TODO: Since all the group by items are only a column reference, guaranteed by building projection below aggregation, we can directly compare data in a chunk.
func (e *groupChecker) meetNewGroup(row chunk.Row) (bool, error) {
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
if len(e.GroupByItems) == 0 {
return false, nil
}
Expand Down
78 changes: 28 additions & 50 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ func (b *executorBuilder) build(p plannercore.Plan) Executor {
return b.buildIndexReader(v)
case *plannercore.PhysicalIndexLookUpReader:
return b.buildIndexLookUpReader(v)
case *plannercore.PhysicalWindow:
return b.buildWindow(v)
default:
b.err = ErrUnknownPlan.GenWithStack("Unknown Plan %T", p)
return nil
Expand Down Expand Up @@ -922,54 +924,6 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo
return e
}

// wrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) {
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
for _, f := range funcs {
// We do not need to wrap cast upon these functions,
// since the EvalXXX method called by the arg is determined by the corresponding arg type.
if f.Name == ast.AggFuncCount || f.Name == ast.AggFuncMin || f.Name == ast.AggFuncMax || f.Name == ast.AggFuncFirstRow {
continue
}
var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression
switch retTp := f.RetTp; retTp.EvalType() {
case types.ETInt:
castFunc = expression.WrapWithCastAsInt
case types.ETReal:
castFunc = expression.WrapWithCastAsReal
case types.ETString:
castFunc = expression.WrapWithCastAsString
case types.ETDecimal:
castFunc = expression.WrapWithCastAsDecimal
default:
panic("should never happen in executorBuilder.wrapCastForAggArgs")
}
for i := range f.Args {
f.Args[i] = castFunc(b.ctx, f.Args[i])
if f.Name != ast.AggFuncAvg && f.Name != ast.AggFuncSum {
continue
}
// After wrapping cast on the argument, flen etc. may not the same
// as the type of the aggregation function. The following part set
// the type of the argument exactly as the type of the aggregation
// function.
// Note: If the `Tp` of argument is the same as the `Tp` of the
// aggregation function, it will not wrap cast function on it
// internally. The reason of the special handling for `Column` is
// that the `RetType` of `Column` refers to the `infoschema`, so we
// need to set a new variable for it to avoid modifying the
// definition in `infoschema`.
if col, ok := f.Args[i].(*expression.Column); ok {
col.RetType = types.NewFieldType(col.RetType.Tp)
}
// originTp is used when the the `Tp` of column is TypeFloat32 while
// the type of the aggregation function is TypeFloat64.
originTp := f.Args[i].GetType().Tp
*(f.Args[i].GetType()) = *(f.RetTp)
f.Args[i].GetType().Tp = originTp
}
}
}

func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) Executor {
src := b.build(v.Children()[0])
if b.err != nil {
Expand Down Expand Up @@ -1059,9 +1013,8 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) Execu
}
e := &StreamAggExec{
baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src),
StmtCtx: b.ctx.GetSessionVars().StmtCtx,
groupChecker: newGroupChecker(b.ctx.GetSessionVars().StmtCtx, v.GroupByItems),
aggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)),
GroupByItems: v.GroupByItems,
}
if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) {
e.defaultVal = nil
Expand Down Expand Up @@ -1926,3 +1879,28 @@ func buildKvRangesForIndexJoin(sc *stmtctx.StatementContext, tableID, indexID in
})
return kvRanges, nil
}

func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec {
childExec := b.build(v.Children()[0])
if b.err != nil {
return nil
}
base := newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), childExec)
var groupByItems []expression.Expression
for _, item := range v.PartitionBy {
groupByItems = append(groupByItems, item.Col)
}
aggDesc := aggregation.NewAggFuncDesc(b.ctx, v.WindowFuncDesc.Name, v.WindowFuncDesc.Args, false)
resultColIdx := len(v.Schema().Columns) - 1
agg := aggfuncs.Build(b.ctx, aggDesc, resultColIdx)
if agg == nil {
b.err = errors.Trace(errors.New("window evaluator only support aggregation functions without frame now"))
return nil
}
e := &WindowExec{baseExecutor: base,
windowFunc: agg,
partialResult: agg.AllocPartialResult(),
groupChecker: newGroupChecker(b.ctx.GetSessionVars().StmtCtx, groupByItems),
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
}
return e
}
176 changes: 176 additions & 0 deletions executor/window.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package executor

import (
"context"
"time"

"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/util/chunk"
)

// WindowExec is the executor for window functions. Note that it only supports aggregation without frame clause now.
type WindowExec struct {
winoros marked this conversation as resolved.
Show resolved Hide resolved
baseExecutor

groupChecker *groupChecker
inputIter *chunk.Iterator4Chunk
inputRow chunk.Row
groupRows []chunk.Row
childResults []*chunk.Chunk
windowFunc aggfuncs.AggFunc
partialResult aggfuncs.PartialResult
executed bool
meetNewGroup bool
remainingRowsInGroup int64
remainingRowsInChunk int
}

// Close implements the Executor Close interface.
func (e *WindowExec) Close() error {
e.childResults = nil
return errors.Trace(e.baseExecutor.Close())
}

// Next implements the Executor Next interface.
func (e *WindowExec) Next(ctx context.Context, chk *chunk.RecordBatch) error {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("windowExec.Next", opentracing.ChildOf(span.Context()))
defer span1.Finish()
}
if e.runtimeStats != nil {
start := time.Now()
defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }()
}
chk.Reset()
if e.meetNewGroup && e.remainingRowsInGroup > 0 {
err := e.appendResult2Chunk(chk.Chunk)
if err != nil {
return err
}
}
for !e.executed && (chk.NumRows() == 0 || e.remainingRowsInChunk > 0) {
err := e.consumeOneGroup(ctx, chk.Chunk)
if err != nil {
e.executed = true
return errors.Trace(err)
}
}
return nil
}

func (e *WindowExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) error {
var err error
if err = e.fetchChildIfNecessary(ctx, chk); err != nil {
return errors.Trace(err)
}
for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() {
e.meetNewGroup, err = e.groupChecker.meetNewGroup(e.inputRow)
if err != nil {
return errors.Trace(err)
}
if e.meetNewGroup {
err := e.consumeGroupRows()
if err != nil {
return errors.Trace(err)
}
err = e.appendResult2Chunk(chk)
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return errors.Trace(err)
}
}
e.groupRows = append(e.groupRows, e.inputRow)
if e.meetNewGroup {
e.inputRow = e.inputIter.Next()
return nil
}
}
return nil
}

func (e *WindowExec) consumeGroupRows() error {
if len(e.groupRows) == 0 {
return nil
}
err := e.windowFunc.UpdatePartialResult(e.ctx, e.groupRows, e.partialResult)
if err != nil {
return errors.Trace(err)
}
e.remainingRowsInGroup += int64(len(e.groupRows))
e.groupRows = e.groupRows[:0]
return nil
}

func (e *WindowExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk) (err error) {
if e.inputIter != nil && e.inputRow != e.inputIter.End() {
return nil
}

// Before fetching a new batch of input, we should consume the last group rows.
err = e.consumeGroupRows()
if err != nil {
return errors.Trace(err)
}

childResult := e.children[0].newFirstChunk()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we allocate a Chunk with tidb_max_chunk_size capacity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tidb_max_chunk_size may result in unnecessary results.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, to avoid unnecessary results, we should use a smaller chunk size for the first Chunk used to fetch child result. But in the next time, we should use a bigger chunk size. This function always use newFirstChunk() to fetch the child result, it's not efficient for big tables.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the way mentioned in the comment, currently, we need to call chunk.Renew to create a larger chunk, and it requires to remember the last chunk... Maybe we can refine it after we can pass down the chunk size to the child.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you file an issue to trace this problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, #9045

err = e.children[0].Next(ctx, &chunk.RecordBatch{Chunk: childResult})
if err != nil {
return errors.Trace(err)
}
e.childResults = append(e.childResults, childResult)
// No more data.
if childResult.NumRows() == 0 {
e.executed = true
err = e.appendResult2Chunk(chk)
return errors.Trace(err)
}

e.inputIter = chunk.NewIterator4Chunk(childResult)
e.inputRow = e.inputIter.Begin()
return nil
}

// appendResult2Chunk appends result of the window function to the result chunk.
func (e *WindowExec) appendResult2Chunk(chk *chunk.Chunk) error {
e.copyChk(chk)
for e.remainingRowsInGroup > 0 && e.remainingRowsInChunk > 0 {
// TODO: We can extend the agg func interface to avoid the `for` loop here.
err := e.windowFunc.AppendFinalResult2Chunk(e.ctx, e.partialResult, chk)
if err != nil {
return err
}
e.remainingRowsInGroup--
e.remainingRowsInChunk--
}
if e.remainingRowsInGroup == 0 {
e.windowFunc.ResetPartialResult(e.partialResult)
}
return nil
}

func (e *WindowExec) copyChk(chk *chunk.Chunk) {
if len(e.childResults) == 0 || chk.NumRows() > 0 {
return
}
childResult := e.childResults[0]
e.childResults = e.childResults[1:]
e.remainingRowsInChunk = childResult.NumRows()
columns := e.Schema().Columns[:len(e.Schema().Columns)-1]
for i, col := range columns {
chk.MakeRefTo(i, childResult, col.Index)
}
}
43 changes: 43 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package executor_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/util/testkit"
)

func (s *testSuite2) TestWindowFunctions(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int, c int)")
tk.MustExec("set @@tidb_enable_window_function = 1")
defer func() {
tk.MustExec("set @@tidb_enable_window_function = 0")
}()
tk.MustExec("insert into t values (1,2,3),(4,3,2),(2,3,4)")
result := tk.MustQuery("select count(a) over () from t")
result.Check(testkit.Rows("3", "3", "3"))
result = tk.MustQuery("select sum(a) over () + count(a) over () from t")
result.Check(testkit.Rows("10", "10", "10"))
result = tk.MustQuery("select sum(a) over (partition by a) from t")
result.Check(testkit.Rows("1", "2", "4"))
result = tk.MustQuery("select 1 + sum(a) over (), count(a) over () from t")
result.Check(testkit.Rows("8 3", "8 3", "8 3"))
result = tk.MustQuery("select sum(t1.a) over() from t t1, t t2")
result.Check(testkit.Rows("21", "21", "21", "21", "21", "21", "21", "21", "21"))
result = tk.MustQuery("select _tidb_rowid, sum(t.a) over() from t")
result.Check(testkit.Rows("1 7", "2 7", "3 7"))
}
Loading