Skip to content

Commit

Permalink
expression: handle max_allowed_packet warnings for repeat function. (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
hhu-cc authored and ngaut committed Aug 22, 2018
1 parent 5c1db1e commit fc14119
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
17 changes: 15 additions & 2 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,17 +554,24 @@ func (c *repeatFunctionClass) getFunction(ctx sessionctx.Context, args []Express
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString, types.ETInt)
bf.tp.Flen = mysql.MaxBlobWidth
SetBinFlagOrBinStr(args[0].GetType(), bf.tp)
sig := &builtinRepeatSig{bf}
valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}
sig := &builtinRepeatSig{bf, maxAllowedPacket}
return sig, nil
}

type builtinRepeatSig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinRepeatSig) Clone() builtinFunc {
newSig := &builtinRepeatSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand All @@ -575,6 +582,7 @@ func (b *builtinRepeatSig) evalString(row types.Row) (d string, isNull bool, err
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}
byteLength := len(str)

num, isNull, err := b.args[1].EvalInt(b.ctx, row)
if isNull || err != nil {
Expand All @@ -587,7 +595,12 @@ func (b *builtinRepeatSig) evalString(row types.Row) (d string, isNull bool, err
num = math.MaxInt32
}

if int64(len(str)) > int64(b.tp.Flen)/num {
if uint64(byteLength)*uint64(num) > b.maxAllowedPacket {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("repeat", b.maxAllowedPacket))
return "", true, nil
}

if int64(byteLength) > int64(b.tp.Flen)/num {
return "", true, nil
}
return strings.Repeat(str, int(num)), false, nil
Expand Down
45 changes: 45 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,51 @@ func (s *testEvaluatorSuite) TestRepeat(c *C) {
c.Assert(v.GetString(), Equals, "")
}

func (s *testEvaluatorSuite) TestRepeatSig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
{Tp: mysql.TypeLonglong},
}
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000}
args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
&Column{Index: 1, RetType: colTypes[1]},
}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
repeat := &builtinRepeatSig{base, 1000}

cases := []struct {
args []interface{}
warning int
res string
}{
{[]interface{}{"a", int64(6)}, 0, "aaaaaa"},
{[]interface{}{"a", int64(10001)}, 1, ""},
{[]interface{}{"毅", int64(6)}, 0, "毅毅毅毅毅毅"},
{[]interface{}{"毅", int64(334)}, 2, ""},
}

for _, t := range cases {
input := chunk.NewChunkWithCapacity(colTypes, 10)
input.AppendString(0, t.args[0].(string))
input.AppendInt64(1, t.args[1].(int64))

res, isNull, err := repeat.evalString(input.GetRow(0))
c.Assert(res, Equals, t.res)
c.Assert(err, IsNil)
if t.warning == 0 {
c.Assert(isNull, IsFalse)
} else {
c.Assert(isNull, IsTrue)
c.Assert(err, IsNil)
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, t.warning)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
}
}
}

func (s *testEvaluatorSuite) TestLower(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
Expand Down

0 comments on commit fc14119

Please sign in to comment.