Skip to content

Commit

Permalink
Added a method to create a mysql database from a connection object
Browse files Browse the repository at this point in the history
  • Loading branch information
Sébastien CAPARROS committed Jun 17, 2021
1 parent 3b3c1b6 commit 1b8c309
Showing 1 changed file with 51 additions and 16 deletions.
67 changes: 51 additions & 16 deletions database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,32 +54,37 @@ type Mysql struct {
config *Config
}

// instance must have `multiStatements` set to true
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
// connection instance must have `multiStatements` set to true
func WithConnection(connection *sql.Conn, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
mx := &Mysql{
conn: connection,
db: nil,
config: config,
}

if err := mx.setupDefaultConfig(); err != nil {
return nil, err
}

if config.DatabaseName == "" {
query := `SELECT DATABASE()`
var databaseName sql.NullString
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if err := mx.ensureVersionTable(); err != nil {
return nil, err
}

if len(databaseName.String) == 0 {
return nil, ErrNoDatabaseName
}
return mx, nil
}

config.DatabaseName = databaseName.String
// instance must have `multiStatements` set to true
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}

if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
if err := instance.Ping(); err != nil {
return nil, err
}

conn, err := instance.Conn(context.Background())
Expand All @@ -93,13 +98,39 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config: config,
}

if err := mx.setupDefaultConfig(); err != nil {
return nil, err
}

if err := mx.ensureVersionTable(); err != nil {
return nil, err
}

return mx, nil
}

func (m *Mysql) setupDefaultConfig() error {
if m.config.DatabaseName == "" {
query := `SELECT DATABASE()`
var databaseName sql.NullString
if err := m.conn.QueryRowContext(context.Background(), query).Scan(&databaseName); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

if len(databaseName.String) == 0 {
return ErrNoDatabaseName
}

m.config.DatabaseName = databaseName.String
}

if len(m.config.MigrationsTable) == 0 {
m.config.MigrationsTable = DefaultMigrationsTable
}

return nil
}

// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
Expand Down Expand Up @@ -243,7 +274,11 @@ func (m *Mysql) Open(url string) (database.Driver, error) {

func (m *Mysql) Close() error {
connErr := m.conn.Close()
dbErr := m.db.Close()
var dbErr error
if m.db != nil {
dbErr = m.db.Close()
}

if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
Expand Down

0 comments on commit 1b8c309

Please sign in to comment.