Skip to content

Commit

Permalink
expression: fix the wrong behavior of char function (#17598) (#18122)
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <[email protected]>
  • Loading branch information
ti-srebot authored Jul 28, 2020
1 parent 889eb98 commit a84bddf
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 76 deletions.
49 changes: 22 additions & 27 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -2308,8 +2308,29 @@ func (c *charFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
if err != nil {
return nil, err
}
// The last argument represents the charset name after "using".
if _, ok := args[len(args)-1].(*Constant); !ok {
// If we got there, there must be something wrong in other places.
logutil.BgLogger().Warn(fmt.Sprintf("The last argument in char function must be constant, but got %T", args[len(args)-1]))
return nil, errIncorrectArgs
}
charsetName, isNull, err := args[len(args)-1].EvalString(ctx, chunk.Row{})
if err != nil {
return nil, err
}
if isNull {
// Use the default charset binary if it is nil.
bf.tp.Charset, bf.tp.Collate = charset.CharsetBin, charset.CollationBin
bf.tp.Flag |= mysql.BinaryFlag
} else {
bf.tp.Charset = charsetName
defaultCollate, err := charset.GetDefaultCollation(charsetName)
if err != nil {
return nil, err
}
bf.tp.Collate = defaultCollate
}
bf.tp.Flen = 4 * (len(args) - 1)
types.SetBinChsClnFlag(bf.tp)

sig := &builtinCharSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Char)
Expand Down Expand Up @@ -2354,33 +2375,7 @@ func (b *builtinCharSig) evalString(row chunk.Row) (string, bool, error) {
}
bigints = append(bigints, val)
}
// The last argument represents the charset name after "using".
// Use default charset utf8 if it is nil.
argCharset, IsNull, err := b.args[len(b.args)-1].EvalString(b.ctx, row)
if err != nil {
return "", true, err
}

result := string(b.convertToBytes(bigints))
charsetLabel := strings.ToLower(argCharset)
if IsNull || charsetLabel == "ascii" || strings.HasPrefix(charsetLabel, "utf8") {
return result, false, nil
}

encoding, charsetName := charset.Lookup(charsetLabel)
if encoding == nil {
return "", true, errors.Errorf("unknown encoding: %s", argCharset)
}

oldStr := result
result, _, err = transform.String(encoding.NewDecoder(), result)
if err != nil {
logutil.BgLogger().Warn("change charset of string",
zap.String("string", oldStr),
zap.String("charset", charsetName),
zap.Error(err))
return "", true, err
}
return result, false, nil
}

Expand Down
7 changes: 0 additions & 7 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1285,13 +1285,6 @@ func (s *testEvaluatorSuite) TestChar(c *C) {
r, err := evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum("AB"))

// Test unsupported charset.
fc = funcs[ast.CharFunc]
f, err = fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums("65", "tidb")))
c.Assert(err, IsNil)
_, err = evalBuiltinFunc(f, chunk.Row{})
c.Assert(err.Error(), Equals, "unknown encoding: tidb")
}

func (s *testEvaluatorSuite) TestCharLength(c *C) {
Expand Down
26 changes: 1 addition & 25 deletions expression/builtin_string_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@ import (
"strings"
"unicode/utf8"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
"golang.org/x/text/transform"
)

Expand Down Expand Up @@ -2241,9 +2238,6 @@ func (b *builtinCharSig) vecEvalString(input *chunk.Chunk, result *chunk.Column)
return err
}
defer b.bufAllocator.put(bufstr)
if err := b.args[l-1].VecEvalString(b.ctx, input, bufstr); err != nil {
return err
}
bigints := make([]int64, 0, l-1)
result.ReserveString(n)
bufint := make([]([]int64), l-1)
Expand All @@ -2259,25 +2253,7 @@ func (b *builtinCharSig) vecEvalString(input *chunk.Chunk, result *chunk.Column)
bigints = append(bigints, bufint[j][i])
}
tempString := string(b.convertToBytes(bigints))
charsetLable := strings.ToLower(bufstr.GetString(i))
if bufstr.IsNull(i) || charsetLable == "ascii" || strings.HasPrefix(charsetLable, "utf8") {
result.AppendString(tempString)
} else {
encoding, charsetName := charset.Lookup(charsetLable)
if encoding == nil {
return errors.Errorf("unknown encoding: %s", bufstr.GetString(i))
}
oldStr := tempString
tempString, _, err := transform.String(encoding.NewDecoder(), tempString)
if err != nil {
logutil.BgLogger().Warn("change charset of string",
zap.String("string", oldStr),
zap.String("charset", charsetName),
zap.Error(err))
return err
}
result.AppendString(tempString)
}
result.AppendString(tempString)
}
return nil
}
Expand Down
3 changes: 2 additions & 1 deletion expression/builtin_string_vec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ var vecBuiltinStringCases = map[string][]vecExprBenchCase{
{
retEvalType: types.ETString,
childrenTypes: []types.EvalType{types.ETInt, types.ETInt, types.ETInt, types.ETString},
geners: []dataGenerator{&charInt64Gener{}, &charInt64Gener{}, &charInt64Gener{}, &charsetStringGener{}},
geners: []dataGenerator{&charInt64Gener{}, &charInt64Gener{}, &charInt64Gener{}, nil},
constants: []*Constant{nil, nil, nil, {Value: types.NewDatum("ascii"), RetType: types.NewFieldType(mysql.TypeString)}},
},
},
ast.FindInSet: {
Expand Down
32 changes: 16 additions & 16 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,22 +382,22 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase {
{"char(c_blob_d )", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 4, types.UnspecifiedLength},
{"char(c_set )", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 4, types.UnspecifiedLength},
{"char(c_enum )", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 4, types.UnspecifiedLength},
{"char(c_int_d , c_int_d using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_bigint_d , c_bigint_d using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_float_d , c_float_d using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_double_d , c_double_d using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_decimal , c_decimal using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_datetime , c_datetime using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_time_d , c_time_d using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_timestamp_d, c_timestamp_d using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_char , c_char using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_varchar , c_varchar using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_text_d , c_text_d using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_binary , c_binary using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_varbinary , c_varbinary using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_blob_d , c_blob_d using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_set , c_set using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_enum , c_enum using utf8)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 8, types.UnspecifiedLength},
{"char(c_int_d , c_int_d using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_bigint_d , c_bigint_d using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_float_d , c_float_d using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_double_d , c_double_d using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_decimal , c_decimal using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_datetime , c_datetime using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_time_d , c_time_d using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_timestamp_d, c_timestamp_d using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_char , c_char using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_varchar , c_varchar using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_text_d , c_text_d using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_binary , c_binary using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_varbinary , c_varbinary using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_blob_d , c_blob_d using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_set , c_set using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"char(c_enum , c_enum using utf8)", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},

{"instr(c_char, c_binary )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"instr(c_char, c_char )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
Expand Down

0 comments on commit a84bddf

Please sign in to comment.