Skip to content

Commit

Permalink
Fix mysqlConn.{affectedRows,insertedIds} growing on each call to Quer…
Browse files Browse the repository at this point in the history
…y(), QueryContext() and Ping().
  • Loading branch information
mherr-google committed Feb 7, 2022
1 parent 241fd91 commit 2e231fa
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 46 deletions.
6 changes: 3 additions & 3 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
case 1:
switch authData[0] {
case cachingSha2PasswordFastAuthSuccess:
if err = mc.readResultOK(); err == nil {
if err = mc.readResultOK(resultUnchanged); err == nil {
return nil // auth successful
}

Expand Down Expand Up @@ -391,7 +391,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
return err
}
}
return mc.readResultOK()
return mc.readResultOK(resultUnchanged)

default:
return ErrMalformPkt
Expand All @@ -416,7 +416,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
if err != nil {
return err
}
return mc.readResultOK()
return mc.readResultOK(resultUnchanged)
}

default:
Expand Down
66 changes: 52 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ import (
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows []int64
insertIds []int64
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
cfg *Config
maxAllowedPacket int
maxWriteSize int
Expand All @@ -45,6 +44,43 @@ type mysqlConn struct {
closed atomicBool // set when conn is closed, before closech is closed
}

// To correctly manage mysqlConn.result (updated by handleOkPacket()), we need
// to ensure all callpaths have either:
//
// 1. cleared it using clearResult() before sending the command, or
// 2. don't need to (eg. in call paths which are accumulating resultsets).
//
// handleOkPacket() takes an argument of this type to ensure exhaustively that
// all callpaths manage this state correctly.
type resultState int

const (
// mysqlConn.result was cleared (ie. a new command or query is being run.)
//
// This value is obtained by calling mysqlConn.clearResult().
resultCleared resultState = iota + 1
// mysqlConn.result was unchanged (ie. additional resultsets are being
// fetched, or the fields did not need to be cleared.)
resultUnchanged
)

// clearResult clears the connection's stored affectedRows and insertIds
// fields.
//
// Ref: https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
//
// It returns a resultCleared status, to be passed directly (or
// indirectly) to handleOkPacket().
//
// All call paths ending in handleOkPacket() must either:
//
// 1. call clearResult(), and pass its result to handleOkPacket().
// 2. pass resultUnchanged to handleOkPacket().
func (mc *mysqlConn) clearResult() resultState {
mc.result = mysqlResult{}
return resultCleared
}

// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
var cmdSet strings.Builder
Expand Down Expand Up @@ -124,6 +160,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if !mc.closed.IsSet() {
mc.clearResult()
err = mc.writeCommandPacket(comQuit)
}

Expand Down Expand Up @@ -310,28 +347,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
}
query = prepared
}
mc.affectedRows = nil
mc.insertIds = nil

err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: mc.affectedRows,
insertIds: mc.insertIds,
}, err
copied := mc.result
return &copied, err
}
return nil, mc.markBadConn(err)
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
resultCleared := mc.clearResult()
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := mc.readResultSetHeaderPacket(resultCleared)
if err != nil {
return err
}
Expand All @@ -348,14 +382,16 @@ func (mc *mysqlConn) exec(query string) error {
}
}

return mc.discardResults()
return mc.discardResults(resultUnchanged)
}

func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.query(query, args)
}

func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
resultCleared := mc.clearResult()

if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand All @@ -376,7 +412,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
resLen, err = mc.readResultSetHeaderPacket(resultCleared)
if err == nil {
rows := new(textRows)
rows.mc = mc
Expand Down Expand Up @@ -404,12 +440,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
resultCleared := mc.clearResult()
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := mc.readResultSetHeaderPacket(resultCleared)
if err == nil {
rows := new(textRows)
rows.mc = mc
Expand Down Expand Up @@ -460,11 +497,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
}
defer mc.finish()

resultCleared := mc.clearResult()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
}

return mc.readResultOK()
return mc.readResultOK(resultCleared)
}

// BeginTx implements driver.ConnBeginTx interface
Expand Down
76 changes: 76 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2155,11 +2155,51 @@ func TestRejectReadOnly(t *testing.T) {
}

func TestPing(t *testing.T) {
ctx := context.Background()
runTests(t, dsn, func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
dbt.fail("Ping", "Ping", err)
}
})

runTests(t, dsn, func(dbt *DBTest) {
conn, err := dbt.db.Conn(ctx)
if err != nil {
dbt.fail("db", "Conn", err)
}

// Check that affectedRows and insertIds are cleared after each call.
conn.Raw(func(conn interface{}) error {
c := conn.(*mysqlConn)

// Issue a query that sets affectedRows and insertIds.
q, err := c.Query(`SELECT 1`, nil)
if err != nil {
dbt.fail("Conn", "Query", err)
}
if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) {
dbt.Fatalf("bad insertIds: got %v, want=%v", got, want)
}
q.Close()

// Verify that Ping() clears both fields.
for i := 0; i < 2; i++ {
if err := c.Ping(ctx); err != nil {
dbt.fail("Pinger", "Ping", err)
}
if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
}
return nil
})
})
}

