Skip to content

Commit

Permalink
conn(dm): add timeout for Conn (#6055) (#6091)
Browse files Browse the repository at this point in the history
close #3733
  • Loading branch information
ti-chi-bot authored Jul 5, 2022
1 parent 781562b commit 01a9b2e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
21 changes: 7 additions & 14 deletions dm/pkg/conn/basedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import (
"sync"
"sync/atomic"

"github.com/DATA-DOG/go-sqlmock"
"github.com/go-sql-driver/mysql"
"github.com/pingcap/failpoint"
toolutils "github.com/pingcap/tidb-tools/pkg/utils"

"github.com/pingcap/tiflow/dm/dm/config"
"github.com/pingcap/tiflow/dm/pkg/retry"
"github.com/pingcap/tiflow/dm/pkg/terror"
"github.com/pingcap/tiflow/dm/pkg/utils"

"github.com/go-sql-driver/mysql"
toolutils "github.com/pingcap/tidb-tools/pkg/utils"
)

var customID int64

var netTimeout = utils.DefaultDBTimeout

// DBProvider providers BaseDB instance.
type DBProvider interface {
Apply(config *config.DBConfig) (*BaseDB, error)
Expand All @@ -52,9 +52,6 @@ func init() {
DefaultDBProvider = &DefaultDBProviderImpl{}
}

// mockDB is used in unit test.
var mockDB sqlmock.Sqlmock

// Apply will build BaseDB with DBConfig.
func (d *DefaultDBProviderImpl) Apply(config *config.DBConfig) (*BaseDB, error) {
// maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection.
Expand Down Expand Up @@ -113,14 +110,8 @@ func (d *DefaultDBProviderImpl) Apply(config *config.DBConfig) (*BaseDB, error)
if err != nil {
return nil, terror.DBErrorAdapt(err, terror.ErrDBDriverError)
}
failpoint.Inject("failDBPing", func(_ failpoint.Value) {
db.Close()
db, mockDB, _ = sqlmock.New()
mockDB.ExpectPing()
mockDB.ExpectClose()
})

ctx, cancel := context.WithTimeout(context.Background(), utils.DefaultDBTimeout)
ctx, cancel := context.WithTimeout(context.Background(), netTimeout)
defer cancel()
err = db.PingContext(ctx)
failpoint.Inject("failDBPing", func(_ failpoint.Value) {
Expand Down Expand Up @@ -159,6 +150,8 @@ func NewBaseDB(db *sql.DB, doFuncInClose ...func()) *BaseDB {

// GetBaseConn retrieves *BaseConn which has own retryStrategy.
func (d *BaseDB) GetBaseConn(ctx context.Context) (*BaseConn, error) {
ctx, cancel := context.WithTimeout(ctx, netTimeout)
defer cancel()
conn, err := d.DB.Conn(ctx)
if err != nil {
return nil, terror.DBErrorAdapt(err, terror.ErrDBDriverError)
Expand Down
88 changes: 60 additions & 28 deletions dm/pkg/conn/basedb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,97 @@
package conn

import (
. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"context"
"database/sql"
"fmt"
"net"
"testing"
"time"

sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/phayes/freeport"
"github.com/pingcap/tiflow/dm/pkg/utils"
"github.com/stretchr/testify/require"

"github.com/DATA-DOG/go-sqlmock"

"github.com/pingcap/tiflow/dm/dm/config"
tcontext "github.com/pingcap/tiflow/dm/pkg/context"
)

var _ = Suite(&testBaseDBSuite{})

type testBaseDBSuite struct{}

func (t *testBaseDBSuite) TestGetBaseConn(c *C) {
func TestGetBaseConn(t *testing.T) {
db, mock, err := sqlmock.New()
c.Assert(err, IsNil)
require.NoError(t, err)

baseDB := NewBaseDB(db)

tctx := tcontext.Background()

dbConn, err := baseDB.GetBaseConn(tctx.Context())
c.Assert(dbConn, NotNil)
c.Assert(err, IsNil)
require.NoError(t, err)
require.NotNil(t, dbConn)

mock.ExpectQuery("select 1").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1"))
// nolint:sqlclosecheck
rows, err := dbConn.QuerySQL(tctx, "select 1")
c.Assert(err, IsNil)
require.NoError(t, err)
ids := make([]int, 0, 1)
for rows.Next() {
var id int
err = rows.Scan(&id)
c.Assert(err, IsNil)
require.NoError(t, err)
ids = append(ids, id)
}
c.Assert(ids, HasLen, 1)
c.Assert(ids[0], Equals, 1)
require.Equal(t, []int{1}, ids)

mock.ExpectBegin()
mock.ExpectExec("create database test").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
affected, err := dbConn.ExecuteSQL(tctx, testStmtHistogram, "test", []string{"create database test"})
c.Assert(err, IsNil)
c.Assert(affected, Equals, 1)
c.Assert(baseDB.Close(), IsNil)
require.NoError(t, err)
require.Equal(t, 1, affected)
require.NoError(t, baseDB.Close())
}

func (t *testBaseDBSuite) TestFailDBPing(c *C) {
c.Assert(failpoint.Enable("github.com/pingcap/tiflow/dm/pkg/conn/failDBPing", "return"), IsNil)
//nolint:errcheck
defer failpoint.Disable("github.com/pingcap/tiflow/dm/pkg/conn/failDBPing")
func TestFailDBPing(t *testing.T) {
netTimeout = time.Second
defer func() {
netTimeout = utils.DefaultDBTimeout
}()
port := freeport.GetPort()
addr := fmt.Sprintf("127.0.0.1:%d", port)

cfg := &config.DBConfig{User: "root", Host: "127.0.0.1", Port: 3306}
l, err := net.Listen("tcp", addr)
require.NoError(t, err)
defer l.Close()

cfg := &config.DBConfig{User: "root", Host: "127.0.0.1", Port: port}
cfg.Adjust()
db, err := DefaultDBProvider.Apply(cfg)
c.Assert(db, IsNil)
c.Assert(err, NotNil)
impl := &DefaultDBProviderImpl{}
db, err := impl.Apply(cfg)
require.Error(t, err)
require.Nil(t, db)
}

func TestGetBaseConnWontBlock(t *testing.T) {
netTimeout = time.Second
defer func() {
netTimeout = utils.DefaultDBTimeout
}()
ctx := context.Background()

port := freeport.GetPort()
addr := fmt.Sprintf("127.0.0.1:%d", port)

l, err := net.Listen("tcp", addr)
require.NoError(t, err)
defer l.Close()

// no such MySQL listening on port, so Conn will block
db, err := sql.Open("mysql", "root:@tcp("+addr+")/test")
require.NoError(t, err)

baseDB := NewBaseDB(db)

err = mockDB.ExpectationsWereMet()
c.Assert(err, IsNil)
_, err = baseDB.GetBaseConn(ctx)
require.Error(t, err)
}

0 comments on commit 01a9b2e

Please sign in to comment.