diff --git a/.github/patches/postagecontract.patch b/.github/patches/postagecontract.patch new file mode 100644 index 00000000000..7d015ebfb36 --- /dev/null +++ b/.github/patches/postagecontract.patch @@ -0,0 +1,4 @@ +25c25 +< BucketDepth = uint8(16) +--- +> BucketDepth = uint8(10) diff --git a/.github/patches/postagereserve.patch b/.github/patches/postagereserve.patch new file mode 100644 index 00000000000..6312bf7e881 --- /dev/null +++ b/.github/patches/postagereserve.patch @@ -0,0 +1,4 @@ +43c43 +< var DefaultDepth = uint8(12) // 12 is the testnet depth at the time of merging to master +--- +> var DefaultDepth = uint8(5) // 12 is the testnet depth at the time of merging to master diff --git a/.github/patches/postagereserve_gc.patch b/.github/patches/postagereserve_gc.patch new file mode 100644 index 00000000000..d73a9bd451b --- /dev/null +++ b/.github/patches/postagereserve_gc.patch @@ -0,0 +1,4 @@ +48c48 +< var Capacity = exp2(23) +--- +> var Capacity = exp2(10) diff --git a/.github/workflows/beekeeper.yml b/.github/workflows/beekeeper.yml index af98de473ff..a0d8afc20af 100644 --- a/.github/workflows/beekeeper.yml +++ b/.github/workflows/beekeeper.yml @@ -56,11 +56,15 @@ jobs: mkdir -p ~/.kube cp /etc/rancher/k3s/k3s.yaml ~/.kube/config echo "kubeconfig: ${HOME}/.kube/config" > ~/.beekeeper.yaml + - name: Apply patches + run: | + patch pkg/postage/batchstore/reserve.go .github/patches/postagereserve.patch + patch pkg/postage/postagecontract/contract.go .github/patches/postagecontract.patch - name: Set testing cluster (DNS discovery) run: | echo -e "127.0.0.10\tregistry.localhost" | sudo tee -a /etc/hosts for ((i=0; i target { done = false @@ -201,76 +197,6 @@ func (db *DB) collectGarbage() (collectedCount uint64, done bool, err error) { return collectedCount, done, nil } -// removeChunksInExcludeIndexFromGC removed any recently chunks in the exclude Index, from the gcIndex. -func (db *DB) removeChunksInExcludeIndexFromGC() (err error) { - db.metrics.GCExcludeCounter.Inc() - defer totalTimeMetric(db.metrics.TotalTimeGCExclude, time.Now()) - defer func() { - if err != nil { - db.metrics.GCExcludeError.Inc() - } - }() - - batch := new(leveldb.Batch) - excludedCount := 0 - var gcSizeChange int64 - err = db.gcExcludeIndex.Iterate(func(item shed.Item) (stop bool, err error) { - // Get access timestamp - retrievalAccessIndexItem, err := db.retrievalAccessIndex.Get(item) - if err != nil { - return false, err - } - item.AccessTimestamp = retrievalAccessIndexItem.AccessTimestamp - - // Get the binId - retrievalDataIndexItem, err := db.retrievalDataIndex.Get(item) - if err != nil { - return false, err - } - item.BinID = retrievalDataIndexItem.BinID - - // Check if this item is in gcIndex and remove it - ok, err := db.gcIndex.Has(item) - if err != nil { - return false, nil - } - if ok { - err = db.gcIndex.DeleteInBatch(batch, item) - if err != nil { - return false, nil - } - if _, err := db.gcIndex.Get(item); err == nil { - gcSizeChange-- - } - excludedCount++ - err = db.gcExcludeIndex.DeleteInBatch(batch, item) - if err != nil { - return false, nil - } - } - - return false, nil - }, nil) - if err != nil { - return err - } - - // update the gc size based on the no of entries deleted in gcIndex - err = db.incGCSizeInBatch(batch, gcSizeChange) - if err != nil { - return err - } - - db.metrics.GCExcludeCounter.Add(float64(excludedCount)) - err = db.shed.WriteBatch(batch) - if err != nil { - db.metrics.GCExcludeWriteBatchError.Inc() - return err - } - - return nil -} - // gcTrigger retruns the absolute value for garbage collection // target value, calculated from db.capacity and gcTargetRatio. func (db *DB) gcTarget() (target uint64) { @@ -331,3 +257,5 @@ var testHookCollectGarbage func(collectedCount uint64) // when the GC is done collecting candidate items for // eviction. var testHookGCIteratorDone func() + +var withinRadiusFn func(*DB, shed.Item) bool diff --git a/pkg/localstore/gc_test.go b/pkg/localstore/gc_test.go index 3f4a157b17c..4b10f2457eb 100644 --- a/pkg/localstore/gc_test.go +++ b/pkg/localstore/gc_test.go @@ -73,29 +73,33 @@ func testDBCollectGarbageWorker(t *testing.T) { case <-closed: } })) + + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) db := newTestDB(t, &Options{ Capacity: 100, }) closed = db.close - addrs := make([]swarm.Address, 0) - + addrs := make([]swarm.Address, chunkCount) + ctx := context.Background() // upload random chunks for i := 0; i < chunkCount; i++ { ch := generateTestRandomChunk() - - _, err := db.Put(context.Background(), storage.ModePutUpload, ch) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) + _, err := db.Put(ctx, storage.ModePutUpload, ch) if err != nil { t.Fatal(err) } - err = db.Set(context.Background(), storage.ModeSetSync, ch.Address()) + err = db.Set(ctx, storage.ModeSetSync, ch.Address()) if err != nil { t.Fatal(err) } - addrs = append(addrs, ch.Address()) - + addrs[i] = ch.Address() } gcTarget := db.gcTarget() @@ -150,7 +154,6 @@ func testDBCollectGarbageWorker(t *testing.T) { // Pin a file, upload chunks to go past the gc limit to trigger GC, // check if the pinned files are still around and removed from gcIndex func TestPinGC(t *testing.T) { - chunkCount := 150 pinChunksCount := 50 dbCapacity := uint64(100) @@ -171,6 +174,7 @@ func TestPinGC(t *testing.T) { case <-closed: } })) + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) db := newTestDB(t, &Options{ Capacity: dbCapacity, @@ -183,6 +187,10 @@ func TestPinGC(t *testing.T) { // upload random chunks for i := 0; i < chunkCount; i++ { ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) mode := storage.ModePutUpload if i < pinChunksCount { @@ -221,8 +229,6 @@ func TestPinGC(t *testing.T) { t.Run("pin Index count", newItemsCountTest(db.pinIndex, pinChunksCount)) - t.Run("gc exclude index count", newItemsCountTest(db.gcExcludeIndex, pinChunksCount)) - t.Run("pull index count", newItemsCountTest(db.pullIndex, int(gcTarget)+pinChunksCount)) t.Run("gc index count", newItemsCountTest(db.gcIndex, int(gcTarget))) @@ -277,6 +283,10 @@ func TestGCAfterPin(t *testing.T) { // upload random chunks for i := 0; i < chunkCount; i++ { ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { @@ -298,8 +308,6 @@ func TestGCAfterPin(t *testing.T) { t.Run("pin Index count", newItemsCountTest(db.pinIndex, chunkCount)) - t.Run("gc exclude index count", newItemsCountTest(db.gcExcludeIndex, chunkCount)) - t.Run("gc index count", newItemsCountTest(db.gcIndex, int(0))) for _, hash := range pinAddrs { @@ -314,10 +322,6 @@ func TestGCAfterPin(t *testing.T) { // to test garbage collection runs by uploading, syncing and // requesting a number of chunks. func TestDB_collectGarbageWorker_withRequests(t *testing.T) { - db := newTestDB(t, &Options{ - Capacity: 100, - }) - testHookCollectGarbageChan := make(chan uint64) defer setTestHookCollectGarbage(func(collectedCount uint64) { // don't trigger if we haven't collected anything - this may @@ -330,11 +334,21 @@ func TestDB_collectGarbageWorker_withRequests(t *testing.T) { testHookCollectGarbageChan <- collectedCount })() + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) + + db := newTestDB(t, &Options{ + Capacity: 100, + }) + addrs := make([]swarm.Address, 0) // upload random chunks just up to the capacity for i := 0; i < int(db.capacity)-1; i++ { ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { @@ -357,7 +371,7 @@ func TestDB_collectGarbageWorker_withRequests(t *testing.T) { close(testHookUpdateGCChan) }) - // request the latest synced chunk + // request the oldest synced chunk // to prioritize it in the gc index // not to be collected _, err := db.Get(context.Background(), storage.ModeGetRequest, addrs[0]) @@ -379,6 +393,11 @@ func TestDB_collectGarbageWorker_withRequests(t *testing.T) { // upload and sync another chunk to trigger // garbage collection ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) + _, err = db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { t.Fatal(err) @@ -468,6 +487,10 @@ func TestDB_gcSize(t *testing.T) { for i := 0; i < count; i++ { ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { @@ -502,6 +525,13 @@ func setTestHookCollectGarbage(h func(collectedCount uint64)) (reset func()) { return reset } +func setWithinRadiusFunc(h func(*DB, shed.Item) bool) (reset func()) { + current := withinRadiusFn + reset = func() { withinRadiusFn = current } + withinRadiusFn = h + return reset +} + // TestSetTestHookCollectGarbage tests if setTestHookCollectGarbage changes // testHookCollectGarbage function correctly and if its reset function // resets the original function. @@ -556,6 +586,7 @@ func TestSetTestHookCollectGarbage(t *testing.T) { } func TestPinAfterMultiGC(t *testing.T) { + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) db := newTestDB(t, &Options{ Capacity: 10, }) @@ -565,6 +596,11 @@ func TestPinAfterMultiGC(t *testing.T) { // upload random chunks above db capacity to see if chunks are still pinned for i := 0; i < 20; i++ { ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) + _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { t.Fatal(err) @@ -581,6 +617,11 @@ func TestPinAfterMultiGC(t *testing.T) { } for i := 0; i < 20; i++ { ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) + _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { t.Fatal(err) @@ -592,6 +633,11 @@ func TestPinAfterMultiGC(t *testing.T) { } for i := 0; i < 20; i++ { ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) + _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { t.Fatal(err) @@ -622,21 +668,26 @@ func TestPinAfterMultiGC(t *testing.T) { func generateAndPinAChunk(t *testing.T, db *DB) swarm.Chunk { // Create a chunk and pin it - pinnedChunk := generateTestRandomChunk() + ch := generateTestRandomChunk() - _, err := db.Put(context.Background(), storage.ModePutUpload, pinnedChunk) + _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { t.Fatal(err) } - err = db.Set(context.Background(), storage.ModeSetPin, pinnedChunk.Address()) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) + + err = db.Set(context.Background(), storage.ModeSetPin, ch.Address()) if err != nil { t.Fatal(err) } - err = db.Set(context.Background(), storage.ModeSetSync, pinnedChunk.Address()) + err = db.Set(context.Background(), storage.ModeSetSync, ch.Address()) if err != nil { t.Fatal(err) } - return pinnedChunk + return ch } func TestPinSyncAndAccessPutSetChunkMultipleTimes(t *testing.T) { @@ -716,6 +767,8 @@ func addRandomChunks(t *testing.T, count int, db *DB, pin bool) []swarm.Chunk { var chunks []swarm.Chunk for i := 0; i < count; i++ { ch := generateTestRandomChunk() + unreserveChunkBatch(t, db, 0, ch) + _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { t.Fatal(err) @@ -750,6 +803,7 @@ func addRandomChunks(t *testing.T, count int, db *DB, pin bool) []swarm.Chunk { func TestGC_NoEvictDirty(t *testing.T) { // lower the maximal number of chunks in a single // gc batch to ensure multiple batches. + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) defer func(s uint64) { gcBatchSize = s }(gcBatchSize) gcBatchSize = 1 @@ -810,6 +864,7 @@ func TestGC_NoEvictDirty(t *testing.T) { // upload random chunks for i := 0; i < chunkCount; i++ { ch := generateTestRandomChunk() + unreserveChunkBatch(t, db, 0, ch) _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { @@ -879,7 +934,6 @@ func TestGC_NoEvictDirty(t *testing.T) { t.Fatal(err) } }) - } // setTestHookGCIteratorDone sets testHookGCIteratorDone and @@ -891,3 +945,13 @@ func setTestHookGCIteratorDone(h func()) (reset func()) { testHookGCIteratorDone = h return reset } + +func unreserveChunkBatch(t *testing.T, db *DB, radius uint8, chs ...swarm.Chunk) { + t.Helper() + for _, ch := range chs { + err := db.UnreserveBatch(ch.Stamp().BatchID(), radius) + if err != nil { + t.Fatal(err) + } + } +} diff --git a/pkg/localstore/index_test.go b/pkg/localstore/index_test.go index 05a6bdb5941..7a2f50a3508 100644 --- a/pkg/localstore/index_test.go +++ b/pkg/localstore/index_test.go @@ -22,6 +22,7 @@ import ( "math/rand" "testing" + "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" ) @@ -78,6 +79,7 @@ func TestDB_pullIndex(t *testing.T) { // a chunk with and performing operations using synced, access and // request modes. func TestDB_gcIndex(t *testing.T) { + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) db := newTestDB(t, nil) chunkCount := 50 @@ -87,6 +89,10 @@ func TestDB_gcIndex(t *testing.T) { // upload random chunks for i := 0; i < chunkCount; i++ { ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { diff --git a/pkg/localstore/localstore.go b/pkg/localstore/localstore.go index d73b490618b..72594442acf 100644 --- a/pkg/localstore/localstore.go +++ b/pkg/localstore/localstore.go @@ -25,6 +25,7 @@ import ( "time" "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage" "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" @@ -85,12 +86,15 @@ type DB struct { // garbage collection index gcIndex shed.Index - // garbage collection exclude index for pinned contents - gcExcludeIndex shed.Index - // pin files Index pinIndex shed.Index + // postage chunks index + postageChunksIndex shed.Index + + // postage chunks index + postageRadiusIndex shed.Index + // field that stores number of intems in gc index gcSize shed.Uint64Field @@ -213,6 +217,10 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB DisableSeeksCompaction: o.DisableSeeksCompaction, } + if withinRadiusFn == nil { + withinRadiusFn = withinRadius + } + db.shed, err = shed.NewDB(path, shedOpts) if err != nil { return nil, err @@ -248,7 +256,8 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB } // Index storing actual chunk address, data and bin id. - db.retrievalDataIndex, err = db.shed.NewIndex("Address->StoreTimestamp|BinID|Data", shed.IndexFuncs{ + headerSize := 16 + postage.StampSize + db.retrievalDataIndex, err = db.shed.NewIndex("Address->StoreTimestamp|BinID|BatchID|Sig|Data", shed.IndexFuncs{ EncodeKey: func(fields shed.Item) (key []byte, err error) { return fields.Address, nil }, @@ -257,16 +266,27 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB return e, nil }, EncodeValue: func(fields shed.Item) (value []byte, err error) { - b := make([]byte, 16) + b := make([]byte, headerSize) binary.BigEndian.PutUint64(b[:8], fields.BinID) binary.BigEndian.PutUint64(b[8:16], uint64(fields.StoreTimestamp)) + stamp, err := postage.NewStamp(fields.BatchID, fields.Sig).MarshalBinary() + if err != nil { + return nil, err + } + copy(b[16:], stamp) value = append(b, fields.Data...) return value, nil }, DecodeValue: func(keyItem shed.Item, value []byte) (e shed.Item, err error) { e.StoreTimestamp = int64(binary.BigEndian.Uint64(value[8:16])) e.BinID = binary.BigEndian.Uint64(value[:8]) - e.Data = value[16:] + stamp := new(postage.Stamp) + if err = stamp.UnmarshalBinary(value[16:headerSize]); err != nil { + return e, err + } + e.BatchID = stamp.BatchID() + e.Sig = stamp.Sig() + e.Data = value[headerSize:] return e, nil }, }) @@ -297,9 +317,9 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB return nil, err } // pull index allows history and live syncing per po bin - db.pullIndex, err = db.shed.NewIndex("PO|BinID->Hash|Tag", shed.IndexFuncs{ + db.pullIndex, err = db.shed.NewIndex("PO|BinID->Hash", shed.IndexFuncs{ EncodeKey: func(fields shed.Item) (key []byte, err error) { - key = make([]byte, 41) + key = make([]byte, 9) key[0] = db.po(swarm.NewAddress(fields.Address)) binary.BigEndian.PutUint64(key[1:9], fields.BinID) return key, nil @@ -309,20 +329,14 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB return e, nil }, EncodeValue: func(fields shed.Item) (value []byte, err error) { - value = make([]byte, 36) // 32 bytes address, 4 bytes tag + value = make([]byte, 64) // 32 bytes address, 32 bytes batch id copy(value, fields.Address) - - if fields.Tag != 0 { - binary.BigEndian.PutUint32(value[32:], fields.Tag) - } - + copy(value[32:], fields.BatchID) return value, nil }, DecodeValue: func(keyItem shed.Item, value []byte) (e shed.Item, err error) { e.Address = value[:32] - if len(value) > 32 { - e.Tag = binary.BigEndian.Uint32(value[32:]) - } + e.BatchID = value[32:64] return e, nil }, }) @@ -367,7 +381,7 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB // create a push syncing triggers used by SubscribePush function db.pushTriggers = make([]chan<- struct{}, 0) // gc index for removable chunk ordered by ascending last access time - db.gcIndex, err = db.shed.NewIndex("AccessTimestamp|BinID|Hash->nil", shed.IndexFuncs{ + db.gcIndex, err = db.shed.NewIndex("AccessTimestamp|BinID|Hash->BatchID", shed.IndexFuncs{ EncodeKey: func(fields shed.Item) (key []byte, err error) { b := make([]byte, 16, 16+len(fields.Address)) binary.BigEndian.PutUint64(b[:8], uint64(fields.AccessTimestamp)) @@ -382,9 +396,14 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB return e, nil }, EncodeValue: func(fields shed.Item) (value []byte, err error) { - return nil, nil + value = make([]byte, 32) + copy(value, fields.BatchID) + return value, nil }, DecodeValue: func(keyItem shed.Item, value []byte) (e shed.Item, err error) { + e = keyItem + e.BatchID = make([]byte, 32) + copy(e.BatchID, value) return e, nil }, }) @@ -415,13 +434,17 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB return nil, err } - // Create a index structure for excluding pinned chunks from gcIndex - db.gcExcludeIndex, err = db.shed.NewIndex("Hash->nil", shed.IndexFuncs{ + db.postageChunksIndex, err = db.shed.NewIndex("BatchID|PO|Hash->nil", shed.IndexFuncs{ EncodeKey: func(fields shed.Item) (key []byte, err error) { - return fields.Address, nil + key = make([]byte, 65) + copy(key[:32], fields.BatchID) + key[32] = db.po(swarm.NewAddress(fields.Address)) + copy(key[33:], fields.Address) + return key, nil }, DecodeKey: func(key []byte) (e shed.Item, err error) { - e.Address = key + e.BatchID = key[:32] + e.Address = key[33:65] return e, nil }, EncodeValue: func(fields shed.Item) (value []byte, err error) { @@ -435,6 +458,28 @@ func New(path string, baseKey []byte, o *Options, logger logging.Logger) (db *DB return nil, err } + db.postageRadiusIndex, err = db.shed.NewIndex("BatchID->Radius", shed.IndexFuncs{ + EncodeKey: func(fields shed.Item) (key []byte, err error) { + key = make([]byte, 32) + copy(key[:32], fields.BatchID) + return key, nil + }, + DecodeKey: func(key []byte) (e shed.Item, err error) { + e.BatchID = key[:32] + return e, nil + }, + EncodeValue: func(fields shed.Item) (value []byte, err error) { + return []byte{fields.Radius}, nil + }, + DecodeValue: func(keyItem shed.Item, value []byte) (e shed.Item, err error) { + e.Radius = value[0] + return e, nil + }, + }) + if err != nil { + return nil, err + } + // start garbage collection worker go db.collectGarbageWorker() return db, nil @@ -485,8 +530,9 @@ func (db *DB) DebugIndices() (indexInfo map[string]int, err error) { "pushIndex": db.pushIndex, "pullIndex": db.pullIndex, "gcIndex": db.gcIndex, - "gcExcludeIndex": db.gcExcludeIndex, "pinIndex": db.pinIndex, + "postageChunksIndex": db.postageChunksIndex, + "postageRadiusIndex": db.postageRadiusIndex, } { indexSize, err := v.Count() if err != nil { @@ -509,6 +555,10 @@ func chunkToItem(ch swarm.Chunk) shed.Item { Address: ch.Address().Bytes(), Data: ch.Data(), Tag: ch.TagID(), + BatchID: ch.Stamp().BatchID(), + Sig: ch.Stamp().Sig(), + Depth: ch.Depth(), + Radius: ch.Radius(), } } diff --git a/pkg/localstore/localstore_test.go b/pkg/localstore/localstore_test.go index c2165d15139..220f6704d72 100644 --- a/pkg/localstore/localstore_test.go +++ b/pkg/localstore/localstore_test.go @@ -30,6 +30,7 @@ import ( "time" "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage" "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" chunktesting "github.com/ethersphere/bee/pkg/storage/testing" @@ -171,8 +172,9 @@ func newTestDB(t testing.TB, o *Options) *DB { } var ( - generateTestRandomChunk = chunktesting.GenerateTestRandomChunk - generateTestRandomChunks = chunktesting.GenerateTestRandomChunks + generateTestRandomChunk = chunktesting.GenerateTestRandomChunk + generateTestRandomChunks = chunktesting.GenerateTestRandomChunks + generateTestRandomChunkAt = chunktesting.GenerateTestRandomChunkAt ) // chunkAddresses return chunk addresses of provided chunks. @@ -251,7 +253,7 @@ func newRetrieveIndexesTest(db *DB, chunk swarm.Chunk, storeTimestamp, accessTim if err != nil { t.Fatal(err) } - validateItem(t, item, chunk.Address().Bytes(), chunk.Data(), storeTimestamp, 0) + validateItem(t, item, chunk.Address().Bytes(), chunk.Data(), storeTimestamp, 0, chunk.Stamp()) // access index should not be set wantErr := leveldb.ErrNotFound @@ -272,15 +274,14 @@ func newRetrieveIndexesTestWithAccess(db *DB, ch swarm.Chunk, storeTimestamp, ac if err != nil { t.Fatal(err) } - validateItem(t, item, ch.Address().Bytes(), ch.Data(), storeTimestamp, 0) if accessTimestamp > 0 { - item, err = db.retrievalAccessIndex.Get(addressToItem(ch.Address())) + item, err = db.retrievalAccessIndex.Get(item) if err != nil { t.Fatal(err) } - validateItem(t, item, ch.Address().Bytes(), nil, 0, accessTimestamp) } + validateItem(t, item, ch.Address().Bytes(), ch.Data(), storeTimestamp, accessTimestamp, ch.Stamp()) } } @@ -298,7 +299,7 @@ func newPullIndexTest(db *DB, ch swarm.Chunk, binID uint64, wantError error) fun t.Errorf("got error %v, want %v", err, wantError) } if err == nil { - validateItem(t, item, ch.Address().Bytes(), nil, 0, 0) + validateItem(t, item, ch.Address().Bytes(), nil, 0, 0, postage.NewStamp(ch.Stamp().BatchID(), nil)) } } } @@ -317,14 +318,14 @@ func newPushIndexTest(db *DB, ch swarm.Chunk, storeTimestamp int64, wantError er t.Errorf("got error %v, want %v", err, wantError) } if err == nil { - validateItem(t, item, ch.Address().Bytes(), nil, storeTimestamp, 0) + validateItem(t, item, ch.Address().Bytes(), nil, storeTimestamp, 0, postage.NewStamp(nil, nil)) } } } // newGCIndexTest returns a test function that validates if the right // chunk values are in the GC index. -func newGCIndexTest(db *DB, chunk swarm.Chunk, storeTimestamp, accessTimestamp int64, binID uint64, wantError error) func(t *testing.T) { +func newGCIndexTest(db *DB, chunk swarm.Chunk, storeTimestamp, accessTimestamp int64, binID uint64, wantError error, stamp *postage.Stamp) func(t *testing.T) { return func(t *testing.T) { t.Helper() @@ -337,7 +338,7 @@ func newGCIndexTest(db *DB, chunk swarm.Chunk, storeTimestamp, accessTimestamp i t.Errorf("got error %v, want %v", err, wantError) } if err == nil { - validateItem(t, item, chunk.Address().Bytes(), nil, 0, accessTimestamp) + validateItem(t, item, chunk.Address().Bytes(), nil, 0, accessTimestamp, stamp) } } } @@ -355,7 +356,7 @@ func newPinIndexTest(db *DB, chunk swarm.Chunk, wantError error) func(t *testing t.Errorf("got error %v, want %v", err, wantError) } if err == nil { - validateItem(t, item, chunk.Address().Bytes(), nil, 0, 0) + validateItem(t, item, chunk.Address().Bytes(), nil, 0, 0, postage.NewStamp(nil, nil)) } } } @@ -438,7 +439,7 @@ func testItemsOrder(t *testing.T, i shed.Index, chunks []testIndexChunk, sortFun } // validateItem is a helper function that checks Item values. -func validateItem(t *testing.T, item shed.Item, address, data []byte, storeTimestamp, accessTimestamp int64) { +func validateItem(t *testing.T, item shed.Item, address, data []byte, storeTimestamp, accessTimestamp int64, stamp swarm.Stamp) { t.Helper() if !bytes.Equal(item.Address, address) { @@ -453,6 +454,12 @@ func validateItem(t *testing.T, item shed.Item, address, data []byte, storeTimes if item.AccessTimestamp != accessTimestamp { t.Errorf("got item access timestamp %v, want %v", item.AccessTimestamp, accessTimestamp) } + if !bytes.Equal(item.BatchID, stamp.BatchID()) { + t.Errorf("got batch ID %x, want %x", item.BatchID, stamp.BatchID()) + } + if !bytes.Equal(item.Sig, stamp.Sig()) { + t.Errorf("got signature %x, want %x", item.Sig, stamp.Sig()) + } } // setNow replaces now function and @@ -514,7 +521,7 @@ func TestSetNow(t *testing.T) { } } -func testIndexCounts(t *testing.T, pushIndex, pullIndex, gcIndex, gcExcludeIndex, pinIndex, retrievalDataIndex, retrievalAccessIndex int, indexInfo map[string]int) { +func testIndexCounts(t *testing.T, pushIndex, pullIndex, gcIndex, pinIndex, retrievalDataIndex, retrievalAccessIndex int, indexInfo map[string]int) { t.Helper() if indexInfo["pushIndex"] != pushIndex { t.Fatalf("pushIndex count mismatch. got %d want %d", indexInfo["pushIndex"], pushIndex) @@ -528,10 +535,6 @@ func testIndexCounts(t *testing.T, pushIndex, pullIndex, gcIndex, gcExcludeIndex t.Fatalf("gcIndex count mismatch. got %d want %d", indexInfo["gcIndex"], gcIndex) } - if indexInfo["gcExcludeIndex"] != gcExcludeIndex { - t.Fatalf("gcExcludeIndex count mismatch. got %d want %d", indexInfo["gcExcludeIndex"], gcExcludeIndex) - } - if indexInfo["pinIndex"] != pinIndex { t.Fatalf("pinIndex count mismatch. got %d want %d", indexInfo["pinIndex"], pinIndex) } @@ -568,7 +571,7 @@ func TestDBDebugIndexes(t *testing.T) { t.Fatal(err) } - testIndexCounts(t, 1, 1, 0, 0, 0, 1, 0, indexCounts) + testIndexCounts(t, 1, 1, 0, 0, 1, 0, indexCounts) // set the chunk for pinning and expect the index count to grow err = db.Set(ctx, storage.ModeSetPin, ch.Address()) @@ -582,5 +585,5 @@ func TestDBDebugIndexes(t *testing.T) { } // assert that there's a pin and gc exclude entry now - testIndexCounts(t, 1, 1, 0, 1, 1, 1, 0, indexCounts) + testIndexCounts(t, 1, 1, 0, 1, 1, 0, indexCounts) } diff --git a/pkg/localstore/mode_get.go b/pkg/localstore/mode_get.go index a74c9338272..0bd64363c74 100644 --- a/pkg/localstore/mode_get.go +++ b/pkg/localstore/mode_get.go @@ -21,11 +21,11 @@ import ( "errors" "time" - "github.com/syndtr/goleveldb/leveldb" - + "github.com/ethersphere/bee/pkg/postage" "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" + "github.com/syndtr/goleveldb/leveldb" ) // Get returns a chunk from the database. If the chunk is @@ -50,7 +50,8 @@ func (db *DB) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) } return nil, err } - return swarm.NewChunk(swarm.NewAddress(out.Address), out.Data), nil + return swarm.NewChunk(swarm.NewAddress(out.Address), out.Data). + WithStamp(postage.NewStamp(out.BatchID, out.Sig)), nil } // get returns Item from the retrieval index @@ -152,25 +153,28 @@ func (db *DB) updateGC(item shed.Item) (err error) { if err != nil { return err } - // update access timestamp + + // update the gc item timestamp in case + // it exists + _, err = db.gcIndex.Get(item) item.AccessTimestamp = now() - // update retrieve access index - err = db.retrievalAccessIndex.PutInBatch(batch, item) - if err != nil { + if err == nil { + err = db.gcIndex.PutInBatch(batch, item) + if err != nil { + return err + } + } else if !errors.Is(err, leveldb.ErrNotFound) { return err } + // if the item is not in the gc we don't + // update the gc index, since the item is + // in the reserve. - // add new entry to gc index ONLY if it is not present in pinIndex - ok, err := db.pinIndex.Has(item) + // update retrieve access index + err = db.retrievalAccessIndex.PutInBatch(batch, item) if err != nil { return err } - if !ok { - err = db.gcIndex.PutInBatch(batch, item) - if err != nil { - return err - } - } return db.shed.WriteBatch(batch) } diff --git a/pkg/localstore/mode_get_multi.go b/pkg/localstore/mode_get_multi.go index ba618d464a3..468215ccbb1 100644 --- a/pkg/localstore/mode_get_multi.go +++ b/pkg/localstore/mode_get_multi.go @@ -21,6 +21,7 @@ import ( "errors" "time" + "github.com/ethersphere/bee/pkg/postage" "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" @@ -50,7 +51,8 @@ func (db *DB) GetMulti(ctx context.Context, mode storage.ModeGet, addrs ...swarm } chunks = make([]swarm.Chunk, len(out)) for i, ch := range out { - chunks[i] = swarm.NewChunk(swarm.NewAddress(ch.Address), ch.Data) + chunks[i] = swarm.NewChunk(swarm.NewAddress(ch.Address), ch.Data). + WithStamp(postage.NewStamp(ch.BatchID, ch.Sig)) } return chunks, nil } diff --git a/pkg/localstore/mode_get_test.go b/pkg/localstore/mode_get_test.go index f356ad52a9c..d2250e8fc76 100644 --- a/pkg/localstore/mode_get_test.go +++ b/pkg/localstore/mode_get_test.go @@ -22,11 +22,14 @@ import ( "testing" "time" + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" ) // TestModeGetRequest validates ModeGetRequest index values on the provided DB. func TestModeGetRequest(t *testing.T) { + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) db := newTestDB(t, nil) uploadTimestamp := time.Now().UTC().UnixNano() @@ -35,6 +38,10 @@ func TestModeGetRequest(t *testing.T) { })() ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) _, err := db.Put(context.Background(), storage.ModePutUpload, ch) if err != nil { @@ -98,8 +105,9 @@ func TestModeGetRequest(t *testing.T) { t.Run("retrieve indexes", newRetrieveIndexesTestWithAccess(db, ch, uploadTimestamp, uploadTimestamp)) - t.Run("gc index", newGCIndexTest(db, ch, uploadTimestamp, uploadTimestamp, 1, nil)) + t.Run("gc index", newGCIndexTest(db, ch, uploadTimestamp, uploadTimestamp, 1, nil, postage.NewStamp(ch.Stamp().BatchID(), nil))) + t.Run("access count", newItemsCountTest(db.retrievalAccessIndex, 1)) t.Run("gc index count", newItemsCountTest(db.gcIndex, 1)) t.Run("gc size", newIndexGCSizeTest(db)) @@ -128,8 +136,9 @@ func TestModeGetRequest(t *testing.T) { t.Run("retrieve indexes", newRetrieveIndexesTestWithAccess(db, ch, uploadTimestamp, accessTimestamp)) - t.Run("gc index", newGCIndexTest(db, ch, uploadTimestamp, accessTimestamp, 1, nil)) + t.Run("gc index", newGCIndexTest(db, ch, uploadTimestamp, accessTimestamp, 1, nil, postage.NewStamp(ch.Stamp().BatchID(), nil))) + t.Run("access count", newItemsCountTest(db.retrievalAccessIndex, 1)) t.Run("gc index count", newItemsCountTest(db.gcIndex, 1)) t.Run("gc size", newIndexGCSizeTest(db)) @@ -153,7 +162,7 @@ func TestModeGetRequest(t *testing.T) { t.Run("retrieve indexes", newRetrieveIndexesTestWithAccess(db, ch, uploadTimestamp, uploadTimestamp)) - t.Run("gc index", newGCIndexTest(db, ch, uploadTimestamp, uploadTimestamp, 1, nil)) + t.Run("gc index", newGCIndexTest(db, ch, uploadTimestamp, uploadTimestamp, 1, nil, postage.NewStamp(ch.Stamp().BatchID(), nil))) t.Run("gc index count", newItemsCountTest(db.gcIndex, 1)) diff --git a/pkg/localstore/mode_put.go b/pkg/localstore/mode_put.go index 653c811e8ee..3857f8f325d 100644 --- a/pkg/localstore/mode_put.go +++ b/pkg/localstore/mode_put.go @@ -91,25 +91,21 @@ func (db *DB) put(mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err e binIDs := make(map[uint8]uint64) switch mode { - case storage.ModePutRequest, storage.ModePutRequestPin: + case storage.ModePutRequest, storage.ModePutRequestPin, storage.ModePutRequestCache: for i, ch := range chs { if containsChunk(ch.Address(), chs[:i]...) { exist[i] = true continue } - exists, c, err := db.putRequest(batch, binIDs, chunkToItem(ch)) + item := chunkToItem(ch) + pin := mode == storage.ModePutRequestPin // force pin in this mode + cache := mode == storage.ModePutRequestCache // force cache + exists, c, err := db.putRequest(batch, binIDs, item, pin, cache) if err != nil { return nil, err } exist[i] = exists gcSizeChange += c - - if mode == storage.ModePutRequestPin { - err = db.setPin(batch, ch.Address()) - if err != nil { - return nil, err - } - } } case storage.ModePutUpload, storage.ModePutUploadPin: @@ -118,7 +114,8 @@ func (db *DB) put(mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err e exist[i] = true continue } - exists, c, err := db.putUpload(batch, binIDs, chunkToItem(ch)) + item := chunkToItem(ch) + exists, c, err := db.putUpload(batch, binIDs, item) if err != nil { return nil, err } @@ -131,11 +128,12 @@ func (db *DB) put(mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err e } gcSizeChange += c if mode == storage.ModePutUploadPin { - err = db.setPin(batch, ch.Address()) + c, err = db.setPin(batch, item) if err != nil { return nil, err } } + gcSizeChange += c } case storage.ModePutSync: @@ -189,12 +187,12 @@ func (db *DB) put(mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err e // - it does not enter the syncpool // The batch can be written to the database. // Provided batch and binID map are updated. -func (db *DB) putRequest(batch *leveldb.Batch, binIDs map[uint8]uint64, item shed.Item) (exists bool, gcSizeChange int64, err error) { - has, err := db.retrievalDataIndex.Has(item) +func (db *DB) putRequest(batch *leveldb.Batch, binIDs map[uint8]uint64, item shed.Item, forcePin, forceCache bool) (exists bool, gcSizeChange int64, err error) { + exists, err = db.retrievalDataIndex.Has(item) if err != nil { return false, 0, err } - if has { + if exists { return true, 0, nil } @@ -203,17 +201,35 @@ func (db *DB) putRequest(batch *leveldb.Batch, binIDs map[uint8]uint64, item she if err != nil { return false, 0, err } + err = db.retrievalDataIndex.PutInBatch(batch, item) + if err != nil { + return false, 0, err + } + err = db.postageChunksIndex.PutInBatch(batch, item) + if err != nil { + return false, 0, err + } - gcSizeChange, err = db.setGC(batch, item) + item.AccessTimestamp = now() + err = db.retrievalAccessIndex.PutInBatch(batch, item) if err != nil { return false, 0, err } - err = db.retrievalDataIndex.PutInBatch(batch, item) + gcSizeChange, err = db.preserveOrCache(batch, item, forcePin, forceCache) if err != nil { return false, 0, err } + if !forceCache { + // if we are here it means the chunk has a valid stamp + // therefore we'd like to be able to pullsync it + err = db.pullIndex.PutInBatch(batch, item) + if err != nil { + return false, 0, err + } + } + return false, gcSizeChange, nil } @@ -248,6 +264,10 @@ func (db *DB) putUpload(batch *leveldb.Batch, binIDs map[uint8]uint64, item shed return false, 0, err } + err = db.postageChunksIndex.PutInBatch(batch, item) + if err != nil { + return false, 0, err + } return false, 0, nil } @@ -277,7 +297,18 @@ func (db *DB) putSync(batch *leveldb.Batch, binIDs map[uint8]uint64, item shed.I if err != nil { return false, 0, err } - gcSizeChange, err = db.setGC(batch, item) + err = db.postageChunksIndex.PutInBatch(batch, item) + if err != nil { + return false, 0, err + } + + item.AccessTimestamp = now() + err = db.retrievalAccessIndex.PutInBatch(batch, item) + if err != nil { + return false, 0, err + } + + gcSizeChange, err = db.preserveOrCache(batch, item, false, false) if err != nil { return false, 0, err } @@ -285,38 +316,20 @@ func (db *DB) putSync(batch *leveldb.Batch, binIDs map[uint8]uint64, item shed.I return false, gcSizeChange, nil } -// setGC is a helper function used to add chunks to the retrieval access -// index and the gc index in the cases that the putToGCCheck condition -// warrants a gc set. this is to mitigate index leakage in edge cases where -// a chunk is added to a node's localstore and given that the chunk is -// already within that node's NN (thus, it can be added to the gc index -// safely) -func (db *DB) setGC(batch *leveldb.Batch, item shed.Item) (gcSizeChange int64, err error) { - if item.BinID == 0 { - i, err := db.retrievalDataIndex.Get(item) - if err != nil { - return 0, err - } - item.BinID = i.BinID - } - i, err := db.retrievalAccessIndex.Get(item) - switch { - case err == nil: - item.AccessTimestamp = i.AccessTimestamp - err = db.gcIndex.DeleteInBatch(batch, item) - if err != nil { - return 0, err - } - gcSizeChange-- - case errors.Is(err, leveldb.ErrNotFound): - // the chunk is not accessed before - default: - return 0, err - } - item.AccessTimestamp = now() - err = db.retrievalAccessIndex.PutInBatch(batch, item) +// preserveOrCache is a helper function used to add chunks to either a pinned reserve or gc cache +// (the retrieval access index and the gc index) +func (db *DB) preserveOrCache(batch *leveldb.Batch, item shed.Item, forcePin, forceCache bool) (gcSizeChange int64, err error) { + // item needs to be populated with Radius + item2, err := db.postageRadiusIndex.Get(item) if err != nil { - return 0, err + // if there's an error, assume the chunk needs to be GCd + forceCache = true + } else { + item.Radius = item2.Radius + } + + if !forceCache && (withinRadiusFn(db, item) || forcePin) { + return db.setPin(batch, item) } // add new entry to gc index ONLY if it is not present in pinIndex @@ -324,13 +337,21 @@ func (db *DB) setGC(batch *leveldb.Batch, item shed.Item) (gcSizeChange int64, e if err != nil { return 0, err } - if !ok { - err = db.gcIndex.PutInBatch(batch, item) - if err != nil { - return 0, err - } - gcSizeChange++ + if ok { + return gcSizeChange, nil + } + exists, err := db.gcIndex.Has(item) + if err != nil && !errors.Is(err, leveldb.ErrNotFound) { + return 0, err + } + if exists { + return 0, nil + } + err = db.gcIndex.PutInBatch(batch, item) + if err != nil { + return 0, err } + gcSizeChange++ return gcSizeChange, nil } diff --git a/pkg/localstore/mode_put_test.go b/pkg/localstore/mode_put_test.go index 1c76f31e939..fecb484d0d8 100644 --- a/pkg/localstore/mode_put_test.go +++ b/pkg/localstore/mode_put_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "github.com/syndtr/goleveldb/leveldb" @@ -31,11 +32,16 @@ import ( // TestModePutRequest validates ModePutRequest index values on the provided DB. func TestModePutRequest(t *testing.T) { + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) for _, tc := range multiChunkTestCases { t.Run(tc.name, func(t *testing.T) { db := newTestDB(t, nil) chunks := generateTestRandomChunks(tc.count) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, chunks...) // keep the record when the chunk is stored var storeTimestamp int64 @@ -58,6 +64,7 @@ func TestModePutRequest(t *testing.T) { } newItemsCountTest(db.gcIndex, tc.count)(t) + newItemsCountTest(db.pullIndex, tc.count)(t) newIndexGCSizeTest(db)(t) }) @@ -77,6 +84,7 @@ func TestModePutRequest(t *testing.T) { } newItemsCountTest(db.gcIndex, tc.count)(t) + newItemsCountTest(db.pullIndex, tc.count)(t) newIndexGCSizeTest(db)(t) }) }) @@ -85,17 +93,21 @@ func TestModePutRequest(t *testing.T) { // TestModePutRequestPin validates ModePutRequestPin index values on the provided DB. func TestModePutRequestPin(t *testing.T) { + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) for _, tc := range multiChunkTestCases { t.Run(tc.name, func(t *testing.T) { db := newTestDB(t, nil) chunks := generateTestRandomChunks(tc.count) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, chunks...) wantTimestamp := time.Now().UTC().UnixNano() defer setNow(func() (t int64) { return wantTimestamp })() - _, err := db.Put(context.Background(), storage.ModePutRequestPin, chunks...) if err != nil { t.Fatal(err) @@ -106,6 +118,46 @@ func TestModePutRequestPin(t *testing.T) { newPinIndexTest(db, ch, nil)(t) } + // gc index should be always 0 since we're pinning + newItemsCountTest(db.gcIndex, 0)(t) + }) + } +} + +// TestModePutRequestCache validates ModePutRequestCache index values on the provided DB. +func TestModePutRequestCache(t *testing.T) { + // note: we set WithinRadius to be true, and verify that nevertheless + // the chunk lands in the cache + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return true })) + for _, tc := range multiChunkTestCases { + t.Run(tc.name, func(t *testing.T) { + db := newTestDB(t, nil) + var chunks []swarm.Chunk + for i := 0; i < tc.count; i++ { + chunk := generateTestRandomChunkAt(swarm.NewAddress(db.baseKey), 2) + chunks = append(chunks, chunk) + } + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database. in the following case + // the radius is 2, and since chunk PO is 2, it falls within + // radius. + unreserveChunkBatch(t, db, 2, chunks...) + + wantTimestamp := time.Now().UTC().UnixNano() + defer setNow(func() (t int64) { + return wantTimestamp + })() + _, err := db.Put(context.Background(), storage.ModePutRequestCache, chunks...) + if err != nil { + t.Fatal(err) + } + + for _, ch := range chunks { + newRetrieveIndexesTestWithAccess(db, ch, wantTimestamp, wantTimestamp)(t) + newPinIndexTest(db, ch, leveldb.ErrNotFound)(t) + } + newItemsCountTest(db.gcIndex, tc.count)(t) }) } @@ -113,6 +165,7 @@ func TestModePutRequestPin(t *testing.T) { // TestModePutSync validates ModePutSync index values on the provided DB. func TestModePutSync(t *testing.T) { + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) for _, tc := range multiChunkTestCases { t.Run(tc.name, func(t *testing.T) { db := newTestDB(t, nil) @@ -123,6 +176,10 @@ func TestModePutSync(t *testing.T) { })() chunks := generateTestRandomChunks(tc.count) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, chunks...) _, err := db.Put(context.Background(), storage.ModePutSync, chunks...) if err != nil { @@ -141,6 +198,8 @@ func TestModePutSync(t *testing.T) { newItemsCountTest(db.gcIndex, tc.count)(t) newIndexGCSizeTest(db)(t) } + newItemsCountTest(db.gcIndex, tc.count)(t) + newIndexGCSizeTest(db)(t) }) } } @@ -157,6 +216,10 @@ func TestModePutUpload(t *testing.T) { })() chunks := generateTestRandomChunks(tc.count) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, chunks...) _, err := db.Put(context.Background(), storage.ModePutUpload, chunks...) if err != nil { @@ -190,6 +253,10 @@ func TestModePutUploadPin(t *testing.T) { })() chunks := generateTestRandomChunks(tc.count) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, chunks...) _, err := db.Put(context.Background(), storage.ModePutUploadPin, chunks...) if err != nil { @@ -270,6 +337,11 @@ func TestModePutUpload_parallel(t *testing.T) { go func() { for i := 0; i < uploadsCount; i++ { chs := generateTestRandomChunks(tc.count) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, chunks...) + select { case chunksChan <- chs: case <-doneChan: @@ -308,8 +380,9 @@ func TestModePutUpload_parallel(t *testing.T) { } // TestModePut_sameChunk puts the same chunk multiple times -// and validates that all relevant indexes have only one item -// in them. +// and validates that all relevant indexes have the correct counts. +// The test assumes that chunk fall into the reserve part of +// the store. func TestModePut_sameChunk(t *testing.T) { for _, tc := range multiChunkTestCases { t.Run(tc.name, func(t *testing.T) { @@ -324,6 +397,18 @@ func TestModePut_sameChunk(t *testing.T) { { name: "ModePutRequest", mode: storage.ModePutRequest, + pullIndex: true, + pushIndex: false, + }, + { + name: "ModePutRequestPin", + mode: storage.ModePutRequest, + pullIndex: true, + pushIndex: false, + }, + { + name: "ModePutRequestCache", + mode: storage.ModePutRequestCache, pullIndex: false, pushIndex: false, }, @@ -342,6 +427,10 @@ func TestModePut_sameChunk(t *testing.T) { } { t.Run(tcn.name, func(t *testing.T) { db := newTestDB(t, nil) + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, chunks...) for i := 0; i < 10; i++ { exist, err := db.Put(context.Background(), tcn.mode, chunks...) @@ -390,6 +479,7 @@ func TestPutDuplicateChunks(t *testing.T) { db := newTestDB(t, nil) ch := generateTestRandomChunk() + unreserveChunkBatch(t, db, 0, ch) exist, err := db.Put(context.Background(), mode, ch, ch) if err != nil { diff --git a/pkg/localstore/mode_set.go b/pkg/localstore/mode_set.go index ca87d81ae6b..e7875da316c 100644 --- a/pkg/localstore/mode_set.go +++ b/pkg/localstore/mode_set.go @@ -19,13 +19,14 @@ package localstore import ( "context" "errors" + "fmt" "time" - "github.com/syndtr/goleveldb/leveldb" - + "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/tags" + "github.com/syndtr/goleveldb/leveldb" ) // Set updates database indexes for @@ -44,8 +45,6 @@ func (db *DB) Set(ctx context.Context, mode storage.ModeSet, addrs ...swarm.Addr // set updates database indexes for // chunks represented by provided addresses. -// It acquires lockAddr to protect two calls -// of this function for the same address in parallel. func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) { // protect parallel updates db.batchMu.Lock() @@ -62,9 +61,10 @@ func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) { triggerPullFeed := make(map[uint8]struct{}) // signal pull feed subscriptions to iterate switch mode { + case storage.ModeSetSync: for _, addr := range addrs { - c, err := db.setSync(batch, addr, mode) + c, err := db.setSync(batch, addr) if err != nil { return err } @@ -73,7 +73,8 @@ func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) { case storage.ModeSetRemove: for _, addr := range addrs { - c, err := db.setRemove(batch, addr) + item := addressToItem(addr) + c, err := db.setRemove(batch, item, true) if err != nil { return err } @@ -82,26 +83,20 @@ func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) { case storage.ModeSetPin: for _, addr := range addrs { - has, err := db.retrievalDataIndex.Has(addressToItem(addr)) - if err != nil { - return err - } - - if !has { - return storage.ErrNotFound - } - - err = db.setPin(batch, addr) + item := addressToItem(addr) + c, err := db.setPin(batch, item) if err != nil { return err } + gcSizeChange += c } case storage.ModeSetUnpin: for _, addr := range addrs { - err := db.setUnpin(batch, addr) + c, err := db.setUnpin(batch, addr) if err != nil { return err } + gcSizeChange += c } default: return ErrInvalidMode @@ -127,7 +122,7 @@ func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) { // from push sync index // - update to gc index happens given item does not exist in pin index // Provided batch is updated. -func (db *DB) setSync(batch *leveldb.Batch, addr swarm.Address, mode storage.ModeSet) (gcSizeChange int64, err error) { +func (db *DB) setSync(batch *leveldb.Batch, addr swarm.Address) (gcSizeChange int64, err error) { item := addressToItem(addr) // need to get access timestamp here as it is not @@ -151,6 +146,7 @@ func (db *DB) setSync(batch *leveldb.Batch, addr swarm.Address, mode storage.Mod } item.StoreTimestamp = i.StoreTimestamp item.BinID = i.BinID + item.BatchID = i.BatchID i, err = db.pushIndex.Get(item) if err != nil { @@ -182,65 +178,51 @@ func (db *DB) setSync(batch *leveldb.Batch, addr swarm.Address, mode storage.Mod return 0, err } - i, err = db.retrievalAccessIndex.Get(item) - switch { - case err == nil: - item.AccessTimestamp = i.AccessTimestamp - err = db.gcIndex.DeleteInBatch(batch, item) - if err != nil { + i1, err := db.retrievalAccessIndex.Get(item) + if err != nil { + if !errors.Is(err, leveldb.ErrNotFound) { return 0, err } - gcSizeChange-- - case errors.Is(err, leveldb.ErrNotFound): - // the chunk is not accessed before - default: - return 0, err - } - item.AccessTimestamp = now() - err = db.retrievalAccessIndex.PutInBatch(batch, item) - if err != nil { - return 0, err - } - - // Add in gcIndex only if this chunk is not pinned - ok, err := db.pinIndex.Has(item) - if err != nil { - return 0, err - } - if !ok { - err = db.gcIndex.PutInBatch(batch, item) + item.AccessTimestamp = now() + err := db.retrievalAccessIndex.PutInBatch(batch, item) if err != nil { return 0, err } - gcSizeChange++ + } else { + item.AccessTimestamp = i1.AccessTimestamp } - - return gcSizeChange, nil + // item needs to be populated with Radius + item2, err := db.postageRadiusIndex.Get(item) + if err != nil { + return 0, fmt.Errorf("postage chunks index: %w", err) + } + item.Radius = item2.Radius + return db.preserveOrCache(batch, item, false, false) } // setRemove removes the chunk by updating indexes: // - delete from retrieve, pull, gc // Provided batch is updated. -func (db *DB) setRemove(batch *leveldb.Batch, addr swarm.Address) (gcSizeChange int64, err error) { - item := addressToItem(addr) - - // need to get access timestamp here as it is not - // provided by the access function, and it is not - // a property of a chunk provided to Accessor.Put. - i, err := db.retrievalAccessIndex.Get(item) - switch { - case err == nil: - item.AccessTimestamp = i.AccessTimestamp - case errors.Is(err, leveldb.ErrNotFound): - default: - return 0, err +func (db *DB) setRemove(batch *leveldb.Batch, item shed.Item, check bool) (gcSizeChange int64, err error) { + if item.AccessTimestamp == 0 { + i, err := db.retrievalAccessIndex.Get(item) + switch { + case err == nil: + item.AccessTimestamp = i.AccessTimestamp + case errors.Is(err, leveldb.ErrNotFound): + default: + return 0, err + } } - i, err = db.retrievalDataIndex.Get(item) - if err != nil { - return 0, err + if item.StoreTimestamp == 0 { + item, err = db.retrievalDataIndex.Get(item) + if err != nil { + return 0, err + } } - item.StoreTimestamp = i.StoreTimestamp - item.BinID = i.BinID + + db.metrics.GCStoreTimeStamps.Set(float64(item.StoreTimestamp)) + db.metrics.GCStoreAccessTimeStamps.Set(float64(item.AccessTimestamp)) err = db.retrievalDataIndex.DeleteInBatch(batch, item) if err != nil { @@ -254,81 +236,128 @@ func (db *DB) setRemove(batch *leveldb.Batch, addr swarm.Address) (gcSizeChange if err != nil { return 0, err } - err = db.gcIndex.DeleteInBatch(batch, item) + err = db.postageChunksIndex.DeleteInBatch(batch, item) if err != nil { return 0, err } + // unless called by GC which iterates through the gcIndex // a check is needed for decrementing gcSize - // as delete is not reporting if the key/value pair - // is deleted or not - if _, err := db.gcIndex.Get(item); err == nil { - gcSizeChange = -1 + // as delete is not reporting if the key/value pair is deleted or not + if check { + _, err := db.gcIndex.Get(item) + if err != nil { + if !errors.Is(err, leveldb.ErrNotFound) { + return 0, err + } + return 0, db.pinIndex.DeleteInBatch(batch, item) + } } - - return gcSizeChange, nil + err = db.gcIndex.DeleteInBatch(batch, item) + if err != nil { + return 0, err + } + return -1, nil } // setPin increments pin counter for the chunk by updating // pin index and sets the chunk to be excluded from garbage collection. // Provided batch is updated. -func (db *DB) setPin(batch *leveldb.Batch, addr swarm.Address) (err error) { - item := addressToItem(addr) - +func (db *DB) setPin(batch *leveldb.Batch, item shed.Item) (gcSizeChange int64, err error) { // Get the existing pin counter of the chunk - existingPinCounter := uint64(0) - pinnedChunk, err := db.pinIndex.Get(item) + i, err := db.pinIndex.Get(item) + item.PinCounter = i.PinCounter if err != nil { - if errors.Is(err, leveldb.ErrNotFound) { - // If this Address is not present in DB, then its a new entry - existingPinCounter = 0 + if !errors.Is(err, leveldb.ErrNotFound) { + return 0, err + } + // if this Address is not pinned yet, then + i, err := db.retrievalAccessIndex.Get(item) + if err != nil { + if !errors.Is(err, leveldb.ErrNotFound) { + return 0, err + } + // not synced yet + } else { + item.AccessTimestamp = i.AccessTimestamp + i, err = db.retrievalDataIndex.Get(item) + if err != nil { + return 0, err + } + item.StoreTimestamp = i.StoreTimestamp + item.BinID = i.BinID - // Add in gcExcludeIndex of the chunk is not pinned already - err = db.gcExcludeIndex.PutInBatch(batch, item) + err = db.gcIndex.DeleteInBatch(batch, item) if err != nil { - return err + return 0, err } - } else { - return err + gcSizeChange = -1 } - } else { - existingPinCounter = pinnedChunk.PinCounter } // Otherwise increase the existing counter by 1 - item.PinCounter = existingPinCounter + 1 + item.PinCounter++ err = db.pinIndex.PutInBatch(batch, item) if err != nil { - return err + return 0, err } - - return nil + return gcSizeChange, nil } // setUnpin decrements pin counter for the chunk by updating pin index. // Provided batch is updated. -func (db *DB) setUnpin(batch *leveldb.Batch, addr swarm.Address) (err error) { +func (db *DB) setUnpin(batch *leveldb.Batch, addr swarm.Address) (gcSizeChange int64, err error) { item := addressToItem(addr) // Get the existing pin counter of the chunk - pinnedChunk, err := db.pinIndex.Get(item) + i, err := db.pinIndex.Get(item) if err != nil { - return err + return 0, err } - + item.PinCounter = i.PinCounter // Decrement the pin counter or // delete it from pin index if the pin counter has reached 0 - if pinnedChunk.PinCounter > 1 { - item.PinCounter = pinnedChunk.PinCounter - 1 - err = db.pinIndex.PutInBatch(batch, item) - if err != nil { - return err + if item.PinCounter > 1 { + item.PinCounter-- + return 0, db.pinIndex.PutInBatch(batch, item) + } + + // PinCounter == 0 + + err = db.pinIndex.DeleteInBatch(batch, item) + if err != nil { + return 0, err + } + i, err = db.retrievalDataIndex.Get(item) + if err != nil { + return 0, err + } + item.StoreTimestamp = i.StoreTimestamp + item.BinID = i.BinID + item.BatchID = i.BatchID + i, err = db.pushIndex.Get(item) + if !errors.Is(err, leveldb.ErrNotFound) { + // err is either nil or not leveldb.ErrNotFound + return 0, err + } + + i, err = db.retrievalAccessIndex.Get(item) + if err != nil { + if !errors.Is(err, leveldb.ErrNotFound) { + return 0, err } - } else { - err = db.pinIndex.DeleteInBatch(batch, item) + item.AccessTimestamp = now() + err = db.retrievalAccessIndex.PutInBatch(batch, item) if err != nil { - return err + return 0, err } + } else { + item.AccessTimestamp = i.AccessTimestamp + } + err = db.gcIndex.PutInBatch(batch, item) + if err != nil { + return 0, err } - return nil + gcSizeChange++ + return gcSizeChange, nil } diff --git a/pkg/localstore/mode_set_test.go b/pkg/localstore/mode_set_test.go index 3f3aad25457..48ad2fc58f3 100644 --- a/pkg/localstore/mode_set_test.go +++ b/pkg/localstore/mode_set_test.go @@ -19,78 +19,12 @@ package localstore import ( "context" "errors" - "io/ioutil" "testing" - "github.com/ethersphere/bee/pkg/logging" - statestore "github.com/ethersphere/bee/pkg/statestore/mock" - - "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/storage" - "github.com/ethersphere/bee/pkg/tags" - tagtesting "github.com/ethersphere/bee/pkg/tags/testing" "github.com/syndtr/goleveldb/leveldb" ) -// here we try to set a normal tag (that should be handled by pushsync) -// as a result we should expect the tag value to remain in the pull index -// and we expect that the tag should not be incremented by pull sync set -func TestModeSetSyncNormalTag(t *testing.T) { - mockStatestore := statestore.NewStateStore() - logger := logging.New(ioutil.Discard, 0) - db := newTestDB(t, &Options{Tags: tags.NewTags(mockStatestore, logger)}) - - tag, err := db.tags.Create(1) - if err != nil { - t.Fatal(err) - } - - ch := generateTestRandomChunk().WithTagID(tag.Uid) - _, err = db.Put(context.Background(), storage.ModePutUpload, ch) - if err != nil { - t.Fatal(err) - } - - err = tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on - if err != nil { - t.Fatal(err) - } - - item, err := db.pullIndex.Get(shed.Item{ - Address: ch.Address().Bytes(), - BinID: 1, - }) - if err != nil { - t.Fatal(err) - } - - if item.Tag != tag.Uid { - t.Fatalf("unexpected tag id value got %d want %d", item.Tag, tag.Uid) - } - - err = db.Set(context.Background(), storage.ModeSetSync, ch.Address()) - if err != nil { - t.Fatal(err) - } - - item, err = db.pullIndex.Get(shed.Item{ - Address: ch.Address().Bytes(), - BinID: 1, - }) - if err != nil { - t.Fatal(err) - } - - // expect the same tag Uid because when we set pull sync on a normal tag - // the tag Uid should remain untouched in pull index - if item.Tag != tag.Uid { - t.Fatalf("unexpected tag id value got %d want %d", item.Tag, tag.Uid) - } - - // 1 stored (because incremented manually in test), 1 sent, 1 synced, 1 total - tagtesting.CheckTag(t, tag, 0, 1, 0, 1, 1, 1) -} - // TestModeSetRemove validates ModeSetRemove index values on the provided DB. func TestModeSetRemove(t *testing.T) { for _, tc := range multiChunkTestCases { diff --git a/pkg/localstore/pin.go b/pkg/localstore/pin.go new file mode 100644 index 00000000000..5371ef8b7d7 --- /dev/null +++ b/pkg/localstore/pin.go @@ -0,0 +1,26 @@ +package localstore + +import ( + "errors" + + "github.com/ethersphere/bee/pkg/shed" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/syndtr/goleveldb/leveldb" +) + +// pinCounter returns the pin counter for a given swarm address, provided that the +// address has been pinned. +func (db *DB) pinCounter(address swarm.Address) (uint64, error) { + out, err := db.pinIndex.Get(shed.Item{ + Address: address.Bytes(), + }) + + if err != nil { + if errors.Is(err, leveldb.ErrNotFound) { + return 0, storage.ErrNotFound + } + return 0, err + } + return out.PinCounter, nil +} diff --git a/pkg/localstore/pin_test.go b/pkg/localstore/pin_test.go new file mode 100644 index 00000000000..1bc912a8907 --- /dev/null +++ b/pkg/localstore/pin_test.go @@ -0,0 +1,208 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package localstore + +import ( + "context" + "errors" + "testing" + + "github.com/ethersphere/bee/pkg/shed" + "github.com/ethersphere/bee/pkg/storage" +) + +func TestPinCounter(t *testing.T) { + chunk := generateTestRandomChunk() + db := newTestDB(t, nil) + addr := chunk.Address() + ctx := context.Background() + _, err := db.Put(ctx, storage.ModePutUpload, chunk) + if err != nil { + t.Fatal(err) + } + var pinCounter uint64 + t.Run("+1 after first pin", func(t *testing.T) { + err := db.Set(ctx, storage.ModeSetPin, addr) + if err != nil { + t.Fatal(err) + } + pinCounter, err = db.pinCounter(addr) + if err != nil { + t.Fatal(err) + } + if pinCounter != 1 { + t.Fatalf("want pin counter %d but got %d", 1, pinCounter) + } + }) + t.Run("2 after second pin", func(t *testing.T) { + err = db.Set(ctx, storage.ModeSetPin, addr) + if err != nil { + t.Fatal(err) + } + pinCounter, err = db.pinCounter(addr) + if err != nil { + t.Fatal(err) + } + if pinCounter != 2 { + t.Fatalf("want pin counter %d but got %d", 2, pinCounter) + } + }) + t.Run("1 after first unpin", func(t *testing.T) { + err = db.Set(ctx, storage.ModeSetUnpin, addr) + if err != nil { + t.Fatal(err) + } + pinCounter, err = db.pinCounter(addr) + if err != nil { + t.Fatal(err) + } + if pinCounter != 1 { + t.Fatalf("want pin counter %d but got %d", 1, pinCounter) + } + }) + t.Run("not found after second unpin", func(t *testing.T) { + err = db.Set(ctx, storage.ModeSetUnpin, addr) + if err != nil { + t.Fatal(err) + } + _, err = db.pinCounter(addr) + if !errors.Is(err, storage.ErrNotFound) { + t.Fatal(err) + } + }) +} + +// Pin a file, upload chunks to go past the gc limit to trigger GC, +// check if the pinned files are still around and removed from gcIndex +func TestPinIndexes(t *testing.T) { + ctx := context.Background() + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) + + db := newTestDB(t, &Options{ + Capacity: 150, + }) + + ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) + + addr := ch.Address() + _, err := db.Put(ctx, storage.ModePutUpload, ch) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "putUpload", db, 1, 0, 1, 1, 0, 0) + + err = db.Set(ctx, storage.ModeSetSync, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setSync", db, 1, 1, 0, 1, 0, 1) + + err = db.Set(ctx, storage.ModeSetPin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setPin", db, 1, 1, 0, 1, 1, 0) + + err = db.Set(ctx, storage.ModeSetPin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setPin 2", db, 1, 1, 0, 1, 1, 0) + + err = db.Set(ctx, storage.ModeSetUnpin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setUnPin", db, 1, 1, 0, 1, 1, 0) + + err = db.Set(ctx, storage.ModeSetUnpin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setUnPin 2", db, 1, 1, 0, 1, 0, 1) + +} + +func TestPinIndexesSync(t *testing.T) { + ctx := context.Background() + t.Cleanup(setWithinRadiusFunc(func(_ *DB, _ shed.Item) bool { return false })) + + db := newTestDB(t, &Options{ + Capacity: 150, + }) + + ch := generateTestRandomChunk() + // call unreserve on the batch with radius 0 so that + // localstore is aware of the batch and the chunk can + // be inserted into the database + unreserveChunkBatch(t, db, 0, ch) + + addr := ch.Address() + _, err := db.Put(ctx, storage.ModePutUpload, ch) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "putUpload", db, 1, 0, 1, 1, 0, 0) + + err = db.Set(ctx, storage.ModeSetPin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setPin", db, 1, 0, 1, 1, 1, 0) + + err = db.Set(ctx, storage.ModeSetPin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setPin 2", db, 1, 0, 1, 1, 1, 0) + + err = db.Set(ctx, storage.ModeSetUnpin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setUnPin", db, 1, 0, 1, 1, 1, 0) + + err = db.Set(ctx, storage.ModeSetUnpin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setUnPin 2", db, 1, 0, 1, 1, 0, 0) + + err = db.Set(ctx, storage.ModeSetPin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setPin 3", db, 1, 0, 1, 1, 1, 0) + + err = db.Set(ctx, storage.ModeSetSync, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setSync", db, 1, 1, 0, 1, 1, 0) + + err = db.Set(ctx, storage.ModeSetUnpin, addr) + if err != nil { + t.Fatal(err) + } + runCountsTest(t, "setUnPin", db, 1, 1, 0, 1, 0, 1) +} + +func runCountsTest(t *testing.T, name string, db *DB, r, a, push, pull, pin, gc int) { + t.Helper() + t.Run(name, func(t *testing.T) { + t.Helper() + t.Run("retrieval data Index count", newItemsCountTest(db.retrievalDataIndex, r)) + t.Run("retrieval access Index count", newItemsCountTest(db.retrievalAccessIndex, a)) + t.Run("push Index count", newItemsCountTest(db.pushIndex, push)) + t.Run("pull Index count", newItemsCountTest(db.pullIndex, pull)) + t.Run("pin Index count", newItemsCountTest(db.pinIndex, pin)) + t.Run("gc index count", newItemsCountTest(db.gcIndex, gc)) + t.Run("gc size", newIndexGCSizeTest(db)) + }) +} diff --git a/pkg/localstore/reserve.go b/pkg/localstore/reserve.go new file mode 100644 index 00000000000..94755e0bd25 --- /dev/null +++ b/pkg/localstore/reserve.go @@ -0,0 +1,93 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package localstore + +import ( + "errors" + "fmt" + + "github.com/ethersphere/bee/pkg/shed" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/syndtr/goleveldb/leveldb" +) + +// UnreserveBatch atomically unpins chunks of a batch in proximity order upto and including po. +// Unpinning will result in all chunks with pincounter 0 to be put in the gc index +// so if a chunk was only pinned by the reserve, unreserving it will make it gc-able. +func (db *DB) UnreserveBatch(id []byte, radius uint8) error { + db.batchMu.Lock() + defer db.batchMu.Unlock() + var ( + item = shed.Item{ + BatchID: id, + } + batch = new(leveldb.Batch) + ) + + i, err := db.postageRadiusIndex.Get(item) + if err != nil { + if !errors.Is(err, leveldb.ErrNotFound) { + return err + } + item.Radius = radius + if err := db.postageRadiusIndex.PutInBatch(batch, item); err != nil { + return err + } + return db.shed.WriteBatch(batch) + } + oldRadius := i.Radius + var gcSizeChange int64 // number to add or subtract from gcSize + unpin := func(item shed.Item) (stop bool, err error) { + c, err := db.setUnpin(batch, swarm.NewAddress(item.Address)) + if err != nil { + return false, fmt.Errorf("unpin: %w", err) + } + + gcSizeChange += c + return false, err + } + + // iterate over chunk in bins + for bin := oldRadius; bin < radius; bin++ { + err := db.postageChunksIndex.Iterate(unpin, &shed.IterateOptions{Prefix: append(id, bin)}) + if err != nil { + return err + } + // adjust gcSize + if err := db.incGCSizeInBatch(batch, gcSizeChange); err != nil { + return err + } + item.Radius = bin + if err := db.postageRadiusIndex.PutInBatch(batch, item); err != nil { + return err + } + if bin == swarm.MaxPO { + if err := db.postageRadiusIndex.DeleteInBatch(batch, item); err != nil { + return err + } + } + if err := db.shed.WriteBatch(batch); err != nil { + return err + } + batch = new(leveldb.Batch) + gcSizeChange = 0 + } + + gcSize, err := db.gcSize.Get() + if err != nil && !errors.Is(err, leveldb.ErrNotFound) { + return err + } + // trigger garbage collection if we reached the capacity + if gcSize >= db.capacity { + db.triggerGarbageCollection() + } + + return nil +} + +func withinRadius(db *DB, item shed.Item) bool { + po := db.po(swarm.NewAddress(item.Address)) + return po >= item.Radius +} diff --git a/pkg/localstore/reserve_test.go b/pkg/localstore/reserve_test.go new file mode 100644 index 00000000000..6fce295551c --- /dev/null +++ b/pkg/localstore/reserve_test.go @@ -0,0 +1,430 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package localstore + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/ethersphere/bee/pkg/shed" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/syndtr/goleveldb/leveldb" +) + +// TestDB_ReserveGC_AllOutOfRadius tests that when all chunks fall outside of +// batch radius, all end up in the cache and that gc size eventually +// converges to the correct value. +func TestDB_ReserveGC_AllOutOfRadius(t *testing.T) { + chunkCount := 150 + + var closed chan struct{} + testHookCollectGarbageChan := make(chan uint64) + t.Cleanup(setTestHookCollectGarbage(func(collectedCount uint64) { + select { + case testHookCollectGarbageChan <- collectedCount: + case <-closed: + } + })) + + db := newTestDB(t, &Options{ + Capacity: 100, + }) + closed = db.close + + addrs := make([]swarm.Address, 0) + + for i := 0; i < chunkCount; i++ { + ch := generateTestRandomChunkAt(swarm.NewAddress(db.baseKey), 2).WithBatch(3, 3) + err := db.UnreserveBatch(ch.Stamp().BatchID(), 4) + if err != nil { + t.Fatal(err) + } + _, err = db.Put(context.Background(), storage.ModePutUpload, ch) + if err != nil { + t.Fatal(err) + } + err = db.Set(context.Background(), storage.ModeSetSync, ch.Address()) + if err != nil { + t.Fatal(err) + } + + addrs = append(addrs, ch.Address()) + } + + gcTarget := db.gcTarget() + + for { + select { + case <-testHookCollectGarbageChan: + case <-time.After(10 * time.Second): + t.Fatal("collect garbage timeout") + } + gcSize, err := db.gcSize.Get() + if err != nil { + t.Fatal(err) + } + if gcSize == gcTarget { + break + } + } + + t.Run("pull index count", newItemsCountTest(db.pullIndex, int(gcTarget))) + + t.Run("postage chunks index count", newItemsCountTest(db.postageChunksIndex, int(gcTarget))) + + // postageRadiusIndex gets removed only when the batches are called with evict on MaxPO+1 + // therefore, the expected index count here is larger than one would expect. + t.Run("postage radius index count", newItemsCountTest(db.postageRadiusIndex, chunkCount)) + + t.Run("gc index count", newItemsCountTest(db.gcIndex, int(gcTarget))) + + t.Run("gc size", newIndexGCSizeTest(db)) + + // the first synced chunk should be removed + t.Run("get the first synced chunk", func(t *testing.T) { + _, err := db.Get(context.Background(), storage.ModeGetRequest, addrs[0]) + if !errors.Is(err, storage.ErrNotFound) { + t.Errorf("got error %v, want %v", err, storage.ErrNotFound) + } + }) + + t.Run("only first inserted chunks should be removed", func(t *testing.T) { + for i := 0; i < (chunkCount - int(gcTarget)); i++ { + _, err := db.Get(context.Background(), storage.ModeGetRequest, addrs[i]) + if !errors.Is(err, storage.ErrNotFound) { + t.Errorf("got error %v, want %v", err, storage.ErrNotFound) + } + } + }) + + // last synced chunk should not be removed + t.Run("get most recent synced chunk", func(t *testing.T) { + _, err := db.Get(context.Background(), storage.ModeGetRequest, addrs[len(addrs)-1]) + if err != nil { + t.Fatal(err) + } + }) +} + +// TestDB_ReserveGC_AllWithinRadius tests that when all chunks fall within +// batch radius, none get collected. +func TestDB_ReserveGC_AllWithinRadius(t *testing.T) { + chunkCount := 150 + + var closed chan struct{} + testHookCollectGarbageChan := make(chan uint64) + t.Cleanup(setTestHookCollectGarbage(func(collectedCount uint64) { + select { + case testHookCollectGarbageChan <- collectedCount: + case <-closed: + } + })) + + db := newTestDB(t, &Options{ + Capacity: 100, + }) + closed = db.close + + addrs := make([]swarm.Address, 0) + + for i := 0; i < chunkCount; i++ { + ch := generateTestRandomChunkAt(swarm.NewAddress(db.baseKey), 2).WithBatch(2, 3) + err := db.UnreserveBatch(ch.Stamp().BatchID(), 2) + if err != nil { + t.Fatal(err) + } + _, err = db.Put(context.Background(), storage.ModePutUpload, ch) + if err != nil { + t.Fatal(err) + } + err = db.Set(context.Background(), storage.ModeSetSync, ch.Address()) + if err != nil { + t.Fatal(err) + } + + addrs = append(addrs, ch.Address()) + } + + select { + case <-testHookCollectGarbageChan: + t.Fatal("gc ran but shouldnt have") + case <-time.After(1 * time.Second): + } + + t.Run("pull index count", newItemsCountTest(db.pullIndex, chunkCount)) + + t.Run("postage chunks index count", newItemsCountTest(db.postageChunksIndex, chunkCount)) + + t.Run("postage radius index count", newItemsCountTest(db.postageRadiusIndex, chunkCount)) + + t.Run("gc index count", newItemsCountTest(db.gcIndex, 0)) + + t.Run("gc size", newIndexGCSizeTest(db)) + + t.Run("all chunks should be accessible", func(t *testing.T) { + for _, a := range addrs { + _, err := db.Get(context.Background(), storage.ModeGetRequest, a) + if err != nil { + t.Errorf("got error %v, want none", err) + } + } + }) +} + +// TestDB_ReserveGC_Unreserve tests that after calling UnreserveBatch +// with a certain radius change, the correct chunks get put into the +// GC index and eventually get garbage collected. +// batch radius, none get collected. +func TestDB_ReserveGC_Unreserve(t *testing.T) { + chunkCount := 150 + + var closed chan struct{} + testHookCollectGarbageChan := make(chan uint64) + t.Cleanup(setTestHookCollectGarbage(func(collectedCount uint64) { + select { + case testHookCollectGarbageChan <- collectedCount: + case <-closed: + } + })) + + db := newTestDB(t, &Options{ + Capacity: 100, + }) + closed = db.close + + // put the first chunkCount chunks within radius + for i := 0; i < chunkCount; i++ { + ch := generateTestRandomChunkAt(swarm.NewAddress(db.baseKey), 2).WithBatch(2, 3) + err := db.UnreserveBatch(ch.Stamp().BatchID(), 2) + if err != nil { + t.Fatal(err) + } + _, err = db.Put(context.Background(), storage.ModePutUpload, ch) + if err != nil { + t.Fatal(err) + } + err = db.Set(context.Background(), storage.ModeSetSync, ch.Address()) + if err != nil { + t.Fatal(err) + } + } + + var po4Chs []swarm.Chunk + for i := 0; i < chunkCount; i++ { + ch := generateTestRandomChunkAt(swarm.NewAddress(db.baseKey), 4).WithBatch(2, 3) + err := db.UnreserveBatch(ch.Stamp().BatchID(), 2) + if err != nil { + t.Fatal(err) + } + _, err = db.Put(context.Background(), storage.ModePutUpload, ch) + if err != nil { + t.Fatal(err) + } + err = db.Set(context.Background(), storage.ModeSetSync, ch.Address()) + if err != nil { + t.Fatal(err) + } + po4Chs = append(po4Chs, ch) + } + + var gcChs []swarm.Chunk + for i := 0; i < 100; i++ { + gcch := generateTestRandomChunkAt(swarm.NewAddress(db.baseKey), 2).WithBatch(2, 3) + err := db.UnreserveBatch(gcch.Stamp().BatchID(), 2) + if err != nil { + t.Fatal(err) + } + _, err = db.Put(context.Background(), storage.ModePutUpload, gcch) + if err != nil { + t.Fatal(err) + } + err = db.Set(context.Background(), storage.ModeSetSync, gcch.Address()) + if err != nil { + t.Fatal(err) + } + gcChs = append(gcChs, gcch) + } + + // radius increases from 2 to 3, chunk is in PO 2, therefore it should be + // GCd + for _, ch := range gcChs { + err := db.UnreserveBatch(ch.Stamp().BatchID(), 3) + if err != nil { + t.Fatal(err) + } + } + + gcTarget := db.gcTarget() + + for { + select { + case <-testHookCollectGarbageChan: + case <-time.After(10 * time.Second): + t.Fatal("collect garbage timeout") + } + gcSize, err := db.gcSize.Get() + if err != nil { + t.Fatal(err) + } + if gcSize == gcTarget { + break + } + } + t.Run("pull index count", newItemsCountTest(db.pullIndex, chunkCount*2+90)) + + t.Run("postage chunks index count", newItemsCountTest(db.postageChunksIndex, chunkCount*2+90)) + + // postageRadiusIndex gets removed only when the batches are called with evict on MaxPO+1 + // therefore, the expected index count here is larger than one would expect. + t.Run("postage radius index count", newItemsCountTest(db.postageRadiusIndex, chunkCount*2+100)) + + t.Run("gc index count", newItemsCountTest(db.gcIndex, 90)) + + t.Run("gc size", newIndexGCSizeTest(db)) + + t.Run("first ten unreserved chunks should not be accessible", func(t *testing.T) { + for _, ch := range gcChs[:10] { + _, err := db.Get(context.Background(), storage.ModeGetRequest, ch.Address()) + if err == nil { + t.Error("got no error, want NotFound") + } + } + }) + + t.Run("the rest should be accessible", func(t *testing.T) { + for _, ch := range gcChs[10:] { + _, err := db.Get(context.Background(), storage.ModeGetRequest, ch.Address()) + if err != nil { + t.Errorf("got error %v but want none", err) + } + } + }) + + t.Run("po 4 chunks accessible", func(t *testing.T) { + for _, ch := range po4Chs { + _, err := db.Get(context.Background(), storage.ModeGetRequest, ch.Address()) + if err != nil { + t.Errorf("got error %v but want none", err) + } + } + }) +} + +// TestDB_ReserveGC_EvictMaxPO tests that when unreserving a batch at +// swarm.MaxPO+1 results in the correct behaviour. +func TestDB_ReserveGC_EvictMaxPO(t *testing.T) { + chunkCount := 150 + + var closed chan struct{} + testHookCollectGarbageChan := make(chan uint64) + t.Cleanup(setTestHookCollectGarbage(func(collectedCount uint64) { + select { + case testHookCollectGarbageChan <- collectedCount: + case <-closed: + } + })) + + db := newTestDB(t, &Options{ + Capacity: 100, + }) + closed = db.close + + // put the first chunkCount chunks within radius + for i := 0; i < chunkCount; i++ { + ch := generateTestRandomChunkAt(swarm.NewAddress(db.baseKey), 2).WithBatch(2, 3) + err := db.UnreserveBatch(ch.Stamp().BatchID(), 2) + if err != nil { + t.Fatal(err) + } + _, err = db.Put(context.Background(), storage.ModePutUpload, ch) + if err != nil { + t.Fatal(err) + } + err = db.Set(context.Background(), storage.ModeSetSync, ch.Address()) + if err != nil { + t.Fatal(err) + } + } + + var gcChs []swarm.Chunk + for i := 0; i < 100; i++ { + gcch := generateTestRandomChunkAt(swarm.NewAddress(db.baseKey), 2).WithBatch(2, 3) + err := db.UnreserveBatch(gcch.Stamp().BatchID(), 2) + if err != nil { + t.Fatal(err) + } + _, err = db.Put(context.Background(), storage.ModePutUpload, gcch) + if err != nil { + t.Fatal(err) + } + err = db.Set(context.Background(), storage.ModeSetSync, gcch.Address()) + if err != nil { + t.Fatal(err) + } + gcChs = append(gcChs, gcch) + } + + for _, ch := range gcChs { + err := db.UnreserveBatch(ch.Stamp().BatchID(), swarm.MaxPO+1) + if err != nil { + t.Fatal(err) + } + } + + gcTarget := db.gcTarget() + + for { + select { + case <-testHookCollectGarbageChan: + case <-time.After(10 * time.Second): + t.Fatal("collect garbage timeout") + } + gcSize, err := db.gcSize.Get() + if err != nil { + t.Fatal(err) + } + if gcSize == gcTarget { + break + } + } + t.Run("pull index count", newItemsCountTest(db.pullIndex, chunkCount+90)) + + t.Run("postage chunks index count", newItemsCountTest(db.postageChunksIndex, chunkCount+90)) + + t.Run("postage radius index count", newItemsCountTest(db.postageRadiusIndex, chunkCount)) + + t.Run("gc index count", newItemsCountTest(db.gcIndex, 90)) + + t.Run("gc size", newIndexGCSizeTest(db)) + + t.Run("first ten unreserved chunks should not be accessible", func(t *testing.T) { + for _, ch := range gcChs[:10] { + _, err := db.Get(context.Background(), storage.ModeGetRequest, ch.Address()) + if err == nil { + t.Error("got no error, want NotFound") + } + } + }) + + t.Run("the rest should be accessible", func(t *testing.T) { + for _, ch := range gcChs[10:] { + _, err := db.Get(context.Background(), storage.ModeGetRequest, ch.Address()) + if err != nil { + t.Errorf("got error %v but want none", err) + } + } + }) + t.Run("batches for the all evicted batches should be evicted", func(t *testing.T) { + for _, ch := range gcChs { + item := shed.Item{BatchID: ch.Stamp().BatchID()} + if _, err := db.postageRadiusIndex.Get(item); !errors.Is(err, leveldb.ErrNotFound) { + t.Fatalf("wanted ErrNotFound but got %v", err) + } + } + }) +} diff --git a/pkg/localstore/subscription_push.go b/pkg/localstore/subscription_push.go index 337ebdacc5d..b3bc83096d8 100644 --- a/pkg/localstore/subscription_push.go +++ b/pkg/localstore/subscription_push.go @@ -22,6 +22,7 @@ import ( "time" "github.com/ethersphere/bee/pkg/flipflop" + "github.com/ethersphere/bee/pkg/postage" "github.com/ethersphere/bee/pkg/shed" "github.com/ethersphere/bee/pkg/swarm" ) @@ -75,8 +76,9 @@ func (db *DB) SubscribePush(ctx context.Context) (c <-chan swarm.Chunk, stop fun return true, err } + stamp := postage.NewStamp(dataItem.BatchID, dataItem.Sig) select { - case chunks <- swarm.NewChunk(swarm.NewAddress(dataItem.Address), dataItem.Data).WithTagID(item.Tag): + case chunks <- swarm.NewChunk(swarm.NewAddress(dataItem.Address), dataItem.Data).WithTagID(item.Tag).WithStamp(stamp): count++ // set next iteration start item // when its chunk is successfully sent to channel diff --git a/pkg/localstore/subscription_push_test.go b/pkg/localstore/subscription_push_test.go index 83318902b9f..e65c00df0f7 100644 --- a/pkg/localstore/subscription_push_test.go +++ b/pkg/localstore/subscription_push_test.go @@ -19,6 +19,7 @@ package localstore import ( "bytes" "context" + "errors" "fmt" "sync" "testing" @@ -74,8 +75,11 @@ func TestDB_SubscribePush(t *testing.T) { // receive and validate addresses from the subscription go func() { - var err error - var i int // address index + var ( + err, ierr error + i int // address index + gotStamp, wantStamp []byte + ) for { select { case got, ok := <-ch: @@ -93,6 +97,16 @@ func TestDB_SubscribePush(t *testing.T) { if !got.Address().Equal(want.Address()) { err = fmt.Errorf("got chunk %v address %s, want %s", i, got.Address(), want.Address()) } + if gotStamp, ierr = got.Stamp().MarshalBinary(); ierr != nil { + err = ierr + } + if wantStamp, ierr = want.Stamp().MarshalBinary(); ierr != nil { + err = ierr + } + if !bytes.Equal(gotStamp, wantStamp) { + err = errors.New("stamps don't match") + } + i++ // send one and only one error per received address select { diff --git a/pkg/netstore/netstore.go b/pkg/netstore/netstore.go index 8a706687df3..e9dd331a0b2 100644 --- a/pkg/netstore/netstore.go +++ b/pkg/netstore/netstore.go @@ -25,6 +25,7 @@ type store struct { storage.Storer retrieval retrieval.Interface logger logging.Logger + validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error) recoveryCallback recovery.Callback // this is the callback to be executed when a chunk fails to be retrieved } @@ -33,8 +34,8 @@ var ( ) // New returns a new NetStore that wraps a given Storer. -func New(s storage.Storer, rcb recovery.Callback, r retrieval.Interface, logger logging.Logger) storage.Storer { - return &store{Storer: s, recoveryCallback: rcb, retrieval: r, logger: logger} +func New(s storage.Storer, validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), rcb recovery.Callback, r retrieval.Interface, logger logging.Logger) storage.Storer { + return &store{Storer: s, validStamp: validStamp, recoveryCallback: rcb, retrieval: r, logger: logger} } // Get retrieves a given chunk address. @@ -55,13 +56,25 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres go s.recoveryCallback(addr, targets) return nil, ErrRecoveryAttempt } + stamp, err := ch.Stamp().MarshalBinary() + if err != nil { + return nil, err + } putMode := storage.ModePutRequest if mode == storage.ModeGetRequestPin { putMode = storage.ModePutRequestPin } - _, err = s.Storer.Put(ctx, putMode, ch) + cch, err := s.validStamp(ch, stamp) + if err != nil { + // if a chunk with an invalid postage stamp was received + // we force it into the cache. + putMode = storage.ModePutRequestCache + cch = ch + } + + _, err = s.Storer.Put(ctx, putMode, cch) if err != nil { return nil, fmt.Errorf("netstore retrieve put: %w", err) } diff --git a/pkg/netstore/netstore_test.go b/pkg/netstore/netstore_test.go index c677a4e0e11..c56c0c1dc31 100644 --- a/pkg/netstore/netstore_test.go +++ b/pkg/netstore/netstore_test.go @@ -16,6 +16,7 @@ import ( "github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/netstore" + postagetesting "github.com/ethersphere/bee/pkg/postage/testing" "github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/recovery" "github.com/ethersphere/bee/pkg/sctx" @@ -25,11 +26,12 @@ import ( ) var chunkData = []byte("mockdata") +var chunkStamp = postagetesting.MustNewStamp() // TestNetstoreRetrieval verifies that a chunk is asked from the network whenever // it is not found locally func TestNetstoreRetrieval(t *testing.T) { - retrieve, store, nstore := newRetrievingNetstore(nil) + retrieve, store, nstore := newRetrievingNetstore(nil, noopValidStamp) addr := swarm.MustParseHexAddress("000001") _, err := nstore.Get(context.Background(), storage.ModeGetRequest, addr) if err != nil { @@ -73,7 +75,7 @@ func TestNetstoreRetrieval(t *testing.T) { // TestNetstoreNoRetrieval verifies that a chunk is not requested from the network // whenever it is found locally. func TestNetstoreNoRetrieval(t *testing.T) { - retrieve, store, nstore := newRetrievingNetstore(nil) + retrieve, store, nstore := newRetrievingNetstore(nil, noopValidStamp) addr := swarm.MustParseHexAddress("000001") // store should have the chunk in advance @@ -103,7 +105,7 @@ func TestRecovery(t *testing.T) { callbackC: callbackWasCalled, } - retrieve, _, nstore := newRetrievingNetstore(rec.recovery) + retrieve, _, nstore := newRetrievingNetstore(rec.recovery, noopValidStamp) addr := swarm.MustParseHexAddress("deadbeef") retrieve.failure = true ctx := context.Background() @@ -123,7 +125,7 @@ func TestRecovery(t *testing.T) { } func TestInvalidRecoveryFunction(t *testing.T) { - retrieve, _, nstore := newRetrievingNetstore(nil) + retrieve, _, nstore := newRetrievingNetstore(nil, noopValidStamp) addr := swarm.MustParseHexAddress("deadbeef") retrieve.failure = true ctx := context.Background() @@ -135,12 +137,60 @@ func TestInvalidRecoveryFunction(t *testing.T) { } } +func TestInvalidPostageStamp(t *testing.T) { + f := func(c swarm.Chunk, _ []byte) (swarm.Chunk, error) { + return nil, errors.New("invalid postage stamp") + } + retrieve, store, nstore := newRetrievingNetstore(nil, f) + addr := swarm.MustParseHexAddress("000001") + _, err := nstore.Get(context.Background(), storage.ModeGetRequest, addr) + if err != nil { + t.Fatal(err) + } + if !retrieve.called { + t.Fatal("retrieve request not issued") + } + if retrieve.callCount != 1 { + t.Fatalf("call count %d", retrieve.callCount) + } + if !retrieve.addr.Equal(addr) { + t.Fatalf("addresses not equal. got %s want %s", retrieve.addr, addr) + } + + // store should have the chunk now + d, err := store.Get(context.Background(), storage.ModeGetRequest, addr) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(d.Data(), chunkData) { + t.Fatal("chunk data not equal to expected data") + } + + if mode := store.GetModePut(addr); mode != storage.ModePutRequestCache { + t.Fatalf("wanted ModePutRequestCache but got %s", mode) + } + + // check that the second call does not result in another retrieve request + d, err = nstore.Get(context.Background(), storage.ModeGetRequest, addr) + if err != nil { + t.Fatal(err) + } + + if retrieve.callCount != 1 { + t.Fatalf("call count %d", retrieve.callCount) + } + if !bytes.Equal(d.Data(), chunkData) { + t.Fatal("chunk data not equal to expected data") + } +} + // returns a mock retrieval protocol, a mock local storage and a netstore -func newRetrievingNetstore(rec recovery.Callback) (ret *retrievalMock, mockStore, ns storage.Storer) { +func newRetrievingNetstore(rec recovery.Callback, validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error)) (ret *retrievalMock, mockStore *mock.MockStorer, ns storage.Storer) { retrieve := &retrievalMock{} store := mock.NewStorer() logger := logging.New(ioutil.Discard, 0) - return retrieve, store, netstore.New(store, rec, retrieve, logger) + return retrieve, store, netstore.New(store, validStamp, rec, retrieve, logger) } type retrievalMock struct { @@ -157,7 +207,7 @@ func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) ( r.called = true atomic.AddInt32(&r.callCount, 1) r.addr = addr - return swarm.NewChunk(addr, chunkData), nil + return swarm.NewChunk(addr, chunkData).WithStamp(chunkStamp), nil } type mockRecovery struct { @@ -172,3 +222,7 @@ func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets pss.Targets func (r *mockRecovery) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) { return nil, fmt.Errorf("chunk not found") } + +var noopValidStamp = func(c swarm.Chunk, _ []byte) (swarm.Chunk, error) { + return c, nil +} diff --git a/pkg/node/node.go b/pkg/node/node.go index f70c8892bc2..95b1832ffec 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -10,6 +10,7 @@ package node import ( "context" "crypto/ecdsa" + "errors" "fmt" "io" "log" @@ -20,7 +21,6 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/ethclient" "github.com/ethersphere/bee/pkg/accounting" "github.com/ethersphere/bee/pkg/addressbook" "github.com/ethersphere/bee/pkg/api" @@ -36,6 +36,11 @@ import ( "github.com/ethersphere/bee/pkg/p2p/libp2p" "github.com/ethersphere/bee/pkg/pingpong" "github.com/ethersphere/bee/pkg/pinning" + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/postage/batchservice" + "github.com/ethersphere/bee/pkg/postage/batchstore" + "github.com/ethersphere/bee/pkg/postage/listener" + "github.com/ethersphere/bee/pkg/postage/postagecontract" "github.com/ethersphere/bee/pkg/pricer" "github.com/ethersphere/bee/pkg/pricing" "github.com/ethersphere/bee/pkg/pss" @@ -51,7 +56,6 @@ import ( "github.com/ethersphere/bee/pkg/settlement/pseudosettle" "github.com/ethersphere/bee/pkg/settlement/swap" "github.com/ethersphere/bee/pkg/settlement/swap/chequebook" - "github.com/ethersphere/bee/pkg/settlement/swap/transaction" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/tags" @@ -84,6 +88,7 @@ type Bee struct { ethClientCloser func() transactionMonitorCloser io.Closer recoveryHandleCleanup func() + listenerCloser io.Closer } type Options struct { @@ -119,6 +124,8 @@ type Options struct { SwapInitialDeposit string SwapEnable bool FullNodeMode bool + PostageContractAddress string + PriceOracleAddress string } func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, signer crypto.Signer, networkID uint64, logger logging.Logger, libp2pPrivateKey, pssPrivateKey *ecdsa.PrivateKey, o Options) (b *Bee, err error) { @@ -193,30 +200,26 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, addressbook := addressbook.New(stateStore) - var swapBackend *ethclient.Client - var overlayEthAddress common.Address - var chainID int64 - var transactionService transaction.Service - var transactionMonitor transaction.Monitor var chequebookFactory chequebook.Factory var chequebookService chequebook.Service var chequeStore chequebook.ChequeStore var cashoutService chequebook.CashoutService - if o.SwapEnable { - swapBackend, overlayEthAddress, chainID, transactionMonitor, transactionService, err = InitChain( - p2pCtx, - logger, - stateStore, - o.SwapEndpoint, - signer, - ) - if err != nil { - return nil, err - } - b.ethClientCloser = swapBackend.Close - b.transactionMonitorCloser = transactionMonitor + swapBackend, overlayEthAddress, chainID, transactionMonitor, transactionService, err := InitChain( + p2pCtx, + logger, + stateStore, + o.SwapEndpoint, + signer, + ) + if err != nil { + return nil, err + } + + b.ethClientCloser = swapBackend.Close + b.transactionMonitorCloser = transactionMonitor + if o.SwapEnable { chequebookFactory, err = InitChequebookFactory( logger, swapBackend, @@ -274,6 +277,75 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, } b.p2pService = p2ps + // localstore depends on batchstore + var path string + + if o.DataDir != "" { + path = filepath.Join(o.DataDir, "localstore") + } + lo := &localstore.Options{ + Capacity: o.DBCapacity, + OpenFilesLimit: o.DBOpenFilesLimit, + BlockCacheCapacity: o.DBBlockCacheCapacity, + WriteBufferSize: o.DBWriteBufferSize, + DisableSeeksCompaction: o.DBDisableSeeksCompaction, + } + + storer, err := localstore.New(path, swarmAddress.Bytes(), lo, logger) + if err != nil { + return nil, fmt.Errorf("localstore: %w", err) + } + b.localstoreCloser = storer + + batchStore, err := batchstore.New(stateStore, storer.UnreserveBatch) + if err != nil { + return nil, fmt.Errorf("batchstore: %w", err) + } + validStamp := postage.ValidStamp(batchStore) + post := postage.NewService(stateStore, chainID) + + var ( + postageContractService postagecontract.Interface + batchSvc postage.EventUpdater + ) + + if !o.Standalone { + postageContractAddress, priceOracleAddress, found := listener.DiscoverAddresses(chainID) + if o.PostageContractAddress != "" { + if !common.IsHexAddress(o.PostageContractAddress) { + return nil, errors.New("malformed postage stamp address") + } + postageContractAddress = common.HexToAddress(o.PostageContractAddress) + } + if o.PriceOracleAddress != "" { + if !common.IsHexAddress(o.PriceOracleAddress) { + return nil, errors.New("malformed price oracle address") + } + priceOracleAddress = common.HexToAddress(o.PriceOracleAddress) + } + if (o.PostageContractAddress == "" || o.PriceOracleAddress == "") && !found { + return nil, errors.New("no known postage stamp addresses for this network") + } + + eventListener := listener.New(logger, swapBackend, postageContractAddress, priceOracleAddress) + b.listenerCloser = eventListener + + batchSvc = batchservice.New(batchStore, logger, eventListener) + + erc20Address, err := postagecontract.LookupERC20Address(p2pCtx, transactionService, postageContractAddress) + if err != nil { + return nil, err + } + + postageContractService = postagecontract.New( + overlayEthAddress, + postageContractAddress, + erc20Address, + transactionService, + post, + ) + } + if !o.Standalone { if natManager := p2ps.NATManager(); natManager != nil { // wait for nat manager to init @@ -325,6 +397,17 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, b.topologyCloser = kad hive.SetAddPeersHandler(kad.AddPeers) p2ps.SetPickyNotifier(kad) + batchStore.SetRadiusSetter(kad) + syncedChan := batchSvc.Start() + + // wait for the postage contract listener to sync + logger.Info("waiting to sync postage contract data, this may take a while... more info available in Debug loglevel") + + // arguably this is not a very nice solution since we dont support + // interrupts at this stage of the application lifecycle. some changes + // would be needed on the cmd level to support context cancellation at + // this stage + <-syncedChan paymentThreshold, ok := new(big.Int).SetString(o.PaymentThreshold, 10) if !ok { @@ -396,24 +479,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, pricing.SetPaymentThresholdObserver(acc) settlement.SetNotifyPaymentFunc(acc.AsyncNotifyPayment) - var path string - - if o.DataDir != "" { - path = filepath.Join(o.DataDir, "localstore") - } - lo := &localstore.Options{ - Capacity: o.DBCapacity, - OpenFilesLimit: o.DBOpenFilesLimit, - BlockCacheCapacity: o.DBBlockCacheCapacity, - WriteBufferSize: o.DBWriteBufferSize, - DisableSeeksCompaction: o.DBDisableSeeksCompaction, - } - storer, err := localstore.New(path, swarmAddress.Bytes(), lo, logger) - if err != nil { - return nil, fmt.Errorf("localstore: %w", err) - } - b.localstoreCloser = storer - retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer) tagService := tags.NewTags(stateStore, logger) b.tagsCloser = tagService @@ -425,16 +490,16 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, if o.GlobalPinningEnabled { // create recovery callback for content repair recoverFunc := recovery.NewCallback(pssService) - ns = netstore.New(storer, recoverFunc, retrieve, logger) + ns = netstore.New(storer, validStamp, recoverFunc, retrieve, logger) } else { - ns = netstore.New(storer, nil, retrieve, logger) + ns = netstore.New(storer, validStamp, nil, retrieve, logger) } traversalService := traversal.NewService(ns) pinningService := pinning.NewService(storer, stateStore, traversalService) - pushSyncProtocol := pushsync.New(swarmAddress, p2ps, storer, kad, tagService, o.FullNodeMode, pssService.TryUnwrap, logger, acc, pricer, signer, tracer) + pushSyncProtocol := pushsync.New(swarmAddress, p2ps, storer, kad, tagService, o.FullNodeMode, pssService.TryUnwrap, validStamp, logger, acc, pricer, signer, tracer) // set the pushSyncer in the PSS pssService.SetPushSyncer(pushSyncProtocol) @@ -450,7 +515,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, pullStorage := pullstorage.New(storer) - pullSyncProtocol := pullsync.New(p2ps, pullStorage, pssService.TryUnwrap, logger) + pullSyncProtocol := pullsync.New(p2ps, pullStorage, pssService.TryUnwrap, validStamp, logger) b.pullSyncCloser = pullSyncProtocol pullerService := puller.New(stateStore, kad, pullSyncProtocol, logger, puller.Options{}) @@ -489,7 +554,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, if o.APIAddr != "" { // API server feedFactory := factory.New(ns) - apiService = api.New(tagService, ns, multiResolver, pssService, traversalService, pinningService, feedFactory, logger, tracer, api.Options{ + apiService = api.New(tagService, ns, multiResolver, pssService, traversalService, pinningService, feedFactory, post, postageContractService, signer, logger, tracer, api.Options{ CORSAllowedOrigins: o.CORSAllowedOrigins, GatewayMode: o.GatewayMode, WsPingPeriod: 60 * time.Second, @@ -531,6 +596,10 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, debugAPIService.MustRegisterMetrics(pullSyncProtocol.Metrics()...) debugAPIService.MustRegisterMetrics(retrieve.Metrics()...) + if bs, ok := batchStore.(metrics.Collector); ok { + debugAPIService.MustRegisterMetrics(bs.Metrics()...) + } + if pssServiceMetrics, ok := pssService.(metrics.Collector); ok { debugAPIService.MustRegisterMetrics(pssServiceMetrics.Metrics()...) } @@ -547,7 +616,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, } // inject dependencies and configure full debug api http path routes - debugAPIService.Configure(p2ps, pingPong, kad, lightNodes, storer, tagService, acc, settlement, o.SwapEnable, swapService, chequebookService) + debugAPIService.Configure(p2ps, pingPong, kad, lightNodes, storer, tagService, acc, settlement, o.SwapEnable, swapService, chequebookService, batchStore) } if err := kad.Start(p2pCtx); err != nil { @@ -630,6 +699,12 @@ func (b *Bee) Shutdown(ctx context.Context) error { errs.add(fmt.Errorf("tag persistence: %w", err)) } + if b.listenerCloser != nil { + if err := b.listenerCloser.Close(); err != nil { + errs.add(fmt.Errorf("error listener: %w", err)) + } + } + if err := b.stateStoreCloser.Close(); err != nil { errs.add(fmt.Errorf("statestore: %w", err)) } diff --git a/pkg/postage/batch.go b/pkg/postage/batch.go new file mode 100644 index 00000000000..3ba87a98df8 --- /dev/null +++ b/pkg/postage/batch.go @@ -0,0 +1,45 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package postage + +import ( + "encoding/binary" + "math/big" +) + +// Batch represents a postage batch, a payment on the blockchain. +type Batch struct { + ID []byte // batch ID + Value *big.Int // normalised balance of the batch + Start uint64 // block number the batch was created + Owner []byte // owner's ethereum address + Depth uint8 // batch depth, i.e., size = 2^{depth} + Radius uint8 // reserve radius, non-serialised +} + +// MarshalBinary implements BinaryMarshaller. It will attempt to serialize the +// postage batch to a byte slice. +// serialised as ID(32)|big endian value(32)|start block(8)|owner addr(20)|depth(1) +func (b *Batch) MarshalBinary() ([]byte, error) { + out := make([]byte, 93) + copy(out, b.ID) + value := b.Value.Bytes() + copy(out[64-len(value):], value) + binary.BigEndian.PutUint64(out[64:72], b.Start) + copy(out[72:], b.Owner) + out[92] = b.Depth + return out, nil +} + +// UnmarshalBinary implements BinaryUnmarshaller. It will attempt deserialize +// the given byte slice into the batch. +func (b *Batch) UnmarshalBinary(buf []byte) error { + b.ID = buf[:32] + b.Value = big.NewInt(0).SetBytes(buf[32:64]) + b.Start = binary.BigEndian.Uint64(buf[64:72]) + b.Owner = buf[72:92] + b.Depth = buf[92] + return nil +} diff --git a/pkg/postage/batch_test.go b/pkg/postage/batch_test.go new file mode 100644 index 00000000000..90452fcbf43 --- /dev/null +++ b/pkg/postage/batch_test.go @@ -0,0 +1,45 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package postage_test + +import ( + "bytes" + "testing" + + "github.com/ethersphere/bee/pkg/postage" + postagetesting "github.com/ethersphere/bee/pkg/postage/testing" +) + +// TestBatchMarshalling tests the idempotence of binary marshal/unmarshal for a +// Batch. +func TestBatchMarshalling(t *testing.T) { + a := postagetesting.MustNewBatch() + buf, err := a.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(buf) != 93 { + t.Fatalf("invalid length for serialised batch. expected 93, got %d", len(buf)) + } + b := &postage.Batch{} + if err := b.UnmarshalBinary(buf); err != nil { + t.Fatalf("unexpected error unmarshalling batch: %v", err) + } + if !bytes.Equal(b.ID, a.ID) { + t.Fatalf("id mismatch, expected %x, got %x", a.ID, b.ID) + } + if !bytes.Equal(b.Owner, a.Owner) { + t.Fatalf("owner mismatch, expected %x, got %x", a.Owner, b.Owner) + } + if a.Value.Uint64() != b.Value.Uint64() { + t.Fatalf("value mismatch, expected %d, got %d", a.Value.Uint64(), b.Value.Uint64()) + } + if a.Start != b.Start { + t.Fatalf("start mismatch, expected %d, got %d", a.Start, b.Start) + } + if a.Depth != b.Depth { + t.Fatalf("depth mismatch, expected %d, got %d", a.Depth, b.Depth) + } +} diff --git a/pkg/postage/batchservice/batchservice.go b/pkg/postage/batchservice/batchservice.go new file mode 100644 index 00000000000..e13606f40c3 --- /dev/null +++ b/pkg/postage/batchservice/batchservice.go @@ -0,0 +1,114 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package batchservice + +import ( + "encoding/hex" + "fmt" + "math/big" + + "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage" +) + +type batchService struct { + storer postage.Storer + logger logging.Logger + listener postage.Listener +} + +type Interface interface { + postage.EventUpdater +} + +// New will create a new BatchService. +func New(storer postage.Storer, logger logging.Logger, listener postage.Listener) Interface { + return &batchService{storer, logger, listener} +} + +// Create will create a new batch with the given ID, owner value and depth and +// stores it in the BatchStore. +func (svc *batchService) Create(id, owner []byte, normalisedBalance *big.Int, depth uint8) error { + b := &postage.Batch{ + ID: id, + Owner: owner, + Value: big.NewInt(0), + Start: svc.storer.GetChainState().Block, + Depth: depth, + } + + err := svc.storer.Put(b, normalisedBalance, depth) + if err != nil { + return fmt.Errorf("put: %w", err) + } + + svc.logger.Debugf("batch service: created batch id %s", hex.EncodeToString(b.ID)) + return nil +} + +// TopUp implements the EventUpdater interface. It tops ups a batch with the +// given ID with the given amount. +func (svc *batchService) TopUp(id []byte, normalisedBalance *big.Int) error { + b, err := svc.storer.Get(id) + if err != nil { + return fmt.Errorf("get: %w", err) + } + + err = svc.storer.Put(b, normalisedBalance, b.Depth) + if err != nil { + return fmt.Errorf("put: %w", err) + } + + svc.logger.Debugf("batch service: topped up batch id %s from %v to %v", hex.EncodeToString(b.ID), b.Value, normalisedBalance) + return nil +} + +// UpdateDepth implements the EventUpdater inteface. It sets the new depth of a +// batch with the given ID. +func (svc *batchService) UpdateDepth(id []byte, depth uint8, normalisedBalance *big.Int) error { + b, err := svc.storer.Get(id) + if err != nil { + return fmt.Errorf("get: %w", err) + } + err = svc.storer.Put(b, normalisedBalance, depth) + if err != nil { + return fmt.Errorf("put: %w", err) + } + + svc.logger.Debugf("batch service: updated depth of batch id %s from %d to %d", hex.EncodeToString(b.ID), b.Depth, depth) + return nil +} + +// UpdatePrice implements the EventUpdater interface. It sets the current +// price from the chain in the service chain state. +func (svc *batchService) UpdatePrice(price *big.Int) error { + cs := svc.storer.GetChainState() + cs.Price = price + if err := svc.storer.PutChainState(cs); err != nil { + return fmt.Errorf("put chain state: %w", err) + } + + svc.logger.Debugf("batch service: updated chain price to %s", price) + return nil +} + +func (svc *batchService) UpdateBlockNumber(blockNumber uint64) error { + cs := svc.storer.GetChainState() + diff := big.NewInt(0).SetUint64(blockNumber - cs.Block) + + cs.Total.Add(cs.Total, diff.Mul(diff, cs.Price)) + cs.Block = blockNumber + if err := svc.storer.PutChainState(cs); err != nil { + return fmt.Errorf("put chain state: %w", err) + } + + svc.logger.Debugf("batch service: updated block height to %d", blockNumber) + return nil +} + +func (svc *batchService) Start() <-chan struct{} { + cs := svc.storer.GetChainState() + return svc.listener.Listen(cs.Block+1, svc) +} diff --git a/pkg/postage/batchservice/batchservice_test.go b/pkg/postage/batchservice/batchservice_test.go new file mode 100644 index 00000000000..9d23947da33 --- /dev/null +++ b/pkg/postage/batchservice/batchservice_test.go @@ -0,0 +1,260 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package batchservice_test + +import ( + "bytes" + "errors" + "io/ioutil" + "math/big" + "testing" + + "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/postage/batchservice" + "github.com/ethersphere/bee/pkg/postage/batchstore/mock" + postagetesting "github.com/ethersphere/bee/pkg/postage/testing" +) + +var ( + testLog = logging.New(ioutil.Discard, 0) + errTest = errors.New("fails") +) + +type mockListener struct { +} + +func (*mockListener) Listen(from uint64, updater postage.EventUpdater) <-chan struct{} { return nil } +func (*mockListener) Close() error { return nil } + +func newMockListener() *mockListener { + return &mockListener{} +} + +func TestBatchServiceCreate(t *testing.T) { + testBatch := postagetesting.MustNewBatch() + testChainState := postagetesting.NewChainState() + + t.Run("expect put create put error", func(t *testing.T) { + svc, _ := newTestStoreAndService( + mock.WithChainState(testChainState), + mock.WithPutErr(errTest, 0), + ) + + if err := svc.Create( + testBatch.ID, + testBatch.Owner, + testBatch.Value, + testBatch.Depth, + ); err == nil { + t.Fatalf("expected error") + } + }) + + t.Run("passes", func(t *testing.T) { + svc, batchStore := newTestStoreAndService( + mock.WithChainState(testChainState), + ) + + if err := svc.Create( + testBatch.ID, + testBatch.Owner, + testBatch.Value, + testBatch.Depth, + ); err != nil { + t.Fatalf("got error %v", err) + } + + got, err := batchStore.Get(testBatch.ID) + if err != nil { + t.Fatalf("batch store get: %v", err) + } + + if !bytes.Equal(got.ID, testBatch.ID) { + t.Fatalf("batch id: want %v, got %v", testBatch.ID, got.ID) + } + if !bytes.Equal(got.Owner, testBatch.Owner) { + t.Fatalf("batch owner: want %v, got %v", testBatch.Owner, got.Owner) + } + if got.Value.Cmp(testBatch.Value) != 0 { + t.Fatalf("batch value: want %v, got %v", testBatch.Value.String(), got.Value.String()) + } + if got.Depth != testBatch.Depth { + t.Fatalf("batch depth: want %v, got %v", got.Depth, testBatch.Depth) + } + if got.Start != testChainState.Block { + t.Fatalf("batch start block different form chain state: want %v, got %v", got.Start, testChainState.Block) + } + }) + +} + +func TestBatchServiceTopUp(t *testing.T) { + testBatch := postagetesting.MustNewBatch() + testNormalisedBalance := big.NewInt(2000000000000) + + t.Run("expect get error", func(t *testing.T) { + svc, _ := newTestStoreAndService( + mock.WithGetErr(errTest, 0), + ) + + if err := svc.TopUp(testBatch.ID, testNormalisedBalance); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("expect put error", func(t *testing.T) { + svc, batchStore := newTestStoreAndService( + mock.WithPutErr(errTest, 1), + ) + putBatch(t, batchStore, testBatch) + + if err := svc.TopUp(testBatch.ID, testNormalisedBalance); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("passes", func(t *testing.T) { + svc, batchStore := newTestStoreAndService() + putBatch(t, batchStore, testBatch) + + want := testNormalisedBalance + + if err := svc.TopUp(testBatch.ID, testNormalisedBalance); err != nil { + t.Fatalf("top up: %v", err) + } + + got, err := batchStore.Get(testBatch.ID) + if err != nil { + t.Fatalf("batch store get: %v", err) + } + + if got.Value.Cmp(want) != 0 { + t.Fatalf("topped up amount: got %v, want %v", got.Value, want) + } + }) +} + +func TestBatchServiceUpdateDepth(t *testing.T) { + const testNewDepth = 30 + testNormalisedBalance := big.NewInt(2000000000000) + testBatch := postagetesting.MustNewBatch() + + t.Run("expect get error", func(t *testing.T) { + svc, _ := newTestStoreAndService( + mock.WithGetErr(errTest, 0), + ) + + if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance); err == nil { + t.Fatal("expected get error") + } + }) + + t.Run("expect put error", func(t *testing.T) { + svc, batchStore := newTestStoreAndService( + mock.WithPutErr(errTest, 1), + ) + putBatch(t, batchStore, testBatch) + + if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance); err == nil { + t.Fatal("expected put error") + } + }) + + t.Run("passes", func(t *testing.T) { + svc, batchStore := newTestStoreAndService() + putBatch(t, batchStore, testBatch) + + if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance); err != nil { + t.Fatalf("update depth: %v", err) + } + + val, err := batchStore.Get(testBatch.ID) + if err != nil { + t.Fatalf("batch store get: %v", err) + } + + if val.Depth != testNewDepth { + t.Fatalf("wrong batch depth set: want %v, got %v", testNewDepth, val.Depth) + } + }) +} + +func TestBatchServiceUpdatePrice(t *testing.T) { + testChainState := postagetesting.NewChainState() + testChainState.Price = big.NewInt(100000) + testNewPrice := big.NewInt(20000000) + + t.Run("expect put error", func(t *testing.T) { + svc, batchStore := newTestStoreAndService( + mock.WithChainState(testChainState), + mock.WithPutErr(errTest, 1), + ) + putChainState(t, batchStore, testChainState) + + if err := svc.UpdatePrice(testNewPrice); err == nil { + t.Fatal("expected error") + } + }) + + t.Run("passes", func(t *testing.T) { + svc, batchStore := newTestStoreAndService( + mock.WithChainState(testChainState), + ) + + if err := svc.UpdatePrice(testNewPrice); err != nil { + t.Fatalf("update price: %v", err) + } + + cs := batchStore.GetChainState() + if cs.Price.Cmp(testNewPrice) != 0 { + t.Fatalf("bad price: want %v, got %v", cs.Price, testNewPrice) + } + }) +} +func TestBatchServiceUpdateBlockNumber(t *testing.T) { + testChainState := &postage.ChainState{ + Block: 1, + Price: big.NewInt(100), + Total: big.NewInt(100), + } + svc, batchStore := newTestStoreAndService( + mock.WithChainState(testChainState), + ) + + // advance the block number and expect total cumulative payout to update + nextBlock := uint64(4) + + if err := svc.UpdateBlockNumber(nextBlock); err != nil { + t.Fatalf("update price: %v", err) + } + nn := big.NewInt(400) + cs := batchStore.GetChainState() + if cs.Total.Cmp(nn) != 0 { + t.Fatalf("bad price: want %v, got %v", nn, cs.Total) + } +} + +func newTestStoreAndService(opts ...mock.Option) (postage.EventUpdater, postage.Storer) { + store := mock.New(opts...) + svc := batchservice.New(store, testLog, newMockListener()) + return svc, store +} + +func putBatch(t *testing.T, store postage.Storer, b *postage.Batch) { + t.Helper() + + if err := store.Put(b, big.NewInt(0), 0); err != nil { + t.Fatalf("store put batch: %v", err) + } +} + +func putChainState(t *testing.T, store postage.Storer, cs *postage.ChainState) { + t.Helper() + + if err := store.PutChainState(cs); err != nil { + t.Fatalf("store put chain state: %v", err) + } +} diff --git a/pkg/postage/batchstore/export_test.go b/pkg/postage/batchstore/export_test.go new file mode 100644 index 00000000000..0db306f0bd4 --- /dev/null +++ b/pkg/postage/batchstore/export_test.go @@ -0,0 +1,43 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package batchstore + +import ( + "fmt" + "math/big" + + "github.com/ethersphere/bee/pkg/postage" +) + +// ChainStateKey is the statestore key for the chain state. +const StateKey = chainStateKey + +// BatchKey returns the index key for the batch ID used in the by-ID batch index. +var BatchKey = batchKey + +// power of 2 function +var Exp2 = exp2 + +// iterates through all batches +func IterateAll(bs postage.Storer, f func(b *postage.Batch) (bool, error)) error { + s := bs.(*store) + return s.store.Iterate(batchKeyPrefix, func(key []byte, _ []byte) (bool, error) { + b, err := s.Get(key[len(key)-32:]) + if err != nil { + return true, err + } + return f(b) + }) +} + +// GetReserve extracts the inner limit and depth of reserve +func GetReserve(si postage.Storer) (*big.Int, uint8) { + s, _ := si.(*store) + return s.rs.Inner, s.rs.Radius +} + +func (s *store) String() string { + return fmt.Sprintf("inner=%d,outer=%d", s.rs.Inner.Uint64(), s.rs.Outer.Uint64()) +} diff --git a/pkg/postage/batchstore/metrics.go b/pkg/postage/batchstore/metrics.go new file mode 100644 index 00000000000..6068389fc5c --- /dev/null +++ b/pkg/postage/batchstore/metrics.go @@ -0,0 +1,52 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package batchstore + +import ( + m "github.com/ethersphere/bee/pkg/metrics" + "github.com/prometheus/client_golang/prometheus" +) + +type metrics struct { + AvailableCapacity prometheus.Gauge + Inner prometheus.Gauge + Outer prometheus.Gauge + Radius prometheus.Gauge +} + +func newMetrics() metrics { + subsystem := "batchstore" + + return metrics{ + AvailableCapacity: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: m.Namespace, + Subsystem: subsystem, + Name: "available_capacity", + Help: "Available capacity observed by the batchstore.", + }), + Inner: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: m.Namespace, + Subsystem: subsystem, + Name: "inner", + Help: "Inner storage tier value observed by the batchstore.", + }), + Outer: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: m.Namespace, + Subsystem: subsystem, + Name: "outer", + Help: "Outer storage tier value observed by the batchstore.", + }), + Radius: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: m.Namespace, + Subsystem: subsystem, + Name: "radius", + Help: "Radius of responsibility observed by the batchstore.", + }), + } +} + +func (s *store) Metrics() []prometheus.Collector { + return m.PrometheusCollectorsFromFields(s.metrics) +} diff --git a/pkg/postage/batchstore/mock/store.go b/pkg/postage/batchstore/mock/store.go new file mode 100644 index 00000000000..e6965de7a7c --- /dev/null +++ b/pkg/postage/batchstore/mock/store.go @@ -0,0 +1,129 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mock + +import ( + "bytes" + "errors" + "math/big" + + "github.com/ethersphere/bee/pkg/postage" +) + +var _ postage.Storer = (*BatchStore)(nil) + +// BatchStore is a mock BatchStorer +type BatchStore struct { + rs *postage.Reservestate + cs *postage.ChainState + id []byte + batch *postage.Batch + getErr error + getErrDelayCnt int + putErr error + putErrDelayCnt int +} + +// Option is a an option passed to New +type Option func(*BatchStore) + +// New creates a new mock BatchStore +func New(opts ...Option) *BatchStore { + bs := &BatchStore{} + bs.cs = &postage.ChainState{} + for _, o := range opts { + o(bs) + } + + return bs +} + +// WithChainState will set the initial chainstate in the ChainStore mock. +func WithReserveState(rs *postage.Reservestate) Option { + return func(bs *BatchStore) { + bs.rs = rs + } +} + +// WithChainState will set the initial chainstate in the ChainStore mock. +func WithChainState(cs *postage.ChainState) Option { + return func(bs *BatchStore) { + bs.cs = cs + } +} + +// WithGetErr will set the get error returned by the ChainStore mock. The error +// will be returned on each subsequent call after delayCnt calls to Get have +// been made. +func WithGetErr(err error, delayCnt int) Option { + return func(bs *BatchStore) { + bs.getErr = err + bs.getErrDelayCnt = delayCnt + } +} + +// WithPutErr will set the put error returned by the ChainStore mock. The error +// will be returned on each subsequent call after delayCnt calls to Put have +// been made. +func WithPutErr(err error, delayCnt int) Option { + return func(bs *BatchStore) { + bs.putErr = err + bs.putErrDelayCnt = delayCnt + } +} + +// Get mocks the Get method from the BatchStore +func (bs *BatchStore) Get(id []byte) (*postage.Batch, error) { + if bs.getErr != nil { + if bs.getErrDelayCnt == 0 { + return nil, bs.getErr + } + bs.getErrDelayCnt-- + } + if !bytes.Equal(bs.id, id) { + return nil, errors.New("no such id") + } + return bs.batch, nil +} + +// Put mocks the Put method from the BatchStore +func (bs *BatchStore) Put(batch *postage.Batch, newValue *big.Int, newDepth uint8) error { + if bs.putErr != nil { + if bs.putErrDelayCnt == 0 { + return bs.putErr + } + bs.putErrDelayCnt-- + } + bs.batch = batch + batch.Depth = newDepth + batch.Value.Set(newValue) + bs.id = batch.ID + return nil +} + +// GetChainState mocks the GetChainState method from the BatchStore +func (bs *BatchStore) GetChainState() *postage.ChainState { + return bs.cs +} + +// PutChainState mocks the PutChainState method from the BatchStore +func (bs *BatchStore) PutChainState(cs *postage.ChainState) error { + if bs.putErr != nil { + if bs.putErrDelayCnt == 0 { + return bs.putErr + } + bs.putErrDelayCnt-- + } + bs.cs = cs + return nil +} + +func (bs *BatchStore) GetReserveState() *postage.Reservestate { + return bs.rs +} + +func (bs *BatchStore) SetRadiusSetter(r postage.RadiusSetter) { + panic("not implemented") +} diff --git a/pkg/postage/batchstore/mock/store_test.go b/pkg/postage/batchstore/mock/store_test.go new file mode 100644 index 00000000000..6b4dc009639 --- /dev/null +++ b/pkg/postage/batchstore/mock/store_test.go @@ -0,0 +1,67 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mock_test + +import ( + "errors" + "math/big" + "testing" + + "github.com/ethersphere/bee/pkg/postage/batchstore/mock" + postagetesting "github.com/ethersphere/bee/pkg/postage/testing" +) + +func TestBatchStorePutGet(t *testing.T) { + const testCnt = 3 + + testBatch := postagetesting.MustNewBatch() + batchStore := mock.New( + mock.WithGetErr(errors.New("fails"), testCnt), + mock.WithPutErr(errors.New("fails"), testCnt), + ) + + // Put should return error after a number of tries: + for i := 0; i < testCnt; i++ { + if err := batchStore.Put(testBatch, big.NewInt(0), 0); err != nil { + t.Fatal(err) + } + } + if err := batchStore.Put(testBatch, big.NewInt(0), 0); err == nil { + t.Fatal("expected error") + } + + // Get should fail on wrong id, and after a number of tries: + if _, err := batchStore.Get(postagetesting.MustNewID()); err == nil { + t.Fatal("expected error") + } + for i := 0; i < testCnt-1; i++ { + if _, err := batchStore.Get(testBatch.ID); err != nil { + t.Fatal(err) + } + } + if _, err := batchStore.Get(postagetesting.MustNewID()); err == nil { + t.Fatal("expected error") + } +} + +func TestBatchStorePutChainState(t *testing.T) { + const testCnt = 3 + + testChainState := postagetesting.NewChainState() + batchStore := mock.New( + mock.WithChainState(testChainState), + mock.WithPutErr(errors.New("fails"), testCnt), + ) + + // PutChainState should return an error after a number of tries: + for i := 0; i < testCnt; i++ { + if err := batchStore.PutChainState(testChainState); err != nil { + t.Fatal(err) + } + } + if err := batchStore.PutChainState(testChainState); err == nil { + t.Fatal("expected error") + } +} diff --git a/pkg/postage/batchstore/reserve.go b/pkg/postage/batchstore/reserve.go new file mode 100644 index 00000000000..0430163a45d --- /dev/null +++ b/pkg/postage/batchstore/reserve.go @@ -0,0 +1,324 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package batchstore implements the reserve +// the reserve serves to maintain chunks in the area of responsibility +// it has two components +// - the batchstore reserve which maintains information about batches, their values, priorities and synchronises with the blockchain +// - the localstore which stores chunks and manages garbage collection +// +// when a new chunk arrives in the localstore, the batchstore reserve is asked to check +// the batch used in the postage stamp attached to the chunk. +// Depending on the value of the batch (reserve depth of the batch), the localstore +// either pins the chunk (thereby protecting it from garbage collection) or not. +// the chunk stays pinned until it is 'unreserved' based on changes in relative priority of the batch it belongs to +// +// the atomic db operation is unreserving a batch down to a depth +// the intended semantics of unreserve is to unpin the chunks +// in the relevant POs, belonging to the batch and (unless they are otherwise pinned) +// allow them to be gargage collected. +// +// the rules of the reserve +// - if batch a is unreserved and val(b) < val(a) then b is unreserved on any po +// - if a batch is unreserved on po p, then it is unreserved also on any p'

