Skip to content

Commit

Permalink
Added a method to create a mysql database from a connection object (#583
Browse files Browse the repository at this point in the history
)

* Added a method to create a mysql database from a connection object

* Calling WithConnection from WithInstance to de-duplicate code

* Adding context and ping to mysql.WithConnection

* Interface type check at compile time
  • Loading branch information
Seb-C authored Jul 6, 2021
1 parent 805f8c8 commit 31cedbb
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions database/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"github.com/hashicorp/go-multierror"
)

var _ database.Driver = (*Mysql)(nil) // explicit compile time type check

func init() {
database.Register("mysql", &Mysql{})
}
Expand Down Expand Up @@ -50,20 +52,26 @@ type Mysql struct {
config *Config
}

// instance must have `multiStatements` set to true
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
// connection instance must have `multiStatements` set to true
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
if err := conn.PingContext(ctx); err != nil {
return nil, err
}

mx := &Mysql{
conn: conn,
db: nil,
config: config,
}

if config.DatabaseName == "" {
query := `SELECT DATABASE()`
var databaseName sql.NullString
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

Expand All @@ -78,21 +86,33 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
config.MigrationsTable = DefaultMigrationsTable
}

conn, err := instance.Conn(context.Background())
if err != nil {
if err := mx.ensureVersionTable(); err != nil {
return nil, err
}

mx := &Mysql{
conn: conn,
db: instance,
config: config,
return mx, nil
}

// instance must have `multiStatements` set to true
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
ctx := context.Background()

if err := instance.Ping(); err != nil {
return nil, err
}

if err := mx.ensureVersionTable(); err != nil {
conn, err := instance.Conn(ctx)
if err != nil {
return nil, err
}

mx, err := WithConnection(ctx, conn, config)
if err != nil {
return nil, err
}

mx.db = instance

return mx, nil
}

Expand Down Expand Up @@ -239,7 +259,11 @@ func (m *Mysql) Open(url string) (database.Driver, error) {

func (m *Mysql) Close() error {
connErr := m.conn.Close()
dbErr := m.db.Close()
var dbErr error
if m.db != nil {
dbErr = m.db.Close()
}

if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
Expand Down

0 comments on commit 31cedbb

Please sign in to comment.