diff --git a/mysqldef/decimal.go b/mysqldef/decimal.go index 7eaed3d09f414..2ff129af1988e 100644 --- a/mysqldef/decimal.go +++ b/mysqldef/decimal.go @@ -199,17 +199,28 @@ func NewDecimalFromUint(value uint64, exp int32) Decimal { // func ParseDecimal(value string) (Decimal, error) { var intString string - var exp int32 + var exp = int32(0) + + n := strings.IndexAny(value, "eE") + if n > 0 { + // It is scientific notation, like 3.14e10 + expInt, err := strconv.Atoi(value[n+1:]) + if err != nil { + return Decimal{}, fmt.Errorf("can't convert %s to decimal, incorrect exponent", value) + } + value = value[0:n] + exp = int32(expInt) + } + parts := strings.Split(value, ".") if len(parts) == 1 { // There is no decimal point, we can just parse the original string as // an int. intString = value - exp = 0 } else if len(parts) == 2 { intString = parts[0] + parts[1] expInt := -len(parts[1]) - exp = int32(expInt) + exp += int32(expInt) } else { return Decimal{}, fmt.Errorf("can't convert %s to decimal: too many .s", value) } diff --git a/mysqldef/decimal_test.go b/mysqldef/decimal_test.go index 88448d1f7d2cb..60bb409c88182 100644 --- a/mysqldef/decimal_test.go +++ b/mysqldef/decimal_test.go @@ -869,3 +869,43 @@ func didPanic(f func()) bool { return ret } + +func TestDecimalScientificNotation(t *testing.T) { + tbl := []struct { + Input string + Expected float64 + }{ + {"314e-2", 3.14}, + {"1e2", 100}, + {"2E-1", 0.2}, + {"2E0", 2}, + {"2.2E-1", 0.22}, + } + + for _, c := range tbl { + n, err := ParseDecimal(c.Input) + if err != nil { + t.Error(err) + } + + f, _ := n.Float64() + if f != c.Expected { + t.Errorf("%f != %f", f, c.Expected) + } + } + + tblErr := []string{ + "12ee", + "ae10", + "12e1a", + "12e1.2", + "e1", + } + + for _, c := range tblErr { + _, err := ParseDecimal(c) + if err == nil { + t.Errorf("%s must be invalid decimal", c) + } + } +}