From 4422a23d9b732f62245523318e2be2e2439fb3ea Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Thu, 14 Mar 2019 14:40:51 +0800 Subject: [PATCH] executor: support window function lead and lag (#9672) --- executor/aggfuncs/builder.go | 29 +++++++++ executor/aggfuncs/func_lead_lag.go | 89 +++++++++++++++++++++++++++ executor/window_test.go | 9 +++ expression/aggregation/base_func.go | 15 +++++ expression/aggregation/window_func.go | 11 +++- expression/builtin_control.go | 8 +-- 6 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 executor/aggfuncs/func_lead_lag.go diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 4340e38da44e0..77f08de8515b6 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -75,6 +75,10 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag return buildNthValue(windowFuncDesc, ordinal) case ast.WindowFuncPercentRank: return buildPercenRank(ordinal, orderByCols) + case ast.WindowFuncLead: + return buildLead(windowFuncDesc, ordinal) + case ast.WindowFuncLag: + return buildLag(windowFuncDesc, ordinal) default: return Build(ctx, windowFuncDesc, ordinal) } @@ -395,3 +399,28 @@ func buildPercenRank(ordinal int, orderByCols []*expression.Column) AggFunc { } return &percentRank{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)} } + +func buildLeadLag(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) baseLeadLag { + offset := uint64(1) + if len(aggFuncDesc.Args) >= 2 { + offset, _, _ = expression.GetUint64FromConstant(aggFuncDesc.Args[1]) + } + var defaultExpr expression.Expression + defaultExpr = expression.Null + if len(aggFuncDesc.Args) == 3 { + defaultExpr = aggFuncDesc.Args[2] + } + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + } + return baseLeadLag{baseAggFunc: base, offset: offset, defaultExpr: defaultExpr, valueEvaluator: buildValueEvaluator(aggFuncDesc.RetTp)} +} + +func buildLead(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + return &lead{buildLeadLag(aggFuncDesc, ordinal)} +} + +func buildLag(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + return &lag{buildLeadLag(aggFuncDesc, ordinal)} +} diff --git a/executor/aggfuncs/func_lead_lag.go b/executor/aggfuncs/func_lead_lag.go new file mode 100644 index 0000000000000..ba53e9eb47809 --- /dev/null +++ b/executor/aggfuncs/func_lead_lag.go @@ -0,0 +1,89 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +type baseLeadLag struct { + baseAggFunc + valueEvaluator // TODO: move it to partial result when parallel execution is supported. + + defaultExpr expression.Expression + offset uint64 +} + +type partialResult4LeadLag struct { + rows []chunk.Row + curIdx uint64 +} + +func (v *baseLeadLag) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4LeadLag{}) +} + +func (v *baseLeadLag) ResetPartialResult(pr PartialResult) { + p := (*partialResult4LeadLag)(pr) + p.rows = p.rows[:0] + p.curIdx = 0 +} + +func (v *baseLeadLag) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4LeadLag)(pr) + p.rows = append(p.rows, rowsInGroup...) + return nil +} + +type lead struct { + baseLeadLag +} + +func (v *lead) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4LeadLag)(pr) + var err error + if p.curIdx+v.offset < uint64(len(p.rows)) { + err = v.evaluateRow(sctx, v.args[0], p.rows[p.curIdx+v.offset]) + } else { + err = v.evaluateRow(sctx, v.defaultExpr, p.rows[p.curIdx]) + } + if err != nil { + return err + } + v.appendResult(chk, v.ordinal) + p.curIdx++ + return nil +} + +type lag struct { + baseLeadLag +} + +func (v *lag) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4LeadLag)(pr) + var err error + if p.curIdx >= v.offset { + err = v.evaluateRow(sctx, v.args[0], p.rows[p.curIdx-v.offset]) + } else { + err = v.evaluateRow(sctx, v.defaultExpr, p.rows[p.curIdx]) + } + if err != nil { + return err + } + v.appendResult(chk, v.ordinal) + p.curIdx++ + return nil +} diff --git a/executor/window_test.go b/executor/window_test.go index 5b108ddabb25d..dc17af8a3e870 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -126,4 +126,13 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result.Check(testkit.Rows("1 0", "1 0", "2 0.6666666666666666", "2 0.6666666666666666")) result = tk.MustQuery("select a, b, percent_rank() over(order by a, b) from t") result.Check(testkit.Rows("1 1 0", "1 2 0.3333333333333333", "2 1 0.6666666666666666", "2 2 1")) + + result = tk.MustQuery("select a, lead(a) over (), lag(a) over() from t") + result.Check(testkit.Rows("1 1 ", "1 2 1", "2 2 1", "2 2")) + result = tk.MustQuery("select a, lead(a, 0) over(), lag(a, 0) over() from t") + result.Check(testkit.Rows("1 1 1", "1 1 1", "2 2 2", "2 2 2")) + result = tk.MustQuery("select a, lead(a, 1, a) over(), lag(a, 1, a) over() from t") + result.Check(testkit.Rows("1 1 1", "1 2 1", "2 2 1", "2 2 2")) + result = tk.MustQuery("select a, lead(a, 1, 'lead') over(), lag(a, 1, 'lag') over() from t") + result.Check(testkit.Rows("1 1 lag", "1 2 1", "2 2 1", "2 lead 2")) } diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 91462d5ed62da..4a0fc4dc165d8 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -102,6 +102,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { a.typeInfer4CumeDist() case ast.WindowFuncPercentRank: a.typeInfer4PercentRank() + case ast.WindowFuncLead, ast.WindowFuncLag: + a.typeInfer4LeadLag(ctx) default: panic("unsupported agg function: " + a.Name) } @@ -207,6 +209,15 @@ func (a *baseFuncDesc) typeInfer4PercentRank() { a.RetTp.Flag, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec } +func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) { + if len(a.Args) <= 2 { + a.typeInfer4MaxMin(ctx) + } else { + // Merge the type of first and third argument. + a.RetTp = expression.InferType4ControlFuncs(a.Args[0].GetType(), a.Args[2].GetType()) + } +} + // GetDefaultValue gets the default value when the function's input is null. // According to MySQL, default values of the function are listed as follows: // e.g. @@ -265,6 +276,10 @@ func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) { panic("should never happen in baseFuncDesc.WrapCastForAggArgs") } for i := range a.Args { + // Do not cast the second args of these functions, as they are simply non-negative numbers. + if i == 1 && (a.Name == ast.WindowFuncLead || a.Name == ast.WindowFuncLag || a.Name == ast.WindowFuncNthValue) { + continue + } a.Args[i] = castFunc(ctx, a.Args[i]) if a.Name != ast.AggFuncAvg && a.Name != ast.AggFuncSum { continue diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go index aa436b957aab7..41c17f16682fb 100644 --- a/expression/aggregation/window_func.go +++ b/expression/aggregation/window_func.go @@ -28,12 +28,21 @@ type WindowFuncDesc struct { // NewWindowFuncDesc creates a window function signature descriptor. func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc { - if strings.ToLower(name) == ast.WindowFuncNthValue { + switch strings.ToLower(name) { + case ast.WindowFuncNthValue: val, isNull, ok := expression.GetUint64FromConstant(args[1]) // nth_value does not allow `0`, but allows `null`. if !ok || (val == 0 && !isNull) { return nil } + case ast.WindowFuncLead, ast.WindowFuncLag: + if len(args) < 2 { + break + } + _, isNull, ok := expression.GetUint64FromConstant(args[1]) + if !ok || isNull { + return nil + } } return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)} } diff --git a/expression/builtin_control.go b/expression/builtin_control.go index ef9c3b49adc2d..015f9705f2844 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -53,8 +53,8 @@ var ( _ builtinFunc = &builtinIfJSONSig{} ) -// inferType4ControlFuncs infer result type for builtin IF, IFNULL && NULLIF. -func inferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType { +// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, LEAD and LAG. +func InferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType { resultFieldType := &types.FieldType{} if lhs.Tp == mysql.TypeNull { *resultFieldType = *rhs @@ -470,7 +470,7 @@ func (c *ifFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) if err = c.verifyArgs(args); err != nil { return nil, err } - retTp := inferType4ControlFuncs(args[1].GetType(), args[2].GetType()) + retTp := InferType4ControlFuncs(args[1].GetType(), args[2].GetType()) evalTps := retTp.EvalType() bf := newBaseBuiltinFuncWithTp(ctx, args, evalTps, types.ETInt, evalTps, evalTps) retTp.Flag |= bf.tp.Flag @@ -680,7 +680,7 @@ func (c *ifNullFunctionClass) getFunction(ctx sessionctx.Context, args []Express return nil, err } lhs, rhs := args[0].GetType(), args[1].GetType() - retTp := inferType4ControlFuncs(lhs, rhs) + retTp := InferType4ControlFuncs(lhs, rhs) retTp.Flag |= (lhs.Flag & mysql.NotNullFlag) | (rhs.Flag & mysql.NotNullFlag) if lhs.Tp == mysql.TypeNull && rhs.Tp == mysql.TypeNull { retTp.Tp = mysql.TypeNull