Skip to content

Commit

Permalink
planner, executor: merge window functions with same specification name (
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored and db-storage committed May 29, 2019
1 parent 08c5028 commit e79d8f6
Show file tree
Hide file tree
Showing 17 changed files with 394 additions and 194 deletions.
34 changes: 21 additions & 13 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2114,43 +2114,51 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec
for _, item := range v.PartitionBy {
groupByItems = append(groupByItems, item.Col)
}
aggDesc := aggregation.NewAggFuncDesc(b.ctx, v.WindowFuncDesc.Name, v.WindowFuncDesc.Args, false)
resultColIdx := len(v.Schema().Columns) - 1
orderByCols := make([]*expression.Column, 0, len(v.OrderBy))
for _, item := range v.OrderBy {
orderByCols = append(orderByCols, item.Col)
}
agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols)
windowFuncs := make([]aggfuncs.AggFunc, 0, len(v.WindowFuncDescs))
partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs))
resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs)
for _, desc := range v.WindowFuncDescs {
aggDesc := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false)
agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols)
windowFuncs = append(windowFuncs, agg)
partialResults = append(partialResults, agg.AllocPartialResult())
resultColIdx++
}
var processor windowProcessor
if v.Frame == nil {
processor = &aggWindowProcessor{
windowFunc: agg,
partialResult: agg.AllocPartialResult(),
windowFuncs: windowFuncs,
partialResults: partialResults,
}
} else if v.Frame.Type == ast.Rows {
processor = &rowFrameWindowProcessor{
windowFunc: agg,
partialResult: agg.AllocPartialResult(),
start: v.Frame.Start,
end: v.Frame.End,
windowFuncs: windowFuncs,
partialResults: partialResults,
start: v.Frame.Start,
end: v.Frame.End,
}
} else {
cmpResult := int64(-1)
if len(v.OrderBy) > 0 && v.OrderBy[0].Desc {
cmpResult = 1
}
processor = &rangeFrameWindowProcessor{
windowFunc: agg,
partialResult: agg.AllocPartialResult(),
windowFuncs: windowFuncs,
partialResults: partialResults,
start: v.Frame.Start,
end: v.Frame.End,
orderByCols: orderByCols,
expectedCmpResult: cmpResult,
}
}
return &WindowExec{baseExecutor: base,
processor: processor,
groupChecker: newGroupChecker(b.ctx.GetSessionVars().StmtCtx, groupByItems),
processor: processor,
groupChecker: newGroupChecker(b.ctx.GetSessionVars().StmtCtx, groupByItems),
numWindowFuncs: len(v.WindowFuncDescs),
}
}

Expand Down
98 changes: 57 additions & 41 deletions executor/window.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type WindowExec struct {
meetNewGroup bool
remainingRowsInGroup int
remainingRowsInChunk int
numWindowFuncs int
processor windowProcessor
}

Expand Down Expand Up @@ -171,7 +172,7 @@ func (e *WindowExec) copyChk(chk *chunk.Chunk) {
childResult := e.childResults[0]
e.childResults = e.childResults[1:]
e.remainingRowsInChunk = childResult.NumRows()
columns := e.Schema().Columns[:len(e.Schema().Columns)-1]
columns := e.Schema().Columns[:len(e.Schema().Columns)-e.numWindowFuncs]
for i, col := range columns {
chk.MakeRefTo(i, childResult, col.Index)
}
Expand All @@ -190,38 +191,47 @@ type windowProcessor interface {
}

type aggWindowProcessor struct {
windowFunc aggfuncs.AggFunc
partialResult aggfuncs.PartialResult
windowFuncs []aggfuncs.AggFunc
partialResults []aggfuncs.PartialResult
}

func (p *aggWindowProcessor) consumeGroupRows(ctx sessionctx.Context, rows []chunk.Row) ([]chunk.Row, error) {
err := p.windowFunc.UpdatePartialResult(ctx, rows, p.partialResult)
for i, windowFunc := range p.windowFuncs {
err := windowFunc.UpdatePartialResult(ctx, rows, p.partialResults[i])
if err != nil {
return nil, err
}
}
rows = rows[:0]
return rows, err
return rows, nil
}

func (p *aggWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, rows []chunk.Row, chk *chunk.Chunk, remained int) ([]chunk.Row, error) {
for remained > 0 {
// TODO: We can extend the agg func interface to avoid the `for` loop here.
err := p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk)
if err != nil {
return rows, err
for i, windowFunc := range p.windowFuncs {
// TODO: We can extend the agg func interface to avoid the `for` loop here.
err := windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk)
if err != nil {
return nil, err
}
}
remained--
}
return rows, nil
}

func (p *aggWindowProcessor) resetPartialResult() {
p.windowFunc.ResetPartialResult(p.partialResult)
for i, windowFunc := range p.windowFuncs {
windowFunc.ResetPartialResult(p.partialResults[i])
}
}

