Skip to content

Commit

Permalink
executor: support window func for aggregate without frame clause
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx committed Jan 3, 2019
1 parent 91cdbf2 commit d7f739b
Show file tree
Hide file tree
Showing 16 changed files with 444 additions and 122 deletions.
23 changes: 17 additions & 6 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,7 @@ type StreamAggExec struct {
// isChildReturnEmpty indicates whether the child executor only returns an empty input.
isChildReturnEmpty bool
defaultVal *chunk.Chunk
StmtCtx *stmtctx.StatementContext
GroupByItems []expression.Expression
curGroupKey []types.Datum
tmpGroupKey []types.Datum
groupChecker *groupChecker
inputIter *chunk.Iterator4Chunk
inputRow chunk.Row
aggFuncs []aggfuncs.AggFunc
Expand Down Expand Up @@ -824,7 +821,7 @@ func (e *StreamAggExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) e
return errors.Trace(err)
}
for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() {
meetNewGroup, err := e.meetNewGroup(e.inputRow)
meetNewGroup, err := e.groupChecker.meetNewGroup(e.inputRow)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -911,8 +908,22 @@ func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error {
return nil
}

type groupChecker struct {
StmtCtx *stmtctx.StatementContext
GroupByItems []expression.Expression
curGroupKey []types.Datum
tmpGroupKey []types.Datum
}

func newGroupChecker(stmtCtx *stmtctx.StatementContext, items []expression.Expression) *groupChecker {
return &groupChecker{
StmtCtx: stmtCtx,
GroupByItems: items,
}
}

// meetNewGroup returns a value that represents if the new group is different from last group.
func (e *StreamAggExec) meetNewGroup(row chunk.Row) (bool, error) {
func (e *groupChecker) meetNewGroup(row chunk.Row) (bool, error) {
if len(e.GroupByItems) == 0 {
return false, nil
}
Expand Down
77 changes: 27 additions & 50 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/distsql"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/executor/windowfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/infoschema"
Expand Down Expand Up @@ -167,6 +168,8 @@ func (b *executorBuilder) build(p plannercore.Plan) Executor {
return b.buildIndexReader(v)
case *plannercore.PhysicalIndexLookUpReader:
return b.buildIndexLookUpReader(v)
case *plannercore.PhysicalWindow:
return b.buildWindow(v)
default:
b.err = ErrUnknownPlan.GenWithStack("Unknown Plan %T", p)
return nil
Expand Down Expand Up @@ -918,54 +921,6 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo
return e
}

// wrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) {
for _, f := range funcs {
// We do not need to wrap cast upon these functions,
// since the EvalXXX method called by the arg is determined by the corresponding arg type.
if f.Name == ast.AggFuncCount || f.Name == ast.AggFuncMin || f.Name == ast.AggFuncMax || f.Name == ast.AggFuncFirstRow {
continue
}
var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression
switch retTp := f.RetTp; retTp.EvalType() {
case types.ETInt:
castFunc = expression.WrapWithCastAsInt
case types.ETReal:
castFunc = expression.WrapWithCastAsReal
case types.ETString:
castFunc = expression.WrapWithCastAsString
case types.ETDecimal:
castFunc = expression.WrapWithCastAsDecimal
default:
panic("should never happen in executorBuilder.wrapCastForAggArgs")
}
for i := range f.Args {
f.Args[i] = castFunc(b.ctx, f.Args[i])
if f.Name != ast.AggFuncAvg && f.Name != ast.AggFuncSum {
continue
}
// After wrapping cast on the argument, flen etc. may not the same
// as the type of the aggregation function. The following part set
// the type of the argument exactly as the type of the aggregation
// function.
// Note: If the `Tp` of argument is the same as the `Tp` of the
// aggregation function, it will not wrap cast function on it
// internally. The reason of the special handling for `Column` is
// that the `RetType` of `Column` refers to the `infoschema`, so we
// need to set a new variable for it to avoid modifying the
// definition in `infoschema`.
if col, ok := f.Args[i].(*expression.Column); ok {
col.RetType = types.NewFieldType(col.RetType.Tp)
}
// originTp is used when the the `Tp` of column is TypeFloat32 while
// the type of the aggregation function is TypeFloat64.
originTp := f.Args[i].GetType().Tp
*(f.Args[i].GetType()) = *(f.RetTp)
f.Args[i].GetType().Tp = originTp
}
}
}

func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) Executor {
src := b.build(v.Children()[0])
if b.err != nil {
Expand Down Expand Up @@ -1055,9 +1010,8 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) Execu
}
e := &StreamAggExec{
baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src),
StmtCtx: b.ctx.GetSessionVars().StmtCtx,
groupChecker: newGroupChecker(b.ctx.GetSessionVars().StmtCtx, v.GroupByItems),
aggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)),
GroupByItems: v.GroupByItems,
}
if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) {
e.defaultVal = nil
Expand Down Expand Up @@ -1922,3 +1876,26 @@ func buildKvRangesForIndexJoin(sc *stmtctx.StatementContext, tableID, indexID in
})
return kvRanges, nil
}

