From 8f60e99d427429c1072cb5baa6b45d75d33b077c Mon Sep 17 00:00:00 2001 From: crazycs Date: Tue, 23 Oct 2018 22:19:21 +0800 Subject: [PATCH 1/3] expression: add round when cast json to decimal --- expression/builtin_cast.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index f30437dc0c39a..099672aa91623 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1545,6 +1545,9 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD if err == nil { res = new(types.MyDecimal) err = res.FromFloat64(f64) + if err == nil { + err = res.Round(res, b.tp.Decimal, types.ModeHalfEven) + } } return res, false, errors.Trace(err) } From b512a284b22b7f0cc015864b3ec24579aa8cddfa Mon Sep 17 00:00:00 2001 From: crazycs Date: Wed, 24 Oct 2018 10:30:37 +0800 Subject: [PATCH 2/3] expression: fix json cast to decimal bug of precision and big decimal number --- expression/builtin_cast.go | 11 ++++----- expression/builtin_cast_test.go | 42 +++++++++++++++++++++++++++++++++ types/convert.go | 15 ++++++++++++ types/convert_test.go | 21 +++++++++++++++++ 4 files changed, 82 insertions(+), 7 deletions(-) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 099672aa91623..8e20ca110d6ae 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1541,14 +1541,11 @@ func (b *builtinCastJSONAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyD return res, isNull, errors.Trace(err) } sc := b.ctx.GetSessionVars().StmtCtx - f64, err := types.ConvertJSONToFloat(sc, val) - if err == nil { - res = new(types.MyDecimal) - err = res.FromFloat64(f64) - if err == nil { - err = res.Round(res, b.tp.Decimal, types.ModeHalfEven) - } + res, err = types.ConvertJSONToDecimal(sc, val) + if err != nil { + return res, false, errors.Trace(err) } + res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, sc) return res, false, errors.Trace(err) } diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index ab92b03021295..dd9d00daed014 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -1082,6 +1082,48 @@ func (s *testEvaluatorSuite) TestCastFuncSig(c *C) { c.Assert(iRes, Equals, int64(0)) } +func (s *testEvaluatorSuite) TestCastJSONAsDecimalSig(c *C) { + ctx, sc := s.ctx, s.ctx.GetSessionVars().StmtCtx + originIgnoreTruncate := sc.IgnoreTruncate + sc.IgnoreTruncate = true + defer func() { + sc.IgnoreTruncate = originIgnoreTruncate + }() + + col := &Column{RetType: types.NewFieldType(mysql.TypeJSON), Index: 0} + decFunc := newBaseBuiltinCastFunc(newBaseBuiltinFunc(ctx, []Expression{col}), false) + decFunc.tp = types.NewFieldType(mysql.TypeNewDecimal) + decFunc.tp.Flen = 60 + decFunc.tp.Decimal = 2 + sig := &builtinCastJSONAsDecimalSig{decFunc} + + var tests = []struct { + In string + Out *types.MyDecimal + }{ + {`{}`, types.NewDecFromStringForTest("0")}, + {`[]`, types.NewDecFromStringForTest("0")}, + {`3`, types.NewDecFromStringForTest("3")}, + {`-3`, types.NewDecFromStringForTest("-3")}, + {`4.5`, types.NewDecFromStringForTest("4.5")}, + {`"1234"`, types.NewDecFromStringForTest("1234")}, + // test truncate + {`"1234.1234"`, types.NewDecFromStringForTest("1234.12")}, + {`"1234.4567"`, types.NewDecFromStringForTest("1234.46")}, + // test big decimal + {`"1234567890123456789012345678901234567890123456789012345"`, types.NewDecFromStringForTest("1234567890123456789012345678901234567890123456789012345")}, + } + for _, tt := range tests { + j, err := json.ParseBinaryFromString(tt.In) + c.Assert(err, IsNil) + row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(j)}) + res, isNull, err := sig.evalDecimal(row.ToRow()) + c.Assert(isNull, Equals, false) + c.Assert(err, IsNil) + c.Assert(res.Compare(tt.Out), Equals, 0) + } +} + // TestWrapWithCastAsTypesClasses tests WrapWithCastAsInt/Real/String/Decimal. func (s *testEvaluatorSuite) TestWrapWithCastAsTypesClasses(c *C) { ctx := s.ctx diff --git a/types/convert.go b/types/convert.go index 8997426901b73..7f10d322313c9 100644 --- a/types/convert.go +++ b/types/convert.go @@ -378,6 +378,21 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6 return 0, errors.New("Unknown type code in JSON") } +// ConvertJSONToDecimal casts JSON into decimal. +func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyDecimal, error) { + res := new(MyDecimal) + if j.TypeCode != json.TypeCodeString { + f64, err := ConvertJSONToFloat(sc, j) + if err != nil { + return res, errors.Trace(err) + } + err = res.FromFloat64(f64) + return res, errors.Trace(err) + } + err := sc.HandleTruncate(res.FromString([]byte(j.GetString()))) + return res, errors.Trace(err) +} + // getValidFloatPrefix gets prefix of string which can be successfully parsed as float. func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) { var ( diff --git a/types/convert_test.go b/types/convert_test.go index f74431f2ef843..660182cf3ce7b 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -804,6 +804,27 @@ func (s *testTypeConvertSuite) TestConvertJSONToFloat(c *C) { } } +func (s *testTypeConvertSuite) TestConvertJSONToDecimal(c *C) { + var tests = []struct { + In string + Out *MyDecimal + }{ + {`{}`, NewDecFromStringForTest("0")}, + {`[]`, NewDecFromStringForTest("0")}, + {`3`, NewDecFromStringForTest("3")}, + {`-3`, NewDecFromStringForTest("-3")}, + {`4.5`, NewDecFromStringForTest("4.5")}, + {`"1234"`, NewDecFromStringForTest("1234")}, + {`"1234567890123456789012345678901234567890123456789012345"`, NewDecFromStringForTest("1234567890123456789012345678901234567890123456789012345")}, + } + for _, tt := range tests { + j, err := json.ParseBinaryFromString(tt.In) + c.Assert(err, IsNil) + casted, _ := ConvertJSONToDecimal(new(stmtctx.StatementContext), j) + c.Assert(casted.Compare(tt.Out), Equals, 0) + } +} + func (s *testTypeConvertSuite) TestNumberToDuration(c *C) { var testCases = []struct { number int64 From f232453afb183d40670c41837e4a506982b9839f Mon Sep 17 00:00:00 2001 From: crazycs Date: Thu, 25 Oct 2018 10:01:38 +0800 Subject: [PATCH 3/3] add test --- executor/executor_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/executor/executor_test.go b/executor/executor_test.go index 7c5841cc3bf80..a107e1af38e95 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1369,6 +1369,21 @@ func (s *testSuite) TestJSON(c *C) { // check CAST AS JSON. result = tk.MustQuery(`select CAST('3' AS JSON), CAST('{}' AS JSON), CAST(null AS JSON)`) result.Check(testkit.Rows(`3 {} `)) + + // Check cast json to decimal. + tk.MustExec("drop table if exists test_json") + tk.MustExec("create table test_json ( a decimal(60,2) as (JSON_EXTRACT(b,'$.c')), b json );") + tk.MustExec(`insert into test_json (b) values + ('{"c": "1267.1"}'), + ('{"c": "1267.01"}'), + ('{"c": "1267.1234"}'), + ('{"c": "1267.3456"}'), + ('{"c": "1234567890123456789012345678901234567890123456789012345"}'), + ('{"c": "1234567890123456789012345678901234567890123456789012345.12345"}');`) + + tk.MustQuery("select a from test_json;").Check(testkit.Rows("1267.10", "1267.01", "1267.12", + "1267.35", "1234567890123456789012345678901234567890123456789012345.00", + "1234567890123456789012345678901234567890123456789012345.12")) } func (s *testSuite) TestMultiUpdate(c *C) {