Skip to content

Commit

Permalink
server: use max_allowed_packet to limit the packet size. (#33651)
Browse files Browse the repository at this point in the history
close #31422
  • Loading branch information
CbcWestwolf committed Apr 18, 2022
1 parent 345b1a8 commit 4d3a3c2
Show file tree
Hide file tree
Showing 14 changed files with 274 additions and 172 deletions.
2 changes: 1 addition & 1 deletion errno/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{
ErrDelayedCantChangeLock: mysql.Message("Delayed insert thread couldn't get requested lock for table %-.192s", nil),
ErrTooManyDelayedThreads: mysql.Message("Too many delayed threads in use", nil),
ErrAbortingConnection: mysql.Message("Aborted connection %d to db: '%-.192s' user: '%-.48s' (%-.64s)", nil),
ErrNetPacketTooLarge: mysql.Message("Got a packet bigger than 'maxAllowedPacket' bytes", nil),
ErrNetPacketTooLarge: mysql.Message("Got a packet bigger than 'max_allowed_packet' bytes", nil),
ErrNetReadErrorFromPipe: mysql.Message("Got a read error from the connection pipe", nil),
ErrNetFcntl: mysql.Message("Got an error from fcntl()", nil),
ErrNetPacketsOutOfOrder: mysql.Message("Got packets out of order", nil),
Expand Down
14 changes: 14 additions & 0 deletions executor/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,20 @@ func TestSetVar(t *testing.T) {
tk.MustQuery("select @@global.tidb_remove_orderby_in_subquery").Check(testkit.Rows("0")) // default value is 0
tk.MustExec("set global tidb_remove_orderby_in_subquery=1")
tk.MustQuery("select @@global.tidb_remove_orderby_in_subquery").Check(testkit.Rows("1"))

// the value of max_allowed_packet should be a multiple of 1024
tk.MustExec("set @@global.max_allowed_packet=16385")
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '16385'"))
result := tk.MustQuery("select @@global.max_allowed_packet;")
result.Check(testkit.Rows("16384"))
tk.MustExec("set @@max_allowed_packet=2047")
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '2047'"))
result = tk.MustQuery("select @@max_allowed_packet;")
result.Check(testkit.Rows("1024"))
tk.MustExec("set @@global.max_allowed_packet=0")
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '0'"))
result = tk.MustQuery("select @@global.max_allowed_packet;")
result.Check(testkit.Rows("1024"))
}

func TestTruncateIncorrectIntSessionVar(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion parser/mysql/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ var MySQLErrName = map[uint16]*ErrMessage{
ErrDelayedCantChangeLock: Message("Delayed insert thread couldn't get requested lock for table %-.192s", nil),
ErrTooManyDelayedThreads: Message("Too many delayed threads in use", nil),
ErrAbortingConnection: Message("Aborted connection %d to db: '%-.192s' user: '%-.48s' (%-.64s)", nil),
ErrNetPacketTooLarge: Message("Got a packet bigger than 'maxAllowedPacket' bytes", nil),
ErrNetPacketTooLarge: Message("Got a packet bigger than 'max_allowed_packet' bytes", nil),
ErrNetReadErrorFromPipe: Message("Got a read error from the connection pipe", nil),
ErrNetFcntl: Message("Got an error from fcntl()", nil),
ErrNetPacketsOutOfOrder: Message("Got packets out of order", nil),
Expand Down
8 changes: 8 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ func (cc *clientConn) writeInitialHandshake(ctx context.Context) error {
}

func (cc *clientConn) readPacket() ([]byte, error) {
if cc.ctx != nil {
cc.pkt.setMaxAllowedPacket(cc.ctx.GetSessionVars().MaxAllowedPacket)
}
return cc.pkt.readPacket()
}

Expand Down Expand Up @@ -1075,6 +1078,11 @@ func (cc *clientConn) Run(ctx context.Context) {
zap.Error(err),
)
}
} else if errors.ErrorEqual(err, errNetPacketTooLarge) {
err := cc.writeError(ctx, err)
if err != nil {
terror.Log(err)
}
} else {
errStack := errors.ErrorStack(err)
if !strings.Contains(errStack, "use of closed network connection") {
Expand Down
29 changes: 29 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/binary"
"fmt"
"io"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -1263,3 +1264,31 @@ func TestAuthPlugin2(t *testing.T) {
require.NoError(t, err)

}

func TestMaxAllowedPacket(t *testing.T) {
// Test cases from issue 31422: https://github.com/pingcap/tidb/issues/31422
// The string "SELECT length('') as len;" has 25 chars,
// so if the string inside '' has a length of 999, the total query reaches the max allowed packet size.

const maxAllowedPacket = 1024
var inBuffer bytes.Buffer
bytes := append([]byte{0x00, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 999)))...)
_, err := inBuffer.Write(bytes)
require.NoError(t, err)
brc := newBufferedReadConn(&bytesConn{inBuffer})
pkt := newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
_, err = pkt.readPacket()
require.NoError(t, err)
require.Equal(t, uint8(1), pkt.sequence)

inBuffer.Reset()
bytes = append([]byte{0x01, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 1000)))...)
_, err = inBuffer.Write(bytes)
require.NoError(t, err)
brc = newBufferedReadConn(&bytesConn{inBuffer})
pkt = newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
_, err = pkt.readPacket()
require.Error(t, err)
}
17 changes: 17 additions & 0 deletions server/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx/variable"
)

