diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index 38fc3522..09f259bd 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -29,18 +29,30 @@ var backupCmd = &cobra.Command{ RunE: backupCmdFunc, } -func backupCmdFunc(cmd *cobra.Command, args []string) error { - filename := args[0] +func createBackupFile(filename string) (*os.File, error) { + if filename == "-" { + log.Trace().Str("filename", "- (stdout)").Send() + return os.Stdout, nil + } log.Trace().Str("filename", filename).Send() if _, err := os.Stat(filename); err == nil { - return fmt.Errorf("backup file already exists: %s", filename) + return nil, fmt.Errorf("backup file already exists: %s", filename) } f, err := os.Create(filename) if err != nil { - return fmt.Errorf("unable to create backup file: %w", err) + return nil, fmt.Errorf("unable to create backup file: %w", err) + } + + return f, nil +} + +func backupCmdFunc(cmd *cobra.Command, args []string) error { + f, err := createBackupFile(args[0]) + if err != nil { + return err } client, err := client.NewClient(cmd) @@ -116,8 +128,10 @@ func backupCmdFunc(cmd *cobra.Command, args []string) error { return fmt.Errorf("error closing backup encoder: %w", err) } - if err := f.Sync(); err != nil { - return fmt.Errorf("error syncing backup file: %w", err) + if f != os.Stdout { + if err := f.Sync(); err != nil { + return fmt.Errorf("error syncing backup file: %w", err) + } } if err := f.Close(); err != nil { diff --git a/internal/cmd/restore.go b/internal/cmd/restore.go index 734e424a..82c212d3 100644 --- a/internal/cmd/restore.go +++ b/internal/cmd/restore.go @@ -28,23 +28,41 @@ func registerRestoreCmd(rootCmd *cobra.Command) { var restoreCmd = &cobra.Command{ Use: "restore ", Short: "Restore a permission system from a file", - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), RunE: restoreCmdFunc, } -func restoreCmdFunc(cmd *cobra.Command, args []string) error { - filename := args[0] +func openRestoreFile(filename string) (*os.File, int64, error) { + if filename == "" { + log.Trace().Str("filename", "(stdin)").Send() + return os.Stdin, -1, nil + } log.Trace().Str("filename", filename).Send() stats, err := os.Stat(filename) if err != nil { - return fmt.Errorf("unable to stat restore file: %w", err) + return nil, 0, fmt.Errorf("unable to stat restore file: %w", err) } f, err := os.Open(filename) if err != nil { - return fmt.Errorf("unable to open restore file: %w", err) + return nil, 0, fmt.Errorf("unable to open restore file: %w", err) + } + + return f, stats.Size(), nil +} + +func restoreCmdFunc(cmd *cobra.Command, args []string) error { + filename := "" // Default to stdin. + + if len(args) > 0 { + filename = args[0] + } + + f, fSize, err := openRestoreFile(filename) + if err != nil { + return err } printZTOnly := cobrautil.MustGetBool(cmd, "print-zedtoken-only") @@ -52,7 +70,7 @@ func restoreCmdFunc(cmd *cobra.Command, args []string) error { var hasProgressbar bool var restoreReader io.Reader = f if isatty.IsTerminal(os.Stderr.Fd()) && !printZTOnly { - bar := progressbar.DefaultBytes(stats.Size(), "restoring") + bar := progressbar.DefaultBytes(fSize, "restoring") restoreReader = io.TeeReader(f, bar) hasProgressbar = true }