Skip to content

Commit

Permalink
fewer driver.ErrBadConn to prevent repeated queries (#302)
Browse files Browse the repository at this point in the history
According to the database/sql/driver documentation, ErrBadConn should only
be used when the database was not affected. The driver restarts the same
query on a different connection, then.
The mysql driver did not follow this advice, so queries were repeated if
ErrBadConn is returned but a query succeeded.

This is fixed by changing most ErrBadConn errors to ErrInvalidConn.

The only valid returns of ErrBadConn are at the beginning of a database
interaction when no data was sent to the database yet.

Those valid cases are located the following funcs before attempting to write
to the network or if 0 bytes were written:

* Begin
* BeginTx
* Exec
* ExecContext
* Prepare
* PrepareContext
* Query
* QueryContext

Commit and Rollback could arguably also be on that list, but are left out as
some engines like MyISAM are not supporting transactions.

Tests in b/packets_test.go were changed because they simulate a read not
preceded by a write to the db. This cannot happen as the client has to send
the query first.
  • Loading branch information
arnehormann authored Aug 22, 2017
1 parent 21d7e97 commit 26471af
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 28 deletions.
23 changes: 16 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}

func (mc *mysqlConn) markBadConn(err error) error {
if mc == nil {
return err
}
if err != errBadConnNoWrite {
return err
}
return driver.ErrBadConn
}

func (mc *mysqlConn) Begin() (driver.Tx, error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
Expand All @@ -90,8 +100,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
if err == nil {
return &mysqlTx{mc}, err
}

return nil, err
return nil, mc.markBadConn(err)
}

func (mc *mysqlConn) Close() (err error) {
Expand Down Expand Up @@ -142,7 +151,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
return nil, err
return nil, mc.markBadConn(err)
}

stmt := &mysqlStmt{
Expand Down Expand Up @@ -176,7 +185,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return "", driver.ErrBadConn
return "", ErrInvalidConn
}
buf = buf[:0]
argPos := 0
Expand Down Expand Up @@ -314,14 +323,14 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
insertId: int64(mc.insertId),
}, err
}
return nil, err
return nil, mc.markBadConn(err)
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return err
return mc.markBadConn(err)
}

// Read Result
Expand Down Expand Up @@ -390,7 +399,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
return rows, err
}
}
return nil, err
return nil, mc.markBadConn(err)
}

// Gets the value of the given MySQL System Variable
Expand Down
6 changes: 6 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ var (
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")

// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
// to trigger a resend.
// See https://github.com/go-sql-driver/mysql/pull/302
errBadConnNoWrite = errors.New("bad connection")
)

var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
Expand Down
28 changes: 16 additions & 12 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
}
errLog.Print(err)
mc.Close()
return nil, driver.ErrBadConn
return nil, ErrInvalidConn
}

// packet length [24 bit]
Expand All @@ -57,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if prevData == nil {
errLog.Print(ErrMalformPkt)
mc.Close()
return nil, driver.ErrBadConn
return nil, ErrInvalidConn
}

return prevData, nil
Expand All @@ -71,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
}
errLog.Print(err)
mc.Close()
return nil, driver.ErrBadConn
return nil, ErrInvalidConn
}

// return data if this was the last packet
Expand Down Expand Up @@ -137,10 +137,14 @@ func (mc *mysqlConn) writePacket(data []byte) error {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
return errBadConnNoWrite
}
mc.cleanup()
errLog.Print(err)
}
return driver.ErrBadConn
return ErrInvalidConn
}
}

Expand Down Expand Up @@ -274,7 +278,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
return errBadConnNoWrite
}

// ClientFlags [32 bit]
Expand Down Expand Up @@ -362,7 +366,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
return errBadConnNoWrite
}

// Add the scrambled password [null terminated string]
Expand All @@ -381,7 +385,7 @@ func (mc *mysqlConn) writeClearAuthPacket() error {
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
return errBadConnNoWrite
}

// Add the clear password [null terminated string]
Expand All @@ -404,7 +408,7 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
return errBadConnNoWrite
}

// Add the scramble
Expand All @@ -425,7 +429,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
return errBadConnNoWrite
}

// Add command byte
Expand All @@ -444,7 +448,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
return errBadConnNoWrite
}

// Add command byte
Expand All @@ -465,7 +469,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
return errBadConnNoWrite
}

// Add command byte
Expand Down Expand Up @@ -931,7 +935,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
return errBadConnNoWrite
}

// command [1 byte]
Expand Down
13 changes: 6 additions & 7 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package mysql

import (
"database/sql/driver"
"errors"
"net"
"testing"
Expand Down Expand Up @@ -252,8 +251,8 @@ func TestReadPacketFail(t *testing.T) {
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
conn.maxReads = 1
_, err := mc.readPacket()
if err != driver.ErrBadConn {
t.Errorf("expected ErrBadConn, got %v", err)
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}

// reset
Expand All @@ -264,8 +263,8 @@ func TestReadPacketFail(t *testing.T) {
// fail to read header
conn.closed = true
_, err = mc.readPacket()
if err != driver.ErrBadConn {
t.Errorf("expected ErrBadConn, got %v", err)
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}

// reset
Expand All @@ -277,7 +276,7 @@ func TestReadPacketFail(t *testing.T) {
// fail to read body
conn.maxReads = 1
_, err = mc.readPacket()
if err != driver.ErrBadConn {
t.Errorf("expected ErrBadConn, got %v", err)
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}
}
4 changes: 2 additions & 2 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
return nil, stmt.mc.markBadConn(err)
}

mc := stmt.mc
Expand Down Expand Up @@ -100,7 +100,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
return nil, stmt.mc.markBadConn(err)
}

mc := stmt.mc
Expand Down

0 comments on commit 26471af

Please sign in to comment.