diff --git a/executor/executor_test.go b/executor/executor_test.go index 03ef8c4ec310f..a0f1375277584 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1332,6 +1332,21 @@ func (s *testSuite) TestJSON(c *C) { // check CAST AS JSON. result = tk.MustQuery(`select CAST('3' AS JSON), CAST('{}' AS JSON), CAST(null AS JSON)`) result.Check(testkit.Rows(`3 {} `)) + + // Check cast json to decimal. + tk.MustExec("drop table if exists test_json") + tk.MustExec("create table test_json ( a decimal(60,2) as (JSON_EXTRACT(b,'$.c')), b json );") + tk.MustExec(`insert into test_json (b) values + ('{"c": "1267.1"}'), + ('{"c": "1267.01"}'), + ('{"c": "1267.1234"}'), + ('{"c": "1267.3456"}'), + ('{"c": "1234567890123456789012345678901234567890123456789012345"}'), + ('{"c": "1234567890123456789012345678901234567890123456789012345.12345"}');`) + + tk.MustQuery("select a from test_json;").Check(testkit.Rows("1267.10", "1267.01", "1267.12", + "1267.35", "1234567890123456789012345678901234567890123456789012345.00", + "1234567890123456789012345678901234567890123456789012345.12")) } func (s *testSuite) TestMultiUpdate(c *C) { diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 719ca5a2e9614..b3e8a539b9f7f 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1541,11 +1541,11 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD return res, isNull, errors.Trace(err) } sc := b.ctx.GetSessionVars().StmtCtx - f64, err := types.ConvertJSONToFloat(sc, val) - if err == nil { - res = new(types.MyDecimal) - err = res.FromFloat64(f64) + res, err = types.ConvertJSONToDecimal(sc, val) + if err != nil { + return res, false, errors.Trace(err) } + res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, sc) return res, false, errors.Trace(err) } diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index ab92b03021295..dd9d00daed014 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -1082,6 +1082,48 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) { c.Assert(iRes, Equals, int64(0)) } +func (s *testEvaluatorSuite) TestCastJSONAsDecimalSig(c *C) { + ctx, sc := s.ctx, s.ctx.GetSessionVars().StmtCtx + originIgnoreTruncate := sc.IgnoreTruncate + sc.IgnoreTruncate = true + defer func() { + sc.IgnoreTruncate = originIgnoreTruncate + }() + + col := &Column{RetType: types.NewFieldType(mysql.TypeJSON), Index: 0} + decFunc := newBaseBuiltinCastFunc(newBaseBuiltinFunc(ctx, []Expression{col}), false) + decFunc.tp = types.NewFieldType(mysql.TypeNewDecimal) + decFunc.tp.Flen = 60 + decFunc.tp.Decimal = 2 + sig := &builtinCastJSONAsDecimalSig{decFunc} + + var tests = []struct { + In string + Out *types.MyDecimal + }{ + {`{}`, types.NewDecFromStringForTest("0")}, + {`[]`, types.NewDecFromStringForTest("0")}, + {`3`, types.NewDecFromStringForTest("3")}, + {`-3`, types.NewDecFromStringForTest("-3")}, + {`4.5`, types.NewDecFromStringForTest("4.5")}, + {`"1234"`, types.NewDecFromStringForTest("1234")}, + // test truncate + {`"1234.1234"`, types.NewDecFromStringForTest("1234.12")}, + {`"1234.4567"`, types.NewDecFromStringForTest("1234.46")}, + // test big decimal + {`"1234567890123456789012345678901234567890123456789012345"`, types.NewDecFromStringForTest("1234567890123456789012345678901234567890123456789012345")}, + } + for _, tt := range tests { + j, err := json.ParseBinaryFromString(tt.In) + c.Assert(err, IsNil) + row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(j)}) + res, isNull, err := sig.evalDecimal(row.ToRow()) + c.Assert(isNull, Equals, false) + c.Assert(err, IsNil) + c.Assert(res.Compare(tt.Out), Equals, 0) + } +} + // TestWrapWithCastAsTypesClasses tests WrapWithCastAsInt/Real/String/Decimal. func (s *testEvaluatorSuite) TestWrapWithCastAsTypesClasses(c *C) { ctx := s.ctx diff --git a/types/convert.go b/types/convert.go index b8eb173c4b7ce..08ec314b927bf 100644 --- a/types/convert.go +++ b/types/convert.go @@ -378,6 +378,21 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6 return 0, errors.New("Unknown type code in JSON") } +// ConvertJSONToDecimal casts JSON into decimal. +func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyDecimal, error) { + res := new(MyDecimal) + if j.TypeCode != json.TypeCodeString { + f64, err := ConvertJSONToFloat(sc, j) + if err != nil { + return res, errors.Trace(err) + } + err = res.FromFloat64(f64) + return res, errors.Trace(err) + } + err := sc.HandleTruncate(res.FromString([]byte(j.GetString()))) + return res, errors.Trace(err) +} + // getValidFloatPrefix gets prefix of string which can be successfully parsed as float. func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) { var ( diff --git a/types/convert_test.go b/types/convert_test.go index df35527d7e612..c9103e92c17cb 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -804,6 +804,27 @@ func (s *testTypeConvertSuite) TestConvertJSONToFloat(c *C) { } } +func (s *testTypeConvertSuite) TestConvertJSONToDecimal(c *C) { + var tests = []struct { + In string + Out *MyDecimal + }{ + {`{}`, NewDecFromStringForTest("0")}, + {`[]`, NewDecFromStringForTest("0")}, + {`3`, NewDecFromStringForTest("3")}, + {`-3`, NewDecFromStringForTest("-3")}, + {`4.5`, NewDecFromStringForTest("4.5")}, + {`"1234"`, NewDecFromStringForTest("1234")}, + {`"1234567890123456789012345678901234567890123456789012345"`, NewDecFromStringForTest("1234567890123456789012345678901234567890123456789012345")}, + } + for _, tt := range tests { + j, err := json.ParseBinaryFromString(tt.In) + c.Assert(err, IsNil) + casted, _ := ConvertJSONToDecimal(new(stmtctx.StatementContext), j) + c.Assert(casted.Compare(tt.Out), Equals, 0) + } +} + func (s *testTypeConvertSuite) TestNumberToDuration(c *C) { var testCases = []struct { number int64