From 8d9a8df756d2624747e7e890d4d6761103986e1c Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Thu, 7 Mar 2019 16:14:18 +0800 Subject: [PATCH 1/2] executor: support window function nth_value --- executor/aggfuncs/builder.go | 12 +++++++ executor/aggfuncs/func_value.go | 47 +++++++++++++++++++++++++++ executor/window_test.go | 9 +++++ expression/aggregation/base_func.go | 2 +- expression/aggregation/window_func.go | 7 ++++ expression/util.go | 29 +++++++++++++++++ planner/core/logical_plan_builder.go | 8 +++-- planner/core/logical_plan_test.go | 8 +++++ 8 files changed, 119 insertions(+), 3 deletions(-) diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 7d5c559d26ffc..726a50858c583 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -69,6 +69,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag return buildFirstValue(windowFuncDesc, ordinal) case ast.WindowFuncLastValue: return buildLastValue(windowFuncDesc, ordinal) + case ast.WindowFuncNthValue: + return buildNthValue(windowFuncDesc, ordinal) default: return Build(ctx, windowFuncDesc, ordinal) } @@ -368,3 +370,13 @@ func buildLastValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { } return &lastValue{baseAggFunc: base, tp: aggFuncDesc.RetTp} } + +func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + } + // Already checked when building the function description. + nth, _, _ := expression.GetUint64FromConstant(aggFuncDesc.Args[1]) + return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth} +} diff --git a/executor/aggfuncs/func_value.go b/executor/aggfuncs/func_value.go index 18c552e7a1bed..36d6535eca97f 100644 --- a/executor/aggfuncs/func_value.go +++ b/executor/aggfuncs/func_value.go @@ -300,3 +300,50 @@ func (v *lastValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialR } return nil } + +type nthValue struct { + baseAggFunc + + tp *types.FieldType + nth uint64 +} + +type partialResult4NthValue struct { + seenRows uint64 + evaluator valueEvaluator +} + +func (v *nthValue) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4NthValue{evaluator: buildValueEvaluator(v.tp)}) +} + +func (v *nthValue) ResetPartialResult(pr PartialResult) { + p := (*partialResult4NthValue)(pr) + p.seenRows = 0 +} + +func (v *nthValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + if v.nth == 0 { + return nil + } + p := (*partialResult4NthValue)(pr) + numRows := uint64(len(rowsInGroup)) + if v.nth > p.seenRows && v.nth-p.seenRows <= numRows { + err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[v.nth-p.seenRows-1]) + if err != nil { + return err + } + } + p.seenRows += numRows + return nil +} + +func (v *nthValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4NthValue)(pr) + if v.nth == 0 || p.seenRows < v.nth { + chk.AppendNull(v.ordinal) + } else { + p.evaluator.appendResult(chk, v.ordinal) + } + return nil +} diff --git a/executor/window_test.go b/executor/window_test.go index 4870251fa3b40..d49ab6328e3c2 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -103,4 +103,13 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result = tk.MustQuery("select a, first_value(rand(0)) over(), last_value(rand(0)) over() from t") result.Check(testkit.Rows("1 0.9451961492941164 0.05434383959970039", "1 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039")) + + result = tk.MustQuery("select a, nth_value(a, null) over() from t") + result.Check(testkit.Rows("1 ", "1 ", "2 ", "2 ")) + result = tk.MustQuery("select a, nth_value(a, 1) over() from t") + result.Check(testkit.Rows("1 1", "1 1", "2 1", "2 1")) + result = tk.MustQuery("select a, nth_value(a, 4) over() from t") + result.Check(testkit.Rows("1 2", "1 2", "2 2", "2 2")) + result = tk.MustQuery("select a, nth_value(a, 5) over() from t") + result.Check(testkit.Rows("1 ", "1 ", "2 ", "2 ")) } diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index cfd1f16b74e39..78942bc5648d6 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -92,7 +92,7 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { case ast.AggFuncGroupConcat: a.typeInfer4GroupConcat(ctx) case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow, - ast.WindowFuncFirstValue, ast.WindowFuncLastValue: + ast.WindowFuncFirstValue, ast.WindowFuncLastValue, ast.WindowFuncNthValue: a.typeInfer4MaxMin(ctx) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: a.typeInfer4BitFuncs(ctx) diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go index a0e629350aec6..aa436b957aab7 100644 --- a/expression/aggregation/window_func.go +++ b/expression/aggregation/window_func.go @@ -28,6 +28,13 @@ type WindowFuncDesc struct { // NewWindowFuncDesc creates a window function signature descriptor. func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc { + if strings.ToLower(name) == ast.WindowFuncNthValue { + val, isNull, ok := expression.GetUint64FromConstant(args[1]) + // nth_value does not allow `0`, but allows `null`. + if !ok || (val == 0 && !isNull) { + return nil + } + } return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)} } diff --git a/expression/util.go b/expression/util.go index 98220b9275fb6..1184f06011ee8 100644 --- a/expression/util.go +++ b/expression/util.go @@ -670,3 +670,32 @@ func RemoveDupExprs(ctx sessionctx.Context, exprs []Expression) []Expression { } return res } + +// GetUint64FromConstant gets a uint64 from constant expression. +func GetUint64FromConstant(expr Expression) (uint64, bool, bool) { + con, ok := expr.(*Constant) + if !ok { + return 0, false, false + } + dt := con.Value + if con.DeferredExpr != nil { + var err error + dt, err = con.DeferredExpr.Eval(chunk.Row{}) + if err != nil { + return 0, false, false + } + } + switch dt.Kind() { + case types.KindNull: + return 0, true, true + case types.KindInt64: + val := dt.GetInt64() + if val < 0 { + return 0, false, false + } + return uint64(val), false, true + case types.KindUint64: + return dt.GetUint64(), false, true + } + return 0, false, false +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 1ca56be8d876a..d42b8076f226b 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2762,8 +2762,9 @@ func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, expr *ast.WindowFu return nil, nil, nil, nil, err } p = np - if col, ok := newArg.(*expression.Column); ok { - newArgList = append(newArgList, col) + switch newArg.(type) { + case *expression.Column, *expression.Constant: + newArgList = append(newArgList, newArg) continue } proj.Exprs = append(proj.Exprs, newArg) @@ -2948,6 +2949,9 @@ func (b *PlanBuilder) buildWindowFunction(p LogicalPlan, expr *ast.WindowFuncExp return nil, err } desc := aggregation.NewWindowFuncDesc(b.ctx, expr.F, args) + if desc == nil { + return nil, ErrWrongArguments.GenWithStackByArgs(expr.F) + } // TODO: Check if the function is aggregation function after we support more functions. desc.WrapCastForAggArgs(b.ctx) window := LogicalWindow{ diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index e735348b75c12..0da33d6d932ba 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -2195,6 +2195,14 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { sql: "select row_number() over(rows between 1 preceding and 1 following) from t", result: "TableReader(Table(t))->Window(row_number() over())->Projection", }, + { + sql: "select nth_value(a, 1.0) over() from t", + result: "[planner:1210]Incorrect arguments to nth_value", + }, + { + sql: "select nth_value(a, 0) over() from t", + result: "[planner:1210]Incorrect arguments to nth_value", + }, } s.Parser.EnableWindowFunc(true) From 1f966918e7f03394122a832d16de810f4884f851 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Fri, 8 Mar 2019 14:22:13 +0800 Subject: [PATCH 2/2] log error message --- expression/util.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/expression/util.go b/expression/util.go index 1184f06011ee8..82b1071b409ca 100644 --- a/expression/util.go +++ b/expression/util.go @@ -20,6 +20,7 @@ import ( "unicode" "github.com/pingcap/errors" + "github.com/pingcap/log" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" @@ -28,6 +29,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" + "go.uber.org/zap" "golang.org/x/tools/container/intsets" ) @@ -675,6 +677,7 @@ func RemoveDupExprs(ctx sessionctx.Context, exprs []Expression) []Expression { func GetUint64FromConstant(expr Expression) (uint64, bool, bool) { con, ok := expr.(*Constant) if !ok { + log.Warn("not a constant expression", zap.Any("value", expr)) return 0, false, false } dt := con.Value @@ -682,6 +685,7 @@ func GetUint64FromConstant(expr Expression) (uint64, bool, bool) { var err error dt, err = con.DeferredExpr.Eval(chunk.Row{}) if err != nil { + log.Warn("eval deferred expr failed", zap.Error(err)) return 0, false, false } }