diff --git a/cmd/zed/import.go b/cmd/zed/import.go index b0fd21a2..883c3c9d 100644 --- a/cmd/zed/import.go +++ b/cmd/zed/import.go @@ -15,11 +15,14 @@ import ( "github.com/spf13/cobra" "github.com/authzed/zed/internal/decode" + "github.com/authzed/zed/internal/grpcutil" "github.com/authzed/zed/internal/storage" ) func registerImportCmd(rootCmd *cobra.Command) { rootCmd.AddCommand(importCmd) + importCmd.Flags().Int("batch-size", 1000, "import batch size") + importCmd.Flags().Int("workers", 1, "number of concurrent batching workers") importCmd.Flags().Bool("schema", true, "import schema") importCmd.Flags().Bool("relationships", true, "import relationships") importCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before importing") @@ -47,7 +50,7 @@ var importCmd = &cobra.Command{ From a local file (no prefix): zed import authzed-x7izWU8_2Gw3.yaml - Only schema: + Only schema: zed import --relationships=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml Only relationships: @@ -104,7 +107,9 @@ func importCmdFunc(cmd *cobra.Command, args []string) error { } if cobrautil.MustGetBool(cmd, "relationships") { - if err := importRelationships(client, p.Relationships, prefix); err != nil { + batchSize := cobrautil.MustGetInt(cmd, "batch-size") + workers := cobrautil.MustGetInt(cmd, "workers") + if err := importRelationships(client, p.Relationships, prefix, batchSize, workers); err != nil { return err } } @@ -132,7 +137,7 @@ func importSchema(client *authzed.Client, schema string, definitionPrefix string return nil } -func importRelationships(client *authzed.Client, relationships string, definitionPrefix string) error { +func importRelationships(client *authzed.Client, relationships string, definitionPrefix string, batchSize int, workers int) error { relationshipUpdates := make([]*v1.RelationshipUpdate, 0) scanner := bufio.NewScanner(strings.NewReader(relationships)) for scanner.Scan() { @@ -164,13 +169,24 @@ func importRelationships(client *authzed.Client, relationships string, definitio return err } - request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates} - log.Trace().Interface("request", request).Msg("writing relationships") - log.Info().Int("count", len(relationshipUpdates)).Msg("importing relationships") + log.Info(). + Int("batch_size", batchSize). + Int("workers", workers). + Int("count", len(relationshipUpdates)). + Msg("importing relationships") - if _, err := client.WriteRelationships(context.Background(), request); err != nil { - return err - } + err := grpcutil.ConcurrentBatch(context.Background(), len(relationshipUpdates), batchSize, workers, func(ctx context.Context, no int, start int, end int) error { + request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates[start:end]} + _, err := client.WriteRelationships(ctx, request) + if err != nil { + return err + } - return nil + log.Info(). + Int("batch_no", no). + Int("count", len(relationshipUpdates[start:end])). + Msg("imported relationships") + return nil + }) + return err } diff --git a/go.mod b/go.mod index 321a51e6..95e26ce4 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/rs/zerolog v1.26.1 github.com/spf13/cobra v1.4.0 github.com/stretchr/testify v1.7.1 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 google.golang.org/grpc v1.45.0 google.golang.org/protobuf v1.28.0 @@ -140,7 +141,6 @@ require ( golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 // indirect - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect diff --git a/internal/grpcutil/batch.go b/internal/grpcutil/batch.go new file mode 100644 index 00000000..98929015 --- /dev/null +++ b/internal/grpcutil/batch.go @@ -0,0 +1,68 @@ +package grpcutil + +import ( + "context" + "errors" + "fmt" + "runtime" + + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" +) + +func min(a int, b int) int { + if a <= b { + return a + } + return b +} + +// EachFunc is a callback function that is called for each batch. no is the +// batch number, start is the starting index of this batch in the slice, and +// end is the ending index of this batch in the slice. +type EachFunc func(ctx context.Context, no int, start int, end int) error + +// ConcurrentBatch will calculate the minimum number of batches to required to batch n items +// with batchSize batches. For each batch, it will execute the each function. +// These functions will be processed in parallel using maxWorkers number of +// goroutines. If maxWorkers is 1, then batching will happen sychronously. If +// maxWorkers is 0, then GOMAXPROCS number of workers will be used. +// +// If an error occurs during a batch, all the worker's contexts are cancelled +// and the original error is returned. +func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int, each EachFunc) error { + if n < 0 { + return errors.New("cannot batch items of length < 0") + } else if n == 0 { + // Batching zero items is a noop. + return nil + } + + if batchSize < 1 { + return errors.New("cannot batch items with batch size < 1") + } + + if maxWorkers < 0 { + return errors.New("cannot batch items with workers < 0") + } else if maxWorkers == 0 { + maxWorkers = runtime.GOMAXPROCS(0) + } + + sem := semaphore.NewWeighted(int64(maxWorkers)) + g, ctx := errgroup.WithContext(ctx) + numBatches := (n + batchSize - 1) / batchSize + for i := 0; i < numBatches; i++ { + if err := sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire semaphore for batch number %d: %w", i, err) + } + + batchNum := i + g.Go(func() error { + defer sem.Release(1) + start := batchNum * batchSize + end := min(start+batchSize, n) + return each(ctx, batchNum, start, end) + }) + } + return g.Wait() +} diff --git a/internal/grpcutil/batch_test.go b/internal/grpcutil/batch_test.go new file mode 100644 index 00000000..db4de439 --- /dev/null +++ b/internal/grpcutil/batch_test.go @@ -0,0 +1,140 @@ +package grpcutil + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" +) + +type batch struct { + no int + start int + end int +} + +func generateItems(n int) []string { + items := make([]string, n) + for i := 0; i < n; i++ { + items[i] = fmt.Sprintf("item %d", i) + } + return items +} + +func TestConcurrentBatchOrdering(t *testing.T) { + const batchSize = 3 + const workers = 1 // Set to one to keep everything synchronous. + + tests := []struct { + name string + items []string + want []batch + }{ + { + name: "1 item", + items: generateItems(1), + want: []batch{ + {0, 0, 1}, + }, + }, + { + name: "3 items", + items: generateItems(3), + want: []batch{ + {0, 0, 3}, + }, + }, + { + name: "5 items", + items: generateItems(5), + want: []batch{ + {0, 0, 3}, + {1, 3, 5}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + gotCh := make(chan batch, len(tt.items)) + fn := func(ctx context.Context, no, start, end int) error { + gotCh <- batch{no, start, end} + return nil + } + + err := ConcurrentBatch(context.Background(), len(tt.items), batchSize, workers, fn) + require.NoError(err) + + got := make([]batch, len(gotCh)) + i := 0 + for span := range gotCh { + got[i] = span + i++ + + if i == len(got) { + break + } + } + require.Equal(tt.want, got) + }) + } +} + +func TestConcurrentBatch(t *testing.T) { + tests := []struct { + name string + items []string + batchSize int + workers int + wantCalls int + }{ + { + name: "5 batches", + items: generateItems(50), + batchSize: 10, + workers: 3, + wantCalls: 5, + }, + { + name: "0 batches", + items: []string{}, + batchSize: 10, + workers: 3, + wantCalls: 0, + }, + { + name: "1 batch", + items: generateItems(10), + batchSize: 10, + workers: 3, + wantCalls: 1, + }, + { + name: "1 full batch, 1 partial batch", + items: generateItems(15), + batchSize: 10, + workers: 3, + wantCalls: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + var calls int64 + fn := func(ctx context.Context, no, start, end int) error { + atomic.AddInt64(&calls, 1) + return nil + } + err := ConcurrentBatch(context.Background(), len(tt.items), tt.batchSize, tt.workers, fn) + + require.NoError(err) + require.Equal(tt.wantCalls, int(calls)) + }) + } +}