Skip to content

Commit

Permalink
Merge pull request #316 from authzed/backup-import-with-retries-and-c…
Browse files Browse the repository at this point in the history
…onflict-handling

refactor backup restore to handle serialization errors and conflicts
  • Loading branch information
vroldanbet authored Jan 2, 2024
2 parents 3be450c + 461353e commit 4fb2ba0
Show file tree
Hide file tree
Showing 9 changed files with 821 additions and 165 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/authzed/grpcutil v0.0.0-20230908193239-4286bb1d6403
github.com/authzed/spicedb v1.28.1-0.20231208003000-90be4e6762da
github.com/brianvoe/gofakeit/v6 v6.26.3
github.com/cenkalti/backoff/v4 v4.2.1
github.com/charmbracelet/lipgloss v0.9.1
github.com/google/uuid v1.4.0
github.com/gookit/color v1.5.4
Expand All @@ -28,6 +29,7 @@ require (
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.4
github.com/xlab/treeprint v1.2.0
golang.org/x/exp v0.0.0-20231206192017-f3f8817b8deb
golang.org/x/mod v0.14.0
golang.org/x/sync v0.5.0
golang.org/x/term v0.15.0
Expand Down Expand Up @@ -57,7 +59,6 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/bits-and-blooms/bitset v1.10.0 // indirect
github.com/bits-and-blooms/bloom/v3 v3.6.0 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
Expand Down Expand Up @@ -189,7 +190,6 @@ require (
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/exp v0.0.0-20231206192017-f3f8817b8deb // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/oauth2 v0.15.0 // indirect
golang.org/x/sys v0.15.0 // indirect
Expand Down
161 changes: 47 additions & 114 deletions internal/cmd/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"github.com/rs/zerolog/log"
"github.com/schollz/progressbar/v3"
"github.com/spf13/cobra"
"golang.org/x/exp/constraints"
"golang.org/x/exp/maps"

"github.com/authzed/zed/internal/client"
"github.com/authzed/zed/internal/commands"
Expand Down Expand Up @@ -79,25 +81,24 @@ var (

func registerBackupCmd(rootCmd *cobra.Command) {
rootCmd.AddCommand(backupCmd)
registerBackupCreateFlags(backupCmd)

backupCmd.AddCommand(backupCreateCmd)
backupCreateCmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
backupCreateCmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
registerBackupCreateFlags(backupCreateCmd)

backupCmd.AddCommand(backupRestoreCmd)
backupRestoreCmd.Flags().Int("batch-size", 1_000, "restore relationship write batch size")
backupRestoreCmd.Flags().Int64("batches-per-transaction", 10, "number of batches per transaction")
backupRestoreCmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
backupRestoreCmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
registerBackupRestoreFlags(backupRestoreCmd)

// Restore used to be on the root, so add it there too, but hidden.
rootCmd.AddCommand(&cobra.Command{
restoreCmd := &cobra.Command{
Use: "restore <filename>",
Short: "Restore a permission system from a file",
Args: cobra.MaximumNArgs(1),
RunE: backupRestoreCmdFunc,
Hidden: true,
})
}
rootCmd.AddCommand(restoreCmd)
registerBackupRestoreFlags(restoreCmd)

backupCmd.AddCommand(backupParseSchemaCmd)
backupParseSchemaCmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
Expand All @@ -108,6 +109,21 @@ func registerBackupCmd(rootCmd *cobra.Command) {
backupParseRelsCmd.Flags().String("prefix-filter", "", "Include only relationships with a given prefix")
}

func registerBackupRestoreFlags(cmd *cobra.Command) {
cmd.Flags().Int("batch-size", 1_000, "restore relationship write batch size")
cmd.Flags().Uint("batches-per-transaction", 10, "number of batches per transaction")
cmd.Flags().String("conflict-strategy", "fail", "strategy used when a conflicting relationship is found. Possible values: fail, skip, touch")
cmd.Flags().Bool("disable-retries", false, "retries when an errors is determined to be retryable (e.g. serialization errors)")
cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
cmd.Flags().Duration("request-timeout", 30*time.Second, "timeout for each request performed during restore")
}

func registerBackupCreateFlags(cmd *cobra.Command) {
cmd.Flags().String("prefix-filter", "", "include only schema and relationships with a given prefix")
cmd.Flags().Bool("rewrite-legacy", false, "potentially modify the schema to exclude legacy/broken syntax")
}

func createBackupFile(filename string) (*os.File, error) {
if filename == "-" {
log.Trace().Str("filename", "- (stdout)").Send()
Expand Down Expand Up @@ -347,7 +363,7 @@ func openRestoreFile(filename string) (*os.File, int64, error) {
}

func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error {
decoder, closer, err := decoderFromArgs(cmd, args)
decoder, closer, err := decoderFromArgs(args...)
if err != nil {
return err
}
Expand Down Expand Up @@ -377,122 +393,39 @@ func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error {
}
log.Debug().Str("schema", schema).Bool("filtered", prefixFilter != "").Msg("parsed schema")

client, err := client.NewClient(cmd)
c, err := client.NewClient(cmd)
if err != nil {
return fmt.Errorf("unable to initialize client: %w", err)
}

ctx := cmd.Context()
if _, err := client.WriteSchema(ctx, &v1.WriteSchemaRequest{
Schema: schema,
}); err != nil {
return fmt.Errorf("unable to write schema: %w", err)
}

relationshipWriteStart := time.Now()

relationshipWriter, err := client.BulkImportRelationships(ctx)
if err != nil {
return fmt.Errorf("error creating writer stream: %w", err)
}

batchSize := cobrautil.MustGetInt(cmd, "batch-size")
batchesPerTransaction := cobrautil.MustGetInt64(cmd, "batches-per-transaction")

batch := make([]*v1.Relationship, 0, batchSize)
var written, batchesWritten int64
bar := relProgressBar("restoring from backup")
for rel, err := decoder.Next(); rel != nil && err == nil; rel, err = decoder.Next() {
if err := ctx.Err(); err != nil {
return fmt.Errorf("aborted restore: %w", err)
}

if !hasRelPrefix(rel, prefixFilter) {
continue
}

batch = append(batch, rel)

if len(batch)%batchSize == 0 {
if err := relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{
Relationships: batch,
}); err != nil {
_, closeErr := relationshipWriter.CloseAndRecv()
return fmt.Errorf("error sending batch to server: %w", errors.Join(err, closeErr))
}

// Reset the relationships in the batch
batch = batch[:0]
batchesWritten++
batchesPerTransaction := cobrautil.MustGetUint(cmd, "batches-per-transaction")

if batchesWritten%batchesPerTransaction == 0 {
resp, err := relationshipWriter.CloseAndRecv()
if err != nil {
return fmt.Errorf("error finalizing write of %d batches: %w", batchesPerTransaction, err)
}
written += int64(resp.NumLoaded)
if err := bar.Set64(written); err != nil {
return fmt.Errorf("error incrementing progress bar: %w", err)
}

if !isatty.IsTerminal(os.Stderr.Fd()) {
log.Trace().
Int64("batches", batchesWritten).
Int64("relationships", written).
Msg("restore progress")
}

relationshipWriter, err = client.BulkImportRelationships(ctx)
if err != nil {
return fmt.Errorf("error creating new writer stream: %w", err)
}
}
}
}

// Write the last batch
if err := relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{
Relationships: batch,
}); err != nil {
return fmt.Errorf("error sending last batch to server: %w", err)
}

// Finish the stream
resp, err := relationshipWriter.CloseAndRecv()
strategy, err := GetEnum[ConflictStrategy](cmd, "conflict-strategy", conflictStrategyMapping)
if err != nil {
return fmt.Errorf("error finalizing last write: %w", err)
}
batchesWritten++
written += int64(resp.NumLoaded)
if err := bar.Set64(written); err != nil {
return fmt.Errorf("error incrementing progress bar: %w", err)
}
totalTime := time.Since(relationshipWriteStart)

if err := bar.Finish(); err != nil {
return fmt.Errorf("error finalizing progress bar: %w", err)
return err
}
disableRetries := cobrautil.MustGetBool(cmd, "disable-retries")
requestTimeout := cobrautil.MustGetDuration(cmd, "request-timeout")

log.Info().
Int64("batches", batchesWritten).
Int64("relationships", written).
Uint64("perSecond", perSec(uint64(written), totalTime)).
Stringer("duration", totalTime).
Msg("finished restore")

return nil
return newRestorer(schema, decoder, c, prefixFilter, batchSize, batchesPerTransaction, strategy,
disableRetries, requestTimeout).restoreFromDecoder(cmd.Context())
}

func perSec(i uint64, d time.Duration) uint64 {
secs := uint64(d.Seconds())
if secs == 0 {
return i
// GetEnum is a helper for getting an enum value from a string cobra flag.
func GetEnum[E constraints.Integer](cmd *cobra.Command, name string, mapping map[string]E) (E, error) {
value := cobrautil.MustGetString(cmd, name)
value = strings.TrimSpace(strings.ToLower(value))
if enum, ok := mapping[value]; ok {
return enum, nil
}
return i / secs

var zeroValueE E
return zeroValueE, fmt.Errorf("unexpected flag '%s' value '%s': should be one of %v", name, value, maps.Keys(mapping))
}

func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error {
decoder, closer, err := decoderFromArgs(cmd, args)
decoder, closer, err := decoderFromArgs(args...)
if err != nil {
return err
}
Expand Down Expand Up @@ -520,8 +453,8 @@ func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string)
return err
}

func backupParseRevisionCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error {
decoder, closer, err := decoderFromArgs(cmd, args)
func backupParseRevisionCmdFunc(_ *cobra.Command, out io.Writer, args []string) error {
decoder, closer, err := decoderFromArgs(args...)
if err != nil {
return err
}
Expand All @@ -540,7 +473,7 @@ func backupParseRevisionCmdFunc(cmd *cobra.Command, out io.Writer, args []string

func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error {
prefix := cobrautil.MustGetString(cmd, "prefix-filter")
decoder, closer, err := decoderFromArgs(cmd, args)
decoder, closer, err := decoderFromArgs(args...)
if err != nil {
return err
}
Expand All @@ -566,7 +499,7 @@ func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) er
return nil
}

func decoderFromArgs(_ *cobra.Command, args []string) (*backupformat.Decoder, io.Closer, error) {
func decoderFromArgs(args ...string) (*backupformat.Decoder, io.Closer, error) {
filename := "" // Default to stdin.
if len(args) > 0 {
filename = args[0]
Expand Down
13 changes: 11 additions & 2 deletions internal/cmd/backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"

"github.com/authzed/zed/internal/client"

Expand Down Expand Up @@ -309,7 +310,7 @@ func TestBackupCreateCmdFunc(t *testing.T) {
err = backupCreateCmdFunc(cmd, []string{f})
require.NoError(t, err)

d, closer, err := decoderFromArgs(cmd, []string{f})
d, closer, err := decoderFromArgs(f)
require.NoError(t, err)
defer func() {
_ = d.Close()
Expand All @@ -323,12 +324,20 @@ func TestBackupCreateCmdFunc(t *testing.T) {
require.Equal(t, resp.WrittenAt.Token, d.ZedToken().Token)
}

type durationFlag struct {
flagName string
flagValue time.Duration
}

func TestBackupRestoreCmdFunc(t *testing.T) {
cmd := createTestCobraCommandWithFlagValue(t,
stringFlag{"prefix-filter", "test"},
boolFlag{"rewrite-legacy", false},
stringFlag{"conflict-strategy", "fail"},
boolFlag{"disable-retries", false},
intFlag{"batch-size", 100},
int64Flag{"batches-per-transaction", 10},
uintFlag{"batches-per-transaction", 10},
durationFlag{"request-timeout", 0},
)
backupName := createTestBackup(t, testSchema, testRelationships)

Expand Down
18 changes: 17 additions & 1 deletion internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@ package cmd

import (
"context"
"errors"
"os"
"os/signal"
"syscall"

"github.com/jzelinskie/cobrautil/v2"
"github.com/jzelinskie/cobrautil/v2/cobrazerolog"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"

"github.com/authzed/zed/internal/commands"
)

var SyncFlagsCmdFunc = cobrautil.SyncViperPreRunE("ZED")
var (
SyncFlagsCmdFunc = cobrautil.SyncViperPreRunE("ZED")
errParsing = errors.New("parsing error")
)

func Run() {
zl := cobrazerolog.New(cobrazerolog.WithPreRunLevel(zerolog.DebugLevel))
Expand All @@ -27,7 +32,14 @@ func Run() {
zl.RunE(),
SyncFlagsCmdFunc,
),
SilenceErrors: true,
SilenceUsage: true,
}
rootCmd.SetFlagErrorFunc(func(cmd *cobra.Command, err error) error {
cmd.Println(err)
cmd.Println(cmd.UsageString())
return errParsing
})

zl.RegisterFlags(rootCmd.PersistentFlags())

Expand Down Expand Up @@ -91,6 +103,10 @@ func Run() {
}()

if err := rootCmd.ExecuteContext(ctx); err != nil {
if !errors.Is(err, errParsing) {
log.Err(err).Msg("terminated with errors")
}

os.Exit(1)
}
}
10 changes: 6 additions & 4 deletions internal/cmd/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ type intFlag struct {
flagValue int
}

type int64Flag struct {
type uintFlag struct {
flagName string
flagValue int64
flagValue uint
}

func createTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *cobra.Command {
Expand All @@ -81,8 +81,10 @@ func createTestCobraCommandWithFlagValue(t *testing.T, flagAndValues ...any) *co
c.Flags().Bool(f.flagName, f.flagValue, "")
case intFlag:
c.Flags().Int(f.flagName, f.flagValue, "")
case int64Flag:
c.Flags().Int64(f.flagName, f.flagValue, "")
case uintFlag:
c.Flags().Uint(f.flagName, f.flagValue, "")
case durationFlag:
c.Flags().Duration(f.flagName, f.flagValue, "")
default:
t.Fatalf("unknown flag type: %T", f)
}
Expand Down
Loading

0 comments on commit 4fb2ba0

Please sign in to comment.