Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type: fix float over flow when converting a decimal to a float and then converting a float to an uint (#10405) #10730

Merged
merged 2 commits into from
Jun 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4048,3 +4048,12 @@ func (s *testIntegrationSuite) TestIssue9710(c *C) {
break
}
}

func (s *testIntegrationSuite) TestIssue10181(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a bigint unsigned primary key);`)
tk.MustExec(`insert into t values(9223372036854775807), (18446744073709551615)`)
tk.MustQuery(`select * from t where a > 9223372036854775807-0.5 order by a`).Check(testkit.Rows(`9223372036854775807`, `18446744073709551615`))
}
93 changes: 93 additions & 0 deletions types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,99 @@ func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound u
return uint64(val), nil
}

// convertScientificNotation converts a decimal string with scientific notation to a normal decimal string.
// 1E6 => 1000000, .12345E+5 => 12345
func convertScientificNotation(str string) (string, error) {
// https://golang.org/ref/spec#Floating-point_literals
eIdx := -1
point := -1
for i := 0; i < len(str); i++ {
if str[i] == '.' {
point = i
}
if str[i] == 'e' || str[i] == 'E' {
eIdx = i
if point == -1 {
point = i
}
break
}
}
if eIdx == -1 {
return str, nil
}
exp, err := strconv.ParseInt(str[eIdx+1:], 10, 64)
if err != nil {
return "", errors.WithStack(err)
}

f := str[:eIdx]
if exp == 0 {
return f, nil
} else if exp > 0 { // move point right
if point+int(exp) == len(f)-1 { // 123.456 >> 3 = 123456. = 123456
return f[:point] + f[point+1:], nil
} else if point+int(exp) < len(f)-1 { // 123.456 >> 2 = 12345.6
return f[:point] + f[point+1:point+1+int(exp)] + "." + f[point+1+int(exp):], nil
}
// 123.456 >> 5 = 12345600
return f[:point] + f[point+1:] + strings.Repeat("0", point+int(exp)-len(f)+1), nil
} else { // move point left
exp = -exp
if int(exp) < point { // 123.456 << 2 = 1.23456
return f[:point-int(exp)] + "." + f[point-int(exp):point] + f[point+1:], nil
}
// 123.456 << 5 = 0.00123456
return "0." + strings.Repeat("0", int(exp)-point) + f[:point] + f[point+1:], nil
}
}

func convertDecimalStrToUint(sc *stmtctx.StatementContext, str string, upperBound uint64, tp byte) (uint64, error) {
str, err := convertScientificNotation(str)
if err != nil {
return 0, err
}

var intStr, fracStr string
p := strings.Index(str, ".")
if p == -1 {
intStr = str
} else {
intStr = str[:p]
fracStr = str[p+1:]
}
intStr = strings.TrimLeft(intStr, "0")
if intStr == "" {
intStr = "0"
}
if sc.ShouldClipToZero() && intStr[0] == '-' {
return 0, overflow(intStr, tp)
}

var round uint64
if fracStr != "" && fracStr[0] >= '5' {
round++
}

upperBound -= round
upperStr := strconv.FormatUint(upperBound, 10)
if len(intStr) > len(upperStr) ||
(len(intStr) == len(upperStr) && intStr > upperStr) {
return upperBound, overflow(str, tp)
}

val, err := strconv.ParseUint(intStr, 10, 64)
if err != nil {
return val, err
}
return val + round, nil
}

// ConvertDecimalToUint converts a decimal to a uint by converting it to a string first to avoid float overflow (#10181).
func ConvertDecimalToUint(sc *stmtctx.StatementContext, d *MyDecimal, upperBound uint64, tp byte) (uint64, error) {
return convertDecimalStrToUint(sc, string(d.ToString()), upperBound, tp)
}

// StrToInt converts a string to an integer at the best-effort.
func StrToInt(sc *stmtctx.StatementContext, str string) (int64, error) {
str = strings.TrimSpace(str)
Expand Down
85 changes: 85 additions & 0 deletions types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,88 @@ func (s *testTypeConvertSuite) TestNumberToDuration(c *C) {
c.Assert(dur.Duration, Equals, tc.dur)
}
}

func (s *testTypeConvertSuite) TestStrToDuration(c *C) {
sc := new(stmtctx.StatementContext)
var tests = []struct {
str string
fsp int
isDuration bool
}{
{"20190412120000", 4, false},
{"20190101180000", 6, false},
{"20190101180000", 1, false},
{"20190101181234", 3, false},
}
for _, tt := range tests {
_, _, isDuration, err := StrToDuration(sc, tt.str, tt.fsp)
c.Assert(err, IsNil)
c.Assert(isDuration, Equals, tt.isDuration)
}
}

func (s *testTypeConvertSuite) TestConvertScientificNotation(c *C) {
cases := []struct {
input string
output string
succ bool
}{
{"123.456e0", "123.456", true},
{"123.456e1", "1234.56", true},
{"123.456e3", "123456", true},
{"123.456e4", "1234560", true},
{"123.456e5", "12345600", true},
{"123.456e6", "123456000", true},
{"123.456e7", "1234560000", true},
{"123.456e-1", "12.3456", true},
{"123.456e-2", "1.23456", true},
{"123.456e-3", "0.123456", true},
{"123.456e-4", "0.0123456", true},
{"123.456e-5", "0.00123456", true},
{"123.456e-6", "0.000123456", true},
{"123.456e-7", "0.0000123456", true},
{"123.456e-", "", false},
{"123.456e-7.5", "", false},
{"123.456e", "", false},
}
for _, ca := range cases {
result, err := convertScientificNotation(ca.input)
if !ca.succ {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
c.Assert(ca.output, Equals, result)
}
}
}

func (s *testTypeConvertSuite) TestConvertDecimalStrToUint(c *C) {
cases := []struct {
input string
result uint64
succ bool
}{
{"0.", 0, true},
{"72.40", 72, true},
{"072.40", 72, true},
{"123.456e2", 12346, true},
{"123.456e-2", 1, true},
{"072.50000000001", 73, true},
{".5757", 1, true},
{".12345E+4", 1235, true},
{"9223372036854775807.5", 9223372036854775808, true},
{"9223372036854775807.4999", 9223372036854775807, true},
{"18446744073709551614.55", 18446744073709551615, true},
{"18446744073709551615.344", 18446744073709551615, true},
{"18446744073709551615.544", 0, false},
}
for _, ca := range cases {
result, err := convertDecimalStrToUint(&stmtctx.StatementContext{}, ca.input, math.MaxUint64, 0)
if !ca.succ {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
c.Assert(result, Equals, ca.result)
}
}
}
6 changes: 1 addition & 5 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -902,11 +902,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
val, err = ConvertIntToUint(sc, ival, upperBound, tp)
}
case KindMysqlDecimal:
fval, err1 := d.GetMysqlDecimal().ToFloat64()
val, err = ConvertFloatToUint(sc, fval, upperBound, tp)
if err == nil {
err = err1
}
val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp)
case KindMysqlEnum:
val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp)
case KindMysqlSet:
Expand Down