Skip to content
This repository has been archived by the owner on Aug 23, 2024. It is now read-only.

Commit

Permalink
Postgres and pgx - Add x-migrations-table-quoted url query option to …
Browse files Browse the repository at this point in the history
…postgres and pgx drivers (golang-migrate#95)

By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must to quote migrations table name manually, for instance `"gomigrate"."schema_migrations"`
  • Loading branch information
stephane-klein committed Apr 4, 2021
1 parent 511ae9f commit f31e3b5
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 20 deletions.
1 change: 1 addition & 0 deletions database/pgx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must quote migrations table name manually, for instance `"gomigrate"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
Expand Down
52 changes: 43 additions & 9 deletions database/pgx/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"io/ioutil"
nurl "net/url"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -46,6 +47,7 @@ type Config struct {
DatabaseName string
SchemaName string
StatementTimeout time.Duration
MigrationsTableQuoted bool
MultiStatementEnabled bool
MultiStatementMaxSize int
}
Expand Down Expand Up @@ -137,6 +139,17 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
}

migrationsTable := purl.Query().Get("x-migrations-table")
migrationsTableQuoted := false
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
migrationsTableQuoted, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
}
}
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"gomigrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
}

statementTimeoutString := purl.Query().Get("x-statement-timeout")
statementTimeout := 0
if statementTimeoutString != "" {
Expand Down Expand Up @@ -168,6 +181,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
MigrationsTableQuoted: migrationsTableQuoted,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: multiStatementEnabled,
MultiStatementMaxSize: multiStatementMaxSize,
Expand Down Expand Up @@ -321,7 +335,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}

query := `TRUNCATE ` + quoteIdentifier(p.config.MigrationsTable)
query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
Expand All @@ -333,7 +347,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + quoteIdentifier(p.config.MigrationsTable) +
query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) +
` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
Expand All @@ -351,7 +365,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
}

func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
Expand Down Expand Up @@ -401,7 +415,7 @@ func (p *Postgres) Drop() (err error) {
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand Down Expand Up @@ -433,10 +447,27 @@ func (p *Postgres) ensureVersionTable() (err error) {
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)
var row *sql.Row
tableName := []byte(p.config.MigrationsTable)
schemaName := []byte("")
if p.config.MigrationsTableQuoted {
re := regexp.MustCompile(`"(.*?)"`)
result := re.FindAllSubmatch([]byte(p.config.MigrationsTable), -1)
tableName = result[len(result)-1][1]
if len(result) > 1 {
schemaName = result[0][1]
}
}
var query string
if len(schemaName) > 0 {
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1`
row = p.conn.QueryRowContext(context.Background(), query, tableName, schemaName)
} else {
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
row = p.conn.QueryRowContext(context.Background(), query, tableName)
}

var count int
err = row.Scan(&count)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
Expand All @@ -446,7 +477,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
return nil
}

query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand All @@ -455,7 +486,10 @@ func (p *Postgres) ensureVersionTable() (err error) {
}

// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
func (p *Postgres) quoteIdentifier(name string) string {
if p.config.MigrationsTableQuoted {
return name
}
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
Expand Down
107 changes: 107 additions & 0 deletions database/pgx/pgx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,65 @@ func TestWithSchema(t *testing.T) {
})
}

func TestMigrationTableOption(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := pgConnectionString(ip, port)
p := &Postgres{}
d, _ := p.Open(addr)
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()

// create gomigrate schema
if err := d.Run(strings.NewReader("CREATE SCHEMA gomigrate AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}

// bad unquoted x-migrations-table parameter
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=gomigrate.schema_migrations&x-migrations-table-quoted=1",
pgPassword, ip, port))
if err == nil {
t.Fatal("expected x-migrations-table must be quoted...")
}

// good quoted x-migrations-table parameter
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"gomigrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}

// make sure gomigrate.schema_migrations table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'gomigrate')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table gomigrate.schema_migrations to exist")
}

d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=gomigrate.schema_migrations",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'gomigrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table 'gomigrate.schema_migrations' to exist")
}

})
}

func TestFailToCreateTableWithoutPermissions(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
Expand Down Expand Up @@ -373,6 +432,18 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) {
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}

// re-connect using that x-migrations-table and x-migrations-table-quoted
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))

if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}

if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
})
}

Expand Down Expand Up @@ -679,5 +750,41 @@ func Test_computeLineFromPos(t *testing.T) {
run(true, true)
})
}
}

func Test_quoteIdentifier(t *testing.T) {
testcases := []struct {
migrationsTableQuoted bool
migrationsTable string
want string
}{
{
false,
"schema_name.table_name",
"\"schema_name.table_name\"",
},
{
false,
"table_name",
"\"table_name\"",
},
{
true,
"\"schema_name\".\"table.name\"",
"\"schema_name\".\"table.name\"",
},
}
p := &Postgres{
config: &Config{
MigrationsTableQuoted: false,
},
}

for _, tc := range testcases {
p.config.MigrationsTableQuoted = tc.migrationsTableQuoted
got := p.quoteIdentifier(tc.migrationsTable)
if tc.want != got {
t.Fatalf("expected %s but got %s", tc.want, got)
}
}
}
1 change: 1 addition & 0 deletions database/postgres/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must quote migrations table name manually, for instance `"gomigrate"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
Expand Down
Loading

0 comments on commit f31e3b5

Please sign in to comment.