0 && total >= inner + if s.rs.Inner.Cmp(big.NewInt(0)) > 0 && s.cs.Total.Cmp(s.rs.Inner) >= 0 { + // collect until total+1 + until.Add(s.cs.Total, big1) + } else { + // collect until inner (collect all outer ones) + until.Set(s.rs.Inner) + } + var multiplier int64 + err := s.store.Iterate(valueKeyPrefix, func(key, _ []byte) (bool, error) { + b, err := s.Get(valueKeyToID(key)) + if err != nil { + return true, err + } + + // if batch value >= until then continue to next. + // terminate iteration if until is passed + if b.Value.Cmp(until) >= 0 { + return true, nil + } + + // in the following if statements we check the batch value + // against the inner and outer values and set the multiplier + // to 1 or 2 depending on the value. if the batch value falls + // outside of Outer it means we are evicting twice more chunks + // than within Inner, therefore the multiplier is needed to + // estimate better how much capacity gain is leveraged from + // evicting this specific batch. + + // if multiplier == 0 && batch value >= inner + if multiplier == 0 && b.Value.Cmp(s.rs.Inner) >= 0 { + multiplier = 1 + } + // if multiplier == 1 && batch value >= outer + if multiplier == 1 && b.Value.Cmp(s.rs.Outer) >= 0 { + multiplier = 2 + } + + // unreserve batch fully + err = s.unreserve(b, swarm.MaxPO+1) + if err != nil { + return true, err + } + s.rs.Available += multiplier * exp2(b.Radius-s.rs.Radius-1) + + // if batch has no value then delete it + if b.Value.Cmp(s.cs.Total) <= 0 { + toDelete = append(toDelete, b.ID) + } + return false, nil + }) + if err != nil { + return err + } + + // set inner to either until or Outer, whichever + // is the smaller value. + + s.rs.Inner.Set(until) + + // if outer < until + if s.rs.Outer.Cmp(until) < 0 { + s.rs.Outer.Set(until) + } + if err = s.store.Put(reserveStateKey, s.rs); err != nil { + return err + } + return s.delete(toDelete...) +} + +// tier represents the sections of the reserve that can be described as value intervals +// 0 - out of the reserve +// 1 - within reserve radius = depth (inner half) +// 2 - within reserve radius = depth-1 (both inner and outer halves) +type tier int + +const ( + unreserved tier = iota // out of the reserve + inner // the mid range where chunks are kept within depth + outer // top range where chunks are kept within depth - 1 +) + +// change calculates info relevant to the value change from old to new value and old and new depth +// returns the change in capacity and the radius of reserve +func (rs *reserveState) change(oldv, newv *big.Int, oldDepth, newDepth uint8) (int64, uint8) { + oldTier := rs.tier(oldv) + newTier := rs.setLimits(newv, rs.tier(newv)) + + oldSize := rs.size(oldDepth, oldTier) + newSize := rs.size(newDepth, newTier) + + availableCapacityChange := oldSize - newSize + reserveRadius := rs.radius(newTier) + + return availableCapacityChange, reserveRadius +} + +// size returns the number of chunks the local node is responsible +// to store in its reserve. +func (rs *reserveState) size(depth uint8, t tier) int64 { + size := exp2(depth - rs.Radius - 1) + switch t { + case inner: + return size + case outer: + return 2 * size + default: + // case is unreserved + return 0 + } +} + +// tier returns which tier a value falls into +func (rs *reserveState) tier(x *big.Int) tier { + + // x < rs.Inner || x == 0 + if x.Cmp(rs.Inner) < 0 || rs.Inner.Cmp(big.NewInt(0)) == 0 { + return unreserved + } + + // x < rs.Outer + if x.Cmp(rs.Outer) < 0 { + return inner + } + + // x >= rs.Outer + return outer +} + +// radius returns the reserve radius of a batch given the depth (radius of responsibility) +// based on the tier it falls in +func (rs *reserveState) radius(t tier) uint8 { + switch t { + case unreserved: + return swarm.MaxPO + case inner: + return rs.Radius + default: + // outer + return rs.Radius - 1 + } +} + +// setLimits sets the tier 1 value limit, if new item is the minimum so far (or the very first batch) +// returns the adjusted new tier +func (rs *reserveState) setLimits(val *big.Int, newTier tier) tier { + if newTier != unreserved { + return newTier + } + + // if we're here it means that the new tier + // falls under the unreserved tier + var adjustedTier tier + + // rs.Inner == 0 || rs.Inner > val + if rs.Inner.Cmp(big.NewInt(0)) == 0 || rs.Inner.Cmp(val) > 0 { + adjustedTier = inner + // if the outer is the same as the inner + if rs.Outer.Cmp(rs.Inner) == 0 { + // the value falls below inner and outer + rs.Outer.Set(val) + adjustedTier = outer + } + // inner is decreased to val, this is done when the + // batch is diluted, decreasing the value of it. + rs.Inner.Set(val) + } + return adjustedTier +} + +// update manages what chunks of which batch are allocated to the reserve +func (s *store) update(b *postage.Batch, oldDepth uint8, oldValue *big.Int) error { + newValue := b.Value + newDepth := b.Depth + capacityChange, reserveRadius := s.rs.change(oldValue, newValue, oldDepth, newDepth) + s.rs.Available += capacityChange + + if err := s.unreserve(b, reserveRadius); err != nil { + return err + } + err := s.evictOuter(b) + if err != nil { + return err + } + + s.metrics.AvailableCapacity.Set(float64(s.rs.Available)) + s.metrics.Radius.Set(float64(s.rs.Radius)) + s.metrics.Inner.Set(float64(s.rs.Inner.Int64())) + s.metrics.Outer.Set(float64(s.rs.Outer.Int64())) + return nil +} + +// evictOuter is responsible for keeping capacity positive by unreserving lowest priority batches +func (s *store) evictOuter(last *postage.Batch) error { + // if capacity is positive nothing to evict + if s.rs.Available >= 0 { + return nil + } + err := s.store.Iterate(valueKeyPrefix, func(key, _ []byte) (bool, error) { + batchID := valueKeyToID(key) + b := last + if !bytes.Equal(b.ID, batchID) { + var err error + b, err = s.Get(batchID) + if err != nil { + return true, fmt.Errorf("release get %x %v: %w", batchID, b, err) + } + } + // FIXME: this is needed only because the statestore iterator does not allow seek, only prefix + // so we need to page through all the batches until outer limit is reached + if b.Value.Cmp(s.rs.Outer) < 0 { + return false, nil + } + // stop iteration only if we consumed all batches of the same value as the one that put capacity above zero + if s.rs.Available >= 0 && s.rs.Outer.Cmp(b.Value) != 0 { + return true, nil + } + // unreserve outer PO of the lowest priority batch until capacity is back to positive + s.rs.Available += exp2(b.Depth - s.rs.Radius - 1) + s.rs.Outer.Set(b.Value) + return false, s.unreserve(b, s.rs.Radius) + }) + if err != nil { + return err + } + // add 1 to outer limit value so we dont hit on the same batch next time we iterate + s.rs.Outer.Add(s.rs.Outer, big1) + // if we consumed all batches, ie. we unreserved all chunks on the outer = depth PO + // then its time to increase depth + if s.rs.Available < 0 { + s.rs.Radius++ + s.rs.Outer.Set(s.rs.Inner) // reset outer limit to inner limit + return s.evictOuter(last) + } + return s.store.Put(reserveStateKey, s.rs) +} + +// exp2 returns the e-th power of 2 +func exp2(e uint8) int64 { + if e == 0 { + return 1 + } + b := int64(2) + for i := uint8(1); i < e; i++ { + b *= 2 + } + return b +} diff --git a/pkg/postage/batchstore/reserve_test.go b/pkg/postage/batchstore/reserve_test.go new file mode 100644 index 00000000000..f2711df24ea --- /dev/null +++ b/pkg/postage/batchstore/reserve_test.go @@ -0,0 +1,947 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package batchstore_test + +import ( + "encoding/hex" + "errors" + "fmt" + "io/ioutil" + "math/big" + "math/rand" + "os" + "testing" + + "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/postage/batchstore" + postagetest "github.com/ethersphere/bee/pkg/postage/testing" + "github.com/ethersphere/bee/pkg/statestore/leveldb" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +// random advance on the blockchain +func newBlockAdvance() uint64 { + return uint64(rand.Intn(3) + 1) +} + +// initial depth of a new batch +func newBatchDepth(depth uint8) uint8 { + return depth + uint8(rand.Intn(10)) + 4 +} + +// the factor to increase the batch depth with +func newDilutionFactor() int { + return rand.Intn(3) + 1 +} + +// new value on top of value based on random period and price +func newValue(price, value *big.Int) *big.Int { + period := rand.Intn(100) + 1000 + v := new(big.Int).Mul(price, big.NewInt(int64(period))) + return v.Add(v, value) +} + +// TestBatchStoreUnreserve is testing the correct behaviour of the reserve. +// the following assumptions are tested on each modification of the batches (top up, depth increase, price change) +// - reserve exceeds capacity +// - value-consistency of unreserved POs +func TestBatchStoreUnreserveEvents(t *testing.T) { + defer func(i int64, d uint8) { + batchstore.Capacity = i + batchstore.DefaultDepth = d + }(batchstore.Capacity, batchstore.DefaultDepth) + batchstore.DefaultDepth = 5 + batchstore.Capacity = batchstore.Exp2(16) + + bStore, unreserved := setupBatchStore(t) + bStore.SetRadiusSetter(noopRadiusSetter{}) + batches := make(map[string]*postage.Batch) + + t.Run("new batches only", func(t *testing.T) { + // iterate starting from batchstore.DefaultDepth to maxPO + _, radius := batchstore.GetReserve(bStore) + for step := 0; radius < swarm.MaxPO; step++ { + cs, err := nextChainState(bStore) + if err != nil { + t.Fatal(err) + } + var b *postage.Batch + if b, err = createBatch(bStore, cs, radius); err != nil { + t.Fatal(err) + } + batches[string(b.ID)] = b + if radius, err = checkReserve(bStore, unreserved); err != nil { + t.Fatal(err) + } + } + }) + t.Run("top up batches", func(t *testing.T) { + n := 0 + for id := range batches { + b, err := bStore.Get([]byte(id)) + if err != nil { + if errors.Is(storage.ErrNotFound, err) { + continue + } + t.Fatal(err) + } + cs, err := nextChainState(bStore) + if err != nil { + t.Fatal(err) + } + if err = topUp(bStore, cs, b); err != nil { + t.Fatal(err) + } + if _, err = checkReserve(bStore, unreserved); err != nil { + t.Fatal(err) + } + n++ + if n > len(batches)/5 { + break + } + } + }) + t.Run("dilute batches", func(t *testing.T) { + n := 0 + for id := range batches { + b, err := bStore.Get([]byte(id)) + if err != nil { + if errors.Is(storage.ErrNotFound, err) { + continue + } + t.Fatal(err) + } + cs, err := nextChainState(bStore) + if err != nil { + t.Fatal(err) + } + if err = increaseDepth(bStore, cs, b); err != nil { + t.Fatal(err) + } + if _, err = checkReserve(bStore, unreserved); err != nil { + t.Fatal(err) + } + n++ + if n > len(batches)/5 { + break + } + } + }) +} + +func TestBatchStoreUnreserveAll(t *testing.T) { + defer func(i int64, d uint8) { + batchstore.Capacity = i + batchstore.DefaultDepth = d + }(batchstore.Capacity, batchstore.DefaultDepth) + batchstore.DefaultDepth = 5 + batchstore.Capacity = batchstore.Exp2(16) + + bStore, unreserved := setupBatchStore(t) + bStore.SetRadiusSetter(noopRadiusSetter{}) + var batches [][]byte + // iterate starting from batchstore.DefaultDepth to maxPO + _, depth := batchstore.GetReserve(bStore) + for step := 0; depth < swarm.MaxPO; step++ { + cs, err := nextChainState(bStore) + if err != nil { + t.Fatal(err) + } + event := rand.Intn(6) + // 0: dilute, 1: topup, 2,3,4,5: create + var b *postage.Batch + if event < 2 && len(batches) > 10 { + for { + n := rand.Intn(len(batches)) + b, err = bStore.Get(batches[n]) + if err != nil { + if errors.Is(storage.ErrNotFound, err) { + continue + } + t.Fatal(err) + } + break + } + if event == 0 { + if err = increaseDepth(bStore, cs, b); err != nil { + t.Fatal(err) + } + } else if err = topUp(bStore, cs, b); err != nil { + t.Fatal(err) + } + } else if b, err = createBatch(bStore, cs, depth); err != nil { + t.Fatal(err) + } else { + batches = append(batches, b.ID) + } + if depth, err = checkReserve(bStore, unreserved); err != nil { + t.Fatal(err) + } + } +} + +func setupBatchStore(t *testing.T) (postage.Storer, map[string]uint8) { + t.Helper() + // we cannot use the mock statestore here since the iterator is not giving the right order + // must use the leveldb statestore + dir, err := ioutil.TempDir("", "batchstore_test") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := os.RemoveAll(dir); err != nil { + t.Fatal(err) + } + }) + logger := logging.New(ioutil.Discard, 0) + stateStore, err := leveldb.NewStateStore(dir, logger) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := stateStore.Close(); err != nil { + t.Fatal(err) + } + }) + + // set mock unreserve call + unreserved := make(map[string]uint8) + unreserveFunc := func(batchID []byte, radius uint8) error { + unreserved[hex.EncodeToString(batchID)] = radius + return nil + } + bStore, _ := batchstore.New(stateStore, unreserveFunc) + bStore.SetRadiusSetter(noopRadiusSetter{}) + + // initialise chainstate + err = bStore.PutChainState(&postage.ChainState{ + Block: 0, + Total: big.NewInt(0), + Price: big.NewInt(1), + }) + if err != nil { + t.Fatal(err) + } + return bStore, unreserved +} + +func nextChainState(bStore postage.Storer) (*postage.ChainState, error) { + cs := bStore.GetChainState() + // random advance on the blockchain + advance := newBlockAdvance() + cs = &postage.ChainState{ + Block: advance + cs.Block, + Price: cs.Price, + // settle although no price change + Total: cs.Total.Add(cs.Total, new(big.Int).Mul(cs.Price, big.NewInt(int64(advance)))), + } + return cs, bStore.PutChainState(cs) +} + +// creates a test batch with random value and depth and adds it to the batchstore +func createBatch(bStore postage.Storer, cs *postage.ChainState, depth uint8) (*postage.Batch, error) { + b := postagetest.MustNewBatch() + b.Depth = newBatchDepth(depth) + value := newValue(cs.Price, cs.Total) + b.Value = big.NewInt(0) + return b, bStore.Put(b, value, b.Depth) +} + +// tops up a batch with random amount +func topUp(bStore postage.Storer, cs *postage.ChainState, b *postage.Batch) error { + value := newValue(cs.Price, b.Value) + return bStore.Put(b, value, b.Depth) +} + +// dilutes the batch with random factor +func increaseDepth(bStore postage.Storer, cs *postage.ChainState, b *postage.Batch) error { + diff := newDilutionFactor() + value := new(big.Int).Sub(b.Value, cs.Total) + value.Div(value, big.NewInt(int64(1<= 0 { + if bDepth < depth-1 || bDepth > depth { + return true, fmt.Errorf("incorrect reserve radius. expected %d or %d. got %d", depth-1, depth, bDepth) + } + if bDepth == depth { + if inner.Cmp(b.Value) < 0 { + inner.Set(b.Value) + } + } else if outer.Cmp(b.Value) > 0 || outer.Cmp(big.NewInt(0)) == 0 { + outer.Set(b.Value) + } + if outer.Cmp(big.NewInt(0)) != 0 && outer.Cmp(inner) <= 0 { + return true, fmt.Errorf("inconsistent reserve radius: %d <= %d", outer.Uint64(), inner.Uint64()) + } + size += batchstore.Exp2(b.Depth - bDepth - 1) + } else if bDepth != swarm.MaxPO { + return true, fmt.Errorf("batch below limit expected to be fully unreserved. got found=%v, radius=%d", found, bDepth) + } + return false, nil + }) + if err != nil { + return 0, err + } + if size > batchstore.Capacity { + return 0, fmt.Errorf("reserve size beyond capacity. max %d, got %d", batchstore.Capacity, size) + } + return depth, nil +} + +// TestBatchStore_Unreserve tests that the unreserve +// hook is called with the correct batch IDs and correct +// Radius as a result of batches coming in from chain events. +// All tests share the same initial state: +// ▲ bzz/chunk +// │ +// 6 ├──┐ +// 5 │ ├──┐ +// 4 │ │ ├──┐ +// 3 │ │ │ ├──┐---inner, outer +// │ │ │ │ │ +// └──┴──┴──┴──┴───────> time +// +func TestBatchStore_Unreserve(t *testing.T) { + defer func(i int64, d uint8) { + batchstore.Capacity = i + batchstore.DefaultDepth = d + }(batchstore.Capacity, batchstore.DefaultDepth) + batchstore.DefaultDepth = 5 + batchstore.Capacity = batchstore.Exp2(5) // 32 chunks + // 8 is the initial batch depth we add the initial state batches with. + // the default radius is 5 (defined in reserve.go file), which means there + // are 2^5 neighborhoods. now, since there are 2^8 chunks in a batch (256), + // we can divide that by the number of neighborhoods (32) and get 8, which is + // the number of chunks at most that can fall inside a neighborhood for a batch + initBatchDepth := uint8(8) + + for _, tc := range []struct { + desc string + add []depthValueTuple + exp []batchUnreserveTuple + }{ + { + // add one batch with value 2 and expect that it will be called in + // evict with radius 5, which means that the outer half of chunks from + // that batch will be deleted once chunks start entering the localstore. + // inner 2, outer 4 + desc: "add one at inner", + add: []depthValueTuple{depthValue(8, 2)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 4), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5)}, + }, { + // add one batch with value 3 and expect that it will be called in + // evict with radius 5 alongside with the other value 3 batch + // inner 3, outer 4 + desc: "add another at inner", + add: []depthValueTuple{depthValue(8, 3)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 4), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5)}, + }, { + // add one batch with value 4 and expect that the batch with value + // 3 gets called with radius 5, and BOTH batches with value 4 will + // also be called with radius 5. + // inner 3, outer 5 + desc: "add one at inner and evict half of self", + add: []depthValueTuple{depthValue(8, 4)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5)}, + }, { + // this builds on the previous case: + // since we over-evicted one batch before (since both 4's ended up in + // inner, then we can add another one at 4, and expect it also to be + // at inner (called with 5). + // inner 3, outer 5 (stays the same) + desc: "add one at inner and fill after over-eviction", + add: []depthValueTuple{depthValue(8, 4), depthValue(8, 4)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5), + batchUnreserve(5, 5), + }, + }, { + // insert a batch of depth 6 (2 chunks fall under our radius) + // value is 3, expect unreserve 5, expect other value 3 to be + // at radius 5. + // inner 3, outer 4 + desc: "insert smaller at inner", + add: []depthValueTuple{depthValue(6, 3)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 4), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5), + }, + }, { + // this case builds on the previous one: + // because we over-evicted, we can insert another batch of depth 6 + // with value 3, expect unreserve 5 + // inner 3, outer 4 + desc: "insert smaller and fill over-eviction", + add: []depthValueTuple{depthValue(6, 3), depthValue(6, 3)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 4), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5), + batchUnreserve(5, 5), + }, + }, { + // insert a batch of depth 6 (2 chunks fall under our radius) + // value is 4, expect unreserve 5, expect other value 3 to be + // at radius 5. + // inner 3, outer 4 + desc: "insert smaller and evict cheaper", + add: []depthValueTuple{depthValue(6, 4)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 4), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, { + // insert a batch of depth 6 (2 chunks fall under our radius) + // value is 6, expect unreserve 4, expect other value 3 to be + // at radius 5. + // inner 3, outer 4 + desc: "insert at outer and evict inner", + add: []depthValueTuple{depthValue(6, 6)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 4), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, { + // insert a batch of depth 9 (16 chunks in outer tier) + // expect batches with value 3 and 4 to be unreserved with radius 5 + // including the one that was just added (evicted half of itself) + // inner 3, outer 5 + desc: "insert at inner and evict self and sister batches", + add: []depthValueTuple{depthValue(9, 3)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5), + }, + }, { + // insert a batch of depth 9 (16 chunks in outer tier) + // expect batches with value 3 and 4 to be unreserved with radius 5 + // state is same as the last case + // inner 3, outer 5 + desc: "insert at inner and evict self and sister batches", + add: []depthValueTuple{depthValue(9, 4)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5), + }, + }, { + // insert a batch of depth 9 (16 chunks in outer tier), and 7 (8 chunks in premium) + // expect batches with value 3 to 5 to be unreserved with radius 5 + // inner 3, outer 6 + desc: "insert at outer and evict inner", + add: []depthValueTuple{depthValue(9, 5), depthValue(7, 5)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 4), + batchUnreserve(4, 5), + batchUnreserve(5, 5), + }, + }, { + // insert a batch of depth 10 value 3 (32 chunks in outer tier) + // expect all batches to be called with radius 5! + // inner 3, outer 3 + desc: "insert at outer and evict everything to fit the batch", + add: []depthValueTuple{depthValue(10, 3)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 5), + batchUnreserve(4, 5), + }, + }, { + // builds on the last case: + // insert a batch of depth 10 value 3 (32 chunks in outer tier) + // and of depth 7 value 3. expect value 3's to be called with radius 6 + // inner 3, outer 4 + desc: "insert another at outer and expect evict self", + add: []depthValueTuple{depthValue(10, 3), depthValue(7, 3)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 6), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 5), + batchUnreserve(4, 6), + batchUnreserve(5, 6), + }, + }, { + // insert a batch of depth 10 value 6 (32 chunks in outer tier) + // expect all batches to be called with unreserved 5 + // inner 3, outer 3 + desc: "insert at outer and evict from all to fit the batch", + add: []depthValueTuple{depthValue(10, 6)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 5), + batchUnreserve(4, 5), + }, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + store, unreserved := setupBatchStore(t) + store.SetRadiusSetter(noopRadiusSetter{}) + batches := addBatch(t, store, + depthValue(initBatchDepth, 3), + depthValue(initBatchDepth, 4), + depthValue(initBatchDepth, 5), + depthValue(initBatchDepth, 6), + ) + + checkUnreserved(t, unreserved, batches, 4) + + b := addBatch(t, store, tc.add...) + batches = append(batches, b...) + + for _, v := range tc.exp { + b := []*postage.Batch{batches[v.batchIndex]} + checkUnreserved(t, unreserved, b, v.expDepth) + } + }) + } +} + +// TestBatchStore_Topup tests that the unreserve +// hook is called with the correct batch IDs and correct +// Radius as a result of batches being topped up. +// All tests share the same initial state: +// ▲ bzz/chunk +// │ +// 6 ├──┐ +// 5 │ ├──┐ +// 4 │ │ ├──┐ +// 3 │ │ │ ├──┐ +// 2 │ │ │ │ ├──┐---inner, outer +// └──┴──┴──┴──┴──┴─────> time +// +func TestBatchStore_Topup(t *testing.T) { + defer func(i int64, d uint8) { + batchstore.Capacity = i + batchstore.DefaultDepth = d + }(batchstore.Capacity, batchstore.DefaultDepth) + batchstore.DefaultDepth = 5 + batchstore.Capacity = batchstore.Exp2(5) // 32 chunks + initBatchDepth := uint8(8) + + for _, tc := range []struct { + desc string + topup []batchValueTuple + exp []batchUnreserveTuple + }{ + { + // initial state + // inner 2, outer 4 + desc: "initial state", + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, { + // builds on initial state: + // topup of batch with value 2 to value 3 should result + // in no state change. + // inner 3, outer 4. before the topup: inner 2, outer 4 + desc: "topup value 2->3, same state", + topup: []batchValueTuple{batchValue(0, 3)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, { + // topup of batch with value 2 to value 4 should result + // in the other batches (3,4) in being downgraded to inner too, so all three batches are + // at inner. there's excess capacity + // inner 3, outer 5 + desc: "topup value 2->4, same state", + topup: []batchValueTuple{batchValue(0, 4)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, { + // builds on the last case: + // add another batch at value 2, and since we've over-evicted before, + // we should be able to accommodate it. + // inner 3, outer 5 + desc: "topup value 2->4, add another one at 2, same state", + topup: []batchValueTuple{batchValue(0, 4)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, { + // builds on the last case: + // add another batch at value 2, and since we've over-evicted before, + // we should be able to accommodate it. + // inner 3, outer 5 + desc: "topup value 2->4, add another one at 2, same state", + topup: []batchValueTuple{batchValue(0, 4)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + store, unreserved := setupBatchStore(t) + store.SetRadiusSetter(noopRadiusSetter{}) + batches := addBatch(t, store, + depthValue(initBatchDepth, 2), + depthValue(initBatchDepth, 3), + depthValue(initBatchDepth, 4), + depthValue(initBatchDepth, 5), + depthValue(initBatchDepth, 6), + ) + + topupBatch(t, store, batches, tc.topup...) + + for _, v := range tc.exp { + b := []*postage.Batch{batches[v.batchIndex]} + checkUnreserved(t, unreserved, b, v.expDepth) + } + }) + } +} + +// TestBatchStore_Dilution tests that the unreserve +// hook is called with the correct batch IDs and correct +// Radius as a result of batches being diluted. +// All tests share the same initial state: +// ▲ bzz/chunk +// │ +// 6 ├──┐ +// 5 │ ├──┐ +// 4 │ │ ├──┐ +// 3 │ │ │ ├──┐ +// 2 │ │ │ │ ├──┐---inner, outer +// └──┴──┴──┴──┴──┴─────> time +// +func TestBatchStore_Dilution(t *testing.T) { + defer func(i int64, d uint8) { + batchstore.Capacity = i + batchstore.DefaultDepth = d + }(batchstore.Capacity, batchstore.DefaultDepth) + batchstore.DefaultDepth = 5 + batchstore.Capacity = batchstore.Exp2(5) // 32 chunks + initBatchDepth := uint8(8) + + for _, tc := range []struct { + desc string + dilute []batchDepthTuple + topup []batchValueTuple + exp []batchUnreserveTuple + }{ + { + // initial state + // inner 2, outer 4 + desc: "initial state", + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, { + // dilution halves the value, and doubles the size of the batch + // recalculate the per chunk balance: + // ((value - total) / 2) + total => new batch value + + // expect this batch to be called with unreserved 5. + // the batch collected the outer half, so in fact when it was + // diluted it got downgraded from inner to outer, so it preserves + // the same amount of chunks. the rest stays the same + + // total is 0 at this point + + desc: "dilute most expensive", + dilute: []batchDepthTuple{batchDepth(4, 9)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 4), + batchUnreserve(3, 4), + batchUnreserve(4, 5), + }, + }, { + // expect this batch to be called with unreserved 5, but also the + // the rest of the batches to be evicted with radius 5 to fit this batch in + desc: "dilute most expensive further, evict batch from outer", + dilute: []batchDepthTuple{batchDepth(4, 10)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 5), + batchUnreserve(4, 5), + }, + }, { + // dilute the batch at value 3, expect to evict out the + // batch with value 4 to radius 5 + desc: "dilute cheaper batch and evict batch from outer", + dilute: []batchDepthTuple{batchDepth(1, 9)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 4), + batchUnreserve(4, 4), + }, + }, { + // top up the highest value batch to be value 12, then dilute it + // to be depth 9 (original 8), which causes it to be at value 6 + // expect batches with value 4 and 5 to evict outer, and the last + // batch to be at outer tier (radius 4) + // inner 2, outer 6 + desc: "dilute cheaper batch and evict batch from outer", + topup: []batchValueTuple{batchValue(4, 12)}, + dilute: []batchDepthTuple{batchDepth(4, 9)}, + exp: []batchUnreserveTuple{ + batchUnreserve(0, 5), + batchUnreserve(1, 5), + batchUnreserve(2, 5), + batchUnreserve(3, 5), + batchUnreserve(4, 4), + }, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + store, unreserved := setupBatchStore(t) + store.SetRadiusSetter(noopRadiusSetter{}) + batches := addBatch(t, store, + depthValue(initBatchDepth, 2), + depthValue(initBatchDepth, 3), + depthValue(initBatchDepth, 4), + depthValue(initBatchDepth, 5), + depthValue(initBatchDepth, 6), + ) + + topupBatch(t, store, batches, tc.topup...) + diluteBatch(t, store, batches, tc.dilute...) + + for _, v := range tc.exp { + b := []*postage.Batch{batches[v.batchIndex]} + checkUnreserved(t, unreserved, b, v.expDepth) + } + }) + } +} + +func TestBatchStore_EvictExpired(t *testing.T) { + defer func(i int64, d uint8) { + batchstore.Capacity = i + batchstore.DefaultDepth = d + }(batchstore.Capacity, batchstore.DefaultDepth) + batchstore.DefaultDepth = 5 + batchstore.Capacity = batchstore.Exp2(5) // 32 chunks + initBatchDepth := uint8(8) + + store, unreserved := setupBatchStore(t) + store.SetRadiusSetter(noopRadiusSetter{}) + batches := addBatch(t, store, + depthValue(initBatchDepth, 2), + depthValue(initBatchDepth, 3), + depthValue(initBatchDepth, 4), + depthValue(initBatchDepth, 5), + ) + + cs := store.GetChainState() + cs.Block = 4 + cs.Total = big.NewInt(4) + err := store.PutChainState(cs) + if err != nil { + t.Fatal(err) + } + + // expect the 5 to be preserved and the rest to be unreserved + checkUnreserved(t, unreserved, batches[:3], swarm.MaxPO+1) + checkUnreserved(t, unreserved, batches[3:], 4) + + // check that the batches is actually deleted from + // statestore, by trying to do a Get on the deleted + // batches, and assert that they are not found + for _, v := range batches[:3] { + _, err := store.Get(v.ID) + if !errors.Is(err, storage.ErrNotFound) { + t.Fatalf("expected err not found but got %v", err) + } + } +} + +type depthValueTuple struct { + depth uint8 + value int +} + +func depthValue(d uint8, v int) depthValueTuple { + return depthValueTuple{depth: d, value: v} +} + +type batchValueTuple struct { + batchIndex int + value *big.Int +} + +func batchValue(i, v int) batchValueTuple { + return batchValueTuple{batchIndex: i, value: big.NewInt(int64(v))} +} + +type batchUnreserveTuple struct { + batchIndex int + expDepth uint8 +} + +func batchUnreserve(i int, d uint8) batchUnreserveTuple { + return batchUnreserveTuple{batchIndex: i, expDepth: d} +} + +type batchDepthTuple struct { + batchIndex int + depth uint8 +} + +func batchDepth(i, d int) batchDepthTuple { + return batchDepthTuple{batchIndex: i, depth: uint8(d)} +} + +func topupBatch(t *testing.T, s postage.Storer, batches []*postage.Batch, bvp ...batchValueTuple) { + t.Helper() + for _, v := range bvp { + batch := batches[v.batchIndex] + err := s.Put(batch, v.value, batch.Depth) + if err != nil { + t.Fatal(err) + } + } +} + +func diluteBatch(t *testing.T, s postage.Storer, batches []*postage.Batch, bdp ...batchDepthTuple) { + t.Helper() + for _, v := range bdp { + batch := batches[v.batchIndex] + val := batch.Value + // for every depth increase we half the batch value + for i := batch.Depth; i < v.depth; i++ { + val = big.NewInt(0).Div(val, big.NewInt(2)) + } + err := s.Put(batch, val, v.depth) + if err != nil { + t.Fatal(err) + } + } +} + +func addBatch(t *testing.T, s postage.Storer, dvp ...depthValueTuple) []*postage.Batch { + t.Helper() + var batches []*postage.Batch + for _, v := range dvp { + b := postagetest.MustNewBatch() + + // this is needed since the initial batch state should be + // always zero. should be rectified with less magical test + // helpers + b.Value = big.NewInt(0) + b.Depth = uint8(0) + b.Start = 111 + + val := big.NewInt(int64(v.value)) + + err := s.Put(b, val, v.depth) + if err != nil { + t.Fatal(err) + } + batches = append(batches, b) + } + + return batches +} + +func checkUnreserved(t *testing.T, unreserved map[string]uint8, batches []*postage.Batch, exp uint8) { + t.Helper() + for _, b := range batches { + v, ok := unreserved[hex.EncodeToString(b.ID)] + if !ok { + t.Fatalf("batch %x not called with unreserve", b.ID) + } + if v != exp { + t.Fatalf("batch %x expected unreserve radius %d but got %d", b.ID, exp, v) + } + } +} diff --git a/pkg/postage/batchstore/store.go b/pkg/postage/batchstore/store.go new file mode 100644 index 00000000000..6d0033bbbca --- /dev/null +++ b/pkg/postage/batchstore/store.go @@ -0,0 +1,182 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package batchstore + +import ( + "errors" + "math/big" + + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/storage" +) + +const ( + batchKeyPrefix = "batchstore_batch_" + valueKeyPrefix = "batchstore_value_" + chainStateKey = "batchstore_chainstate" + reserveStateKey = "batchstore_reservestate" +) + +type unreserveFn func(batchID []byte, radius uint8) error + +// store implements postage.Storer +type store struct { + store storage.StateStorer // State store backend to persist batches. + cs *postage.ChainState // the chain state + rs *reserveState // the reserve state + unreserveFunc unreserveFn // unreserve function + metrics metrics // metrics + + radiusSetter postage.RadiusSetter // setter for radius notifications +} + +// New constructs a new postage batch store. +// It initialises both chain state and reserve state from the persistent state store +func New(st storage.StateStorer, unreserveFunc unreserveFn) (postage.Storer, error) { + cs := &postage.ChainState{} + err := st.Get(chainStateKey, cs) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + return nil, err + } + cs = &postage.ChainState{ + Block: 0, + Total: big.NewInt(0), + Price: big.NewInt(0), + } + } + rs := &reserveState{} + err = st.Get(reserveStateKey, rs) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + return nil, err + } + rs = &reserveState{ + Radius: DefaultDepth, + Inner: big.NewInt(0), + Outer: big.NewInt(0), + Available: Capacity, + } + } + s := &store{ + store: st, + cs: cs, + rs: rs, + unreserveFunc: unreserveFunc, + metrics: newMetrics(), + } + + return s, nil +} + +func (s *store) GetReserveState() *postage.Reservestate { + return &postage.Reservestate{ + Radius: s.rs.Radius, + Available: s.rs.Available, + Outer: new(big.Int).Set(s.rs.Outer), + Inner: new(big.Int).Set(s.rs.Inner), + } +} + +// Get returns a batch from the batchstore with the given ID. +func (s *store) Get(id []byte) (*postage.Batch, error) { + b := &postage.Batch{} + err := s.store.Get(batchKey(id), b) + if err != nil { + return nil, err + } + b.Radius = s.rs.radius(s.rs.tier(b.Value)) + return b, nil +} + +// Put stores a given batch in the batchstore and requires new values of Value and Depth +func (s *store) Put(b *postage.Batch, value *big.Int, depth uint8) error { + oldVal := new(big.Int).Set(b.Value) + oldDepth := b.Depth + err := s.store.Delete(valueKey(oldVal, b.ID)) + if err != nil { + return err + } + b.Value.Set(value) + b.Depth = depth + err = s.store.Put(valueKey(b.Value, b.ID), nil) + if err != nil { + return err + } + err = s.update(b, oldDepth, oldVal) + if err != nil { + return err + } + + if s.radiusSetter != nil { + s.radiusSetter.SetRadius(s.rs.Radius) + } + return s.store.Put(batchKey(b.ID), b) +} + +// delete removes the batches with ids given as arguments. +func (s *store) delete(ids ...[]byte) error { + for _, id := range ids { + b, err := s.Get(id) + if err != nil { + return err + } + err = s.store.Delete(valueKey(b.Value, id)) + if err != nil { + return err + } + err = s.store.Delete(batchKey(id)) + if err != nil { + return err + } + } + return nil +} + +// PutChainState implements BatchStorer. +// It purges expired batches and unreserves underfunded ones before it +// stores the chain state in the batch store. +func (s *store) PutChainState(cs *postage.ChainState) error { + s.cs = cs + err := s.evictExpired() + if err != nil { + return err + } + // this needs to be improved, since we can miss some calls on + // startup. the same goes for the other call to radiusSetter + if s.radiusSetter != nil { + s.radiusSetter.SetRadius(s.rs.Radius) + } + + return s.store.Put(chainStateKey, cs) +} + +// GetChainState implements BatchStorer. It returns the stored chain state from +// the batch store. +func (s *store) GetChainState() *postage.ChainState { + return s.cs +} + +func (s *store) SetRadiusSetter(r postage.RadiusSetter) { + s.radiusSetter = r +} + +// batchKey returns the index key for the batch ID used in the by-ID batch index. +func batchKey(id []byte) string { + return batchKeyPrefix + string(id) +} + +// valueKey returns the index key for the batch ID used in the by-ID batch index. +func valueKey(val *big.Int, id []byte) string { + value := make([]byte, 32) + val.FillBytes(value) // zero-extended big-endian byte slice + return valueKeyPrefix + string(value) + string(id) +} + +// valueKeyToID extracts the batch ID from a value key - used in value-based iteration +func valueKeyToID(key []byte) []byte { + l := len(key) + return key[l-32 : l] +} diff --git a/pkg/postage/batchstore/store_test.go b/pkg/postage/batchstore/store_test.go new file mode 100644 index 00000000000..80d7f64dce2 --- /dev/null +++ b/pkg/postage/batchstore/store_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package batchstore_test + +import ( + "testing" + + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/postage/batchstore" + postagetest "github.com/ethersphere/bee/pkg/postage/testing" + "github.com/ethersphere/bee/pkg/statestore/mock" + "github.com/ethersphere/bee/pkg/storage" +) + +func unreserve([]byte, uint8) error { return nil } +func TestBatchStoreGet(t *testing.T) { + testBatch := postagetest.MustNewBatch() + key := batchstore.BatchKey(testBatch.ID) + + stateStore := mock.NewStateStore() + batchStore, _ := batchstore.New(stateStore, nil) + + stateStorePut(t, stateStore, key, testBatch) + got := batchStoreGetBatch(t, batchStore, testBatch.ID) + postagetest.CompareBatches(t, testBatch, got) +} + +func TestBatchStorePut(t *testing.T) { + testBatch := postagetest.MustNewBatch() + key := batchstore.BatchKey(testBatch.ID) + + stateStore := mock.NewStateStore() + batchStore, _ := batchstore.New(stateStore, unreserve) + batchStore.SetRadiusSetter(noopRadiusSetter{}) + batchStorePutBatch(t, batchStore, testBatch) + + var got postage.Batch + stateStoreGet(t, stateStore, key, &got) + postagetest.CompareBatches(t, testBatch, &got) +} + +func TestBatchStoreGetChainState(t *testing.T) { + testChainState := postagetest.NewChainState() + + stateStore := mock.NewStateStore() + batchStore, _ := batchstore.New(stateStore, nil) + batchStore.SetRadiusSetter(noopRadiusSetter{}) + + err := batchStore.PutChainState(testChainState) + if err != nil { + t.Fatal(err) + } + got := batchStore.GetChainState() + postagetest.CompareChainState(t, testChainState, got) +} + +func TestBatchStorePutChainState(t *testing.T) { + testChainState := postagetest.NewChainState() + + stateStore := mock.NewStateStore() + batchStore, _ := batchstore.New(stateStore, nil) + batchStore.SetRadiusSetter(noopRadiusSetter{}) + + batchStorePutChainState(t, batchStore, testChainState) + var got postage.ChainState + stateStoreGet(t, stateStore, batchstore.StateKey, &got) + postagetest.CompareChainState(t, testChainState, &got) +} + +func stateStoreGet(t *testing.T, st storage.StateStorer, k string, v interface{}) { + if err := st.Get(k, v); err != nil { + t.Fatalf("store get batch: %v", err) + } +} + +func stateStorePut(t *testing.T, st storage.StateStorer, k string, v interface{}) { + if err := st.Put(k, v); err != nil { + t.Fatalf("store put batch: %v", err) + } +} + +func batchStoreGetBatch(t *testing.T, st postage.Storer, id []byte) *postage.Batch { + t.Helper() + b, err := st.Get(id) + if err != nil { + t.Fatalf("postage storer get: %v", err) + } + return b +} + +func batchStorePutBatch(t *testing.T, st postage.Storer, b *postage.Batch) { + t.Helper() + if err := st.Put(b, b.Value, b.Depth); err != nil { + t.Fatalf("postage storer put: %v", err) + } +} + +func batchStorePutChainState(t *testing.T, st postage.Storer, cs *postage.ChainState) { + t.Helper() + if err := st.PutChainState(cs); err != nil { + t.Fatalf("postage storer put chain state: %v", err) + } +} + +type noopRadiusSetter struct{} + +func (n noopRadiusSetter) SetRadius(_ uint8) {} diff --git a/pkg/postage/chainstate.go b/pkg/postage/chainstate.go new file mode 100644 index 00000000000..25818bd49f6 --- /dev/null +++ b/pkg/postage/chainstate.go @@ -0,0 +1,14 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package postage + +import "math/big" + +// ChainState contains data the batch service reads from the chain. +type ChainState struct { + Block uint64 `json:"block"` // The block number of the last postage event. + Total *big.Int `json:"total"` // Cumulative amount paid per stamp. + Price *big.Int `json:"price"` // Bzz/chunk/block normalised price. +} diff --git a/pkg/postage/export_test.go b/pkg/postage/export_test.go new file mode 100644 index 00000000000..fbd0fa25b47 --- /dev/null +++ b/pkg/postage/export_test.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package postage + +import ( + "github.com/ethersphere/bee/pkg/swarm" +) + +func (st *StampIssuer) Inc(a swarm.Address) error { + return st.inc(a) +} diff --git a/pkg/postage/interface.go b/pkg/postage/interface.go new file mode 100644 index 00000000000..0dc13a952c6 --- /dev/null +++ b/pkg/postage/interface.go @@ -0,0 +1,42 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package postage + +import ( + "io" + "math/big" +) + +// EventUpdater interface definitions reflect the updates triggered by events +// emitted by the postage contract on the blockchain. +type EventUpdater interface { + Create(id []byte, owner []byte, normalisedBalance *big.Int, depth uint8) error + TopUp(id []byte, normalisedBalance *big.Int) error + UpdateDepth(id []byte, depth uint8, normalisedBalance *big.Int) error + UpdatePrice(price *big.Int) error + UpdateBlockNumber(blockNumber uint64) error + Start() <-chan struct{} +} + +// Storer represents the persistence layer for batches on the current (highest +// available) block. +type Storer interface { + Get(id []byte) (*Batch, error) + Put(*Batch, *big.Int, uint8) error + PutChainState(*ChainState) error + GetChainState() *ChainState + GetReserveState() *Reservestate + SetRadiusSetter(RadiusSetter) +} + +type RadiusSetter interface { + SetRadius(uint8) +} + +// Listener provides a blockchain event iterator. +type Listener interface { + io.Closer + Listen(from uint64, updater EventUpdater) <-chan struct{} +} diff --git a/pkg/postage/listener/export_test.go b/pkg/postage/listener/export_test.go new file mode 100644 index 00000000000..596255c4764 --- /dev/null +++ b/pkg/postage/listener/export_test.go @@ -0,0 +1,17 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package listener + +var ( + PostageStampABI = postageStampABI + PriceOracleABI = priceOracleABI + + BatchCreatedTopic = batchCreatedTopic + BatchTopupTopic = batchTopupTopic + BatchDepthIncreaseTopic = batchDepthIncreaseTopic + PriceUpdateTopic = priceUpdateTopic + + TailSize = tailSize +) diff --git a/pkg/postage/listener/listener.go b/pkg/postage/listener/listener.go new file mode 100644 index 00000000000..993c1f928c0 --- /dev/null +++ b/pkg/postage/listener/listener.go @@ -0,0 +1,284 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package listener + +import ( + "context" + "errors" + "fmt" + "math/big" + "strings" + "sync" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/settlement/swap/transaction" + "github.com/ethersphere/go-storage-incentives-abi/postageabi" +) + +const ( + blockPage = 10000 // how many blocks to sync every time + tailSize = 4 // how many blocks to tail from the tip of the chain +) + +var ( + chainUpdateInterval = 5 * time.Second + postageStampABI = parseABI(postageabi.PostageStampABIv0_1_0) + priceOracleABI = parseABI(postageabi.PriceOracleABIv0_1_0) + // batchCreatedTopic is the postage contract's batch created event topic + batchCreatedTopic = postageStampABI.Events["BatchCreated"].ID + // batchTopupTopic is the postage contract's batch topup event topic + batchTopupTopic = postageStampABI.Events["BatchTopUp"].ID + // batchDepthIncreaseTopic is the postage contract's batch dilution event topic + batchDepthIncreaseTopic = postageStampABI.Events["BatchDepthIncrease"].ID + // priceUpdateTopic is the price oracle's price update event topic + priceUpdateTopic = priceOracleABI.Events["PriceUpdate"].ID +) + +type BlockHeightContractFilterer interface { + bind.ContractFilterer + BlockNumber(context.Context) (uint64, error) +} + +type listener struct { + logger logging.Logger + ev BlockHeightContractFilterer + + postageStampAddress common.Address + priceOracleAddress common.Address + quit chan struct{} + wg sync.WaitGroup +} + +func New( + logger logging.Logger, + ev BlockHeightContractFilterer, + postageStampAddress, + priceOracleAddress common.Address, +) postage.Listener { + return &listener{ + logger: logger, + ev: ev, + + postageStampAddress: postageStampAddress, + priceOracleAddress: priceOracleAddress, + quit: make(chan struct{}), + } +} + +func (l *listener) filterQuery(from, to *big.Int) ethereum.FilterQuery { + return ethereum.FilterQuery{ + FromBlock: from, + ToBlock: to, + Addresses: []common.Address{ + l.postageStampAddress, + l.priceOracleAddress, + }, + Topics: [][]common.Hash{ + { + batchCreatedTopic, + batchTopupTopic, + batchDepthIncreaseTopic, + priceUpdateTopic, + }, + }, + } +} + +func (l *listener) processEvent(e types.Log, updater postage.EventUpdater) error { + switch e.Topics[0] { + case batchCreatedTopic: + c := &batchCreatedEvent{} + err := transaction.ParseEvent(&postageStampABI, "BatchCreated", c, e) + if err != nil { + return err + } + return updater.Create( + c.BatchId[:], + c.Owner.Bytes(), + c.NormalisedBalance, + c.Depth, + ) + case batchTopupTopic: + c := &batchTopUpEvent{} + err := transaction.ParseEvent(&postageStampABI, "BatchTopUp", c, e) + if err != nil { + return err + } + return updater.TopUp( + c.BatchId[:], + c.NormalisedBalance, + ) + case batchDepthIncreaseTopic: + c := &batchDepthIncreaseEvent{} + err := transaction.ParseEvent(&postageStampABI, "BatchDepthIncrease", c, e) + if err != nil { + return err + } + return updater.UpdateDepth( + c.BatchId[:], + c.NewDepth, + c.NormalisedBalance, + ) + case priceUpdateTopic: + c := &priceUpdateEvent{} + err := transaction.ParseEvent(&priceOracleABI, "PriceUpdate", c, e) + if err != nil { + return err + } + return updater.UpdatePrice( + c.Price, + ) + default: + return errors.New("unknown event") + } +} + +func (l *listener) Listen(from uint64, updater postage.EventUpdater) <-chan struct{} { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-l.quit + cancel() + }() + + synced := make(chan struct{}) + closeOnce := new(sync.Once) + paged := make(chan struct{}, 1) + paged <- struct{}{} + + l.wg.Add(1) + listenf := func() error { + defer l.wg.Done() + for { + select { + case <-paged: + // if we paged then it means there's more things to sync on + case <-time.After(chainUpdateInterval): + case <-l.quit: + return nil + } + to, err := l.ev.BlockNumber(ctx) + if err != nil { + return err + } + + if to < tailSize { + // in a test blockchain there might be not be enough blocks yet + continue + } + + // consider to-tailSize as the "latest" block we need to sync to + to = to - tailSize + + if to < from { + // if the blockNumber is actually less than what we already, it might mean the backend is not synced or some reorg scenario + continue + } + + // do some paging (sub-optimal) + if to-from > blockPage { + paged <- struct{}{} + to = from + blockPage + } else { + closeOnce.Do(func() { close(synced) }) + } + + events, err := l.ev.FilterLogs(ctx, l.filterQuery(big.NewInt(int64(from)), big.NewInt(int64(to)))) + if err != nil { + return err + } + + // this is called before processing the events + // so that the eviction in batchstore gets the correct + // block height context for the gc round. otherwise + // expired batches might be "revived". + err = updater.UpdateBlockNumber(to) + if err != nil { + return err + } + + for _, e := range events { + if err = l.processEvent(e, updater); err != nil { + return err + } + } + + from = to + 1 + } + } + + go func() { + err := listenf() + if err != nil { + l.logger.Errorf("event listener sync: %v", err) + } + }() + + return synced +} + +func (l *listener) Close() error { + close(l.quit) + done := make(chan struct{}) + + go func() { + defer close(done) + l.wg.Wait() + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + return errors.New("postage listener closed with running goroutines") + } + return nil +} + +func parseABI(json string) abi.ABI { + cabi, err := abi.JSON(strings.NewReader(json)) + if err != nil { + panic(fmt.Sprintf("error creating ABI for postage contract: %v", err)) + } + return cabi +} + +type batchCreatedEvent struct { + BatchId [32]byte + TotalAmount *big.Int + NormalisedBalance *big.Int + Owner common.Address + Depth uint8 +} + +type batchTopUpEvent struct { + BatchId [32]byte + TopupAmount *big.Int + NormalisedBalance *big.Int +} + +type batchDepthIncreaseEvent struct { + BatchId [32]byte + NewDepth uint8 + NormalisedBalance *big.Int +} + +type priceUpdateEvent struct { + Price *big.Int +} + +// DiscoverAddresses returns the canonical contracts for this chainID +func DiscoverAddresses(chainID int64) (postageStamp, priceOracle common.Address, found bool) { + if chainID == 5 { + // goerli + return common.HexToAddress("0xF7a041E7e2B79ccA1975852Eb6D4c6cE52986b4a"), common.HexToAddress("0x1044534090de6f4014ece6d036C699130Bd5Df43"), true + } + return common.Address{}, common.Address{}, false +} diff --git a/pkg/postage/listener/listener_test.go b/pkg/postage/listener/listener_test.go new file mode 100644 index 00000000000..e0696f0b25e --- /dev/null +++ b/pkg/postage/listener/listener_test.go @@ -0,0 +1,477 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package listener_test + +import ( + "bytes" + "context" + "io/ioutil" + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage/listener" +) + +var hash common.Hash = common.HexToHash("ff6ec1ed9250a6952fabac07c6eb103550dc65175373eea432fd115ce8bb2246") +var addr common.Address = common.HexToAddress("abcdef") + +var postageStampAddress common.Address = common.HexToAddress("eeee") +var priceOracleAddress common.Address = common.HexToAddress("eeef") + +func TestListener(t *testing.T) { + logger := logging.New(ioutil.Discard, 0) + timeout := 5 * time.Second + // test that when the listener gets a certain event + // then we would like to assert the appropriate EventUpdater method was called + t.Run("create event", func(t *testing.T) { + c := createArgs{ + id: hash[:], + owner: addr[:], + amount: big.NewInt(42), + normalisedAmount: big.NewInt(43), + depth: 100, + } + + ev, evC := newEventUpdaterMock() + mf := newMockFilterer( + WithFilterLogEvents( + c.toLog(), + ), + ) + listener := listener.New(logger, mf, postageStampAddress, priceOracleAddress) + listener.Listen(0, ev) + + select { + case e := <-evC: + e.(blockNumberCall).compare(t, 0) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + + select { + case e := <-evC: + e.(createArgs).compare(t, c) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + }) + + t.Run("topup event", func(t *testing.T) { + topup := topupArgs{ + id: hash[:], + amount: big.NewInt(0), + normalisedBalance: big.NewInt(1), + } + + ev, evC := newEventUpdaterMock() + mf := newMockFilterer( + WithFilterLogEvents( + topup.toLog(), + ), + ) + listener := listener.New(logger, mf, postageStampAddress, priceOracleAddress) + listener.Listen(0, ev) + + select { + case e := <-evC: + e.(blockNumberCall).compare(t, 0) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + + select { + case e := <-evC: + e.(topupArgs).compare(t, topup) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + }) + + t.Run("depthIncrease event", func(t *testing.T) { + depthIncrease := depthArgs{ + id: hash[:], + depth: 200, + normalisedBalance: big.NewInt(2), + } + + ev, evC := newEventUpdaterMock() + mf := newMockFilterer( + WithFilterLogEvents( + depthIncrease.toLog(), + ), + ) + listener := listener.New(logger, mf, postageStampAddress, priceOracleAddress) + listener.Listen(0, ev) + + select { + case e := <-evC: + e.(blockNumberCall).compare(t, 0) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + + select { + case e := <-evC: + e.(depthArgs).compare(t, depthIncrease) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + }) + + t.Run("priceUpdate event", func(t *testing.T) { + priceUpdate := priceArgs{ + price: big.NewInt(500), + } + + ev, evC := newEventUpdaterMock() + mf := newMockFilterer( + WithFilterLogEvents( + priceUpdate.toLog(), + ), + ) + listener := listener.New(logger, mf, postageStampAddress, priceOracleAddress) + listener.Listen(0, ev) + + select { + case e := <-evC: + e.(blockNumberCall).compare(t, 0) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + + select { + case e := <-evC: + e.(priceArgs).compare(t, priceUpdate) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + }) + + t.Run("multiple events", func(t *testing.T) { + c := createArgs{ + id: hash[:], + owner: addr[:], + amount: big.NewInt(42), + normalisedAmount: big.NewInt(43), + depth: 100, + } + + topup := topupArgs{ + id: hash[:], + amount: big.NewInt(0), + normalisedBalance: big.NewInt(1), + } + + depthIncrease := depthArgs{ + id: hash[:], + depth: 200, + normalisedBalance: big.NewInt(2), + } + + priceUpdate := priceArgs{ + price: big.NewInt(500), + } + + blockNumber := uint64(500) + + ev, evC := newEventUpdaterMock() + mf := newMockFilterer( + WithFilterLogEvents( + c.toLog(), + topup.toLog(), + depthIncrease.toLog(), + priceUpdate.toLog(), + ), + WithBlockNumber(blockNumber), + ) + l := listener.New(logger, mf, postageStampAddress, priceOracleAddress) + l.Listen(0, ev) + + select { + case e := <-evC: + e.(blockNumberCall).compare(t, blockNumber-uint64(listener.TailSize)) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for block number update") + } + + select { + case e := <-evC: + e.(createArgs).compare(t, c) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + + select { + case e := <-evC: + e.(topupArgs).compare(t, topup) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + + select { + case e := <-evC: + e.(depthArgs).compare(t, depthIncrease) // event args should be equal + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + + select { + case e := <-evC: + e.(priceArgs).compare(t, priceUpdate) + case <-time.After(timeout): + t.Fatal("timed out waiting for event") + } + }) +} + +func newEventUpdaterMock() (*updater, chan interface{}) { + c := make(chan interface{}) + return &updater{ + eventC: c, + }, c +} + +type updater struct { + eventC chan interface{} +} + +func (u *updater) Create(id, owner []byte, normalisedAmount *big.Int, depth uint8) error { + u.eventC <- createArgs{ + id: id, + owner: owner, + normalisedAmount: normalisedAmount, + depth: depth, + } + return nil +} + +func (u *updater) TopUp(id []byte, normalisedBalance *big.Int) error { + u.eventC <- topupArgs{ + id: id, + normalisedBalance: normalisedBalance, + } + return nil +} + +func (u *updater) UpdateDepth(id []byte, depth uint8, normalisedBalance *big.Int) error { + u.eventC <- depthArgs{ + id: id, + depth: depth, + normalisedBalance: normalisedBalance, + } + return nil +} + +func (u *updater) UpdatePrice(price *big.Int) error { + u.eventC <- priceArgs{price} + return nil +} + +func (u *updater) UpdateBlockNumber(blockNumber uint64) error { + u.eventC <- blockNumberCall{blockNumber: blockNumber} + return nil +} + +func (u *updater) Start() <-chan struct{} { return nil } + +type mockFilterer struct { + filterLogEvents []types.Log + subscriptionEvents []types.Log + sub *sub + blockNumber uint64 +} + +func newMockFilterer(opts ...Option) *mockFilterer { + mock := &mockFilterer{ + blockNumber: uint64(listener.TailSize), // use the tailSize as blockNumber by default to ensure at least block 0 is ready + } + for _, o := range opts { + o.apply(mock) + } + return mock +} + +func WithFilterLogEvents(events ...types.Log) Option { + return optionFunc(func(s *mockFilterer) { + s.filterLogEvents = events + }) +} + +func WithBlockNumber(blockNumber uint64) Option { + return optionFunc(func(s *mockFilterer) { + s.blockNumber = blockNumber + }) +} + +func (m *mockFilterer) FilterLogs(ctx context.Context, query ethereum.FilterQuery) ([]types.Log, error) { + return m.filterLogEvents, nil +} + +func (m *mockFilterer) SubscribeFilterLogs(ctx context.Context, query ethereum.FilterQuery, ch chan<- types.Log) (ethereum.Subscription, error) { + go func() { + for _, ev := range m.subscriptionEvents { + ch <- ev + } + }() + s := newSub() + return s, nil +} + +func (m *mockFilterer) Close() { + close(m.sub.c) +} + +func (m *mockFilterer) BlockNumber(context.Context) (uint64, error) { + return m.blockNumber, nil +} + +type sub struct { + c chan error +} + +func newSub() *sub { + return &sub{ + c: make(chan error), + } +} + +func (s *sub) Unsubscribe() {} +func (s *sub) Err() <-chan error { + return s.c +} + +type createArgs struct { + id []byte + owner []byte + amount *big.Int + normalisedAmount *big.Int + depth uint8 +} + +func (c createArgs) compare(t *testing.T, want createArgs) { + if !bytes.Equal(c.id, want.id) { + t.Fatalf("id mismatch. got %v want %v", c.id, want.id) + } + if !bytes.Equal(c.owner, want.owner) { + t.Fatalf("owner mismatch. got %v want %v", c.owner, want.owner) + } + if c.normalisedAmount.Cmp(want.normalisedAmount) != 0 { + t.Fatalf("normalised amount mismatch. got %v want %v", c.normalisedAmount.String(), want.normalisedAmount.String()) + } +} + +func (c createArgs) toLog() types.Log { + b, err := listener.PostageStampABI.Events["BatchCreated"].Inputs.NonIndexed().Pack(c.amount, c.normalisedAmount, common.BytesToAddress(c.owner), c.depth) + if err != nil { + panic(err) + } + return types.Log{ + Data: b, + Topics: []common.Hash{listener.BatchCreatedTopic, common.BytesToHash(c.id)}, // 1st item is the function sig digest, 2nd is always the batch id + } +} + +type topupArgs struct { + id []byte + amount *big.Int + normalisedBalance *big.Int +} + +func (ta topupArgs) compare(t *testing.T, want topupArgs) { + t.Helper() + if !bytes.Equal(ta.id, want.id) { + t.Fatalf("id mismatch. got %v want %v", ta.id, want.id) + } + if ta.normalisedBalance.Cmp(want.normalisedBalance) != 0 { + t.Fatalf("normalised balance mismatch. got %v want %v", ta.normalisedBalance.String(), want.normalisedBalance.String()) + } +} + +func (ta topupArgs) toLog() types.Log { + b, err := listener.PostageStampABI.Events["BatchTopUp"].Inputs.NonIndexed().Pack(ta.amount, ta.normalisedBalance) + if err != nil { + panic(err) + } + return types.Log{ + Data: b, + Topics: []common.Hash{listener.BatchTopupTopic, common.BytesToHash(ta.id)}, // 1st item is the function sig digest, 2nd is always the batch id + } +} + +type depthArgs struct { + id []byte + depth uint8 + normalisedBalance *big.Int +} + +func (d depthArgs) compare(t *testing.T, want depthArgs) { + t.Helper() + if !bytes.Equal(d.id, want.id) { + t.Fatalf("id mismatch. got %v want %v", d.id, want.id) + } + if d.depth != want.depth { + t.Fatalf("depth mismatch. got %d want %d", d.depth, want.depth) + } + if d.normalisedBalance.Cmp(want.normalisedBalance) != 0 { + t.Fatalf("normalised balance mismatch. got %v want %v", d.normalisedBalance.String(), want.normalisedBalance.String()) + } +} + +func (d depthArgs) toLog() types.Log { + b, err := listener.PostageStampABI.Events["BatchDepthIncrease"].Inputs.NonIndexed().Pack(d.depth, d.normalisedBalance) + if err != nil { + panic(err) + } + return types.Log{ + Data: b, + Topics: []common.Hash{listener.BatchDepthIncreaseTopic, common.BytesToHash(d.id)}, // 1st item is the function sig digest, 2nd is always the batch id + } +} + +type priceArgs struct { + price *big.Int +} + +func (p priceArgs) compare(t *testing.T, want priceArgs) { + t.Helper() + if p.price.Cmp(want.price) != 0 { + t.Fatalf("price mismatch. got %s want %s", p.price.String(), want.price.String()) + } +} + +func (p priceArgs) toLog() types.Log { + b, err := listener.PriceOracleABI.Events["PriceUpdate"].Inputs.NonIndexed().Pack(p.price) + if err != nil { + panic(err) + } + return types.Log{ + Data: b, + Topics: []common.Hash{listener.PriceUpdateTopic}, + } +} + +type blockNumberCall struct { + blockNumber uint64 +} + +func (b blockNumberCall) compare(t *testing.T, want uint64) { + t.Helper() + if b.blockNumber != want { + t.Fatalf("blockNumber mismatch. got %d want %d", b.blockNumber, want) + } +} + +type Option interface { + apply(*mockFilterer) +} + +type optionFunc func(*mockFilterer) + +func (f optionFunc) apply(r *mockFilterer) { f(r) } diff --git a/pkg/postage/mock/service.go b/pkg/postage/mock/service.go new file mode 100644 index 00000000000..f7ca44a5297 --- /dev/null +++ b/pkg/postage/mock/service.go @@ -0,0 +1,73 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mock + +import ( + "errors" + + "github.com/ethersphere/bee/pkg/postage" +) + +type optionFunc func(*mockPostage) + +// Option is an option passed to a mock postage Service. +type Option interface { + apply(*mockPostage) +} + +func (f optionFunc) apply(r *mockPostage) { f(r) } + +// New creates a new mock postage service. +func New(o ...Option) postage.Service { + m := &mockPostage{} + for _, v := range o { + v.apply(m) + } + + return m +} + +// WithAcceptAll sets the mock to return a new BatchIssuer on every +// call to GetStampIssuer. +func WithAcceptAll() Option { + return optionFunc(func(m *mockPostage) { m.acceptAll = true }) +} + +func WithIssuer(s *postage.StampIssuer) Option { + return optionFunc(func(m *mockPostage) { m.i = s }) +} + +type mockPostage struct { + i *postage.StampIssuer + acceptAll bool +} + +func (m *mockPostage) Add(s *postage.StampIssuer) { + m.i = s +} + +func (m *mockPostage) StampIssuers() []*postage.StampIssuer { + return []*postage.StampIssuer{m.i} +} + +func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, error) { + if m.acceptAll { + return postage.NewStampIssuer("test fallback", "test identity", id, 24, 6), nil + } + + if m.i != nil { + return m.i, nil + } + + return nil, errors.New("stampissuer not found") +} + +func (m *mockPostage) Load() error { + panic("not implemented") // TODO: Implement +} + +func (m *mockPostage) Save() error { + panic("not implemented") // TODO: Implement +} diff --git a/pkg/postage/mock/stamper.go b/pkg/postage/mock/stamper.go new file mode 100644 index 00000000000..2c5b0a543f7 --- /dev/null +++ b/pkg/postage/mock/stamper.go @@ -0,0 +1,22 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mock + +import ( + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/swarm" +) + +type mockStamper struct{} + +// NewStamper returns anew new mock stamper. +func NewStamper() postage.Stamper { + return &mockStamper{} +} + +// Stamp implements the Stamper interface. It returns an empty postage stamp. +func (mockStamper) Stamp(_ swarm.Address) (*postage.Stamp, error) { + return &postage.Stamp{}, nil +} diff --git a/pkg/postage/postagecontract/contract.go b/pkg/postage/postagecontract/contract.go new file mode 100644 index 00000000000..41cf1ee181f --- /dev/null +++ b/pkg/postage/postagecontract/contract.go @@ -0,0 +1,238 @@ +// Copyright 2021 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package postagecontract + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "math/big" + "strings" + + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/settlement/swap/transaction" + "github.com/ethersphere/go-storage-incentives-abi/postageabi" + "github.com/ethersphere/go-sw3-abi/sw3abi" +) + +var ( + BucketDepth = uint8(16) + + postageStampABI = parseABI(postageabi.PostageStampABIv0_1_0) + erc20ABI = parseABI(sw3abi.ERC20ABIv0_3_1) + batchCreatedTopic = postageStampABI.Events["BatchCreated"].ID + + ErrBatchCreate = errors.New("batch creation failed") + ErrInsufficientFunds = errors.New("insufficient token balance") + ErrInvalidDepth = errors.New("invalid depth") +) + +type Interface interface { + CreateBatch(ctx context.Context, initialBalance *big.Int, depth uint8, label string) ([]byte, error) +} + +type postageContract struct { + owner common.Address + postageContractAddress common.Address + bzzTokenAddress common.Address + transactionService transaction.Service + postageService postage.Service +} + +func New( + owner, + postageContractAddress, + bzzTokenAddress common.Address, + transactionService transaction.Service, + postageService postage.Service, +) Interface { + return &postageContract{ + owner: owner, + postageContractAddress: postageContractAddress, + bzzTokenAddress: bzzTokenAddress, + transactionService: transactionService, + postageService: postageService, + } +} + +func (c *postageContract) sendApproveTransaction(ctx context.Context, amount *big.Int) (*types.Receipt, error) { + callData, err := erc20ABI.Pack("approve", c.postageContractAddress, amount) + if err != nil { + return nil, err + } + + txHash, err := c.transactionService.Send(ctx, &transaction.TxRequest{ + To: &c.bzzTokenAddress, + Data: callData, + GasPrice: nil, + GasLimit: 0, + Value: big.NewInt(0), + }) + if err != nil { + return nil, err + } + + receipt, err := c.transactionService.WaitForReceipt(ctx, txHash) + if err != nil { + return nil, err + } + + if receipt.Status == 0 { + return nil, transaction.ErrTransactionReverted + } + + return receipt, nil +} + +func (c *postageContract) sendCreateBatchTransaction(ctx context.Context, owner common.Address, initialBalance *big.Int, depth uint8, nonce common.Hash) (*types.Receipt, error) { + callData, err := postageStampABI.Pack("createBatch", owner, initialBalance, depth, nonce) + if err != nil { + return nil, err + } + + request := &transaction.TxRequest{ + To: &c.postageContractAddress, + Data: callData, + GasPrice: nil, + GasLimit: 0, + Value: big.NewInt(0), + } + + txHash, err := c.transactionService.Send(ctx, request) + if err != nil { + return nil, err + } + + receipt, err := c.transactionService.WaitForReceipt(ctx, txHash) + if err != nil { + return nil, err + } + + if receipt.Status == 0 { + return nil, transaction.ErrTransactionReverted + } + + return receipt, nil +} + +func (c *postageContract) getBalance(ctx context.Context) (*big.Int, error) { + callData, err := erc20ABI.Pack("balanceOf", c.owner) + if err != nil { + return nil, err + } + + result, err := c.transactionService.Call(ctx, &transaction.TxRequest{ + To: &c.bzzTokenAddress, + Data: callData, + }) + if err != nil { + return nil, err + } + + results, err := erc20ABI.Unpack("balanceOf", result) + if err != nil { + return nil, err + } + return abi.ConvertType(results[0], new(big.Int)).(*big.Int), nil +} + +func (c *postageContract) CreateBatch(ctx context.Context, initialBalance *big.Int, depth uint8, label string) ([]byte, error) { + + if depth < BucketDepth { + return nil, ErrInvalidDepth + } + + totalAmount := big.NewInt(0).Mul(initialBalance, big.NewInt(int64(1<> (32 - depth) +} + +// Label returns the label of the issuer. +func (st *StampIssuer) Label() string { + return st.label +} + +// MarshalBinary gives the byte slice serialisation of a StampIssuer: +// = label[32]|keyID[32]|batchID[32]|batchDepth[1]|bucketDepth[1]|size_0[4]|size_1[4]|.... +func (st *StampIssuer) MarshalBinary() ([]byte, error) { + buf := make([]byte, 32+32+32+1+1+(1<<(st.bucketDepth+2))) + label := []byte(st.label) + copy(buf[32-len(label):32], label) + keyID := []byte(st.keyID) + copy(buf[64-len(keyID):64], keyID) + copy(buf[64:96], st.batchID) + buf[96] = st.batchDepth + buf[97] = st.bucketDepth + st.mu.Lock() + defer st.mu.Unlock() + for i, addr := range st.buckets { + offset := 98 + i*4 + binary.BigEndian.PutUint32(buf[offset:offset+4], addr) + } + return buf, nil +} + +// UnmarshalBinary parses a serialised StampIssuer into the receiver struct. +func (st *StampIssuer) UnmarshalBinary(buf []byte) error { + st.label = toString(buf[:32]) + st.keyID = toString(buf[32:64]) + st.batchID = buf[64:96] + st.batchDepth = buf[96] + st.bucketDepth = buf[97] + st.buckets = make([]uint32, 1< top { + top = v + } + } + + return top +} + +// ID returns the BatchID for this batch. +func (s *StampIssuer) ID() []byte { + id := make([]byte, len(s.batchID)) + copy(id, s.batchID) + return id +} diff --git a/pkg/postage/stampissuer_test.go b/pkg/postage/stampissuer_test.go new file mode 100644 index 00000000000..d4c01da8b91 --- /dev/null +++ b/pkg/postage/stampissuer_test.go @@ -0,0 +1,54 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package postage_test + +import ( + crand "crypto/rand" + "io" + "reflect" + "testing" + + "github.com/ethersphere/bee/pkg/postage" + "github.com/ethersphere/bee/pkg/swarm" +) + +// TestStampIssuerMarshalling tests the idempotence of binary marshal/unmarshal. +func TestStampIssuerMarshalling(t *testing.T) { + st := newTestStampIssuer(t) + buf, err := st.MarshalBinary() + if err != nil { + t.Fatal(err) + } + st0 := &postage.StampIssuer{} + err = st0.UnmarshalBinary(buf) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(st, st0) { + t.Fatalf("unmarshal(marshal(StampIssuer)) != StampIssuer \n%v\n%v", st, st0) + } +} + +func newTestStampIssuer(t *testing.T) *postage.StampIssuer { + t.Helper() + id := make([]byte, 32) + _, err := io.ReadFull(crand.Reader, id) + if err != nil { + t.Fatal(err) + } + st := postage.NewStampIssuer("label", "keyID", id, 16, 8) + addr := make([]byte, 32) + for i := 0; i < 1<<8; i++ { + _, err := io.ReadFull(crand.Reader, addr) + if err != nil { + t.Fatal(err) + } + err = st.Inc(swarm.NewAddress(addr)) + if err != nil { + t.Fatal(err) + } + } + return st +} diff --git a/pkg/postage/testing/batch.go b/pkg/postage/testing/batch.go new file mode 100644 index 00000000000..0a89b2181b4 --- /dev/null +++ b/pkg/postage/testing/batch.go @@ -0,0 +1,98 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testing + +import ( + "bytes" + crand "crypto/rand" + "io" + "math/big" + "math/rand" + "testing" + + "github.com/ethersphere/bee/pkg/postage" +) + +const defaultDepth = 16 + +// BatchOption is an optional parameter for NewBatch +type BatchOption func(c *postage.Batch) + +// MustNewID will generate a new random ID (32 byte slice). Panics on errors. +func MustNewID() []byte { + id := make([]byte, 32) + _, err := io.ReadFull(crand.Reader, id) + if err != nil { + panic(err) + } + return id +} + +// MustNewAddress will generate a new random address (20 byte slice). Panics on +// errors. +func MustNewAddress() []byte { + addr := make([]byte, 20) + _, err := io.ReadFull(crand.Reader, addr) + if err != nil { + panic(err) + } + return addr +} + +// NewBigInt will generate a new random big int (uint64 base value). +func NewBigInt() *big.Int { + return (new(big.Int)).SetUint64(rand.Uint64()) // skipcq: GSC-G404 +} + +// MustNewBatch will create a new test batch. Fields that are not supplied will +// be filled with random data. Panics on errors. +func MustNewBatch(opts ...BatchOption) *postage.Batch { + b := &postage.Batch{ + ID: MustNewID(), + Value: NewBigInt(), + Start: rand.Uint64(), // skipcq: GSC-G404 + Depth: defaultDepth, + } + + for _, opt := range opts { + opt(b) + } + + if b.Owner == nil { + b.Owner = MustNewAddress() + } + + return b +} + +// WithOwner will set the batch owner on a randomized batch. +func WithOwner(owner []byte) BatchOption { + return func(b *postage.Batch) { + b.Owner = owner + } +} + +// CompareBatches is a testing helper that compares two batches and fails the +// test if all fields are not equal. +// Fails on first different value and prints the comparison. +func CompareBatches(t *testing.T, want, got *postage.Batch) { + t.Helper() + + if !bytes.Equal(want.ID, got.ID) { + t.Fatalf("batch ID: want %v, got %v", want.ID, got.ID) + } + if want.Value.Cmp(got.Value) != 0 { + t.Fatalf("value: want %v, got %v", want.Value, got.Value) + } + if want.Start != got.Start { + t.Fatalf("start: want %v, got %b", want.Start, got.Start) + } + if !bytes.Equal(want.Owner, got.Owner) { + t.Fatalf("owner: want %v, got %v", want.Owner, got.Owner) + } + if want.Depth != got.Depth { + t.Fatalf("depth: want %v, got %v", want.Depth, got.Depth) + } +} diff --git a/pkg/postage/testing/chainstate.go b/pkg/postage/testing/chainstate.go new file mode 100644 index 00000000000..9b8a118c777 --- /dev/null +++ b/pkg/postage/testing/chainstate.go @@ -0,0 +1,38 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testing + +import ( + "math/rand" + "testing" + + "github.com/ethersphere/bee/pkg/postage" +) + +// NewChainState will create a new ChainState with random values. +func NewChainState() *postage.ChainState { + return &postage.ChainState{ + Block: rand.Uint64(), // skipcq: GSC-G404 + Price: NewBigInt(), + Total: NewBigInt(), + } +} + +// CompareChainState is a test helper that compares two ChainStates and fails +// the test if they are not exactly equal. +// Fails on first difference and returns a descriptive comparison. +func CompareChainState(t *testing.T, want, got *postage.ChainState) { + t.Helper() + + if want.Block != got.Block { + t.Fatalf("block: want %v, got %v", want.Block, got.Block) + } + if want.Price.Cmp(got.Price) != 0 { + t.Fatalf("price: want %v, got %v", want.Price, got.Price) + } + if want.Total.Cmp(got.Total) != 0 { + t.Fatalf("total: want %v, got %v", want.Total, got.Total) + } +} diff --git a/pkg/postage/testing/stamp.go b/pkg/postage/testing/stamp.go new file mode 100644 index 00000000000..6ab3141279e --- /dev/null +++ b/pkg/postage/testing/stamp.go @@ -0,0 +1,31 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testing + +import ( + crand "crypto/rand" + "io" + + "github.com/ethersphere/bee/pkg/postage" +) + +const signatureSize = 65 + +// MustNewSignature will create a new random signature (65 byte slice). Panics +// on errors. +func MustNewSignature() []byte { + sig := make([]byte, signatureSize) + _, err := io.ReadFull(crand.Reader, sig) + if err != nil { + panic(err) + } + return sig +} + +// MustNewStamp will generate a postage stamp with random data. Panics on +// errors. +func MustNewStamp() *postage.Stamp { + return postage.NewStamp(MustNewID(), MustNewSignature()) +} diff --git a/pkg/pss/pss.go b/pkg/pss/pss.go index 116f94f5236..daef3f306a0 100644 --- a/pkg/pss/pss.go +++ b/pkg/pss/pss.go @@ -18,6 +18,7 @@ import ( "time" "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage" "github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/swarm" ) @@ -29,7 +30,7 @@ var ( type Sender interface { // Send arbitrary byte slice with the given topic to Targets. - Send(context.Context, Topic, []byte, *ecdsa.PublicKey, Targets) error + Send(context.Context, Topic, []byte, postage.Stamper, *ecdsa.PublicKey, Targets) error } type Interface interface { @@ -84,7 +85,7 @@ type Handler func(context.Context, []byte) // Send constructs a padded message with topic and payload, // wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address. // Uses push-sync to deliver message. -func (p *pss) Send(ctx context.Context, topic Topic, payload []byte, recipient *ecdsa.PublicKey, targets Targets) error { +func (p *pss) Send(ctx context.Context, topic Topic, payload []byte, stamper postage.Stamper, recipient *ecdsa.PublicKey, targets Targets) error { p.metrics.TotalMessagesSentCounter.Inc() tStart := time.Now() @@ -94,6 +95,12 @@ func (p *pss) Send(ctx context.Context, topic Topic, payload []byte, recipient * return err } + stamp, err := stamper.Stamp(tc.Address()) + if err != nil { + return err + } + tc = tc.WithStamp(stamp) + p.metrics.MessageMiningDuration.Set(time.Since(tStart).Seconds()) // push the chunk using push sync so that it reaches it destination in network diff --git a/pkg/pss/pss_test.go b/pkg/pss/pss_test.go index 9bc8b355c1b..9ccd52088db 100644 --- a/pkg/pss/pss_test.go +++ b/pkg/pss/pss_test.go @@ -13,6 +13,8 @@ import ( "github.com/ethersphere/bee/pkg/crypto" "github.com/ethersphere/bee/pkg/logging" + "github.com/ethersphere/bee/pkg/postage" + postagetesting "github.com/ethersphere/bee/pkg/postage/testing" "github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pushsync" pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock" @@ -42,9 +44,9 @@ func TestSend(t *testing.T) { t.Fatal(err) } recipient := &privkey.PublicKey - + s := &stamper{} // call Send to store trojan chunk in localstore - if err = p.Send(ctx, topic, payload, recipient, targets); err != nil { + if err = p.Send(ctx, topic, payload, s, recipient, targets); err != nil { t.Fatal(err) } @@ -229,3 +231,9 @@ func ensureCalls(t *testing.T, calls *int, exp int) { t.Fatalf("expected %d calls, found %d", exp, *calls) } } + +type stamper struct{} + +func (s *stamper) Stamp(_ swarm.Address) (*postage.Stamp, error) { + return postagetesting.MustNewStamp(), nil +} diff --git a/pkg/pss/trojan.go b/pkg/pss/trojan.go index bc4fb5e2a04..0f11fa725ad 100644 --- a/pkg/pss/trojan.go +++ b/pkg/pss/trojan.go @@ -8,11 +8,13 @@ import ( "bytes" "context" "crypto/ecdsa" + random "crypto/rand" "encoding/binary" "encoding/hex" "errors" "fmt" - random "math/rand" + "math" + "math/big" "github.com/btcsuite/btcd/btcec" "github.com/ethersphere/bee/pkg/bmtpool" @@ -31,6 +33,8 @@ var ( // ErrVarLenTargets is returned when the given target list for a trojan chunk has addresses of different lengths ErrVarLenTargets = errors.New("target list cannot have targets of different length") + + maxUint32 = big.NewInt(math.MaxUint32) ) // Topic is the type that classifies messages, allows client applications to subscribe to @@ -202,7 +206,11 @@ func contains(col Targets, elem []byte) bool { func mine(ctx context.Context, odd bool, f func(nonce []byte) (swarm.Chunk, error)) (swarm.Chunk, error) { seeds := make([]uint32, 8) for i := range seeds { - seeds[i] = random.Uint32() + b, err := random.Int(random.Reader, maxUint32) + if err != nil { + return nil, err + } + seeds[i] = uint32(b.Int64()) } initnonce := make([]byte, 32) for i := 0; i < 8; i++ { @@ -269,7 +277,7 @@ func extractPublicKey(chunkData []byte) (*ecdsa.PublicKey, error) { // instead the hash of the secret key and the topic is matched against a hint (64 bit meta info)q // proper integrity check will disambiguate any potential collisions (false positives) // if the topic matches the hint, it returns the el-Gamal decryptor, otherwise an error -func matchTopic(key *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey, hint []byte, topic []byte) (encryption.Decrypter, error) { +func matchTopic(key *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey, hint, topic []byte) (encryption.Decrypter, error) { dec, err := elgamal.NewDecrypter(key, pubkey, topic, swarm.NewHasher) if err != nil { return nil, err diff --git a/pkg/pullsync/pb/pullsync.pb.go b/pkg/pullsync/pb/pullsync.pb.go index a52a00d39c0..1e1f7bd1adf 100644 --- a/pkg/pullsync/pb/pullsync.pb.go +++ b/pkg/pullsync/pb/pullsync.pb.go @@ -349,6 +349,7 @@ func (m *Want) GetBitVector() []byte { type Delivery struct { Address []byte `protobuf:"bytes,1,opt,name=Address,proto3" json:"Address,omitempty"` Data []byte `protobuf:"bytes,2,opt,name=Data,proto3" json:"Data,omitempty"` + Stamp []byte `protobuf:"bytes,3,opt,name=Stamp,proto3" json:"Stamp,omitempty"` } func (m *Delivery) Reset() { *m = Delivery{} } @@ -398,6 +399,13 @@ func (m *Delivery) GetData() []byte { return nil } +func (m *Delivery) GetStamp() []byte { + if m != nil { + return m.Stamp + } + return nil +} + func init() { proto.RegisterType((*Syn)(nil), "pullsync.Syn") proto.RegisterType((*Ack)(nil), "pullsync.Ack") @@ -412,26 +420,27 @@ func init() { func init() { proto.RegisterFile("pullsync.proto", fileDescriptor_d1dee042cf9c065c) } var fileDescriptor_d1dee042cf9c065c = []byte{ - // 295 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x90, 0xbf, 0x4e, 0xf3, 0x30, - 0x10, 0xc0, 0xeb, 0x24, 0xed, 0xd7, 0xef, 0x54, 0x2a, 0xe4, 0x01, 0x45, 0xa8, 0x32, 0x95, 0xc5, - 0xd0, 0x89, 0x85, 0x05, 0x36, 0x9a, 0x56, 0xc0, 0x86, 0x64, 0x22, 0x90, 0xd8, 0xdc, 0xc4, 0x85, - 0x88, 0xd4, 0x8e, 0x6c, 0x07, 0x29, 0x6f, 0xc1, 0x63, 0x31, 0x76, 0x64, 0x44, 0xc9, 0x8b, 0xa0, - 0x98, 0x44, 0x2c, 0x4c, 0xfe, 0xfd, 0xee, 0x7c, 0x7f, 0x74, 0x30, 0x2d, 0xca, 0x3c, 0x37, 0x95, - 0x4c, 0xce, 0x0a, 0xad, 0xac, 0xc2, 0xe3, 0xde, 0xe9, 0x10, 0xfc, 0xfb, 0x4a, 0xd2, 0x13, 0xf0, - 0x97, 0xc9, 0x2b, 0x0e, 0xe1, 0xdf, 0xaa, 0xd4, 0x46, 0x69, 0x13, 0xa2, 0xb9, 0xbf, 0x08, 0x58, - 0xaf, 0xf4, 0x18, 0x02, 0x56, 0x66, 0x29, 0xc6, 0x3f, 0x6f, 0x88, 0xe6, 0x68, 0x71, 0xc0, 0x1c, - 0xd3, 0x19, 0x8c, 0x56, 0x5c, 0x26, 0x22, 0xff, 0x33, 0x7b, 0x05, 0xe3, 0x1b, 0x61, 0x19, 0x97, - 0xcf, 0x02, 0x1f, 0x82, 0x1f, 0x65, 0xd2, 0xa5, 0x87, 0xac, 0xc5, 0xb6, 0xe2, 0x5a, 0xab, 0x5d, - 0xe8, 0xcd, 0xd1, 0x22, 0x60, 0x8e, 0xf1, 0x14, 0xbc, 0x58, 0x85, 0xbe, 0x8b, 0x78, 0xb1, 0xa2, - 0x97, 0x30, 0xbc, 0xdb, 0x6e, 0x85, 0x6e, 0xd7, 0x8b, 0x55, 0xb1, 0x53, 0xc6, 0xba, 0x16, 0x01, - 0xeb, 0x15, 0x1f, 0xc1, 0xe8, 0x96, 0x9b, 0x17, 0x61, 0x5c, 0xa3, 0x09, 0xeb, 0x8c, 0x9e, 0x42, - 0xf0, 0xc8, 0xa5, 0xc5, 0x33, 0xf8, 0x1f, 0x65, 0xf6, 0x41, 0x24, 0x56, 0x69, 0x57, 0x3b, 0x61, - 0xbf, 0x01, 0x7a, 0x01, 0xe3, 0xb5, 0xc8, 0xb3, 0x37, 0xa1, 0xab, 0x76, 0xc6, 0x32, 0x4d, 0xb5, - 0x30, 0xa6, 0xfb, 0xd7, 0x6b, 0xbb, 0xea, 0x9a, 0x5b, 0xde, 0x4d, 0x70, 0x1c, 0xcd, 0x3e, 0x6a, - 0x82, 0xf6, 0x35, 0x41, 0x5f, 0x35, 0x41, 0xef, 0x0d, 0x19, 0xec, 0x1b, 0x32, 0xf8, 0x6c, 0xc8, - 0xe0, 0xc9, 0x2b, 0x36, 0x9b, 0x91, 0xbb, 0xf6, 0xf9, 0x77, 0x00, 0x00, 0x00, 0xff, 0xff, 0xe5, - 0x7c, 0x69, 0x94, 0x7f, 0x01, 0x00, 0x00, + // 307 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x91, 0xcf, 0x4a, 0x03, 0x31, + 0x10, 0xc6, 0x9b, 0xfd, 0x53, 0xeb, 0x50, 0x8b, 0x04, 0x91, 0x45, 0x4a, 0x2c, 0xc1, 0x43, 0x4f, + 0x5e, 0x3c, 0x79, 0xb3, 0x7f, 0x50, 0x4f, 0x0a, 0x69, 0x51, 0xf0, 0x96, 0x6e, 0x53, 0x5d, 0xdc, + 0x26, 0x4b, 0x92, 0x15, 0xf6, 0x2d, 0x7c, 0x2c, 0x8f, 0x3d, 0x7a, 0x94, 0xdd, 0x17, 0x91, 0x4d, + 0x77, 0xf1, 0xe2, 0x29, 0xdf, 0x6f, 0x26, 0x33, 0xdf, 0x07, 0x03, 0x83, 0x2c, 0x4f, 0x53, 0x53, + 0xc8, 0xf8, 0x32, 0xd3, 0xca, 0x2a, 0xdc, 0x6b, 0x99, 0x86, 0xe0, 0x2f, 0x0a, 0x49, 0xcf, 0xc1, + 0x9f, 0xc4, 0xef, 0x38, 0x82, 0x83, 0x59, 0xae, 0x8d, 0xd2, 0x26, 0x42, 0x23, 0x7f, 0x1c, 0xb0, + 0x16, 0xe9, 0x19, 0x04, 0x2c, 0x4f, 0xd6, 0x18, 0xef, 0xdf, 0x08, 0x8d, 0xd0, 0xf8, 0x88, 0x39, + 0x4d, 0x87, 0xd0, 0x9d, 0x71, 0x19, 0x8b, 0xf4, 0xdf, 0xee, 0x0d, 0xf4, 0xee, 0x84, 0x65, 0x5c, + 0xbe, 0x0a, 0x7c, 0x0c, 0xfe, 0x34, 0x91, 0xae, 0x1d, 0xb2, 0x5a, 0xd6, 0x13, 0xb7, 0x5a, 0x6d, + 0x23, 0x6f, 0x84, 0xc6, 0x01, 0x73, 0x1a, 0x0f, 0xc0, 0x5b, 0xaa, 0xc8, 0x77, 0x15, 0x6f, 0xa9, + 0xe8, 0x35, 0x84, 0x8f, 0x9b, 0x8d, 0xd0, 0x75, 0xbc, 0xa5, 0xca, 0xb6, 0xca, 0x58, 0xb7, 0x22, + 0x60, 0x2d, 0xe2, 0x53, 0xe8, 0xde, 0x73, 0xf3, 0x26, 0x8c, 0x5b, 0xd4, 0x67, 0x0d, 0xd1, 0x0b, + 0x08, 0x9e, 0xb9, 0xb4, 0x78, 0x08, 0x87, 0xd3, 0xc4, 0x3e, 0x89, 0xd8, 0x2a, 0xed, 0x66, 0xfb, + 0xec, 0xaf, 0x40, 0x1f, 0xa0, 0x37, 0x17, 0x69, 0xf2, 0x21, 0x74, 0x51, 0x7b, 0x4c, 0xd6, 0x6b, + 0x2d, 0x8c, 0x69, 0xfe, 0xb5, 0x58, 0x47, 0x9d, 0x73, 0xcb, 0x1b, 0x07, 0xa7, 0xf1, 0x09, 0x84, + 0x0b, 0xcb, 0xb7, 0x99, 0x4b, 0xdb, 0x67, 0x7b, 0x98, 0x0e, 0xbf, 0x4a, 0x82, 0x76, 0x25, 0x41, + 0x3f, 0x25, 0x41, 0x9f, 0x15, 0xe9, 0xec, 0x2a, 0xd2, 0xf9, 0xae, 0x48, 0xe7, 0xc5, 0xcb, 0x56, + 0xab, 0xae, 0xbb, 0xc1, 0xd5, 0x6f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xf7, 0xd1, 0x30, 0xa1, 0x95, + 0x01, 0x00, 0x00, } func (m *Syn) Marshal() (dAtA []byte, err error) { @@ -677,6 +686,13 @@ func (m *Delivery) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if len(m.Stamp) > 0 { + i -= len(m.Stamp) + copy(dAtA[i:], m.Stamp) + i = encodeVarintPullsync(dAtA, i, uint64(len(m.Stamp))) + i-- + dAtA[i] = 0x1a + } if len(m.Data) > 0 { i -= len(m.Data) copy(dAtA[i:], m.Data) @@ -815,6 +831,10 @@ func (m *Delivery) Size() (n int) { if l > 0 { n += 1 + l + sovPullsync(uint64(l)) } + l = len(m.Stamp) + if l > 0 { + n += 1 + l + sovPullsync(uint64(l)) + } return n } @@ -1550,6 +1570,40 @@ func (m *Delivery) Unmarshal(dAtA []byte) error { m.Data = []byte{} } iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Stamp", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPullsync + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthPullsync + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthPullsync + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Stamp = append(m.Stamp[:0], dAtA[iNdEx:postIndex]...) + if m.Stamp == nil { + m.Stamp = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipPullsync(dAtA[iNdEx:]) diff --git a/pkg/pullsync/pb/pullsync.proto b/pkg/pullsync/pb/pullsync.proto index 56a9753af49..0e82c5a8be7 100644 --- a/pkg/pullsync/pb/pullsync.proto +++ b/pkg/pullsync/pb/pullsync.proto @@ -40,5 +40,6 @@ message Want { message Delivery { bytes Address = 1; bytes Data = 2; + bytes Stamp = 3; } diff --git a/pkg/pullsync/pullstorage/mock/pullstorage.go b/pkg/pullsync/pullstorage/mock/pullstorage.go index 740f52e40cd..57077614823 100644 --- a/pkg/pullsync/pullstorage/mock/pullstorage.go +++ b/pkg/pullsync/pullstorage/mock/pullstorage.go @@ -35,7 +35,8 @@ func WithIntervalsResp(addrs []swarm.Address, top uint64, err error) Option { func WithChunks(chs ...swarm.Chunk) Option { return optionFunc(func(p *PullStorage) { for _, c := range chs { - p.chunks[c.Address().String()] = c.Data() + c := c + p.chunks[c.Address().String()] = c } }) } @@ -67,7 +68,7 @@ type PullStorage struct { putCalls int setCalls int - chunks map[string][]byte + chunks map[string]swarm.Chunk evilAddr swarm.Address evilChunk swarm.Chunk @@ -80,7 +81,7 @@ type PullStorage struct { // NewPullStorage returns a new PullStorage mock. func NewPullStorage(opts ...Option) *PullStorage { s := &PullStorage{ - chunks: make(map[string][]byte), + chunks: make(map[string]swarm.Chunk), } for _, v := range opts { v.apply(s) @@ -128,7 +129,7 @@ func (s *PullStorage) Get(_ context.Context, _ storage.ModeGet, addrs ...swarm.A } if v, ok := s.chunks[a.String()]; ok { - chs = append(chs, swarm.NewChunk(a, v)) + chs = append(chs, v) } else if !ok { return nil, storage.ErrNotFound } @@ -141,7 +142,8 @@ func (s *PullStorage) Put(_ context.Context, _ storage.ModePut, chs ...swarm.Chu s.mtx.Lock() defer s.mtx.Unlock() for _, c := range chs { - s.chunks[c.Address().String()] = c.Data() + c := c + s.chunks[c.Address().String()] = c } s.putCalls++ return nil diff --git a/pkg/pullsync/pullsync.go b/pkg/pullsync/pullsync.go index e81f17d6cb4..b88d3dfc533 100644 --- a/pkg/pullsync/pullsync.go +++ b/pkg/pullsync/pullsync.go @@ -60,13 +60,14 @@ type Interface interface { } type Syncer struct { - streamer p2p.Streamer - metrics metrics - logger logging.Logger - storage pullstorage.Storer - quit chan struct{} - wg sync.WaitGroup - unwrap func(swarm.Chunk) + streamer p2p.Streamer + metrics metrics + logger logging.Logger + storage pullstorage.Storer + quit chan struct{} + wg sync.WaitGroup + unwrap func(swarm.Chunk) + validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error) ruidMtx sync.Mutex ruidCtx map[uint32]func() @@ -75,16 +76,17 @@ type Syncer struct { io.Closer } -func New(streamer p2p.Streamer, storage pullstorage.Storer, unwrap func(swarm.Chunk), logger logging.Logger) *Syncer { +func New(streamer p2p.Streamer, storage pullstorage.Storer, unwrap func(swarm.Chunk), validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), logger logging.Logger) *Syncer { return &Syncer{ - streamer: streamer, - storage: storage, - metrics: newMetrics(), - unwrap: unwrap, - logger: logger, - ruidCtx: make(map[uint32]func()), - wg: sync.WaitGroup{}, - quit: make(chan struct{}), + streamer: streamer, + storage: storage, + metrics: newMetrics(), + unwrap: unwrap, + validStamp: validStamp, + logger: logger, + ruidCtx: make(map[uint32]func()), + wg: sync.WaitGroup{}, + quit: make(chan struct{}), } } @@ -225,6 +227,10 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8 s.metrics.DeliveryCounter.Inc() chunk := swarm.NewChunk(addr, delivery.Data) + if chunk, err = s.validStamp(chunk, delivery.Stamp); err != nil { + return 0, ru.Ruid, err + } + if cac.Valid(chunk) { go s.unwrap(chunk) } else if !soc.Valid(chunk) { @@ -326,7 +332,11 @@ func (s *Syncer) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (er } for _, v := range chs { - deliver := pb.Delivery{Address: v.Address().Bytes(), Data: v.Data()} + stamp, err := v.Stamp().MarshalBinary() + if err != nil { + return fmt.Errorf("serialise stamp: %w", err) + } + deliver := pb.Delivery{Address: v.Address().Bytes(), Data: v.Data(), Stamp: stamp} if err := w.WriteMsgWithContext(ctx, &deliver); err != nil { return fmt.Errorf("write delivery: %w", err) } diff --git a/pkg/pullsync/pullsync_test.go b/pkg/pullsync/pullsync_test.go index ff1d05cb8a9..12d1a2fd137 100644 --- a/pkg/pullsync/pullsync_test.go +++ b/pkg/pullsync/pullsync_test.go @@ -14,6 +14,7 @@ import ( "github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p/streamtest" + postagetesting "github.com/ethersphere/bee/pkg/postage/testing" "github.com/ethersphere/bee/pkg/pullsync" "github.com/ethersphere/bee/pkg/pullsync/pullstorage/mock" testingc "github.com/ethersphere/bee/pkg/storage/testing" @@ -141,7 +142,8 @@ func TestIncoming_WantAll(t *testing.T) { func TestIncoming_UnsolicitedChunk(t *testing.T) { evilAddr := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000666") evilData := []byte{0x66, 0x66, 0x66} - evil := swarm.NewChunk(evilAddr, evilData) + stamp := postagetesting.MustNewStamp() + evil := swarm.NewChunk(evilAddr, evilData).WithStamp(stamp) var ( mockTopmost = uint64(5) @@ -214,5 +216,6 @@ func newPullSync(s p2p.Streamer, o ...mock.Option) (*pullsync.Syncer, *mock.Pull storage := mock.NewPullStorage(o...) logger := logging.New(ioutil.Discard, 0) unwrap := func(swarm.Chunk) {} - return pullsync.New(s, storage, unwrap, logger), storage + validStamp := func(ch swarm.Chunk, _ []byte) (swarm.Chunk, error) { return ch, nil } + return pullsync.New(s, storage, unwrap, validStamp, logger), storage } diff --git a/pkg/pusher/pusher_test.go b/pkg/pusher/pusher_test.go index e24af901b39..d91fb3863de 100644 --- a/pkg/pusher/pusher_test.go +++ b/pkg/pusher/pusher_test.go @@ -21,6 +21,7 @@ import ( "github.com/ethersphere/bee/pkg/pushsync" pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock" "github.com/ethersphere/bee/pkg/storage" + testingc "github.com/ethersphere/bee/pkg/storage/testing" "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/topology/mock" @@ -94,7 +95,7 @@ func TestSendChunkToSyncWithTag(t *testing.T) { t.Fatal(err) } - chunk := createChunk().WithTagID(ta.Uid) + chunk := testingc.GenerateTestRandomChunk().WithTagID(ta.Uid) _, err = storer.Put(context.Background(), storage.ModePutUpload, chunk) if err != nil { @@ -125,7 +126,7 @@ func TestSendChunkToSyncWithTag(t *testing.T) { // TestSendChunkToPushSyncWithoutTag is similar to TestSendChunkToPushSync, excep that the tags are not // present to simulate bzz api withotu splitter condition func TestSendChunkToPushSyncWithoutTag(t *testing.T) { - chunk := createChunk() + chunk := testingc.GenerateTestRandomChunk() // create a trigger and a closestpeer triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") @@ -171,7 +172,7 @@ func TestSendChunkToPushSyncWithoutTag(t *testing.T) { // get a invalid receipt (not with the address of the chunk sent). The test makes sure that this error // is received and the ModeSetSync is not set for the chunk. func TestSendChunkAndReceiveInvalidReceipt(t *testing.T) { - chunk := createChunk() + chunk := testingc.GenerateTestRandomChunk() // create a trigger and a closestpeer triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") @@ -209,7 +210,7 @@ func TestSendChunkAndReceiveInvalidReceipt(t *testing.T) { // expects a timeout to get instead of getting a receipt. The test makes sure that timeout error // is received and the ModeSetSync is not set for the chunk. func TestSendChunkAndTimeoutinReceivingReceipt(t *testing.T) { - chunk := createChunk() + chunk := testingc.GenerateTestRandomChunk() // create a trigger and a closestpeer triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") @@ -285,7 +286,7 @@ func TestPusherClose(t *testing.T) { _, p, storer := createPusher(t, triggerPeer, pushSyncService, mock.WithClosestPeer(closestPeer)) - chunk := createChunk() + chunk := testingc.GenerateTestRandomChunk() _, err := storer.Put(context.Background(), storage.ModePutUpload, chunk) if err != nil { @@ -361,13 +362,6 @@ func TestPusherClose(t *testing.T) { } } -func createChunk() swarm.Chunk { - // chunk data to upload - chunkAddress := swarm.MustParseHexAddress("7000000000000000000000000000000000000000000000000000000000000000") - chunkData := []byte("1234") - return swarm.NewChunk(chunkAddress, chunkData).WithTagID(666) -} - func createPusher(t *testing.T, addr swarm.Address, pushSyncService pushsync.PushSyncer, mockOpts ...mock.Option) (*tags.Tags, *pusher.Service, *Store) { t.Helper() logger := logging.New(ioutil.Discard, 0) diff --git a/pkg/pushsync/pb/pushsync.pb.go b/pkg/pushsync/pb/pushsync.pb.go index 9865e9bdf37..57a74464a23 100644 --- a/pkg/pushsync/pb/pushsync.pb.go +++ b/pkg/pushsync/pb/pushsync.pb.go @@ -25,6 +25,7 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type Delivery struct { Address []byte `protobuf:"bytes,1,opt,name=Address,proto3" json:"Address,omitempty"` Data []byte `protobuf:"bytes,2,opt,name=Data,proto3" json:"Data,omitempty"` + Stamp []byte `protobuf:"bytes,3,opt,name=Stamp,proto3" json:"Stamp,omitempty"` } func (m *Delivery) Reset() { *m = Delivery{} } @@ -74,6 +75,13 @@ func (m *Delivery) GetData() []byte { return nil } +func (m *Delivery) GetStamp() []byte { + if m != nil { + return m.Stamp + } + return nil +} + type Receipt struct { Address []byte `protobuf:"bytes,1,opt,name=Address,proto3" json:"Address,omitempty"` Signature []byte `protobuf:"bytes,2,opt,name=Signature,proto3" json:"Signature,omitempty"` @@ -134,17 +142,18 @@ func init() { func init() { proto.RegisterFile("pushsync.proto", fileDescriptor_723cf31bfc02bfd6) } var fileDescriptor_723cf31bfc02bfd6 = []byte{ - // 155 bytes of a gzipped FileDescriptorProto + // 170 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2b, 0x28, 0x2d, 0xce, - 0x28, 0xae, 0xcc, 0x4b, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x80, 0xf1, 0x95, 0x2c, + 0x28, 0xae, 0xcc, 0x4b, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x80, 0xf1, 0x95, 0xfc, 0xb8, 0x38, 0x5c, 0x52, 0x73, 0x32, 0xcb, 0x52, 0x8b, 0x2a, 0x85, 0x24, 0xb8, 0xd8, 0x1d, 0x53, 0x52, 0x8a, 0x52, 0x8b, 0x8b, 0x25, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0x60, 0x5c, 0x21, 0x21, - 0x2e, 0x16, 0x97, 0xc4, 0x92, 0x44, 0x09, 0x26, 0xb0, 0x30, 0x98, 0xad, 0xe4, 0xc8, 0xc5, 0x1e, - 0x94, 0x9a, 0x9c, 0x9a, 0x59, 0x50, 0x82, 0x47, 0xa3, 0x0c, 0x17, 0x67, 0x70, 0x66, 0x7a, 0x5e, - 0x62, 0x49, 0x69, 0x51, 0x2a, 0x54, 0x37, 0x42, 0xc0, 0x49, 0xe6, 0xc4, 0x23, 0x39, 0xc6, 0x0b, - 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, - 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x98, 0x0a, 0x92, 0x92, 0xd8, 0xc0, 0x6e, 0x35, 0x06, 0x04, 0x00, - 0x00, 0xff, 0xff, 0x72, 0xaf, 0x50, 0xbc, 0xbd, 0x00, 0x00, 0x00, + 0x2e, 0x16, 0x97, 0xc4, 0x92, 0x44, 0x09, 0x26, 0xb0, 0x30, 0x98, 0x2d, 0x24, 0xc2, 0xc5, 0x1a, + 0x5c, 0x92, 0x98, 0x5b, 0x20, 0xc1, 0x0c, 0x16, 0x84, 0x70, 0x94, 0x1c, 0xb9, 0xd8, 0x83, 0x52, + 0x93, 0x53, 0x33, 0x0b, 0x4a, 0xf0, 0x18, 0x27, 0xc3, 0xc5, 0x19, 0x9c, 0x99, 0x9e, 0x97, 0x58, + 0x52, 0x5a, 0x94, 0x0a, 0x35, 0x13, 0x21, 0xe0, 0x24, 0x73, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, + 0x72, 0x8c, 0x0f, 0x1e, 0xc9, 0x31, 0x4e, 0x78, 0x2c, 0xc7, 0x70, 0xe1, 0xb1, 0x1c, 0xc3, 0x8d, + 0xc7, 0x72, 0x0c, 0x51, 0x4c, 0x05, 0x49, 0x49, 0x6c, 0x60, 0x1f, 0x18, 0x03, 0x02, 0x00, 0x00, + 0xff, 0xff, 0xbb, 0xdf, 0x60, 0x63, 0xd3, 0x00, 0x00, 0x00, } func (m *Delivery) Marshal() (dAtA []byte, err error) { @@ -167,6 +176,13 @@ func (m *Delivery) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if len(m.Stamp) > 0 { + i -= len(m.Stamp) + copy(dAtA[i:], m.Stamp) + i = encodeVarintPushsync(dAtA, i, uint64(len(m.Stamp))) + i-- + dAtA[i] = 0x1a + } if len(m.Data) > 0 { i -= len(m.Data) copy(dAtA[i:], m.Data) @@ -246,6 +262,10 @@ func (m *Delivery) Size() (n int) { if l > 0 { n += 1 + l + sovPushsync(uint64(l)) } + l = len(m.Stamp) + if l > 0 { + n += 1 + l + sovPushsync(uint64(l)) + } return n } @@ -369,6 +389,40 @@ func (m *Delivery) Unmarshal(dAtA []byte) error { m.Data = []byte{} } iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Stamp", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPushsync + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthPushsync + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthPushsync + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Stamp = append(m.Stamp[:0], dAtA[iNdEx:postIndex]...) + if m.Stamp == nil { + m.Stamp = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipPushsync(dAtA[iNdEx:]) diff --git a/pkg/pushsync/pb/pushsync.proto b/pkg/pushsync/pb/pushsync.proto index 9f378641cd1..76029279088 100644 --- a/pkg/pushsync/pb/pushsync.proto +++ b/pkg/pushsync/pb/pushsync.proto @@ -11,6 +11,7 @@ option go_package = "pb"; message Delivery { bytes Address = 1; bytes Data = 2; + bytes Stamp = 3; } message Receipt { diff --git a/pkg/pushsync/pushsync.go b/pkg/pushsync/pushsync.go index 28bf5154674..03a6d880088 100644 --- a/pkg/pushsync/pushsync.go +++ b/pkg/pushsync/pushsync.go @@ -64,6 +64,7 @@ type PushSync struct { pricer pricer.Interface metrics metrics tracer *tracing.Tracer + validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error) signer crypto.Signer isFullNode bool } @@ -72,12 +73,12 @@ var timeToLive = 5 * time.Second // request time to live var timeToWaitForPushsyncToNeighbor = 3 * time.Second // time to wait to get a receipt for a chunk var nPeersToPushsync = 3 // number of peers to replicate to as receipt is sent upstream -func New(address swarm.Address, streamer p2p.StreamerDisconnecter, storer storage.Putter, topologyDriver topology.Driver, tagger *tags.Tags, isFullNode bool, unwrap func(swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, signer crypto.Signer, tracer *tracing.Tracer) *PushSync { +func New(address swarm.Address, streamer p2p.StreamerDisconnecter, storer storage.Putter, topology topology.Driver, tagger *tags.Tags, isFullNode bool, unwrap func(swarm.Chunk), validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, signer crypto.Signer, tracer *tracing.Tracer) *PushSync { ps := &PushSync{ address: address, streamer: streamer, storer: storer, - topologyDriver: topologyDriver, + topologyDriver: topology, tagger: tagger, isFullNode: isFullNode, unwrap: unwrap, @@ -86,6 +87,7 @@ func New(address swarm.Address, streamer p2p.StreamerDisconnecter, storer storag pricer: pricer, metrics: newMetrics(), tracer: tracer, + validStamp: validStamp, signer: signer, } return ps @@ -125,6 +127,9 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ps.metrics.TotalReceived.Inc() chunk := swarm.NewChunk(swarm.NewAddress(ch.Address), ch.Data) + if chunk, err = ps.validStamp(chunk, ch.Stamp); err != nil { + return fmt.Errorf("pushsync valid stamp: %w", err) + } if cac.Valid(chunk) { if ps.unwrap != nil { @@ -217,10 +222,16 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) defer streamer.Close() w := protobuf.NewWriter(streamer) - + ctx, cancel = context.WithTimeout(ctx, timeToWaitForPushsyncToNeighbor) + defer cancel() + stamp, err := chunk.Stamp().MarshalBinary() + if err != nil { + return + } err = w.WriteMsgWithContext(ctx, &pb.Delivery{ Address: chunk.Address().Bytes(), Data: chunk.Data(), + Stamp: stamp, }) if err != nil { _ = streamer.Reset() @@ -283,6 +294,11 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk) (rr *pb.R lastErr error ) + stamp, err := ch.Stamp().MarshalBinary() + if err != nil { + return nil, err + } + deferFuncs := make([]func(), 0) defersFn := func() { if len(deferFuncs) > 0 { @@ -347,6 +363,7 @@ func (ps *PushSync) pushToClosest(ctx context.Context, ch swarm.Chunk) (rr *pb.R if err := w.WriteMsgWithContext(ctxd, &pb.Delivery{ Address: ch.Address().Bytes(), Data: ch.Data(), + Stamp: stamp, }); err != nil { _ = streamer.Reset() lastErr = fmt.Errorf("chunk %s deliver to peer %s: %w", ch.Address().String(), peer.String(), err) diff --git a/pkg/pushsync/pushsync_test.go b/pkg/pushsync/pushsync_test.go index b745bea5ae2..d1c1d816ed9 100644 --- a/pkg/pushsync/pushsync_test.go +++ b/pkg/pushsync/pushsync_test.go @@ -17,16 +17,17 @@ import ( accountingmock "github.com/ethersphere/bee/pkg/accounting/mock" "github.com/ethersphere/bee/pkg/crypto" cryptomock "github.com/ethersphere/bee/pkg/crypto/mock" - "github.com/ethersphere/bee/pkg/localstore" "github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/streamtest" + "github.com/ethersphere/bee/pkg/postage" pricermock "github.com/ethersphere/bee/pkg/pricer/mock" "github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync/pb" statestore "github.com/ethersphere/bee/pkg/statestore/mock" "github.com/ethersphere/bee/pkg/storage" + mocks "github.com/ethersphere/bee/pkg/storage/mock" testingc "github.com/ethersphere/bee/pkg/storage/testing" "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/tags" @@ -122,8 +123,9 @@ func TestReplicateBeforeReceipt(t *testing.T) { // node that is connected to secondPeer // it's address is closer to the chunk than secondPeer but it will not receive the chunk - _, storerEmpty, _, _ := createPushSyncNode(t, emptyPeer, defaultPrices, nil, nil, defaultSigner) + psEmpty, storerEmpty, _, _ := createPushSyncNode(t, emptyPeer, defaultPrices, nil, nil, defaultSigner) defer storerEmpty.Close() + emptyRecorder := streamtest.New(streamtest.WithProtocols(psEmpty.Protocol()), streamtest.WithBaseAddr(secondPeer)) wFunc := func(addr swarm.Address) bool { return true @@ -131,7 +133,7 @@ func TestReplicateBeforeReceipt(t *testing.T) { // node that is connected to closestPeer // will receieve chunk from closestPeer - psSecond, storerSecond, _, secondAccounting := createPushSyncNode(t, secondPeer, defaultPrices, nil, nil, defaultSigner, mock.WithPeers(emptyPeer), mock.WithIsWithinFunc(wFunc)) + psSecond, storerSecond, _, secondAccounting := createPushSyncNode(t, secondPeer, defaultPrices, emptyRecorder, nil, defaultSigner, mock.WithPeers(emptyPeer), mock.WithIsWithinFunc(wFunc)) defer storerSecond.Close() secondRecorder := streamtest.New(streamtest.WithProtocols(psSecond.Protocol()), streamtest.WithBaseAddr(closestPeer)) @@ -549,14 +551,10 @@ func TestSignsReceipt(t *testing.T) { } } -func createPushSyncNode(t *testing.T, addr swarm.Address, prices pricerParameters, recorder *streamtest.Recorder, unwrap func(swarm.Chunk), signer crypto.Signer, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) { +func createPushSyncNode(t *testing.T, addr swarm.Address, prices pricerParameters, recorder *streamtest.Recorder, unwrap func(swarm.Chunk), signer crypto.Signer, mockOpts ...mock.Option) (*pushsync.PushSync, *mocks.MockStorer, *tags.Tags, accounting.Interface) { t.Helper() logger := logging.New(ioutil.Discard, 0) - - storer, err := localstore.New("", addr.Bytes(), nil, logger) - if err != nil { - t.Fatal(err) - } + storer := mocks.NewStorer() mockTopology := mock.NewTopologyDriver(mockOpts...) mockStatestore := statestore.NewStateStore() @@ -569,8 +567,11 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, prices pricerParameter if unwrap == nil { unwrap = func(swarm.Chunk) {} } + validStamp := func(ch swarm.Chunk, stamp []byte) (swarm.Chunk, error) { + return ch.WithStamp(postage.NewStamp(nil, nil)), nil + } - return pushsync.New(addr, recorderDisconnecter, storer, mockTopology, mtag, true, unwrap, logger, mockAccounting, mockPricer, signer, nil), storer, mtag, mockAccounting + return pushsync.New(addr, recorderDisconnecter, storer, mockTopology, mtag, true, unwrap, validStamp, logger, mockAccounting, mockPricer, signer, nil), storer, mtag, mockAccounting } func waitOnRecordAndTest(t *testing.T, peer swarm.Address, recorder *streamtest.Recorder, add swarm.Address, data []byte) { diff --git a/pkg/recovery/repair.go b/pkg/recovery/repair.go index 7a4d27e634f..e039538d7c8 100644 --- a/pkg/recovery/repair.go +++ b/pkg/recovery/repair.go @@ -35,7 +35,7 @@ func NewCallback(pssSender pss.Sender) Callback { return func(chunkAddress swarm.Address, targets pss.Targets) { payload := chunkAddress ctx := context.Background() - _ = pssSender.Send(ctx, Topic, payload.Bytes(), &recipient, targets) + _ = pssSender.Send(ctx, Topic, payload.Bytes(), nil, &recipient, targets) } } diff --git a/pkg/recovery/repair_test.go b/pkg/recovery/repair_test.go index a326e56063d..5a91a79bbd2 100644 --- a/pkg/recovery/repair_test.go +++ b/pkg/recovery/repair_test.go @@ -16,6 +16,7 @@ import ( "github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/p2p/streamtest" + "github.com/ethersphere/bee/pkg/postage" pricermock "github.com/ethersphere/bee/pkg/pricer/mock" "github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pushsync" @@ -231,7 +232,11 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store streamtest.WithProtocols(server.Protocol()), ) retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil) - ns := netstore.New(storer, recoveryFunc, retrieve, logger) + validStamp := func(ch swarm.Chunk, stamp []byte) (swarm.Chunk, error) { + return ch.WithStamp(postage.NewStamp(nil, nil)), nil + } + + ns := netstore.New(storer, validStamp, recoveryFunc, retrieve, logger) return ns } @@ -251,7 +256,7 @@ type mockPssSender struct { } // Send mocks the pss Send function -func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []byte, recipient *ecdsa.PublicKey, targets pss.Targets) error { +func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []byte, _ postage.Stamper, recipient *ecdsa.PublicKey, targets pss.Targets) error { mp.callbackC <- true return nil } diff --git a/pkg/resolver/client/ens/ens.go b/pkg/resolver/client/ens/ens.go index 3aeb9a4683d..0fad46f4275 100644 --- a/pkg/resolver/client/ens/ens.go +++ b/pkg/resolver/client/ens/ens.go @@ -141,7 +141,7 @@ func (c *Client) Close() error { return nil } -func wrapDial(endpoint string, contractAddr string) (*ethclient.Client, *goens.Registry, error) { +func wrapDial(endpoint, contractAddr string) (*ethclient.Client, *goens.Registry, error) { // Dial the eth client. ethCl, err := ethclient.Dial(endpoint) if err != nil { diff --git a/pkg/retrieval/pb/retrieval.pb.go b/pkg/retrieval/pb/retrieval.pb.go index 85b4715a23f..11acf007415 100644 --- a/pkg/retrieval/pb/retrieval.pb.go +++ b/pkg/retrieval/pb/retrieval.pb.go @@ -67,7 +67,8 @@ func (m *Request) GetAddr() []byte { } type Delivery struct { - Data []byte `protobuf:"bytes,1,opt,name=Data,proto3" json:"Data,omitempty"` + Data []byte `protobuf:"bytes,1,opt,name=Data,proto3" json:"Data,omitempty"` + Stamp []byte `protobuf:"bytes,2,opt,name=Stamp,proto3" json:"Stamp,omitempty"` } func (m *Delivery) Reset() { *m = Delivery{} } @@ -110,24 +111,32 @@ func (m *Delivery) GetData() []byte { return nil } +func (m *Delivery) GetStamp() []byte { + if m != nil { + return m.Stamp + } + return nil +} + func init() { - proto.RegisterType((*Request)(nil), "retieval.Request") - proto.RegisterType((*Delivery)(nil), "retieval.Delivery") + proto.RegisterType((*Request)(nil), "retrieval.Request") + proto.RegisterType((*Delivery)(nil), "retrieval.Delivery") } func init() { proto.RegisterFile("retrieval.proto", fileDescriptor_fcade0a564e5dcd4) } var fileDescriptor_fcade0a564e5dcd4 = []byte{ - // 134 bytes of a gzipped FileDescriptorProto + // 146 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2f, 0x4a, 0x2d, 0x29, - 0xca, 0x4c, 0x2d, 0x4b, 0xcc, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x28, 0x4a, 0x2d, - 0x01, 0xf3, 0x95, 0x64, 0xb9, 0xd8, 0x83, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x84, 0xb8, - 0x58, 0x1c, 0x53, 0x52, 0x8a, 0x24, 0x18, 0x15, 0x18, 0x35, 0x78, 0x82, 0xc0, 0x6c, 0x25, 0x39, - 0x2e, 0x0e, 0x97, 0xd4, 0x9c, 0xcc, 0xb2, 0xd4, 0xa2, 0x4a, 0x90, 0xbc, 0x4b, 0x62, 0x49, 0x22, - 0x4c, 0x1e, 0xc4, 0x76, 0x92, 0x39, 0xf1, 0x48, 0x8e, 0xf1, 0xc2, 0x23, 0x39, 0xc6, 0x07, 0x8f, - 0xe4, 0x18, 0x27, 0x3c, 0x96, 0x63, 0xb8, 0xf0, 0x58, 0x8e, 0xe1, 0xc6, 0x63, 0x39, 0x86, 0x28, - 0xa6, 0x82, 0xa4, 0x24, 0x36, 0xb0, 0x6d, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0xe5, 0x88, - 0xb0, 0x44, 0x80, 0x00, 0x00, 0x00, + 0xca, 0x4c, 0x2d, 0x4b, 0xcc, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x84, 0x0b, 0x28, + 0xc9, 0x72, 0xb1, 0x07, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, 0x08, 0x09, 0x71, 0xb1, 0x38, 0xa6, + 0xa4, 0x14, 0x49, 0x30, 0x2a, 0x30, 0x6a, 0xf0, 0x04, 0x81, 0xd9, 0x4a, 0x26, 0x5c, 0x1c, 0x2e, + 0xa9, 0x39, 0x99, 0x65, 0xa9, 0x45, 0x95, 0x20, 0x79, 0x97, 0xc4, 0x92, 0x44, 0x98, 0x3c, 0x88, + 0x2d, 0x24, 0xc2, 0xc5, 0x1a, 0x5c, 0x92, 0x98, 0x5b, 0x20, 0xc1, 0x04, 0x16, 0x84, 0x70, 0x9c, + 0x64, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, 0x8f, + 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0xf1, 0x58, 0x8e, 0x21, 0x8a, 0xa9, 0x20, 0x29, 0x89, + 0x0d, 0xec, 0x08, 0x63, 0x40, 0x00, 0x00, 0x00, 0xff, 0xff, 0xf7, 0x72, 0x32, 0x41, 0x97, 0x00, + 0x00, 0x00, } func (m *Request) Marshal() (dAtA []byte, err error) { @@ -180,6 +189,13 @@ func (m *Delivery) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if len(m.Stamp) > 0 { + i -= len(m.Stamp) + copy(dAtA[i:], m.Stamp) + i = encodeVarintRetrieval(dAtA, i, uint64(len(m.Stamp))) + i-- + dAtA[i] = 0x12 + } if len(m.Data) > 0 { i -= len(m.Data) copy(dAtA[i:], m.Data) @@ -224,6 +240,10 @@ func (m *Delivery) Size() (n int) { if l > 0 { n += 1 + l + sovRetrieval(uint64(l)) } + l = len(m.Stamp) + if l > 0 { + n += 1 + l + sovRetrieval(uint64(l)) + } return n } @@ -383,6 +403,40 @@ func (m *Delivery) Unmarshal(dAtA []byte) error { m.Data = []byte{} } iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Stamp", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRetrieval + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRetrieval + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthRetrieval + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Stamp = append(m.Stamp[:0], dAtA[iNdEx:postIndex]...) + if m.Stamp == nil { + m.Stamp = []byte{} + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipRetrieval(dAtA[iNdEx:]) diff --git a/pkg/retrieval/pb/retrieval.proto b/pkg/retrieval/pb/retrieval.proto index bc3aa1de1d5..8104b3563c3 100644 --- a/pkg/retrieval/pb/retrieval.proto +++ b/pkg/retrieval/pb/retrieval.proto @@ -4,14 +4,15 @@ syntax = "proto3"; -package retieval; +package retrieval; option go_package = "pb"; message Request { - bytes Addr = 1; + bytes Addr = 1; } message Delivery { - bytes Data = 1; + bytes Data = 1; + bytes Stamp = 2; } diff --git a/pkg/retrieval/retrieval.go b/pkg/retrieval/retrieval.go index 61a454d106c..f565d902627 100644 --- a/pkg/retrieval/retrieval.go +++ b/pkg/retrieval/retrieval.go @@ -20,6 +20,7 @@ import ( "github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p/protobuf" + "github.com/ethersphere/bee/pkg/postage" "github.com/ethersphere/bee/pkg/pricer" pb "github.com/ethersphere/bee/pkg/retrieval/pb" "github.com/ethersphere/bee/pkg/soc" @@ -245,7 +246,12 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski Observe(time.Since(startTimer).Seconds()) s.metrics.TotalRetrieved.Inc() - chunk = swarm.NewChunk(addr, d.Data) + stamp := new(postage.Stamp) + err = stamp.UnmarshalBinary(d.Stamp) + if err != nil { + return nil, peer, fmt.Errorf("stamp unmarshal: %w", err) + } + chunk = swarm.NewChunk(addr, d.Data).WithStamp(stamp) if !cac.Valid(chunk) { if !soc.Valid(chunk) { s.metrics.InvalidChunkRetrieved.Inc() @@ -352,8 +358,13 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e } } + stamp, err := chunk.Stamp().MarshalBinary() + if err != nil { + return fmt.Errorf("stamp marshal: %w", err) + } if err := w.WriteMsgWithContext(ctx, &pb.Delivery{ - Data: chunk.Data(), + Data: chunk.Data(), + Stamp: stamp, }); err != nil { return fmt.Errorf("write delivery: %w peer %s", err, p.Address.String()) } diff --git a/pkg/retrieval/retrieval_test.go b/pkg/retrieval/retrieval_test.go index 2e0097bad97..28ab051c221 100644 --- a/pkg/retrieval/retrieval_test.go +++ b/pkg/retrieval/retrieval_test.go @@ -49,9 +49,13 @@ func TestDelivery(t *testing.T) { pricerMock = pricermock.NewMockService(defaultPrice, defaultPrice) ) + stamp, err := chunk.Stamp().MarshalBinary() + if err != nil { + t.Fatal(err) + } // put testdata in the mock store of the server - _, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk) + _, err = mockStorer.Put(context.Background(), storage.ModePutUpload, chunk) if err != nil { t.Fatal(err) } @@ -84,6 +88,13 @@ func TestDelivery(t *testing.T) { if !bytes.Equal(v.Data(), chunk.Data()) { t.Fatalf("request and response data not equal. got %s want %s", v, chunk.Data()) } + vstamp, err := v.Stamp().MarshalBinary() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(vstamp, stamp) { + t.Fatal("stamp mismatch") + } records, err := recorder.Records(serverAddr, "retrieval", "1.0.0", "retrieval") if err != nil { t.Fatal(err) diff --git a/pkg/settlement/swap/transaction/event_test.go b/pkg/settlement/swap/transaction/event_test.go index a6a4258dae2..20ee242b234 100644 --- a/pkg/settlement/swap/transaction/event_test.go +++ b/pkg/settlement/swap/transaction/event_test.go @@ -25,7 +25,7 @@ type transferEvent struct { Value *big.Int } -func newTransferLog(address common.Address, from common.Address, to common.Address, value *big.Int) *types.Log { +func newTransferLog(address, from, to common.Address, value *big.Int) *types.Log { return &types.Log{ Topics: []common.Hash{ erc20ABI.Events["Transfer"].ID, diff --git a/pkg/settlement/swap/transaction/transaction.go b/pkg/settlement/swap/transaction/transaction.go index f5cc4db0fdd..272c90f28af 100644 --- a/pkg/settlement/swap/transaction/transaction.go +++ b/pkg/settlement/swap/transaction/transaction.go @@ -152,7 +152,6 @@ func (t *transactionService) Call(ctx context.Context, request *TxRequest) ([]by Gas: request.GasLimit, Value: request.Value, } - data, err := t.backend.CallContract(ctx, msg, nil) if err != nil { return nil, err diff --git a/pkg/shed/index.go b/pkg/shed/index.go index ff1353696d1..530375b889f 100644 --- a/pkg/shed/index.go +++ b/pkg/shed/index.go @@ -45,6 +45,10 @@ type Item struct { BinID uint64 PinCounter uint64 // maintains the no of time a chunk is pinned Tag uint32 + BatchID []byte // postage batch ID + Sig []byte // postage stamp + Depth uint8 // postage batch depth + Radius uint8 // postage batch reserve radius, po upto and excluding which chunks are unpinned } // Merge is a helper method to construct a new @@ -72,6 +76,18 @@ func (i Item) Merge(i2 Item) Item { if i.Tag == 0 { i.Tag = i2.Tag } + if len(i.Sig) == 0 { + i.Sig = i2.Sig + } + if len(i.BatchID) == 0 { + i.BatchID = i2.BatchID + } + if i.Depth == 0 { + i.Depth = i2.Depth + } + if i.Radius == 0 { + i.Radius = i2.Radius + } return i } diff --git a/pkg/storage/mock/storer.go b/pkg/storage/mock/storer.go index 1e4def6ac94..424cafaa127 100644 --- a/pkg/storage/mock/storer.go +++ b/pkg/storage/mock/storer.go @@ -15,7 +15,7 @@ import ( var _ storage.Storer = (*MockStorer)(nil) type MockStorer struct { - store map[string][]byte + store map[string]swarm.Chunk modePut map[string]storage.ModePut modeSet map[string]storage.ModeSet pinnedAddress []swarm.Address // Stores the pinned address @@ -52,7 +52,7 @@ func WithPartialInterval(v bool) Option { func NewStorer(opts ...Option) *MockStorer { s := &MockStorer{ - store: make(map[string][]byte), + store: make(map[string]swarm.Chunk), modePut: make(map[string]storage.ModePut), modeSet: make(map[string]storage.ModeSet), morePull: make(chan struct{}), @@ -75,7 +75,7 @@ func (m *MockStorer) Get(_ context.Context, _ storage.ModeGet, addr swarm.Addres if !has { return nil, storage.ErrNotFound } - return swarm.NewChunk(addr, v), nil + return v, nil } func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err error) { @@ -98,7 +98,9 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm // and copies the data from the call into the in-memory store b := make([]byte, len(ch.Data())) copy(b, ch.Data()) - m.store[ch.Address().String()] = b + addr := swarm.NewAddress(ch.Address().Bytes()) + stamp := ch.Stamp() + m.store[ch.Address().String()] = swarm.NewChunk(addr, b).WithStamp(stamp) m.modePut[ch.Address().String()] = mode // pin chunks if needed diff --git a/pkg/storage/store.go b/pkg/storage/store.go index 866b5f88522..445df446bc9 100644 --- a/pkg/storage/store.go +++ b/pkg/storage/store.go @@ -85,6 +85,8 @@ const ( ModePutUploadPin // ModePutRequestPin: the same as ModePutRequest but also pin the chunk with the put ModePutRequestPin + // ModePutRequestCache forces a retrieved chunk to be stored in the cache + ModePutRequestCache ) // ModeSet enumerates different Setter modes. diff --git a/pkg/storage/testing/chunk.go b/pkg/storage/testing/chunk.go index 3526b312956..506f4ef56b1 100644 --- a/pkg/storage/testing/chunk.go +++ b/pkg/storage/testing/chunk.go @@ -21,10 +21,14 @@ import ( "time" "github.com/ethersphere/bee/pkg/cac" + postagetesting "github.com/ethersphere/bee/pkg/postage/testing" "github.com/ethersphere/bee/pkg/swarm" + swarmtesting "github.com/ethersphere/bee/pkg/swarm/test" ) -// fixtreuChunks are pregenerated content-addressed chunks necessary for explicit +var mockStamp swarm.Stamp + +// fixtureChunks are pregenerated content-addressed chunks necessary for explicit // test scenarios where random generated chunks are not good enough. var fixtureChunks = map[string]swarm.Chunk{ "0025": swarm.NewChunk( @@ -48,6 +52,9 @@ var fixtureChunks = map[string]swarm.Chunk{ func init() { // needed for GenerateTestRandomChunk rand.Seed(time.Now().UnixNano()) + + mockStamp = postagetesting.MustNewStamp() + } // GenerateTestRandomChunk generates a valid content addressed chunk. @@ -55,7 +62,8 @@ func GenerateTestRandomChunk() swarm.Chunk { data := make([]byte, swarm.ChunkSize) _, _ = rand.Read(data) ch, _ := cac.New(data) - return ch + stamp := postagetesting.MustNewStamp() + return ch.WithStamp(stamp) } // GenerateTestRandomInvalidChunk generates a random, however invalid, content @@ -65,7 +73,8 @@ func GenerateTestRandomInvalidChunk() swarm.Chunk { _, _ = rand.Read(data) key := make([]byte, swarm.SectionSize) _, _ = rand.Read(key) - return swarm.NewChunk(swarm.NewAddress(key), data) + stamp := postagetesting.MustNewStamp() + return swarm.NewChunk(swarm.NewAddress(key), data).WithStamp(stamp) } // GenerateTestRandomChunks generates a slice of random @@ -78,6 +87,15 @@ func GenerateTestRandomChunks(count int) []swarm.Chunk { return chunks } +// GenerateTestRandomChunkAt generates an invalid (!) chunk with address of proximity order po wrt target. +func GenerateTestRandomChunkAt(target swarm.Address, po int) swarm.Chunk { + data := make([]byte, swarm.ChunkSize) + _, _ = rand.Read(data) + addr := swarmtesting.RandomAddressAt(target, po) + stamp := postagetesting.MustNewStamp() + return swarm.NewChunk(addr, data).WithStamp(stamp) +} + // FixtureChunk gets a pregenerated content-addressed chunk and // panics if one is not found. func FixtureChunk(prefix string) swarm.Chunk { @@ -85,5 +103,5 @@ func FixtureChunk(prefix string) swarm.Chunk { if !ok { panic("no fixture found") } - return c + return c.WithStamp(mockStamp) } diff --git a/pkg/swarm/swarm.go b/pkg/swarm/swarm.go index d6acff222f5..d34fe43faf0 100644 --- a/pkg/swarm/swarm.go +++ b/pkg/swarm/swarm.go @@ -7,6 +7,7 @@ package swarm import ( "bytes" + "encoding" "encoding/hex" "encoding/json" "errors" @@ -125,17 +126,43 @@ var ZeroAddress = NewAddress(nil) type AddressIterFunc func(address Address) error type Chunk interface { + // Address returns the chunk address. Address() Address + // Data returns the chunk data. Data() []byte + // TagID returns the tag ID for this chunk. TagID() uint32 + // WithTagID attaches the tag ID to the chunk. WithTagID(t uint32) Chunk + // Stamp returns the postage stamp associated with this chunk. + Stamp() Stamp + // WithStamp attaches a postage stamp to the chunk. + WithStamp(Stamp) Chunk + // Radius is the PO above which the batch is preserved. + Radius() uint8 + // Depth returns the batch depth of the stamp - allowed batch size = 2^{depth}. + Depth() uint8 + // WithBatch attaches batch parameters to the chunk. + WithBatch(radius, depth uint8) Chunk + // Equal checks if the chunk is equal to another. Equal(Chunk) bool } +// Stamp interface for postage.Stamp to avoid circular dependency +type Stamp interface { + BatchID() []byte + Sig() []byte + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler +} + type chunk struct { - addr Address - sdata []byte - tagID uint32 + addr Address + sdata []byte + tagID uint32 + stamp Stamp + radius uint8 + depth uint8 } func NewChunk(addr Address, data []byte) Chunk { @@ -150,6 +177,17 @@ func (c *chunk) WithTagID(t uint32) Chunk { return c } +func (c *chunk) WithStamp(stamp Stamp) Chunk { + c.stamp = stamp + return c +} + +func (c *chunk) WithBatch(radius, depth uint8) Chunk { + c.radius = radius + c.depth = depth + return c +} + func (c *chunk) Address() Address { return c.addr } @@ -162,6 +200,18 @@ func (c *chunk) TagID() uint32 { return c.tagID } +func (c *chunk) Stamp() Stamp { + return c.stamp +} + +func (c *chunk) Radius() uint8 { + return c.radius +} + +func (c *chunk) Depth() uint8 { + return c.depth +} + func (c *chunk) String() string { return fmt.Sprintf("Address: %v Chunksize: %v", c.addr.String(), len(c.sdata)) } diff --git a/pkg/swarm/test/helper.go b/pkg/swarm/test/helper.go index acbb0a622b9..4dd5b377065 100644 --- a/pkg/swarm/test/helper.go +++ b/pkg/swarm/test/helper.go @@ -26,12 +26,12 @@ func RandomAddressAt(self swarm.Address, prox int) swarm.Address { } flipbyte := byte(1 << uint8(7-trans)) transbyteb := transbytea ^ byte(255) - randbyte := byte(rand.Intn(255)) + randbyte := byte(rand.Intn(255)) // skipcq: GSC-G404 addr[pos] = ((addr[pos] & transbytea) ^ flipbyte) | randbyte&transbyteb } for i := pos + 1; i < len(addr); i++ { - addr[i] = byte(rand.Intn(255)) + addr[i] = byte(rand.Intn(255)) // skipcq: GSC-G404 } a := swarm.NewAddress(addr) diff --git a/pkg/topology/kademlia/kademlia.go b/pkg/topology/kademlia/kademlia.go index 6eddb8219aa..83f8360f392 100644 --- a/pkg/topology/kademlia/kademlia.go +++ b/pkg/topology/kademlia/kademlia.go @@ -67,6 +67,7 @@ type Kad struct { knownPeers *pslice.PSlice // both are po aware slice of addresses bootnodes []ma.Multiaddr depth uint8 // current neighborhood depth + radius uint8 // storage area of responsibility depthMu sync.RWMutex // protect depth changes manageC chan struct{} // trigger the manage forever loop to connect to new peers waitNext map[string]retryInfo // sanction connections to a peer, key is overlay string and value is a retry information @@ -87,7 +88,12 @@ type retryInfo struct { } // New returns a new Kademlia. -func New(base swarm.Address, addressbook addressbook.Interface, discovery discovery.Driver, p2p p2p.Service, logger logging.Logger, o Options) *Kad { +func New(base swarm.Address, + addressbook addressbook.Interface, + discovery discovery.Driver, + p2p p2p.Service, + logger logging.Logger, + o Options) *Kad { if o.SaturationFunc == nil { o.SaturationFunc = binSaturated } @@ -340,7 +346,7 @@ func (k *Kad) manage() { k.connectedPeers.Add(peer, po) k.depthMu.Lock() - k.depth = recalcDepth(k.connectedPeers) + k.depth = recalcDepth(k.connectedPeers, k.radius) k.depthMu.Unlock() k.logger.Debugf("connected to peer: %s for bin: %d", peer, i) @@ -421,7 +427,7 @@ func (k *Kad) manage() { k.connectedPeers.Add(peer, po) k.depthMu.Lock() - k.depth = recalcDepth(k.connectedPeers) + k.depth = recalcDepth(k.connectedPeers, k.radius) k.depthMu.Unlock() k.logger.Debugf("connected to peer: %s old depth: %d new depth: %d", peer, currentDepth, k.NeighborhoodDepth()) @@ -518,7 +524,7 @@ func (k *Kad) connectBootnodes(ctx context.Context) { // when a bin is not saturated it means we would like to proactively // initiate connections to other peers in the bin. func binSaturated(bin uint8, peers, connected *pslice.PSlice) (bool, bool) { - potentialDepth := recalcDepth(peers) + potentialDepth := recalcDepth(peers, swarm.MaxPO) // short circuit for bins which are >= depth if bin >= potentialDepth { @@ -544,7 +550,7 @@ func binSaturated(bin uint8, peers, connected *pslice.PSlice) (bool, bool) { } // recalcDepth calculates and returns the kademlia depth. -func recalcDepth(peers *pslice.PSlice) uint8 { +func recalcDepth(peers *pslice.PSlice, radius uint8) uint8 { // handle edge case separately if peers.Length() <= nnLowWatermark { return 0 @@ -590,9 +596,15 @@ func recalcDepth(peers *pslice.PSlice) uint8 { return false, false, nil }) if shallowestUnsaturated > candidate { + if radius < candidate { + return radius + } return candidate } + if radius < shallowestUnsaturated { + return radius + } return shallowestUnsaturated } @@ -761,7 +773,7 @@ func (k *Kad) connected(ctx context.Context, addr swarm.Address) error { k.waitNextMu.Unlock() k.depthMu.Lock() - k.depth = recalcDepth(k.connectedPeers) + k.depth = recalcDepth(k.connectedPeers, k.radius) k.depthMu.Unlock() k.notifyPeerSig() @@ -779,7 +791,7 @@ func (k *Kad) Disconnected(peer p2p.Peer) { k.waitNextMu.Unlock() k.depthMu.Lock() - k.depth = recalcDepth(k.connectedPeers) + k.depth = recalcDepth(k.connectedPeers, k.radius) k.depthMu.Unlock() select { @@ -1030,6 +1042,23 @@ func (k *Kad) IsBalanced(bin uint8) bool { return true } +func (k *Kad) SetRadius(r uint8) { + k.depthMu.Lock() + defer k.depthMu.Unlock() + if k.radius == r { + return + } + k.radius = r + oldD := k.depth + k.depth = recalcDepth(k.connectedPeers, k.radius) + if k.depth != oldD { + select { + case k.manageC <- struct{}{}: + default: + } + } +} + func (k *Kad) Snapshot() *topology.KadParams { var infos []topology.BinInfo for i := int(swarm.MaxPO); i >= 0; i-- { diff --git a/pkg/topology/kademlia/kademlia_test.go b/pkg/topology/kademlia/kademlia_test.go index edaabfceacd..f96eb705d15 100644 --- a/pkg/topology/kademlia/kademlia_test.go +++ b/pkg/topology/kademlia/kademlia_test.go @@ -50,6 +50,8 @@ func TestNeighborhoodDepth(t *testing.T) { base, kad, ab, _, signer = newTestKademlia(&conns, nil, kademlia.Options{}) ) + kad.SetRadius(swarm.MaxPO) // initial tests do not check for radius + if err := kad.Start(context.Background()); err != nil { t.Fatal(err) } @@ -111,8 +113,14 @@ func TestNeighborhoodDepth(t *testing.T) { // depth is 7 because bin 7 is unsaturated (1 peer) kDepth(t, kad, 7) - // expect shallow peers not in depth + // set the radius to be lower than unsaturated, expect radius as depth + kad.SetRadius(6) + kDepth(t, kad, 6) + // set the radius to MaxPO again so that intermediate checks can run + kad.SetRadius(swarm.MaxPO) + + // expect shallow peers not in depth for _, a := range shallowPeers { if kad.IsWithinDepth(a) { t.Fatal("expected address to outside of depth") @@ -141,6 +149,13 @@ func TestNeighborhoodDepth(t *testing.T) { waitConn(t, &conns) kDepth(t, kad, 8) + // again set radius to lower value, expect that as depth + kad.SetRadius(5) + kDepth(t, kad, 5) + + // reset radius to MaxPO for the rest of the checks + kad.SetRadius(swarm.MaxPO) + var addrs []swarm.Address // fill the rest up to the bin before last and check that everything works at the edges for i := 9; i < int(swarm.MaxBins); i++ { @@ -303,6 +318,8 @@ func TestManageWithBalancing(t *testing.T) { base, kad, ab, _, signer = newTestKademlia(&conns, nil, kademlia.Options{SaturationFunc: saturationFunc, BitSuffixLength: 2}) ) + kad.SetRadius(swarm.MaxPO) // don't use radius for checks + // implement satiration function (while having access to Kademlia instance) sfImpl := func(bin uint8, peers, connected *pslice.PSlice) (bool, bool) { return kad.IsBalanced(bin), false @@ -418,6 +435,7 @@ func TestOversaturation(t *testing.T) { conns int32 // how many connect calls were made to the p2p mock base, kad, ab, _, signer = newTestKademlia(&conns, nil, kademlia.Options{}) ) + kad.SetRadius(swarm.MaxPO) // don't use radius for checks if err := kad.Start(context.Background()); err != nil { t.Fatal(err) @@ -473,6 +491,7 @@ func TestOversaturationBootnode(t *testing.T) { conns int32 // how many connect calls were made to the p2p mock base, kad, ab, _, signer = newTestKademlia(&conns, nil, kademlia.Options{BootnodeMode: true}) ) + kad.SetRadius(swarm.MaxPO) // don't use radius for checks if err := kad.Start(context.Background()); err != nil { t.Fatal(err) diff --git a/pkg/tracing/tracing.go b/pkg/tracing/tracing.go index 3755cc9df92..97171616b58 100644 --- a/pkg/tracing/tracing.go +++ b/pkg/tracing/tracing.go @@ -186,11 +186,7 @@ func (t *Tracer) AddContextHTTPHeader(ctx context.Context, headers http.Header) } carrier := opentracing.HTTPHeadersCarrier(headers) - if err := t.tracer.Inject(c, opentracing.HTTPHeaders, carrier); err != nil { - return err - } - - return nil + return t.tracer.Inject(c, opentracing.HTTPHeaders, carrier) } // FromHTTPHeaders returns tracing span context from HTTP headers. If the tracing