diff --git a/cmd/explaintest/r/new_character_set_builtin.result b/cmd/explaintest/r/new_character_set_builtin.result index bb4e9feaae201..104da7324a001 100644 --- a/cmd/explaintest/r/new_character_set_builtin.result +++ b/cmd/explaintest/r/new_character_set_builtin.result @@ -54,7 +54,6 @@ set @@tidb_enable_vectorized_expression = true; select to_base64(a), to_base64(b), to_base64(c) from t; to_base64(a) to_base64(b) to_base64(c) 5LiA5LqM5LiJ 0ru2/sj9 5LiA5LqM5LiJAAAAAAAAAAAAAAA= -set @@tidb_enable_vectorized_expression = false; drop table if exists t; create table t(a char(10)); insert into t values ('中文'), ('啊'), ('a'), ('1'), ('ㅂ'); @@ -152,3 +151,207 @@ select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"mont decode(encode(a,"monty"),"monty") = a md5(decode(encode(b,"monty"),"monty")) = md5(b) decode(encode(c,"monty"),"monty") = c 1 1 1 set @@tidb_enable_vectorized_expression = false; +drop table if exists t; +create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20)); +insert into t values ('一', '一', 0xe4b880); +insert into t values ('一', '一', 0xd2bb); +insert into t values ('一', '一', 0xe4ba8c); +insert into t values ('一', '一', 0xb6fe); +set @@tidb_enable_vectorized_expression = true; +select hex(concat(a, c)), hex(concat(b, c)) from t; +hex(concat(a, c)) hex(concat(b, c)) +E4B880E4B8800000000000000000000000000000000000 D2BBE4B8800000000000000000000000000000000000 +E4B880D2BB000000000000000000000000000000000000 D2BBD2BB000000000000000000000000000000000000 +E4B880E4BA8C0000000000000000000000000000000000 D2BBE4BA8C0000000000000000000000000000000000 +E4B880B6FE000000000000000000000000000000000000 D2BBB6FE000000000000000000000000000000000000 +select hex(concat(a, 0xe4b880)), hex(concat(b, 0xd2bb)) from t; +hex(concat(a, 0xe4b880)) hex(concat(b, 0xd2bb)) +E4B880E4B880 D2BBD2BB +E4B880E4B880 D2BBD2BB +E4B880E4B880 D2BBD2BB +E4B880E4B880 D2BBD2BB +select a = 0xe4b880, b = 0xd2bb from t; +a = 0xe4b880 b = 0xd2bb +1 1 +1 1 +1 1 +1 1 +select a = c, b = c from t; +a = c b = c +0 0 +0 0 +0 0 +0 0 +select hex(insert(a, 1, 2, 0xe4ba8c)), hex(insert(b, 1, 2, 0xb6fe)) from t; +hex(insert(a, 1, 2, 0xe4ba8c)) hex(insert(b, 1, 2, 0xb6fe)) +E4BA8C B6FE +E4BA8C B6FE +E4BA8C B6FE +E4BA8C B6FE +select hex(insert(a, 1, 2, c)), hex(insert(b, 1, 2, c)) from t; +hex(insert(a, 1, 2, c)) hex(insert(b, 1, 2, c)) +E4B880000000000000000000000000000000000080 E4B8800000000000000000000000000000000000 +D2BB00000000000000000000000000000000000080 D2BB000000000000000000000000000000000000 +E4BA8C000000000000000000000000000000000080 E4BA8C0000000000000000000000000000000000 +B6FE00000000000000000000000000000000000080 B6FE000000000000000000000000000000000000 +select hex(lpad(a, 5, 0xe4ba8c)), hex(lpad(b, 5, 0xb6fe)) from t; +hex(lpad(a, 5, 0xe4ba8c)) hex(lpad(b, 5, 0xb6fe)) +E4BA8CE4BA8CE4BA8CE4BA8CE4B880 B6FEB6FEB6FEB6FED2BB +E4BA8CE4BA8CE4BA8CE4BA8CE4B880 B6FEB6FEB6FEB6FED2BB +E4BA8CE4BA8CE4BA8CE4BA8CE4B880 B6FEB6FEB6FEB6FED2BB +E4BA8CE4BA8CE4BA8CE4BA8CE4B880 B6FEB6FEB6FEB6FED2BB +select hex(lpad(a, 5, c)), hex(lpad(b, 5, c)) from t; +hex(lpad(a, 5, c)) hex(lpad(b, 5, c)) +E4B8E4B880 E4B880D2BB +D2BBE4B880 D2BB00D2BB +E4BAE4B880 E4BA8CD2BB +B6FEE4B880 B6FE00D2BB +select hex(rpad(a, 5, 0xe4ba8c)), hex(rpad(b, 5, 0xb6fe)) from t; +hex(rpad(a, 5, 0xe4ba8c)) hex(rpad(b, 5, 0xb6fe)) +E4B880E4BA8CE4BA8CE4BA8CE4BA8C D2BBB6FEB6FEB6FEB6FE +E4B880E4BA8CE4BA8CE4BA8CE4BA8C D2BBB6FEB6FEB6FEB6FE +E4B880E4BA8CE4BA8CE4BA8CE4BA8C D2BBB6FEB6FEB6FEB6FE +E4B880E4BA8CE4BA8CE4BA8CE4BA8C D2BBB6FEB6FEB6FEB6FE +select hex(rpad(a, 5, c)), hex(rpad(b, 5, c)) from t; +hex(rpad(a, 5, c)) hex(rpad(b, 5, c)) +E4B880E4B8 D2BBE4B880 +E4B880D2BB D2BBD2BB00 +E4B880E4BA D2BBE4BA8C +E4B880B6FE D2BBB6FE00 +select hex(elt(2, a, 0xe4ba8c)), hex(elt(2, b, 0xb6fe)) from t; +hex(elt(2, a, 0xe4ba8c)) hex(elt(2, b, 0xb6fe)) +E4BA8C B6FE +E4BA8C B6FE +E4BA8C B6FE +E4BA8C B6FE +select hex(elt(2, a, c)), hex(elt(2, b, c)) from t; +hex(elt(2, a, c)) hex(elt(2, b, c)) +E4B8800000000000000000000000000000000000 E4B8800000000000000000000000000000000000 +D2BB000000000000000000000000000000000000 D2BB000000000000000000000000000000000000 +E4BA8C0000000000000000000000000000000000 E4BA8C0000000000000000000000000000000000 +B6FE000000000000000000000000000000000000 B6FE000000000000000000000000000000000000 +select hex(instr(a, 0xe4b880)), hex(instr(b, 0xd2bb)) from t; +hex(instr(a, 0xe4b880)) hex(instr(b, 0xd2bb)) +1 1 +1 1 +1 1 +1 1 +select hex(position(a in 0xe4b880)), hex(position(b in 0xd2bb)) from t; +hex(position(a in 0xe4b880)) hex(position(b in 0xd2bb)) +1 1 +1 1 +1 1 +1 1 +select a like 0xe4b880, b like 0xd2bb from t; +a like 0xe4b880 b like 0xd2bb +1 1 +1 1 +1 1 +1 1 +select a = 0xb6fe from t; +Error 3854: Cannot convert string 'B6FE' from binary to utf8mb4 +select b = 0xe4ba8c from t; +Error 3854: Cannot convert string 'E4BA8C' from binary to gbk +select concat(a, 0xb6fe) from t; +Error 3854: Cannot convert string 'B6FE' from binary to utf8mb4 +select concat(b, 0xe4ba8c) from t; +Error 3854: Cannot convert string 'E4BA8C' from binary to gbk +set @@tidb_enable_vectorized_expression = false; +select hex(concat(a, c)), hex(concat(b, c)) from t; +hex(concat(a, c)) hex(concat(b, c)) +E4B880E4B8800000000000000000000000000000000000 D2BBE4B8800000000000000000000000000000000000 +E4B880D2BB000000000000000000000000000000000000 D2BBD2BB000000000000000000000000000000000000 +E4B880E4BA8C0000000000000000000000000000000000 D2BBE4BA8C0000000000000000000000000000000000 +E4B880B6FE000000000000000000000000000000000000 D2BBB6FE000000000000000000000000000000000000 +select hex(concat(a, 0xe4b880)), hex(concat(b, 0xd2bb)) from t; +hex(concat(a, 0xe4b880)) hex(concat(b, 0xd2bb)) +E4B880E4B880 D2BBD2BB +E4B880E4B880 D2BBD2BB +E4B880E4B880 D2BBD2BB +E4B880E4B880 D2BBD2BB +select a = 0xe4b880, b = 0xd2bb from t; +a = 0xe4b880 b = 0xd2bb +1 1 +1 1 +1 1 +1 1 +select a = c, b = c from t; +a = c b = c +0 0 +0 0 +0 0 +0 0 +select hex(insert(a, 1, 2, 0xe4ba8c)), hex(insert(b, 1, 2, 0xb6fe)) from t; +hex(insert(a, 1, 2, 0xe4ba8c)) hex(insert(b, 1, 2, 0xb6fe)) +E4BA8C B6FE +E4BA8C B6FE +E4BA8C B6FE +E4BA8C B6FE +select hex(insert(a, 1, 2, c)), hex(insert(b, 1, 2, c)) from t; +hex(insert(a, 1, 2, c)) hex(insert(b, 1, 2, c)) +E4B880000000000000000000000000000000000080 E4B8800000000000000000000000000000000000 +D2BB00000000000000000000000000000000000080 D2BB000000000000000000000000000000000000 +E4BA8C000000000000000000000000000000000080 E4BA8C0000000000000000000000000000000000 +B6FE00000000000000000000000000000000000080 B6FE000000000000000000000000000000000000 +select hex(lpad(a, 5, 0xe4ba8c)), hex(lpad(b, 5, 0xb6fe)) from t; +hex(lpad(a, 5, 0xe4ba8c)) hex(lpad(b, 5, 0xb6fe)) +E4BA8CE4BA8CE4BA8CE4BA8CE4B880 B6FEB6FEB6FEB6FED2BB +E4BA8CE4BA8CE4BA8CE4BA8CE4B880 B6FEB6FEB6FEB6FED2BB +E4BA8CE4BA8CE4BA8CE4BA8CE4B880 B6FEB6FEB6FEB6FED2BB +E4BA8CE4BA8CE4BA8CE4BA8CE4B880 B6FEB6FEB6FEB6FED2BB +select hex(lpad(a, 5, c)), hex(lpad(b, 5, c)) from t; +hex(lpad(a, 5, c)) hex(lpad(b, 5, c)) +E4B8E4B880 E4B880D2BB +D2BBE4B880 D2BB00D2BB +E4BAE4B880 E4BA8CD2BB +B6FEE4B880 B6FE00D2BB +select hex(rpad(a, 5, 0xe4ba8c)), hex(rpad(b, 5, 0xb6fe)) from t; +hex(rpad(a, 5, 0xe4ba8c)) hex(rpad(b, 5, 0xb6fe)) +E4B880E4BA8CE4BA8CE4BA8CE4BA8C D2BBB6FEB6FEB6FEB6FE +E4B880E4BA8CE4BA8CE4BA8CE4BA8C D2BBB6FEB6FEB6FEB6FE +E4B880E4BA8CE4BA8CE4BA8CE4BA8C D2BBB6FEB6FEB6FEB6FE +E4B880E4BA8CE4BA8CE4BA8CE4BA8C D2BBB6FEB6FEB6FEB6FE +select hex(rpad(a, 5, c)), hex(rpad(b, 5, c)) from t; +hex(rpad(a, 5, c)) hex(rpad(b, 5, c)) +E4B880E4B8 D2BBE4B880 +E4B880D2BB D2BBD2BB00 +E4B880E4BA D2BBE4BA8C +E4B880B6FE D2BBB6FE00 +select hex(elt(2, a, 0xe4ba8c)), hex(elt(2, b, 0xb6fe)) from t; +hex(elt(2, a, 0xe4ba8c)) hex(elt(2, b, 0xb6fe)) +E4BA8C B6FE +E4BA8C B6FE +E4BA8C B6FE +E4BA8C B6FE +select hex(elt(2, a, c)), hex(elt(2, b, c)) from t; +hex(elt(2, a, c)) hex(elt(2, b, c)) +E4B8800000000000000000000000000000000000 E4B8800000000000000000000000000000000000 +D2BB000000000000000000000000000000000000 D2BB000000000000000000000000000000000000 +E4BA8C0000000000000000000000000000000000 E4BA8C0000000000000000000000000000000000 +B6FE000000000000000000000000000000000000 B6FE000000000000000000000000000000000000 +select hex(instr(a, 0xe4b880)), hex(instr(b, 0xd2bb)) from t; +hex(instr(a, 0xe4b880)) hex(instr(b, 0xd2bb)) +1 1 +1 1 +1 1 +1 1 +select hex(position(a in 0xe4b880)), hex(position(b in 0xd2bb)) from t; +hex(position(a in 0xe4b880)) hex(position(b in 0xd2bb)) +1 1 +1 1 +1 1 +1 1 +select a like 0xe4b880, b like 0xd2bb from t; +a like 0xe4b880 b like 0xd2bb +1 1 +1 1 +1 1 +1 1 +select a = 0xb6fe from t; +Error 3854: Cannot convert string 'B6FE' from binary to utf8mb4 +select b = 0xe4ba8c from t; +Error 3854: Cannot convert string 'E4BA8C' from binary to gbk +select concat(a, 0xb6fe) from t; +Error 3854: Cannot convert string 'B6FE' from binary to utf8mb4 +select concat(b, 0xe4ba8c) from t; +Error 3854: Cannot convert string 'E4BA8C' from binary to gbk diff --git a/cmd/explaintest/t/new_character_set_builtin.test b/cmd/explaintest/t/new_character_set_builtin.test index d5d0bcc9a14f5..1cec990d1a5c0 100644 --- a/cmd/explaintest/t/new_character_set_builtin.test +++ b/cmd/explaintest/t/new_character_set_builtin.test @@ -31,7 +31,6 @@ insert into t values ('一二三', '一二三', '一二三'); select to_base64(a), to_base64(b), to_base64(c) from t; set @@tidb_enable_vectorized_expression = true; select to_base64(a), to_base64(b), to_base64(c) from t; -set @@tidb_enable_vectorized_expression = false; -- test for builtin function convert() drop table if exists t; @@ -75,3 +74,62 @@ select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"mont set @@tidb_enable_vectorized_expression = true; select decode(encode(a,"monty"),"monty") = a, md5(decode(encode(b,"monty"),"monty")) = md5(b), decode(encode(c,"monty"),"monty") = c from t; set @@tidb_enable_vectorized_expression = false; + +drop table if exists t; +create table t (a char(20) charset utf8mb4, b char(20) charset gbk, c binary(20)); +insert into t values ('一', '一', 0xe4b880); +insert into t values ('一', '一', 0xd2bb); +insert into t values ('一', '一', 0xe4ba8c); +insert into t values ('一', '一', 0xb6fe); + +set @@tidb_enable_vectorized_expression = true; +select hex(concat(a, c)), hex(concat(b, c)) from t; +select hex(concat(a, 0xe4b880)), hex(concat(b, 0xd2bb)) from t; +select a = 0xe4b880, b = 0xd2bb from t; +select a = c, b = c from t; +select hex(insert(a, 1, 2, 0xe4ba8c)), hex(insert(b, 1, 2, 0xb6fe)) from t; +select hex(insert(a, 1, 2, c)), hex(insert(b, 1, 2, c)) from t; +select hex(lpad(a, 5, 0xe4ba8c)), hex(lpad(b, 5, 0xb6fe)) from t; +select hex(lpad(a, 5, c)), hex(lpad(b, 5, c)) from t; +select hex(rpad(a, 5, 0xe4ba8c)), hex(rpad(b, 5, 0xb6fe)) from t; +select hex(rpad(a, 5, c)), hex(rpad(b, 5, c)) from t; +select hex(elt(2, a, 0xe4ba8c)), hex(elt(2, b, 0xb6fe)) from t; +select hex(elt(2, a, c)), hex(elt(2, b, c)) from t; +select hex(instr(a, 0xe4b880)), hex(instr(b, 0xd2bb)) from t; +select hex(position(a in 0xe4b880)), hex(position(b in 0xd2bb)) from t; +select a like 0xe4b880, b like 0xd2bb from t; + +--error ER_CANNOT_CONVERT_STRING +select a = 0xb6fe from t; +--error ER_CANNOT_CONVERT_STRING +select b = 0xe4ba8c from t; +--error ER_CANNOT_CONVERT_STRING +select concat(a, 0xb6fe) from t; +--error ER_CANNOT_CONVERT_STRING +select concat(b, 0xe4ba8c) from t; + +set @@tidb_enable_vectorized_expression = false; +select hex(concat(a, c)), hex(concat(b, c)) from t; +select hex(concat(a, 0xe4b880)), hex(concat(b, 0xd2bb)) from t; +select a = 0xe4b880, b = 0xd2bb from t; +select a = c, b = c from t; +select hex(insert(a, 1, 2, 0xe4ba8c)), hex(insert(b, 1, 2, 0xb6fe)) from t; +select hex(insert(a, 1, 2, c)), hex(insert(b, 1, 2, c)) from t; +select hex(lpad(a, 5, 0xe4ba8c)), hex(lpad(b, 5, 0xb6fe)) from t; +select hex(lpad(a, 5, c)), hex(lpad(b, 5, c)) from t; +select hex(rpad(a, 5, 0xe4ba8c)), hex(rpad(b, 5, 0xb6fe)) from t; +select hex(rpad(a, 5, c)), hex(rpad(b, 5, c)) from t; +select hex(elt(2, a, 0xe4ba8c)), hex(elt(2, b, 0xb6fe)) from t; +select hex(elt(2, a, c)), hex(elt(2, b, c)) from t; +select hex(instr(a, 0xe4b880)), hex(instr(b, 0xd2bb)) from t; +select hex(position(a in 0xe4b880)), hex(position(b in 0xd2bb)) from t; +select a like 0xe4b880, b like 0xd2bb from t; + +--error ER_CANNOT_CONVERT_STRING +select a = 0xb6fe from t; +--error ER_CANNOT_CONVERT_STRING +select b = 0xe4ba8c from t; +--error ER_CANNOT_CONVERT_STRING +select concat(a, 0xb6fe) from t; +--error ER_CANNOT_CONVERT_STRING +select concat(b, 0xe4ba8c) from t; diff --git a/errno/errcode.go b/errno/errcode.go index c4b5234a554f2..04fe354c9067b 100644 --- a/errno/errcode.go +++ b/errno/errcode.go @@ -901,6 +901,7 @@ const ( ErrFKIncompatibleColumns = 3780 ErrFunctionalIndexRowValueIsNotAllowed = 3800 ErrDependentByFunctionalIndex = 3837 + ErrCannotConvertString = 3854 ErrInvalidJSONValueForFuncIndex = 3903 ErrJSONValueOutOfRangeForFuncIndex = 3904 ErrFunctionalIndexDataIsTooLong = 3907 diff --git a/errno/errname.go b/errno/errname.go index 799f74af63f08..df3661ab7dbc7 100644 --- a/errno/errname.go +++ b/errno/errname.go @@ -896,6 +896,7 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{ ErrFKIncompatibleColumns: mysql.Message("Referencing column '%s' in foreign key constraint '%s' are incompatible", nil), ErrFunctionalIndexRowValueIsNotAllowed: mysql.Message("Expression of expression index '%s' cannot refer to a row value", nil), ErrDependentByFunctionalIndex: mysql.Message("Column '%s' has an expression index dependency and cannot be dropped or renamed", nil), + ErrCannotConvertString: mysql.Message("Cannot convert string '%.64s' from %s to %s", nil), ErrInvalidJSONValueForFuncIndex: mysql.Message("Invalid JSON value for CAST for expression index '%s'", nil), ErrJSONValueOutOfRangeForFuncIndex: mysql.Message("Out of range JSON value for CAST for expression index '%s'", nil), ErrFunctionalIndexDataIsTooLong: mysql.Message("Data too long for expression index '%s'", nil), diff --git a/expression/builtin.go b/expression/builtin.go index 2ec8672c5cce1..a1e2cbba919da 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -91,7 +91,7 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi if ctx == nil { return baseBuiltinFunc{}, errors.New("unexpected nil session ctx") } - ec, err := deriveCollation(ctx, funcName, args, retType, retType) + ec, _, err := deriveCollation(ctx, funcName, args, retType, retType) if err != nil { return baseBuiltinFunc{}, err } @@ -125,7 +125,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex // derive collation information for string function, and we must do it // before doing implicit cast. - ec, err := deriveCollation(ctx, funcName, args, retType, argTps...) + ec, retTp, err := deriveCollation(ctx, funcName, args, retType, argTps...) if err != nil { return } @@ -139,7 +139,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex case types.ETDecimal: args[i] = WrapWithCastAsDecimal(ctx, args[i]) case types.ETString: - args[i] = WrapWithCastAsString(ctx, args[i]) + args[i] = WrapWithCastAsStringWithTp(ctx, args[i], retTp) case types.ETDatetime: args[i] = WrapWithCastAsTime(ctx, args[i], types.NewFieldType(mysql.TypeDatetime)) case types.ETTimestamp: diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index b155370d64462..933061930c128 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -23,12 +23,16 @@ package expression import ( + "fmt" "math" "strconv" "strings" + "unicode/utf8" "github.com/pingcap/errors" + "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/ast" + "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" @@ -37,6 +41,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tipb/go-tipb" ) @@ -108,6 +113,11 @@ var ( _ builtinFunc = &builtinCastJSONAsJSONSig{} ) +var ( + // errCannotConvertString returns when the string can not convert to other charset. + errCannotConvertString = dbterror.ClassExpression.NewStd(errno.ErrCannotConvertString) +) + type castAsIntFunctionClass struct { baseFunctionClass @@ -1112,6 +1122,23 @@ func (b *builtinCastStringAsStringSig) evalString(row chunk.Row) (res string, is if isNull || err != nil { return res, isNull, err } + ov := res + fromChs := b.args[0].GetType().Charset + toChs := b.tp.Charset + if toChs == charset.CharsetBin && fromChs != charset.CharsetBin { + res, err = charset.NewEncoding(fromChs).EncodeString(res) + } else if toChs != charset.CharsetBin && fromChs == charset.CharsetBin { + res, err = charset.NewEncoding(toChs).DecodeString(res) + // If toChs is utf8 or utf8mb4, DecodeString will do nothing and return nil error, but we need check if the binary literal is able to convert to utf8. + if toChs == charset.CharsetUTF8 || toChs == charset.CharsetUTF8MB4 { + if !utf8.ValidString(res) { + return "", false, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", ov), fromChs, toChs) + } + } + } + if err != nil { + return "", false, errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", ov), fromChs, toChs) + } sc := b.ctx.GetSessionVars().StmtCtx res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc, false) if err != nil { @@ -1907,6 +1934,25 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression { return BuildCastFunction(ctx, expr, tp) } +// WrapWithCastAsStringWithTp wraps `expr` with `cast`. +func WrapWithCastAsStringWithTp(ctx sessionctx.Context, expr Expression, toTp *types.FieldType) Expression { + if expr.GetType().EvalType() == types.ETString && toTp != nil { + if expr.GetType().Charset == toTp.Charset { + return expr + } + toTp = &types.FieldType{ + Tp: mysql.TypeVarString, + Decimal: expr.GetType().Decimal, // keep original Decimal + Charset: toTp.Charset, + Collate: toTp.Collate, + Flen: expr.GetType().Flen, // keep original Flen + } + return BuildCastFunction(ctx, expr, toTp) + } + + return WrapWithCastAsString(ctx, expr) +} + // WrapWithCastAsString wraps `expr` with `cast` if the return type of expr is // not type string, otherwise, returns `expr` directly. func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression { diff --git a/expression/builtin_cast_vec.go b/expression/builtin_cast_vec.go index 95609069dcba6..0410beff889b7 100644 --- a/expression/builtin_cast_vec.go +++ b/expression/builtin_cast_vec.go @@ -15,10 +15,13 @@ package expression import ( + "fmt" "math" "strconv" "strings" + "unicode/utf8" + "github.com/pingcap/tidb/parser/charset" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" @@ -1820,6 +1823,24 @@ func (b *builtinCastStringAsStringSig) vecEvalString(input *chunk.Chunk, result var res string var isNull bool + + fromChs := b.args[0].GetType().Charset + toChs := b.tp.Charset + transferString := func(s string) (string, error) { return s, nil } + if toChs == charset.CharsetBin && fromChs != charset.CharsetBin { + transferString = charset.NewEncoding(fromChs).EncodeString + } else if toChs != charset.CharsetBin && fromChs == charset.CharsetBin { + transferString = charset.NewEncoding(toChs).DecodeString + if toChs == charset.CharsetUTF8 || toChs == charset.CharsetUTF8MB4 { + transferString = func(s string) (string, error) { + if !utf8.ValidString(s) { + return "", errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", s), fromChs, toChs) + } + return s, nil + } + } + } + sc := b.ctx.GetSessionVars().StmtCtx result.ReserveString(n) for i := 0; i < n; i++ { @@ -1827,7 +1848,11 @@ func (b *builtinCastStringAsStringSig) vecEvalString(input *chunk.Chunk, result result.AppendNull() continue } - res, err = types.ProduceStrWithSpecifiedTp(buf.GetString(i), b.tp, sc, false) + res, err = transferString(buf.GetString(i)) + if err != nil { + return errCannotConvertString.GenWithStackByArgs(fmt.Sprintf("%X", buf.GetString(i)), fromChs, toChs) + } + res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc, false) if err != nil { return err } diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index e5c2dcb08fcfa..dcd241fdbe4d7 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -1210,7 +1210,7 @@ func GetCmpFunction(ctx sessionctx.Context, lhs, rhs Expression) CompareFunc { case types.ETDecimal: return CompareDecimal case types.ETString: - coll, _ := CheckAndDeriveCollationFromExprs(ctx, "", types.ETInt, lhs, rhs) + coll, _, _ := CheckAndDeriveCollationFromExprs(ctx, "", types.ETInt, lhs, rhs) return genCompareString(coll.Collation) case types.ETDuration: return CompareDuration diff --git a/expression/builtin_control.go b/expression/builtin_control.go index e9b39bf36ab5c..4aa9e93ce9b55 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -94,7 +94,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp } if types.IsNonBinaryStr(lhs) && !types.IsBinaryStr(rhs) { - ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp) + ec, _, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp) if err != nil { return nil, err } @@ -104,7 +104,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp resultFieldType.Flag |= mysql.BinaryFlag } } else if types.IsNonBinaryStr(rhs) && !types.IsBinaryStr(lhs) { - ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp) + ec, _, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp) if err != nil { return nil, err } diff --git a/expression/builtin_string.go b/expression/builtin_string.go index b5b495321e16c..f7b835d624912 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -2972,6 +2972,10 @@ func (c *quoteFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi } SetBinFlagOrBinStr(args[0].GetType(), bf.tp) bf.tp.Flen = 2*args[0].GetType().Flen + 2 + // If arg is NULL, quote function will return 'NULL', the Flen should be 4. + if args[0].GetType().Tp == mysql.TypeNull { + bf.tp.Flen = 4 + } if bf.tp.Flen > mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth } diff --git a/expression/collation.go b/expression/collation.go index 66e0c2e33c9c4..c1f6ce31b7e15 100644 --- a/expression/collation.go +++ b/expression/collation.go @@ -192,7 +192,8 @@ func deriveCoercibilityForColumn(c *Column) Coercibility { return CoercibilityImplicit } -func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType, argTps ...types.EvalType) (ec *ExprCollation, err error) { +// retTp is the type that the function's args should cast to, only use for string args, return nil means no need to do cast. +func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType, argTps ...types.EvalType) (ec *ExprCollation, retTp *types.FieldType, err error) { switch funcName { case ast.Concat, ast.ConcatWS, ast.Lower, ast.Lcase, ast.Reverse, ast.Upper, ast.Ucase, ast.Quote, ast.Coalesce: return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args...) @@ -215,53 +216,48 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, case ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, ast.NE, ast.NullEQ, ast.Strcmp: // if compare type is string, we should determine which collation should be used. if argTps[0] == types.ETString { - ec, err = CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args...) + ec, retTp, err = CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args...) if err != nil { - return nil, err + return nil, nil, err } ec.Coer = CoercibilityNumeric ec.Repe = ASCII - return ec, nil + return ec, retTp, nil } case ast.If: return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args[1], args[2]) case ast.Ifnull: return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args[0], args[1]) case ast.Like: - ec, err = CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args[0], args[1]) + ec, retTp, err = CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args[0], args[1]) if err != nil { - return nil, err + return nil, nil, err } ec.Coer = CoercibilityNumeric ec.Repe = ASCII - return ec, nil + return ec, retTp, nil case ast.In: if args[0].GetType().EvalType() == types.ETString { return CheckAndDeriveCollationFromExprs(ctx, funcName, types.ETInt, args...) } case ast.DateFormat, ast.TimeFormat: charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo() - return &ExprCollation{args[1].Coercibility(), args[1].Repertoire(), charsetInfo, collation}, nil + return &ExprCollation{args[1].Coercibility(), args[1].Repertoire(), charsetInfo, collation}, nil, nil case ast.Cast: - // We assume all the cast are implicit. - ec = &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType().Charset, args[0].GetType().Collate} - // Non-string type cast to string type should use @@character_set_connection and @@collation_connection. - // String type cast to string type should keep its original charset and collation. It should not happen. - if retType == types.ETString && argTps[0] != types.ETString { - ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo() - } - return ec, nil + // We assume all the cast are implicit, keep the collation related fields to its original value. + return &ExprCollation{args[0].Coercibility(), args[0].Repertoire(), args[0].GetType().Charset, args[0].GetType().Collate}, nil, nil case ast.Case: // FIXME: case function aggregate collation is not correct. - return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args...) + ec, _, err = CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args...) + return ec, nil, err case ast.Database, ast.User, ast.CurrentUser, ast.Version, ast.CurrentRole, ast.TiDBVersion: chs, coll := charset.GetDefaultCharsetAndCollate() - return &ExprCollation{CoercibilitySysconst, UNICODE, chs, coll}, nil + return &ExprCollation{CoercibilitySysconst, UNICODE, chs, coll}, nil, nil case ast.Format, ast.Space, ast.ToBase64, ast.UUID, ast.Hex, ast.MD5, ast.SHA, ast.SHA2: // should return ASCII repertoire, MySQL's doc says it depends on character_set_connection, but it not true from its source code. ec = &ExprCollation{Coer: CoercibilityCoercible, Repe: ASCII} ec.Charset, ec.Collation = ctx.GetSessionVars().GetCharsetInfo() - return ec, nil + return ec, nil, nil } ec = &ExprCollation{CoercibilityNumeric, ASCII, charset.CharsetBin, charset.CollationBin} @@ -272,7 +268,7 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, ec.Repe = UNICODE } } - return ec, nil + return ec, nil, nil } // DeriveCollationFromExprs derives collation information from these expressions. @@ -284,14 +280,14 @@ func DeriveCollationFromExprs(ctx sessionctx.Context, exprs ...Expression) (dstC } // CheckAndDeriveCollationFromExprs derives collation information from these expressions, return error if derives collation error. -func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, evalType types.EvalType, args ...Expression) (et *ExprCollation, err error) { +func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, evalType types.EvalType, args ...Expression) (et *ExprCollation, retTp *types.FieldType, err error) { ec := inferCollation(args...) if ec == nil { - return nil, illegalMixCollationErr(funcName, args) + return nil, nil, illegalMixCollationErr(funcName, args) } if evalType != types.ETString && ec.Coer == CoercibilityNone { - return nil, illegalMixCollationErr(funcName, args) + return nil, nil, illegalMixCollationErr(funcName, args) } if evalType == types.ETString && ec.Coer == CoercibilityNumeric { @@ -301,10 +297,9 @@ func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, e } if !safeConvert(ctx, ec, args...) { - return nil, illegalMixCollationErr(funcName, args) + return nil, nil, illegalMixCollationErr(funcName, args) } - - return ec, nil + return ec, &types.FieldType{Charset: ec.Charset, Collate: ec.Collation}, nil } func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) bool { @@ -322,7 +317,11 @@ func safeConvert(ctx sessionctx.Context, ec *ExprCollation, args ...Expression) if err != nil { return false } - if !isNull && !isValidString(str, ec.Charset) { + // if value is NULL or binary string, just skip it. + if isNull || types.IsBinaryStr(c.GetType()) { + continue + } + if !isValidString(str, ec.Charset) { return false } } else { diff --git a/expression/collation_test.go b/expression/collation_test.go index 1a8541fef3060..fcd1f7e2578b2 100644 --- a/expression/collation_test.go +++ b/expression/collation_test.go @@ -612,23 +612,11 @@ func TestDeriveCollation(t *testing.T) { false, &ExprCollation{CoercibilitySysconst, UNICODE, charset.CharsetUTF8MB4, charset.CollationUTF8MB4}, }, - { - []string{ - ast.Cast, - }, - []Expression{ - newColInt(CoercibilityExplicit), - }, - []types.EvalType{types.ETInt}, - types.ETString, - false, - &ExprCollation{CoercibilityExplicit, ASCII, charset.CharsetUTF8MB4, charset.CollationUTF8MB4}, - }, } for i, test := range tests { for _, fc := range test.fcs { - ec, err := deriveCollation(ctx, fc, test.args, test.retTp, test.argTps...) + ec, _, err := deriveCollation(ctx, fc, test.args, test.retTp, test.argTps...) if test.err { require.Error(t, err, "Number: %d, function: %s", i, fc) require.Nil(t, ec, i) diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 99693677091bc..502601a9d1218 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -1218,7 +1218,12 @@ func convertUint(val []byte) (*Constant, error) { func convertString(val []byte, tp *tipb.FieldType) (*Constant, error) { var d types.Datum d.SetBytesAsString(val, protoToCollation(tp.Collate), uint32(tp.Flen)) - return &Constant{Value: d, RetType: types.NewFieldType(mysql.TypeVarString)}, nil + return &Constant{Value: d, RetType: &types.FieldType{ + Tp: mysql.TypeString, + Flag: uint(tp.Flag), + Charset: tp.Charset, + Flen: int(tp.Flen), + }}, nil } func convertFloat(val []byte, f32 bool) (*Constant, error) { diff --git a/expression/integration_test.go b/expression/integration_test.go index 59984a2c97f61..c7688b9afcaa3 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1180,7 +1180,7 @@ func (s *testIntegrationSuite2) TestStringBuiltin(c *C) { // for insert result = tk.MustQuery(`select insert("中文", 1, 1, cast("aaa" as binary)), insert("ba", -1, 1, "aaa"), insert("ba", 1, 100, "aaa"), insert("ba", 100, 1, "aaa");`) - result.Check(testkit.Rows("aaa文 ba aaa ba")) + result.Check(testkit.Rows("aaa\xb8\xad文 ba aaa ba")) result = tk.MustQuery(`select insert("bb", NULL, 1, "aa"), insert("bb", 1, NULL, "aa"), insert(NULL, 1, 1, "aaa"), insert("bb", 1, 1, NULL);`) result.Check(testkit.Rows(" ")) result = tk.MustQuery(`SELECT INSERT("bb", 0, 1, NULL), INSERT("bb", 0, NULL, "aaa");`) diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 13f5d81380e7e..fea83846729f6 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -275,7 +275,6 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"CONCAT('T', 'i', 'DB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength}, {"CONCAT_WS('-', 'T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, mysql.NotNullFlag, 6, types.UnspecifiedLength}, {"CONCAT_WS(',', 'TiDB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 25, types.UnspecifiedLength}, - {"CONCAT(c_bchar, 0x80)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 23, types.UnspecifiedLength}, {"left(c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, {"right(c_int_d, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, {"lower(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 20, types.UnspecifiedLength}, @@ -490,6 +489,7 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"quote(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 42, types.UnspecifiedLength}, {"quote(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 0, types.UnspecifiedLength}, {"quote(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 0, types.UnspecifiedLength}, + {"quote(null )", mysql.TypeVarString, charset.CharsetBinary, mysql.BinaryFlag, 4, types.UnspecifiedLength}, {"convert(c_double_d using utf8mb4)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"convert(c_binary using utf8mb4)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, diff --git a/parser/charset/encoding_table.go b/parser/charset/encoding_table.go index ea7e6d8915798..18aac4b75a968 100644 --- a/parser/charset/encoding_table.go +++ b/parser/charset/encoding_table.go @@ -175,14 +175,14 @@ var encodings = map[string]struct { "cp819": {charmap.Windows1252, "windows-1252"}, "csisolatin1": {charmap.Windows1252, "windows-1252"}, "ibm819": {charmap.Windows1252, "windows-1252"}, - "iso-8859-1": {charmap.Windows1252, "windows-1252"}, + "iso-8859-1": {charmap.ISO8859_1, "iso-8859-1"}, "iso-ir-100": {charmap.Windows1252, "windows-1252"}, "iso8859-1": {charmap.Windows1252, "windows-1252"}, "iso88591": {charmap.Windows1252, "windows-1252"}, "iso_8859-1": {charmap.Windows1252, "windows-1252"}, "iso_8859-1:1987": {charmap.Windows1252, "windows-1252"}, "l1": {charmap.Windows1252, "windows-1252"}, - "latin1": {charmap.Windows1252, "windows-1252"}, + "latin1": {charmap.ISO8859_1, "iso-8859-1"}, "us-ascii": {charmap.Windows1252, "windows-1252"}, "windows-1252": {charmap.Windows1252, "windows-1252"}, "x-cp1252": {charmap.Windows1252, "windows-1252"}, @@ -273,6 +273,9 @@ func FindNextCharacterLength(label string) func([]byte) int { var encodingNextCharacterLength = map[string]func([]byte) int{ // https://en.wikipedia.org/wiki/GBK_(character_encoding)#Layout_diagram + "windows-1252": func(bs []byte) int { + return 1 + }, "gbk": characterLengthGBK, "utf-8": characterLengthUTF8, "binary": func(bs []byte) int { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 0581301d1a791..f21c0ce634adf 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -545,7 +545,7 @@ func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, v *ast. // Lexpr cannot compare with rexpr by different collate opString := new(strings.Builder) v.Op.Format(opString) - _, er.err = expression.CheckAndDeriveCollationFromExprs(er.sctx, opString.String(), types.ETInt, lexpr, rexpr) + _, _, er.err = expression.CheckAndDeriveCollationFromExprs(er.sctx, opString.String(), types.ETInt, lexpr, rexpr) if er.err != nil { return v, true } @@ -1670,7 +1670,7 @@ func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) { expr, lexp, rexp := er.wrapExpWithCast() - coll, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "BETWEEN", types.ETInt, expr, lexp, rexp) + coll, _, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "BETWEEN", types.ETInt, expr, lexp, rexp) er.err = err if er.err != nil { return diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 175b5395c7fd7..6c12a27d62b55 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -1378,7 +1378,7 @@ func (b *PlanBuilder) buildProjection4Union(ctx context.Context, u *LogicalUnion childTp := u.children[j].Schema().Columns[i].RetType resultTp = unionJoinFieldType(resultTp, childTp) } - collation, err := expression.CheckAndDeriveCollationFromExprs(b.ctx, "UNION", resultTp.EvalType(), tmpExprs...) + collation, _, err := expression.CheckAndDeriveCollationFromExprs(b.ctx, "UNION", resultTp.EvalType(), tmpExprs...) if err != nil || collation.Coer == expression.CoercibilityNone { return collate.ErrIllegalMixCollation.GenWithStackByArgs("UNION") } diff --git a/types/datum.go b/types/datum.go index d79a086cd7346..26f9b761e4840 100644 --- a/types/datum.go +++ b/types/datum.go @@ -996,9 +996,9 @@ func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementCon // overflowed part is all whitespaces var overflowed string var characterLen int - // Flen is the rune length, not binary length, for UTF8 charset, we need to calculate the + // Flen is the rune length, not binary length, for Non-binary charset, we need to calculate the // rune count and truncate to Flen runes if it is too long. - if chs == charset.CharsetUTF8 || chs == charset.CharsetUTF8MB4 { + if chs != charset.CharsetBinary { characterLen = utf8.RuneCountInString(s) if characterLen > flen { // 1. If len(s) is 0 and flen is 0, truncateLen will be 0, don't truncate s.