From 068d9146f36b7cafec1e1a970e3ad18c14ef95ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Rold=C3=A1n=20Betancort?= Date: Tue, 19 Dec 2023 10:31:48 +0000 Subject: [PATCH] introduce restorer tests and adjust implementation accordingly --- internal/cmd/backup.go | 12 +- internal/cmd/backup_test.go | 2 +- internal/cmd/restorer.go | 111 ++++++++++------ internal/cmd/restorer_test.go | 242 ++++++++++++++++++++++++++++++++++ 4 files changed, 318 insertions(+), 49 deletions(-) create mode 100644 internal/cmd/restorer_test.go diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index 9dce090b..add94092 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -350,7 +350,7 @@ func openRestoreFile(filename string) (*os.File, int64, error) { } func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error { - decoder, closer, err := decoderFromArgs(cmd, args) + decoder, closer, err := decoderFromArgs(args...) if err != nil { return err } @@ -406,7 +406,7 @@ func backupRestoreCmdFunc(cmd *cobra.Command, args []string) error { } func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { - decoder, closer, err := decoderFromArgs(cmd, args) + decoder, closer, err := decoderFromArgs(args...) if err != nil { return err } @@ -434,8 +434,8 @@ func backupParseSchemaCmdFunc(cmd *cobra.Command, out io.Writer, args []string) return err } -func backupParseRevisionCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { - decoder, closer, err := decoderFromArgs(cmd, args) +func backupParseRevisionCmdFunc(_ *cobra.Command, out io.Writer, args []string) error { + decoder, closer, err := decoderFromArgs(args...) if err != nil { return err } @@ -454,7 +454,7 @@ func backupParseRevisionCmdFunc(cmd *cobra.Command, out io.Writer, args []string func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) error { prefix := cobrautil.MustGetString(cmd, "prefix-filter") - decoder, closer, err := decoderFromArgs(cmd, args) + decoder, closer, err := decoderFromArgs(args...) if err != nil { return err } @@ -480,7 +480,7 @@ func backupParseRelsCmdFunc(cmd *cobra.Command, out io.Writer, args []string) er return nil } -func decoderFromArgs(_ *cobra.Command, args []string) (*backupformat.Decoder, io.Closer, error) { +func decoderFromArgs(args ...string) (*backupformat.Decoder, io.Closer, error) { filename := "" // Default to stdin. if len(args) > 0 { filename = args[0] diff --git a/internal/cmd/backup_test.go b/internal/cmd/backup_test.go index de33a791..22324dca 100644 --- a/internal/cmd/backup_test.go +++ b/internal/cmd/backup_test.go @@ -309,7 +309,7 @@ func TestBackupCreateCmdFunc(t *testing.T) { err = backupCreateCmdFunc(cmd, []string{f}) require.NoError(t, err) - d, closer, err := decoderFromArgs(cmd, []string{f}) + d, closer, err := decoderFromArgs(f) require.NoError(t, err) defer func() { _ = d.Close() diff --git a/internal/cmd/restorer.go b/internal/cmd/restorer.go index 92c5ce59..35309010 100644 --- a/internal/cmd/restorer.go +++ b/internal/cmd/restorer.go @@ -20,6 +20,12 @@ import ( "github.com/authzed/zed/pkg/backupformat" ) +// FIXME temporary hack until a proper error is exposed from the API, specific to CRDB +var ( + txConflictCodes = []string{"SQLSTATE 23505"} + retryableErrorCodes = []string{"retryable error"} +) + type restorer struct { decoder *backupformat.Decoder client client.Client @@ -32,11 +38,14 @@ type restorer struct { bar *progressbar.ProgressBar // stats - relsWritten int64 - batchesWritten int64 - relsSkipped int64 - duplicateRels int64 - totalRetries int64 + filteredOutRels int64 + writtenRels int64 + writtenBatches int64 + skippedRels int64 + skippedBatches int64 + duplicateRels int64 + duplicateBatches int64 + totalRetries int64 } func newRestorer(decoder *backupformat.Decoder, client client.Client, prefixFilter string, batchSize int, @@ -59,7 +68,7 @@ func (r *restorer) restoreFromDecoder(ctx context.Context) error { relationshipWriteStart := time.Now() defer func() { if err := r.bar.Finish(); err != nil { - log.Err(err).Msg("error finalizing progress bar") + log.Warn().Err(err).Msg("error finalizing progress bar") } }() @@ -72,10 +81,12 @@ func (r *restorer) restoreFromDecoder(ctx context.Context) error { batchesToBeCommitted := make([][]*v1.Relationship, 0, r.batchesPerTransaction) for rel, err := r.decoder.Next(); rel != nil && err == nil; rel, err = r.decoder.Next() { if err := ctx.Err(); err != nil { + r.bar.Describe("backup restore aborted") return fmt.Errorf("aborted restore: %w", err) } if !hasRelPrefix(rel, r.prefixFilter) { + r.filteredOutRels++ continue } @@ -87,8 +98,6 @@ func (r *restorer) restoreFromDecoder(ctx context.Context) error { Relationships: batch, }) if err != nil { - r.totalRetries++ - // It feels non-idiomatic to check for error and perform an operation, but in gRPC, when an element // sent over the stream fails, we need to call recvAndClose() to get the error. if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil { @@ -106,12 +115,12 @@ func (r *restorer) restoreFromDecoder(ctx context.Context) error { continue } - // Reset the relationships in the batch. Do not reuse in case failure happens on subsequent batch in the tx + // The batch just sent is kept in batchesToBeCommitted, which is used for retries. + // Therefore, we cannot reuse the batch. Batches may fail on send, or on commit (CloseAndRecv). batch = make([]*v1.Relationship, 0, r.batchSize) - r.batchesWritten++ // if we've sent the maximum number of batches per transaction, proceed to commit - if r.batchesWritten%r.batchesPerTransaction != 0 { + if int64(len(batchesToBeCommitted))%r.batchesPerTransaction != 0 { continue } @@ -130,30 +139,32 @@ func (r *restorer) restoreFromDecoder(ctx context.Context) error { // Write the last batch if len(batch) > 0 { - if err := relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{ - Relationships: batch, - }); err != nil { - return fmt.Errorf("error sending last batch to server: %w", err) - } + // Since we are going to close the stream anyway after the last batch, and given the actual error + // is only returned on CloseAndRecv(), we have to ignore the error here in order to get the actual + // underlying error that caused Send() to fail. It also gives us the opportunity to retry it + // in case it failed. + batchesToBeCommitted = append(batchesToBeCommitted, batch) + _ = relationshipWriter.Send(&v1.BulkImportRelationshipsRequest{Relationships: batch}) } if err := r.commitStream(ctx, relationshipWriter, batchesToBeCommitted); err != nil { return fmt.Errorf("error committing last set of batches: %w", err) } - r.bar.Describe("complected import") + r.bar.Describe("completed import") if err := r.bar.Finish(); err != nil { - log.Err(err).Msg("error finalizing progress bar") + log.Warn().Err(err).Msg("error finalizing progress bar") } totalTime := time.Since(relationshipWriteStart) log.Info(). - Int64("batches", r.batchesWritten). - Int64("relationships_loaded", r.relsWritten). - Int64("relationships_skipped", r.relsSkipped). + Int64("batches", r.writtenBatches). + Int64("relationships_loaded", r.writtenRels). + Int64("relationships_skipped", r.skippedRels). Int64("duplicate_relationships", r.duplicateRels). + Int64("relationships_filtered_out", r.filteredOutRels). Int64("retried_errors", r.totalRetries). - Uint64("perSecond", perSec(uint64(r.relsWritten), totalTime)). + Uint64("perSecond", perSec(uint64(r.writtenRels+r.skippedRels), totalTime)). Stringer("duration", totalTime). Msg("finished restore") return nil @@ -183,54 +194,63 @@ func (r *restorer) commitStream(ctx context.Context, bulkImportClient v1.Experim case retryable && r.disableRetryErrors: return err case conflict && r.skipOnConflicts: - r.relsSkipped += int64(expectedLoaded) + r.skippedRels += int64(expectedLoaded) + r.skippedBatches += int64(len(batchesToBeCommitted)) + r.duplicateBatches += int64(len(batchesToBeCommitted)) r.duplicateRels += int64(expectedLoaded) - numLoaded = expectedLoaded r.bar.Describe("skipping conflicting batch") case conflict && r.touchOnConflicts: r.bar.Describe("retrying conflicting batch") r.duplicateRels += int64(expectedLoaded) + r.duplicateBatches += int64(len(batchesToBeCommitted)) + r.totalRetries++ numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted) if err != nil { return fmt.Errorf("failed to write retried batch: %w", err) } - case conflict && !r.touchOnConflicts: + + retries++ // account for the initial attempt + r.writtenBatches += int64(len(batchesToBeCommitted)) + r.writtenRels += int64(numLoaded) + case conflict && (!r.touchOnConflicts && !r.skipOnConflicts): r.bar.Describe("conflict detected, aborting restore") return fmt.Errorf("duplicate relationships found") case retryable: r.bar.Describe("retrying after error") + r.totalRetries++ numLoaded, retries, err = r.writeBatchesWithRetry(ctx, batchesToBeCommitted) if err != nil { return fmt.Errorf("failed to write retried batch: %w", err) } + + retries++ // account for the initial attempt + r.writtenBatches += int64(len(batchesToBeCommitted)) + r.writtenRels += int64(numLoaded) default: r.bar.Describe("restoring from backup") + r.writtenBatches += int64(len(batchesToBeCommitted)) } // it was a successful transaction commit without duplicates if resp != nil { - numLoaded = resp.NumLoaded - - var expected uint64 - for _, b := range batchesToBeCommitted { - expected += uint64(len(b)) - } - - if expected != numLoaded { - log.Warn().Uint64("loaded", numLoaded).Uint64("expected", expected).Msg("unexpected number of relationships loaded") + r.writtenRels += int64(resp.NumLoaded) + if expectedLoaded != resp.NumLoaded { + log.Warn().Uint64("loaded", resp.NumLoaded).Uint64("expected", expectedLoaded).Msg("unexpected number of relationships loaded") } } - r.relsWritten += int64(numLoaded) - if err := r.bar.Set64(r.relsWritten); err != nil { + if err := r.bar.Set64(r.writtenRels + r.skippedRels); err != nil { return fmt.Errorf("error incrementing progress bar: %w", err) } if !isatty.IsTerminal(os.Stderr.Fd()) { log.Trace(). - Int64("batches_written", r.batchesWritten). - Int64("relationships_written", r.relsWritten). + Int64("batches_written", r.writtenBatches). + Int64("relationships_written", r.writtenRels). + Int64("duplicate_batches", r.duplicateBatches). Int64("duplicate_relationships", r.duplicateRels). + Int64("skipped_batches", r.skippedBatches). + Int64("skipped_relationships", r.skippedRels). Uint64("retries", retries). Msg("restore progress") } @@ -292,8 +312,13 @@ func isAlreadyExistsError(err error) bool { } } - // FIXME temporary hack until a proper error is exposed from the API, specific to CRDB - return strings.Contains(err.Error(), "SQLSTATE 23505") + for _, code := range txConflictCodes { + if strings.Contains(err.Error(), code) { + return true + } + } + + return false } func isRetryableError(err error) bool { @@ -301,8 +326,10 @@ func isRetryableError(err error) bool { return false } - if strings.Contains(err.Error(), "RETRY_SERIALIZABLE") { // FIXME hack until SpiceDB exposes proper typed err - return true + for _, code := range retryableErrorCodes { + if strings.Contains(err.Error(), code) { + return true + } } return false diff --git a/internal/cmd/restorer_test.go b/internal/cmd/restorer_test.go new file mode 100644 index 00000000..def2ecd1 --- /dev/null +++ b/internal/cmd/restorer_test.go @@ -0,0 +1,242 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "testing" + + v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" + "github.com/authzed/spicedb/pkg/tuple" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + + "github.com/authzed/zed/internal/client" +) + +var unrecoverableError = []string{"unrecoverable"} + +func TestRestorer(t *testing.T) { + for _, tt := range []struct { + name string + prefixFilter string + batchSize int + batchesPerTransaction int64 + skipOnConflicts bool + touchOnConflicts bool + disableRetryErrors bool + sendErrors []string + commitErrors []string + touchErrors []string + relationships []string + }{ + {"honors batch size = 1", "", 1, 1, false, false, false, nil, nil, nil, testRelationships}, + {"correctly handles remainder batch", "", 2, 1, false, false, false, nil, nil, nil, testRelationships}, + {"correctly handles batch size == total rels", "", 3, 1, false, false, false, nil, nil, nil, testRelationships}, + {"correctly handles batch size > total rels", "", 4, 1, false, false, false, nil, nil, nil, testRelationships}, + {"correctly handles empty set", "", 1, 1, false, false, false, nil, nil, nil, nil}, + {"skips conflicting writes when skipOnConflict is enabled", "", 1, 1, true, false, false, nil, txConflictCodes, nil, testRelationships}, + {"applies touch when touchOnConflict is enabled", "", 1, 1, false, true, false, nil, txConflictCodes, nil, testRelationships}, + {"skips on conflict when skipOnConflict is enabled", "", 2, 1, true, false, false, nil, txConflictCodes, nil, testRelationships}, + {"failed batches are written individually when touchOnConflict is enabled", "", 1, 2, false, true, false, nil, txConflictCodes, nil, testRelationships}, + {"fails on conflict if touchOnConflict=false && skipOnConflict=false", "", 1, 1, false, false, false, txConflictCodes, nil, nil, testRelationships}, + {"fails on unexpected commit error", "", 1, 1, false, false, false, nil, []string{"unexpected"}, nil, testRelationships}, + {"retries commit retryable errors", "", 1, 1, false, false, false, nil, retryableErrorCodes, nil, testRelationships}, + {"retries on conflict when fallback WriteRelationships fails", "", 1, 1, false, true, false, nil, txConflictCodes, retryableErrorCodes, testRelationships}, + {"returns error on retryable error if retries are disabled", "", 1, 1, false, false, true, nil, retryableErrorCodes, nil, testRelationships}, + {"fails fast if conflict-triggered touch fails with an unrecoverable error", "", 1, 1, false, true, false, nil, txConflictCodes, unrecoverableError, testRelationships}, + {"retries if error happens right after sending a batch over the stream", "", 1, 1, false, true, false, txConflictCodes, txConflictCodes, nil, testRelationships}, + {"filters relationships", "test", 1, 1, false, false, false, nil, nil, nil, append([]string{"foo/resource:1#reader@foo/user:1"}, testRelationships...)}, + {"handles gracefully all rels as filtered", "invalid", 1, 1, false, false, false, nil, nil, nil, testRelationships}, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + backupFileName := createTestBackup(t, testSchema, tt.relationships) + d, closer, err := decoderFromArgs(backupFileName) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, closer.Close()) + require.NoError(t, os.Remove(backupFileName)) + }) + + expectedFilteredRels := make([]string, 0, len(tt.relationships)) + for _, rel := range tt.relationships { + if !hasRelPrefix(tuple.ParseRel(rel), tt.prefixFilter) { + continue + } + + expectedFilteredRels = append(expectedFilteredRels, rel) + } + + expectedBatches := len(expectedFilteredRels) / tt.batchSize + // there is always one extra commit, regardless there is or not a remainder batch + expectedCommits := expectedBatches/int(tt.batchesPerTransaction) + 1 + remainderBatch := false + if len(expectedFilteredRels)%tt.batchSize != 0 { + expectedBatches++ + remainderBatch = true + } + + c := &mockClient{ + t: t, + remainderBatch: remainderBatch, + expectedRels: expectedFilteredRels, + expectedBatches: expectedBatches, + requestedBatchSize: tt.batchSize, + requestedBatchesPerTransaction: tt.batchesPerTransaction, + commitErrors: tt.commitErrors, + touchErrors: tt.touchErrors, + sendErrors: tt.sendErrors, + } + + expectedConflicts := 0 + expectedRetries := 0 + var expectsError string + for _, err := range tt.commitErrors { + if isRetryableError(errors.New(err)) { + expectedRetries++ + if tt.disableRetryErrors { + expectsError = err + } + } else if isAlreadyExistsError(errors.New(err)) { + expectedConflicts++ + } else { + expectsError = err + } + } + for _, err := range tt.touchErrors { + if isRetryableError(errors.New(err)) { + expectedRetries++ + if tt.disableRetryErrors { + expectsError = err + } + } else { + expectsError = err + } + } + + // if skip is enabled, there will be N less relationships written, where N is the number of conflicts + expectedWrittenRels := len(expectedFilteredRels) + if tt.skipOnConflicts { + expectedWrittenRels -= expectedConflicts * tt.batchSize + } + + expectedWrittenBatches := len(expectedFilteredRels) / tt.batchSize + if tt.skipOnConflicts { + expectedWrittenBatches -= expectedConflicts + } + if remainderBatch { + expectedWrittenBatches++ + } + + expectedTouchedBatches := expectedRetries + expectedTouchedRels := expectedRetries * tt.batchSize + if tt.touchOnConflicts { + expectedTouchedBatches += expectedConflicts * int(tt.batchesPerTransaction) + expectedTouchedRels += expectedConflicts * int(tt.batchesPerTransaction) * tt.batchSize + } + + expectedSkippedBatches := 0 + expectedSkippedRels := 0 + if tt.skipOnConflicts { + expectedSkippedBatches += expectedConflicts + expectedSkippedRels += expectedConflicts * tt.batchSize + } + + r := newRestorer(d, c, tt.prefixFilter, tt.batchSize, tt.batchesPerTransaction, tt.skipOnConflicts, tt.touchOnConflicts, tt.disableRetryErrors) + err = r.restoreFromDecoder(context.Background()) + if expectsError != "" || (expectedConflicts > 0 && !tt.skipOnConflicts && !tt.touchOnConflicts) { + require.ErrorContains(t, err, expectsError) + return + } + + require.NoError(t, err) + + // assert on mock stats + require.Equal(t, expectedBatches, c.receivedBatches, "unexpected number of received batches") + require.Equal(t, expectedCommits, c.receivedCommits, "unexpected number of batch commits") + require.Equal(t, len(expectedFilteredRels), c.receivedRels, "unexpected number of received relationships") + require.Equal(t, expectedTouchedBatches, c.touchedBatches, "unexpected number of touched batches") + require.Equal(t, expectedTouchedRels, c.touchedRels, "unexpected number of touched commits") + + // assert on restorer stats + require.Equal(t, expectedWrittenRels, int(r.writtenRels), "unexpected number of written relationships") + require.Equal(t, expectedWrittenBatches, int(r.writtenBatches), "unexpected number of written relationships") + require.Equal(t, expectedSkippedBatches, int(r.skippedBatches), "unexpected number of conflicting batches skipped") + require.Equal(t, expectedSkippedRels, int(r.skippedRels), "unexpected number of conflicting relationships skipped") + require.Equal(t, int64(expectedConflicts)*tt.batchesPerTransaction, r.duplicateBatches, "unexpected number of duplicate batches detected") + require.Equal(t, expectedConflicts*int(tt.batchesPerTransaction)*tt.batchSize, int(r.duplicateRels), "unexpected number of duplicate relationships detected") + require.Equal(t, int64(expectedRetries+expectedConflicts-expectedSkippedBatches), r.totalRetries, "unexpected number of retries") + require.Equal(t, len(tt.relationships)-len(expectedFilteredRels), int(r.filteredOutRels), "unexpected number of filtered out relationships") + }) + } +} + +type mockClient struct { + client.Client + v1.ExperimentalService_BulkImportRelationshipsClient + t *testing.T + remainderBatch bool + expectedRels []string + expectedBatches int + requestedBatchSize int + requestedBatchesPerTransaction int64 + receivedBatches int + receivedCommits int + receivedRels int + touchedBatches int + touchedRels int + lastReceivedBatch []*v1.Relationship + sendErrors []string + commitErrors []string + touchErrors []string +} + +func (m *mockClient) BulkImportRelationships(_ context.Context, _ ...grpc.CallOption) (v1.ExperimentalService_BulkImportRelationshipsClient, error) { + return m, nil +} + +func (m *mockClient) Send(req *v1.BulkImportRelationshipsRequest) error { + m.receivedBatches++ + m.receivedRels += len(req.Relationships) + m.lastReceivedBatch = req.Relationships + if m.receivedBatches <= len(m.sendErrors) { + return fmt.Errorf(m.sendErrors[m.receivedBatches-1]) + } + + if m.receivedBatches == m.expectedBatches && m.remainderBatch { + require.Equal(m.t, len(m.expectedRels)%m.requestedBatchSize, len(req.Relationships)) + } else { + require.Equal(m.t, m.requestedBatchSize, len(req.Relationships)) + } + + for i, rel := range req.Relationships { + require.True(m.t, proto.Equal(rel, tuple.ParseRel(m.expectedRels[((m.receivedBatches-1)*m.requestedBatchSize)+i]))) + } + + return nil +} + +func (m *mockClient) WriteRelationships(_ context.Context, in *v1.WriteRelationshipsRequest, _ ...grpc.CallOption) (*v1.WriteRelationshipsResponse, error) { + m.touchedBatches++ + m.touchedRels += len(in.Updates) + if m.touchedBatches <= len(m.touchErrors) { + return nil, fmt.Errorf(m.touchErrors[m.touchedBatches-1]) + } + + return &v1.WriteRelationshipsResponse{}, nil +} + +func (m *mockClient) CloseAndRecv() (*v1.BulkImportRelationshipsResponse, error) { + m.receivedCommits++ + lastBatch := m.lastReceivedBatch + defer func() { m.lastReceivedBatch = nil }() + + if m.receivedCommits <= len(m.commitErrors) { + return nil, fmt.Errorf(m.commitErrors[m.receivedCommits-1]) + } + + return &v1.BulkImportRelationshipsResponse{NumLoaded: uint64(len(lastBatch))}, nil +}