Skip to content

Commit

Permalink
expression: add built-in function json_array_insert (#11076)
Browse files Browse the repository at this point in the history
  • Loading branch information
foreyes authored and eurekaka committed Jul 8, 2019
1 parent a737d26 commit d47f168
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 1 deletion.
69 changes: 68 additions & 1 deletion expression/builtin_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ var (
_ builtinFunc = &builtinJSONUnquoteSig{}
_ builtinFunc = &builtinJSONArraySig{}
_ builtinFunc = &builtinJSONArrayAppendSig{}
_ builtinFunc = &builtinJSONArrayInsertSig{}
_ builtinFunc = &builtinJSONObjectSig{}
_ builtinFunc = &builtinJSONExtractSig{}
_ builtinFunc = &builtinJSONSetSig{}
Expand Down Expand Up @@ -816,8 +817,74 @@ type jsonArrayInsertFunctionClass struct {
baseFunctionClass
}

type builtinJSONArrayInsertSig struct {
baseBuiltinFunc
}

func (c *jsonArrayInsertFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "JSON_ARRAY_INSERT")
if err := c.verifyArgs(args); err != nil {
return nil, err
}
if len(args)&1 != 1 {
return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName)
}

argTps := make([]types.EvalType, 0, len(args))
argTps = append(argTps, types.ETJson)
for i := 1; i < len(args)-1; i += 2 {
argTps = append(argTps, types.ETString, types.ETJson)
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETJson, argTps...)
for i := 2; i < len(args); i += 2 {
DisableParseJSONFlag4Expr(args[i])
}
sig := &builtinJSONArrayInsertSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_JsonArrayInsertSig)
return sig, nil
}

func (b *builtinJSONArrayInsertSig) Clone() builtinFunc {
newSig := &builtinJSONArrayInsertSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinJSONArrayInsertSig) evalJSON(row chunk.Row) (res json.BinaryJSON, isNull bool, err error) {
res, isNull, err = b.args[0].EvalJSON(b.ctx, row)
if err != nil || isNull {
return res, true, err
}

for i := 1; i < len(b.args)-1; i += 2 {
// If JSON path is NULL, MySQL breaks and returns NULL.
s, isNull, err := b.args[i].EvalString(b.ctx, row)
if err != nil || isNull {
return res, true, err
}

pathExpr, err := json.ParseJSONPathExpr(s)
if err != nil {
return res, true, json.ErrInvalidJSONPath.GenWithStackByArgs(s)
}
if pathExpr.ContainsAnyAsterisk() {
return res, true, json.ErrInvalidJSONPathWildcard.GenWithStackByArgs(s)
}

value, isnull, err := b.args[i+1].EvalJSON(b.ctx, row)
if err != nil {
return res, true, err
}

if isnull {
value = json.CreateBinary(nil)
}

res, err = res.ArrayInsert(pathExpr, value)
if err != nil {
return res, true, err
}
}
return res, false, nil
}

type jsonMergePatchFunctionClass struct {
Expand Down
73 changes: 73 additions & 0 deletions expression/builtin_json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,3 +865,76 @@ func (s *testEvaluatorSuite) TestJSONSearch(c *C) {
}
}
}

