diff --git a/app/pkg/dbx/migrate.go b/app/pkg/dbx/migrate.go index 9204c5cf7..eebe43154 100644 --- a/app/pkg/dbx/migrate.go +++ b/app/pkg/dbx/migrate.go @@ -4,7 +4,9 @@ import ( "context" "database/sql" stdErrors "errors" + "github.com/lib/pq" "os" + "slices" "sort" "strconv" "strings" @@ -63,20 +65,23 @@ func Migrate(ctx context.Context, path string) error { totalMigrationsExecuted := 0 + pendingVersions, err := getPendingMigrations(versions) + if err != nil { + return errors.Wrap(err, "failed to get pending migrations") + } + // Apply all migrations - for _, version := range versions { - if version > lastVersion { - fileName := versionFiles[version] - log.Infof(ctx, "Running Version: @{Version} (@{FileName})", dto.Props{ - "Version": version, - "FileName": fileName, - }) - err := runMigration(ctx, version, path, fileName) - if err != nil { - return errors.Wrap(err, "failed to run migration '%s'", fileName) - } - totalMigrationsExecuted++ + for _, version := range pendingVersions { + fileName := versionFiles[version] + log.Infof(ctx, "Running Version: @{Version} (@{FileName})", dto.Props{ + "Version": version, + "FileName": fileName, + }) + err := runMigration(ctx, version, path, fileName) + if err != nil { + return errors.Wrap(err, "failed to run migration '%s'", fileName) } + totalMigrationsExecuted++ } if totalMigrationsExecuted > 0 { @@ -140,3 +145,26 @@ func getLastMigration() (int, error) { return int(lastVersion.Int64), nil } + +func getPendingMigrations(versions []int) ([]int, error) { + pendingMigrations := append([]int(nil), versions...) + + rows, err := conn.Query("SELECT version FROM migrations_history WHERE version = ANY($1)", pq.Array(pendingMigrations)) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var version int + if err := rows.Scan(&version); err != nil { + return nil, errors.Wrap(err, "failed to scan version") + } + i := slices.Index(pendingMigrations, version) + if i != -1 { + pendingMigrations = slices.Delete(pendingMigrations, i, i+1) + } + } + + return pendingMigrations, nil +} diff --git a/app/pkg/dbx/migrate_test.go b/app/pkg/dbx/migrate_test.go index 7ee8cc504..183107bc6 100644 --- a/app/pkg/dbx/migrate_test.go +++ b/app/pkg/dbx/migrate_test.go @@ -39,6 +39,30 @@ func TestMigrate_Success(t *testing.T) { trx.MustRollback() } +func TestMigrate_SuccessWithPastMigration(t *testing.T) { + setupMigrationTest(t) + ctx := context.Background() + + err := dbx.Migrate(ctx, "/app/pkg/dbx/testdata/migration_success") + Expect(err).IsNil() + + err = dbx.Migrate(ctx, "/app/pkg/dbx/testdata/migration_success_with_new_migrations") + Expect(err).IsNil() + + trx, _ := dbx.BeginTx(ctx) + var version string + err = trx.Scalar(&version, "SELECT version FROM migrations_history WHERE version = '209901010000' LIMIT 1") + Expect(err).IsNil() + Expect(version).Equals("209901010000") + + var count int + err = trx.Scalar(&count, "SELECT COUNT(*) FROM migrations_history WHERE version IN (209901010000,210001010002)") + Expect(err).IsNil() + Expect(count).Equals(2) + + trx.MustRollback() +} + func TestMigrate_Failure(t *testing.T) { setupMigrationTest(t) ctx := context.Background() diff --git a/app/pkg/dbx/testdata/migration_success_with_new_migrations/209901010000_create.sql b/app/pkg/dbx/testdata/migration_success_with_new_migrations/209901010000_create.sql new file mode 100644 index 000000000..6f9bf34f9 --- /dev/null +++ b/app/pkg/dbx/testdata/migration_success_with_new_migrations/209901010000_create.sql @@ -0,0 +1 @@ +insert into dummy (id, description) values (400, 'Description 400A'); \ No newline at end of file diff --git a/app/pkg/dbx/testdata/migration_success_with_new_migrations/210001010002_delete.sql b/app/pkg/dbx/testdata/migration_success_with_new_migrations/210001010002_delete.sql new file mode 100644 index 000000000..2655a0b24 --- /dev/null +++ b/app/pkg/dbx/testdata/migration_success_with_new_migrations/210001010002_delete.sql @@ -0,0 +1 @@ +DELETE FROM dummy WHERE id = 400; \ No newline at end of file