Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: use max_allowed_packet to limit the packet size. (#33651) #34059

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion errno/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{
ErrDelayedCantChangeLock: mysql.Message("Delayed insert thread couldn't get requested lock for table %-.192s", nil),
ErrTooManyDelayedThreads: mysql.Message("Too many delayed threads in use", nil),
ErrAbortingConnection: mysql.Message("Aborted connection %d to db: '%-.192s' user: '%-.48s' (%-.64s)", nil),
ErrNetPacketTooLarge: mysql.Message("Got a packet bigger than 'maxAllowedPacket' bytes", nil),
ErrNetPacketTooLarge: mysql.Message("Got a packet bigger than 'max_allowed_packet' bytes", nil),
ErrNetReadErrorFromPipe: mysql.Message("Got a read error from the connection pipe", nil),
ErrNetFcntl: mysql.Message("Got an error from fcntl()", nil),
ErrNetPacketsOutOfOrder: mysql.Message("Got packets out of order", nil),
Expand Down
25 changes: 25 additions & 0 deletions executor/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,31 @@ func TestSetVar(t *testing.T) {
tk.MustExec("set global tidb_ignore_prepared_cache_close_stmt=0")
tk.MustQuery("select @@global.tidb_ignore_prepared_cache_close_stmt").Check(testkit.Rows("0"))
tk.MustQuery("show global variables like 'tidb_ignore_prepared_cache_close_stmt'").Check(testkit.Rows("tidb_ignore_prepared_cache_close_stmt OFF"))
<<<<<<< HEAD
=======

// test for tidb_remove_orderby_in_subquery
tk.MustQuery("select @@session.tidb_remove_orderby_in_subquery").Check(testkit.Rows("0")) // default value is 0
tk.MustExec("set session tidb_remove_orderby_in_subquery=1")
tk.MustQuery("select @@session.tidb_remove_orderby_in_subquery").Check(testkit.Rows("1"))
tk.MustQuery("select @@global.tidb_remove_orderby_in_subquery").Check(testkit.Rows("0")) // default value is 0
tk.MustExec("set global tidb_remove_orderby_in_subquery=1")
tk.MustQuery("select @@global.tidb_remove_orderby_in_subquery").Check(testkit.Rows("1"))

// the value of max_allowed_packet should be a multiple of 1024
tk.MustExec("set @@global.max_allowed_packet=16385")
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '16385'"))
result := tk.MustQuery("select @@global.max_allowed_packet;")
result.Check(testkit.Rows("16384"))
tk.MustExec("set @@max_allowed_packet=2047")
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '2047'"))
result = tk.MustQuery("select @@max_allowed_packet;")
result.Check(testkit.Rows("1024"))
tk.MustExec("set @@global.max_allowed_packet=0")
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '0'"))
result = tk.MustQuery("select @@global.max_allowed_packet;")
result.Check(testkit.Rows("1024"))
>>>>>>> 4d3a3c259... server: use max_allowed_packet to limit the packet size. (#33651)
}

func TestTruncateIncorrectIntSessionVar(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion parser/mysql/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ var MySQLErrName = map[uint16]*ErrMessage{
ErrDelayedCantChangeLock: Message("Delayed insert thread couldn't get requested lock for table %-.192s", nil),
ErrTooManyDelayedThreads: Message("Too many delayed threads in use", nil),
ErrAbortingConnection: Message("Aborted connection %d to db: '%-.192s' user: '%-.48s' (%-.64s)", nil),
ErrNetPacketTooLarge: Message("Got a packet bigger than 'maxAllowedPacket' bytes", nil),
ErrNetPacketTooLarge: Message("Got a packet bigger than 'max_allowed_packet' bytes", nil),
ErrNetReadErrorFromPipe: Message("Got a read error from the connection pipe", nil),
ErrNetFcntl: Message("Got an error from fcntl()", nil),
ErrNetPacketsOutOfOrder: Message("Got packets out of order", nil),
Expand Down
8 changes: 8 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ func (cc *clientConn) writeInitialHandshake(ctx context.Context) error {
}

func (cc *clientConn) readPacket() ([]byte, error) {
if cc.ctx != nil {
cc.pkt.setMaxAllowedPacket(cc.ctx.GetSessionVars().MaxAllowedPacket)
}
return cc.pkt.readPacket()
}

Expand Down Expand Up @@ -1075,6 +1078,11 @@ func (cc *clientConn) Run(ctx context.Context) {
zap.Error(err),
)
}
} else if errors.ErrorEqual(err, errNetPacketTooLarge) {
err := cc.writeError(ctx, err)
if err != nil {
terror.Log(err)
}
} else {
errStack := errors.ErrorStack(err)
if !strings.Contains(errStack, "use of closed network connection") {
Expand Down
29 changes: 29 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/binary"
"fmt"
"io"
"strings"
"testing"

"github.com/pingcap/failpoint"
Expand Down Expand Up @@ -1133,3 +1134,31 @@ func TestAuthPlugin2(t *testing.T) {
require.NoError(t, err)

}

func TestMaxAllowedPacket(t *testing.T) {
// Test cases from issue 31422: https://github.com/pingcap/tidb/issues/31422
// The string "SELECT length('') as len;" has 25 chars,
// so if the string inside '' has a length of 999, the total query reaches the max allowed packet size.

const maxAllowedPacket = 1024
var inBuffer bytes.Buffer
bytes := append([]byte{0x00, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 999)))...)
_, err := inBuffer.Write(bytes)
require.NoError(t, err)
brc := newBufferedReadConn(&bytesConn{inBuffer})
pkt := newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
_, err = pkt.readPacket()
require.NoError(t, err)
require.Equal(t, uint8(1), pkt.sequence)

inBuffer.Reset()
bytes = append([]byte{0x01, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 1000)))...)
_, err = inBuffer.Write(bytes)
require.NoError(t, err)
brc = newBufferedReadConn(&bytesConn{inBuffer})
pkt = newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
_, err = pkt.readPacket()
require.Error(t, err)
}
17 changes: 17 additions & 0 deletions server/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx/variable"
)

