diff --git a/pkg/postage/batch.go b/pkg/postage/batch.go index 2f5e0ecda3c..073df452dd4 100644 --- a/pkg/postage/batch.go +++ b/pkg/postage/batch.go @@ -11,23 +11,26 @@ import ( // Batch represents a postage batch, a payment on the blockchain. type Batch struct { - ID []byte // batch ID - Value *big.Int // overall balance of the batch - Start uint64 // block number the batch was created - Owner []byte // owner's ethereum address - Depth uint8 // batch depth, i.e., size = 2^{depth} + ID []byte // batch ID + Value *big.Int // overall balance of the batch + Start uint64 // block number the batch was created + Owner []byte // owner's ethereum address + Depth uint8 // batch depth, i.e., size = 2^{depth} + NormalisedBalance *big.Int // normalised balance of the batch } // MarshalBinary implements BinaryMarshaller. It will attempt to serialize the // postage batch to a byte slice. func (b *Batch) MarshalBinary() ([]byte, error) { - out := make([]byte, 93) + out := make([]byte, 157) copy(out, b.ID) value := b.Value.Bytes() copy(out[64-len(value):], value) binary.BigEndian.PutUint64(out[64:72], b.Start) copy(out[72:], b.Owner) out[92] = b.Depth + normalisedBalance := b.NormalisedBalance.Bytes() + copy(out[157-len(normalisedBalance):], normalisedBalance) return out, nil } @@ -39,5 +42,6 @@ func (b *Batch) UnmarshalBinary(buf []byte) error { b.Start = binary.BigEndian.Uint64(buf[64:72]) b.Owner = buf[72:92] b.Depth = buf[92] + b.NormalisedBalance = big.NewInt(0).SetBytes(buf[93:]) return nil } diff --git a/pkg/postage/batch_test.go b/pkg/postage/batch_test.go index d8c06ff9629..617041a14dd 100644 --- a/pkg/postage/batch_test.go +++ b/pkg/postage/batch_test.go @@ -21,7 +21,7 @@ func TestBatchMarshalling(t *testing.T) { if err != nil { t.Fatal(err) } - if len(buf) != 93 { + if len(buf) != 157 { t.Fatalf("invalid length for serialised batch. expected 93, got %d", len(buf)) } b := &postage.Batch{} @@ -43,4 +43,7 @@ func TestBatchMarshalling(t *testing.T) { if a.Depth != b.Depth { t.Fatalf("depth mismatch, expected %d, got %d", a.Depth, b.Depth) } + if a.NormalisedBalance.Uint64() != b.NormalisedBalance.Uint64() { + t.Fatalf("normalised balance mismatch, expected %d, got %d", a.NormalisedBalance.Uint64(), b.NormalisedBalance.Uint64()) + } } diff --git a/pkg/postage/batchservice/batchservice.go b/pkg/postage/batchservice/batchservice.go index 28c19821ae9..ad33bccf37b 100644 --- a/pkg/postage/batchservice/batchservice.go +++ b/pkg/postage/batchservice/batchservice.go @@ -37,13 +37,14 @@ func New(storer postage.Storer, logger logging.Logger) (postage.EventUpdater, er // Create will create a new batch with the given ID, owner value and depth and // stores it in the BatchStore. -func (svc *batchService) Create(id, owner []byte, value *big.Int, depth uint8) error { +func (svc *batchService) Create(id, owner []byte, value *big.Int, normalisedBalance *big.Int, depth uint8) error { b := &postage.Batch{ - ID: id, - Owner: owner, - Value: value, - Start: svc.cs.Block, - Depth: depth, + ID: id, + Owner: owner, + Value: value, + Start: svc.cs.Block, + Depth: depth, + NormalisedBalance: normalisedBalance, } err := svc.storer.Put(b) @@ -57,13 +58,14 @@ func (svc *batchService) Create(id, owner []byte, value *big.Int, depth uint8) e // TopUp implements the EventUpdater interface. It tops ups a batch with the // given ID with the given amount. -func (svc *batchService) TopUp(id []byte, amount *big.Int) error { +func (svc *batchService) TopUp(id []byte, amount *big.Int, normalisedBalance *big.Int) error { b, err := svc.storer.Get(id) if err != nil { return fmt.Errorf("get: %w", err) } b.Value.Add(b.Value, amount) + b.NormalisedBalance.Set(normalisedBalance) err = svc.storer.Put(b) if err != nil { @@ -76,13 +78,14 @@ func (svc *batchService) TopUp(id []byte, amount *big.Int) error { // UpdateDepth implements the EventUpdater inteface. It sets the new depth of a // batch with the given ID. -func (svc *batchService) UpdateDepth(id []byte, depth uint8) error { +func (svc *batchService) UpdateDepth(id []byte, depth uint8, normalisedBalance *big.Int) error { b, err := svc.storer.Get(id) if err != nil { return fmt.Errorf("get: %w", err) } b.Depth = depth + b.NormalisedBalance.Set(normalisedBalance) err = svc.storer.Put(b) if err != nil { diff --git a/pkg/postage/batchservice/batchservice_test.go b/pkg/postage/batchservice/batchservice_test.go index 4390fd9d9f2..30d98820da6 100644 --- a/pkg/postage/batchservice/batchservice_test.go +++ b/pkg/postage/batchservice/batchservice_test.go @@ -61,6 +61,7 @@ func TestBatchServiceCreate(t *testing.T) { testBatch.ID, testBatch.Owner, testBatch.Value, + testBatch.NormalisedBalance, testBatch.Depth, ); err == nil { t.Fatalf("expected error") @@ -77,6 +78,7 @@ func TestBatchServiceCreate(t *testing.T) { testBatch.ID, testBatch.Owner, testBatch.Value, + testBatch.NormalisedBalance, testBatch.Depth, ); err != nil { t.Fatalf("got error %v", err) @@ -102,6 +104,9 @@ func TestBatchServiceCreate(t *testing.T) { if got.Start != testChainState.Block { t.Fatalf("batch start block different form chain state: want %v, got %v", got.Start, testChainState.Block) } + if got.NormalisedBalance.Cmp(testBatch.NormalisedBalance) != 0 { + t.Fatalf("normalised batch value: want %v, got %v", testBatch.NormalisedBalance.String(), got.NormalisedBalance.String()) + } }) @@ -110,6 +115,7 @@ func TestBatchServiceCreate(t *testing.T) { func TestBatchServiceTopUp(t *testing.T) { testBatch := postagetesting.MustNewBatch() testTopUpAmount := big.NewInt(10000000000000) + testNormalisedBalance := big.NewInt(2000000000000) t.Run("expect get error", func(t *testing.T) { svc, _ := newTestStoreAndService( @@ -118,7 +124,7 @@ func TestBatchServiceTopUp(t *testing.T) { mock.WithGetErr(testErr, 1), ) - if err := svc.TopUp(testBatch.ID, testTopUpAmount); err == nil { + if err := svc.TopUp(testBatch.ID, testTopUpAmount, testNormalisedBalance); err == nil { t.Fatal("expected error") } }) @@ -130,7 +136,7 @@ func TestBatchServiceTopUp(t *testing.T) { ) putBatch(t, batchStore, testBatch) - if err := svc.TopUp(testBatch.ID, testTopUpAmount); err == nil { + if err := svc.TopUp(testBatch.ID, testTopUpAmount, testNormalisedBalance); err == nil { t.Fatal("expected error") } }) @@ -142,7 +148,7 @@ func TestBatchServiceTopUp(t *testing.T) { want := testBatch.Value want.Add(want, testTopUpAmount) - if err := svc.TopUp(testBatch.ID, testTopUpAmount); err != nil { + if err := svc.TopUp(testBatch.ID, testTopUpAmount, testNormalisedBalance); err != nil { t.Fatalf("top up: %v", err) } @@ -155,11 +161,15 @@ func TestBatchServiceTopUp(t *testing.T) { t.Fatalf("topped up amount: got %v, want %v", got.Value, want) } + if got.NormalisedBalance.Cmp(testNormalisedBalance) != 0 { + t.Fatalf("normalised batch value: want %v, got %v", testNormalisedBalance.String(), got.NormalisedBalance.String()) + } }) } func TestBatchServiceUpdateDepth(t *testing.T) { const testNewDepth = 30 + testNormalisedBalance := big.NewInt(2000000000000) testBatch := postagetesting.MustNewBatch() t.Run("expect get error", func(t *testing.T) { @@ -168,7 +178,7 @@ func TestBatchServiceUpdateDepth(t *testing.T) { mock.WithGetErr(testErr, 1), ) - if err := svc.UpdateDepth(testBatch.ID, testNewDepth); err == nil { + if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance); err == nil { t.Fatal("expected get error") } }) @@ -180,7 +190,7 @@ func TestBatchServiceUpdateDepth(t *testing.T) { ) putBatch(t, batchStore, testBatch) - if err := svc.UpdateDepth(testBatch.ID, testNewDepth); err == nil { + if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance); err == nil { t.Fatal("expected put error") } }) @@ -189,7 +199,7 @@ func TestBatchServiceUpdateDepth(t *testing.T) { svc, batchStore := newTestStoreAndService(t) putBatch(t, batchStore, testBatch) - if err := svc.UpdateDepth(testBatch.ID, testNewDepth); err != nil { + if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance); err != nil { t.Fatalf("update depth: %v", err) } @@ -201,6 +211,10 @@ func TestBatchServiceUpdateDepth(t *testing.T) { if val.Depth != testNewDepth { t.Fatalf("wrong batch depth set: want %v, got %v", testNewDepth, val.Depth) } + + if val.NormalisedBalance.Cmp(testNormalisedBalance) != 0 { + t.Fatalf("normalised batch value: want %v, got %v", testNormalisedBalance.String(), val.NormalisedBalance.String()) + } }) } diff --git a/pkg/postage/interface.go b/pkg/postage/interface.go index 6da023c71ad..79f313565e5 100644 --- a/pkg/postage/interface.go +++ b/pkg/postage/interface.go @@ -11,9 +11,9 @@ import ( // EventUpdater interface definitions reflect the updates triggered by events // emitted by the postage contract on the blockchain. type EventUpdater interface { - Create(id []byte, owner []byte, amount *big.Int, depth uint8) error - TopUp(id []byte, amount *big.Int) error - UpdateDepth(id []byte, depth uint8) error + Create(id []byte, owner []byte, amount *big.Int, normalisedBalance *big.Int, depth uint8) error + TopUp(id []byte, amount *big.Int, normalisedBalance *big.Int) error + UpdateDepth(id []byte, depth uint8, normalisedBalance *big.Int) error UpdatePrice(price *big.Int) error } diff --git a/pkg/postage/testing/batch.go b/pkg/postage/testing/batch.go index 51bf5b07566..1676151a221 100644 --- a/pkg/postage/testing/batch.go +++ b/pkg/postage/testing/batch.go @@ -50,10 +50,11 @@ func NewBigInt() *big.Int { // be filled with random data. Panics on errors. func MustNewBatch(opts ...BatchOption) *postage.Batch { b := &postage.Batch{ - ID: MustNewID(), - Value: NewBigInt(), - Start: rand.Uint64(), - Depth: defaultDepth, + ID: MustNewID(), + Value: NewBigInt(), + Start: rand.Uint64(), + Depth: defaultDepth, + NormalisedBalance: NewBigInt(), } for _, opt := range opts { @@ -95,4 +96,7 @@ func CompareBatches(t *testing.T, want, got *postage.Batch) { if want.Depth != got.Depth { t.Fatalf("depth: want %v, got %v", want.Depth, got.Depth) } + if want.NormalisedBalance.Cmp(got.NormalisedBalance) != 0 { + t.Fatalf("normalised balance: want %v, got %v", want.NormalisedBalance, got.NormalisedBalance) + } }