func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec {
childExec := b.build(v.Children()[0])
if b.err != nil {
return nil
}
base := newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), childExec)
var groupByItems []expression.Expression
for _, item := range v.PartitionBy {
groupByItems = append(groupByItems, item.Col)
}
windowFunc, err := windowfuncs.BuildWindowFunc(b.ctx, v.WindowFuncDesc, len(v.Schema().Columns)-1)
if err != nil {
b.err = err
return nil
}
e := &WindowExec{baseExecutor: base,
windowFunc: windowFunc,
groupChecker: newGroupChecker(b.ctx.GetSessionVars().StmtCtx, groupByItems),
childCols: v.ChildCols,
}
return e
}
158 changes: 158 additions & 0 deletions executor/window.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package executor

import (
"context"
"time"

"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/executor/windowfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/util/chunk"
)

// WindowExec is the executor for window functions.
type WindowExec struct {
baseExecutor

groupChecker *groupChecker
inputIter *chunk.Iterator4Chunk
inputRow chunk.Row
groupRows []chunk.Row
childResults []*chunk.Chunk
windowFunc windowfuncs.WindowFunc
executed bool
childCols []*expression.Column
}

// Close implements the Executor Close interface.
func (e *WindowExec) Close() error {
e.childResults = nil
return errors.Trace(e.baseExecutor.Close())
}

// Next implements the Executor Next interface.
func (e *WindowExec) Next(ctx context.Context, chk *chunk.Chunk) error {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("windowExec.Next", opentracing.ChildOf(span.Context()))
defer span1.Finish()
}
if e.runtimeStats != nil {
start := time.Now()
defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }()
}
chk.Reset()
if e.windowFunc.HasRemainingResults() {
e.appendResult2Chunk(chk)
}
for !e.executed && (chk.NumRows() == 0 || chk.RemainedRows(chk.NumCols()-1) > 0) {
err := e.consumeOneGroup(ctx, chk)
if err != nil {
e.executed = true
return errors.Trace(err)
}
}
return nil
}

func (e *WindowExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) error {
for !e.executed {
if err := e.fetchChildIfNecessary(ctx, chk); err != nil {
return errors.Trace(err)
}
for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() {
meetNewGroup, err := e.groupChecker.meetNewGroup(e.inputRow)
if err != nil {
return errors.Trace(err)
}
if meetNewGroup {
err := e.consumeGroupRows(chk)
if err != nil {
return errors.Trace(err)
}
err = e.appendResult2Chunk(chk)
if err != nil {
return errors.Trace(err)
}
}
e.groupRows = append(e.groupRows, e.inputRow)
if meetNewGroup {
e.inputRow = e.inputIter.Next()
return nil
}
}
}
return nil
}

func (e *WindowExec) consumeGroupRows(chk *chunk.Chunk) error {
if len(e.groupRows) == 0 {
return nil
}
e.copyChk(chk)
var err error
e.groupRows, err = e.windowFunc.ProcessOneChunk(e.ctx, e.groupRows, chk)
return err
}

func (e *WindowExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk) (err error) {
if e.inputIter != nil && e.inputRow != e.inputIter.End() {
return nil
}

// Before fetching a new batch of input, we should consume the last groupChecker.
err = e.consumeGroupRows(chk)
if err != nil {
return errors.Trace(err)
}

childResult := e.children[0].newFirstChunk()
err = e.children[0].Next(ctx, childResult)
if err != nil {
return errors.Trace(err)
}
e.childResults = append(e.childResults, childResult)
// No more data.
if childResult.NumRows() == 0 {
e.executed = true
err = e.appendResult2Chunk(chk)
return errors.Trace(err)
}

e.inputIter = chunk.NewIterator4Chunk(childResult)
e.inputRow = e.inputIter.Begin()
return nil
}

// appendResult2Chunk appends result of all the aggregation functions to the
// result chunk, and reset the evaluation context for each aggregation.
func (e *WindowExec) appendResult2Chunk(chk *chunk.Chunk) error {
e.copyChk(chk)
var err error
e.groupRows, err = e.windowFunc.ExhaustResult(e.ctx, e.groupRows, chk)
return err
}

func (e *WindowExec) copyChk(chk *chunk.Chunk) {
if len(e.childResults) == 0 || chk.NumRows() > 0 {
return
}
childResult := e.childResults[0]
e.childResults = e.childResults[1:]
for i, col := range e.childCols {
chk.CopyColumns(childResult, i, col.Index)
}
}
41 changes: 41 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package executor_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/util/testkit"
)

func (s *testSuite2) TestWindowFunctions(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int, c int)")
tk.MustExec("set @@tidb_enable_window_function = 1")
defer func() {
tk.MustExec("set @@tidb_enable_window_function = 0")
}()
tk.MustExec("insert into t values (1,2,3),(4,3,2),(2,3,4)")
result := tk.MustQuery("select count(a) over () from t")
result.Check(testkit.Rows("3", "3", "3"))
result = tk.MustQuery("select sum(a) over () + count(a) over () from t")
result.Check(testkit.Rows("10", "10", "10"))
result = tk.MustQuery("select sum(a) over (partition by a) from t")
result.Check(testkit.Rows("1", "2", "4"))
result = tk.MustQuery("select 1 + sum(a) over (), count(a) over () from t")
result.Check(testkit.Rows("8 3", "8 3", "8 3"))
result = tk.MustQuery("select sum(t1.a) over() from t t1, t t2")
result.Check(testkit.Rows("21", "21", "21", "21", "21", "21", "21", "21", "21"))
}
Loading

0 comments on commit d7f739b

Please sign in to comment.