diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 83fed3a888..a20fa378d8 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -13,6 +13,7 @@ import ( "unicode" "github.com/kyleconroy/sqlc/internal/catalog" + "github.com/kyleconroy/sqlc/internal/migrations" core "github.com/kyleconroy/sqlc/internal/pg" "github.com/kyleconroy/sqlc/internal/postgres" "github.com/kyleconroy/sqlc/internal/postgresql/ast" @@ -86,8 +87,7 @@ func ReadSQLFiles(path string) ([]string, error) { if strings.HasPrefix(filepath.Base(filename), ".") { continue } - // Remove golang-migrate rollback files. - if strings.HasSuffix(filename, ".down.sql") { + if migrations.IsDown(filename) { continue } sql = append(sql, filename) @@ -109,7 +109,7 @@ func ParseCatalog(schema string) (core.Catalog, error) { merr.Add(filename, "", 0, err) continue } - contents := RemoveRollbackStatements(string(blob)) + contents := migrations.RemoveRollbackStatements(string(blob)) tree, err := pg.Parse(contents) if err != nil { merr.Add(filename, contents, 0, err) diff --git a/internal/dinosql/testdata/migrations/1.down.sql b/internal/dinosql/testdata/migrations/1.down.sql deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/internal/dinosql/testdata/migrations/1.up.sql b/internal/dinosql/testdata/migrations/1.up.sql deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/internal/dinosql/testdata/migrations/2.down.sql b/internal/dinosql/testdata/migrations/2.down.sql deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/internal/dinosql/testdata/migrations/2.sql b/internal/dinosql/testdata/migrations/2.sql deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/internal/dinosql/testdata/migrations/foo.sql b/internal/dinosql/testdata/migrations/foo.sql deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/internal/dinosql/migrations.go b/internal/migrations/migrations.go similarity index 80% rename from internal/dinosql/migrations.go rename to internal/migrations/migrations.go index 20b253f45d..ea37bd4a59 100644 --- a/internal/dinosql/migrations.go +++ b/internal/migrations/migrations.go @@ -1,4 +1,4 @@ -package dinosql +package migrations import ( "bufio" @@ -27,3 +27,8 @@ func RemoveRollbackStatements(contents string) string { } return strings.Join(lines, "\n") } + +func IsDown(filename string) bool { + // Remove golang-migrate rollback files. + return strings.HasSuffix(filename, ".down.sql") +} diff --git a/internal/dinosql/migrations_test.go b/internal/migrations/migrations_test.go similarity index 79% rename from internal/dinosql/migrations_test.go rename to internal/migrations/migrations_test.go index 223156d33e..bde915b1c9 100644 --- a/internal/dinosql/migrations_test.go +++ b/internal/migrations/migrations_test.go @@ -1,10 +1,9 @@ -package dinosql +package migrations import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" ) const inputGoose = ` @@ -60,21 +59,19 @@ func TestRemoveRollback(t *testing.T) { } func TestRemoveGolangMigrateRollback(t *testing.T) { - want := []string{ + filenames := map[string]bool{ // make sure we let through golang-migrate files that aren't rollbacks - "testdata/migrations/1.up.sql", + "migrations/1.up.sql": false, // make sure we let through other sql files - "testdata/migrations/2.sql", - "testdata/migrations/foo.sql", + "migrations/2.sql": false, + "migrations/foo.sql": false, + "migrations/1.down.sql": true, } - got, err := ReadSQLFiles("./testdata/migrations") - if err != nil { - t.Fatal(err) - } - - less := func(a, b string) bool { return a < b } - if diff := cmp.Diff(want, got, cmpopts.SortSlices(less)); diff != "" { - t.Errorf("golang-migrate filtering mismatch: \n %s", diff) + for filename, want := range filenames { + got := IsDown(filename) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("IsDown mismatch: %s\n %s", filename, diff) + } } } diff --git a/internal/mysql/parse.go b/internal/mysql/parse.go index ee1404905d..be377aa611 100644 --- a/internal/mysql/parse.go +++ b/internal/mysql/parse.go @@ -11,6 +11,7 @@ import ( "github.com/kyleconroy/sqlc/internal/config" "github.com/kyleconroy/sqlc/internal/dinosql" + "github.com/kyleconroy/sqlc/internal/migrations" ) // Query holds the data for walking and validating mysql querys @@ -44,7 +45,7 @@ func parsePath(sqlPath string, generator PackageGenerator) (*Result, error) { if err != nil { parseErrors.Add(filename, "", 0, err) } - contents := dinosql.RemoveRollbackStatements(string(blob)) + contents := migrations.RemoveRollbackStatements(string(blob)) if err != nil { parseErrors.Add(filename, "", 0, err) continue