diff --git a/distsql/distsql.go b/distsql/distsql.go index 45f8665fac2a7..4034978581471 100644 --- a/distsql/distsql.go +++ b/distsql/distsql.go @@ -16,6 +16,7 @@ package distsql import ( "github.com/juju/errors" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" @@ -54,6 +55,10 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie }, nil } + label := metrics.LblGeneral + if sctx.GetSessionVars().InRestrictedSQL { + label = metrics.LblInternal + } return &selectResult{ label: "dag", resp: resp, @@ -63,21 +68,28 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie fieldTypes: fieldTypes, ctx: sctx, feedback: fb, + sqlType: label, }, nil } // Analyze do a analyze request. -func Analyze(ctx context.Context, client kv.Client, kvReq *kv.Request, vars *kv.Variables) (SelectResult, error) { +func Analyze(ctx context.Context, client kv.Client, kvReq *kv.Request, vars *kv.Variables, + isRestrict bool) (SelectResult, error) { resp := client.Send(ctx, kvReq, vars) if resp == nil { return nil, errors.New("client returns nil response") } + label := metrics.LblGeneral + if isRestrict { + label = metrics.LblInternal + } result := &selectResult{ label: "analyze", resp: resp, results: make(chan resultWithErr, kvReq.Concurrency), closed: make(chan struct{}), feedback: statistics.NewQueryFeedback(0, nil, 0, false), + sqlType: label, } return result, nil } @@ -94,6 +106,7 @@ func Checksum(ctx context.Context, client kv.Client, kvReq *kv.Request, vars *kv results: make(chan resultWithErr, kvReq.Concurrency), closed: make(chan struct{}), feedback: statistics.NewQueryFeedback(0, nil, 0, false), + sqlType: metrics.LblGeneral, } return result, nil } diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index c933debaa52de..40e4419f7d025 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -64,6 +64,7 @@ func (s *testSuite) TestSelectNormal(c *C) { result, ok := response.(*selectResult) c.Assert(ok, IsTrue) c.Assert(result.label, Equals, "dag") + c.Assert(result.sqlType, Equals, "general") c.Assert(result.rowLen, Equals, len(colTypes)) response.Fetch(context.TODO()) @@ -143,12 +144,13 @@ func (s *testSuite) TestAnalyze(c *C) { Build() c.Assert(err, IsNil) - response, err := Analyze(context.TODO(), s.sctx.GetClient(), request, kv.DefaultVars) + response, err := Analyze(context.TODO(), s.sctx.GetClient(), request, kv.DefaultVars, true) c.Assert(err, IsNil) result, ok := response.(*selectResult) c.Assert(ok, IsTrue) c.Assert(result.label, Equals, "analyze") + c.Assert(result.sqlType, Equals, "internal") response.Fetch(context.TODO()) diff --git a/distsql/select_result.go b/distsql/select_result.go index ecdf76a79a3ad..77e3b2839cc18 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -67,6 +67,7 @@ type selectResult struct { feedback *statistics.QueryFeedback partialCount int64 // number of partial results. + sqlType string } func (r *selectResult) Fetch(ctx context.Context) { @@ -78,7 +79,7 @@ func (r *selectResult) fetch(ctx context.Context) { defer func() { close(r.results) duration := time.Since(startTime) - metrics.DistSQLQueryHistgram.WithLabelValues(r.label).Observe(duration.Seconds()) + metrics.DistSQLQueryHistgram.WithLabelValues(r.label, r.sqlType).Observe(duration.Seconds()) }() for { resultSubset, err := r.resp.Next(ctx) diff --git a/executor/adapter.go b/executor/adapter.go index 7b9877a5778c2..13d74fc21206b 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -354,15 +354,19 @@ func (a *ExecStmt) logSlowQuery(txnTS uint64, succ bool) { if len(sessVars.StmtCtx.IndexIDs) > 0 { indexIDs = strings.Replace(fmt.Sprintf("index_ids:%v ", a.Ctx.GetSessionVars().StmtCtx.IndexIDs), " ", ",", -1) } - user := a.Ctx.GetSessionVars().User + user := sessVars.User + var internal string + if sessVars.InRestrictedSQL { + internal = "[INTERNAL] " + } if costTime < threshold { logutil.SlowQueryLogger.Debugf( - "[QUERY] cost_time:%v %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", - costTime, sessVars.StmtCtx.GetExecDetails(), succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) + "[QUERY] %vcost_time:%v %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", + internal, costTime, sessVars.StmtCtx.GetExecDetails(), succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) } else { logutil.SlowQueryLogger.Warnf( - "[SLOW_QUERY] cost_time:%v %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", - costTime, sessVars.StmtCtx.GetExecDetails(), succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) + "[SLOW_QUERY] %vcost_time:%v %s succ:%v con:%v user:%s txn_start_ts:%v database:%v %v%vsql:%v", + internal, costTime, sessVars.StmtCtx.GetExecDetails(), succ, connID, user, txnTS, currentDB, tableIDs, indexIDs, sql) } } diff --git a/executor/analyze.go b/executor/analyze.go index f7bc176636197..46b825706791a 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -176,7 +176,7 @@ func (e *AnalyzeIndexExec) open() error { SetConcurrency(e.concurrency). Build() ctx := context.TODO() - e.result, err = distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars) + e.result, err = distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL) if err != nil { return errors.Trace(err) } @@ -295,7 +295,7 @@ func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectRe return nil, errors.Trace(err) } ctx := context.TODO() - result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars) + result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL) if err != nil { return nil, errors.Trace(err) } diff --git a/expression/builtin.go b/expression/builtin.go index 5717b40c80be5..536f9f3ca42e2 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -572,14 +572,15 @@ var funcs = map[string]functionClass{ ast.ValidatePasswordStrength: &validatePasswordStrengthFunctionClass{baseFunctionClass{ast.ValidatePasswordStrength, 1, 1}}, // json functions - ast.JSONType: &jsonTypeFunctionClass{baseFunctionClass{ast.JSONType, 1, 1}}, - ast.JSONExtract: &jsonExtractFunctionClass{baseFunctionClass{ast.JSONExtract, 2, -1}}, - ast.JSONUnquote: &jsonUnquoteFunctionClass{baseFunctionClass{ast.JSONUnquote, 1, 1}}, - ast.JSONSet: &jsonSetFunctionClass{baseFunctionClass{ast.JSONSet, 3, -1}}, - ast.JSONInsert: &jsonInsertFunctionClass{baseFunctionClass{ast.JSONInsert, 3, -1}}, - ast.JSONReplace: &jsonReplaceFunctionClass{baseFunctionClass{ast.JSONReplace, 3, -1}}, - ast.JSONRemove: &jsonRemoveFunctionClass{baseFunctionClass{ast.JSONRemove, 2, -1}}, - ast.JSONMerge: &jsonMergeFunctionClass{baseFunctionClass{ast.JSONMerge, 2, -1}}, - ast.JSONObject: &jsonObjectFunctionClass{baseFunctionClass{ast.JSONObject, 0, -1}}, - ast.JSONArray: &jsonArrayFunctionClass{baseFunctionClass{ast.JSONArray, 0, -1}}, + ast.JSONType: &jsonTypeFunctionClass{baseFunctionClass{ast.JSONType, 1, 1}}, + ast.JSONExtract: &jsonExtractFunctionClass{baseFunctionClass{ast.JSONExtract, 2, -1}}, + ast.JSONUnquote: &jsonUnquoteFunctionClass{baseFunctionClass{ast.JSONUnquote, 1, 1}}, + ast.JSONSet: &jsonSetFunctionClass{baseFunctionClass{ast.JSONSet, 3, -1}}, + ast.JSONInsert: &jsonInsertFunctionClass{baseFunctionClass{ast.JSONInsert, 3, -1}}, + ast.JSONReplace: &jsonReplaceFunctionClass{baseFunctionClass{ast.JSONReplace, 3, -1}}, + ast.JSONRemove: &jsonRemoveFunctionClass{baseFunctionClass{ast.JSONRemove, 2, -1}}, + ast.JSONMerge: &jsonMergeFunctionClass{baseFunctionClass{ast.JSONMerge, 2, -1}}, + ast.JSONObject: &jsonObjectFunctionClass{baseFunctionClass{ast.JSONObject, 0, -1}}, + ast.JSONArray: &jsonArrayFunctionClass{baseFunctionClass{ast.JSONArray, 0, -1}}, + ast.JSONContains: &jsonContainsFunctionClass{baseFunctionClass{ast.JSONContains, 2, 3}}, } diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 50b88e0bff3dd..79dfb430b2ca3 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -968,7 +968,7 @@ func (b *builtinCastDecimalAsTimeSig) evalTime(row chunk.Row) (res types.Time, i return res, isNull, errors.Trace(err) } sc := b.ctx.GetSessionVars().StmtCtx - res, err = types.ParseTime(sc, string(val.ToString()), b.tp.Tp, b.tp.Decimal) + res, err = types.ParseTimeFromFloatString(sc, string(val.ToString()), b.tp.Tp, b.tp.Decimal) if err != nil { return res, false, errors.Trace(err) } diff --git a/expression/builtin_json.go b/expression/builtin_json.go index f088c902f010a..fc32544a65699 100644 --- a/expression/builtin_json.go +++ b/expression/builtin_json.go @@ -35,6 +35,7 @@ var ( _ functionClass = &jsonMergeFunctionClass{} _ functionClass = &jsonObjectFunctionClass{} _ functionClass = &jsonArrayFunctionClass{} + _ functionClass = &jsonContainsFunctionClass{} // Type of JSON value. _ builtinFunc = &builtinJSONTypeSig{} @@ -56,6 +57,8 @@ var ( _ builtinFunc = &builtinJSONRemoveSig{} // Merge JSON documents, preserving duplicate keys. _ builtinFunc = &builtinJSONMergeSig{} + // Check JSON document contains specific target. + _ builtinFunc = &builtinJSONContainsSig{} ) type jsonTypeFunctionClass struct { @@ -548,3 +551,66 @@ func jsonModify(ctx sessionctx.Context, args []Expression, row chunk.Row, mt jso } return res, false, nil } + +type jsonContainsFunctionClass struct { + baseFunctionClass +} + +type builtinJSONContainsSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONContainsSig) Clone() builtinFunc { + newSig := &builtinJSONContainsSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonContainsFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, errors.Trace(err) + } + argTps := []types.EvalType{types.ETJson, types.ETJson} + if len(args) == 3 { + argTps = append(argTps, types.ETString) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTps...) + sig := &builtinJSONContainsSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonContainsSig) + return sig, nil +} + +func (b *builtinJSONContainsSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { + obj, isNull, err := b.args[0].EvalJSON(b.ctx, row) + if isNull || err != nil { + return res, isNull, errors.Trace(err) + } + target, isNull, err := b.args[1].EvalJSON(b.ctx, row) + if isNull || err != nil { + return res, isNull, errors.Trace(err) + } + var pathExpr json.PathExpression + if len(b.args) == 3 { + path, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return res, isNull, errors.Trace(err) + } + pathExpr, err = json.ParseJSONPathExpr(path) + if err != nil { + return res, true, errors.Trace(err) + } + if pathExpr.ContainsAnyAsterisk() { + return res, true, json.ErrInvalidJSONPathWildcard + } + var exists bool + obj, exists = obj.Extract([]json.PathExpression{pathExpr}) + if !exists { + return res, true, nil + } + } + + if json.ContainsBinary(obj, target) { + return 1, false, nil + } + return 0, false, nil +} diff --git a/expression/builtin_json_test.go b/expression/builtin_json_test.go index ec137fbe4dba1..51e631ce291fa 100644 --- a/expression/builtin_json_test.go +++ b/expression/builtin_json_test.go @@ -259,7 +259,7 @@ func (s *testEvaluatorSuite) TestJSONObject(c *C) { } } -func (s *testEvaluatorSuite) TestJSONORemove(c *C) { +func (s *testEvaluatorSuite) TestJSONRemove(c *C) { defer testleak.AfterTest(c)() fc := funcs[ast.JSONRemove] tbl := []struct { @@ -308,3 +308,66 @@ func (s *testEvaluatorSuite) TestJSONORemove(c *C) { } } } + +func (s *testEvaluatorSuite) TestJSONContains(c *C) { + defer testleak.AfterTest(c)() + fc := funcs[ast.JSONContains] + tbl := []struct { + input []interface{} + expected interface{} + success bool + }{ + // Tests nil arguments + {[]interface{}{nil, `1`, "$.c"}, nil, true}, + {[]interface{}{`{"a": [1, 2, {"aa": "xx"}]}`, nil, "$.a[3]"}, nil, true}, + {[]interface{}{`{"a": [1, 2, {"aa": "xx"}]}`, `1`, nil}, nil, true}, + // Tests with path expression + {[]interface{}{`[1,2,[1,[5,[3]]]]`, `[1,3]`, "$[2]"}, 1, true}, + {[]interface{}{`[1,2,[1,[5,{"a":[2,3]}]]]`, `[1,{"a":[3]}]`, "$[2]"}, 1, true}, + {[]interface{}{`[{"a":1}]`, `{"a":1}`, "$"}, 1, true}, + {[]interface{}{`[{"a":1,"b":2}]`, `{"a":1,"b":2}`, "$"}, 1, true}, + {[]interface{}{`[{"a":{"a":1},"b":2}]`, `{"a":1}`, "$.a"}, 0, true}, + // Tests without path expression + {[]interface{}{`{}`, `{}`}, 1, true}, + {[]interface{}{`{"a":1}`, `{}`}, 1, true}, + {[]interface{}{`{"a":1}`, `1`}, 0, true}, + {[]interface{}{`{"a":[1]}`, `[1]`}, 0, true}, + {[]interface{}{`{"b":2, "c":3}`, `{"c":3}`}, 1, true}, + {[]interface{}{`1`, `1`}, 1, true}, + {[]interface{}{`[1]`, `1`}, 1, true}, + {[]interface{}{`[1,2]`, `[1]`}, 1, true}, + {[]interface{}{`[1,2]`, `[1,3]`}, 0, true}, + {[]interface{}{`[1,2]`, `["1"]`}, 0, true}, + {[]interface{}{`[1,2,[1,3]]`, `[1,3]`}, 1, true}, + {[]interface{}{`[1,2,[1,[5,[3]]]]`, `[1,3]`}, 1, true}, + {[]interface{}{`[1,2,[1,[5,{"a":[2,3]}]]]`, `[1,{"a":[3]}]`}, 1, true}, + {[]interface{}{`[{"a":1}]`, `{"a":1}`}, 1, true}, + {[]interface{}{`[{"a":1,"b":2}]`, `{"a":1}`}, 1, true}, + {[]interface{}{`[{"a":{"a":1},"b":2}]`, `{"a":1}`}, 0, true}, + // Tests path expression contains any asterisk + {[]interface{}{`{"a": [1, 2, {"aa": "xx"}]}`, `1`, "$.*"}, nil, false}, + {[]interface{}{`{"a": [1, 2, {"aa": "xx"}]}`, `1`, "$[*]"}, nil, false}, + {[]interface{}{`{"a": [1, 2, {"aa": "xx"}]}`, `1`, "$**.a"}, nil, false}, + // Tests path expression does not identify a section of the target document + {[]interface{}{`{"a": [1, 2, {"aa": "xx"}]}`, `1`, "$.c"}, nil, true}, + {[]interface{}{`{"a": [1, 2, {"aa": "xx"}]}`, `1`, "$.a[3]"}, nil, true}, + {[]interface{}{`{"a": [1, 2, {"aa": "xx"}]}`, `1`, "$.a[2].b"}, nil, true}, + } + for _, t := range tbl { + args := types.MakeDatums(t.input...) + f, err := fc.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + + if t.success { + c.Assert(err, IsNil) + if t.expected == nil { + c.Assert(d.IsNull(), IsTrue) + } else { + c.Assert(d.GetInt64(), Equals, int64(t.expected.(int))) + } + } else { + c.Assert(err, NotNil) + } + } +} diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 5b0d76d7cf069..cdd6f54a2326e 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -993,7 +993,10 @@ func (b *builtinMonthSig) evalInt(row chunk.Row) (int64, bool, error) { } if date.IsZero() { - return 0, true, errors.Trace(handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenByArgs(date.String()))) + if b.ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() { + return 0, true, errors.Trace(handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenByArgs(date.String()))) + } + return 0, false, nil } return int64(date.Time.Month()), false, nil @@ -1030,9 +1033,9 @@ func (b *builtinMonthNameSig) evalString(row chunk.Row) (string, bool, error) { return "", true, errors.Trace(handleInvalidTimeError(b.ctx, err)) } mon := arg.Time.Month() - if arg.IsZero() || mon < 0 || mon > len(types.MonthNames) { + if (arg.IsZero() && b.ctx.GetSessionVars().SQLMode.HasNoZeroDateMode()) || mon < 0 || mon > len(types.MonthNames) { return "", true, errors.Trace(handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenByArgs(arg.String()))) - } else if mon == 0 { + } else if mon == 0 || arg.IsZero() { return "", true, nil } return types.MonthNames[mon-1], false, nil @@ -1111,7 +1114,10 @@ func (b *builtinDayOfMonthSig) evalInt(row chunk.Row) (int64, bool, error) { return 0, true, errors.Trace(handleInvalidTimeError(b.ctx, err)) } if arg.IsZero() { - return 0, true, errors.Trace(handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenByArgs(arg.String()))) + if b.ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() { + return 0, true, errors.Trace(handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenByArgs(arg.String()))) + } + return 0, false, nil } return int64(arg.Time.Day()), false, nil } @@ -1393,9 +1399,11 @@ func (b *builtinYearSig) evalInt(row chunk.Row) (int64, bool, error) { } if date.IsZero() { - return 0, true, errors.Trace(handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenByArgs(date.String()))) + if b.ctx.GetSessionVars().SQLMode.HasNoZeroDateMode() { + return 0, true, errors.Trace(handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenByArgs(date.String()))) + } + return 0, false, nil } - return int64(date.Time.Year()), false, nil } diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 1e7c472f2641c..e516c6da08d25 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -169,13 +169,115 @@ func (s *testEvaluatorSuite) TestDate(c *C) { Week interface{} WeekOfYear interface{} YearWeek interface{} + }{ + {nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}, + {"0000-00-00 00:00:00", 0, 0, nil, 0, nil, nil, nil, nil, nil, nil, nil}, + {"0000-00-00", 0, 0, nil, 0, nil, nil, nil, nil, nil, nil, nil}, + } + + dtblNil := tblToDtbl(tblNil) + for _, t := range dtblNil { + fc := funcs[ast.Year] + f, err := fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["Year"][0]) + + fc = funcs[ast.Month] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["Month"][0]) + + fc = funcs[ast.MonthName] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["MonthName"][0]) + + fc = funcs[ast.DayOfMonth] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["DayOfMonth"][0]) + + fc = funcs[ast.DayOfWeek] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["DayOfWeek"][0]) + + fc = funcs[ast.DayOfYear] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["DayOfYear"][0]) + + fc = funcs[ast.Weekday] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["WeekDay"][0]) + + fc = funcs[ast.DayName] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["DayName"][0]) + + fc = funcs[ast.Week] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["Week"][0]) + + fc = funcs[ast.WeekOfYear] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["WeekOfYear"][0]) + + fc = funcs[ast.YearWeek] + f, err = fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v, testutil.DatumEquals, t["YearWeek"][0]) + } + + // test nil with 'NO_ZERO_DATE' set in sql_mode + tblNil = []struct { + Input interface{} + Year interface{} + Month interface{} + MonthName interface{} + DayOfMonth interface{} + DayOfWeek interface{} + DayOfYear interface{} + WeekDay interface{} + DayName interface{} + Week interface{} + WeekOfYear interface{} + YearWeek interface{} }{ {nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}, {"0000-00-00 00:00:00", nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}, {"0000-00-00", nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}, } - dtblNil := tblToDtbl(tblNil) + dtblNil = tblToDtbl(tblNil) + s.ctx.GetSessionVars().SetSystemVar("sql_mode", "NO_ZERO_DATE") for _, t := range dtblNil { fc := funcs[ast.Year] f, err := fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index d396fe048f4e4..3c9720f13e2c0 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -430,6 +430,8 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti f = &builtinJSONRemoveSig{base} case tipb.ScalarFuncSig_JsonMergeSig: f = &builtinJSONMergeSig{base} + case tipb.ScalarFuncSig_JsonContainsSig: + f = &builtinJSONContainsSig{base} case tipb.ScalarFuncSig_LikeSig: f = &builtinLikeSig{base} diff --git a/expression/integration_test.go b/expression/integration_test.go index c37b478e5453e..92e3e8bcafdfa 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1091,10 +1091,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { _, err := tk.Exec(`insert into t select year("aa")`) c.Assert(err, NotNil) c.Assert(terror.ErrorEqual(err, types.ErrInvalidTimeFormat), IsTrue, Commentf("err %v", err)) - _, err = tk.Exec(`insert into t select year("0000-00-00 00:00:00")`) + tk.MustExec(`insert into t select year("0000-00-00 00:00:00")`) + tk.MustExec(`set sql_mode="NO_ZERO_DATE";`) + tk.MustExec(`insert into t select year("0000-00-00 00:00:00")`) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Incorrect datetime value: '0000-00-00 00:00:00.000000'")) + tk.MustExec(`set sql_mode="NO_ZERO_DATE,STRICT_TRANS_TABLES";`) + _, err = tk.Exec(`insert into t select year("0000-00-00 00:00:00");`) c.Assert(err, NotNil) c.Assert(types.ErrIncorrectDatetimeValue.Equal(err), IsTrue, Commentf("err %v", err)) tk.MustExec(`insert into t select 1`) + tk.MustExec(`set sql_mode="STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION";`) _, err = tk.Exec(`update t set a = year("aa")`) c.Assert(terror.ErrorEqual(err, types.ErrInvalidTimeFormat), IsTrue, Commentf("err %v", err)) _, err = tk.Exec(`delete from t where a = year("aa")`) @@ -1114,9 +1120,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { _, err = tk.Exec(`insert into t select month("aa")`) c.Assert(err, NotNil) c.Assert(terror.ErrorEqual(err, types.ErrInvalidTimeFormat), IsTrue) - _, err = tk.Exec(`insert into t select month("0000-00-00 00:00:00")`) + tk.MustExec(`insert into t select month("0000-00-00 00:00:00")`) + tk.MustExec(`set sql_mode="NO_ZERO_DATE";`) + tk.MustExec(`insert into t select month("0000-00-00 00:00:00")`) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Incorrect datetime value: '0000-00-00 00:00:00.000000'")) + tk.MustExec(`set sql_mode="NO_ZERO_DATE,STRICT_TRANS_TABLES";`) + _, err = tk.Exec(`insert into t select month("0000-00-00 00:00:00");`) c.Assert(err, NotNil) - c.Assert(types.ErrIncorrectDatetimeValue.Equal(err), IsTrue) + c.Assert(types.ErrIncorrectDatetimeValue.Equal(err), IsTrue, Commentf("err %v", err)) + tk.MustExec(`insert into t select 1`) + tk.MustExec(`set sql_mode="STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION";`) tk.MustExec(`insert into t select 1`) _, err = tk.Exec(`update t set a = month("aa")`) c.Assert(terror.ErrorEqual(err, types.ErrInvalidTimeFormat), IsTrue) @@ -1442,6 +1455,14 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { result = tk.MustQuery(`select dayOfYear(null), dayOfYear("2017-08-12"), dayOfYear("0000-00-00"), dayOfYear("2017-00-00"), dayOfYear("0000-00-00 12:12:12"), dayOfYear("2017-00-00 12:12:12")`) result.Check(testkit.Rows(" 224 ")) result = tk.MustQuery(`select dayOfMonth(null), dayOfMonth("2017-08-12"), dayOfMonth("0000-00-00"), dayOfMonth("2017-00-00"), dayOfMonth("0000-00-00 12:12:12"), dayOfMonth("2017-00-00 12:12:12")`) + result.Check(testkit.Rows(" 12 0 0 0 0")) + + tk.MustExec("set sql_mode = 'NO_ZERO_DATE'") + result = tk.MustQuery(`select dayOfWeek(null), dayOfWeek("2017-08-12"), dayOfWeek("0000-00-00"), dayOfWeek("2017-00-00"), dayOfWeek("0000-00-00 12:12:12"), dayOfWeek("2017-00-00 12:12:12")`) + result.Check(testkit.Rows(" 7 ")) + result = tk.MustQuery(`select dayOfYear(null), dayOfYear("2017-08-12"), dayOfYear("0000-00-00"), dayOfYear("2017-00-00"), dayOfYear("0000-00-00 12:12:12"), dayOfYear("2017-00-00 12:12:12")`) + result.Check(testkit.Rows(" 224 ")) + result = tk.MustQuery(`select dayOfMonth(null), dayOfMonth("2017-08-12"), dayOfMonth("0000-00-00"), dayOfMonth("2017-00-00"), dayOfMonth("0000-00-00 12:12:12"), dayOfMonth("2017-00-00 12:12:12")`) result.Check(testkit.Rows(" 12 0 0 0")) tk.MustExec(`drop table if exists t`) @@ -1458,6 +1479,13 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { _, err = tk.Exec("insert into t value(dayOfMonth('2017-00-00'))") c.Assert(types.ErrIncorrectDatetimeValue.Equal(err), IsTrue) + tk.MustExec("insert into t value(dayOfMonth('0000-00-00'))") + tk.MustExec(`update t set a = dayOfMonth("0000-00-00")`) + tk.MustExec("set sql_mode = 'NO_ZERO_DATE';") + tk.MustExec("insert into t value(dayOfMonth('0000-00-00'))") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Incorrect datetime value: '0000-00-00 00:00:00.000000'")) + tk.MustExec(`update t set a = dayOfMonth("0000-00-00")`) + tk.MustExec("set sql_mode = 'NO_ZERO_DATE,STRICT_TRANS_TABLES';") _, err = tk.Exec("insert into t value(dayOfMonth('0000-00-00'))") c.Assert(types.ErrIncorrectDatetimeValue.Equal(err), IsTrue) tk.MustExec("insert into t value(0)") @@ -1539,8 +1567,15 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { tk.MustExec(`insert into t value("abc")`) tk.MustExec("set sql_mode = 'STRICT_TRANS_TABLES'") - _, err = tk.Exec("insert into t value(monthname('0000-00-00'))") - c.Assert(types.ErrIncorrectDatetimeValue.Equal(err), IsTrue) + tk.MustExec("insert into t value(monthname('0000-00-00'))") + tk.MustExec(`update t set a = monthname("0000-00-00")`) + tk.MustExec("set sql_mode = 'NO_ZERO_DATE'") + tk.MustExec("insert into t value(monthname('0000-00-00'))") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Incorrect datetime value: '0000-00-00 00:00:00.000000'")) + tk.MustExec(`update t set a = monthname("0000-00-00")`) + tk.MustExec("set sql_mode = ''") + tk.MustExec("insert into t value(monthname('0000-00-00'))") + tk.MustExec("set sql_mode = 'STRICT_TRANS_TABLES,NO_ZERO_DATE'") _, err = tk.Exec(`update t set a = monthname("0000-00-00")`) c.Assert(types.ErrIncorrectDatetimeValue.Equal(err), IsTrue) _, err = tk.Exec(`delete from t where a = monthname(123)`) @@ -1880,6 +1915,9 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) { result = tk.MustQuery(`select cast(cast('2017-01-01 01:01:11.12' as date) as datetime(2));`) result.Check(testkit.Rows("2017-01-01 00:00:00.00")) + result = tk.MustQuery(`select cast(20170118.999 as datetime);`) + result.Check(testkit.Rows("2017-01-18 00:00:00")) + // for ISNULL tk.MustExec("drop table if exists t") tk.MustExec("create table t (a int, b int, c int, d char(10), e datetime, f float, g decimal(10, 3))") @@ -3205,6 +3243,19 @@ func (s *testIntegrationSuite) TestFuncJSON(c *C) { tk.MustExec(`update table_json set a=json_set(a,'$.a',json_object('a',1,'b',2)) where json_extract(a,'$.a[1]') = '2'`) r = tk.MustQuery(`select json_extract(a, '$.a.a'), json_extract(a, '$.a.b') from table_json`) r.Check(testkit.Rows("1 2", " ")) + + r = tk.MustQuery(`select json_contains(NULL, '1'), json_contains('1', NULL), json_contains('1', '1', NULL)`) + r.Check(testkit.Rows(" ")) + r = tk.MustQuery(`select json_contains('{}','{}'), json_contains('[1]','1'), json_contains('[1]','"1"'), json_contains('[1,2,[1,[5,[3]]]]', '[1,3]', '$[2]'), json_contains('[1,2,[1,[5,{"a":[2,3]}]]]', '[1,{"a":[3]}]', "$[2]"), json_contains('{"a":1}', '{"a":1,"b":2}', "$")`) + r.Check(testkit.Rows("1 1 0 1 1 0")) + r = tk.MustQuery(`select json_contains('{"a": 1}', '1', "$.c"), json_contains('{"a": [1, 2]}', '1', "$.a[2]"), json_contains('{"a": [1, {"a": 1}]}', '1', "$.a[1].b")`) + r.Check(testkit.Rows(" ")) + rs, err := tk.Exec("select json_contains('1','1','$.*')") + c.Assert(err, IsNil) + c.Assert(rs, NotNil) + _, err = session.GetRows4Test(context.Background(), tk.Se, rs) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[json:3149]In this situation, path expressions may not contain the * and ** tokens.") } func (s *testIntegrationSuite) TestColumnInfoModified(c *C) { diff --git a/metrics/distsql.go b/metrics/distsql.go index d93283eb81abb..71df60daa1f9e 100644 --- a/metrics/distsql.go +++ b/metrics/distsql.go @@ -26,7 +26,7 @@ var ( Name: "handle_query_duration_seconds", Help: "Bucketed histogram of processing time (s) of handled queries.", Buckets: prometheus.ExponentialBuckets(0.0005, 2, 13), - }, []string{LblType}) + }, []string{LblType, LblSQLType}) DistSQLScanKeysPartialHistogram = prometheus.NewHistogram( prometheus.HistogramOpts{ diff --git a/metrics/server.go b/metrics/server.go index 9d61c27e27bd9..61ecbf0de9d61 100644 --- a/metrics/server.go +++ b/metrics/server.go @@ -23,14 +23,14 @@ import ( // Metrics var ( - QueryDurationHistogram = prometheus.NewHistogram( + QueryDurationHistogram = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "tidb", Subsystem: "server", Name: "handle_query_duration_seconds", Help: "Bucketed histogram of processing time (s) of handled queries.", Buckets: prometheus.ExponentialBuckets(0.0005, 2, 22), - }) + }, []string{LblSQLType}) QueryTotalCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ diff --git a/metrics/session.go b/metrics/session.go index 5f75ace2b2c69..4ba4022f65f3a 100644 --- a/metrics/session.go +++ b/metrics/session.go @@ -17,30 +17,30 @@ import "github.com/prometheus/client_golang/prometheus" // Session metrics. var ( - SessionExecuteParseDuration = prometheus.NewHistogram( + SessionExecuteParseDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "tidb", Subsystem: "session", Name: "parse_duration_seconds", Help: "Bucketed histogram of processing time (s) in parse SQL.", Buckets: prometheus.LinearBuckets(0.00004, 0.00001, 13), - }) - SessionExecuteCompileDuration = prometheus.NewHistogram( + }, []string{LblSQLType}) + SessionExecuteCompileDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "tidb", Subsystem: "session", Name: "compile_duration_seconds", Help: "Bucketed histogram of processing time (s) in query optimize.", Buckets: prometheus.LinearBuckets(0.00004, 0.00001, 13), - }) - SessionExecuteRunDuration = prometheus.NewHistogram( + }, []string{LblSQLType}) + SessionExecuteRunDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "tidb", Subsystem: "session", Name: "execute_duration_seconds", Help: "Bucketed histogram of processing time (s) in running executor.", Buckets: prometheus.ExponentialBuckets(0.0001, 2, 13), - }) + }, []string{LblSQLType}) SchemaLeaseErrorCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: "tidb", @@ -107,6 +107,9 @@ const ( LblRollback = "rollback" LblType = "type" LblResult = "result" + LblSQLType = "sql_type" + LblGeneral = "general" + LblInternal = "internal" ) func init() { diff --git a/mysql/errcode.go b/mysql/errcode.go index dcc24bc5c3799..f748f212b0ec7 100644 --- a/mysql/errcode.go +++ b/mysql/errcode.go @@ -889,6 +889,7 @@ const ( ErrInvalidJSONText = 3140 ErrInvalidJSONPath = 3143 ErrInvalidJSONData = 3146 + ErrInvalidJSONPathWildcard = 3149 ErrJSONUsedAsKey = 3152 // TiDB self-defined errors. diff --git a/mysql/errname.go b/mysql/errname.go index 5deda27e6716c..d83a137e808bf 100644 --- a/mysql/errname.go +++ b/mysql/errname.go @@ -886,6 +886,7 @@ var MySQLErrName = map[uint16]string{ ErrInvalidJSONText: "Invalid JSON text: %-.192s", ErrInvalidJSONPath: "Invalid JSON path expression %s.", ErrInvalidJSONData: "Invalid data type for JSON data", + ErrInvalidJSONPathWildcard: "In this situation, path expressions may not contain the * and ** tokens.", ErrJSONUsedAsKey: "JSON column '%-.192s' cannot be used in key specification.", // TiDB errors. diff --git a/mysql/state.go b/mysql/state.go index 117b6e0e0517f..a31f61c01fdc0 100644 --- a/mysql/state.go +++ b/mysql/state.go @@ -253,5 +253,6 @@ var MySQLState = map[uint16]string{ ErrInvalidJSONText: "22032", ErrInvalidJSONPath: "42000", ErrInvalidJSONData: "22032", + ErrInvalidJSONPathWildcard: "42000", ErrJSONUsedAsKey: "42000", } diff --git a/server/conn.go b/server/conn.go index 1dcf3d2d1544b..21c74d0587886 100644 --- a/server/conn.go +++ b/server/conn.go @@ -569,7 +569,7 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { } else { metrics.QueryTotalCounter.WithLabelValues(label, "OK").Inc() } - metrics.QueryDurationHistogram.Observe(time.Since(startTime).Seconds()) + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblGeneral).Observe(time.Since(startTime).Seconds()) } // dispatch handles client request based on command which is the first byte of the data. diff --git a/session/session.go b/session/session.go index 11f5641553d9a..9f5e662edcfdf 100644 --- a/session/session.go +++ b/session/session.go @@ -542,6 +542,7 @@ func (s *session) ExecRestrictedSQL(sctx sessionctx.Context, sql string) ([]chun defer s.sysSessionPool().Put(tmp) metrics.SessionRestrictedSQLCounter.Inc() + startTime := time.Now() recordSets, err := se.Execute(ctx, sql) if err != nil { return nil, nil, errors.Trace(err) @@ -566,6 +567,7 @@ func (s *session) ExecRestrictedSQL(sctx sessionctx.Context, sql string) ([]chun fields = rs.Fields() } } + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) return rows, fields, nil } @@ -732,7 +734,11 @@ func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode } return nil, errors.Trace(err) } - metrics.SessionExecuteRunDuration.Observe(time.Since(startTime).Seconds()) + label := metrics.LblGeneral + if s.sessionVars.InRestrictedSQL { + label = metrics.LblInternal + } + metrics.SessionExecuteRunDuration.WithLabelValues(label).Observe(time.Since(startTime).Seconds()) if recordSet != nil { recordSets = append(recordSets, recordSet) @@ -766,7 +772,11 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []ast.Rec log.Warnf("con:%d parse error:\n%v\n%s", connID, err, sql) return nil, errors.Trace(err) } - metrics.SessionExecuteParseDuration.Observe(time.Since(startTS).Seconds()) + label := metrics.LblGeneral + if s.sessionVars.InRestrictedSQL { + label = metrics.LblInternal + } + metrics.SessionExecuteParseDuration.WithLabelValues(label).Observe(time.Since(startTS).Seconds()) compiler := executor.Compiler{Ctx: s} for _, stmtNode := range stmtNodes { @@ -784,7 +794,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []ast.Rec log.Warnf("con:%d compile error:\n%v\n%s", connID, err, sql) return nil, errors.Trace(err) } - metrics.SessionExecuteCompileDuration.Observe(time.Since(startTS).Seconds()) + metrics.SessionExecuteCompileDuration.WithLabelValues(label).Observe(time.Since(startTS).Seconds()) // Step3: Execute the physical plan. if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets); err != nil { diff --git a/types/json/binary_functions.go b/types/json/binary_functions.go index 4f9d58e245680..677ea59d2557f 100644 --- a/types/json/binary_functions.go +++ b/types/json/binary_functions.go @@ -689,3 +689,45 @@ func PeekBytesAsJSON(b []byte) (n int, err error) { err = errors.New("Invalid JSON bytes") return } + +// ContainsBinary check whether JSON document contains specific target according the following rules: +// 1) object contains a target object if and only if every key is contained in source object and the value associated with the target key is contained in the value associated with the source key; +// 2) array contains a target nonarray if and only if the target is contained in some element of the array; +// 3) array contains a target array if and only if every element is contained in some element of the array; +// 4) scalar contains a target scalar if and only if they are comparable and are equal; +func ContainsBinary(obj, target BinaryJSON) bool { + switch obj.TypeCode { + case TypeCodeObject: + if target.TypeCode == TypeCodeObject { + len := target.getElemCount() + for i := 0; i < len; i++ { + key := target.objectGetKey(i) + val := target.objectGetVal(i) + if exp, exists := obj.objectSearchKey(key); !exists || !ContainsBinary(exp, val) { + return false + } + } + return true + } + return false + case TypeCodeArray: + if target.TypeCode == TypeCodeArray { + len := target.getElemCount() + for i := 0; i < len; i++ { + if !ContainsBinary(obj, target.arrayGetElem(i)) { + return false + } + } + return true + } + len := obj.getElemCount() + for i := 0; i < len; i++ { + if ContainsBinary(obj.arrayGetElem(i), target) { + return true + } + } + return false + default: + return CompareBinary(obj, target) == 0 + } +} diff --git a/types/json/binary_test.go b/types/json/binary_test.go index d03a02d9afe58..8e88ebaf1317f 100644 --- a/types/json/binary_test.go +++ b/types/json/binary_test.go @@ -293,3 +293,34 @@ func BenchmarkBinaryMarshal(b *testing.B) { bj.MarshalJSON() } } + +func (s *testJSONSuite) TestBinaryJSONContains(c *C) { + var tests = []struct { + input string + target string + expected bool + }{ + {`{}`, `{}`, true}, + {`{"a":1}`, `{}`, true}, + {`{"a":1}`, `1`, false}, + {`{"a":[1]}`, `[1]`, false}, + {`{"b":2, "c":3}`, `{"c":3}`, true}, + {`1`, `1`, true}, + {`[1]`, `1`, true}, + {`[1,2]`, `[1]`, true}, + {`[1,2]`, `[1,3]`, false}, + {`[1,2]`, `["1"]`, false}, + {`[1,2,[1,3]]`, `[1,3]`, true}, + {`[1,2,[1,[5,[3]]]]`, `[1,3]`, true}, + {`[1,2,[1,[5,{"a":[2,3]}]]]`, `[1,{"a":[3]}]`, true}, + {`[{"a":1}]`, `{"a":1}`, true}, + {`[{"a":1,"b":2}]`, `{"a":1}`, true}, + {`[{"a":{"a":1},"b":2}]`, `{"a":1}`, false}, + } + + for _, tt := range tests { + obj := mustParseBinaryFromString(c, tt.input) + target := mustParseBinaryFromString(c, tt.target) + c.Assert(ContainsBinary(obj, target), Equals, tt.expected) + } +} diff --git a/types/json/constants.go b/types/json/constants.go index 83e2f4aa1328c..452f1accd3ebd 100644 --- a/types/json/constants.go +++ b/types/json/constants.go @@ -212,12 +212,15 @@ var ( ErrInvalidJSONPath = terror.ClassJSON.New(mysql.ErrInvalidJSONPath, mysql.MySQLErrName[mysql.ErrInvalidJSONPath]) // ErrInvalidJSONData means invalid JSON data. ErrInvalidJSONData = terror.ClassJSON.New(mysql.ErrInvalidJSONData, mysql.MySQLErrName[mysql.ErrInvalidJSONData]) + // ErrInvalidJSONPathWildcard means invalid JSON path that contain wildcard characters. + ErrInvalidJSONPathWildcard = terror.ClassJSON.New(mysql.ErrInvalidJSONPathWildcard, mysql.MySQLErrName[mysql.ErrInvalidJSONPathWildcard]) ) func init() { terror.ErrClassToMySQLCodes[terror.ClassJSON] = map[terror.ErrCode]uint16{ - mysql.ErrInvalidJSONText: mysql.ErrInvalidJSONText, - mysql.ErrInvalidJSONPath: mysql.ErrInvalidJSONPath, - mysql.ErrInvalidJSONData: mysql.ErrInvalidJSONData, + mysql.ErrInvalidJSONText: mysql.ErrInvalidJSONText, + mysql.ErrInvalidJSONPath: mysql.ErrInvalidJSONPath, + mysql.ErrInvalidJSONData: mysql.ErrInvalidJSONData, + mysql.ErrInvalidJSONPathWildcard: mysql.ErrInvalidJSONPathWildcard, } } diff --git a/types/json/path_expr.go b/types/json/path_expr.go index 2a522de94b247..8bb9b7a7d6551 100644 --- a/types/json/path_expr.go +++ b/types/json/path_expr.go @@ -117,6 +117,11 @@ func (pe PathExpression) popOneLastLeg() (PathExpression, pathLeg) { return PathExpression{legs: pe.legs[:lastLegIdx]}, lastLeg } +// ContainsAnyAsterisk returns true if pe contains any asterisk. +func (pe PathExpression) ContainsAnyAsterisk() bool { + return pe.flags.containsAnyAsterisk() +} + // ParseJSONPathExpr parses a JSON path expression. Returns a PathExpression // object which can be used in JSON_EXTRACT, JSON_SET and so on. func ParseJSONPathExpr(pathExpr string) (pe PathExpression, err error) {