diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go new file mode 100644 index 0000000000000..9555f900bdb73 --- /dev/null +++ b/expression/aggregation/window_func.go @@ -0,0 +1,29 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregation + +import ( + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" +) + +// WindowFuncDesc describes a window function signature, only used in planner. +type WindowFuncDesc struct { + baseFuncDesc +} + +// NewWindowFuncDesc creates a window function signature descriptor. +func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc { + return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)} +} diff --git a/go.mod b/go.mod index c7e83e66896cf..aac0314e95b31 100644 --- a/go.mod +++ b/go.mod @@ -48,7 +48,7 @@ require ( github.com/pingcap/gofail v0.0.0-20181217135706-6a951c1e42c3 github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c - github.com/pingcap/parser v0.0.0-20181225032741-ff56f7f11ed6 + github.com/pingcap/parser v0.0.0-20190103075927-c065c7404641 github.com/pingcap/pd v2.1.0-rc.4+incompatible github.com/pingcap/tidb-tools v2.1.1-0.20181218072513-b2235d442b06+incompatible github.com/pingcap/tipb v0.0.0-20181012112600-11e33c750323 diff --git a/go.sum b/go.sum index 7488fcf41cf2f..9fc08d73e161b 100644 --- a/go.sum +++ b/go.sum @@ -110,8 +110,8 @@ github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e h1:P73/4dPCL96rG github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e/go.mod h1:O17XtbryoCJhkKGbT62+L2OlrniwqiGLSqrmdHCMzZw= github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c h1:Qf5St5XGwKgKQLar9lEXoeO0hJMVaFBj3JqvFguWtVg= github.com/pingcap/kvproto v0.0.0-20181203065228-c14302da291c/go.mod h1:Ja9XPjot9q4/3JyCZodnWDGNXt4pKemhIYCvVJM7P24= -github.com/pingcap/parser v0.0.0-20181225032741-ff56f7f11ed6 h1:ooapyJxH6uSHNvpYjPOggHtd2dzLKwSvYLVzs3OjoM0= -github.com/pingcap/parser v0.0.0-20181225032741-ff56f7f11ed6/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/parser v0.0.0-20190103075927-c065c7404641 h1:KTGU8kr2wY+FRiHHs8I5lp385b+OzYnbOr3/tPVw7mU= +github.com/pingcap/parser v0.0.0-20190103075927-c065c7404641/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd v2.1.0-rc.4+incompatible h1:/buwGk04aHO5odk/+O8ZOXGs4qkUjYTJ2UpCJXna8NE= github.com/pingcap/pd v2.1.0-rc.4+incompatible/go.mod h1:nD3+EoYes4+aNNODO99ES59V83MZSI+dFbhyr667a0E= github.com/pingcap/tidb-tools v2.1.1-0.20181218072513-b2235d442b06+incompatible h1:Bsd+NHosPVowEGB3BCx+2d8wUQGDTXSSC5ljeNS6cXo= diff --git a/planner/core/errors.go b/planner/core/errors.go index 6d029c490f533..b49a3c9719328 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -51,6 +51,9 @@ const ( codeWrongNumberOfColumnsInSelect = mysql.ErrWrongNumberOfColumnsInSelect codeWrongValueCountOnRow = mysql.ErrWrongValueCountOnRow codeTablenameNotAllowedHere = mysql.ErrTablenameNotAllowedHere + + codeWindowInvalidWindowFuncUse = mysql.ErrWindowInvalidWindowFuncUse + codeWindowInvalidWindowFuncAliasUse = mysql.ErrWindowInvalidWindowFuncAliasUse ) // error definitions. @@ -88,6 +91,9 @@ var ( ErrNonUniqTable = terror.ClassOptimizer.New(codeNonUniqTable, mysql.MySQLErrName[mysql.ErrNonuniqTable]) ErrWrongValueCountOnRow = terror.ClassOptimizer.New(mysql.ErrWrongValueCountOnRow, mysql.MySQLErrName[mysql.ErrWrongValueCountOnRow]) ErrViewInvalid = terror.ClassOptimizer.New(mysql.ErrViewInvalid, mysql.MySQLErrName[mysql.ErrViewInvalid]) + + ErrWindowInvalidWindowFuncUse = terror.ClassOptimizer.New(codeWindowInvalidWindowFuncUse, mysql.MySQLErrName[mysql.ErrWindowInvalidWindowFuncUse]) + ErrWindowInvalidWindowFuncAliasUse = terror.ClassOptimizer.New(codeWindowInvalidWindowFuncAliasUse, mysql.MySQLErrName[mysql.ErrWindowInvalidWindowFuncAliasUse]) ) func init() { @@ -115,6 +121,9 @@ func init() { codeNonUniqTable: mysql.ErrNonuniqTable, codeWrongNumberOfColumnsInSelect: mysql.ErrWrongNumberOfColumnsInSelect, codeWrongValueCountOnRow: mysql.ErrWrongValueCountOnRow, + + codeWindowInvalidWindowFuncUse: mysql.ErrWindowInvalidWindowFuncUse, + codeWindowInvalidWindowFuncAliasUse: mysql.ErrWindowInvalidWindowFuncAliasUse, } terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mysqlErrCodeMap } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index b77d7b2d244e2..ee8796fb19994 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -789,6 +789,18 @@ func (la *LogicalApply) exhaustPhysicalPlans(prop *property.PhysicalProperty) [] return []PhysicalPlan{apply} } +func (p *LogicalWindow) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan { + childProperty := &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, Items: p.ByItems, Enforced: true} + if !prop.IsPrefix(childProperty) { + return nil + } + window := PhysicalWindow{ + WindowFuncDesc: p.WindowFuncDesc, + }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProperty) + window.SetSchema(p.Schema()) + return []PhysicalPlan{window} +} + // exhaustPhysicalPlans is only for implementing interface. DataSource and Dual generate task in `findBestTask` directly. func (p *baseLogicalPlan) exhaustPhysicalPlans(_ *property.PhysicalProperty) []PhysicalPlan { panic("baseLogicalPlan.exhaustPhysicalPlans() should never be called.") diff --git a/planner/core/explain.go b/planner/core/explain.go index f4b37efd869d1..44e9ef4c81099 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -297,3 +297,9 @@ func (p *PhysicalTopN) ExplainInfo() string { fmt.Fprintf(buffer, ", offset:%v, count:%v", p.Offset, p.Count) return buffer.String() } + +// ExplainInfo implements PhysicalPlan interface. +func (p *PhysicalWindow) ExplainInfo() string { + // TODO: Add explain info for partition by, order by and frame. + return p.WindowFuncDesc.String() +} diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index c25a63f7a4c7e..4f00fbb23c962 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -305,12 +305,25 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { } er.ctxStack = append(er.ctxStack, expression.NewValuesFunc(er.ctx, col.Index, col.RetType)) return inNode, true + case *ast.WindowFuncExpr: + return er.handleWindowFunction(v) default: er.asScalar = true } return inNode, false } +func (er *expressionRewriter) handleWindowFunction(v *ast.WindowFuncExpr) (ast.Node, bool) { + windowPlan, err := er.b.buildWindowFunction(er.p, v, er.aggrMap) + if err != nil { + er.err = err + return v, false + } + er.ctxStack = append(er.ctxStack, windowPlan.GetWindowResultColumn()) + er.p = windowPlan + return v, true +} + func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) (ast.Node, bool) { v.L.Accept(er) if er.err != nil { @@ -753,7 +766,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok } switch v := inNode.(type) { case *ast.AggregateFuncExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, - *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr: + *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr: case *driver.ValueExpr: value := &expression.Constant{Value: v.Datum, RetType: &v.Type} er.ctxStack = append(er.ctxStack, value) diff --git a/planner/core/initialize.go b/planner/core/initialize.go index 9c2b698c927f0..ff2a3dbc7b059 100644 --- a/planner/core/initialize.go +++ b/planner/core/initialize.go @@ -81,6 +81,8 @@ const ( TypeTableReader = "TableReader" // TypeIndexReader is the type of IndexReader. TypeIndexReader = "IndexReader" + // TypeWindow is the type of Window. + TypeWindow = "Window" ) // Init initializes LogicalAggregation. @@ -231,6 +233,20 @@ func (p PhysicalMaxOneRow) Init(ctx sessionctx.Context, stats *property.StatsInf return &p } +// Init initializes LogicalWindow. +func (p LogicalWindow) Init(ctx sessionctx.Context) *LogicalWindow { + p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeWindow, &p) + return &p +} + +// Init initializes PhysicalWindow. +func (p PhysicalWindow) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalWindow { + p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeWindow, &p) + p.childrenReqProps = props + p.stats = stats + return &p +} + // Init initializes Update. func (p Update) Init(ctx sessionctx.Context) *Update { p.basePlan = newBasePlan(ctx, TypeUpdate) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 82a7788ec0032..20e76e036e787 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -33,6 +33,7 @@ import ( "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" @@ -596,13 +597,35 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi } // buildProjection returns a Projection plan and non-aux columns length. -func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int) (LogicalPlan, int, error) { +func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, considerWindow bool) (LogicalPlan, int, error) { b.optFlag |= flagEliminateProjection b.curClause = fieldList proj := LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx) schema := expression.NewSchema(make([]*expression.Column, 0, len(fields))...) oldLen := 0 - for _, field := range fields { + for i, field := range fields { + if !field.Auxiliary { + oldLen++ + } + + isWindowFuncField := ast.HasWindowFlag(field.Expr) + // Although window functions occurs in the select fields, but it has to be processed after having clause. + // So when we build the projection for select fields, we need to skip the window function. + // When `considerWindow` is false, we will only build fields for non-window functions, so we add fake placeholders. + // for window functions. These fake placeholders will be erased in column pruning. + // When `considerWindow` is true, all the non-window fields have been built, so we just use the schema columns. + if (considerWindow && !isWindowFuncField) || (!considerWindow && isWindowFuncField) { + var expr expression.Expression + if isWindowFuncField { + expr = expression.Zero + } else { + expr = p.Schema().Columns[i] + } + proj.Exprs = append(proj.Exprs, expr) + col := b.buildProjectionField(proj.id, schema.Len()+1, field, expr) + schema.Append(col) + continue + } newExpr, np, err := b.rewrite(field.Expr, p, mapper, true) if err != nil { return nil, 0, errors.Trace(err) @@ -613,10 +636,6 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, col := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr) schema.Append(col) - - if !field.Auxiliary { - oldLen++ - } } proj.SetSchema(schema) proj.SetChildren(p) @@ -999,10 +1018,11 @@ func resolveFromSelectFields(v *ast.ColumnNameExpr, fields []*ast.SelectField, i return } -// havingAndOrderbyExprResolver visits Expr tree. +// havingWindowAndOrderbyExprResolver visits Expr tree. // It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. -type havingAndOrderbyExprResolver struct { +type havingWindowAndOrderbyExprResolver struct { inAggFunc bool + inWindowFunc bool inExpr bool orderBy bool err error @@ -1016,10 +1036,12 @@ type havingAndOrderbyExprResolver struct { } // Enter implements Visitor interface. -func (a *havingAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { +func (a *havingWindowAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { switch n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = true + case *ast.WindowFuncExpr: + a.inWindowFunc = true case *driver.ParamMarkerExpr, *ast.ColumnNameExpr, *ast.ColumnName: case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: // Enter a new context, skip it. @@ -1031,7 +1053,7 @@ func (a *havingAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChi return n, false } -func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, schema *expression.Schema) (int, error) { +func (a *havingWindowAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, schema *expression.Schema) (int, error) { col, err := schema.FindColumn(v.Name) if err != nil { return -1, errors.Trace(err) @@ -1045,7 +1067,7 @@ func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, Name: col.ColName, } for i, field := range a.selectFields { - if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && colMatch(newColName, c.Name) { + if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && colMatch(c.Name, newColName) { return i, nil } } @@ -1059,7 +1081,7 @@ func (a *havingAndOrderbyExprResolver) resolveFromSchema(v *ast.ColumnNameExpr, } // Leave implements Visitor interface. -func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { +func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { switch v := n.(type) { case *ast.AggregateFuncExpr: a.inAggFunc = false @@ -1069,9 +1091,15 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool Expr: v, AsName: model.NewCIStr(fmt.Sprintf("sel_agg_%d", len(a.selectFields))), }) + case *ast.WindowFuncExpr: + a.inWindowFunc = false + if a.curClause == havingClause { + a.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F) + return node, false + } case *ast.ColumnNameExpr: resolveFieldsFirst := true - if a.inAggFunc || (a.orderBy && a.inExpr) { + if a.inAggFunc || a.inWindowFunc || (a.orderBy && a.inExpr) { resolveFieldsFirst = false } if !a.inAggFunc && !a.orderBy { @@ -1089,6 +1117,10 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool if a.err != nil { return node, false } + if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { + a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) + return node, false + } if index == -1 { if a.orderBy { index, a.err = a.resolveFromSchema(v, a.p.Schema()) @@ -1102,8 +1134,12 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool var err error index, err = a.resolveFromSchema(v, a.p.Schema()) _ = err - if index == -1 { + if index == -1 && a.curClause != windowClause { index, a.err = resolveFromSelectFields(v, a.selectFields, false) + if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { + a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) + return node, false + } } } if a.err != nil { @@ -1137,7 +1173,7 @@ func (a *havingAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool // When we rewrite the order by / having expression, we will find column in map at first. func (b *PlanBuilder) resolveHavingAndOrderBy(sel *ast.SelectStmt, p LogicalPlan) ( map[*ast.AggregateFuncExpr]int, map[*ast.AggregateFuncExpr]int, error) { - extractor := &havingAndOrderbyExprResolver{ + extractor := &havingWindowAndOrderbyExprResolver{ p: p, selectFields: sel.Fields.Fields, aggMapper: make(map[*ast.AggregateFuncExpr]int), @@ -1190,6 +1226,31 @@ func (b *PlanBuilder) extractAggFuncs(fields []*ast.SelectField) ([]*ast.Aggrega return aggList, totalAggMapper } +// resolveWindowFunction will process window functions and resolve the columns that don't exist in select fields. +func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) ( + map[*ast.AggregateFuncExpr]int, error) { + extractor := &havingWindowAndOrderbyExprResolver{ + p: p, + selectFields: sel.Fields.Fields, + aggMapper: make(map[*ast.AggregateFuncExpr]int), + colMapper: b.colMapper, + outerSchemas: b.outerSchemas, + } + extractor.curClause = windowClause + for _, field := range sel.Fields.Fields { + if !ast.HasWindowFlag(field.Expr) { + continue + } + n, ok := field.Expr.Accept(extractor) + if !ok { + return nil, extractor.err + } + field.Expr = n.(ast.ExprNode) + } + sel.Fields.Fields = extractor.selectFields + return extractor.aggMapper, nil +} + // gbyResolver resolves group by items from select fields. type gbyResolver struct { ctx sessionctx.Context @@ -1234,6 +1295,8 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { ret.Accept(extractor) if len(extractor.AggFuncs) != 0 { err = ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to group function") + } else if ast.HasWindowFlag(ret) { + err = ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to window function") } else { return ret, true } @@ -1727,6 +1790,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error var ( aggFuncs []*ast.AggregateFuncExpr havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int + windowMap map[*ast.AggregateFuncExpr]int gbyCols []expression.Expression ) @@ -1759,6 +1823,13 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } } + hasWindowFuncField := b.detectSelectWindow(sel) + if hasWindowFuncField { + windowMap, err = b.resolveWindowFunction(sel, p) + if err != nil { + return nil, err + } + } // We must resolve having and order by clause before build projection, // because when the query is "select a+1 as b from t having sum(b) < 0", we must replace sum(b) to sum(a+1), // which only can be done before building projection and extracting Agg functions. @@ -1792,7 +1863,9 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } var oldLen int - p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap) + // According to https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html, + // we can only process window functions after having clause, so `considerWindow` is false now. + p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap, false) if err != nil { return nil, errors.Trace(err) } @@ -1805,6 +1878,14 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } } + if hasWindowFuncField { + // Now we build the window function fields. + p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, windowMap, true) + if err != nil { + return nil, err + } + } + if sel.Distinct { p = b.buildDistinct(p, oldLen) } @@ -2503,6 +2584,100 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { return del, nil } +// buildProjectionForWindow builds the projection for expressions in the window specification that is not an column, +// so after the projection, window functions only needs to deal with columns. +func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, []property.Item, []expression.Expression, error) { + b.optFlag |= flagEliminateProjection + + var items []*ast.ByItem + spec := expr.Spec + if spec.PartitionBy != nil { + items = append(items, spec.PartitionBy.Items...) + } + if spec.OrderBy != nil { + items = append(items, spec.OrderBy.Items...) + } + projLen := len(p.Schema().Columns) + len(items) + len(expr.Args) + proj := LogicalProjection{Exprs: make([]expression.Expression, 0, projLen)}.Init(b.ctx) + schema := expression.NewSchema(make([]*expression.Column, 0, projLen)...) + for _, col := range p.Schema().Columns { + proj.Exprs = append(proj.Exprs, col) + schema.Append(col) + } + + transformer := &itemTransformer{} + propertyItems := make([]property.Item, 0, len(items)) + for _, item := range items { + newExpr, _ := item.Expr.Accept(transformer) + item.Expr = newExpr.(ast.ExprNode) + it, np, err := b.rewrite(item.Expr, p, aggMap, true) + if err != nil { + return nil, nil, nil, err + } + p = np + if col, ok := it.(*expression.Column); ok { + propertyItems = append(propertyItems, property.Item{Col: col, Desc: item.Desc}) + continue + } + proj.Exprs = append(proj.Exprs, it) + col := &expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), schema.Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: it.GetType(), + } + schema.Append(col) + propertyItems = append(propertyItems, property.Item{Col: col, Desc: item.Desc}) + } + + newArgList := make([]expression.Expression, 0, len(expr.Args)) + for _, arg := range expr.Args { + newArg, np, err := b.rewrite(arg, p, aggMap, true) + if err != nil { + return nil, nil, nil, err + } + p = np + if col, ok := newArg.(*expression.Column); ok { + newArgList = append(newArgList, col) + continue + } + proj.Exprs = append(proj.Exprs, newArg) + col := &expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), schema.Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: newArg.GetType(), + } + schema.Append(col) + newArgList = append(newArgList, col) + } + + proj.SetSchema(schema) + proj.SetChildren(p) + return proj, propertyItems, newArgList, nil +} + +func (b *PlanBuilder) buildWindowFunction(p LogicalPlan, expr *ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (*LogicalWindow, error) { + p, byItems, args, err := b.buildProjectionForWindow(p, expr, aggMap) + if err != nil { + return nil, err + } + + desc := aggregation.NewWindowFuncDesc(b.ctx, expr.F, args) + window := LogicalWindow{ + WindowFuncDesc: desc, + ByItems: byItems, + }.Init(b.ctx) + schema := p.Schema().Clone() + schema.Append(&expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_window_%d", window.id, p.Schema().Len())), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + IsReferenced: true, + RetType: desc.RetTp, + }) + window.SetChildren(p) + window.SetSchema(schema) + return window, nil +} + // extractTableList extracts all the TableNames from node. func extractTableList(node ast.ResultSetNode, input []*ast.TableName) []*ast.TableName { switch x := node.(type) { diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index d0c173e7ae374..43a00e4cff1c1 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -1897,3 +1897,96 @@ func (s *testPlanSuite) TestSelectView(c *C) { c.Assert(ToString(p), Equals, tt.best, comment) } } + +func (s *testPlanSuite) TestWindowFunction(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + sql string + result string + }{ + { + sql: "select a, avg(a) over(partition by a) from t", + result: "TableReader(Table(t))->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a, avg(a) over(partition by b) from t", + result: "TableReader(Table(t))->Sort->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a, avg(a+1) over(partition by (a+1)) from t", + result: "TableReader(Table(t))->Projection->Sort->Window(avg(2_proj_window_3))->Projection", + }, + { + sql: "select a, avg(a) over(order by a asc, b desc) from t order by a asc, b desc", + result: "TableReader(Table(t))->Sort->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a, b as a, avg(a) over(partition by a) from t", + result: "TableReader(Table(t))->Window(avg(test.t.a))->Projection", + }, + { + sql: "select a, b as z, sum(z) over() from t", + result: "[planner:1054]Unknown column 'z' in 'field list'", + }, + { + sql: "select a, b as z from t order by (sum(z) over())", + result: "TableReader(Table(t))->Window(sum(test.t.z))->Sort->Projection", + }, + { + sql: "select sum(avg(a)) over() from t", + result: "TableReader(Table(t)->StreamAgg)->StreamAgg->Window(sum(sel_agg_2))->Projection", + }, + { + sql: "select b from t order by(sum(a) over())", + result: "TableReader(Table(t))->Window(sum(test.t.a))->Sort->Projection", + }, + { + sql: "select b from t order by(sum(a) over(partition by a))", + result: "TableReader(Table(t))->Window(sum(test.t.a))->Sort->Projection", + }, + { + sql: "select b from t order by(sum(avg(a)) over())", + result: "TableReader(Table(t)->StreamAgg)->StreamAgg->Window(sum(sel_agg_2))->Sort->Projection", + }, + { + sql: "select a from t having (select sum(a) over() as w from t tt where a > t.a)", + result: "Apply{TableReader(Table(t))->TableReader(Table(t)->Sel([gt(tt.a, test.t.a)]))->Window(sum(tt.a))->MaxOneRow->Sel([w])}->Projection", + }, + { + sql: "select avg(a) over() as w from t having w > 1", + result: "[planner:3594]You cannot use the alias 'w' of an expression containing a window function in this context.'", + }, + { + sql: "select sum(a) over() as sum_a from t group by sum_a", + result: "[planner:1247]Reference 'sum_a' not supported (reference to window function)", + }, + } + + s.Parser.EnableWindowFunc(true) + defer func() { + s.Parser.EnableWindowFunc(false) + }() + for i, tt := range tests { + comment := Commentf("case:%v sql:%s", i, tt.sql) + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + Preprocess(s.ctx, stmt, s.is, false) + builder := &PlanBuilder{ + ctx: MockContext(), + is: s.is, + colMapper: make(map[*ast.ColumnNameExpr]int), + } + p, err := builder.Build(stmt) + if err != nil { + c.Assert(err.Error(), Equals, tt.result, comment) + continue + } + c.Assert(err, IsNil) + p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + c.Assert(err, IsNil) + lp, ok := p.(LogicalPlan) + c.Assert(ok, IsTrue) + p, err = physicalOptimize(lp) + c.Assert(ToString(p), Equals, tt.result, comment) + } +} diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index aa2cc510b67c9..1a625a93ad709 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" @@ -42,6 +43,7 @@ var ( _ LogicalPlan = &LogicalSort{} _ LogicalPlan = &LogicalLock{} _ LogicalPlan = &LogicalLimit{} + _ LogicalPlan = &LogicalWindow{} ) // JoinType contains CrossJoin, InnerJoin, LeftOuterJoin, RightOuterJoin, FullOuterJoin, SemiJoin. @@ -617,3 +619,17 @@ type LogicalLock struct { Lock ast.SelectLockType } + +// LogicalWindow represents a logical window function plan. +type LogicalWindow struct { + logicalSchemaProducer + + WindowFuncDesc *aggregation.WindowFuncDesc + ByItems []property.Item // ByItems is composed of `PARTITION BY` and `ORDER BY` items. + // TODO: add frame clause +} + +// GetWindowResultColumn returns the column storing the result of the window function. +func (p *LogicalWindow) GetWindowResultColumn() *expression.Column { + return p.schema.Columns[p.schema.Len()-1] +} diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 4c66330fabdfe..4f404e88d9f2c 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -47,6 +47,7 @@ var ( _ PhysicalPlan = &PhysicalHashJoin{} _ PhysicalPlan = &PhysicalMergeJoin{} _ PhysicalPlan = &PhysicalUnionScan{} + _ PhysicalPlan = &PhysicalWindow{} ) // PhysicalTableReader is the table reader in tidb. @@ -373,3 +374,10 @@ type PhysicalTableDual struct { RowCount int } + +// PhysicalWindow is the physical operator of window function. +type PhysicalWindow struct { + physicalSchemaProducer + + WindowFuncDesc *aggregation.WindowFuncDesc +} diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index b5dc12c6ab878..8d0d44ec576a3 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -93,6 +93,7 @@ const ( onClause orderByClause whereClause + windowClause groupByClause showStatement globalOrderByClause @@ -108,6 +109,7 @@ var clauseMsg = map[clauseCode]string{ groupByClause: "group statement", showStatement: "show statement", globalOrderByClause: "global ORDER clause", + windowClause: "field list", // For window functions that in field list. } // PlanBuilder builds Plan from an ast.Node. @@ -300,6 +302,15 @@ func (b *PlanBuilder) detectSelectAgg(sel *ast.SelectStmt) bool { return false } +func (b *PlanBuilder) detectSelectWindow(sel *ast.SelectStmt) bool { + for _, f := range sel.Fields.Fields { + if ast.HasWindowFlag(f.Expr) { + return true + } + } + return false +} + func getPathByIndexName(paths []*accessPath, idxName model.CIStr, tblInfo *model.TableInfo) *accessPath { var tablePath *accessPath for _, path := range paths { diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 05b239ee3baf7..bf2d86a916575 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -15,7 +15,6 @@ package core import ( "fmt" - "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" @@ -265,3 +264,30 @@ func (p *LogicalLock) PruneColumns(parentUsedCols []*expression.Column) { p.children[0].PruneColumns(parentUsedCols) } } + +// PruneColumns implements LogicalPlan interface. +func (p *LogicalWindow) PruneColumns(parentUsedCols []*expression.Column) { + windowColumn := p.GetWindowResultColumn() + len := 0 + for _, col := range parentUsedCols { + if !windowColumn.Equal(nil, col) { + parentUsedCols[len] = col + len++ + } + } + parentUsedCols = parentUsedCols[:len] + parentUsedCols = p.extractUsedCols(parentUsedCols) + p.children[0].PruneColumns(parentUsedCols) + p.SetSchema(p.children[0].Schema().Clone()) + p.Schema().Append(windowColumn) +} + +func (p *LogicalWindow) extractUsedCols(parentUsedCols []*expression.Column) []*expression.Column { + for _, arg := range p.WindowFuncDesc.Args { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(arg)...) + } + for _, by := range p.ByItems { + parentUsedCols = append(parentUsedCols, by.Col) + } + return parentUsedCols +} diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 3e70bdb998327..597ce50a9e8d6 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -117,6 +117,8 @@ func (pe *projectionEliminater) eliminate(p LogicalPlan, replace map[string]*exp childFlag = false } else if _, isAgg := p.(*LogicalAggregation); isAgg || isProj { childFlag = true + } else if _, isWindow := p.(*LogicalWindow); isWindow { + childFlag = true } for i, child := range p.Children() { p.Children()[i] = pe.eliminate(child, replace, childFlag) @@ -204,3 +206,12 @@ func (lt *LogicalTopN) replaceExprColumns(replace map[string]*expression.Column) resolveExprAndReplace(byItem.Expr, replace) } } + +func (p *LogicalWindow) replaceExprColumns(replace map[string]*expression.Column) { + for _, arg := range p.WindowFuncDesc.Args { + resolveExprAndReplace(arg, replace) + } + for _, item := range p.ByItems { + resolveColumnAndReplace(item.Col, replace) + } +} diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index bf352289da908..1523ebeaf136f 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -465,3 +465,10 @@ func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []e p.attachOnConds(joinConds) return predicates } + +// PredicatePushDown implements LogicalPlan PredicatePushDown interface. +func (p *LogicalWindow) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan) { + // Window function forbids any condition to push down. + p.baseLogicalPlan.PredicatePushDown(nil) + return predicates, p +} diff --git a/planner/core/stats.go b/planner/core/stats.go index 3d39c7fa8473c..32972f6132a33 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -328,3 +328,16 @@ func (p *LogicalMaxOneRow) DeriveStats(childStats []*property.StatsInfo) (*prope p.stats = getSingletonStats(p.Schema().Len()) return p.stats, nil } + +// DeriveStats implement LogicalPlan DeriveStats interface. +func (p *LogicalWindow) DeriveStats(childStats []*property.StatsInfo) (*property.StatsInfo, error) { + childProfile := childStats[0] + childLen := len(childProfile.Cardinality) + p.stats = &property.StatsInfo{ + RowCount: childProfile.RowCount, + Cardinality: make([]float64, childLen+1), + } + copy(p.stats.Cardinality, childProfile.Cardinality) + p.stats.Cardinality[childLen] = childProfile.RowCount + return p.stats, nil +} diff --git a/planner/core/stringer.go b/planner/core/stringer.go index 6cd078354469a..62459b66d4524 100644 --- a/planner/core/stringer.go +++ b/planner/core/stringer.go @@ -220,6 +220,10 @@ func toString(in Plan, strs []string, idxs []int) ([]string, []int) { if x.SelectPlan != nil { str = fmt.Sprintf("%s->Insert", ToString(x.SelectPlan)) } + case *LogicalWindow: + str = fmt.Sprintf("Window(%s)", x.WindowFuncDesc.String()) + case *PhysicalWindow: + str = fmt.Sprintf("Window(%s)", x.WindowFuncDesc.String()) default: str = fmt.Sprintf("%T", in) }