diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 721c9edc3ae3c..6ae26bafcec27 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -59,6 +59,10 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal return buildApproxCountDistinct(aggFuncDesc, ordinal) case ast.AggFuncStddevPop: return buildStdDevPop(aggFuncDesc, ordinal) + case ast.AggFuncVarSamp: + return buildVarSamp(aggFuncDesc, ordinal) + case ast.AggFuncStddevSamp: + return buildStddevSamp(aggFuncDesc, ordinal) } return nil } @@ -470,6 +474,44 @@ func buildStdDevPop(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { } } +// buildVarSamp builds the AggFunc implementation for function "VAR_SAMP()" +func buildVarSamp(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseVarPopAggFunc{ + baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + }, + } + switch aggFuncDesc.Mode { + case aggregation.DedupMode: + return nil + default: + if aggFuncDesc.HasDistinct { + return &varSamp4DistinctFloat64{varPop4DistinctFloat64{base}} + } + return &varSamp4Float64{varPop4Float64{base}} + } +} + +// buildStddevSamp builds the AggFunc implementation for function "STDDEV_SAMP()" +func buildStddevSamp(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseVarPopAggFunc{ + baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + }, + } + switch aggFuncDesc.Mode { + case aggregation.DedupMode: + return nil + default: + if aggFuncDesc.HasDistinct { + return &stddevSamp4DistinctFloat64{varPop4DistinctFloat64{base}} + } + return &stddevSamp4Float64{varPop4Float64{base}} + } +} + // buildJSONObjectAgg builds the AggFunc implementation for function "json_objectagg". func buildJSONObjectAgg(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { base := baseAggFunc{ diff --git a/executor/aggfuncs/func_stddevsamp.go b/executor/aggfuncs/func_stddevsamp.go new file mode 100644 index 0000000000000..443e0dc1cfd32 --- /dev/null +++ b/executor/aggfuncs/func_stddevsamp.go @@ -0,0 +1,51 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "math" + + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +type stddevSamp4Float64 struct { + varPop4Float64 +} + +func (e *stddevSamp4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4VarPopFloat64)(pr) + if p.count <= 1 { + chk.AppendNull(e.ordinal) + return nil + } + variance := p.variance / float64(p.count-1) + chk.AppendFloat64(e.ordinal, math.Sqrt(variance)) + return nil +} + +type stddevSamp4DistinctFloat64 struct { + varPop4DistinctFloat64 +} + +func (e *stddevSamp4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4VarPopDistinctFloat64)(pr) + if p.count <= 1 { + chk.AppendNull(e.ordinal) + return nil + } + variance := p.variance / float64(p.count-1) + chk.AppendFloat64(e.ordinal, math.Sqrt(variance)) + return nil +} diff --git a/executor/aggfuncs/func_stddevsamp_test.go b/executor/aggfuncs/func_stddevsamp_test.go new file mode 100644 index 0000000000000..50ad2c3fe9d05 --- /dev/null +++ b/executor/aggfuncs/func_stddevsamp_test.go @@ -0,0 +1,25 @@ +package aggfuncs_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" +) + +func (s *testSuite) TestMergePartialResult4Stddevsamp(c *C) { + tests := []aggTest{ + buildAggTester(ast.AggFuncStddevSamp, mysql.TypeDouble, 5, 1.5811388300841898, 1, 1.407885953173359), + } + for _, test := range tests { + s.testMergePartialResult(c, test) + } +} + +func (s *testSuite) TestStddevsamp(c *C) { + tests := []aggTest{ + buildAggTester(ast.AggFuncStddevSamp, mysql.TypeDouble, 5, nil, 1.5811388300841898), + } + for _, test := range tests { + s.testAggFunc(c, test) + } +} diff --git a/executor/aggfuncs/func_varsamp.go b/executor/aggfuncs/func_varsamp.go new file mode 100644 index 0000000000000..5b0b15338943b --- /dev/null +++ b/executor/aggfuncs/func_varsamp.go @@ -0,0 +1,49 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" +) + +type varSamp4Float64 struct { + varPop4Float64 +} + +func (e *varSamp4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4VarPopFloat64)(pr) + if p.count <= 1 { + chk.AppendNull(e.ordinal) + return nil + } + variance := p.variance / float64(p.count-1) + chk.AppendFloat64(e.ordinal, variance) + return nil +} + +type varSamp4DistinctFloat64 struct { + varPop4DistinctFloat64 +} + +func (e *varSamp4DistinctFloat64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4VarPopDistinctFloat64)(pr) + if p.count <= 1 { + chk.AppendNull(e.ordinal) + return nil + } + variance := p.variance / float64(p.count-1) + chk.AppendFloat64(e.ordinal, variance) + return nil +} diff --git a/executor/aggfuncs/func_varsamp_test.go b/executor/aggfuncs/func_varsamp_test.go new file mode 100644 index 0000000000000..f68c5da6c710d --- /dev/null +++ b/executor/aggfuncs/func_varsamp_test.go @@ -0,0 +1,25 @@ +package aggfuncs_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" +) + +func (s *testSuite) TestMergePartialResult4Varsamp(c *C) { + tests := []aggTest{ + buildAggTester(ast.AggFuncVarSamp, mysql.TypeDouble, 5, 2.5, 1, 1.9821428571428572), + } + for _, test := range tests { + s.testMergePartialResult(c, test) + } +} + +func (s *testSuite) TestVarsamp(c *C) { + tests := []aggTest{ + buildAggTester(ast.AggFuncVarSamp, mysql.TypeDouble, 5, nil, 2.5), + } + for _, test := range tests { + s.testAggFunc(c, test) + } +} diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 8ce4c9b91676b..0d9b6b15d82fd 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -409,8 +409,6 @@ func (s *testSuiteAgg) TestAggregation(c *C) { _, err = tk.Exec("select std_samp(a) from t") // TODO: Fix this error message. c.Assert(errors.Cause(err).Error(), Equals, "[expression:1305]FUNCTION test.std_samp does not exist") - _, err = tk.Exec("select var_samp(a) from t") - c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_samp") // For issue #14072: wrong result when using generated column with aggregate statement tk.MustExec("drop table if exists t1;") @@ -464,6 +462,14 @@ func (s *testSuiteAgg) TestAggregation(c *C) { tk.MustQuery("select stddev_pop(b) from t1 group by a order by a;").Check(testkit.Rows("", "0", "0")) tk.MustQuery("select std(b) from t1 group by a order by a;").Check(testkit.Rows("", "0", "0")) tk.MustQuery("select stddev(b) from t1 group by a order by a;").Check(testkit.Rows("", "0", "0")) + + //For var_samp()/stddev_samp() + tk.MustExec("drop table if exists t1;") + tk.MustExec("CREATE TABLE t1 (id int(11),value1 float(10,2));") + tk.MustExec("INSERT INTO t1 VALUES (1,0.00),(1,1.00), (1,2.00), (2,10.00), (2,11.00), (2,12.00), (2,13.00);") + result = tk.MustQuery("select id, stddev_pop(value1), var_pop(value1), stddev_samp(value1), var_samp(value1) from t1 group by id order by id;") + result.Check(testkit.Rows("1 0.816496580927726 0.6666666666666666 1 1", "2 1.118033988749895 1.25 1.2909944487358056 1.6666666666666667")) + // For issue #19676 The result of stddev_pop(distinct xxx) is wrong tk.MustExec("drop table if exists t1;") tk.MustExec("CREATE TABLE t1 (id int);") diff --git a/expression/aggregation/agg_to_pb.go b/expression/aggregation/agg_to_pb.go index f256667f3abaa..fffb0a29e8fc8 100644 --- a/expression/aggregation/agg_to_pb.go +++ b/expression/aggregation/agg_to_pb.go @@ -63,6 +63,10 @@ func AggFuncToPBExpr(sc *stmtctx.StatementContext, client kv.Client, aggFunc *Ag tp = tipb.ExprType_JsonObjectAgg case ast.AggFuncStddevPop: tp = tipb.ExprType_StddevPop + case ast.AggFuncVarSamp: + tp = tipb.ExprType_VarSamp + case ast.AggFuncStddevSamp: + tp = tipb.ExprType_StddevSamp } if !client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) { return nil diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index d74216a851663..031f924ba9bf4 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -109,10 +109,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error { a.typeInfer4PercentRank() case ast.WindowFuncLead, ast.WindowFuncLag: a.typeInfer4LeadLag(ctx) - case ast.AggFuncVarPop: - a.typeInfer4VarPop(ctx) - case ast.AggFuncStddevPop: - a.typeInfer4Std(ctx) + case ast.AggFuncVarPop, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp: + a.typeInfer4PopOrSamp(ctx) case ast.AggFuncJsonObjectAgg: a.typeInfer4JsonFuncs(ctx) default: @@ -253,14 +251,8 @@ func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) { } } -func (a *baseFuncDesc) typeInfer4VarPop(ctx sessionctx.Context) { - //var_pop's return value type is double - a.RetTp = types.NewFieldType(mysql.TypeDouble) - a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength -} - -func (a *baseFuncDesc) typeInfer4Std(ctx sessionctx.Context) { - //std's return value type is double +func (a *baseFuncDesc) typeInfer4PopOrSamp(ctx sessionctx.Context) { + //var_pop/std/var_samp/stddev_samp's return value type is double a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength } diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index c713a0badf92f..8ad7162e50879 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -38,7 +38,7 @@ func (a *aggregationPushDownSolver) isDecomposableWithJoin(fun *aggregation.AggF return false } switch fun.Name { - case ast.AggFuncAvg, ast.AggFuncGroupConcat, ast.AggFuncVarPop, ast.AggFuncJsonObjectAgg, ast.AggFuncStddevPop: + case ast.AggFuncAvg, ast.AggFuncGroupConcat, ast.AggFuncVarPop, ast.AggFuncJsonObjectAgg, ast.AggFuncStddevPop, ast.AggFuncVarSamp, ast.AggFuncStddevSamp: // TODO: Support avg push down. return false case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow: