Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: avoid sum from avg overflow (#30010) #30378

Merged
merged 8 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ type baseAggFunc struct {
// used to append the final result of this function.
ordinal int

// frac stores digits of the fractional part of decimals,
// which makes the decimal be the result of type inferring.
frac int
// retTp means the target type of the final agg should return.
retTp *types.FieldType
}

func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) (memDelta int64, err error) {
Expand Down
31 changes: 5 additions & 26 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}

// If HasDistinct and mode is CompleteMode or Partial1Mode, we should
Expand Down Expand Up @@ -252,13 +253,9 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
Expand Down Expand Up @@ -286,16 +283,8 @@ func buildAvg(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if len(base.args) == 2 {
frac = base.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

switch aggFuncDesc.Mode {
// Build avg functions which consume the original data and remove the
// duplicated input of the same group.
Expand Down Expand Up @@ -339,13 +328,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down Expand Up @@ -391,15 +375,10 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
retTp: aggFuncDesc.RetTp,
},
isMax: isMax,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)

evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
evalType = types.ETString
Expand Down
20 changes: 18 additions & 2 deletions executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -71,7 +73,14 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down Expand Up @@ -259,7 +268,14 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return err
}
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of avg should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -475,7 +477,14 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of first_row should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
10 changes: 9 additions & 1 deletion executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -811,7 +812,14 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of max or min should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion executor/aggfuncs/func_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package aggfuncs
import (
"unsafe"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -168,7 +170,14 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if e.retTp == nil {
return errors.New("e.retTp of sum should not be nil")
}
frac := e.retTp.Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
if err != nil {
return err
}
Expand Down
22 changes: 22 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,28 @@ func (s *testSuiteAgg) TestIssue23277(c *C) {
tk.MustExec("drop table t;")
}

func TestAvgDecimal(t *testing.T) {
t.Parallel()
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test;")
tk.MustExec("drop table if exists td;")
tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));")
tk.MustExec("insert into td values (null, 22876);")
tk.MustExec("insert into td values (9220557287087669248, 32767);")
tk.MustExec("insert into td values (28030, 32767);")
tk.MustExec("insert into td values (-3309864251140603904,32767);")
tk.MustExec("insert into td values (4,0);")
tk.MustExec("insert into td values (null,0);")
tk.MustExec("insert into td values (4,-23828);")
tk.MustExec("insert into td values (54720,32767);")
tk.MustExec("insert into td values (0,29815);")
tk.MustExec("insert into td values (10017,-32661);")
tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260"))
tk.MustExec("drop table td;")
}

// https://github.com/pingcap/tidb/issues/23314
func (s *testSuiteAgg) TestIssue23314(c *C) {
tk := testkit.NewTestKit(c, s.store)
Expand Down
42 changes: 42 additions & 0 deletions executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,48 @@ func (s *tiflashTestSuite) TestUnionWithEmptyDualTable(c *C) {
tk.MustQuery("select count(*) from (select a , b from t union all select a , c from t1 where false) tt").Check(testkit.Rows("1"))
}

func (s *tiflashTestSuite) TestAvgOverflow(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
// avg int
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a decimal(1,0))")
tk.MustExec("alter table t set tiflash replica 1")
tb := testGetTableByName(c, tk.Se, "test", "t")
err := domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into t values(9)")
for i := 0; i < 16; i++ {
tk.MustExec("insert into t select * from t")
}
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")
tk.MustExec("set @@session.tidb_enforce_mpp=ON")
tk.MustQuery("select avg(a) from t group by a").Check(testkit.Rows("9.0000"))
tk.MustExec("drop table if exists t")

// avg decimal
tk.MustExec("drop table if exists td;")
tk.MustExec("create table td (col_bigint bigint(20), col_smallint smallint(6));")
tk.MustExec("alter table td set tiflash replica 1")
tb = testGetTableByName(c, tk.Se, "test", "td")
err = domain.GetDomain(tk.Se).DDL().UpdateTableReplicaInfo(tk.Se, tb.Meta().ID, true)
c.Assert(err, IsNil)
tk.MustExec("insert into td values (null, 22876);")
tk.MustExec("insert into td values (9220557287087669248, 32767);")
tk.MustExec("insert into td values (28030, 32767);")
tk.MustExec("insert into td values (-3309864251140603904,32767);")
tk.MustExec("insert into td values (4,0);")
tk.MustExec("insert into td values (null,0);")
tk.MustExec("insert into td values (4,-23828);")
tk.MustExec("insert into td values (54720,32767);")
tk.MustExec("insert into td values (0,29815);")
tk.MustExec("insert into td values (10017,-32661);")
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tiflash\"")
tk.MustExec("set @@session.tidb_enforce_mpp=ON")
tk.MustQuery(" SELECT AVG( col_bigint / col_smallint) AS field1 FROM td;").Sort().Check(testkit.Rows("25769363061037.62077260"))
tk.MustExec("drop table if exists td;")
}

func (s *tiflashTestSuite) TestMppApply(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
32 changes: 17 additions & 15 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type baseFuncDesc struct {

func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) {
b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
err := b.typeInfer(ctx)
err := b.TypeInfer(ctx)
return b, err
}

Expand Down Expand Up @@ -83,8 +83,8 @@ func (a *baseFuncDesc) String() string {
return buffer.String()
}

// typeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error {
// TypeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) TypeInfer(ctx sessionctx.Context) error {
switch a.Name {
case ast.AggFuncCount:
a.typeInfer4Count(ctx)
Expand Down Expand Up @@ -206,6 +206,14 @@ func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) {
types.SetBinChsClnFlag(a.RetTp)
}

// TypeInfer4AvgSum infers the type of sum from avg, which should extend the precision of decimal
// compatible with mysql.
func (a *baseFuncDesc) TypeInfer4AvgSum(avgRetType *types.FieldType) {
if avgRetType.Tp == mysql.TypeNewDecimal {
a.RetTp.Flen = mathutil.Min(mysql.MaxDecimalWidth, a.RetTp.Flen+22)
}
}

// typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Avg(ctx sessionctx.Context) {
Expand Down Expand Up @@ -245,6 +253,12 @@ func (a *baseFuncDesc) typeInfer4GroupConcat(ctx sessionctx.Context) {

a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxBlobWidth, 0
// TODO: a.Args[i] = expression.WrapWithCastAsString(ctx, a.Args[i])
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}

}

func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) {
Expand Down Expand Up @@ -368,18 +382,6 @@ var noNeedCastAggFuncs = map[string]struct{}{
ast.AggFuncJsonObjectAgg: {},
}

// WrapCastAsDecimalForAggArgs wraps the args of some specific aggregate functions
// with a cast as decimal function. See issue #19426
func (a *baseFuncDesc) WrapCastAsDecimalForAggArgs(ctx sessionctx.Context) {
if a.Name == ast.AggFuncGroupConcat {
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}
}
}

// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
if len(a.Args) == 0 {
Expand Down
1 change: 0 additions & 1 deletion planner/core/rule_inject_extra_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll {
// since the types of the args are already the expected.
func wrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc) {
for i := range aggFuncs {
aggFuncs[i].WrapCastAsDecimalForAggArgs(sctx)
if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode {
aggFuncs[i].WrapCastForAggArgs(sctx)
}
Expand Down
Loading