From d07bdb2e2b53160df0463a4667753ae4d567d430 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Thu, 6 Jun 2019 11:34:42 +0800 Subject: [PATCH] type: fix float over flow when converting a decimal to a float and then converting a float to an uint (#10730) --- expression/integration_test.go | 9 ++++ types/convert.go | 93 ++++++++++++++++++++++++++++++++++ types/convert_test.go | 85 +++++++++++++++++++++++++++++++ types/datum.go | 6 +-- 4 files changed, 188 insertions(+), 5 deletions(-) diff --git a/expression/integration_test.go b/expression/integration_test.go index 4e9eff689d92d..71ecef6c6b692 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -4048,3 +4048,12 @@ func (s *testIntegrationSuite) TestIssue9710(c *C) { break } } + +func (s *testIntegrationSuite) TestIssue10181(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a bigint unsigned primary key);`) + tk.MustExec(`insert into t values(9223372036854775807), (18446744073709551615)`) + tk.MustQuery(`select * from t where a > 9223372036854775807-0.5 order by a`).Check(testkit.Rows(`9223372036854775807`, `18446744073709551615`)) +} diff --git a/types/convert.go b/types/convert.go index 87e3cd82e24a8..4764e62dd1c7e 100644 --- a/types/convert.go +++ b/types/convert.go @@ -143,6 +143,99 @@ func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound u return uint64(val), nil } +// convertScientificNotation converts a decimal string with scientific notation to a normal decimal string. +// 1E6 => 1000000, .12345E+5 => 12345 +func convertScientificNotation(str string) (string, error) { + // https://golang.org/ref/spec#Floating-point_literals + eIdx := -1 + point := -1 + for i := 0; i < len(str); i++ { + if str[i] == '.' { + point = i + } + if str[i] == 'e' || str[i] == 'E' { + eIdx = i + if point == -1 { + point = i + } + break + } + } + if eIdx == -1 { + return str, nil + } + exp, err := strconv.ParseInt(str[eIdx+1:], 10, 64) + if err != nil { + return "", errors.WithStack(err) + } + + f := str[:eIdx] + if exp == 0 { + return f, nil + } else if exp > 0 { // move point right + if point+int(exp) == len(f)-1 { // 123.456 >> 3 = 123456. = 123456 + return f[:point] + f[point+1:], nil + } else if point+int(exp) < len(f)-1 { // 123.456 >> 2 = 12345.6 + return f[:point] + f[point+1:point+1+int(exp)] + "." + f[point+1+int(exp):], nil + } + // 123.456 >> 5 = 12345600 + return f[:point] + f[point+1:] + strings.Repeat("0", point+int(exp)-len(f)+1), nil + } else { // move point left + exp = -exp + if int(exp) < point { // 123.456 << 2 = 1.23456 + return f[:point-int(exp)] + "." + f[point-int(exp):point] + f[point+1:], nil + } + // 123.456 << 5 = 0.00123456 + return "0." + strings.Repeat("0", int(exp)-point) + f[:point] + f[point+1:], nil + } +} + +func convertDecimalStrToUint(sc *stmtctx.StatementContext, str string, upperBound uint64, tp byte) (uint64, error) { + str, err := convertScientificNotation(str) + if err != nil { + return 0, err + } + + var intStr, fracStr string + p := strings.Index(str, ".") + if p == -1 { + intStr = str + } else { + intStr = str[:p] + fracStr = str[p+1:] + } + intStr = strings.TrimLeft(intStr, "0") + if intStr == "" { + intStr = "0" + } + if sc.ShouldClipToZero() && intStr[0] == '-' { + return 0, overflow(intStr, tp) + } + + var round uint64 + if fracStr != "" && fracStr[0] >= '5' { + round++ + } + + upperBound -= round + upperStr := strconv.FormatUint(upperBound, 10) + if len(intStr) > len(upperStr) || + (len(intStr) == len(upperStr) && intStr > upperStr) { + return upperBound, overflow(str, tp) + } + + val, err := strconv.ParseUint(intStr, 10, 64) + if err != nil { + return val, err + } + return val + round, nil +} + +// ConvertDecimalToUint converts a decimal to a uint by converting it to a string first to avoid float overflow (#10181). +func ConvertDecimalToUint(sc *stmtctx.StatementContext, d *MyDecimal, upperBound uint64, tp byte) (uint64, error) { + return convertDecimalStrToUint(sc, string(d.ToString()), upperBound, tp) +} + // StrToInt converts a string to an integer at the best-effort. func StrToInt(sc *stmtctx.StatementContext, str string) (int64, error) { str = strings.TrimSpace(str) diff --git a/types/convert_test.go b/types/convert_test.go index b5f3f737be0d5..51816860f2795 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -880,3 +880,88 @@ func (s *testTypeConvertSuite) TestNumberToDuration(c *C) { c.Assert(dur.Duration, Equals, tc.dur) } } + +func (s *testTypeConvertSuite) TestStrToDuration(c *C) { + sc := new(stmtctx.StatementContext) + var tests = []struct { + str string + fsp int + isDuration bool + }{ + {"20190412120000", 4, false}, + {"20190101180000", 6, false}, + {"20190101180000", 1, false}, + {"20190101181234", 3, false}, + } + for _, tt := range tests { + _, _, isDuration, err := StrToDuration(sc, tt.str, tt.fsp) + c.Assert(err, IsNil) + c.Assert(isDuration, Equals, tt.isDuration) + } +} + +func (s *testTypeConvertSuite) TestConvertScientificNotation(c *C) { + cases := []struct { + input string + output string + succ bool + }{ + {"123.456e0", "123.456", true}, + {"123.456e1", "1234.56", true}, + {"123.456e3", "123456", true}, + {"123.456e4", "1234560", true}, + {"123.456e5", "12345600", true}, + {"123.456e6", "123456000", true}, + {"123.456e7", "1234560000", true}, + {"123.456e-1", "12.3456", true}, + {"123.456e-2", "1.23456", true}, + {"123.456e-3", "0.123456", true}, + {"123.456e-4", "0.0123456", true}, + {"123.456e-5", "0.00123456", true}, + {"123.456e-6", "0.000123456", true}, + {"123.456e-7", "0.0000123456", true}, + {"123.456e-", "", false}, + {"123.456e-7.5", "", false}, + {"123.456e", "", false}, + } + for _, ca := range cases { + result, err := convertScientificNotation(ca.input) + if !ca.succ { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + c.Assert(ca.output, Equals, result) + } + } +} + +func (s *testTypeConvertSuite) TestConvertDecimalStrToUint(c *C) { + cases := []struct { + input string + result uint64 + succ bool + }{ + {"0.", 0, true}, + {"72.40", 72, true}, + {"072.40", 72, true}, + {"123.456e2", 12346, true}, + {"123.456e-2", 1, true}, + {"072.50000000001", 73, true}, + {".5757", 1, true}, + {".12345E+4", 1235, true}, + {"9223372036854775807.5", 9223372036854775808, true}, + {"9223372036854775807.4999", 9223372036854775807, true}, + {"18446744073709551614.55", 18446744073709551615, true}, + {"18446744073709551615.344", 18446744073709551615, true}, + {"18446744073709551615.544", 0, false}, + } + for _, ca := range cases { + result, err := convertDecimalStrToUint(&stmtctx.StatementContext{}, ca.input, math.MaxUint64, 0) + if !ca.succ { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + c.Assert(result, Equals, ca.result) + } + } +} diff --git a/types/datum.go b/types/datum.go index 880e6312bcf97..7fa790c4ded85 100644 --- a/types/datum.go +++ b/types/datum.go @@ -902,11 +902,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( val, err = ConvertIntToUint(sc, ival, upperBound, tp) } case KindMysqlDecimal: - fval, err1 := d.GetMysqlDecimal().ToFloat64() - val, err = ConvertFloatToUint(sc, fval, upperBound, tp) - if err == nil { - err = err1 - } + val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp) case KindMysqlEnum: val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp) case KindMysqlSet: