diff --git a/dm/pkg/conn/basedb.go b/dm/pkg/conn/basedb.go index 250adb6aaeb..f9f02b1ca9a 100644 --- a/dm/pkg/conn/basedb.go +++ b/dm/pkg/conn/basedb.go @@ -23,20 +23,20 @@ import ( "sync" "sync/atomic" - "github.com/DATA-DOG/go-sqlmock" + "github.com/go-sql-driver/mysql" "github.com/pingcap/failpoint" + toolutils "github.com/pingcap/tidb-tools/pkg/utils" "github.com/pingcap/tiflow/dm/dm/config" "github.com/pingcap/tiflow/dm/pkg/retry" "github.com/pingcap/tiflow/dm/pkg/terror" "github.com/pingcap/tiflow/dm/pkg/utils" - - "github.com/go-sql-driver/mysql" - toolutils "github.com/pingcap/tidb-tools/pkg/utils" ) var customID int64 +var netTimeout = utils.DefaultDBTimeout + // DBProvider providers BaseDB instance. type DBProvider interface { Apply(config *config.DBConfig) (*BaseDB, error) @@ -52,9 +52,6 @@ func init() { DefaultDBProvider = &DefaultDBProviderImpl{} } -// mockDB is used in unit test. -var mockDB sqlmock.Sqlmock - // Apply will build BaseDB with DBConfig. func (d *DefaultDBProviderImpl) Apply(config *config.DBConfig) (*BaseDB, error) { // maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection. @@ -113,14 +110,8 @@ func (d *DefaultDBProviderImpl) Apply(config *config.DBConfig) (*BaseDB, error) if err != nil { return nil, terror.DBErrorAdapt(err, terror.ErrDBDriverError) } - failpoint.Inject("failDBPing", func(_ failpoint.Value) { - db.Close() - db, mockDB, _ = sqlmock.New() - mockDB.ExpectPing() - mockDB.ExpectClose() - }) - ctx, cancel := context.WithTimeout(context.Background(), utils.DefaultDBTimeout) + ctx, cancel := context.WithTimeout(context.Background(), netTimeout) defer cancel() err = db.PingContext(ctx) failpoint.Inject("failDBPing", func(_ failpoint.Value) { @@ -159,6 +150,8 @@ func NewBaseDB(db *sql.DB, doFuncInClose ...func()) *BaseDB { // GetBaseConn retrieves *BaseConn which has own retryStrategy. func (d *BaseDB) GetBaseConn(ctx context.Context) (*BaseConn, error) { + ctx, cancel := context.WithTimeout(ctx, netTimeout) + defer cancel() conn, err := d.DB.Conn(ctx) if err != nil { return nil, terror.DBErrorAdapt(err, terror.ErrDBDriverError) diff --git a/dm/pkg/conn/basedb_test.go b/dm/pkg/conn/basedb_test.go index 92860cde982..f503f9a7743 100644 --- a/dm/pkg/conn/basedb_test.go +++ b/dm/pkg/conn/basedb_test.go @@ -14,65 +14,97 @@ package conn import ( - . "github.com/pingcap/check" - "github.com/pingcap/failpoint" + "context" + "database/sql" + "fmt" + "net" + "testing" + "time" - sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/phayes/freeport" + "github.com/pingcap/tiflow/dm/pkg/utils" + "github.com/stretchr/testify/require" + + "github.com/DATA-DOG/go-sqlmock" "github.com/pingcap/tiflow/dm/dm/config" tcontext "github.com/pingcap/tiflow/dm/pkg/context" ) -var _ = Suite(&testBaseDBSuite{}) - -type testBaseDBSuite struct{} - -func (t *testBaseDBSuite) TestGetBaseConn(c *C) { +func TestGetBaseConn(t *testing.T) { db, mock, err := sqlmock.New() - c.Assert(err, IsNil) + require.NoError(t, err) baseDB := NewBaseDB(db) tctx := tcontext.Background() dbConn, err := baseDB.GetBaseConn(tctx.Context()) - c.Assert(dbConn, NotNil) - c.Assert(err, IsNil) + require.NoError(t, err) + require.NotNil(t, dbConn) mock.ExpectQuery("select 1").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) // nolint:sqlclosecheck rows, err := dbConn.QuerySQL(tctx, "select 1") - c.Assert(err, IsNil) + require.NoError(t, err) ids := make([]int, 0, 1) for rows.Next() { var id int err = rows.Scan(&id) - c.Assert(err, IsNil) + require.NoError(t, err) ids = append(ids, id) } - c.Assert(ids, HasLen, 1) - c.Assert(ids[0], Equals, 1) + require.Equal(t, []int{1}, ids) mock.ExpectBegin() mock.ExpectExec("create database test").WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() affected, err := dbConn.ExecuteSQL(tctx, testStmtHistogram, "test", []string{"create database test"}) - c.Assert(err, IsNil) - c.Assert(affected, Equals, 1) - c.Assert(baseDB.Close(), IsNil) + require.NoError(t, err) + require.Equal(t, 1, affected) + require.NoError(t, baseDB.Close()) } -func (t *testBaseDBSuite) TestFailDBPing(c *C) { - c.Assert(failpoint.Enable("github.com/pingcap/tiflow/dm/pkg/conn/failDBPing", "return"), IsNil) - //nolint:errcheck - defer failpoint.Disable("github.com/pingcap/tiflow/dm/pkg/conn/failDBPing") +func TestFailDBPing(t *testing.T) { + netTimeout = time.Second + defer func() { + netTimeout = utils.DefaultDBTimeout + }() + port := freeport.GetPort() + addr := fmt.Sprintf("127.0.0.1:%d", port) - cfg := &config.DBConfig{User: "root", Host: "127.0.0.1", Port: 3306} + l, err := net.Listen("tcp", addr) + require.NoError(t, err) + defer l.Close() + + cfg := &config.DBConfig{User: "root", Host: "127.0.0.1", Port: port} cfg.Adjust() - db, err := DefaultDBProvider.Apply(cfg) - c.Assert(db, IsNil) - c.Assert(err, NotNil) + impl := &DefaultDBProviderImpl{} + db, err := impl.Apply(cfg) + require.Error(t, err) + require.Nil(t, db) +} + +func TestGetBaseConnWontBlock(t *testing.T) { + netTimeout = time.Second + defer func() { + netTimeout = utils.DefaultDBTimeout + }() + ctx := context.Background() + + port := freeport.GetPort() + addr := fmt.Sprintf("127.0.0.1:%d", port) + + l, err := net.Listen("tcp", addr) + require.NoError(t, err) + defer l.Close() + + // no such MySQL listening on port, so Conn will block + db, err := sql.Open("mysql", "root:@tcp("+addr+")/test") + require.NoError(t, err) + + baseDB := NewBaseDB(db) - err = mockDB.ExpectationsWereMet() - c.Assert(err, IsNil) + _, err = baseDB.GetBaseConn(ctx) + require.Error(t, err) }