From 8f9bb03e98586bcaedcc7d8150bee6a49986de52 Mon Sep 17 00:00:00 2001 From: Brandur Date: Sat, 16 Mar 2024 11:47:44 -0700 Subject: [PATCH] Migration CLI: Add `migrate-get` command and `--dry-run`/`--show-sql` options This one's in pursuit of resolving #209, which I've decided to tackle now because we're going to have to cut a CLI release across River package versions for #258 anyway, so this'll reuse some work. A new `river migrate-get` command becomes available, whose only job is to dump SQL from River migrations so that it can easily be plugged into other migration frameworks. Its use looks like: river migrate-get --version 3 --down > version3.down.sql river migrate-get --version 3 --up > version3.up.sql It can also take multiple versions: river migrate-get --version 3,2,1 --down > river.down.sql river migrate-get --version 1,2.3 --up > river.up.sql It can also dump _all_ migrations, which will be useful in cases where users want to avoid River's internal migration framework completely, and use their own: river migrate-get --all --exclude-version 1 --up > river_all.up.sql river migrate-get --all --exclude-version 1 --down > river_all.down.sql Along with that, `migrate-down` and `migrate-up` get a few new useful options: * `--dry-run`: Prints information on migrations that would be run, but doesn't modify the database in any way. * `--show-sql`: Prints SQL for each migration step that was applied. This gives users an easy way to, after a River upgrade, run the CLI to see what commands would be run were they to migrate, but without actually performing the migration, likely a step that most production users would perform to be cautious: river migrate-up --dry-run --show-sql I've also done a little cleanup around the River CLI's `main.go`. The `--verbose` and `--debug` commands added in #258 are now promoted to persistent flag configuration so they're available for all commands, and we now have one standardized way of initializing an appropriate logger. Fixes #209. --- .github/workflows/ci.yaml | 12 ++ CHANGELOG.md | 2 + cmd/river/go.mod | 6 +- cmd/river/go.sum | 8 + cmd/river/main.go | 286 +++++++++++++++++++++++------ cmd/river/main_test.go | 16 ++ rivermigrate/river_migrate.go | 134 +++++++++----- rivermigrate/river_migrate_test.go | 100 ++++++++-- 8 files changed, 443 insertions(+), 121 deletions(-) create mode 100644 cmd/river/main_test.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e15b8ea2..d9ec7943 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -124,6 +124,18 @@ jobs: - name: Create database run: psql --echo-errors --quiet -c '\timing off' -c "CREATE DATABASE river_dev;" ${ADMIN_DATABASE_URL} + - run: ./river migrate-get --down --version 3 + working-directory: ./cmd/river + + - run: ./river migrate-get --up --version 3 + working-directory: ./cmd/river + + - run: ./river migrate-get --all --exclude-version 1 --down + working-directory: ./cmd/river + + - run: ./river migrate-get --all --exclude-version 1 --up + working-directory: ./cmd/river + - name: river migrate-up run: ./river migrate-up --database-url $DATABASE_URL working-directory: ./cmd/river diff --git a/CHANGELOG.md b/CHANGELOG.md index 591c7033..fd0f78d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - The River CLI now supports `river bench` to benchmark River's job throughput against a database. [PR #254](https://github.com/riverqueue/river/pull/254). +- The River CLI now has a `river migrate-get` command to dump SQL for River migrations for use in alternative migration frameworks. Use it like `river migrate-get --up --version 3 > version3.up.sql`. [PR #273](https://github.com/riverqueue/river/pull/273). +- The River CLI's `migrate-down` and `migrate-up` options get two new options for `--dry-run` and `--show-sql`. They can be combined to easily run a preflight check on a River upgrade to see which migration commands would be run on a database, but without actually running them. [PR #273](https://github.com/riverqueue/river/pull/273). - The River client gets a new `Client.SubscribeConfig` function that lets a subscriber specify the maximum size of their subscription channel. [PR #258](https://github.com/riverqueue/river/pull/258). ### Changed diff --git a/cmd/river/go.mod b/cmd/river/go.mod index 89a83c03..5678a1ff 100644 --- a/cmd/river/go.mod +++ b/cmd/river/go.mod @@ -12,21 +12,25 @@ replace github.com/riverqueue/river/riverdriver/riverpgxv5 => ../../riverdriver/ require ( github.com/jackc/pgx/v5 v5.5.5 + github.com/lmittmann/tint v1.0.4 github.com/riverqueue/river v0.0.17 github.com/riverqueue/river/riverdriver v0.0.25 github.com/riverqueue/river/riverdriver/riverpgxv5 v0.0.25 github.com/riverqueue/river/rivertype v0.0.25 github.com/spf13/cobra v1.8.0 + github.com/stretchr/testify v1.9.0 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect - github.com/lmittmann/tint v1.0.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/crypto v0.17.0 // indirect golang.org/x/sync v0.6.0 // indirect golang.org/x/text v0.14.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/cmd/river/go.sum b/cmd/river/go.sum index f472581d..66358676 100644 --- a/cmd/river/go.sum +++ b/cmd/river/go.sum @@ -14,6 +14,10 @@ github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lmittmann/tint v1.0.4 h1:LeYihpJ9hyGvE0w+K2okPTGUdVLfng1+nDNVR4vWISc= @@ -24,6 +28,8 @@ github.com/riverqueue/river/rivertype v0.0.25 h1:iyReBD59MUan83gp3geGoHKU5eUrB9J github.com/riverqueue/river/rivertype v0.0.25/go.mod h1:PvsLQ/xSATmmn9gdjB3pnIaj9ZSLmWhDTI4EPrK3AJ0= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= @@ -43,6 +49,8 @@ golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/cmd/river/main.go b/cmd/river/main.go index 57cb0f20..4f0e1032 100644 --- a/cmd/river/main.go +++ b/cmd/river/main.go @@ -4,9 +4,12 @@ import ( "context" "errors" "fmt" + "io" "log/slog" "os" + "slices" "strconv" + "strings" "time" "github.com/jackc/pgx/v5/pgxpool" @@ -19,30 +22,49 @@ import ( ) func main() { + var rootOpts struct { + Debug bool + Verbose bool + } + rootCmd := &cobra.Command{ Use: "river", Short: "Provides command line facilities for the River job queue", - Long: ` + Long: strings.TrimSpace(` Provides command line facilities for the River job queue. - `, + `), Run: func(cmd *cobra.Command, args []string) { _ = cmd.Usage() }, } + rootCmd.PersistentFlags().BoolVar(&rootOpts.Debug, "debug", false, "output maximum logging verbosity (debug level)") + rootCmd.PersistentFlags().BoolVarP(&rootOpts.Verbose, "verbose", "v", false, "output additional logging verbosity (info level)") + rootCmd.MarkFlagsMutuallyExclusive("debug", "verbose") ctx := context.Background() execHandlingError := func(f func() (bool, error)) { ok, err := f() if err != nil { - fmt.Printf("failed: %s\n", err) + fmt.Fprintf(os.Stderr, "failed: %s\n", err) } if err != nil || !ok { os.Exit(1) } } - mustMarkFlagRequired := func(cmd *cobra.Command, name string) { //nolint:unparam + makeLogger := func() *slog.Logger { + switch { + case rootOpts.Debug: + return slog.New(tint.NewHandler(os.Stdout, &tint.Options{Level: slog.LevelDebug})) + case rootOpts.Verbose: + return slog.New(tint.NewHandler(os.Stdout, nil)) + default: + return slog.New(tint.NewHandler(os.Stdout, &tint.Options{Level: slog.LevelWarn})) + } + } + + mustMarkFlagRequired := func(cmd *cobra.Command, name string) { // We just panic here because this will never happen outside of an error // in development. if err := cmd.MarkFlagRequired(name); err != nil { @@ -57,7 +79,7 @@ Provides command line facilities for the River job queue. cmd := &cobra.Command{ Use: "bench", Short: "Run River benchmark", - Long: ` + Long: strings.TrimSpace(` Run a River benchmark which inserts and works jobs continually, giving a rough idea of jobs per second and time to work a single job. @@ -69,66 +91,123 @@ before starting the client, and works until all jobs are finished. The database in --database-url will have its jobs table truncated, so make sure to use a development database only. - `, + `), Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return bench(ctx, &opts) }) + execHandlingError(func() (bool, error) { return bench(ctx, makeLogger(), os.Stdout, &opts) }) }, } cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to benchmark (should look like `postgres://...`") - cmd.Flags().BoolVar(&opts.Debug, "debug", false, "output maximum logging verbosity (debug level)") cmd.Flags().DurationVar(&opts.Duration, "duration", 0, "duration after which to stop benchmark, accepting Go-style durations like 1m, 5m30s") cmd.Flags().IntVarP(&opts.NumTotalJobs, "num-total-jobs", "n", 0, "number of jobs to insert before starting and which are worked down until finish") - cmd.Flags().BoolVarP(&opts.Verbose, "verbose", "v", false, "output additional logging verbosity (info level)") mustMarkFlagRequired(cmd, "database-url") - cmd.MarkFlagsMutuallyExclusive("debug", "verbose") rootCmd.AddCommand(cmd) } + // migrate-down and migrate-up share a set of options, so this is a way of + // plugging in all the right flags to both so options and docstrings stay + // consistent. + addMigrateFlags := func(cmd *cobra.Command, opts *migrateOpts) { + cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to migrate (should look like `postgres://...`") + cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false, "print information on migrations, but don't apply them") + cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 0, "maximum number of steps to migrate") + cmd.Flags().BoolVar(&opts.ShowSQL, "show-sql", false, "show SQL of each migration") + cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "target version to migrate to (final state includes this version, but none after it)") + mustMarkFlagRequired(cmd, "database-url") + } + // migrate-down { - var opts migrateDownOpts + var opts migrateOpts cmd := &cobra.Command{ Use: "migrate-down", Short: "Run River schema down migrations", - Long: ` + Long: strings.TrimSpace(` Run down migrations to reverse the River database schema changes. Defaults to running a single down migration. This behavior can be changed with --max-steps or --target-version. - `, + +SQL being run can be output using --show-sql, and executing real database +operations can be prevented with --dry-run. Combine --show-sql and --dry-run to +dump prospective migrations that would be applied to stdout. + `), Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return migrateDown(ctx, &opts) }) + execHandlingError(func() (bool, error) { return migrateDown(ctx, makeLogger(), os.Stdout, &opts) }) }, } - cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to migrate (should look like `postgres://...`") - cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 1, "maximum number of steps to migrate") - cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "target version to migrate to (final state includes this version, but none after it)") - mustMarkFlagRequired(cmd, "database-url") + addMigrateFlags(cmd, &opts) + rootCmd.AddCommand(cmd) + } + + // migrate-get + { + var opts migrateGetOpts + + cmd := &cobra.Command{ + Use: "migrate-get", + Short: "Get SQL for specific River migration", + Long: strings.TrimSpace(` +Retrieve SQL for a single migration version. This command is aimed at cases +where using River's internal migration framework isn't desirable by allowing +migration SQL to be dumped for use elsewhere. + +Specify a version with --version, and one of --down or --up: + + river migrate-get --version 3 --up > river3.up.sql + river migrate-get --version 3 --down > river3.down.sql + +Can also take multiple versions by separating them with commas or passing +--version multiple times: + + river migrate-get --version 1,2,3 --up > river.up.sql + river migrate-get --version 3,2,1 --down > river.down.sql + +Or use --all to print all known migrations in either direction. Often used in +conjunction with --exclude-version 1 to exclude the tables for River's migration +framework, which aren't necessary if using an external framework: + + river migrate-get --all --exclude-version 1 --up > river_all.up.sql + river migrate-get --all --exclude-version 1 --down > river_all.down.sql + `), + Run: func(cmd *cobra.Command, args []string) { + execHandlingError(func() (bool, error) { return migrateGet(ctx, makeLogger(), os.Stdout, &opts) }) + }, + } + cmd.Flags().BoolVar(&opts.All, "all", false, "print all migrations; down migrations are printed in descending order") + cmd.Flags().BoolVar(&opts.Down, "down", false, "print down migration") + cmd.Flags().IntSliceVar(&opts.ExcludeVersion, "exclude-version", nil, "exclude version(s), usually version 1, containing River's migration tables") + cmd.Flags().BoolVar(&opts.Up, "up", false, "print up migration") + cmd.Flags().IntSliceVar(&opts.Version, "version", nil, "version(s) to print (can be multiple versions)") + cmd.MarkFlagsMutuallyExclusive("all", "version") + cmd.MarkFlagsOneRequired("all", "version") + cmd.MarkFlagsMutuallyExclusive("down", "up") + cmd.MarkFlagsOneRequired("down", "up") rootCmd.AddCommand(cmd) } // migrate-up { - var opts migrateUpOpts + var opts migrateOpts cmd := &cobra.Command{ Use: "migrate-up", Short: "Run River schema up migrations", - Long: ` + Long: strings.TrimSpace(` Run up migrations to raise the database schema necessary to run River. Defaults to running all up migrations that aren't yet run. This behavior can be restricted with --max-steps or --target-version. - `, + +SQL being run can be output using --show-sql, and executing real database +operations can be prevented with --dry-run. Combine --show-sql and --dry-run to +dump prospective migrations that would be applied to stdout. + `), Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return migrateUp(ctx, &opts) }) + execHandlingError(func() (bool, error) { return migrateUp(ctx, makeLogger(), os.Stdout, &opts) }) }, } - cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to migrate (should look like `postgres://...`") - cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 0, "maximum number of steps to migrate") - cmd.Flags().IntVar(&opts.TargetVersion, "target-version", 0, "target version to migrate to (final state includes this version)") - mustMarkFlagRequired(cmd, "database-url") + addMigrateFlags(cmd, &opts) rootCmd.AddCommand(cmd) } @@ -139,12 +218,15 @@ restricted with --max-steps or --target-version. cmd := &cobra.Command{ Use: "validate", Short: "Validate River schema", - Long: ` + Long: strings.TrimSpace(` Validates the current River schema, exiting with a non-zero status in case there are outstanding migrations that still need to be run. - `, + +Can be paired with river migrate-up --dry-run --show-sql to dump information on +migrations that need to be run, but without running them. + `), Run: func(cmd *cobra.Command, args []string) { - execHandlingError(func() (bool, error) { return validate(ctx, &opts) }) + execHandlingError(func() (bool, error) { return validate(ctx, makeLogger(), os.Stdout, &opts) }) }, } cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to validate (should look like `postgres://...`") @@ -152,7 +234,10 @@ are outstanding migrations that still need to be run. rootCmd.AddCommand(cmd) } - execHandlingError(func() (bool, error) { return true, rootCmd.Execute() }) + // Cobra will already print an error on an uknown command, and there aren't + // really any other important top-level error cases to worry about as far as + // I can tell, so ignore a returned error here so we don't double print it. + _ = rootCmd.Execute() } func openDBPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) { @@ -204,7 +289,7 @@ func (o *benchOpts) validate() error { return nil } -func bench(ctx context.Context, opts *benchOpts) (bool, error) { +func bench(ctx context.Context, logger *slog.Logger, _ io.Writer, opts *benchOpts) (bool, error) { if err := opts.validate(); err != nil { return false, err } @@ -215,16 +300,6 @@ func bench(ctx context.Context, opts *benchOpts) (bool, error) { } defer dbPool.Close() - var logger *slog.Logger - switch { - case opts.Debug: - logger = slog.New(tint.NewHandler(os.Stdout, &tint.Options{Level: slog.LevelDebug})) - case opts.Verbose: - logger = slog.New(tint.NewHandler(os.Stdout, nil)) - default: - logger = slog.New(tint.NewHandler(os.Stdout, &tint.Options{Level: slog.LevelWarn})) - } - benchmarker := riverbench.NewBenchmarker(riverpgxv5.New(dbPool), logger, opts.Duration, opts.NumTotalJobs) if err := benchmarker.Run(ctx); err != nil { @@ -234,13 +309,15 @@ func bench(ctx context.Context, opts *benchOpts) (bool, error) { return true, nil } -type migrateDownOpts struct { +type migrateOpts struct { DatabaseURL string + DryRun bool + ShowSQL bool MaxSteps int TargetVersion int } -func (o *migrateDownOpts) validate() error { +func (o *migrateOpts) validate() error { if o.DatabaseURL == "" { return errors.New("database URL cannot be empty") } @@ -248,20 +325,26 @@ func (o *migrateDownOpts) validate() error { return nil } -func migrateDown(ctx context.Context, opts *migrateDownOpts) (bool, error) { +func migrateDown(ctx context.Context, logger *slog.Logger, out io.Writer, opts *migrateOpts) (bool, error) { if err := opts.validate(); err != nil { return false, err } + // Default to applying only one migration maximum on the down direction. + if opts.MaxSteps == 0 && opts.TargetVersion == 0 { + opts.MaxSteps = 1 + } + dbPool, err := openDBPool(ctx, opts.DatabaseURL) if err != nil { return false, err } defer dbPool.Close() - migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) + migrator := rivermigrate.New(riverpgxv5.New(dbPool), &rivermigrate.Config{Logger: logger}) - _, err = migrator.Migrate(ctx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ + res, err := migrator.Migrate(ctx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ + DryRun: opts.DryRun, MaxSteps: opts.MaxSteps, TargetVersion: opts.TargetVersion, }) @@ -269,24 +352,106 @@ func migrateDown(ctx context.Context, opts *migrateDownOpts) (bool, error) { return false, err } + migratePrintResult(out, opts, res, rivermigrate.DirectionDown) + return true, nil } -type migrateUpOpts struct { - DatabaseURL string - MaxSteps int - TargetVersion int +func migratePrintResult(out io.Writer, opts *migrateOpts, res *rivermigrate.MigrateResult, direction rivermigrate.Direction) { + if len(res.Versions) < 1 { + fmt.Fprintf(out, "no migrations to apply\n") + return + } + + for _, migrateVersion := range res.Versions { + if opts.DryRun { + fmt.Fprintf(out, "migration %03d [%s] [DRY RUN]\n", migrateVersion.Version, direction) + } else { + fmt.Fprintf(out, "applied migration %03d [%s] [%s]\n", migrateVersion.Version, direction, migrateVersion.Duration) + } + + if opts.ShowSQL { + fmt.Fprintf(out, "%s\n", strings.Repeat("-", 80)) + fmt.Fprintf(out, "%s\n", migrationComment(migrateVersion.Version, direction)) + fmt.Fprintf(out, "%s\n\n", strings.TrimSpace(migrateVersion.SQL)) + } + } + + // Only prints if more steps than available were requested. + if opts.MaxSteps > 0 && len(res.Versions) < opts.MaxSteps { + fmt.Fprintf(out, "no more migrations to apply\n") + } } -func (o *migrateUpOpts) validate() error { - if o.DatabaseURL == "" { - return errors.New("database URL cannot be empty") +// An informational comment that's tagged on top of any migration's SQL to help +// attribute what it is for when it's copied elsewhere like other migration +// frameworks. +func migrationComment(version int, direction rivermigrate.Direction) string { + return fmt.Sprintf("-- River migration %03d [%s]", version, direction) +} + +type migrateGetOpts struct { + All bool + Down bool + ExcludeVersion []int + Up bool + Version []int +} + +func migrateGet(_ context.Context, logger *slog.Logger, out io.Writer, opts *migrateGetOpts) (bool, error) { + migrator := rivermigrate.New(riverpgxv5.New(nil), &rivermigrate.Config{Logger: logger}) + + var migrations []rivermigrate.Migration + if opts.All { + migrations = migrator.AllVersions() + if opts.Down { + slices.Reverse(migrations) + } + } else { + for _, version := range opts.Version { + migration, err := migrator.GetVersion(version) + if err != nil { + return false, err + } + + migrations = append(migrations, migration) + } } - return nil + var printedOne bool + + for _, migration := range migrations { + if slices.Contains(opts.ExcludeVersion, migration.Version) { + continue + } + + // print newlines between multiple versions + if printedOne { + fmt.Fprintf(out, "\n") + } + + var ( + direction rivermigrate.Direction + sql string + ) + switch { + case opts.Down: + direction = rivermigrate.DirectionDown + sql = migration.SQLDown + case opts.Up: + direction = rivermigrate.DirectionUp + sql = migration.SQLUp + } + + printedOne = true + fmt.Fprintf(out, "%s\n", migrationComment(migration.Version, direction)) + fmt.Fprintf(out, "%s\n", strings.TrimSpace(sql)) + } + + return true, nil } -func migrateUp(ctx context.Context, opts *migrateUpOpts) (bool, error) { +func migrateUp(ctx context.Context, logger *slog.Logger, out io.Writer, opts *migrateOpts) (bool, error) { if err := opts.validate(); err != nil { return false, err } @@ -297,9 +462,10 @@ func migrateUp(ctx context.Context, opts *migrateUpOpts) (bool, error) { } defer dbPool.Close() - migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) + migrator := rivermigrate.New(riverpgxv5.New(dbPool), &rivermigrate.Config{Logger: logger}) - _, err = migrator.Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ + res, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ + DryRun: opts.DryRun, MaxSteps: opts.MaxSteps, TargetVersion: opts.TargetVersion, }) @@ -307,6 +473,8 @@ func migrateUp(ctx context.Context, opts *migrateUpOpts) (bool, error) { return false, err } + migratePrintResult(out, opts, res, rivermigrate.DirectionUp) + return true, nil } @@ -322,7 +490,7 @@ func (o *validateOpts) validate() error { return nil } -func validate(ctx context.Context, opts *validateOpts) (bool, error) { +func validate(ctx context.Context, logger *slog.Logger, _ io.Writer, opts *validateOpts) (bool, error) { if err := opts.validate(); err != nil { return false, err } @@ -333,7 +501,7 @@ func validate(ctx context.Context, opts *validateOpts) (bool, error) { } defer dbPool.Close() - migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) + migrator := rivermigrate.New(riverpgxv5.New(dbPool), &rivermigrate.Config{Logger: logger}) res, err := migrator.Validate(ctx) if err != nil { diff --git a/cmd/river/main_test.go b/cmd/river/main_test.go new file mode 100644 index 00000000..7a47fe70 --- /dev/null +++ b/cmd/river/main_test.go @@ -0,0 +1,16 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/rivermigrate" +) + +func TestMigrationComment(t *testing.T) { + t.Parallel() + + require.Equal(t, "-- River migration 001 [down]", migrationComment(1, rivermigrate.DirectionDown)) + require.Equal(t, "-- River migration 002 [up]", migrationComment(2, rivermigrate.DirectionUp)) +} diff --git a/rivermigrate/river_migrate.go b/rivermigrate/river_migrate.go index 10128130..7f02bb6b 100644 --- a/rivermigrate/river_migrate.go +++ b/rivermigrate/river_migrate.go @@ -24,12 +24,17 @@ import ( "github.com/riverqueue/river/riverdriver" ) -// A bundled migration containing a version (1, 2, 3), and SQL for up and down -// directions. -type migrationBundle struct { +// Migration is a bundled migration containing a version (e.g. 1, 2, 3), and SQL +// for up and down directions. +type Migration struct { + // SQLDown is the s SQL for the migration's down direction. + SQLDown string + + // SQLUp is the s SQL for the migration's up direction. + SQLUp string + + // Version is the integer version number of this migration. Version int - Up string - Down string } //nolint:gochecknoglobals @@ -55,7 +60,7 @@ type Migrator[TTx any] struct { baseservice.BaseService driver riverdriver.Driver[TTx] - migrations map[int]*migrationBundle // allows us to inject test migrations + migrations map[int]*Migration // allows us to inject test migrations } // New returns a new migrator with the given database driver and configuration. @@ -107,6 +112,8 @@ func New[TTx any](driver riverdriver.Driver[TTx], config *Config) *Migrator[TTx] // MigrateOpts are options for a migrate operation. type MigrateOpts struct { + DryRun bool + // MaxSteps is the maximum number of migrations to apply either up or down. // When migrating in the up direction, migrates an unlimited number of steps // by default. When migrating in the down direction, migrates only a single @@ -143,10 +150,17 @@ type MigrateResult struct { // MigrateVersion is the result for a single applied migration. type MigrateVersion struct { + // Duration is the amount of time it took to apply the migration. + Duration time.Duration + + // SQL is the SQL that was applied along with the migration. + SQL string + // Version is the version of the migration applied. Version int } +func migrationToInt(migration Migration) int { return migration.Version } func migrateVersionToInt(version MigrateVersion) int { return version.Version } type Direction string @@ -156,6 +170,26 @@ const ( DirectionUp Direction = "up" ) +// AllVersions gets information on all known migration versions. +func (m *Migrator[TTx]) AllVersions() []Migration { + migrations := maputil.Values(m.migrations) + slices.SortFunc(migrations, func(v1, v2 *Migration) int { return v1.Version - v2.Version }) + return sliceutil.Map(migrations, func(m *Migration) Migration { return *m }) +} + +// GetVersion gets information about a specific migration version. An error is +// returned if a versions is requested that doesn't exist. +func (m *Migrator[TTx]) GetVersion(version int) (Migration, error) { + migration, ok := m.migrations[version] + if !ok { + availableVersions := maputil.Keys(m.migrations) + slices.Sort(availableVersions) + return Migration{}, fmt.Errorf("migration %d not found (available versions: %v)", version, availableVersions) + } + + return *migration, nil +} + // Migrate migrates the database in the given direction (up or down). The opts // parameter may be omitted for convenience. // @@ -171,12 +205,12 @@ const ( // // handle error // } func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { - return dbutil.WithTxV(ctx, m.driver.GetExecutor(), func(ctx context.Context, tx riverdriver.ExecutorTx) (*MigrateResult, error) { + return dbutil.WithTxV(ctx, m.driver.GetExecutor(), func(ctx context.Context, exec riverdriver.ExecutorTx) (*MigrateResult, error) { switch direction { case DirectionDown: - return m.migrateDown(ctx, tx, direction, opts) + return m.migrateDown(ctx, exec, direction, opts) case DirectionUp: - return m.migrateUp(ctx, tx, direction, opts) + return m.migrateUp(ctx, exec, direction, opts) } panic("invalid direction: " + direction) @@ -257,7 +291,7 @@ func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Execut } sortedTargetMigrations := maputil.Values(targetMigrations) - slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return b.Version - a.Version }) // reverse order + slices.SortFunc(sortedTargetMigrations, func(a, b *Migration) int { return b.Version - a.Version }) // reverse order res, err := m.applyMigrations(ctx, exec, direction, opts, sortedTargetMigrations) if err != nil { @@ -278,8 +312,10 @@ func (m *Migrator[TTx]) migrateDown(ctx context.Context, exec riverdriver.Execut return res, nil } - if _, err := exec.MigrationDeleteByVersionMany(ctx, sliceutil.Map(res.Versions, migrateVersionToInt)); err != nil { - return nil, fmt.Errorf("error deleting migration rows for versions %+v: %w", res.Versions, err) + if !opts.DryRun { + if _, err := exec.MigrationDeleteByVersionMany(ctx, sliceutil.Map(res.Versions, migrateVersionToInt)); err != nil { + return nil, fmt.Errorf("error deleting migration rows for versions %+v: %w", res.Versions, err) + } } return res, nil @@ -298,15 +334,17 @@ func (m *Migrator[TTx]) migrateUp(ctx context.Context, exec riverdriver.Executor } sortedTargetMigrations := maputil.Values(targetMigrations) - slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return a.Version - b.Version }) + slices.SortFunc(sortedTargetMigrations, func(a, b *Migration) int { return a.Version - b.Version }) res, err := m.applyMigrations(ctx, exec, direction, opts, sortedTargetMigrations) if err != nil { return nil, err } - if _, err := exec.MigrationInsertMany(ctx, sliceutil.Map(res.Versions, migrateVersionToInt)); err != nil { - return nil, fmt.Errorf("error inserting migration rows for versions %+v: %w", res.Versions, err) + if opts == nil || !opts.DryRun { + if _, err := exec.MigrationInsertMany(ctx, sliceutil.Map(res.Versions, migrateVersionToInt)); err != nil { + return nil, fmt.Errorf("error inserting migration rows for versions %+v: %w", res.Versions, err) + } } return res, nil @@ -341,7 +379,7 @@ func (m *Migrator[TTx]) validate(ctx context.Context, exec riverdriver.Executor) // Common code shared between the up and down migration directions that walks // through each target migration and applies it, logging appropriately. -func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts, sortedTargetMigrations []*migrationBundle) (*MigrateResult, error) { +func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Executor, direction Direction, opts *MigrateOpts, sortedTargetMigrations []*Migration) (*MigrateResult, error) { if opts == nil { opts = &MigrateOpts{} } @@ -356,7 +394,7 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex switch { case maxSteps < 0: - sortedTargetMigrations = []*migrationBundle{} + sortedTargetMigrations = []*Migration{} case maxSteps > 0: sortedTargetMigrations = sortedTargetMigrations[0:min(maxSteps, len(sortedTargetMigrations))] } @@ -366,7 +404,7 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex return nil, fmt.Errorf("version %d is not a valid River migration version", opts.TargetVersion) } - targetIndex := slices.IndexFunc(sortedTargetMigrations, func(b *migrationBundle) bool { return b.Version == opts.TargetVersion }) + targetIndex := slices.IndexFunc(sortedTargetMigrations, func(b *Migration) bool { return b.Version == opts.TargetVersion }) if targetIndex == -1 { return nil, fmt.Errorf("version %d is not in target list of valid migrations to apply", opts.TargetVersion) } @@ -391,28 +429,34 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex } for _, versionBundle := range sortedTargetMigrations { - sql := versionBundle.Up - if direction == DirectionDown { - sql = versionBundle.Down + var sql string + switch direction { + case DirectionDown: + sql = versionBundle.SQLDown + case DirectionUp: + sql = versionBundle.SQLUp } - m.Logger.InfoContext(ctx, fmt.Sprintf(m.Name+": Applying migration %03d [%s]", versionBundle.Version, strings.ToUpper(string(direction))), - slog.String("direction", string(direction)), - slog.Int("version", versionBundle.Version), - ) + var duration time.Duration - _, err := exec.Exec(ctx, sql) - if err != nil { - return nil, fmt.Errorf("error applying version %03d [%s]: %w", - versionBundle.Version, strings.ToUpper(string(direction)), err) + if !opts.DryRun { + start := time.Now() + _, err := exec.Exec(ctx, sql) + if err != nil { + return nil, fmt.Errorf("error applying version %03d [%s]: %w", + versionBundle.Version, strings.ToUpper(string(direction)), err) + } + duration = time.Since(start) } - res.Versions = append(res.Versions, MigrateVersion{Version: versionBundle.Version}) - } + m.Logger.InfoContext(ctx, m.Name+": Applied migration", + slog.String("direction", string(direction)), + slog.Bool("dry_run", opts.DryRun), + slog.Duration("duration", duration), + slog.Int("version", versionBundle.Version), + ) - // Only prints if more steps than available were requested. - if opts.MaxSteps > 0 && len(res.Versions) < opts.MaxSteps { - m.Logger.InfoContext(ctx, m.Name+": No more migrations to apply") + res.Versions = append(res.Versions, MigrateVersion{Duration: duration, SQL: sql, Version: versionBundle.Version}) } return res, nil @@ -443,12 +487,12 @@ func (m *Migrator[TTx]) existingMigrations(ctx context.Context, exec riverdriver // Reads a series of migration bundles from a file system, which practically // speaking will always be the embedded FS read from the contents of the // `migration/` subdirectory. -func migrationsFromFS(migrationFS fs.FS) ([]*migrationBundle, error) { +func migrationsFromFS(migrationFS fs.FS) ([]*Migration, error) { const subdir = "migration" var ( - bundles []*migrationBundle - lastBundle *migrationBundle + bundles []*Migration + lastBundle *Migration ) err := fs.WalkDir(migrationFS, subdir, func(path string, entry fs.DirEntry, err error) error { @@ -478,7 +522,7 @@ func migrationsFromFS(migrationFS fs.FS) ([]*migrationBundle, error) { // This works because `fs.WalkDir` guarantees lexical order, so all 001* // files always appear before all 002* files, etc. if lastBundle == nil || lastBundle.Version != version { - lastBundle = &migrationBundle{Version: version} + lastBundle = &Migration{Version: version} bundles = append(bundles, lastBundle) } @@ -494,9 +538,9 @@ func migrationsFromFS(migrationFS fs.FS) ([]*migrationBundle, error) { switch { case strings.HasSuffix(name, ".down.sql"): - lastBundle.Down = string(contents) + lastBundle.SQLDown = string(contents) case strings.HasSuffix(name, ".up.sql"): - lastBundle.Up = string(contents) + lastBundle.SQLUp = string(contents) default: return fmt.Errorf("file %q should end with either '.down.sql' or '.up.sql'", name) } @@ -511,7 +555,7 @@ func migrationsFromFS(migrationFS fs.FS) ([]*migrationBundle, error) { } // Same as the above, but for convenience, panics on an error. -func mustMigrationsFromFS(migrationFS fs.FS) []*migrationBundle { +func mustMigrationsFromFS(migrationFS fs.FS) []*Migration { bundles, err := migrationsFromFS(migrationFS) if err != nil { panic(err) @@ -523,15 +567,15 @@ func mustMigrationsFromFS(migrationFS fs.FS) []*migrationBundle { // of configuration problems as new migrations are introduced. e.g. Checks for // missing fields or accidentally duplicated version numbers from copy/pasta // problems. -func validateAndInit(versions []*migrationBundle) map[int]*migrationBundle { +func validateAndInit(versions []*Migration) map[int]*Migration { lastVersion := 0 - migrations := make(map[int]*migrationBundle, len(versions)) + migrations := make(map[int]*Migration, len(versions)) for _, versionBundle := range versions { - if versionBundle.Down == "" { + if versionBundle.SQLDown == "" { panic(fmt.Sprintf("version bundle should specify Down: %+v", versionBundle)) } - if versionBundle.Up == "" { + if versionBundle.SQLUp == "" { panic(fmt.Sprintf("version bundle should specify Up: %+v", versionBundle)) } if versionBundle.Version == 0 { diff --git a/rivermigrate/river_migrate_test.go b/rivermigrate/river_migrate_test.go index ef99434f..82e2d5f8 100644 --- a/rivermigrate/river_migrate_test.go +++ b/rivermigrate/river_migrate_test.go @@ -30,16 +30,16 @@ var ( // numbers so that the tests don't break anytime we add a new one. riverMigrationsMaxVersion = riverMigrations[len(riverMigrations)-1].Version - testVersions = []*migrationBundle{ + testVersions = []*Migration{ { Version: riverMigrationsMaxVersion + 1, - Up: "CREATE TABLE test_table(id bigserial PRIMARY KEY);", - Down: "DROP TABLE test_table;", + SQLUp: "CREATE TABLE test_table(id bigserial PRIMARY KEY);", + SQLDown: "DROP TABLE test_table;", }, { Version: riverMigrationsMaxVersion + 2, - Up: "ALTER TABLE test_table ADD COLUMN name varchar(200); CREATE INDEX idx_test_table_name ON test_table(name);", - Down: "DROP INDEX idx_test_table_name; ALTER TABLE test_table DROP COLUMN name;", + SQLUp: "ALTER TABLE test_table ADD COLUMN name varchar(200); CREATE INDEX idx_test_table_name ON test_table(name);", + SQLDown: "DROP INDEX idx_test_table_name; ALTER TABLE test_table DROP COLUMN name;", }, } @@ -107,6 +107,15 @@ func TestMigrator(t *testing.T) { return migrator, tx } + t.Run("AllVersions", func(t *testing.T) { + t.Parallel() + + migrator, _ := setup(t) + + migrations := migrator.AllVersions() + require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), sliceutil.Map(migrations, migrationToInt)) + }) + t.Run("MigrateDownDefault", func(t *testing.T) { t.Parallel() @@ -165,7 +174,7 @@ func TestMigrator(t *testing.T) { migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion-2), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) err = dbExecError(ctx, bundle.driver.UnwrapExecutor(bundle.tx), "SELECT name FROM test_table") require.Error(t, err) @@ -186,7 +195,7 @@ func TestMigrator(t *testing.T) { migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) }) t.Run("MigrateDownWithDatabaseSQLDriver", func(t *testing.T) { @@ -202,7 +211,7 @@ func TestMigrator(t *testing.T) { migrations, err := migrator.driver.UnwrapExecutor(tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(2), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) }) t.Run("MigrateDownWithTargetVersion", func(t *testing.T) { @@ -221,7 +230,7 @@ func TestMigrator(t *testing.T) { migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) err = dbExecError(ctx, bundle.driver.UnwrapExecutor(bundle.tx), "SELECT name FROM test_table") require.Error(t, err) @@ -262,6 +271,45 @@ func TestMigrator(t *testing.T) { } }) + t.Run("MigrateDownDryRun", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{}) + require.NoError(t, err) + + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{DryRun: true}) + require.NoError(t, err) + require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + // Migrate down returned a result above for a migration that was + // removed, but because we're in a dry run, the database still shows + // this version. + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) + require.NoError(t, err) + require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), + sliceutil.Map(migrations, driverMigrationToInt)) + }) + + t.Run("GetVersion", func(t *testing.T) { + t.Parallel() + + migrator, _ := setup(t) + + { + migrateVersion, err := migrator.GetVersion(riverMigrationsWithTestVersionsMaxVersion) + require.NoError(t, err) + require.Equal(t, riverMigrationsWithTestVersionsMaxVersion, migrateVersion.Version) + } + + { + _, err := migrator.GetVersion(99_999) + availableVersions := seqOneTo(riverMigrationsWithTestVersionsMaxVersion) + require.EqualError(t, err, fmt.Sprintf("migration %d not found (available versions: %v)", 99_999, availableVersions)) + } + }) + t.Run("MigrateNilOpts", func(t *testing.T) { t.Parallel() @@ -288,7 +336,7 @@ func TestMigrator(t *testing.T) { migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") require.NoError(t, err) @@ -304,7 +352,7 @@ func TestMigrator(t *testing.T) { migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") require.NoError(t, err) @@ -324,7 +372,7 @@ func TestMigrator(t *testing.T) { migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion-1), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) // Column `name` is only added in the second test version. err = dbExecError(ctx, bundle.driver.UnwrapExecutor(bundle.tx), "SELECT name FROM test_table") @@ -350,7 +398,7 @@ func TestMigrator(t *testing.T) { migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(3), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) }) t.Run("MigrateUpWithDatabaseSQLDriver", func(t *testing.T) { @@ -366,7 +414,7 @@ func TestMigrator(t *testing.T) { migrations, err := migrator.driver.UnwrapExecutor(tx).MigrationGetAll(ctx) require.NoError(t, err) require.Equal(t, seqOneTo(riverMigrationsMaxVersion+1), - sliceutil.Map(migrations, migrationToInt)) + sliceutil.Map(migrations, driverMigrationToInt)) }) t.Run("MigrateUpWithTargetVersion", func(t *testing.T) { @@ -381,7 +429,7 @@ func TestMigrator(t *testing.T) { migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) require.NoError(t, err) - require.Equal(t, seqOneTo(5), sliceutil.Map(migrations, migrationToInt)) + require.Equal(t, seqOneTo(5), sliceutil.Map(migrations, driverMigrationToInt)) }) t.Run("MigrateUpWithTargetVersionInvalid", func(t *testing.T) { @@ -402,6 +450,26 @@ func TestMigrator(t *testing.T) { } }) + t.Run("MigrateUpDryRun", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{DryRun: true}) + require.NoError(t, err) + require.Equal(t, DirectionUp, res.Direction) + require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion - 1, riverMigrationsWithTestVersionsMaxVersion}, + sliceutil.Map(res.Versions, migrateVersionToInt)) + + // Migrate up returned a result above for migrations that were applied, + // but because we're in a dry run, the database still shows the test + // migration versions not applied. + migrations, err := bundle.driver.UnwrapExecutor(bundle.tx).MigrationGetAll(ctx) + require.NoError(t, err) + require.Equal(t, seqOneTo(riverMigrationsMaxVersion), + sliceutil.Map(migrations, driverMigrationToInt)) + }) + t.Run("ValidateSuccess", func(t *testing.T) { t.Parallel() @@ -439,7 +507,7 @@ func dbExecError(ctx context.Context, exec riverdriver.Executor, sql string) err }) } -func migrationToInt(r *riverdriver.Migration) int { return r.Version } +func driverMigrationToInt(r *riverdriver.Migration) int { return r.Version } func seqOneTo(max int) []int { seq := make([]int, max)