Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support aggregate function stddev_samp() and var_samp() (#19810) #20036

Merged
merged 2 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal
return buildApproxCountDistinct(aggFuncDesc, ordinal)
case ast.AggFuncStddevPop:
return buildStdDevPop(aggFuncDesc, ordinal)
case ast.AggFuncVarSamp:
return buildVarSamp(aggFuncDesc, ordinal)
case ast.AggFuncStddevSamp:
return buildStddevSamp(aggFuncDesc, ordinal)
}
return nil
}
Expand Down Expand Up @@ -470,6 +474,44 @@ func buildStdDevPop(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
}
}

// buildVarSamp builds the AggFunc implementation for function "VAR_SAMP()"
func buildVarSamp(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseVarPopAggFunc{
baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
},
}
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
default:
if aggFuncDesc.HasDistinct {
return &varSamp4DistinctFloat64{varPop4DistinctFloat64{base}}
}
return &varSamp4Float64{varPop4Float64{base}}
}
}

// buildStddevSamp builds the AggFunc implementation for function "STDDEV_SAMP()"
func buildStddevSamp(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseVarPopAggFunc{
baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
},
}
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
default:
if aggFuncDesc.HasDistinct {
return &stddevSamp4DistinctFloat64{varPop4DistinctFloat64{base}}
}
return &stddevSamp4Float64{varPop4Float64{base}}
}
}

// buildJSONObjectAgg builds the AggFunc implementation for function "json_objectagg".
func buildJSONObjectAgg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
Expand Down
51 changes: 51 additions & 0 deletions executor/aggfuncs/func_stddevsamp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"math"

"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

type stddevSamp4Float64 struct {
varPop4Float64
}

func (e *stddevSamp4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopFloat64)(pr)
if p.count <= 1 {
chk.AppendNull(e.ordinal)
return nil
}
variance := p.variance / float64(p.count-1)
chk.AppendFloat64(e.ordinal, math.Sqrt(variance))
return nil
}

type stddevSamp4DistinctFloat64 struct {
varPop4DistinctFloat64
}

func (e *stddevSamp4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopDistinctFloat64)(pr)
if p.count <= 1 {
chk.AppendNull(e.ordinal)
return nil
}
variance := p.variance / float64(p.count-1)
chk.AppendFloat64(e.ordinal, math.Sqrt(variance))
return nil
}
25 changes: 25 additions & 0 deletions executor/aggfuncs/func_stddevsamp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package aggfuncs_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
)

func (s *testSuite) TestMergePartialResult4Stddevsamp(c *C) {
tests := []aggTest{
buildAggTester(ast.AggFuncStddevSamp, mysql.TypeDouble, 5, 1.5811388300841898, 1, 1.407885953173359),
}
for _, test := range tests {
s.testMergePartialResult(c, test)
}
}

func (s *testSuite) TestStddevsamp(c *C) {
tests := []aggTest{
buildAggTester(ast.AggFuncStddevSamp, mysql.TypeDouble, 5, nil, 1.5811388300841898),
}
for _, test := range tests {
s.testAggFunc(c, test)
}
}
49 changes: 49 additions & 0 deletions executor/aggfuncs/func_varsamp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

type varSamp4Float64 struct {
varPop4Float64
}

func (e *varSamp4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopFloat64)(pr)
if p.count <= 1 {
chk.AppendNull(e.ordinal)
return nil
}
variance := p.variance / float64(p.count-1)
chk.AppendFloat64(e.ordinal, variance)
return nil
}

type varSamp4DistinctFloat64 struct {
varPop4DistinctFloat64
}

func (e *varSamp4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4VarPopDistinctFloat64)(pr)
if p.count <= 1 {
chk.AppendNull(e.ordinal)
return nil
}
variance := p.variance / float64(p.count-1)
chk.AppendFloat64(e.ordinal, variance)
return nil
}
25 changes: 25 additions & 0 deletions executor/aggfuncs/func_varsamp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package aggfuncs_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
)

func (s *testSuite) TestMergePartialResult4Varsamp(c *C) {
tests := []aggTest{
buildAggTester(ast.AggFuncVarSamp, mysql.TypeDouble, 5, 2.5, 1, 1.9821428571428572),
}
for _, test := range tests {
s.testMergePartialResult(c, test)
}
}

