Skip to content

Commit

Permalink
Merge pull request #120 from bryanhuhta/batch-import
Browse files Browse the repository at this point in the history
Add batching functionality
  • Loading branch information
jzelinskie authored May 13, 2022
2 parents 9f8cd79 + 15d6c98 commit 3c0bf27
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 11 deletions.
36 changes: 26 additions & 10 deletions cmd/zed/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ import (
"github.com/spf13/cobra"

"github.com/authzed/zed/internal/decode"
"github.com/authzed/zed/internal/grpcutil"
"github.com/authzed/zed/internal/storage"
)

func registerImportCmd(rootCmd *cobra.Command) {
rootCmd.AddCommand(importCmd)
importCmd.Flags().Int("batch-size", 1000, "import batch size")
importCmd.Flags().Int("workers", 1, "number of concurrent batching workers")
importCmd.Flags().Bool("schema", true, "import schema")
importCmd.Flags().Bool("relationships", true, "import relationships")
importCmd.Flags().String("schema-definition-prefix", "", "prefix to add to the schema's definition(s) before importing")
Expand Down Expand Up @@ -47,7 +50,7 @@ var importCmd = &cobra.Command{
From a local file (no prefix):
zed import authzed-x7izWU8_2Gw3.yaml
Only schema:
Only schema:
zed import --relationships=false file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml
Only relationships:
Expand Down Expand Up @@ -104,7 +107,9 @@ func importCmdFunc(cmd *cobra.Command, args []string) error {
}

if cobrautil.MustGetBool(cmd, "relationships") {
if err := importRelationships(client, p.Relationships, prefix); err != nil {
batchSize := cobrautil.MustGetInt(cmd, "batch-size")
workers := cobrautil.MustGetInt(cmd, "workers")
if err := importRelationships(client, p.Relationships, prefix, batchSize, workers); err != nil {
return err
}
}
Expand Down Expand Up @@ -132,7 +137,7 @@ func importSchema(client *authzed.Client, schema string, definitionPrefix string
return nil
}

func importRelationships(client *authzed.Client, relationships string, definitionPrefix string) error {
func importRelationships(client *authzed.Client, relationships string, definitionPrefix string, batchSize int, workers int) error {
relationshipUpdates := make([]*v1.RelationshipUpdate, 0)
scanner := bufio.NewScanner(strings.NewReader(relationships))
for scanner.Scan() {
Expand Down Expand Up @@ -164,13 +169,24 @@ func importRelationships(client *authzed.Client, relationships string, definitio
return err
}

request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates}
log.Trace().Interface("request", request).Msg("writing relationships")
log.Info().Int("count", len(relationshipUpdates)).Msg("importing relationships")
log.Info().
Int("batch_size", batchSize).
Int("workers", workers).
Int("count", len(relationshipUpdates)).
Msg("importing relationships")

if _, err := client.WriteRelationships(context.Background(), request); err != nil {
return err
}
err := grpcutil.ConcurrentBatch(context.Background(), len(relationshipUpdates), batchSize, workers, func(ctx context.Context, no int, start int, end int) error {
request := &v1.WriteRelationshipsRequest{Updates: relationshipUpdates[start:end]}
_, err := client.WriteRelationships(ctx, request)
if err != nil {
return err
}

return nil
log.Info().
Int("batch_no", no).
Int("count", len(relationshipUpdates[start:end])).
Msg("imported relationships")
return nil
})
return err
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ require (
github.com/rs/zerolog v1.26.1
github.com/spf13/cobra v1.4.0
github.com/stretchr/testify v1.7.1
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
google.golang.org/grpc v1.45.0
google.golang.org/protobuf v1.28.0
Expand Down Expand Up @@ -140,7 +141,6 @@ require (
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect
Expand Down
68 changes: 68 additions & 0 deletions internal/grpcutil/batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package grpcutil

import (
"context"
"errors"
"fmt"
"runtime"

"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)

func min(a int, b int) int {
if a <= b {
return a
}
return b
}

// EachFunc is a callback function that is called for each batch. no is the
// batch number, start is the starting index of this batch in the slice, and
// end is the ending index of this batch in the slice.
type EachFunc func(ctx context.Context, no int, start int, end int) error

// ConcurrentBatch will calculate the minimum number of batches to required to batch n items
// with batchSize batches. For each batch, it will execute the each function.
// These functions will be processed in parallel using maxWorkers number of
// goroutines. If maxWorkers is 1, then batching will happen sychronously. If
// maxWorkers is 0, then GOMAXPROCS number of workers will be used.
//
// If an error occurs during a batch, all the worker's contexts are cancelled
// and the original error is returned.
func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int, each EachFunc) error {
if n < 0 {
return errors.New("cannot batch items of length < 0")
} else if n == 0 {
// Batching zero items is a noop.
return nil
}

if batchSize < 1 {
return errors.New("cannot batch items with batch size < 1")
}

if maxWorkers < 0 {
return errors.New("cannot batch items with workers < 0")
} else if maxWorkers == 0 {
maxWorkers = runtime.GOMAXPROCS(0)
}

sem := semaphore.NewWeighted(int64(maxWorkers))
g, ctx := errgroup.WithContext(ctx)
numBatches := (n + batchSize - 1) / batchSize
for i := 0; i < numBatches; i++ {
if err := sem.Acquire(ctx, 1); err != nil {
return fmt.Errorf("failed to acquire semaphore for batch number %d: %w", i, err)
}

batchNum := i
g.Go(func() error {
defer sem.Release(1)
start := batchNum * batchSize
end := min(start+batchSize, n)
return each(ctx, batchNum, start, end)
})
}
return g.Wait()
}
140 changes: 140 additions & 0 deletions internal/grpcutil/batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package grpcutil

import (
"context"
"fmt"
"sync/atomic"
"testing"

"github.com/stretchr/testify/require"
)

type batch struct {
no int
start int
end int
}

func generateItems(n int) []string {
items := make([]string, n)
for i := 0; i < n; i++ {
items[i] = fmt.Sprintf("item %d", i)
}
return items
}

func TestConcurrentBatchOrdering(t *testing.T) {
const batchSize = 3
const workers = 1 // Set to one to keep everything synchronous.

tests := []struct {
name string
items []string
want []batch
}{
{
name: "1 item",
items: generateItems(1),
want: []batch{
{0, 0, 1},
},
},
{
name: "3 items",
items: generateItems(3),
want: []batch{
{0, 0, 3},
},
},
{
name: "5 items",
items: generateItems(5),
want: []batch{
{0, 0, 3},
{1, 3, 5},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)

gotCh := make(chan batch, len(tt.items))
fn := func(ctx context.Context, no, start, end int) error {
gotCh <- batch{no, start, end}
return nil
}

err := ConcurrentBatch(context.Background(), len(tt.items), batchSize, workers, fn)
require.NoError(err)

got := make([]batch, len(gotCh))
i := 0
for span := range gotCh {
got[i] = span
i++

if i == len(got) {
break
}
}
require.Equal(tt.want, got)
})
}
}

func TestConcurrentBatch(t *testing.T) {
tests := []struct {
name string
items []string
batchSize int
workers int
wantCalls int
}{
{
name: "5 batches",
items: generateItems(50),
batchSize: 10,
workers: 3,
wantCalls: 5,
},
{
name: "0 batches",
items: []string{},
batchSize: 10,
workers: 3,
wantCalls: 0,
},
{
name: "1 batch",
items: generateItems(10),
batchSize: 10,
workers: 3,
wantCalls: 1,
},
{
name: "1 full batch, 1 partial batch",
items: generateItems(15),
batchSize: 10,
workers: 3,
wantCalls: 2,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)

var calls int64
fn := func(ctx context.Context, no, start, end int) error {
atomic.AddInt64(&calls, 1)
return nil
}
err := ConcurrentBatch(context.Background(), len(tt.items), tt.batchSize, tt.workers, fn)

require.NoError(err)
require.Equal(tt.wantCalls, int(calls))
})
}
}

0 comments on commit 3c0bf27

Please sign in to comment.