Skip to content

Commit

Permalink
executor: avoid sum from avg overflow (#30010) (#30380)
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Dec 17, 2021
1 parent 44cedf4 commit 663e51d
Show file tree
Hide file tree
Showing 12 changed files with 156 additions and 63 deletions.
5 changes: 2 additions & 3 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,8 @@ type baseAggFunc struct {
// used to append the final result of this function.
ordinal int

// frac stores digits of the fractional part of decimals,
// which makes the decimal be the result of type inferring.
frac int
// retTp means the target type of the final agg should return.
retTp *types.FieldType
}

func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) (memDelta int64, err error) {
Expand Down
32 changes: 5 additions & 27 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"fmt"
"strconv"

"github.com/cznic/mathutil"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
Expand Down Expand Up @@ -191,6 +190,7 @@ func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}

// If HasDistinct and mode is CompleteMode or Partial1Mode, we should
Expand Down Expand Up @@ -250,13 +250,9 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
Expand Down Expand Up @@ -284,16 +280,8 @@ func buildAvg(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if len(base.args) == 2 {
frac = base.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

switch aggFuncDesc.Mode {
// Build avg functions which consume the original data and remove the
// duplicated input of the same group.
Expand Down Expand Up @@ -337,13 +325,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down Expand Up @@ -389,15 +372,10 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
isMax: isMax,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down
20 changes: 18 additions & 2 deletions executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -71,7 +73,14 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down Expand Up @@ -257,7 +266,14 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -475,7 +477,14 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of first_row should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"container/heap"
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -783,7 +785,14 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of max or min should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -166,7 +168,14 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of sum should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
19 changes: 19 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1450,3 +1450,22 @@ func (s *testSuiteAgg) TestIssue23277(c *C) {
tk.MustQuery("select avg(a) from t group by a").Sort().Check(testkit.Rows("-120.0000", "127.0000"))
tk.MustExec("drop table t;")
}

func (s *testSuiteAgg) TestAvgDecimal(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("use test;")
tk.MustExec("drop table if exists td;")
tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));")
tk.MustExec("insert into td values (null, 22876);")
tk.MustExec("insert into td values (9220557287087669248, 32767);")
tk.MustExec("insert into td values (28030, 32767);")
tk.MustExec("insert into td values (-3309864251140603904,32767);")
tk.MustExec("insert into td values (4,0);")
tk.MustExec("insert into td values (null,0);")
tk.MustExec("insert into td values (4,-23828);")
tk.MustExec("insert into td values (54720,32767);")
tk.MustExec("insert into td values (0,29815);")
tk.MustExec("insert into td values (10017,-32661);")
tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260"))
tk.MustExec("drop table td;")
}
42 changes: 42 additions & 0 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,48 @@ func (s *tiflashTestSuite) TestMppUnionAll(c *C) {

}

func (s *tiflashTestSuite) TestAvgOverflow(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
// avg int
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a decimal(1,0))")
tk.MustExec("alter table t set tiflash replica 1")
tb := testGetTableByName(c, tk.Se, "test", "t")
err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into t values(9)")
for i := 0; i < 16; i++ {
tk.MustExec("insert into t select * from t")
}
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")
tk.MustExec("set @@session.tidb_enforce_mpp=ON")
tk.MustQuery("select avg(a) from t group by a").Check(testkit.Rows("9.0000"))
tk.MustExec("drop table if exists t")

// avg decimal
tk.MustExec("drop table if exists td;")
tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));")
tk.MustExec("alter table td set tiflash replica 1")
tb = testGetTableByName(c, tk.Se, "test", "td")
err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into td values (null, 22876);")
tk.MustExec("insert into td values (9220557287087669248, 32767);")
tk.MustExec("insert into td values (28030, 32767);")
tk.MustExec("insert into td values (-3309864251140603904,32767);")
tk.MustExec("insert into td values (4,0);")
tk.MustExec("insert into td values (null,0);")
tk.MustExec("insert into td values (4,-23828);")
tk.MustExec("insert into td values (54720,32767);")
tk.MustExec("insert into td values (0,29815);")
tk.MustExec("insert into td values (10017,-32661);")
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")
tk.MustExec("set @@session.tidb_enforce_mpp=ON")
tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260"))
tk.MustExec("drop table if exists td;")
}

func (s *tiflashTestSuite) TestUnionWithEmptyDualTable(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
32 changes: 17 additions & 15 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type baseFuncDesc struct {

func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) {
b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
err := b.typeInfer(ctx)
err := b.TypeInfer(ctx)
return b, err
}

Expand Down Expand Up @@ -83,8 +83,8 @@ func (a *baseFuncDesc) String() string {
return buffer.String()
}

// typeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error {
// TypeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error {
switch a.Name {
case ast.AggFuncCount:
a.typeInfer4Count(ctx)
Expand Down Expand Up @@ -204,6 +204,14 @@ func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
types.SetBinChsClnFlag(a.RetTp)
}

// TypeInfer4AvgSum infers the type of sum from avg, which should extend the precision of decimal
// compatible with mysql.
func (a *baseFuncDesc) TypeInfer4AvgSum(avgRetType *types.FieldType) {
if avgRetType.Tp == mysql.TypeNewDecimal {
a.RetTp.Flen = mathutil.Min(mysql.MaxDecimalWidth, a.RetTp.Flen+22)
}
}

// typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
Expand Down Expand Up @@ -243,6 +251,12 @@ func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) {

a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxBlobWidth, 0
// TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i])
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}

}

func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
Expand Down Expand Up @@ -365,18 +379,6 @@ var noNeedCastAggFuncs = map[string]struct{}{
ast.AggFuncJsonObjectAgg: {},
}

// WrapCastAsDecimalForAggArgs wraps the args of some specific aggregate functions
// with a cast as decimal function. See issue #19426
func (a *baseFuncDesc) WrapCastAsDecimalForAggArgs(ctx sessionctx.Context) {
if a.Name == ast.AggFuncGroupConcat {
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}
}
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
if len(a.Args) == 0 {
Expand Down
1 change: 0 additions & 1 deletion planner/core/rule_inject_extra_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll {
// since the types of the args are already the expected.
func wrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc) {
for i := range aggFuncs {
aggFuncs[i].WrapCastAsDecimalForAggArgs(sctx)
if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode {
aggFuncs[i].WrapCastForAggArgs(sctx)
}
Expand Down
Loading

0 comments on commit 663e51d

Please sign in to comment.