Skip to content

Commit

Permalink
executor: add window function PERCENT_RANK (#9671)
Browse files Browse the repository at this point in the history
  • Loading branch information
winoros committed Mar 13, 2019
1 parent aef66b2 commit f5a4dd9
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 0 deletions.
9 changes: 9 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
return buildCumeDist(ordinal, orderByCols)
case ast.WindowFuncNthValue:
return buildNthValue(windowFuncDesc, ordinal)
case ast.WindowFuncPercentRank:
return buildPercenRank(ordinal, orderByCols)
default:
return Build(ctx, windowFuncDesc, ordinal)
}
Expand Down Expand Up @@ -386,3 +388,10 @@ func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
nth, _, _ := expression.GetUint64FromConstant(aggFuncDesc.Args[1])
return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth}
}

func buildPercenRank(ordinal int, orderByCols []*expression.Column) AggFunc {
base := baseAggFunc{
ordinal: ordinal,
}
return &percentRank{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)}
}
61 changes: 61 additions & 0 deletions executor/aggfuncs/func_percent_rank.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

// percentRank calculates the percentage of partition values less than the value in the current row, excluding the highest value.
// It can be calculated as `(rank - 1) / (total_rows_in_set - 1).
type percentRank struct {
baseAggFunc
rowComparer
}

func (pr *percentRank) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4Rank{})
}

func (pr *percentRank) ResetPartialResult(partial PartialResult) {
p := (*partialResult4Rank)(partial)
p.curIdx = 0
p.lastRank = 0
p.rows = p.rows[:0]
}

func (pr *percentRank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, partial PartialResult) error {
p := (*partialResult4Rank)(partial)
p.rows = append(p.rows, rowsInGroup...)
return nil
}

func (pr *percentRank) AppendFinalResult2Chunk(sctx sessionctx.Context, partial PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4Rank)(partial)
numRows := int64(len(p.rows))
p.curIdx++
if p.curIdx == 1 {
p.lastRank = 1
chk.AppendFloat64(pr.ordinal, 0)
return nil
}
if pr.compareRows(p.rows[p.curIdx-2], p.rows[p.curIdx-1]) == 0 {
chk.AppendFloat64(pr.ordinal, float64(p.lastRank-1)/float64(numRows-1))
return nil
}
p.lastRank = p.curIdx
chk.AppendFloat64(pr.ordinal, float64(p.lastRank-1)/float64(numRows-1))
return nil
}
7 changes: 7 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,11 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("1 2", "1 2", "2 2", "2 2"))
result = tk.MustQuery("select a, nth_value(a, 5) over() from t")
result.Check(testkit.Rows("1 <nil>", "1 <nil>", "2 <nil>", "2 <nil>"))

result = tk.MustQuery("select a, percent_rank() over() from t")
result.Check(testkit.Rows("1 0", "1 0", "2 0", "2 0"))
result = tk.MustQuery("select a, percent_rank() over(order by a) from t")
result.Check(testkit.Rows("1 0", "1 0", "2 0.6666666666666666", "2 0.6666666666666666"))
result = tk.MustQuery("select a, b, percent_rank() over(order by a, b) from t")
result.Check(testkit.Rows("1 1 0", "1 2 0.3333333333333333", "2 1 0.6666666666666666", "2 2 1"))
}
7 changes: 7 additions & 0 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4NumberFuncs()
case ast.WindowFuncCumeDist:
a.typeInfer4CumeDist()
case ast.WindowFuncPercentRank:
a.typeInfer4PercentRank()
default:
panic("unsupported agg function: " + a.Name)
}
Expand Down Expand Up @@ -200,6 +202,11 @@ func (a *baseFuncDesc) typeInfer4CumeDist() {
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
}

func (a *baseFuncDesc) typeInfer4PercentRank() {
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flag, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
}

// GetDefaultValue gets the default value when the function's input is null.
// According to MySQL, default values of the function are listed as follows:
// e.g.
Expand Down

0 comments on commit f5a4dd9

Please sign in to comment.