func (s *testEvaluatorSuite) TestJSONArrayInsert(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.JSONArrayInsert]
tbl := []struct {
input []interface{}
expected interface{}
success bool
err *terror.Error
}{
// Success
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.b[1]`, `z`}, `{"a": 1, "b": [2, "z", 3], "c": 4}`, true, nil},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.a[1]`, `z`}, `{"a": 1, "b": [2, 3], "c": 4}`, true, nil},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.d[1]`, `z`}, `{"a": 1, "b": [2, 3], "c": 4}`, true, nil},
{[]interface{}{`[{"a": 1, "b": [2, 3], "c": 4}]`, `$[1]`, `w`}, `[{"a": 1, "b": [2, 3], "c": 4}, "w"]`, true, nil},
{[]interface{}{`[{"a": 1, "b": [2, 3], "c": 4}]`, `$[0]`, nil}, `[null, {"a": 1, "b": [2, 3], "c": 4}]`, true, nil},
{[]interface{}{`[1, 2, 3]`, `$[100]`, `{"b": 2}`}, `[1, 2, 3, "{\"b\": 2}"]`, true, nil},
// About null
{[]interface{}{nil, `$`, nil}, nil, true, nil},
{[]interface{}{nil, `$`, `a`}, nil, true, nil},
{[]interface{}{`[]`, `$[0]`, nil}, `[null]`, true, nil},
{[]interface{}{`{}`, `$[0]`, nil}, `{}`, true, nil},
// Bad arguments
{[]interface{}{`asdf`, `$`, nil}, nil, false, json.ErrInvalidJSONText},
{[]interface{}{``, `$`, nil}, nil, false, json.ErrInvalidJSONText},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.d`}, nil, false, ErrIncorrectParameterCount},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.c`, `y`, `$.b`}, nil, false, ErrIncorrectParameterCount},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, nil, nil}, nil, true, nil},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `asdf`, nil}, nil, false, json.ErrInvalidJSONPath},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, 42, nil}, nil, false, json.ErrInvalidJSONPath},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.*`, nil}, nil, false, json.ErrInvalidJSONPathWildcard},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.b[0]`, nil, `$.a`, nil}, nil, false, json.ErrInvalidJSONPathArrayCell},
{[]interface{}{`{"a": 1, "b": [2, 3], "c": 4}`, `$.a`, nil}, nil, false, json.ErrInvalidJSONPathArrayCell},
// Following tests come from MySQL doc.
{[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[1]`, `x`}, `["a", "x", {"b": [1, 2]}, [3, 4]]`, true, nil},
{[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[100]`, `x`}, `["a", {"b": [1, 2]}, [3, 4], "x"]`, true, nil},
{[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[1].b[0]`, `x`}, `["a", {"b": ["x", 1, 2]}, [3, 4]]`, true, nil},
{[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[2][1]`, `y`}, `["a", {"b": [1, 2]}, [3, "y", 4]]`, true, nil},
{[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[0]`, `x`, `$[2][1]`, `y`}, `["x", "a", {"b": [1, 2]}, [3, 4]]`, true, nil},
// More test cases
{[]interface{}{`["a", {"b": [1, 2]}, [3, 4]]`, `$[0]`, `x`, `$[0]`, `y`}, `["y", "x", "a", {"b": [1, 2]}, [3, 4]]`, true, nil},
}
for _, t := range tbl {
args := types.MakeDatums(t.input...)
f, err := fc.getFunction(s.ctx, s.datumsToConstants(args))
// Parameter count error
if err != nil {
c.Assert(t.err, NotNil)
c.Assert(t.err.Equal(err), Equals, true)
continue
}

d, err := evalBuiltinFunc(f, chunk.Row{})

if t.success {
c.Assert(err, IsNil)
switch x := t.expected.(type) {
case string:
var j1, j2 json.BinaryJSON
j1, err = json.ParseBinaryFromString(x)
c.Assert(err, IsNil)
j2 = d.GetMysqlJSON()
var cmp int
cmp = json.CompareBinary(j1, j2)
c.Assert(cmp, Equals, 0)
case nil:
c.Assert(d.IsNull(), IsTrue)
}
} else {
c.Assert(t.err.Equal(err), Equals, true)
}
}
}
2 changes: 2 additions & 0 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
f = &builtinJSONArraySig{base}
case tipb.ScalarFuncSig_JsonArrayAppendSig:
f = &builtinJSONArrayAppendSig{base}
case tipb.ScalarFuncSig_JsonArrayInsertSig:
f = &builtinJSONArrayInsertSig{base}
case tipb.ScalarFuncSig_JsonObjectSig:
f = &builtinJSONObjectSig{base}
case tipb.ScalarFuncSig_JsonExtractSig:
Expand Down
43 changes: 43 additions & 0 deletions types/json/binary_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,49 @@ func (bj BinaryJSON) Modify(pathExprList []PathExpression, values []BinaryJSON,
return bj, nil
}

// ArrayInsert insert a BinaryJSON into the given array cell.
// All path expressions cannot contain * or ** wildcard.
// If any error occurs, the input won't be changed.
func (bj BinaryJSON) ArrayInsert(pathExpr PathExpression, value BinaryJSON) (res BinaryJSON, err error) {
// Check the path is a index
if len(pathExpr.legs) < 1 {
return bj, ErrInvalidJSONPathArrayCell
}
parentPath, lastLeg := pathExpr.popOneLastLeg()
if lastLeg.typ != pathLegIndex {
return bj, ErrInvalidJSONPathArrayCell
}
// Find the target array
obj, exists := bj.Extract([]PathExpression{parentPath})
if !exists || obj.TypeCode != TypeCodeArray {
return bj, nil
}

idx := lastLeg.arrayIndex
count := obj.GetElemCount()
if idx >= count {
idx = count
}
// Insert into the array
newArray := make([]BinaryJSON, 0, count+1)
for i := 0; i < idx; i++ {
elem := obj.arrayGetElem(i)
newArray = append(newArray, elem)
}
newArray = append(newArray, value)
for i := idx; i < count; i++ {
elem := obj.arrayGetElem(i)
newArray = append(newArray, elem)
}
obj = buildBinaryArray(newArray)

bj, err = bj.Modify([]PathExpression{parentPath}, []BinaryJSON{obj}, ModifySet)
if err != nil {
return bj, err
}
return bj, nil
}

// Remove removes the elements indicated by pathExprList from JSON.
func (bj BinaryJSON) Remove(pathExprList []PathExpression) (BinaryJSON, error) {
for _, pathExpr := range pathExprList {
Expand Down
3 changes: 3 additions & 0 deletions types/json/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ var (
ErrInvalidJSONPathWildcard = terror.ClassJSON.New(mysql.ErrInvalidJSONPathWildcard, mysql.MySQLErrName[mysql.ErrInvalidJSONPathWildcard])
// ErrInvalidJSONContainsPathType means invalid JSON contains path type.
ErrInvalidJSONContainsPathType = terror.ClassJSON.New(mysql.ErrInvalidJSONContainsPathType, mysql.MySQLErrName[mysql.ErrInvalidJSONContainsPathType])
// ErrInvalidJSONPathArrayCell means invalid JSON path for an array cell.
ErrInvalidJSONPathArrayCell = terror.ClassJSON.New(mysql.ErrInvalidJSONPathArrayCell, mysql.MySQLErrName[mysql.ErrInvalidJSONPathArrayCell])
)

func init() {
Expand All @@ -225,6 +227,7 @@ func init() {
mysql.ErrInvalidJSONData: mysql.ErrInvalidJSONData,
mysql.ErrInvalidJSONPathWildcard: mysql.ErrInvalidJSONPathWildcard,
mysql.ErrInvalidJSONContainsPathType: mysql.ErrInvalidJSONContainsPathType,
mysql.ErrInvalidJSONPathArrayCell: mysql.ErrInvalidJSONPathArrayCell,
}
}

Expand Down

0 comments on commit d47f168

Please sign in to comment.