diff --git a/expression/integration_test.go b/expression/integration_test.go index 086f6162e0c94..86bfcc4d46264 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2315,7 +2315,6 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { defer s.cleanEnv(c) tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - ctx := context.Background() // for is true && is false tk.MustExec("drop table if exists t") @@ -2753,14 +2752,6 @@ func (s *testIntegrationSuite2) TestBuiltin(c *C) { _, err = tk.Exec("insert into t values(-9223372036854775809)") c.Assert(err, NotNil) - // test case decimal precision less than the scale. - rs, err := tk.Exec("select cast(12.1 as decimal(3, 4));") - c.Assert(err, IsNil) - _, err = session.GetRows4Test(ctx, tk.Se, rs) - c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '').") - c.Assert(rs.Close(), IsNil) - // test unhex and hex result = tk.MustQuery("select unhex('4D7953514C')") result.Check(testkit.Rows("MySQL")) diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 11b44d83e038d..4c5374bfff7f1 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -79,7 +79,7 @@ func TryAddExtraLimit(ctx sessionctx.Context, node ast.StmtNode) ast.StmtNode { // Preprocess resolves table names of the node, and checks some statements validation. func Preprocess(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema, preprocessOpt ...PreprocessOpt) error { - v := preprocessor{is: is, ctx: ctx, tableAliasInJoin: make([]map[string]interface{}, 0)} + v := preprocessor{is: is, ctx: ctx, sql: node.Text(), tableAliasInJoin: make([]map[string]interface{}, 0)} for _, optFn := range preprocessOpt { optFn(&v) } @@ -110,6 +110,7 @@ const ( type preprocessor struct { is infoschema.InfoSchema ctx sessionctx.Context + sql string err error flag preprocessorFlag @@ -184,6 +185,8 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { if node.FnName.L == ast.NextVal || node.FnName.L == ast.LastVal || node.FnName.L == ast.SetVal { p.flag |= inSequenceFunction } + case *ast.FuncCastExpr: + p.checkCastGrammar(node) case *ast.BRIEStmt: if node.Kind == ast.BRIEKindRestore { p.flag |= inCreateOrDropTable @@ -802,18 +805,21 @@ func checkColumn(colDef *ast.ColumnDef) error { } // Check column type. - tp := colDef.Tp + return checkTp(colDef.Tp, colDef.Name.Name.O, "") +} + +func checkTp(tp *types.FieldType, colName, val string) error { if tp == nil { return nil } if tp.Flen > math.MaxUint32 { - return types.ErrTooBigDisplayWidth.GenWithStack("Display width out of range for column '%s' (max = %d)", colDef.Name.Name.O, math.MaxUint32) + return types.ErrTooBigDisplayWidth.GenWithStack("Display width out of range for column '%s' (max = %d)", colName, math.MaxUint32) } switch tp.Tp { case mysql.TypeString: if tp.Flen != types.UnspecifiedLength && tp.Flen > mysql.MaxFieldCharLength { - return types.ErrTooBigFieldLength.GenWithStack("Column length too big for column '%s' (max = %d); use BLOB or TEXT instead", colDef.Name.Name.O, mysql.MaxFieldCharLength) + return types.ErrTooBigFieldLength.GenWithStack("Column length too big for column '%s' (max = %d); use BLOB or TEXT instead", colName, mysql.MaxFieldCharLength) } case mysql.TypeVarchar: if len(tp.Charset) == 0 { @@ -822,7 +828,7 @@ func checkColumn(colDef *ast.ColumnDef) error { // return nil, to make the check in the ddl.CreateTable. return nil } - err := ddl.IsTooBigFieldLength(colDef.Tp.Flen, colDef.Name.Name.O, tp.Charset) + err := ddl.IsTooBigFieldLength(tp.Flen, colName, tp.Charset) if err != nil { return err } @@ -835,41 +841,58 @@ func checkColumn(colDef *ast.ColumnDef) error { // For Double type Flen and Decimal check is moved to parser component default: if tp.Flen > mysql.MaxDoublePrecisionLength { - return types.ErrWrongFieldSpec.GenWithStackByArgs(colDef.Name.Name.O) + return types.ErrWrongFieldSpec.GenWithStackByArgs(colName) } } } else { if tp.Decimal > mysql.MaxFloatingTypeScale { - return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, colDef.Name.Name.O, mysql.MaxFloatingTypeScale) + return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, colName, mysql.MaxFloatingTypeScale) } if tp.Flen > mysql.MaxFloatingTypeWidth { - return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colDef.Name.Name.O, mysql.MaxFloatingTypeWidth) + return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colName, mysql.MaxFloatingTypeWidth) } } case mysql.TypeSet: if len(tp.Elems) > mysql.MaxTypeSetMembers { - return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colDef.Name.Name.O) + return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colName) } // Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html. - for _, str := range colDef.Tp.Elems { + for _, str := range tp.Elems { if strings.Contains(str, ",") { return types.ErrIllegalValueForType.GenWithStackByArgs(types.TypeStr(tp.Tp), str) } } case mysql.TypeNewDecimal: if tp.Decimal > mysql.MaxDecimalScale { - return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, colDef.Name.Name.O, mysql.MaxDecimalScale) + var arg string + if colName == "" { + arg = val + } else { + arg = colName + } + return types.ErrTooBigScale.GenWithStackByArgs(tp.Decimal, arg, mysql.MaxDecimalScale) + } if tp.Flen > mysql.MaxDecimalWidth { - return types.ErrTooBigPrecision.GenWithStackByArgs(tp.Flen, colDef.Name.Name.O, mysql.MaxDecimalWidth) + var arg string + if colName == "" { + arg = val + } else { + arg = colName + } + return types.ErrTooBigPrecision.GenWithStackByArgs(tp.Flen, arg, mysql.MaxDecimalWidth) + } + + if tp.Flen < tp.Decimal { + return types.ErrMBiggerThanD.GenWithStackByArgs(colName) } case mysql.TypeBit: if tp.Flen <= 0 { - return types.ErrInvalidFieldSize.GenWithStackByArgs(colDef.Name.Name.O) + return types.ErrInvalidFieldSize.GenWithStackByArgs(colName) } if tp.Flen > mysql.MaxBitDisplayWidth { - return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colDef.Name.Name.O, mysql.MaxBitDisplayWidth) + return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colName, mysql.MaxBitDisplayWidth) } default: // TODO: Add more types. @@ -1063,3 +1086,23 @@ func (p *preprocessor) resolveCreateSequenceStmt(stmt *ast.CreateSequenceStmt) { return } } +func (p *preprocessor) checkCastGrammar(node *ast.FuncCastExpr) { + var val string + switch x := node.Expr.(type) { + case ast.ValueExpr: + val = x.GetDatumString() + if val == "" { + val = fmt.Sprintf("%v", x.GetValue()) + } else { + wrapChar := p.sql[x.OriginTextPosition() : x.OriginTextPosition()+1] + val = wrapChar + val + wrapChar + } + case *ast.ColumnNameExpr: + val = x.Name.Name.O + default: + } + if err := checkTp(node.Tp, "", val); err != nil { + p.err = err + return + } +} diff --git a/planner/core/preprocess_test.go b/planner/core/preprocess_test.go index 4f6d90422a391..b76af468efb8b 100644 --- a/planner/core/preprocess_test.go +++ b/planner/core/preprocess_test.go @@ -269,6 +269,31 @@ func (s *testValidatorSuite) TestValidator(c *C) { {"CREATE TABLE t (a int, index(a));", false, nil}, {"CREATE INDEX `` on t (a);", true, errors.New("[ddl:1280]Incorrect index name ''")}, {"CREATE INDEX `` on t ((lower(a)));", true, errors.New("[ddl:1280]Incorrect index name ''")}, + + // for ErrTooBigScale + {`select convert('0.0', Decimal(41, 40))`, false, errors.New(`[types:1425]Too big scale 40 specified for column ''0.0''. Maximum is 30` + ".")}, + // for cast decimal ErrTooBigPrecision + {`select * from t where d = cast(d as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column 'd'. Maximum is 65` + ".")}, + {`select * from t where d = cast(111 as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '111'. Maximum is 65` + ".")}, + {`select * from t where d = cast("abc" as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '"abc"'. Maximum is 65` + ".")}, + {`select * from t where d = cast('d' as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column ''d''. Maximum is 65` + ".")}, + {`select cast(d as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column 'd'. Maximum is 65` + ".")}, + {`select cast(111 as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '111'. Maximum is 65` + ".")}, + {`select cast("abc" as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '"abc"'. Maximum is 65` + ".")}, + {`select cast("'d'" as decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '"'d'"'. Maximum is 65` + ".")}, + // for cast decimal ErrMBiggerThanD + {`select * from t where d = cast(d as decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")}, + {`select * from t where d = cast("d" as decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")}, + {`select * from t where d = cast("'d'" as decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")}, + // for convert decimal ErrTooBigPrecision + {`select * from t where d = convert(d, decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column 'd'. Maximum is 65` + ".")}, + {`select * from t where d = convert(111, decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '111'. Maximum is 65` + ".")}, + {`select * from t where d = convert("abc", decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column '"abc"'. Maximum is 65` + ".")}, + {`select * from t where d = convert('d', decimal(1000,20))`, false, errors.New(`[types:1426]Too big precision 1000 specified for column ''d''. Maximum is 65` + ".")}, + // for convert decimal ErrMBiggerThanD + {`select * from t where d = convert(d , decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")}, + {`select * from t where d = convert("d", decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")}, + {`select * from t where d = convert("'d'", decimal(10,20))`, false, errors.New(`[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '')` + ".")}, } _, err := s.se.Execute(context.Background(), "use test")