const defaultWriterSize = 16 * 1024
Expand All @@ -54,16 +55,22 @@ var (
)

// packetIO is a helper to read and write data in packet format.
// MySQL Packets: https://dev.mysql.com/doc/internals/en/mysql-packet.html
type packetIO struct {
bufReadConn *bufferedReadConn
bufWriter *bufio.Writer
sequence uint8
readTimeout time.Duration
// maxAllowedPacket is the maximum size of one packet in readPacket.
maxAllowedPacket uint64
// accumulatedLength count the length of totally received 'payload' in readPacket.
accumulatedLength uint64
}

func newPacketIO(bufReadConn *bufferedReadConn) *packetIO {
p := &packetIO{sequence: 0}
p.setBufferedReadConn(bufReadConn)
p.setMaxAllowedPacket(variable.DefMaxAllowedPacket)
return p
}

Expand Down Expand Up @@ -96,6 +103,12 @@ func (p *packetIO) readOnePacket() ([]byte, error) {

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

// Accumulated payload length exceeds the limit.
if p.accumulatedLength += uint64(length); p.accumulatedLength > p.maxAllowedPacket {
terror.Log(errNetPacketTooLarge)
return nil, errNetPacketTooLarge
}

data := make([]byte, length)
if p.readTimeout > 0 {
if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil {
Expand All @@ -108,6 +121,10 @@ func (p *packetIO) readOnePacket() ([]byte, error) {
return data, nil
}

func (p *packetIO) setMaxAllowedPacket(maxAllowedPacket uint64) {
p.maxAllowedPacket = maxAllowedPacket
}

func (p *packetIO) readPacket() ([]byte, error) {
if p.readTimeout == 0 {
if err := p.bufReadConn.SetReadDeadline(time.Time{}); err != nil {
Expand Down
10 changes: 5 additions & 5 deletions server/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ func TestPacketIORead(t *testing.T) {
// Test read one packet
brc := newBufferedReadConn(&bytesConn{inBuffer})
pkt := newPacketIO(brc)
bytes, err := pkt.readPacket()
readBytes, err := pkt.readPacket()
require.NoError(t, err)
require.Equal(t, uint8(1), pkt.sequence)
require.Equal(t, []byte{0x01}, bytes)
require.Equal(t, []byte{0x01}, readBytes)

inBuffer.Reset()
buf := make([]byte, mysql.MaxPayloadLen+9)
Expand All @@ -79,11 +79,11 @@ func TestPacketIORead(t *testing.T) {
// Test read multiple packets
brc = newBufferedReadConn(&bytesConn{inBuffer})
pkt = newPacketIO(brc)
bytes, err = pkt.readPacket()
readBytes, err = pkt.readPacket()
require.NoError(t, err)
require.Equal(t, uint8(2), pkt.sequence)
require.Equal(t, mysql.MaxPayloadLen+1, len(bytes))
require.Equal(t, byte(0x0a), bytes[mysql.MaxPayloadLen])
require.Equal(t, mysql.MaxPayloadLen+1, len(readBytes))
require.Equal(t, byte(0x0a), readBytes[mysql.MaxPayloadLen])
}

type bytesConn struct {
Expand Down
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ var (
errMultiStatementDisabled = dbterror.ClassServer.NewStd(errno.ErrMultiStatementDisabled)
errNewAbortingConnection = dbterror.ClassServer.NewStd(errno.ErrNewAbortingConnection)
errNotSupportedAuthMode = dbterror.ClassServer.NewStd(errno.ErrNotSupportedAuthMode)
errNetPacketTooLarge = dbterror.ClassServer.NewStd(errno.ErrNetPacketTooLarge)
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down
2 changes: 1 addition & 1 deletion session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ func createSessionFunc(store kv.Storage) pools.Factory {
if err != nil {
return nil, errors.Trace(err)
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxAllowedPacket, "67108864")
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10))
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
14 changes: 9 additions & 5 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,11 @@ func testTxnLazyInitialize(s *testSessionSuite, c *C, isPessimistic bool) {

func (s *testSessionSuite) TestGlobalVarAccessor(c *C) {
varName := "max_allowed_packet"
varValue := "67108864" // This is the default value for max_allowed_packet
varValue := strconv.FormatUint(variable.DefMaxAllowedPacket, 10) // This is the default value for max_allowed_packet

// The value of max_allowed_packet should be a multiple of 1024,
// so the setting of varValue1 and varValue2 would be truncated to varValue0
varValue0 := "4194304"
varValue1 := "4194305"
varValue2 := "4194306"

Expand All @@ -660,25 +664,25 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) {
c.Assert(err, IsNil)
v, err = se.GetGlobalSysVar(varName)
c.Assert(err, IsNil)
c.Assert(v, Equals, varValue1)
c.Assert(v, Equals, varValue0)
c.Assert(tk.Se.CommitTxn(context.TODO()), IsNil)

tk1 := testkit.NewTestKitWithInit(c, s.store)
se1 := tk1.Se.(variable.GlobalVarAccessor)
v, err = se1.GetGlobalSysVar(varName)
c.Assert(err, IsNil)
c.Assert(v, Equals, varValue1)
c.Assert(v, Equals, varValue0)
err = se1.SetGlobalSysVar(varName, varValue2)
c.Assert(err, IsNil)
v, err = se1.GetGlobalSysVar(varName)
c.Assert(err, IsNil)
c.Assert(v, Equals, varValue2)
c.Assert(v, Equals, varValue0)
c.Assert(tk1.Se.CommitTxn(context.TODO()), IsNil)

// Make sure the change is visible to any client that accesses that global variable.
v, err = se.GetGlobalSysVar(varName)
c.Assert(err, IsNil)
c.Assert(v, Equals, varValue2)
c.Assert(v, Equals, varValue0)

// For issue 10955, make sure the new session load `max_execution_time` into sessionVars.
tk1.MustExec("set @@global.max_execution_time = 100")
Expand Down
4 changes: 4 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,9 @@ type SessionVars struct {
RcReadCheckTS bool
// RemoveOrderbyInSubquery indicates whether to remove ORDER BY in subquery.
RemoveOrderbyInSubquery bool

// MaxAllowedPacket indicates the maximum size of a packet for the MySQL protocol.
MaxAllowedPacket uint64
}

// InitStatementContext initializes a StatementContext, the object is reused to reduce allocation.
Expand Down Expand Up @@ -1262,6 +1265,7 @@ func NewSessionVars() *SessionVars {
StatsLoadSyncWait: StatsLoadSyncWait.Load(),
EnableLegacyInstanceScope: DefEnableLegacyInstanceScope,
RemoveOrderbyInSubquery: DefTiDBRemoveOrderbyInSubquery,
MaxAllowedPacket: DefMaxAllowedPacket,
}
vars.KVVars = tikvstore.NewVariables(&vars.Killed)
vars.Concurrency = Concurrency{
Expand Down
27 changes: 26 additions & 1 deletion sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,32 @@ var defaultSysVars = []*SysVar{
}
return nil
}},
{Scope: ScopeGlobal | ScopeSession, Name: MaxAllowedPacket, Value: "67108864", Type: TypeUnsigned, MinValue: 1024, MaxValue: MaxOfMaxAllowedPacket},
{Scope: ScopeGlobal | ScopeSession, Name: MaxAllowedPacket, Value: strconv.FormatUint(DefMaxAllowedPacket, 10), Type: TypeUnsigned, MinValue: 1024, MaxValue: MaxOfMaxAllowedPacket,
Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) {
// Truncate the value of max_allowed_packet to be a multiple of 1024,
// nonmultiples are rounded down to the nearest multiple.
u, err := strconv.ParseUint(normalizedValue, 10, 64)
if err != nil {
return normalizedValue, err
}
remainder := u % 1024
if remainder != 0 {
vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenWithStackByArgs(MaxAllowedPacket, normalizedValue))
u -= remainder
}
return strconv.FormatUint(u, 10), nil
},
GetSession: func(s *SessionVars) (string, error) {
return strconv.FormatUint(s.MaxAllowedPacket, 10), nil
},
SetSession: func(s *SessionVars, val string) error {
var err error
if s.MaxAllowedPacket, err = strconv.ParseUint(val, 10, 64); err != nil {
return err
}
return nil
},
},
{Scope: ScopeGlobal | ScopeSession, Name: WindowingUseHighPrecision, Value: On, Type: TypeBool, IsHintUpdatable: true, SetSession: func(s *SessionVars, val string) error {
s.WindowingUseHighPrecision = TiDBOptOn(val)
return nil
Expand Down
Loading

0 comments on commit 4d3a3c2

Please sign in to comment.