type rowFrameWindowProcessor struct {
windowFunc aggfuncs.AggFunc
partialResult aggfuncs.PartialResult
start *core.FrameBound
end *core.FrameBound
curRowIdx uint64
windowFuncs []aggfuncs.AggFunc
partialResults []aggfuncs.PartialResult
start *core.FrameBound
end *core.FrameBound
curRowIdx uint64
}

func (p *rowFrameWindowProcessor) getStartOffset(numRows uint64) uint64 {
Expand Down Expand Up @@ -283,33 +293,36 @@ func (p *rowFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, row
p.curRowIdx++
remained--
if start >= end {
err := p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk)
if err != nil {
return nil, err
for i, windowFunc := range p.windowFuncs {
err := windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk)
if err != nil {
return nil, err
}
}
continue
}
err := p.windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResult)
if err != nil {
return nil, err
}
err = p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk)
if err != nil {
return nil, err
for i, windowFunc := range p.windowFuncs {
err := windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResults[i])
if err != nil {
return nil, err
}
err = windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk)
if err != nil {
return nil, err
}
windowFunc.ResetPartialResult(p.partialResults[i])
}
p.windowFunc.ResetPartialResult(p.partialResult)
}
return rows, nil
}

func (p *rowFrameWindowProcessor) resetPartialResult() {
p.windowFunc.ResetPartialResult(p.partialResult)
p.curRowIdx = 0
}

type rangeFrameWindowProcessor struct {
windowFunc aggfuncs.AggFunc
partialResult aggfuncs.PartialResult
windowFuncs []aggfuncs.AggFunc
partialResults []aggfuncs.PartialResult
start *core.FrameBound
end *core.FrameBound
curRowIdx uint64
Expand Down Expand Up @@ -385,21 +398,25 @@ func (p *rangeFrameWindowProcessor) appendResult2Chunk(ctx sessionctx.Context, r
p.curRowIdx++
remained--
if start >= end {
err := p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk)
if err != nil {
return nil, err
for i, windowFunc := range p.windowFuncs {
err := windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk)
if err != nil {
return nil, err
}
}
continue
}
err = p.windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResult)
if err != nil {
return nil, err
}
err = p.windowFunc.AppendFinalResult2Chunk(ctx, p.partialResult, chk)
if err != nil {
return nil, err
for i, windowFunc := range p.windowFuncs {
err := windowFunc.UpdatePartialResult(ctx, rows[start:end], p.partialResults[i])
if err != nil {
return nil, err
}
err = windowFunc.AppendFinalResult2Chunk(ctx, p.partialResults[i], chk)
if err != nil {
return nil, err
}
windowFunc.ResetPartialResult(p.partialResults[i])
}
p.windowFunc.ResetPartialResult(p.partialResult)
}
return rows, nil
}
Expand All @@ -409,7 +426,6 @@ func (p *rangeFrameWindowProcessor) consumeGroupRows(ctx sessionctx.Context, row
}

func (p *rangeFrameWindowProcessor) resetPartialResult() {
p.windowFunc.ResetPartialResult(p.partialResult)
p.curRowIdx = 0
p.lastStartOffset = 0
p.lastEndOffset = 0
Expand Down
7 changes: 7 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,11 @@ func (s *testSuite4) TestWindowFunctions(c *C) {
"5 <nil> 2013-01-01 00:00:00 15",
),
)

