Skip to content

Commit

Permalink
expression: return null when cast to huge binary type (#8768) (#9350)
Browse files Browse the repository at this point in the history
  • Loading branch information
eurekaka authored and jackysp committed Feb 19, 2019
1 parent f9df5fa commit a4113cd
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 18 deletions.
62 changes: 50 additions & 12 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -520,8 +521,11 @@ func (b *builtinCastIntAsStringSig) evalString(row types.Row) (res string, isNul
}
res = strconv.FormatUint(uVal, 10)
}
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastIntAsTimeSig struct {
Expand Down Expand Up @@ -781,8 +785,11 @@ func (b *builtinCastRealAsStringSig) evalString(row types.Row) (res string, isNu
if isNull || err != nil {
return res, isNull, errors.Trace(err)
}
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, 64), b.tp, b.ctx.GetSessionVars().StmtCtx)
return res, isNull, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(strconv.FormatFloat(val, 'f', -1, 64), b.tp, b.ctx.GetSessionVars().StmtCtx, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastRealAsTimeSig struct {
Expand Down Expand Up @@ -908,8 +915,11 @@ func (b *builtinCastDecimalAsStringSig) evalString(row types.Row) (res string, i
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(string(val.ToString()), b.tp, sc, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastDecimalAsRealSig struct {
Expand Down Expand Up @@ -1000,8 +1010,11 @@ func (b *builtinCastStringAsStringSig) evalString(row types.Row) (res string, is
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, sc, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastStringAsIntSig struct {
Expand Down Expand Up @@ -1291,8 +1304,11 @@ func (b *builtinCastTimeAsStringSig) evalString(row types.Row) (res string, isNu
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

type builtinCastTimeAsDurationSig struct {
Expand Down Expand Up @@ -1415,8 +1431,30 @@ func (b *builtinCastDurationAsStringSig) evalString(row types.Row) (res string,
return res, isNull, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc)
return res, false, errors.Trace(err)
res, err = types.ProduceStrWithSpecifiedTp(val.String(), b.tp, sc, false)
if err != nil {
return res, false, err
}
return padZeroForBinaryType(res, b.tp, b.ctx)
}

func padZeroForBinaryType(s string, tp *types.FieldType, ctx sessionctx.Context) (string, bool, error) {
flen := tp.Flen
if tp.Tp == mysql.TypeString && types.IsBinaryStr(tp) && len(s) < flen {
sc := ctx.GetSessionVars().StmtCtx
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return "", false, err
}
if uint64(flen) > maxAllowedPacket {
sc.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("cast_as_binary", maxAllowedPacket))
return "", true, nil
}
padding := make([]byte, flen-len(s))
s = string(append([]byte(s), padding...))
}
return s, false, nil
}

type builtinCastDurationAsTimeSig struct {
Expand Down
15 changes: 13 additions & 2 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ func (s *testEvaluatorSuite) TestCast(c *C) {
c.Assert(len(res.GetString()), Equals, 5)
c.Assert(res.GetString(), Equals, string([]byte{'a', 0x00, 0x00, 0x00, 0x00}))

// cast(str as binary(N)), N > len([]byte(str)).
// cast("a" as binary(4294967295))
tp.Flen = 4294967295
f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("a"), RetType: types.NewFieldType(mysql.TypeString)}, tp)
res, err = f.Eval(chunk.Row{})
c.Assert(err, IsNil)
c.Assert(res.IsNull(), IsTrue)
warnings := sc.GetWarnings()
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err))

origSc := sc
sc.InSelectStmt = true
sc.OverflowAsWarning = true
Expand All @@ -93,8 +104,8 @@ func (s *testEvaluatorSuite) TestCast(c *C) {
c.Assert(err, IsNil)
c.Assert(res.GetUint64() == math.MaxUint64, IsTrue)

warnings := sc.GetWarnings()
lastWarn := warnings[len(warnings)-1]
warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn.Err), IsTrue)

f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("-1"), RetType: types.NewFieldType(mysql.TypeString)}, tp1)
Expand Down
9 changes: 5 additions & 4 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -789,16 +789,17 @@ func (d *Datum) convertToString(sc *stmtctx.StatementContext, target *FieldType)
default:
return invalidConv(d, target.Tp)
}
s, err := ProduceStrWithSpecifiedTp(s, target, sc)
s, err := ProduceStrWithSpecifiedTp(s, target, sc, true)
ret.SetString(s)
if target.Charset == charset.CharsetBin {
ret.k = KindBytes
}
return ret, errors.Trace(err)
}

// ProduceStrWithSpecifiedTp produces a new string according to `flen` and `chs`.
func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementContext) (_ string, err error) {
// ProduceStrWithSpecifiedTp produces a new string according to `flen` and `chs`. Param `padZero` indicates
// whether we should pad `\0` for `binary(flen)` type.
func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementContext, padZero bool) (_ string, err error) {
flen, chs := tp.Flen, tp.Charset
if flen >= 0 {
// Flen is the rune length, not binary length, for UTF8 charset, we need to calculate the
Expand Down Expand Up @@ -827,7 +828,7 @@ func ProduceStrWithSpecifiedTp(s string, tp *FieldType, sc *stmtctx.StatementCon
} else if len(s) > flen {
err = ErrDataTooLong.Gen("Data Too Long, field len %d, data len %d", flen, len(s))
s = truncateStr(s, flen)
} else if tp.Tp == mysql.TypeString && IsBinaryStr(tp) && len(s) < flen {
} else if tp.Tp == mysql.TypeString && IsBinaryStr(tp) && len(s) < flen && padZero {
padding := make([]byte, flen-len(s))
s = string(append([]byte(s), padding...))
}
Expand Down

0 comments on commit a4113cd

Please sign in to comment.