const defaultWriterSize = 16 * 1024
Expand All @@ -54,16 +55,22 @@ var (
)

// packetIO is a helper to read and write data in packet format.
// MySQL Packets: https://dev.mysql.com/doc/internals/en/mysql-packet.html
type packetIO struct {
bufReadConn *bufferedReadConn
bufWriter *bufio.Writer
sequence uint8
readTimeout time.Duration
// maxAllowedPacket is the maximum size of one packet in readPacket.
maxAllowedPacket uint64
// accumulatedLength count the length of totally received 'payload' in readPacket.
accumulatedLength uint64
}

func newPacketIO(bufReadConn *bufferedReadConn) *packetIO {
p := &packetIO{sequence: 0}
p.setBufferedReadConn(bufReadConn)
p.setMaxAllowedPacket(variable.DefMaxAllowedPacket)
return p
}

Expand Down Expand Up @@ -96,6 +103,12 @@ func (p *packetIO) readOnePacket() ([]byte, error) {

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)

// Accumulated payload length exceeds the limit.
if p.accumulatedLength += uint64(length); p.accumulatedLength > p.maxAllowedPacket {
terror.Log(errNetPacketTooLarge)
return nil, errNetPacketTooLarge
}

data := make([]byte, length)
if p.readTimeout > 0 {
if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil {
Expand All @@ -108,6 +121,10 @@ func (p *packetIO) readOnePacket() ([]byte, error) {
return data, nil
}

func (p *packetIO) setMaxAllowedPacket(maxAllowedPacket uint64) {
p.maxAllowedPacket = maxAllowedPacket
}

func (p *packetIO) readPacket() ([]byte, error) {
if p.readTimeout == 0 {
if err := p.bufReadConn.SetReadDeadline(time.Time{}); err != nil {
Expand Down
10 changes: 5 additions & 5 deletions server/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ func TestPacketIORead(t *testing.T) {
// Test read one packet
brc := newBufferedReadConn(&bytesConn{inBuffer})
pkt := newPacketIO(brc)
bytes, err := pkt.readPacket()
readBytes, err := pkt.readPacket()
require.NoError(t, err)
require.Equal(t, uint8(1), pkt.sequence)
require.Equal(t, []byte{0x01}, bytes)
require.Equal(t, []byte{0x01}, readBytes)

inBuffer.Reset()
buf := make([]byte, mysql.MaxPayloadLen+9)
Expand All @@ -79,11 +79,11 @@ func TestPacketIORead(t *testing.T) {
// Test read multiple packets
brc = newBufferedReadConn(&bytesConn{inBuffer})
pkt = newPacketIO(brc)
bytes, err = pkt.readPacket()
readBytes, err = pkt.readPacket()
require.NoError(t, err)
require.Equal(t, uint8(2), pkt.sequence)
require.Equal(t, mysql.MaxPayloadLen+1, len(bytes))
require.Equal(t, byte(0x0a), bytes[mysql.MaxPayloadLen])
require.Equal(t, mysql.MaxPayloadLen+1, len(readBytes))
require.Equal(t, byte(0x0a), readBytes[mysql.MaxPayloadLen])
}

type bytesConn struct {
Expand Down
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ var (
errMultiStatementDisabled = dbterror.ClassServer.NewStd(errno.ErrMultiStatementDisabled)
errNewAbortingConnection = dbterror.ClassServer.NewStd(errno.ErrNewAbortingConnection)
errNotSupportedAuthMode = dbterror.ClassServer.NewStd(errno.ErrNotSupportedAuthMode)
errNetPacketTooLarge = dbterror.ClassServer.NewStd(errno.ErrNetPacketTooLarge)
)

// DefaultCapability is the capability of the server when it is created using the default configuration.
Expand Down
2 changes: 1 addition & 1 deletion session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,7 @@ func createSessionFunc(store kv.Storage) pools.Factory {
if err != nil {
return nil, errors.Trace(err)
}
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxAllowedPacket, "67108864")
err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10))
if err != nil {
return nil, errors.Trace(err)
}
Expand Down
14 changes: 9 additions & 5 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,11 @@ func testTxnLazyInitialize(s *testSessionSuite, c *C, isPessimistic bool) {

func (s *testSessionSuite) TestGlobalVarAccessor(c *C) {
varName := "max_allowed_packet"
varValue := "67108864" // This is the default value for max_allowed_packet
varValue := strconv.FormatUint(variable.DefMaxAllowedPacket, 10) // This is the default value for max_allowed_packet

// The value of max_allowed_packet should be a multiple of 1024,
// so the setting of varValue1 and varValue2 would be truncated to varValue0
varValue0 := "4194304"
varValue1 := "4194305"
varValue2 := "4194306"

Expand All @@ -661,25 +665,25 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) {
c.Assert(err, IsNil)
v, err = se.GetGlobalSysVar(varName)
c.Assert(err, IsNil)
c.Assert(v, Equals, varValue1)
c.Assert(v, Equals, varValue0)
c.Assert(tk.Se.CommitTxn(context.TODO()), IsNil)

tk1 := testkit.NewTestKitWithInit(c, s.store)
se1 := tk1.Se.(variable.GlobalVarAccessor)
v, err = se1.GetGlobalSysVar(varName)
c.Assert(err, IsNil)
c.Assert(v, Equals, varValue1)
c.Assert(v, Equals, varValue0)
err = se1.SetGlobalSysVar(varName, varValue2)
c.Assert(err, IsNil)
v, err = se1.GetGlobalSysVar(varName)
c.Assert(err, IsNil)
c.Assert(v, Equals, varValue2)
c.Assert(v, Equals, varValue0)
c.Assert(tk1.Se.CommitTxn(context.TODO()), IsNil)

// Make sure the change is visible to any client that accesses that global variable.
v, err = se.GetGlobalSysVar(varName)
c.Assert(err, IsNil)
c.Assert(v, Equals, varValue2)
c.Assert(v, Equals, varValue0)

// For issue 10955, make sure the new session load `max_execution_time` into sessionVars.
tk1.MustExec("set @@global.max_execution_time = 100")
Expand Down
13 changes: 13 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,14 @@ type SessionVars struct {
BatchPendingTiFlashCount int
// RcReadCheckTS indicates if ts check optimization is enabled for current session.
RcReadCheckTS bool
<<<<<<< HEAD
=======
// RemoveOrderbyInSubquery indicates whether to remove ORDER BY in subquery.
RemoveOrderbyInSubquery bool

// MaxAllowedPacket indicates the maximum size of a packet for the MySQL protocol.
MaxAllowedPacket uint64
>>>>>>> 4d3a3c259... server: use max_allowed_packet to limit the packet size. (#33651)
}

// InitStatementContext initializes a StatementContext, the object is reused to reduce allocation.
Expand Down Expand Up @@ -1259,6 +1267,11 @@ func NewSessionVars() *SessionVars {
Rng: utilMath.NewWithTime(),
StatsLoadSyncWait: StatsLoadSyncWait.Load(),
EnableLegacyInstanceScope: DefEnableLegacyInstanceScope,
<<<<<<< HEAD
=======
RemoveOrderbyInSubquery: DefTiDBRemoveOrderbyInSubquery,
MaxAllowedPacket: DefMaxAllowedPacket,
>>>>>>> 4d3a3c259... server: use max_allowed_packet to limit the packet size. (#33651)
}
vars.KVVars = tikvstore.NewVariables(&vars.Killed)
vars.Concurrency = Concurrency{
Expand Down
27 changes: 26 additions & 1 deletion sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,32 @@ var defaultSysVars = []*SysVar{
}
return nil
}},
{Scope: ScopeGlobal | ScopeSession, Name: MaxAllowedPacket, Value: "67108864", Type: TypeUnsigned, MinValue: 1024, MaxValue: MaxOfMaxAllowedPacket},
{Scope: ScopeGlobal | ScopeSession, Name: MaxAllowedPacket, Value: strconv.FormatUint(DefMaxAllowedPacket, 10), Type: TypeUnsigned, MinValue: 1024, MaxValue: MaxOfMaxAllowedPacket,
Validation: func(vars *SessionVars, normalizedValue string, originalValue string, scope ScopeFlag) (string, error) {
// Truncate the value of max_allowed_packet to be a multiple of 1024,
// nonmultiples are rounded down to the nearest multiple.
u, err := strconv.ParseUint(normalizedValue, 10, 64)
if err != nil {
return normalizedValue, err
}
remainder := u % 1024
if remainder != 0 {
vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenWithStackByArgs(MaxAllowedPacket, normalizedValue))
u -= remainder
}
return strconv.FormatUint(u, 10), nil
},
GetSession: func(s *SessionVars) (string, error) {
return strconv.FormatUint(s.MaxAllowedPacket, 10), nil
},
SetSession: func(s *SessionVars, val string) error {
var err error
if s.MaxAllowedPacket, err = strconv.ParseUint(val, 10, 64); err != nil {
return err
}
return nil
},
},
{Scope: ScopeGlobal | ScopeSession, Name: WindowingUseHighPrecision, Value: On, Type: TypeBool, IsHintUpdatable: true, SetSession: func(s *SessionVars, val string) error {
s.WindowingUseHighPrecision = TiDBOptOn(val)
return nil
Expand Down
Loading