Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix the wrong rounding behavior of Decimal #33278

Merged
merged 26 commits into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c0dab1e
fix ProduceDecWithSpecifiedTp
gengliqi Mar 21, 2022
6f7d6eb
Merge branch 'master' into fix-32213
gengliqi Mar 21, 2022
4454abe
update comments
gengliqi Mar 21, 2022
54c732a
fix test & add one test
gengliqi Mar 22, 2022
642114e
Merge branch 'master' into fix-32213
gengliqi Mar 22, 2022
5a3fb84
Merge branch 'master' into fix-32213
gengliqi Mar 22, 2022
634b476
Merge branch 'master' into fix-32213
gengliqi Mar 23, 2022
93e22ce
Merge branch 'master' into fix-32213
gengliqi Mar 24, 2022
140dc1d
address comments
gengliqi Mar 28, 2022
8ecbb00
Merge branch 'master' into fix-32213
gengliqi Mar 29, 2022
b679b2c
Merge branch 'master' into fix-32213
gengliqi Mar 29, 2022
95b4665
Merge branch 'master' into fix-32213
ti-chi-bot Mar 30, 2022
f0eebc7
Merge branch 'master' into fix-32213
ti-chi-bot Mar 30, 2022
5ed6649
Merge branch 'master' into fix-32213
ti-chi-bot Mar 30, 2022
7f3fe40
Merge branch 'master' into fix-32213
ti-chi-bot Mar 30, 2022
2ef3bd1
Merge branch 'master' into fix-32213
ti-chi-bot Mar 30, 2022
77d1d30
Merge branch 'master' into fix-32213
ti-chi-bot Mar 30, 2022
3bd807e
Merge branch 'master' into fix-32213
ti-chi-bot Mar 31, 2022
3489cac
fix nil pointer
gengliqi Mar 31, 2022
27ae425
Merge branch 'master' into fix-32213
ti-chi-bot Mar 31, 2022
fd18575
fix nil pointer
gengliqi Apr 1, 2022
82aaff6
fix lint
gengliqi Apr 1, 2022
e1ffc4a
Merge branch 'master' into fix-32213
gengliqi Apr 1, 2022
43fe0d2
Merge branch 'master' into fix-32213
ti-chi-bot Apr 1, 2022
34bcad3
Merge branch 'master' into fix-32213
ti-chi-bot Apr 1, 2022
bd25016
Merge branch 'master' into fix-32213
ti-chi-bot Apr 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions executor/aggfuncs/func_avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
err = finalResult.Round(finalResult, frac, types.ModeHalfUp)
if err != nil {
return err
}
Expand Down Expand Up @@ -276,7 +276,7 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, frac, types.ModeHalfEven)
err = finalResult.Round(finalResult, frac, types.ModeHalfUp)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
err := p.val.Round(&p.val, frac, types.ModeHalfUp)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
err := p.val.Round(&p.val, frac, types.ModeHalfUp)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion executor/aggfuncs/func_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err := p.val.Round(&p.val, frac, types.ModeHalfEven)
err := p.val.Round(&p.val, frac, types.ModeHalfUp)
if err != nil {
return err
}
Expand Down
19 changes: 19 additions & 0 deletions executor/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1940,3 +1940,22 @@ func TestInsertIntoSelectError(t *testing.T) {
tk.MustQuery("SELECT * FROM t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustExec("DROP TABLE t1;")
}

// https://github.com/pingcap/tidb/issues/32213.
func TestIssue32213(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)
tk.MustExec(`use test`)

tk.MustExec("create table test.t1(c1 float)")
tk.MustExec("insert into test.t1 values(999.99)")
tk.MustQuery("select cast(test.t1.c1 as decimal(4, 1)) from test.t1").Check(testkit.Rows("999.9"))
tk.MustQuery("select cast(test.t1.c1 as decimal(5, 1)) from test.t1").Check(testkit.Rows("1000.0"))

tk.MustExec("drop table if exists test.t1")
tk.MustExec("create table test.t1(c1 decimal(6, 4))")
tk.MustExec("insert into test.t1 values(99.9999)")
tk.MustQuery("select cast(test.t1.c1 as decimal(5, 3)) from test.t1").Check(testkit.Rows("99.999"))
tk.MustQuery("select cast(test.t1.c1 as decimal(6, 3)) from test.t1").Check(testkit.Rows("100.000"))
}
2 changes: 1 addition & 1 deletion expression/aggregation/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (af *avgFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Datum) {
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = to.Round(to, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
err = to.Round(to, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfUp)
terror.Log(err)
d.SetMysqlDecimal(to)
}
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ func (s *builtinArithmeticDivideDecimalSig) evalDecimal(row chunk.Row) (*types.M
} else if err == nil {
_, frac := c.PrecisionAndFrac()
if frac < s.baseBuiltinFunc.tp.Decimal {
err = c.Round(c, s.baseBuiltinFunc.tp.Decimal, types.ModeHalfEven)
err = c.Round(c, s.baseBuiltinFunc.tp.Decimal, types.ModeHalfUp)
}
} else if err == types.ErrOverflow {
err = types.ErrOverflow.GenWithStackByArgs("DECIMAL", fmt.Sprintf("(%s / %s)", s.args[0].String(), s.args[1].String()))
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_arithmetic_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (b *builtinArithmeticDivideDecimalSig) vecEvalDecimal(input *chunk.Chunk, r
} else if err == nil {
_, frac = to.PrecisionAndFrac()
if frac < b.baseBuiltinFunc.tp.Decimal {
if err = to.Round(&to, b.baseBuiltinFunc.tp.Decimal, types.ModeHalfEven); err != nil {
if err = to.Round(&to, b.baseBuiltinFunc.tp.Decimal, types.ModeHalfUp); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ func (b *builtinCastDecimalAsIntSig) evalInt(row chunk.Row) (res int64, isNull b

// Round is needed for both unsigned and signed.
var to types.MyDecimal
err = val.Round(&to, 0, types.ModeHalfEven)
err = val.Round(&to, 0, types.ModeHalfUp)
if err != nil {
return 0, true, err
}
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_cast_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -1743,7 +1743,7 @@ func (b *builtinCastDecimalAsIntSig) vecEvalInt(input *chunk.Chunk, result *chun

// Round is needed for both unsigned and signed.
to := d64s[i]
err = d64s[i].Round(&to, 0, types.ModeHalfEven)
err = d64s[i].Round(&to, 0, types.ModeHalfUp)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ func (b *builtinRoundDecSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool,
return nil, isNull, err
}
to := new(types.MyDecimal)
if err = val.Round(to, 0, types.ModeHalfEven); err != nil {
if err = val.Round(to, 0, types.ModeHalfUp); err != nil {
return nil, true, err
}
return to, false, nil
Expand Down Expand Up @@ -469,7 +469,7 @@ func (b *builtinRoundWithFracDecSig) evalDecimal(row chunk.Row) (*types.MyDecima
return nil, isNull, err
}
to := new(types.MyDecimal)
if err = val.Round(to, mathutil.Min(int(frac), b.tp.Decimal), types.ModeHalfEven); err != nil {
if err = val.Round(to, mathutil.Min(int(frac), b.tp.Decimal), types.ModeHalfUp); err != nil {
return nil, true, err
}
return to, false, nil
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_math_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ func (b *builtinRoundDecSig) vecEvalDecimal(input *chunk.Chunk, result *chunk.Co
if result.IsNull(i) {
continue
}
if err := d64s[i].Round(buf, 0, types.ModeHalfEven); err != nil {
if err := d64s[i].Round(buf, 0, types.ModeHalfUp); err != nil {
return err
}
d64s[i] = *buf
Expand Down Expand Up @@ -994,7 +994,7 @@ func (b *builtinRoundWithFracDecSig) vecEvalDecimal(input *chunk.Chunk, result *
continue
}
// TODO: reuse d64[i] and remove the temporary variable tmp.
if err := d64s[i].Round(tmp, mathutil.Min(int(i64s[i]), b.tp.Decimal), types.ModeHalfEven); err != nil {
if err := d64s[i].Round(tmp, mathutil.Min(int(i64s[i]), b.tp.Decimal), types.ModeHalfUp); err != nil {
return err
}
d64s[i] = *tmp
Expand Down
8 changes: 4 additions & 4 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,7 @@ func evalFromUnixTime(ctx sessionctx.Context, fsp int, unixTimeStamp *types.MyDe

sc := ctx.GetSessionVars().StmtCtx
tmp := time.Unix(integralPart, fractionalPart).In(sc.TimeZone)
t, err := convertTimeToMysqlTime(tmp, fsp, types.ModeHalfEven)
t, err := convertTimeToMysqlTime(tmp, fsp, types.ModeHalfUp)
if err != nil {
return res, true, err
}
Expand Down Expand Up @@ -2050,7 +2050,7 @@ func (b *builtinSysDateWithFspSig) evalTime(row chunk.Row) (d types.Time, isNull

loc := b.ctx.GetSessionVars().Location()
now := time.Now().In(loc)
result, err := convertTimeToMysqlTime(now, int(fsp), types.ModeHalfEven)
result, err := convertTimeToMysqlTime(now, int(fsp), types.ModeHalfUp)
if err != nil {
return types.ZeroTime, true, err
}
Expand All @@ -2072,7 +2072,7 @@ func (b *builtinSysDateWithoutFspSig) Clone() builtinFunc {
func (b *builtinSysDateWithoutFspSig) evalTime(row chunk.Row) (d types.Time, isNull bool, err error) {
tz := b.ctx.GetSessionVars().Location()
now := time.Now().In(tz)
result, err := convertTimeToMysqlTime(now, 0, types.ModeHalfEven)
result, err := convertTimeToMysqlTime(now, 0, types.ModeHalfUp)
if err != nil {
return types.ZeroTime, true, err
}
Expand Down Expand Up @@ -2393,7 +2393,7 @@ func evalUTCTimestampWithFsp(ctx sessionctx.Context, fsp int) (types.Time, bool,
if err != nil {
return types.ZeroTime, true, err
}
result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp, types.ModeHalfEven)
result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp, types.ModeHalfUp)
if err != nil {
return types.ZeroTime, true, err
}
Expand Down
4 changes: 2 additions & 2 deletions expression/builtin_time_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (b *builtinSysDateWithoutFspSig) vecEvalTime(input *chunk.Chunk, result *ch

result.ResizeTime(n, false)
times := result.Times()
t, err := convertTimeToMysqlTime(now, 0, types.ModeHalfEven)
t, err := convertTimeToMysqlTime(now, 0, types.ModeHalfUp)
if err != nil {
return err
}
Expand Down Expand Up @@ -775,7 +775,7 @@ func (b *builtinSysDateWithFspSig) vecEvalTime(input *chunk.Chunk, result *chunk
if result.IsNull(i) {
continue
}
t, err := convertTimeToMysqlTime(now, int(ds[i]), types.ModeHalfEven)
t, err := convertTimeToMysqlTime(now, int(ds[i]), types.ModeHalfUp)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func (c *Constant) EvalDecimal(ctx sessionctx.Context, row chunk.Row) (*types.My
// The decimal may be modified during plan building.
_, frac := res.PrecisionAndFrac()
if frac < c.GetType().Decimal {
err = res.Round(res, c.GetType().Decimal, types.ModeHalfEven)
err = res.Round(res, c.GetType().Decimal, types.ModeHalfUp)
}
return res, false, err
}
Expand Down
2 changes: 1 addition & 1 deletion types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ func TestConvert(t *testing.T) {
dec := NewDecFromInt(-123)
err := dec.Shift(-5)
require.NoError(t, err)
err = dec.Round(dec, 5, ModeHalfEven)
err = dec.Round(dec, 5, ModeHalfUp)
require.NoError(t, err)
signedAccept(t, mysql.TypeNewDecimal, dec, "-0.00123")
}
Expand Down
41 changes: 24 additions & 17 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
}
case KindMysqlTime:
dec := d.GetMysqlTime().ToNumber()
err = dec.Round(dec, 0, ModeHalfEven)
err = dec.Round(dec, 0, ModeHalfUp)
ival, err1 := dec.ToInt()
if err == nil {
err = err1
Expand All @@ -1152,7 +1152,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
}
case KindMysqlDuration:
dec := d.GetMysqlDuration().ToNumber()
err = dec.Round(dec, 0, ModeHalfEven)
err = dec.Round(dec, 0, ModeHalfUp)
ival, err1 := dec.ToInt()
if err1 == nil {
val, err = ConvertIntToUint(sc, ival, upperBound, tp)
Expand Down Expand Up @@ -1429,20 +1429,27 @@ func ProduceDecWithSpecifiedTp(dec *MyDecimal, tp *FieldType, sc *stmtctx.Statem
if flen < decimal {
return nil, ErrMBiggerThanD.GenWithStackByArgs("")
}
prec, frac := dec.PrecisionAndFrac()
if !dec.IsZero() && prec-frac > flen-decimal {
dec = NewMaxOrMinDec(dec.IsNegative(), flen, decimal)
// select (cast 111 as decimal(1)) causes a warning in MySQL.
err = ErrOverflow.GenWithStackByArgs("DECIMAL", fmt.Sprintf("(%d, %d)", flen, decimal))
} else if frac != decimal {
old := *dec
err = dec.Round(dec, decimal, ModeHalfEven)
if err != nil {
return nil, err
}
if !old.IsZero() && frac > decimal && dec.Compare(&old) != 0 {
sc.AppendWarning(ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", &old))
err = nil

var old *MyDecimal
isZero := dec.IsZero()

if !isZero && int(dec.digitsFrac) > decimal {
old = new(MyDecimal)
*old = *dec
}
err = dec.Round(dec, decimal, ModeHalfUp)
if err != nil {
return nil, err
}

if !isZero {
_, digitsInt := dec.removeLeadingZeros()
if flen-decimal < digitsInt {
gengliqi marked this conversation as resolved.
Show resolved Hide resolved
dec = NewMaxOrMinDec(dec.IsNegative(), flen, decimal)
// select cast(111 as decimal(1)) causes a warning in MySQL.
err = ErrOverflow.GenWithStackByArgs("DECIMAL", fmt.Sprintf("(%d, %d)", flen, decimal))
} else if old != nil && dec.Compare(old) != 0 {
sc.AppendWarning(ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", old))
}
}
}
Expand Down Expand Up @@ -1806,7 +1813,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e
return ival, errors.Trace(err)
case KindMysqlDecimal:
var to MyDecimal
err := d.GetMysqlDecimal().Round(&to, 0, ModeHalfEven)
err := d.GetMysqlDecimal().Round(&to, 0, ModeHalfUp)
ival, err1 := to.ToInt()
if err == nil {
err = err1
Expand Down
47 changes: 47 additions & 0 deletions types/datum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/types/json"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/hack"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -566,3 +567,49 @@ func BenchmarkCompareDatumByReflect(b *testing.B) {
reflect.DeepEqual(vals, vals1)
}
}

func TestProduceDecWithSpecifiedTp(t *testing.T) {
tests := []struct {
dec string
flen int
frac int
newDec string
isOverflow bool
gengliqi marked this conversation as resolved.
Show resolved Hide resolved
}{
{"0.0000", 4, 3, "0.000", false},
{"0.0001", 4, 3, "0.000", false},
{"123", 8, 5, "123.00000", false},
{"-123", 8, 5, "-123.00000", false},
{"123.899", 5, 2, "123.90", false},
{"-123.899", 5, 2, "-123.90", false},
{"123.899", 6, 2, "123.90", false},
{"-123.899", 6, 2, "-123.90", false},
{"123.99", 4, 1, "124.0", false},
{"123.99", 3, 0, "124", false},
{"-123.99", 3, 0, "-124", false},
{"123.99", 3, 1, "99.9", true},
{"-123.99", 3, 1, "-99.9", true},
{"99.9999", 5, 3, "99.999", true},
{"-99.9999", 5, 3, "-99.999", true},
{"99.9999", 6, 3, "100.000", false},
{"-99.9999", 6, 3, "-100.000", false},
}
sc := new(stmtctx.StatementContext)
for _, tt := range tests {
tp := &FieldType{
Tp: mysql.TypeNewDecimal,
Flen: tt.flen,
Decimal: tt.frac,
}
dec := NewDecFromStringForTest(tt.dec)
newDec, err := ProduceDecWithSpecifiedTp(dec, tp, sc)
if tt.isOverflow {
if !ErrOverflow.Equal(err) {
assert.FailNow(t, "Error is not overflow", err)
}
} else {
require.NoError(t, err, tt)
}
require.Equal(t, tt.newDec, newDec.String())
}
}
Loading