// See Issue #799
Expand Down Expand Up @@ -2436,6 +2476,42 @@ func TestSkipResults(t *testing.T) {
})
}

func TestQueryMultipleResults(t *testing.T) {
ctx := context.Background()
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
dbt.mustExec(`
CREATE TABLE test (
id INT NOT NULL AUTO_INCREMENT,
value VARCHAR(255),
PRIMARY KEY (id)
)`)
conn, err := dbt.db.Conn(ctx)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
conn.Raw(func(conn interface{}) error {
qr := conn.(driver.Queryer)

c := conn.(*mysqlConn)

// Demonstrate that repeated queries reset the affectedRows
for i := 0; i < 2; i++ {
_, err := qr.Query(`
INSERT INTO test (value) VALUES ('a'), ('b');
INSERT INTO test (value) VALUES ('c'), ('d'), ('e');
`, nil)
if err != nil {
t.Fatalf("insert statements failed: %v", err)
}
if got, want := c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want) {
t.Errorf("bad affectedRows: got %v, want=%v", got, want)
}
}
return nil
})
})
}

func TestPingContext(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
4 changes: 2 additions & 2 deletions infile.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) {
}
}

func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
func (mc *mysqlConn) handleInFileRequest(name string, resultState resultState) (err error) {
var rdr io.Reader
var data []byte
packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
Expand Down Expand Up @@ -174,7 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {

// read OK packet
if err == nil {
return mc.readResultOK()
return mc.readResultOK(resultState)
}

mc.readPacket()
Expand Down
40 changes: 25 additions & 15 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
switch data[0] {

case iOK:
return nil, "", mc.handleOkPacket(data)
// resultUnchanged, since auth happens before any queries or
// commands have been executed.
return nil, "", mc.handleOkPacket(data, resultUnchanged)

case iAuthMoreData:
return data[1:], "", err
Expand All @@ -520,37 +522,37 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
}

// Returns error if Packet is not an 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error {
func (mc *mysqlConn) readResultOK(resultState resultState) error {
data, err := mc.readPacket()
if err != nil {
return err
}

if data[0] == iOK {
return mc.handleOkPacket(data)
return mc.handleOkPacket(data, resultState)
}
return mc.handleErrorPacket(data)
}

// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
func (mc *mysqlConn) readResultSetHeaderPacket(resultState resultState) (int, error) {
// handleOkPacket replaces both values; other cases leave the values unchanged.
mc.affectedRows = append(mc.affectedRows, 0)
mc.insertIds = append(mc.insertIds, 0)
mc.result.affectedRows = append(mc.result.affectedRows, 0)
mc.result.insertIds = append(mc.result.insertIds, 0)

data, err := mc.readPacket()
if err == nil {
switch data[0] {

case iOK:
return 0, mc.handleOkPacket(data)
return 0, mc.handleOkPacket(data, resultState)

case iERR:
return 0, mc.handleErrorPacket(data)

case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
return 0, mc.handleInFileRequest(string(data[1:]), resultState)
}

// column count
Expand Down Expand Up @@ -613,7 +615,11 @@ func readStatus(b []byte) statusFlag {

// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
//
// The resultState argument ensures that the caller has either cleared the
// affectedRows and insertIds fields (by calling clearResult()) before
// the call, or intentionally left them unchanged.
func (mc *mysqlConn) handleOkPacket(data []byte, resultState resultState) error {
var n, m int
var affectedRows, insertId uint64

Expand All @@ -627,11 +633,11 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {

// Update for the current statement result (only used by
// readResultSetHeaderPacket).
if len(mc.affectedRows) > 0 {
mc.affectedRows[len(mc.affectedRows)-1] = int64(affectedRows)
if len(mc.result.affectedRows) > 0 {
mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows)
}
if len(mc.insertIds) > 0 {
mc.insertIds[len(mc.insertIds)-1] = int64(insertId)
if len(mc.result.insertIds) > 0 {
mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId)
}

// server_status [2 bytes]
Expand Down Expand Up @@ -1165,9 +1171,13 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {

// For each remaining resultset in the stream, discards its rows and updates
// mc.affectedRows and mc.insertIds.
func (mc *mysqlConn) discardResults() error {
//
// The resultState argument ensures that the caller has either reset the
// affectedRows and insertIds counters before the call by calling
// resetStoredOKPackets(), or intentionally left them unchanged.
func (mc *mysqlConn) discardResults(resultState resultState) error {
for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := mc.readResultSetHeaderPacket(resultState)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 2e231fa

Please sign in to comment.