From 67add75d15d4a3cc1e8b23a1cec25420badee2f1 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Wed, 2 Jan 2019 13:06:15 +0800 Subject: [PATCH] executor: support window func for aggregate without frame clause --- executor/aggregate.go | 23 ++- executor/builder.go | 77 ++++----- executor/window.go | 158 +++++++++++++++++++ executor/window_test.go | 41 +++++ executor/windowfuncs/window_funcs.go | 67 ++++++++ expression/aggregation/base_func.go | 46 ++++++ expression/aggregation/window_func.go | 29 ++++ expression/column.go | 1 + go.mod | 2 + go.sum | 1 + planner/core/errors.go | 103 ++++++------ planner/core/exhaust_physical_plans.go | 17 ++ planner/core/explain.go | 6 + planner/core/expression_rewriter.go | 15 +- planner/core/initialize.go | 16 ++ planner/core/logical_plan_builder.go | 207 +++++++++++++++++++++++-- planner/core/logical_plan_test.go | 93 +++++++++++ planner/core/logical_plans.go | 17 ++ planner/core/physical_plans.go | 12 ++ planner/core/planbuilder.go | 11 ++ planner/core/resolve_indices.go | 21 +++ planner/core/rule_column_pruning.go | 31 +++- planner/core/stats.go | 13 ++ planner/core/stringer.go | 4 + util/chunk/chunk.go | 10 ++ 25 files changed, 898 insertions(+), 123 deletions(-) create mode 100644 executor/window.go create mode 100644 executor/window_test.go create mode 100644 executor/windowfuncs/window_funcs.go create mode 100644 expression/aggregation/window_func.go diff --git a/executor/aggregate.go b/executor/aggregate.go index f32f5de50ce22..0ee21880796c1 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -760,10 +760,7 @@ type StreamAggExec struct { // isChildReturnEmpty indicates whether the child executor only returns an empty input. isChildReturnEmpty bool defaultVal *chunk.Chunk - StmtCtx *stmtctx.StatementContext - GroupByItems []expression.Expression - curGroupKey []types.Datum - tmpGroupKey []types.Datum + group *group inputIter *chunk.Iterator4Chunk inputRow chunk.Row aggFuncs []aggfuncs.AggFunc @@ -824,7 +821,7 @@ func (e *StreamAggExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) e return errors.Trace(err) } for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { - meetNewGroup, err := e.meetNewGroup(e.inputRow) + meetNewGroup, err := e.group.meetNewGroup(e.inputRow) if err != nil { return errors.Trace(err) } @@ -911,8 +908,22 @@ func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error { return nil } +type group struct { + StmtCtx *stmtctx.StatementContext + GroupByItems []expression.Expression + curGroupKey []types.Datum + tmpGroupKey []types.Datum +} + +func newGroup(stmtCtx *stmtctx.StatementContext, items []expression.Expression) *group { + return &group{ + StmtCtx: stmtCtx, + GroupByItems: items, + } +} + // meetNewGroup returns a value that represents if the new group is different from last group. -func (e *StreamAggExec) meetNewGroup(row chunk.Row) (bool, error) { +func (e *group) meetNewGroup(row chunk.Row) (bool, error) { if len(e.GroupByItems) == 0 { return false, nil } diff --git a/executor/builder.go b/executor/builder.go index 659432fba12e3..c8f9cf16a6cfb 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -17,6 +17,7 @@ import ( "bytes" "context" "fmt" + "github.com/pingcap/tidb/executor/windowfuncs" "math" "sort" "strings" @@ -168,6 +169,8 @@ func (b *executorBuilder) build(p plannercore.Plan) Executor { return b.buildIndexReader(v) case *plannercore.PhysicalIndexLookUpReader: return b.buildIndexLookUpReader(v) + case *plannercore.PhysicalWindow: + return b.buildWindow(v) default: b.err = ErrUnknownPlan.GenWithStack("Unknown Plan %T", p) return nil @@ -919,54 +922,6 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo return e } -// wrapCastForAggArgs wraps the args of an aggregate function with a cast function. -func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) { - for _, f := range funcs { - // We do not need to wrap cast upon these functions, - // since the EvalXXX method called by the arg is determined by the corresponding arg type. - if f.Name == ast.AggFuncCount || f.Name == ast.AggFuncMin || f.Name == ast.AggFuncMax || f.Name == ast.AggFuncFirstRow { - continue - } - var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression - switch retTp := f.RetTp; retTp.EvalType() { - case types.ETInt: - castFunc = expression.WrapWithCastAsInt - case types.ETReal: - castFunc = expression.WrapWithCastAsReal - case types.ETString: - castFunc = expression.WrapWithCastAsString - case types.ETDecimal: - castFunc = expression.WrapWithCastAsDecimal - default: - panic("should never happen in executorBuilder.wrapCastForAggArgs") - } - for i := range f.Args { - f.Args[i] = castFunc(b.ctx, f.Args[i]) - if f.Name != ast.AggFuncAvg && f.Name != ast.AggFuncSum { - continue - } - // After wrapping cast on the argument, flen etc. may not the same - // as the type of the aggregation function. The following part set - // the type of the argument exactly as the type of the aggregation - // function. - // Note: If the `Tp` of argument is the same as the `Tp` of the - // aggregation function, it will not wrap cast function on it - // internally. The reason of the special handling for `Column` is - // that the `RetType` of `Column` refers to the `infoschema`, so we - // need to set a new variable for it to avoid modifying the - // definition in `infoschema`. - if col, ok := f.Args[i].(*expression.Column); ok { - col.RetType = types.NewFieldType(col.RetType.Tp) - } - // originTp is used when the the `Tp` of column is TypeFloat32 while - // the type of the aggregation function is TypeFloat64. - originTp := f.Args[i].GetType().Tp - *(f.Args[i].GetType()) = *(f.RetTp) - f.Args[i].GetType().Tp = originTp - } - } -} - // buildProjBelowAgg builds a ProjectionExec below AggregationExec. // If all the args of `aggFuncs`, and all the item of `groupByItems` // are columns or constants, we do not need to build the `proj`. @@ -975,7 +930,9 @@ func (b *executorBuilder) buildProjBelowAgg(aggFuncs []*aggregation.AggFuncDesc, // If the mode is FinalMode, we do not need to wrap cast upon the args, // since the types of the args are already the expected. if len(aggFuncs) > 0 && aggFuncs[0].Mode != aggregation.FinalMode { - b.wrapCastForAggArgs(aggFuncs) + for _, agg := range aggFuncs { + agg.WrapCastForAggArgs(b.ctx) + } } for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ { f := aggFuncs[i] @@ -1128,9 +1085,8 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) Execu src = b.buildProjBelowAgg(v.AggFuncs, v.GroupByItems, src) e := &StreamAggExec{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), src), - StmtCtx: b.ctx.GetSessionVars().StmtCtx, + group: newGroup(b.ctx.GetSessionVars().StmtCtx, v.GroupByItems), aggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)), - GroupByItems: v.GroupByItems, } if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { e.defaultVal = nil @@ -1995,3 +1951,22 @@ func buildKvRangesForIndexJoin(sc *stmtctx.StatementContext, tableID, indexID in }) return kvRanges, nil } + +func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec { + childExec := b.build(v.Children()[0]) + if b.err != nil { + b.err = errors.Trace(b.err) + return nil + } + base := newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), childExec) + var groupByItems []expression.Expression + for _, item := range v.PartitionBy { + groupByItems = append(groupByItems, item.Col) + } + e := &WindowExec{baseExecutor: base, + wf: windowfuncs.BuildWindowFunc(b.ctx, v.WindowFuncDesc, len(v.Schema().Columns)-1), + group: newGroup(b.ctx.GetSessionVars().StmtCtx, groupByItems), + childCols: v.ChildCols, + } + return e +} diff --git a/executor/window.go b/executor/window.go new file mode 100644 index 0000000000000..919aecbc6178d --- /dev/null +++ b/executor/window.go @@ -0,0 +1,158 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/executor/windowfuncs" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/util/chunk" +) + +// WindowExec is the executor for window functions. +type WindowExec struct { + baseExecutor + + group *group + inputIter *chunk.Iterator4Chunk + inputRow chunk.Row + groupRows []chunk.Row + childResults []*chunk.Chunk + wf windowfuncs.WindowFunc + executed bool + childCols []*expression.Column +} + +// Close implements the Executor Close interface. +func (e *WindowExec) Close() error { + e.childResults = nil + return errors.Trace(e.baseExecutor.Close()) +} + +// Next implements the Executor Next interface. +func (e *WindowExec) Next(ctx context.Context, chk *chunk.Chunk) error { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("windowExec.Next", opentracing.ChildOf(span.Context())) + defer span1.Finish() + } + if e.runtimeStats != nil { + start := time.Now() + defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() + } + chk.Reset() + if e.wf.RemainResult() { + e.appendResult2Chunk(chk) + } + for !e.executed && (chk.NumRows() == 0 || chk.RemainedRows(chk.NumCols()-1) > 0) { + err := e.consumeOneGroup(ctx, chk) + if err != nil { + e.executed = true + return errors.Trace(err) + } + } + return nil +} + +func (e *WindowExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) error { + for !e.executed { + if err := e.fetchChildIfNecessary(ctx, chk); err != nil { + return errors.Trace(err) + } + for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { + meetNewGroup, err := e.group.meetNewGroup(e.inputRow) + if err != nil { + return errors.Trace(err) + } + if meetNewGroup { + err := e.consumeGroupRows(chk) + if err != nil { + return errors.Trace(err) + } + err = e.appendResult2Chunk(chk) + if err != nil { + return errors.Trace(err) + } + } + e.groupRows = append(e.groupRows, e.inputRow) + if meetNewGroup { + e.inputRow = e.inputIter.Next() + return nil + } + } + } + return nil +} + +func (e *WindowExec) consumeGroupRows(chk *chunk.Chunk) error { + if len(e.groupRows) == 0 { + return nil + } + e.copyChk(chk) + var err error + e.groupRows, err = e.wf.ProcessOneChunk(e.ctx, e.groupRows, chk) + return err +} + +func (e *WindowExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk) (err error) { + if e.inputIter != nil && e.inputRow != e.inputIter.End() { + return nil + } + + // Before fetching a new batch of input, we should consume the last group. + err = e.consumeGroupRows(chk) + if err != nil { + return errors.Trace(err) + } + + childResult := e.children[0].newFirstChunk() + err = e.children[0].Next(ctx, childResult) + if err != nil { + return errors.Trace(err) + } + e.childResults = append(e.childResults, childResult) + // No more data. + if childResult.NumRows() == 0 { + e.executed = true + err = e.appendResult2Chunk(chk) + return errors.Trace(err) + } + + e.inputIter = chunk.NewIterator4Chunk(childResult) + e.inputRow = e.inputIter.Begin() + return nil +} + +// appendResult2Chunk appends result of all the aggregation functions to the +// result chunk, and reset the evaluation context for each aggregation. +func (e *WindowExec) appendResult2Chunk(chk *chunk.Chunk) error { + e.copyChk(chk) + var err error + e.groupRows, err = e.wf.ExhaustResult(e.ctx, e.groupRows, chk) + return err +} + +func (e *WindowExec) copyChk(chk *chunk.Chunk) { + if len(e.childResults) == 0 || chk.NumRows() > 0 { + return + } + childResult := e.childResults[0] + e.childResults = e.childResults[1:] + for i, col := range e.childCols { + chk.CopyColumns(childResult, i, col.Index) + } +} diff --git a/executor/window_test.go b/executor/window_test.go new file mode 100644 index 0000000000000..4bda1bd23e2a0 --- /dev/null +++ b/executor/window_test.go @@ -0,0 +1,41 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/testkit" +) + +func (s *testSuite2) TestWindowFunctions(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int, c int)") + tk.MustExec("set @@tidb_enable_window_function = 1") + defer func() { + tk.MustExec("set @@tidb_enable_window_function = 0") + }() + tk.MustExec("insert into t values (1,2,3),(4,3,2),(2,3,4)") + result := tk.MustQuery("select count(a) over () from t") + result.Check(testkit.Rows("3", "3", "3")) + result = tk.MustQuery("select sum(a) over () + count(a) over () from t") + result.Check(testkit.Rows("10", "10", "10")) + result = tk.MustQuery("select sum(a) over (partition by a) from t") + result.Check(testkit.Rows("1", "2", "4")) + result = tk.MustQuery("select 1 + sum(a) over (), count(a) over () from t") + result.Check(testkit.Rows("8 3", "8 3", "8 3")) + result = tk.MustQuery("select sum(t1.a) over() from t t1, t t2") + result.Check(testkit.Rows("21", "21", "21", "21", "21", "21", "21", "21", "21")) +} diff --git a/executor/windowfuncs/window_funcs.go b/executor/windowfuncs/window_funcs.go new file mode 100644 index 0000000000000..554afec2e704e --- /dev/null +++ b/executor/windowfuncs/window_funcs.go @@ -0,0 +1,67 @@ +package windowfuncs + +import ( + "github.com/pingcap/tidb/executor/aggfuncs" + "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +// WindowFunc is the interface for processing window functions. +type WindowFunc interface { + // ProcessOneChunk processes one chunk. + ProcessOneChunk(sctx sessionctx.Context, rows []chunk.Row, dest *chunk.Chunk) ([]chunk.Row, error) + // ExhaustResult exhausts result to the result chunk. + ExhaustResult(sctx sessionctx.Context, rows []chunk.Row, dest *chunk.Chunk) ([]chunk.Row, error) + // RemainResult checks if there are some remained results to be exhausted. + RemainResult() bool +} + +// aggNoFrame deals with agg functions with no frame specification. +type aggNoFrame struct { + result aggfuncs.PartialResult + agg aggfuncs.AggFunc + remained int64 +} + +// ProcessOneChunk implements the WindowFunc interface. +func (wf *aggNoFrame) ProcessOneChunk(sctx sessionctx.Context, rows []chunk.Row, dest *chunk.Chunk) ([]chunk.Row, error) { + err := wf.agg.UpdatePartialResult(sctx, rows, wf.result) + if err != nil { + return nil, err + } + wf.remained += int64(len(rows)) + rows = rows[:0] + return rows, nil +} + +// ExhaustResult implements the WindowFunc interface. +func (wf *aggNoFrame) ExhaustResult(sctx sessionctx.Context, rows []chunk.Row, dest *chunk.Chunk) ([]chunk.Row, error) { + rows = rows[:0] + for wf.remained > 0 && dest.RemainedRows(dest.NumCols()-1) > 0 { + err := wf.agg.AppendFinalResult2Chunk(sctx, wf.result, dest) + if err != nil { + return rows, err + } + wf.remained-- + } + if wf.remained == 0 { + wf.agg.ResetPartialResult(wf.result) + } + return rows, nil +} + +// RemainResult implements the WindowFunc interface. +func (wf *aggNoFrame) RemainResult() bool { + return wf.remained > 0 +} + +// BuildWindowFunc builds window functions according to the window functions description. +func BuildWindowFunc(ctx sessionctx.Context, window *aggregation.WindowFuncDesc, ordinal int) WindowFunc { + aggDesc := aggregation.NewAggFuncDesc(ctx, window.Name, window.Args, false) + agg := aggfuncs.Build(ctx, aggDesc, ordinal) + return &aggNoFrame{ + agg: agg, + result: agg.AllocPartialResult(), + } +} diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index d7ab7a3d0816a..400ca2210a902 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -211,3 +211,49 @@ func (a *baseFuncDesc) GetDefaultValue() (v types.Datum) { } return } + +// WrapCastForAggArgs wraps the args of an aggregate function with a cast function. +func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) { + // We do not need to wrap cast upon these functions, + // since the EvalXXX method called by the arg is determined by the corresponding arg type. + if a.Name == ast.AggFuncCount || a.Name == ast.AggFuncMin || a.Name == ast.AggFuncMax || a.Name == ast.AggFuncFirstRow { + return + } + var castFunc func(ctx sessionctx.Context, expr expression.Expression) expression.Expression + switch retTp := a.RetTp; retTp.EvalType() { + case types.ETInt: + castFunc = expression.WrapWithCastAsInt + case types.ETReal: + castFunc = expression.WrapWithCastAsReal + case types.ETString: + castFunc = expression.WrapWithCastAsString + case types.ETDecimal: + castFunc = expression.WrapWithCastAsDecimal + default: + panic("should never happen in executorBuilder.wrapCastForAggArgs") + } + for i := range a.Args { + a.Args[i] = castFunc(ctx, a.Args[i]) + if a.Name != ast.AggFuncAvg && a.Name != ast.AggFuncSum { + continue + } + // After wrapping cast on the argument, flen etc. may not the same + // as the type of the aggregation function. The following part set + // the type of the argument exactly as the type of the aggregation + // function. + // Note: If the `Tp` of argument is the same as the `Tp` of the + // aggregation function, it will not wrap cast function on it + // internally. The reason of the special handling for `Column` is + // that the `RetType` of `Column` refers to the `infoschema`, so we + // need to set a new variable for it to avoid modifying the + // definition in `infoschema`. + if col, ok := a.Args[i].(*expression.Column); ok { + col.RetType = types.NewFieldType(col.RetType.Tp) + } + // originTp is used when the the `Tp` of column is TypeFloat32 while + // the type of the aggregation function is TypeFloat64. + originTp := a.Args[i].GetType().Tp + *(a.Args[i].GetType()) = *(a.RetTp) + a.Args[i].GetType().Tp = originTp + } +} diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go new file mode 100644 index 0000000000000..9555f900bdb73 --- /dev/null +++ b/expression/aggregation/window_func.go @@ -0,0 +1,29 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregation + +import ( + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" +) + +// WindowFuncDesc describes a window function signature, only used in planner. +type WindowFuncDesc struct { + baseFuncDesc +} + +// NewWindowFuncDesc creates a window function signature descriptor. +func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc { + return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)} +} diff --git a/expression/column.go b/expression/column.go index 59ca27daa6706..681de551dbb8a 100644 --- a/expression/column.go +++ b/expression/column.go @@ -216,6 +216,7 @@ func (col *Column) EvalInt(ctx sessionctx.Context, row chunk.Row) (int64, bool, // EvalReal returns real representation of Column. func (col *Column) EvalReal(ctx sessionctx.Context, row chunk.Row) (float64, bool, error) { + log.Warning(col.Index) if row.IsNull(col.Index) { return 0, true, nil } diff --git a/go.mod b/go.mod index b3beb315f66cb..233e593fc2961 100644 --- a/go.mod +++ b/go.mod @@ -86,3 +86,5 @@ require ( sourcegraph.com/sourcegraph/appdash v0.0.0-20180531100431-4c381bd170b4 sourcegraph.com/sourcegraph/appdash-data v0.0.0-20151005221446-73f23eafcf67 ) + +replace github.com/pingcap/parser => github.com/lamxTyler/parser v0.0.0-20181226064458-13cec0b9d426 diff --git a/go.sum b/go.sum index 60c7b9c477f0d..c6977fa3d8415 100644 --- a/go.sum +++ b/go.sum @@ -81,6 +81,7 @@ github.com/klauspost/cpuid v0.0.0-20170728055534-ae7887de9fa5 h1:2U0HzY8BJ8hVwDK github.com/klauspost/cpuid v0.0.0-20170728055534-ae7887de9fa5/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/lamxTyler/parser v0.0.0-20181226064458-13cec0b9d426/go.mod h1:6c1rwSy9dUuNebYdr1IMI4+/sT3/Q65MXP2UCg7/vJI= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/montanaflynn/stats v0.0.0-20180911141734-db72e6cae808 h1:pmpDGKLw4n82EtrNiLqB+xSz/JQwFOaZuMALYUHwX5s= diff --git a/planner/core/errors.go b/planner/core/errors.go index 6d029c490f533..0541b4b8d7bfc 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -51,6 +51,9 @@ const ( codeWrongNumberOfColumnsInSelect = mysql.ErrWrongNumberOfColumnsInSelect codeWrongValueCountOnRow = mysql.ErrWrongValueCountOnRow codeTablenameNotAllowedHere = mysql.ErrTablenameNotAllowedHere + + codeWindowInvalidWindowFuncUse = mysql.ErrWindowInvalidWindowFuncUse + codeWindowInvalidWindowFuncAliasUse = mysql.ErrWindowInvalidWindowFuncAliasUse ) // error definitions. @@ -63,58 +66,62 @@ var ( ErrSchemaChanged = terror.ClassOptimizer.New(codeSchemaChanged, "Schema has changed") ErrTablenameNotAllowedHere = terror.ClassOptimizer.New(codeTablenameNotAllowedHere, "Table '%s' from one of the %ss cannot be used in %s") - ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage]) - ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq]) - ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown]) - ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField]) - ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable]) - ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) - ErrWrongNumberOfColumnsInSelect = terror.ClassOptimizer.New(codeWrongNumberOfColumnsInSelect, mysql.MySQLErrName[mysql.ErrWrongNumberOfColumnsInSelect]) - ErrBadGeneratedColumn = terror.ClassOptimizer.New(codeBadGeneratedColumn, mysql.MySQLErrName[mysql.ErrBadGeneratedColumn]) - ErrFieldNotInGroupBy = terror.ClassOptimizer.New(codeFieldNotInGroupBy, mysql.MySQLErrName[mysql.ErrFieldNotInGroupBy]) - ErrBadTable = terror.ClassOptimizer.New(codeBadTable, mysql.MySQLErrName[mysql.ErrBadTable]) - ErrKeyDoesNotExist = terror.ClassOptimizer.New(codeKeyDoesNotExist, mysql.MySQLErrName[mysql.ErrKeyDoesNotExist]) - ErrOperandColumns = terror.ClassOptimizer.New(codeOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns]) - ErrInvalidWildCard = terror.ClassOptimizer.New(codeInvalidWildCard, "Wildcard fields without any table name appears in wrong place") - ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(codeInvalidGroupFuncUse, mysql.MySQLErrName[mysql.ErrInvalidGroupFuncUse]) - ErrIllegalReference = terror.ClassOptimizer.New(codeIllegalReference, mysql.MySQLErrName[mysql.ErrIllegalReference]) - ErrNoDB = terror.ClassOptimizer.New(codeNoDB, mysql.MySQLErrName[mysql.ErrNoDB]) - ErrUnknownExplainFormat = terror.ClassOptimizer.New(codeUnknownExplainFormat, mysql.MySQLErrName[mysql.ErrUnknownExplainFormat]) - ErrWrongGroupField = terror.ClassOptimizer.New(codeWrongGroupField, mysql.MySQLErrName[mysql.ErrWrongGroupField]) - ErrDupFieldName = terror.ClassOptimizer.New(codeDupFieldName, mysql.MySQLErrName[mysql.ErrDupFieldName]) - ErrNonUpdatableTable = terror.ClassOptimizer.New(codeNonUpdatableTable, mysql.MySQLErrName[mysql.ErrNonUpdatableTable]) - ErrInternal = terror.ClassOptimizer.New(codeInternal, mysql.MySQLErrName[mysql.ErrInternal]) - ErrMixOfGroupFuncAndFields = terror.ClassOptimizer.New(codeMixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression #%d of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by") - ErrNonUniqTable = terror.ClassOptimizer.New(codeNonUniqTable, mysql.MySQLErrName[mysql.ErrNonuniqTable]) - ErrWrongValueCountOnRow = terror.ClassOptimizer.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow]) - ErrViewInvalid = terror.ClassOptimizer.New(mysql.ErrViewInvalid, mysql.MySQLErrName[mysql.ErrViewInvalid]) + ErrWrongUsage = terror.ClassOptimizer.New(codeWrongUsage, mysql.MySQLErrName[mysql.ErrWrongUsage]) + ErrAmbiguous = terror.ClassOptimizer.New(codeAmbiguous, mysql.MySQLErrName[mysql.ErrNonUniq]) + ErrUnknown = terror.ClassOptimizer.New(codeUnknown, mysql.MySQLErrName[mysql.ErrUnknown]) + ErrUnknownColumn = terror.ClassOptimizer.New(codeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField]) + ErrUnknownTable = terror.ClassOptimizer.New(codeUnknownTable, mysql.MySQLErrName[mysql.ErrUnknownTable]) + ErrWrongArguments = terror.ClassOptimizer.New(codeWrongArguments, mysql.MySQLErrName[mysql.ErrWrongArguments]) + ErrWrongNumberOfColumnsInSelect = terror.ClassOptimizer.New(codeWrongNumberOfColumnsInSelect, mysql.MySQLErrName[mysql.ErrWrongNumberOfColumnsInSelect]) + ErrBadGeneratedColumn = terror.ClassOptimizer.New(codeBadGeneratedColumn, mysql.MySQLErrName[mysql.ErrBadGeneratedColumn]) + ErrFieldNotInGroupBy = terror.ClassOptimizer.New(codeFieldNotInGroupBy, mysql.MySQLErrName[mysql.ErrFieldNotInGroupBy]) + ErrBadTable = terror.ClassOptimizer.New(codeBadTable, mysql.MySQLErrName[mysql.ErrBadTable]) + ErrKeyDoesNotExist = terror.ClassOptimizer.New(codeKeyDoesNotExist, mysql.MySQLErrName[mysql.ErrKeyDoesNotExist]) + ErrOperandColumns = terror.ClassOptimizer.New(codeOperandColumns, mysql.MySQLErrName[mysql.ErrOperandColumns]) + ErrInvalidWildCard = terror.ClassOptimizer.New(codeInvalidWildCard, "Wildcard fields without any table name appears in wrong place") + ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(codeInvalidGroupFuncUse, mysql.MySQLErrName[mysql.ErrInvalidGroupFuncUse]) + ErrIllegalReference = terror.ClassOptimizer.New(codeIllegalReference, mysql.MySQLErrName[mysql.ErrIllegalReference]) + ErrNoDB = terror.ClassOptimizer.New(codeNoDB, mysql.MySQLErrName[mysql.ErrNoDB]) + ErrUnknownExplainFormat = terror.ClassOptimizer.New(codeUnknownExplainFormat, mysql.MySQLErrName[mysql.ErrUnknownExplainFormat]) + ErrWrongGroupField = terror.ClassOptimizer.New(codeWrongGroupField, mysql.MySQLErrName[mysql.ErrWrongGroupField]) + ErrDupFieldName = terror.ClassOptimizer.New(codeDupFieldName, mysql.MySQLErrName[mysql.ErrDupFieldName]) + ErrNonUpdatableTable = terror.ClassOptimizer.New(codeNonUpdatableTable, mysql.MySQLErrName[mysql.ErrNonUpdatableTable]) + ErrInternal = terror.ClassOptimizer.New(codeInternal, mysql.MySQLErrName[mysql.ErrInternal]) + ErrMixOfGroupFuncAndFields = terror.ClassOptimizer.New(codeMixOfGroupFuncAndFields, "In aggregated query without GROUP BY, expression #%d of SELECT list contains nonaggregated column '%s'; this is incompatible with sql_mode=only_full_group_by") + ErrNonUniqTable = terror.ClassOptimizer.New(codeNonUniqTable, mysql.MySQLErrName[mysql.ErrNonuniqTable]) + ErrWrongValueCountOnRow = terror.ClassOptimizer.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow]) + ErrViewInvalid = terror.ClassOptimizer.New(mysql.ErrViewInvalid, mysql.MySQLErrName[mysql.ErrViewInvalid]) + ErrWindowInvalidWindowFuncUse = terror.ClassOptimizer.New(codeWindowInvalidWindowFuncUse, mysql.MySQLErrName[mysql.ErrWindowInvalidWindowFuncUse]) + ErrWindowInvalidWindowFuncAliasUse = terror.ClassOptimizer.New(codeWindowInvalidWindowFuncAliasUse, mysql.MySQLErrName[mysql.ErrWindowInvalidWindowFuncAliasUse]) ) func init() { mysqlErrCodeMap := map[terror.ErrCode]uint16{ - codeWrongUsage: mysql.ErrWrongUsage, - codeAmbiguous: mysql.ErrNonUniq, - codeUnknownColumn: mysql.ErrBadField, - codeUnknownTable: mysql.ErrBadTable, - codeWrongArguments: mysql.ErrWrongArguments, - codeBadGeneratedColumn: mysql.ErrBadGeneratedColumn, - codeFieldNotInGroupBy: mysql.ErrFieldNotInGroupBy, - codeBadTable: mysql.ErrBadTable, - codeKeyDoesNotExist: mysql.ErrKeyDoesNotExist, - codeOperandColumns: mysql.ErrOperandColumns, - codeInvalidWildCard: mysql.ErrParse, - codeInvalidGroupFuncUse: mysql.ErrInvalidGroupFuncUse, - codeIllegalReference: mysql.ErrIllegalReference, - codeNoDB: mysql.ErrNoDB, - codeUnknownExplainFormat: mysql.ErrUnknownExplainFormat, - codeWrongGroupField: mysql.ErrWrongGroupField, - codeDupFieldName: mysql.ErrDupFieldName, - codeNonUpdatableTable: mysql.ErrUnknownTable, - codeInternal: mysql.ErrInternal, - codeMixOfGroupFuncAndFields: mysql.ErrMixOfGroupFuncAndFields, - codeNonUniqTable: mysql.ErrNonuniqTable, - codeWrongNumberOfColumnsInSelect: mysql.ErrWrongNumberOfColumnsInSelect, - codeWrongValueCountOnRow: mysql.ErrWrongValueCountOnRow, + codeWrongUsage: mysql.ErrWrongUsage, + codeAmbiguous: mysql.ErrNonUniq, + codeUnknownColumn: mysql.ErrBadField, + codeUnknownTable: mysql.ErrBadTable, + codeWrongArguments: mysql.ErrWrongArguments, + codeBadGeneratedColumn: mysql.ErrBadGeneratedColumn, + codeFieldNotInGroupBy: mysql.ErrFieldNotInGroupBy, + codeBadTable: mysql.ErrBadTable, + codeKeyDoesNotExist: mysql.ErrKeyDoesNotExist, + codeOperandColumns: mysql.ErrOperandColumns, + codeInvalidWildCard: mysql.ErrParse, + codeInvalidGroupFuncUse: mysql.ErrInvalidGroupFuncUse, + codeIllegalReference: mysql.ErrIllegalReference, + codeNoDB: mysql.ErrNoDB, + codeUnknownExplainFormat: mysql.ErrUnknownExplainFormat, + codeWrongGroupField: mysql.ErrWrongGroupField, + codeDupFieldName: mysql.ErrDupFieldName, + codeNonUpdatableTable: mysql.ErrUnknownTable, + codeInternal: mysql.ErrInternal, + codeMixOfGroupFuncAndFields: mysql.ErrMixOfGroupFuncAndFields, + codeNonUniqTable: mysql.ErrNonuniqTable, + codeWrongNumberOfColumnsInSelect: mysql.ErrWrongNumberOfColumnsInSelect, + codeWrongValueCountOnRow: mysql.ErrWrongValueCountOnRow, + codeWindowInvalidWindowFuncUse: mysql.ErrWindowInvalidWindowFuncUse, + codeWindowInvalidWindowFuncAliasUse: mysql.ErrWindowInvalidWindowFuncAliasUse, } terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mysqlErrCodeMap } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index b77d7b2d244e2..fba5ee6e906d3 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -789,6 +789,23 @@ func (la *LogicalApply) exhaustPhysicalPlans(prop *property.PhysicalProperty) [] return []PhysicalPlan{apply} } +func (p *LogicalWindow) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan { + var byItems []property.Item + byItems = append(byItems, p.PartitionBy...) + byItems = append(byItems, p.OrderBy...) + childProperty := &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, Items: byItems, Enforced: true} + if !prop.IsPrefix(childProperty) { + return nil + } + window := PhysicalWindow{ + WindowFuncDesc: p.WindowFuncDesc, + PartitionBy: p.PartitionBy, + OrderBy: p.OrderBy, + }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProperty) + window.SetSchema(p.Schema()) + return []PhysicalPlan{window} +} + // exhaustPhysicalPlans is only for implementing interface. DataSource and Dual generate task in `findBestTask` directly. func (p *baseLogicalPlan) exhaustPhysicalPlans(_ *property.PhysicalProperty) []PhysicalPlan { panic("baseLogicalPlan.exhaustPhysicalPlans() should never be called.") diff --git a/planner/core/explain.go b/planner/core/explain.go index f4b37efd869d1..44e9ef4c81099 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -297,3 +297,9 @@ func (p *PhysicalTopN) ExplainInfo() string { fmt.Fprintf(buffer, ", offset:%v, count:%v", p.Offset, p.Count) return buffer.String() } + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalWindow) ExplainInfo() string { + // TODO: Add explain info for partition by, order by and frame. + return p.WindowFuncDesc.String() +} diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 2a4af05983ace..4bae008b4d74b 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -303,12 +303,25 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { } er.ctxStack = append(er.ctxStack, expression.NewValuesFunc(er.ctx, col.Index, col.RetType)) return inNode, true + case *ast.WindowFuncExpr: + return er.handleWindowFunction(v) default: er.asScalar = true } return inNode, false } +func (er *expressionRewriter) handleWindowFunction(v *ast.WindowFuncExpr) (ast.Node, bool) { + windowPlan, err := er.b.buildWindowFunction(er.p, v, er.aggrMap) + if err != nil { + er.err = err + return v, false + } + er.ctxStack = append(er.ctxStack, windowPlan.GetWindowResultColumn()) + er.p = windowPlan + return v, true +} + func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) (ast.Node, bool) { v.L.Accept(er) if er.err != nil { @@ -751,7 +764,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok } switch v := inNode.(type) { case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, - *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr: + *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr: case *driver.ValueExpr: value := &expression.Constant{Value: v.Datum, RetType: &v.Type} er.ctxStack = append(er.ctxStack, value) diff --git a/planner/core/initialize.go b/planner/core/initialize.go index 9c2b698c927f0..ff2a3dbc7b059 100644 --- a/planner/core/initialize.go +++ b/planner/core/initialize.go @@ -81,6 +81,8 @@ const ( TypeTableReader = "TableReader" // TypeIndexReader is the type of IndexReader. TypeIndexReader = "IndexReader" + // TypeWindow is the type of Window. + TypeWindow = "Window" ) // Init initializes LogicalAggregation. @@ -231,6 +233,20 @@ func (p PhysicalMaxOneRow) Init(ctx sessionctx.Context, stats *property.StatsInf return &p } +// Init initializes LogicalWindow. +func (p LogicalWindow) Init(ctx sessionctx.Context) *LogicalWindow { + p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeWindow, &p) + return &p +} + +// Init initializes PhysicalWindow. +func (p PhysicalWindow) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalWindow { + p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeWindow, &p) + p.childrenReqProps = props + p.stats = stats + return &p +} + // Init initializes Update. func (p Update) Init(ctx sessionctx.Context) *Update { p.basePlan = newBasePlan(ctx, TypeUpdate) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 592bc86519e1f..10db4ffd933cf 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" @@ -596,13 +597,32 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi } // buildProjection returns a Projection plan and non-aux columns length. -func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int) (LogicalPlan, int, error) { +func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, considerWindow bool) (LogicalPlan, int, error) { b.optFlag |= flagEliminateProjection b.curClause = fieldList proj := LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx) schema := expression.NewSchema(make([]*expression.Column, 0, len(fields))...) oldLen := 0 - for _, field := range fields { + for i, field := range fields { + if !field.Auxiliary { + oldLen++ + } + + isWindowFuncField := ast.HasWindowFlag(field.Expr) + // When `considerWindow` is false, we will only build fields for non-window functions, so we add fake placeholders. + // When `considerWindow` is true, all the non-window fields have been built, so we just use the schema columns. + if (considerWindow && !isWindowFuncField) || (!considerWindow && isWindowFuncField) { + var expr expression.Expression + if isWindowFuncField { + expr = expression.Zero + } else { + expr = p.Schema().Columns[i] + } + proj.Exprs = append(proj.Exprs, expr) + col := b.buildProjectionField(proj.id, schema.Len()+1, field, expr) + schema.Append(col) + continue + } newExpr, np, err := b.rewrite(field.Expr, p, mapper, true) if err != nil { return nil, 0, errors.Trace(err) @@ -613,10 +633,6 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, col := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr) schema.Append(col) - - if !field.Auxiliary { - oldLen++ - } } proj.SetSchema(schema) proj.SetChildren(p) @@ -999,10 +1015,11 @@ func resolveFromSelectFields(v *ast.ColumnNameExpr, fields []*ast.SelectField, i return } -// havingAndOrderbyExprResolver visits Expr tree. +// havingWindowAndOrderbyExprResolver visits Expr tree. // It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. -type havingAndOrderbyExprResolver struct { +type havingWindowAndOrderbyExprResolver struct { inAggFunc bool + inWindowFunc bool inExpr bool orderBy bool err error @@ -1016,10 +1033,12 @@ type havingAndOrderbyExprResolver struct { } // Enter implements Visitor interface. -func (a *havingAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { +func (a *havingWindowAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { switch n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = true + case *ast.WindowFuncExpr: + a.inWindowFunc = true case *driver.ParamMarkerExpr, *ast.ColumnNameExpr, *ast.ColumnName: case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: // Enter a new context, skip it. @@ -1031,7 +1050,7 @@ func (a *havingAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChi return n, false } -func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, schema *expression.Schema) (int, error) { +func (a *havingWindowAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, schema *expression.Schema) (int, error) { col, err := schema.FindColumn(v.Name) if err != nil { return -1, errors.Trace(err) @@ -1045,7 +1064,7 @@ func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, Name: col.ColName, } for i, field := range a.selectFields { - if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && colMatch(newColName, c.Name) { + if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && colMatch(c.Name, newColName) { return i, nil } } @@ -1059,7 +1078,7 @@ func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, } // Leave implements Visitor interface. -func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { +func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { switch v := n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = false @@ -1069,9 +1088,15 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool Expr: v, AsName: model.NewCIStr(fmt.Sprintf("sel_agg_%d", len(a.selectFields))), }) + case *ast.WindowFuncExpr: + a.inWindowFunc = false + if a.curClause == havingClause { + a.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F) + return node, false + } case *ast.ColumnNameExpr: resolveFieldsFirst := true - if a.inAggFunc || (a.orderBy && a.inExpr) { + if a.inAggFunc || a.inWindowFunc || (a.orderBy && a.inExpr) { resolveFieldsFirst = false } if !a.inAggFunc && !a.orderBy { @@ -1089,6 +1114,10 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool if a.err != nil { return node, false } + if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { + a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) + return node, false + } if index == -1 { if a.orderBy { index, a.err = a.resolveFromSchema(v, a.p.Schema()) @@ -1102,8 +1131,12 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool var err error index, err = a.resolveFromSchema(v, a.p.Schema()) _ = err - if index == -1 { + if index == -1 && a.curClause != windowClause { index, a.err = resolveFromSelectFields(v, a.selectFields, false) + if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { + a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) + return node, false + } } } if a.err != nil { @@ -1137,7 +1170,7 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool // When we rewrite the order by / having expression, we will find column in map at first. func (b *PlanBuilder) resolveHavingAndOrderBy(sel *ast.SelectStmt, p LogicalPlan) ( map[*ast.AggregateFuncExpr]int, map[*ast.AggregateFuncExpr]int, error) { - extractor := &havingAndOrderbyExprResolver{ + extractor := &havingWindowAndOrderbyExprResolver{ p: p, selectFields: sel.Fields.Fields, aggMapper: make(map[*ast.AggregateFuncExpr]int), @@ -1190,6 +1223,31 @@ func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega return aggList, totalAggMapper } +// resolveWindowFunction will process window functions and resolve the columns that don't exist in select fields. +func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) ( + map[*ast.AggregateFuncExpr]int, error) { + extractor := &havingWindowAndOrderbyExprResolver{ + p: p, + selectFields: sel.Fields.Fields, + aggMapper: make(map[*ast.AggregateFuncExpr]int), + colMapper: b.colMapper, + outerSchemas: b.outerSchemas, + } + extractor.curClause = windowClause + for _, field := range sel.Fields.Fields { + if !ast.HasWindowFlag(field.Expr) { + continue + } + n, ok := field.Expr.Accept(extractor) + if !ok { + return nil, extractor.err + } + field.Expr = n.(ast.ExprNode) + } + sel.Fields.Fields = extractor.selectFields + return extractor.aggMapper, nil +} + // gbyResolver resolves group by items from select fields. type gbyResolver struct { ctx sessionctx.Context @@ -1234,6 +1292,8 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { ret.Accept(extractor) if len(extractor.AggFuncs) != 0 { err = ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to group function") + } else if ast.HasWindowFlag(ret) { + err = ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to window function") } else { return ret, true } @@ -1727,6 +1787,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error var ( aggFuncs []*ast.AggregateFuncExpr havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int + windowMap map[*ast.AggregateFuncExpr]int gbyCols []expression.Expression ) @@ -1759,6 +1820,13 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } } + hasWindowFuncField := b.detectSelectWindow(sel) + if hasWindowFuncField { + windowMap, err = b.resolveWindowFunction(sel, p) + if err != nil { + return nil, err + } + } // We must resolve having and order by clause before build projection, // because when the query is "select a+1 as b from t having sum(b) < 0", we must replace sum(b) to sum(a+1), // which only can be done before building projection and extracting Agg functions. @@ -1792,7 +1860,8 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } var oldLen int - p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap) + // `considerWindow` is false now because we can only process window functions after having clause. + p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap, false) if err != nil { return nil, errors.Trace(err) } @@ -1805,6 +1874,14 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } } + if hasWindowFuncField { + // Now we build the window function fields. + p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, windowMap, true) + if err != nil { + return nil, err + } + } + if sel.Distinct { p = b.buildDistinct(p, oldLen) } @@ -2484,6 +2561,104 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { return del, nil } +func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, []property.Item, []expression.Expression, error) { + b.optFlag |= flagEliminateProjection + + var items []*ast.ByItem + spec := expr.Spec + if spec.PartitionBy != nil { + items = append(items, spec.PartitionBy.Items...) + } + if spec.OrderBy != nil { + items = append(items, spec.OrderBy.Items...) + } + projLen := len(p.Schema().Columns) + len(items) + len(expr.Args) + proj := LogicalProjection{Exprs: make([]expression.Expression, 0, projLen)}.Init(b.ctx) + schema := expression.NewSchema(make([]*expression.Column, 0, projLen)...) + for _, col := range p.Schema().Columns { + proj.Exprs = append(proj.Exprs, col) + schema.Append(col) + } + + transformer := &itemTransformer{} + propertyItems := make([]property.Item, 0, len(items)) + for _, item := range items { + newExpr, _ := item.Expr.Accept(transformer) + item.Expr = newExpr.(ast.ExprNode) + it, np, err := b.rewrite(item.Expr, p, aggMap, true) + if err != nil { + return nil, nil, nil, err + } + p = np + if col, ok := it.(*expression.Column); ok { + propertyItems = append(propertyItems, property.Item{Col: col, Desc: item.Desc}) + continue + } + proj.Exprs = append(proj.Exprs, it) + col := &expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), schema.Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: it.GetType(), + } + schema.Append(col) + propertyItems = append(propertyItems, property.Item{Col: col, Desc: item.Desc}) + } + + newArgList := make([]expression.Expression, 0, len(expr.Args)) + for _, arg := range expr.Args { + newArg, np, err := b.rewrite(arg, p, aggMap, true) + if err != nil { + return nil, nil, nil, err + } + p = np + if col, ok := newArg.(*expression.Column); ok { + newArgList = append(newArgList, col) + continue + } + proj.Exprs = append(proj.Exprs, newArg) + col := &expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), schema.Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: newArg.GetType(), + } + schema.Append(col) + newArgList = append(newArgList, col) + } + + proj.SetSchema(schema) + proj.SetChildren(p) + return proj, propertyItems, newArgList, nil +} + +func (b *PlanBuilder) buildWindowFunction(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (*LogicalWindow, error) { + p, byItems, args, err := b.buildProjectionForWindow(p, expr, aggMap) + if err != nil { + return nil, err + } + + desc := aggregation.NewWindowFuncDesc(b.ctx, expr.F, args) + desc.WrapCastForAggArgs(b.ctx) + lenPartition := 0 + if expr.Spec.PartitionBy != nil { + lenPartition = len(expr.Spec.PartitionBy.Items) + } + window := LogicalWindow{ + WindowFuncDesc: desc, + PartitionBy: byItems[0:lenPartition], + OrderBy: byItems[lenPartition:], + }.Init(b.ctx) + schema := p.Schema().Clone() + schema.Append(&expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_window_%d", window.id, p.Schema().Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + IsReferenced: true, + RetType: desc.RetTp, + }) + window.SetChildren(p) + window.SetSchema(schema) + return window, nil +} + // extractTableList extracts all the TableNames from node. func extractTableList(node ast.ResultSetNode, input []*ast.TableName) []*ast.TableName { switch x := node.(type) { diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index d0c173e7ae374..d59b11f37acb7 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1897,3 +1897,96 @@ func (s *testPlanSuite) TestSelectView(c *C) { c.Assert(ToString(p), Equals, tt.best, comment) } } + +func (s *testPlanSuite) TestWindowFunction(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + sql string + result string + }{ + { + sql: "select a, avg(a) over(partition by a) from t", + result: "TableReader(Table(t))->Window(avg(cast(test.t.a)))->Projection", + }, + { + sql: "select a, avg(a) over(partition by b) from t", + result: "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a)))->Projection", + }, + { + sql: "select a, avg(a+1) over(partition by (a+1)) from t", + result: "TableReader(Table(t))->Projection->Sort->Window(avg(cast(2_proj_window_3)))->Projection", + }, + { + sql: "select a, avg(a) over(order by a asc, b desc) from t order by a asc, b desc", + result: "TableReader(Table(t))->Sort->Window(avg(cast(test.t.a)))->Projection", + }, + { + sql: "select a, b as a, avg(a) over(partition by a) from t", + result: "TableReader(Table(t))->Window(avg(cast(test.t.a)))->Projection", + }, + { + sql: "select a, b as z, sum(z) over() from t", + result: "[planner:1054]Unknown column 'z' in 'field list'", + }, + { + sql: "select a, b as z from t order by (sum(z) over())", + result: "TableReader(Table(t))->Window(sum(cast(test.t.z)))->Sort->Projection", + }, + { + sql: "select sum(avg(a)) over() from t", + result: "TableReader(Table(t)->StreamAgg)->StreamAgg->Window(sum(sel_agg_2))->Projection", + }, + { + sql: "select b from t order by(sum(a) over())", + result: "TableReader(Table(t))->Window(sum(cast(test.t.a)))->Sort->Projection", + }, + { + sql: "select b from t order by(sum(a) over(partition by a))", + result: "TableReader(Table(t))->Window(sum(cast(test.t.a)))->Sort->Projection", + }, + { + sql: "select b from t order by(sum(avg(a)) over())", + result: "TableReader(Table(t)->StreamAgg)->StreamAgg->Window(sum(sel_agg_2))->Sort->Projection", + }, + { + sql: "select a from t having (select sum(a) over() as w from t tt where a > t.a)", + result: "Apply{TableReader(Table(t))->TableReader(Table(t)->Sel([gt(tt.a, test.t.a)]))->Window(sum(cast(tt.a)))->MaxOneRow->Sel([w])}->Projection", + }, + { + sql: "select avg(a) over() as w from t having w > 1", + result: "[planner:3594]You cannot use the alias 'w' of an expression containing a window function in this context.'", + }, + { + sql: "select sum(a) over() as sum_a from t group by sum_a", + result: "[planner:1247]Reference 'sum_a' not supported (reference to window function)", + }, + } + + s.Parser.EnableWindowFunc(true) + defer func() { + s.Parser.EnableWindowFunc(false) + }() + for i, tt := range tests { + comment := Commentf("case:%v sql:%s", i, tt.sql) + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + Preprocess(s.ctx, stmt, s.is, false) + builder := &PlanBuilder{ + ctx: MockContext(), + is: s.is, + colMapper: make(map[*ast.ColumnNameExpr]int), + } + p, err := builder.Build(stmt) + if err != nil { + c.Assert(err.Error(), Equals, tt.result, comment) + continue + } + c.Assert(err, IsNil) + p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + c.Assert(err, IsNil) + lp, ok := p.(LogicalPlan) + c.Assert(ok, IsTrue) + p, err = physicalOptimize(lp) + c.Assert(ToString(p), Equals, tt.result, comment) + } +} diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index aa2cc510b67c9..b6f7463ee3d65 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" @@ -42,6 +43,7 @@ var ( _ LogicalPlan = &LogicalSort{} _ LogicalPlan = &LogicalLock{} _ LogicalPlan = &LogicalLimit{} + _ LogicalPlan = &LogicalWindow{} ) // JoinType contains CrossJoin, InnerJoin, LeftOuterJoin, RightOuterJoin, FullOuterJoin, SemiJoin. @@ -617,3 +619,18 @@ type LogicalLock struct { Lock ast.SelectLockType } + +// LogicalWindow represents a logical window function plan. +type LogicalWindow struct { + logicalSchemaProducer + + WindowFuncDesc *aggregation.WindowFuncDesc + PartitionBy []property.Item + OrderBy []property.Item + // TODO: add frame clause +} + +// GetWindowResultColumn returns the column storing the result of the window function. +func (p *LogicalWindow) GetWindowResultColumn() *expression.Column { + return p.schema.Columns[p.schema.Len()-1] +} diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 4c66330fabdfe..bb9f6188c2c33 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" @@ -47,6 +48,7 @@ var ( _ PhysicalPlan = &PhysicalHashJoin{} _ PhysicalPlan = &PhysicalMergeJoin{} _ PhysicalPlan = &PhysicalUnionScan{} + _ PhysicalPlan = &PhysicalWindow{} ) // PhysicalTableReader is the table reader in tidb. @@ -373,3 +375,13 @@ type PhysicalTableDual struct { RowCount int } + +// PhysicalWindow is the physical operator of window function. +type PhysicalWindow struct { + physicalSchemaProducer + + WindowFuncDesc *aggregation.WindowFuncDesc + PartitionBy []property.Item + OrderBy []property.Item + ChildCols []*expression.Column +} diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 87ec96a350385..ee5708b95b2ae 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -93,6 +93,7 @@ const ( onClause orderByClause whereClause + windowClause groupByClause showStatement globalOrderByClause @@ -108,6 +109,7 @@ var clauseMsg = map[clauseCode]string{ groupByClause: "group statement", showStatement: "show statement", globalOrderByClause: "global ORDER clause", + windowClause: "field list", // For window functions that in field list. } // PlanBuilder builds Plan from an ast.Node. @@ -300,6 +302,15 @@ func (b *PlanBuilder) detectSelectAgg(sel *ast.SelectStmt) bool { return false } +func (b *PlanBuilder) detectSelectWindow(sel *ast.SelectStmt) bool { + for _, f := range sel.Fields.Fields { + if ast.HasWindowFlag(f.Expr) { + return true + } + } + return false +} + func getPathByIndexName(paths []*accessPath, idxName model.CIStr, tblInfo *model.TableInfo) *accessPath { var tablePath *accessPath for _, path := range paths { diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index 27562bc21838b..23c0a3e07bb9f 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -184,6 +184,27 @@ func (p *PhysicalSort) ResolveIndices() { } } +// ResolveIndices implements Plan interface. +func (p *PhysicalWindow) ResolveIndices() { + p.physicalSchemaProducer.ResolveIndices() + p.ChildCols = p.Schema().Columns[:len(p.Schema().Columns)-1] + for i, col := range p.ChildCols { + newCol := col.ResolveIndices(p.children[0].Schema()) + p.ChildCols[i] = newCol.(*expression.Column) + } + for i, item := range p.PartitionBy { + newCol := item.Col.ResolveIndices(p.children[0].Schema()) + p.PartitionBy[i].Col = newCol.(*expression.Column) + } + for i, item := range p.OrderBy { + newCol := item.Col.ResolveIndices(p.children[0].Schema()) + p.OrderBy[i].Col = newCol.(*expression.Column) + } + for i, arg := range p.WindowFuncDesc.Args { + p.WindowFuncDesc.Args[i] = arg.ResolveIndices(p.children[0].Schema()) + } +} + // ResolveIndices implements Plan interface. func (p *PhysicalTopN) ResolveIndices() { p.basePhysicalPlan.ResolveIndices() diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 05b239ee3baf7..07f6e7fd7b38c 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -15,7 +15,6 @@ package core import ( "fmt" - "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" @@ -265,3 +264,33 @@ func (p *LogicalLock) PruneColumns(parentUsedCols []*expression.Column) { p.children[0].PruneColumns(parentUsedCols) } } + +// PruneColumns implements LogicalPlan interface. +func (p *LogicalWindow) PruneColumns(parentUsedCols []*expression.Column) { + windowColumn := p.GetWindowResultColumn() + len := 0 + for _, col := range parentUsedCols { + if !windowColumn.Equal(nil, col) { + parentUsedCols[len] = col + len++ + } + } + parentUsedCols = parentUsedCols[:len] + parentUsedCols = p.extractUsedCols(parentUsedCols) + p.children[0].PruneColumns(parentUsedCols) + p.SetSchema(p.children[0].Schema().Clone()) + p.Schema().Append(windowColumn) +} + +func (p *LogicalWindow) extractUsedCols(parentUsedCols []*expression.Column) []*expression.Column { + for _, arg := range p.WindowFuncDesc.Args { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(arg)...) + } + for _, by := range p.PartitionBy { + parentUsedCols = append(parentUsedCols, by.Col) + } + for _, by := range p.OrderBy { + parentUsedCols = append(parentUsedCols, by.Col) + } + return parentUsedCols +} diff --git a/planner/core/stats.go b/planner/core/stats.go index 3d39c7fa8473c..32972f6132a33 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -328,3 +328,16 @@ func (p *LogicalMaxOneRow) DeriveStats(childStats []*property.StatsInfo) (*prope p.stats = getSingletonStats(p.Schema().Len()) return p.stats, nil } + +// DeriveStats implement LogicalPlan DeriveStats interface. +func (p *LogicalWindow) DeriveStats(childStats []*property.StatsInfo) (*property.StatsInfo, error) { + childProfile := childStats[0] + childLen := len(childProfile.Cardinality) + p.stats = &property.StatsInfo{ + RowCount: childProfile.RowCount, + Cardinality: make([]float64, childLen+1), + } + copy(p.stats.Cardinality, childProfile.Cardinality) + p.stats.Cardinality[childLen] = childProfile.RowCount + return p.stats, nil +} diff --git a/planner/core/stringer.go b/planner/core/stringer.go index 6cd078354469a..62459b66d4524 100644 --- a/planner/core/stringer.go +++ b/planner/core/stringer.go @@ -220,6 +220,10 @@ func toString(in Plan, strs []string, idxs []int) ([]string, []int) { if x.SelectPlan != nil { str = fmt.Sprintf("%s->Insert", ToString(x.SelectPlan)) } + case *LogicalWindow: + str = fmt.Sprintf("Window(%s)", x.WindowFuncDesc.String()) + case *PhysicalWindow: + str = fmt.Sprintf("Window(%s)", x.WindowFuncDesc.String()) default: str = fmt.Sprintf("%T", in) } diff --git a/util/chunk/chunk.go b/util/chunk/chunk.go index 3b91cf3bb693d..85e216e784cf6 100644 --- a/util/chunk/chunk.go +++ b/util/chunk/chunk.go @@ -187,6 +187,11 @@ func (c *Chunk) SwapColumns(other *Chunk) { c.numVirtualRows, other.numVirtualRows = other.numVirtualRows, c.numVirtualRows } +// CopyColumns copies columns `other.columns[from]` to `c.columns[dst]`. +func (c *Chunk) CopyColumns(other *Chunk, dst, from int) { + c.columns[dst] = other.columns[from] +} + // SetNumVirtualRows sets the virtual row number for a Chunk. // It should only be used when there exists no column in the Chunk. func (c *Chunk) SetNumVirtualRows(numVirtualRows int) { @@ -523,3 +528,8 @@ func readTime(buf []byte) types.Time { Fsp: fsp, } } + +// RemainedRows returns the number of rows needs to be appended in specific column. +func (c *Chunk) RemainedRows(colIdx int) int { + return c.columns[0].length - c.columns[colIdx].length +}