Skip to content

Commit

Permalink
type: fix float over flow when converting a decimal to a float and th…
Browse files Browse the repository at this point in the history
…en converting a float to an uint (#10730)
  • Loading branch information
qw4990 authored and zz-jason committed Jun 6, 2019
1 parent e974759 commit d07bdb2
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 5 deletions.
9 changes: 9 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4048,3 +4048,12 @@ func (s *testIntegrationSuite) TestIssue9710(c *C) {
break
}
}

func (s *testIntegrationSuite) TestIssue10181(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a bigint unsigned primary key);`)
tk.MustExec(`insert into t values(9223372036854775807), (18446744073709551615)`)
tk.MustQuery(`select * from t where a > 9223372036854775807-0.5 order by a`).Check(testkit.Rows(`9223372036854775807`, `18446744073709551615`))
}
93 changes: 93 additions & 0 deletions types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,99 @@ func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound u
return uint64(val), nil
}

// convertScientificNotation converts a decimal string with scientific notation to a normal decimal string.
// 1E6 => 1000000, .12345E+5 => 12345
func convertScientificNotation(str string) (string, error) {
// https://golang.org/ref/spec#Floating-point_literals
eIdx := -1
point := -1
for i := 0; i < len(str); i++ {
if str[i] == '.' {
point = i
}
if str[i] == 'e' || str[i] == 'E' {
eIdx = i
if point == -1 {
point = i
}
break
}
}
if eIdx == -1 {
return str, nil
}
exp, err := strconv.ParseInt(str[eIdx+1:], 10, 64)
if err != nil {
return "", errors.WithStack(err)
}

f := str[:eIdx]
if exp == 0 {
return f, nil
} else if exp > 0 { // move point right
if point+int(exp) == len(f)-1 { // 123.456 >> 3 = 123456. = 123456
return f[:point] + f[point+1:], nil
} else if point+int(exp) < len(f)-1 { // 123.456 >> 2 = 12345.6
return f[:point] + f[point+1:point+1+int(exp)] + "." + f[point+1+int(exp):], nil
}
// 123.456 >> 5 = 12345600
return f[:point] + f[point+1:] + strings.Repeat("0", point+int(exp)-len(f)+1), nil
} else { // move point left
exp = -exp
if int(exp) < point { // 123.456 << 2 = 1.23456
return f[:point-int(exp)] + "." + f[point-int(exp):point] + f[point+1:], nil
}
// 123.456 << 5 = 0.00123456
return "0." + strings.Repeat("0", int(exp)-point) + f[:point] + f[point+1:], nil
}
}

func convertDecimalStrToUint(sc *stmtctx.StatementContext, str string, upperBound uint64, tp byte) (uint64, error) {
str, err := convertScientificNotation(str)
if err != nil {
return 0, err
}

var intStr, fracStr string
p := strings.Index(str, ".")
if p == -1 {
intStr = str
} else {
intStr = str[:p]
fracStr = str[p+1:]
}
intStr = strings.TrimLeft(intStr, "0")
if intStr == "" {
intStr = "0"
}
if sc.ShouldClipToZero() && intStr[0] == '-' {
return 0, overflow(intStr, tp)
}

var round uint64
if fracStr != "" && fracStr[0] >= '5' {
round++
}

upperBound -= round
upperStr := strconv.FormatUint(upperBound, 10)
if len(intStr) > len(upperStr) ||
(len(intStr) == len(upperStr) && intStr > upperStr) {
return upperBound, overflow(str, tp)
}

val, err := strconv.ParseUint(intStr, 10, 64)
if err != nil {
return val, err
}
return val + round, nil
}

// ConvertDecimalToUint converts a decimal to a uint by converting it to a string first to avoid float overflow (#10181).
func ConvertDecimalToUint(sc *stmtctx.StatementContext, d *MyDecimal, upperBound uint64, tp byte) (uint64, error) {
return convertDecimalStrToUint(sc, string(d.ToString()), upperBound, tp)
}

// StrToInt converts a string to an integer at the best-effort.
func StrToInt(sc *stmtctx.StatementContext, str string) (int64, error) {
str = strings.TrimSpace(str)
Expand Down
85 changes: 85 additions & 0 deletions types/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,88 @@ func (s *testTypeConvertSuite) TestNumberToDuration(c *C) {
c.Assert(dur.Duration, Equals, tc.dur)
}
}

func (s *testTypeConvertSuite) TestStrToDuration(c *C) {
sc := new(stmtctx.StatementContext)
var tests = []struct {
str string
fsp int
isDuration bool
}{
{"20190412120000", 4, false},
{"20190101180000", 6, false},
{"20190101180000", 1, false},
{"20190101181234", 3, false},
}
for _, tt := range tests {
_, _, isDuration, err := StrToDuration(sc, tt.str, tt.fsp)
c.Assert(err, IsNil)
c.Assert(isDuration, Equals, tt.isDuration)
}
}

func (s *testTypeConvertSuite) TestConvertScientificNotation(c *C) {
cases := []struct {
input string
output string
succ bool
}{
{"123.456e0", "123.456", true},
{"123.456e1", "1234.56", true},
{"123.456e3", "123456", true},
{"123.456e4", "1234560", true},
{"123.456e5", "12345600", true},
{"123.456e6", "123456000", true},
{"123.456e7", "1234560000", true},
{"123.456e-1", "12.3456", true},
{"123.456e-2", "1.23456", true},
{"123.456e-3", "0.123456", true},
{"123.456e-4", "0.0123456", true},
{"123.456e-5", "0.00123456", true},
{"123.456e-6", "0.000123456", true},
{"123.456e-7", "0.0000123456", true},
{"123.456e-", "", false},
{"123.456e-7.5", "", false},
{"123.456e", "", false},
}
for _, ca := range cases {
result, err := convertScientificNotation(ca.input)
if !ca.succ {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
c.Assert(ca.output, Equals, result)
}
}
}

func (s *testTypeConvertSuite) TestConvertDecimalStrToUint(c *C) {
cases := []struct {
input string
result uint64
succ bool
}{
{"0.", 0, true},
{"72.40", 72, true},
{"072.40", 72, true},
{"123.456e2", 12346, true},
{"123.456e-2", 1, true},
{"072.50000000001", 73, true},
{".5757", 1, true},
{".12345E+4", 1235, true},
{"9223372036854775807.5", 9223372036854775808, true},
{"9223372036854775807.4999", 9223372036854775807, true},
{"18446744073709551614.55", 18446744073709551615, true},
{"18446744073709551615.344", 18446744073709551615, true},
{"18446744073709551615.544", 0, false},
}
for _, ca := range cases {
result, err := convertDecimalStrToUint(&stmtctx.StatementContext{}, ca.input, math.MaxUint64, 0)
if !ca.succ {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
c.Assert(result, Equals, ca.result)
}
}
}
6 changes: 1 addition & 5 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -902,11 +902,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
val, err = ConvertIntToUint(sc, ival, upperBound, tp)
}
case KindMysqlDecimal:
fval, err1 := d.GetMysqlDecimal().ToFloat64()
val, err = ConvertFloatToUint(sc, fval, upperBound, tp)
if err == nil {
err = err1
}
val, err = ConvertDecimalToUint(sc, d.GetMysqlDecimal(), upperBound, tp)
case KindMysqlEnum:
val, err = ConvertFloatToUint(sc, d.GetMysqlEnum().ToNumber(), upperBound, tp)
case KindMysqlSet:
Expand Down

0 comments on commit d07bdb2

Please sign in to comment.