From d47f1687f3dc598a3b67100c955d3b2df3ef0907 Mon Sep 17 00:00:00 2001 From: Foreyes Date: Mon, 8 Jul 2019 14:54:56 +0800 Subject: [PATCH] expression: add built-in function `json_array_insert` (#11076) --- expression/builtin_json.go | 69 ++++++++++++++++++++++++++++++- expression/builtin_json_test.go | 73 +++++++++++++++++++++++++++++++++ expression/distsql_builtin.go | 2 + types/json/binary_functions.go | 43 +++++++++++++++++++ types/json/constants.go | 3 ++ 5 files changed, 189 insertions(+), 1 deletion(-) diff --git a/expression/builtin_json.go b/expression/builtin_json.go index 462ea88ed0d6a..e3df97b333379 100644 --- a/expression/builtin_json.go +++ b/expression/builtin_json.go @@ -58,6 +58,7 @@ var ( _ builtinFunc = &builtinJSONUnquoteSig{} _ builtinFunc = &builtinJSONArraySig{} _ builtinFunc = &builtinJSONArrayAppendSig{} + _ builtinFunc = &builtinJSONArrayInsertSig{} _ builtinFunc = &builtinJSONObjectSig{} _ builtinFunc = &builtinJSONExtractSig{} _ builtinFunc = &builtinJSONSetSig{} @@ -816,8 +817,74 @@ type jsonArrayInsertFunctionClass struct { baseFunctionClass } +type builtinJSONArrayInsertSig struct { + baseBuiltinFunc +} + func (c *jsonArrayInsertFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "JSON_ARRAY_INSERT") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + if len(args)&1 != 1 { + return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) + } + + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for i := 1; i < len(args)-1; i += 2 { + argTps = append(argTps, types.ETString, types.ETJson) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETJson, argTps...) + for i := 2; i < len(args); i += 2 { + DisableParseJSONFlag4Expr(args[i]) + } + sig := &builtinJSONArrayInsertSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonArrayInsertSig) + return sig, nil +} + +func (b *builtinJSONArrayInsertSig) Clone() builtinFunc { + newSig := &builtinJSONArrayInsertSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinJSONArrayInsertSig) evalJSON(row chunk.Row) (res json.BinaryJSON, isNull bool, err error) { + res, isNull, err = b.args[0].EvalJSON(b.ctx, row) + if err != nil || isNull { + return res, true, err + } + + for i := 1; i < len(b.args)-1; i += 2 { + // If JSON path is NULL, MySQL breaks and returns NULL. + s, isNull, err := b.args[i].EvalString(b.ctx, row) + if err != nil || isNull { + return res, true, err + } + + pathExpr, err := json.ParseJSONPathExpr(s) + if err != nil { + return res, true, json.ErrInvalidJSONPath.GenWithStackByArgs(s) + } + if pathExpr.ContainsAnyAsterisk() { + return res, true, json.ErrInvalidJSONPathWildcard.GenWithStackByArgs(s) + } + + value, isnull, err := b.args[i+1].EvalJSON(b.ctx, row) + if err != nil { + return res, true, err + } + + if isnull { + value = json.CreateBinary(nil) + } + + res, err = res.ArrayInsert(pathExpr, value) + if err != nil { + return res, true, err + } + } + return res, false, nil } type jsonMergePatchFunctionClass struct { diff --git a/expression/builtin_json_test.go b/expression/builtin_json_test.go index dba46f94b3afb..6b44ffdb0097b 100644 --- a/expression/builtin_json_test.go +++ b/expression/builtin_json_test.go @@ -865,3 +865,76 @@ func (s *testEvaluatorSuite) TestJSONSearch(c *C) { } } } + +func (s *testEvaluatorSuite) TestJSONArrayInsert(c *C) { + defer testleak.AfterTest(c)() + fc := funcs[ast.JSONArrayInsert] + tbl := []struct { + input []interface{} + expected interface{} + success bool + err *terror.Error + }{ + // Success + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.b[1]`, `z`}, `{"a": 1, "b": [2, "z", 3], "c": 4}`, true, nil}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.a[1]`, `z`}, `{"a": 1, "b": [2, 3], "c": 4}`, true, nil}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.d[1]`, `z`}, `{"a": 1, "b": [2, 3], "c": 4}`, true, nil}, + {[]interface{}{`[{"a": 1, "b": [2, 3], "c": 4}]`, `$[1]`, `w`}, `[{"a": 1, "b": [2, 3], "c": 4}, "w"]`, true, nil}, + {[]interface{}{`[{"a": 1, "b": [2, 3], "c": 4}]`, `$[0]`, nil}, `[null, {"a": 1, "b": [2, 3], "c": 4}]`, true, nil}, + {[]interface{}{`[1, 2, 3]`, `$[100]`, `{"b": 2}`}, `[1, 2, 3, "{\"b\": 2}"]`, true, nil}, + // About null + {[]interface{}{nil, `$`, nil}, nil, true, nil}, + {[]interface{}{nil, `$`, `a`}, nil, true, nil}, + {[]interface{}{`[]`, `$[0]`, nil}, `[null]`, true, nil}, + {[]interface{}{`{}`, `$[0]`, nil}, `{}`, true, nil}, + // Bad arguments + {[]interface{}{`asdf`, `$`, nil}, nil, false, json.ErrInvalidJSONText}, + {[]interface{}{``, `$`, nil}, nil, false, json.ErrInvalidJSONText}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.d`}, nil, false, ErrIncorrectParameterCount}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.c`, `y`, `$.b`}, nil, false, ErrIncorrectParameterCount}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, nil, nil}, nil, true, nil}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `asdf`, nil}, nil, false, json.ErrInvalidJSONPath}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, 42, nil}, nil, false, json.ErrInvalidJSONPath}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.*`, nil}, nil, false, json.ErrInvalidJSONPathWildcard}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.b[0]`, nil, `$.a`, nil}, nil, false, json.ErrInvalidJSONPathArrayCell}, + {[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.a`, nil}, nil, false, json.ErrInvalidJSONPathArrayCell}, + // Following tests come from MySQL doc. + {[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[1]`, `x`}, `["a", "x", {"b": [1, 2]}, [3, 4]]`, true, nil}, + {[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[100]`, `x`}, `["a", {"b": [1, 2]}, [3, 4], "x"]`, true, nil}, + {[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[1].b[0]`, `x`}, `["a", {"b": ["x", 1, 2]}, [3, 4]]`, true, nil}, + {[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[2][1]`, `y`}, `["a", {"b": [1, 2]}, [3, "y", 4]]`, true, nil}, + {[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[0]`, `x`, `$[2][1]`, `y`}, `["x", "a", {"b": [1, 2]}, [3, 4]]`, true, nil}, + // More test cases + {[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[0]`, `x`, `$[0]`, `y`}, `["y", "x", "a", {"b": [1, 2]}, [3, 4]]`, true, nil}, + } + for _, t := range tbl { + args := types.MakeDatums(t.input...) + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) + // Parameter count error + if err != nil { + c.Assert(t.err, NotNil) + c.Assert(t.err.Equal(err), Equals, true) + continue + } + + d, err := evalBuiltinFunc(f, chunk.Row{}) + + if t.success { + c.Assert(err, IsNil) + switch x := t.expected.(type) { + case string: + var j1, j2 json.BinaryJSON + j1, err = json.ParseBinaryFromString(x) + c.Assert(err, IsNil) + j2 = d.GetMysqlJSON() + var cmp int + cmp = json.CompareBinary(j1, j2) + c.Assert(cmp, Equals, 0) + case nil: + c.Assert(d.IsNull(), IsTrue) + } + } else { + c.Assert(t.err.Equal(err), Equals, true) + } + } +} diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 127e6baa95c78..a603858d89591 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -422,6 +422,8 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti f = &builtinJSONArraySig{base} case tipb.ScalarFuncSig_JsonArrayAppendSig: f = &builtinJSONArrayAppendSig{base} + case tipb.ScalarFuncSig_JsonArrayInsertSig: + f = &builtinJSONArrayInsertSig{base} case tipb.ScalarFuncSig_JsonObjectSig: f = &builtinJSONObjectSig{base} case tipb.ScalarFuncSig_JsonExtractSig: diff --git a/types/json/binary_functions.go b/types/json/binary_functions.go index 9cc87569e3542..18dc0d2eed1b3 100644 --- a/types/json/binary_functions.go +++ b/types/json/binary_functions.go @@ -369,6 +369,49 @@ func (bj BinaryJSON) Modify(pathExprList []PathExpression, values []BinaryJSON, return bj, nil } +// ArrayInsert insert a BinaryJSON into the given array cell. +// All path expressions cannot contain * or ** wildcard. +// If any error occurs, the input won't be changed. +func (bj BinaryJSON) ArrayInsert(pathExpr PathExpression, value BinaryJSON) (res BinaryJSON, err error) { + // Check the path is a index + if len(pathExpr.legs) < 1 { + return bj, ErrInvalidJSONPathArrayCell + } + parentPath, lastLeg := pathExpr.popOneLastLeg() + if lastLeg.typ != pathLegIndex { + return bj, ErrInvalidJSONPathArrayCell + } + // Find the target array + obj, exists := bj.Extract([]PathExpression{parentPath}) + if !exists || obj.TypeCode != TypeCodeArray { + return bj, nil + } + + idx := lastLeg.arrayIndex + count := obj.GetElemCount() + if idx >= count { + idx = count + } + // Insert into the array + newArray := make([]BinaryJSON, 0, count+1) + for i := 0; i < idx; i++ { + elem := obj.arrayGetElem(i) + newArray = append(newArray, elem) + } + newArray = append(newArray, value) + for i := idx; i < count; i++ { + elem := obj.arrayGetElem(i) + newArray = append(newArray, elem) + } + obj = buildBinaryArray(newArray) + + bj, err = bj.Modify([]PathExpression{parentPath}, []BinaryJSON{obj}, ModifySet) + if err != nil { + return bj, err + } + return bj, nil +} + // Remove removes the elements indicated by pathExprList from JSON. func (bj BinaryJSON) Remove(pathExprList []PathExpression) (BinaryJSON, error) { for _, pathExpr := range pathExprList { diff --git a/types/json/constants.go b/types/json/constants.go index 03c9a5aa7a5a2..1f8186fb807e2 100644 --- a/types/json/constants.go +++ b/types/json/constants.go @@ -216,6 +216,8 @@ var ( ErrInvalidJSONPathWildcard = terror.ClassJSON.New(mysql.ErrInvalidJSONPathWildcard, mysql.MySQLErrName[mysql.ErrInvalidJSONPathWildcard]) // ErrInvalidJSONContainsPathType means invalid JSON contains path type. ErrInvalidJSONContainsPathType = terror.ClassJSON.New(mysql.ErrInvalidJSONContainsPathType, mysql.MySQLErrName[mysql.ErrInvalidJSONContainsPathType]) + // ErrInvalidJSONPathArrayCell means invalid JSON path for an array cell. + ErrInvalidJSONPathArrayCell = terror.ClassJSON.New(mysql.ErrInvalidJSONPathArrayCell, mysql.MySQLErrName[mysql.ErrInvalidJSONPathArrayCell]) ) func init() { @@ -225,6 +227,7 @@ func init() { mysql.ErrInvalidJSONData: mysql.ErrInvalidJSONData, mysql.ErrInvalidJSONPathWildcard: mysql.ErrInvalidJSONPathWildcard, mysql.ErrInvalidJSONContainsPathType: mysql.ErrInvalidJSONContainsPathType, + mysql.ErrInvalidJSONPathArrayCell: mysql.ErrInvalidJSONPathArrayCell, } }