diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 4ef95f7f4a701..4340e38da44e0 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -73,6 +73,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag return buildCumeDist(ordinal, orderByCols) case ast.WindowFuncNthValue: return buildNthValue(windowFuncDesc, ordinal) + case ast.WindowFuncPercentRank: + return buildPercenRank(ordinal, orderByCols) default: return Build(ctx, windowFuncDesc, ordinal) } @@ -386,3 +388,10 @@ func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { nth, _, _ := expression.GetUint64FromConstant(aggFuncDesc.Args[1]) return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth} } + +func buildPercenRank(ordinal int, orderByCols []*expression.Column) AggFunc { + base := baseAggFunc{ + ordinal: ordinal, + } + return &percentRank{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)} +} diff --git a/executor/aggfuncs/func_percent_rank.go b/executor/aggfuncs/func_percent_rank.go new file mode 100644 index 0000000000000..0ea863a048bb6 --- /dev/null +++ b/executor/aggfuncs/func_percent_rank.go @@ -0,0 +1,61 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +// percentRank calculates the percentage of partition values less than the value in the current row, excluding the highest value. +// It can be calculated as `(rank - 1) / (total_rows_in_set - 1). +type percentRank struct { + baseAggFunc + rowComparer +} + +func (pr *percentRank) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4Rank{}) +} + +func (pr *percentRank) ResetPartialResult(partial PartialResult) { + p := (*partialResult4Rank)(partial) + p.curIdx = 0 + p.lastRank = 0 + p.rows = p.rows[:0] +} + +func (pr *percentRank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, partial PartialResult) error { + p := (*partialResult4Rank)(partial) + p.rows = append(p.rows, rowsInGroup...) + return nil +} + +func (pr *percentRank) AppendFinalResult2Chunk(sctx sessionctx.Context, partial PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4Rank)(partial) + numRows := int64(len(p.rows)) + p.curIdx++ + if p.curIdx == 1 { + p.lastRank = 1 + chk.AppendFloat64(pr.ordinal, 0) + return nil + } + if pr.compareRows(p.rows[p.curIdx-2], p.rows[p.curIdx-1]) == 0 { + chk.AppendFloat64(pr.ordinal, float64(p.lastRank-1)/float64(numRows-1)) + return nil + } + p.lastRank = p.curIdx + chk.AppendFloat64(pr.ordinal, float64(p.lastRank-1)/float64(numRows-1)) + return nil +} diff --git a/executor/window_test.go b/executor/window_test.go index 76cc408a654d9..5b108ddabb25d 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -119,4 +119,11 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result.Check(testkit.Rows("1 2", "1 2", "2 2", "2 2")) result = tk.MustQuery("select a, nth_value(a, 5) over() from t") result.Check(testkit.Rows("1 ", "1 ", "2 ", "2 ")) + + result = tk.MustQuery("select a, percent_rank() over() from t") + result.Check(testkit.Rows("1 0", "1 0", "2 0", "2 0")) + result = tk.MustQuery("select a, percent_rank() over(order by a) from t") + result.Check(testkit.Rows("1 0", "1 0", "2 0.6666666666666666", "2 0.6666666666666666")) + result = tk.MustQuery("select a, b, percent_rank() over(order by a, b) from t") + result.Check(testkit.Rows("1 1 0", "1 2 0.3333333333333333", "2 1 0.6666666666666666", "2 2 1")) } diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 9e8db54b5d560..91462d5ed62da 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -100,6 +100,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { a.typeInfer4NumberFuncs() case ast.WindowFuncCumeDist: a.typeInfer4CumeDist() + case ast.WindowFuncPercentRank: + a.typeInfer4PercentRank() default: panic("unsupported agg function: " + a.Name) } @@ -200,6 +202,11 @@ func (a *baseFuncDesc) typeInfer4CumeDist() { a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec } +func (a *baseFuncDesc) typeInfer4PercentRank() { + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flag, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec +} + // GetDefaultValue gets the default value when the function's input is null. // According to MySQL, default values of the function are listed as follows: // e.g.