Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type: fix float over flow when converting a decimal to a float and then converting a float to an uint #10405

Merged
merged 10 commits into from
May 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4344,3 +4344,12 @@ func (s *testIntegrationSuite) TestDateTimeAddReal(c *C) {
tk.MustQuery(c.sql).Check(testkit.Rows(c.result))
}
}

func (s *testIntegrationSuite) TestIssue10181(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a bigint unsigned primary key);`)
tk.MustExec(`insert into t values(9223372036854775807), (18446744073709551615)`)
tk.MustQuery(`select * from t where a > 9223372036854775807-0.5 order by a`).Check(testkit.Rows(`9223372036854775807`, `18446744073709551615`))
}
93 changes: 93 additions & 0 deletions types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,99 @@ func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound u
return uint64(val), nil
}

// convertScientificNotation converts a decimal string with scientific notation to a normal decimal string.
// 1E6 => 1000000, .12345E+5 => 12345
func convertScientificNotation(str string) (string, error) {
// https://golang.org/ref/spec#Floating-point_literals
eIdx := -1
point := -1
for i := 0; i < len(str); i++ {
if str[i] == '.' {
point = i
}
if str[i] == 'e' || str[i] == 'E' {
eIdx = i
if point == -1 {
point = i
}
break
}
}
if eIdx == -1 {
return str, nil
}
exp, err := strconv.ParseInt(str[eIdx+1:], 10, 64)
if err != nil {
return "", errors.WithStack(err)
}

f := str[:eIdx]
if exp == 0 {
return f, nil
} else if exp > 0 { // move point right
if point+int(exp) == len(f)-1 { // 123.456 >> 3 = 123456. = 123456
return f[:point] + f[point+1:], nil
} else if point+int(exp) < len(f)-1 { // 123.456 >> 2 = 12345.6
return f[:point] + f[point+1:point+1+int(exp)] + "." + f[point+1+int(exp):], nil
}
// 123.456 >> 5 = 12345600
return f[:point] + f[point+1:] + strings.Repeat("0", point+int(exp)-len(f)+1), nil
} else { // move point left
exp = -exp
if int(exp) < point { // 123.456 << 2 = 1.23456
return f[:point-int(exp)] + "." + f[point-int(exp):point] + f[point+1:], nil
}
// 123.456 << 5 = 0.00123456
return "0." + strings.Repeat("0", int(exp)-point) + f[:point] + f[point+1:], nil
}
}

func convertDecimalStrToUint(sc *stmtctx.StatementContext, str string, upperBound uint64, tp byte) (uint64, error) {
str, err := convertScientificNotation(str)
if err != nil {
return 0, err
}

var intStr, fracStr string
p := strings.Index(str, ".")
if p == -1 {
intStr = str
} else {
intStr = str[:p]
fracStr = str[p+1:]
}
intStr = strings.TrimLeft(intStr, "0")
if intStr == "" {
intStr = "0"
}
if sc.ShouldClipToZero() && intStr[0] == '-' {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will len(intStr) == 0 be true here ? If so, intStr[0] == '-' will cause a panic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that len(intStr) > 0 here because MyDecimal returns a valid decimal string without scientific notation.
But I will let it support scientific notation for safety.

return 0, overflow(str, tp)
}

var round uint64
if fracStr != "" && fracStr[0] >= '5' {
round++
}

upperBound -= round
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
upperStr := strconv.FormatUint(upperBound, 10)
if len(intStr) > len(upperStr) ||
(len(intStr) == len(upperStr) && intStr > upperStr) {
return upperBound, overflow(str, tp)
}

val, err := strconv.ParseUint(intStr, 10, 64)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return val, err

return val, err
}
return val + round, nil
}

// ConvertDecimalToUint converts a decimal to a uint by converting it to a string first to avoid float overflow (#10181).
func ConvertDecimalToUint(sc *stmtctx.StatementContext, d *MyDecimal, upperBound uint64, tp byte) (uint64, error) {
return convertDecimalStrToUint(sc, string(d.ToString()), upperBound, tp)
}

// StrToInt converts a string to an integer at the best-effort.
func StrToInt(sc *stmtctx.StatementContext, str string) (int64, error) {
str = strings.TrimSpace(str)
Expand Down
66 changes: 66 additions & 0 deletions types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -903,3 +903,69 @@ func (s *testTypeConvertSuite) TestStrToDuration(c *C) {
c.Assert(isDuration, Equals, tt.isDuration)
}
}

func (s *testTypeConvertSuite) TestConvertScientificNotation(c *C) {
cases := []struct {
input string
output string
succ bool
}{
{"123.456e0", "123.456", true},
{"123.456e1", "1234.56", true},
{"123.456e3", "123456", true},
{"123.456e4", "1234560", true},
{"123.456e5", "12345600", true},
{"123.456e6", "123456000", true},
{"123.456e7", "1234560000", true},
{"123.456e-1", "12.3456", true},
{"123.456e-2", "1.23456", true},
{"123.456e-3", "0.123456", true},
{"123.456e-4", "0.0123456", true},
{"123.456e-5", "0.00123456", true},
{"123.456e-6", "0.000123456", true},
{"123.456e-7", "0.0000123456", true},
{"123.456e-", "", false},
{"123.456e-7.5", "", false},
{"123.456e", "", false},
}
for _, ca := range cases {
result, err := convertScientificNotation(ca.input)
if !ca.succ {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
c.Assert(ca.output, Equals, result)
}
}
}

func (s *testTypeConvertSuite) TestConvertDecimalStrToUint(c *C) {
cases := []struct {
input string
result uint64
succ bool
}{
{"0.", 0, true},
{"72.40", 72, true},
{"072.40", 72, true},
{"123.456e2", 12346, true},
{"123.456e-2", 1, true},
{"072.50000000001", 73, true},
{".5757", 1, true},
{".12345E+4", 1235, true},
{"9223372036854775807.5", 9223372036854775808, true},
{"9223372036854775807.4999", 9223372036854775807, true},
{"18446744073709551614.55", 18446744073709551615, true},
{"18446744073709551615.344", 18446744073709551615, true},
{"18446744073709551615.544", 0, false},
}
for _, ca := range cases {
result, err := convertDecimalStrToUint(&stmtctx.StatementContext{}, ca.input, math.MaxUint64, 0)
if !ca.succ {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
c.Assert(result, Equals, ca.result)
}
}
}
6 changes: 1 addition & 5 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -911,11 +911,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
val, err = ConvertIntToUint(sc, ival, upperBound, tp)
}
case KindMysqlDecimal:
fval, err1 := d.GetMysqlDecimal().ToFloat64()
val, err = ConvertFloatToUint(sc, fval, upperBound, tp)
if err == nil {
err = err1
}
val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp)
case KindMysqlEnum:
val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp)
case KindMysqlSet:
Expand Down