func (s *testSuite) TestVarsamp(c *C) {
tests := []aggTest{
buildAggTester(ast.AggFuncVarSamp, mysql.TypeDouble, 5, nil, 2.5),
}
for _, test := range tests {
s.testAggFunc(c, test)
}
}
10 changes: 8 additions & 2 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,6 @@ func (s *testSuiteAgg) TestAggregation(c *C) {
_, err = tk.Exec("select std_samp(a) from t")
// TODO: Fix this error message.
c.Assert(errors.Cause(err).Error(), Equals, "[expression:1305]FUNCTION test.std_samp does not exist")
_, err = tk.Exec("select var_samp(a) from t")
c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_samp")

// For issue #14072: wrong result when using generated column with aggregate statement
tk.MustExec("drop table if exists t1;")
Expand Down Expand Up @@ -464,6 +462,14 @@ func (s *testSuiteAgg) TestAggregation(c *C) {
tk.MustQuery("select stddev_pop(b) from t1 group by a order by a;").Check(testkit.Rows("<nil>", "0", "0"))
tk.MustQuery("select std(b) from t1 group by a order by a;").Check(testkit.Rows("<nil>", "0", "0"))
tk.MustQuery("select stddev(b) from t1 group by a order by a;").Check(testkit.Rows("<nil>", "0", "0"))

//For var_samp()/stddev_samp()
tk.MustExec("drop table if exists t1;")
tk.MustExec("CREATE TABLE t1 (id int(11),value1 float(10,2));")
tk.MustExec("INSERT INTO t1 VALUES (1,0.00),(1,1.00), (1,2.00), (2,10.00), (2,11.00), (2,12.00), (2,13.00);")
result = tk.MustQuery("select id, stddev_pop(value1), var_pop(value1), stddev_samp(value1), var_samp(value1) from t1 group by id order by id;")
result.Check(testkit.Rows("1 0.816496580927726 0.6666666666666666 1 1", "2 1.118033988749895 1.25 1.2909944487358056 1.6666666666666667"))

// For issue #19676 The result of stddev_pop(distinct xxx) is wrong
tk.MustExec("drop table if exists t1;")
tk.MustExec("CREATE TABLE t1 (id int);")
Expand Down
4 changes: 4 additions & 0 deletions expression/aggregation/agg_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ func AggFuncToPBExpr(sc *stmtctx.StatementContext, client kv.Client, aggFunc *Ag
tp = tipb.ExprType_JsonObjectAgg
case ast.AggFuncStddevPop:
tp = tipb.ExprType_StddevPop
case ast.AggFuncVarSamp:
tp = tipb.ExprType_VarSamp
case ast.AggFuncStddevSamp:
tp = tipb.ExprType_StddevSamp
}
if !client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) {
return nil
Expand Down
16 changes: 4 additions & 12 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error {
a.typeInfer4PercentRank()
case ast.WindowFuncLead, ast.WindowFuncLag:
a.typeInfer4LeadLag(ctx)
case ast.AggFuncVarPop:
a.typeInfer4VarPop(ctx)
case ast.AggFuncStddevPop:
a.typeInfer4Std(ctx)
case ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp:
a.typeInfer4PopOrSamp(ctx)
case ast.AggFuncJsonObjectAgg:
a.typeInfer4JsonFuncs(ctx)
default:
Expand Down Expand Up @@ -253,14 +251,8 @@ func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) {
}
}

func (a *baseFuncDesc) typeInfer4VarPop(ctx sessionctx.Context) {
//var_pop's return value type is double
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}

func (a *baseFuncDesc) typeInfer4Std(ctx sessionctx.Context) {
//std's return value type is double
func (a *baseFuncDesc) typeInfer4PopOrSamp(ctx sessionctx.Context) {
//var_pop/std/var_samp/stddev_samp's return value type is double
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/rule_aggregation_push_down.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (a *aggregationPushDownSolver) isDecomposableWithJoin(fun *aggregation.AggF
return false
}
switch fun.Name {
case ast.AggFuncAvg, ast.AggFuncGroupConcat, ast.AggFuncVarPop, ast.AggFuncJsonObjectAgg, ast.AggFuncStddevPop:
case ast.AggFuncAvg, ast.AggFuncGroupConcat, ast.AggFuncVarPop, ast.AggFuncJsonObjectAgg, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp:
// TODO: Support avg push down.
return false
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow:
Expand Down