result = tk.MustQuery("select sum(a) over w, sum(b) over w from t window w as (order by a)")
result.Check(testkit.Rows("2 3", "2 3", "6 6", "6 6"))
result = tk.MustQuery("select row_number() over w, sum(b) over w from t window w as (order by a)")
result.Check(testkit.Rows("1 3", "2 3", "3 6", "4 6"))
result = tk.MustQuery("select row_number() over w, sum(b) over w from t window w as (rows between 1 preceding and 1 following)")
result.Check(testkit.Rows("1 3", "2 4", "3 5", "4 3"))
}
8 changes: 4 additions & 4 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1158,10 +1158,10 @@ func (p *LogicalWindow) exhaustPhysicalPlans(prop *property.PhysicalProperty) []
return nil
}
window := PhysicalWindow{
WindowFuncDesc: p.WindowFuncDesc,
PartitionBy: p.PartitionBy,
OrderBy: p.OrderBy,
Frame: p.Frame,
WindowFuncDescs: p.WindowFuncDescs,
PartitionBy: p.PartitionBy,
OrderBy: p.OrderBy,
Frame: p.Frame,
}.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProperty)
window.SetSchema(p.Schema())
return []PhysicalPlan{window}
Expand Down
13 changes: 12 additions & 1 deletion planner/core/explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ func (p *PhysicalWindow) formatFrameBound(buffer *bytes.Buffer, bound *FrameBoun

// ExplainInfo implements PhysicalPlan interface.
func (p *PhysicalWindow) ExplainInfo() string {
buffer := bytes.NewBufferString(p.WindowFuncDesc.String())
buffer := bytes.NewBufferString("")
formatWindowFuncDescs(buffer, p.WindowFuncDescs)
buffer.WriteString(" over(")
isFirst := true
if len(p.PartitionBy) > 0 {
Expand Down Expand Up @@ -370,3 +371,13 @@ func (p *PhysicalWindow) ExplainInfo() string {
buffer.WriteString(")")
return buffer.String()
}

func formatWindowFuncDescs(buffer *bytes.Buffer, descs []*aggregation.WindowFuncDesc) *bytes.Buffer {
for i, desc := range descs {
if i != 0 {
buffer.WriteString(", ")
}
buffer.WriteString(desc.String())
}
return buffer
}
44 changes: 22 additions & 22 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
driver "github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/stringutil"
)
Expand Down Expand Up @@ -82,14 +82,14 @@ func (b *PlanBuilder) rewriteInsertOnDuplicateUpdate(exprNode ast.ExprNode, mock
// asScalar means whether this expression must be treated as a scalar expression.
// And this function returns a result expression, a new plan that may have apply or semi-join.
func (b *PlanBuilder) rewrite(exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (expression.Expression, LogicalPlan, error) {
expr, resultPlan, err := b.rewriteWithPreprocess(exprNode, p, aggMapper, asScalar, nil)
expr, resultPlan, err := b.rewriteWithPreprocess(exprNode, p, aggMapper, nil, asScalar, nil)
return expr, resultPlan, err
}

// rewriteWithPreprocess is for handling the situation that we need to adjust the input ast tree
// before really using its node in `expressionRewriter.Leave`. In that case, we first call
// er.preprocess(expr), which returns a new expr. Then we use the new expr in `Leave`.
func (b *PlanBuilder) rewriteWithPreprocess(exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool, preprocess func(ast.Node) ast.Node) (expression.Expression, LogicalPlan, error) {
func (b *PlanBuilder) rewriteWithPreprocess(exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, asScalar bool, preprocess func(ast.Node) ast.Node) (expression.Expression, LogicalPlan, error) {
b.rewriterCounter++
defer func() { b.rewriterCounter-- }()

Expand All @@ -103,6 +103,7 @@ func (b *PlanBuilder) rewriteWithPreprocess(exprNode ast.ExprNode, p LogicalPlan
}

rewriter.aggrMap = aggMapper
rewriter.windowMap = windowMapper
rewriter.asScalar = asScalar
rewriter.preprocess = preprocess

Expand Down Expand Up @@ -153,13 +154,14 @@ func (b *PlanBuilder) rewriteExprNode(rewriter *expressionRewriter, exprNode ast
}

type expressionRewriter struct {
ctxStack []expression.Expression
p LogicalPlan
schema *expression.Schema
err error
aggrMap map[*ast.AggregateFuncExpr]int
b *PlanBuilder
ctx sessionctx.Context
ctxStack []expression.Expression
p LogicalPlan
schema *expression.Schema
err error
aggrMap map[*ast.AggregateFuncExpr]int
windowMap map[*ast.WindowFuncExpr]int
b *PlanBuilder
ctx sessionctx.Context

// asScalar indicates the return value must be a scalar value.
// NOTE: This value can be changed during expression rewritten.
Expand Down Expand Up @@ -315,7 +317,16 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
er.ctxStack = append(er.ctxStack, expression.NewValuesFunc(er.ctx, col.Index, col.RetType))
return inNode, true
case *ast.WindowFuncExpr:
return er.handleWindowFunction(v)
index, ok := -1, false
if er.windowMap != nil {
index, ok = er.windowMap[v]
}
if !ok {
er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F)
return inNode, true
}
er.ctxStack = append(er.ctxStack, er.schema.Columns[index])
return inNode, true
case *ast.FuncCallExpr:
if _, ok := expression.DisableFoldFunctions[v.FnName.L]; ok {
er.disableFoldCounter++
Expand All @@ -326,17 +337,6 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
return inNode, false
}

func (er *expressionRewriter) handleWindowFunction(v *ast.WindowFuncExpr) (ast.Node, bool) {
windowPlan, err := er.b.buildWindowFunction(er.p, v, er.aggrMap)
if err != nil {
er.err = err
return v, false
}
er.ctxStack = append(er.ctxStack, windowPlan.GetWindowResultColumn())
er.p = windowPlan
return v, true
}

func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r expression.Expression, not bool) {
var condition expression.Expression
if rCol, ok := r.(*expression.Column); ok && (er.asScalar || not) {
Expand Down
Loading

0 comments on commit e79d8f6

Please sign in to comment.