diff --git a/execution/executor/sortition.go b/execution/executor/sortition.go index 67807acf4..fc294bb23 100644 --- a/execution/executor/sortition.go +++ b/execution/executor/sortition.go @@ -77,10 +77,10 @@ func (e *SortitionExecutor) joinCommittee(sb sandbox.Sandbox, joiningPower := int64(0) committee := sb.Committee() currentHeight := sb.CurrentHeight() - sb.IterateValidators(func(vs *sandbox.ValidatorStatus) { - if vs.Validator.LastJoinedHeight() == currentHeight { - if !committee.Contains(vs.Validator.Address()) { - joiningPower += vs.Validator.Power() + sb.IterateValidators(func(val *validator.Validator, updated bool) { + if val.LastJoinedHeight() == currentHeight { + if !committee.Contains(val.Address()) { + joiningPower += val.Power() joiningNum++ } } diff --git a/execution/executor/sortition_test.go b/execution/executor/sortition_test.go index 703bea4bb..544c54566 100644 --- a/execution/executor/sortition_test.go +++ b/execution/executor/sortition_test.go @@ -5,7 +5,6 @@ import ( "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/crypto/bls" - "github.com/pactus-project/pactus/sandbox" "github.com/pactus-project/pactus/sortition" "github.com/pactus-project/pactus/types/tx" "github.com/pactus-project/pactus/types/validator" @@ -211,9 +210,9 @@ func TestOldestDidNotPropose(t *testing.T) { assert.NoError(t, exe.Execute(trx2, tSandbox)) joined := make([]*validator.Validator, 0) - tSandbox.IterateValidators(func(vs *sandbox.ValidatorStatus) { - if vs.Validator.LastJoinedHeight() == tSandbox.CurrentHeight() { - joined = append(joined, &vs.Validator) + tSandbox.IterateValidators(func(val *validator.Validator, updated bool) { + if val.LastJoinedHeight() == tSandbox.CurrentHeight() { + joined = append(joined, val) } }) tSandbox.TestCommittee.Update(0, joined) diff --git a/sandbox/interface.go b/sandbox/interface.go index 95499b478..57617415d 100644 --- a/sandbox/interface.go +++ b/sandbox/interface.go @@ -28,6 +28,6 @@ type Sandbox interface { Params() param.Params CurrentHeight() uint32 - IterateAccounts(consumer func(crypto.Address, *AccountStatus)) - IterateValidators(consumer func(*ValidatorStatus)) + IterateAccounts(consumer func(addr crypto.Address, acc *account.Account, updated bool)) + IterateValidators(consumer func(val *validator.Validator, updated bool)) } diff --git a/sandbox/mock.go b/sandbox/mock.go index e0e0e09e3..861c597aa 100644 --- a/sandbox/mock.go +++ b/sandbox/mock.go @@ -81,21 +81,15 @@ func (m *MockSandbox) Params() param.Params { func (m *MockSandbox) RecentBlockByStamp(stamp hash.Stamp) (uint32, *block.Block) { return m.TestStore.RecentBlockByStamp(stamp) } -func (m *MockSandbox) IterateAccounts(consumer func(crypto.Address, *AccountStatus)) { +func (m *MockSandbox) IterateAccounts(consumer func(crypto.Address, *account.Account, bool)) { m.TestStore.IterateAccounts(func(addr crypto.Address, acc *account.Account) bool { - consumer(addr, &AccountStatus{ - Account: *acc, - Updated: true, - }) + consumer(addr, acc, true) return false }) } -func (m *MockSandbox) IterateValidators(consumer func(*ValidatorStatus)) { +func (m *MockSandbox) IterateValidators(consumer func(*validator.Validator, bool)) { m.TestStore.IterateValidators(func(val *validator.Validator) bool { - consumer(&ValidatorStatus{ - Validator: *val, - Updated: true, - }) + consumer(val, true) return false }) } diff --git a/sandbox/sandbox.go b/sandbox/sandbox.go index e391052cf..d589af933 100644 --- a/sandbox/sandbox.go +++ b/sandbox/sandbox.go @@ -23,21 +23,21 @@ type sandbox struct { store store.Reader committee committee.Reader - accounts map[crypto.Address]*AccountStatus - validators map[crypto.Address]*ValidatorStatus + accounts map[crypto.Address]*sandboxAccount + validators map[crypto.Address]*sandboxValidator params param.Params totalAccounts int32 totalValidators int32 } -type ValidatorStatus struct { - Validator validator.Validator - Updated bool +type sandboxValidator struct { + validator *validator.Validator + updated bool } -type AccountStatus struct { - Account account.Account - Updated bool +type sandboxAccount struct { + account *account.Account + updated bool } func NewSandbox(store store.Reader, params param.Params, committee committee.Reader) Sandbox { @@ -47,8 +47,8 @@ func NewSandbox(store store.Reader, params param.Params, committee committee.Rea params: params, } - sb.accounts = make(map[crypto.Address]*AccountStatus) - sb.validators = make(map[crypto.Address]*ValidatorStatus) + sb.accounts = make(map[crypto.Address]*sandboxAccount) + sb.validators = make(map[crypto.Address]*sandboxValidator) sb.totalAccounts = sb.store.TotalAccounts() sb.totalValidators = sb.store.TotalValidators() @@ -80,21 +80,20 @@ func (sb *sandbox) Account(addr crypto.Address) *account.Account { s, ok := sb.accounts[addr] if ok { - clone := new(account.Account) - *clone = s.Account - return clone + return s.account.Clone() } acc, err := sb.store.Account(addr) if err != nil { return nil } - sb.accounts[addr] = &AccountStatus{ - Account: *acc, + sb.accounts[addr] = &sandboxAccount{ + account: acc, } - return acc + return acc.Clone() } + func (sb *sandbox) MakeNewAccount(addr crypto.Address) *account.Account { sb.lk.Lock() defer sb.lk.Unlock() @@ -104,14 +103,17 @@ func (sb *sandbox) MakeNewAccount(addr crypto.Address) *account.Account { } acc := account.NewAccount(sb.totalAccounts) - sb.accounts[addr] = &AccountStatus{ - Account: *acc, - Updated: true, + sb.accounts[addr] = &sandboxAccount{ + account: acc, + updated: true, } sb.totalAccounts++ - return acc + return acc.Clone() } +// This function takes ownership of the account pointer. +// It is important that the caller should not modify the account data and +// keep it immutable. func (sb *sandbox) UpdateAccount(addr crypto.Address, acc *account.Account) { sb.lk.Lock() defer sb.lk.Unlock() @@ -120,8 +122,8 @@ func (sb *sandbox) UpdateAccount(addr crypto.Address, acc *account.Account) { if !ok { sb.shouldPanicForUnknownAddress() } - s.Account = *acc - s.Updated = true + s.account = acc + s.updated = true } func (sb *sandbox) Validator(addr crypto.Address) *validator.Validator { @@ -130,19 +132,17 @@ func (sb *sandbox) Validator(addr crypto.Address) *validator.Validator { s, ok := sb.validators[addr] if ok { - clone := new(validator.Validator) - *clone = s.Validator - return clone + return s.validator.Clone() } val, err := sb.store.Validator(addr) if err != nil { return nil } - sb.validators[addr] = &ValidatorStatus{ - Validator: *val, + sb.validators[addr] = &sandboxValidator{ + validator: val, } - return val + return val.Clone() } func (sb *sandbox) MakeNewValidator(pub *bls.PublicKey) *validator.Validator { @@ -155,14 +155,17 @@ func (sb *sandbox) MakeNewValidator(pub *bls.PublicKey) *validator.Validator { } val := validator.NewValidator(pub, sb.totalValidators) - sb.validators[addr] = &ValidatorStatus{ - Validator: *val, - Updated: true, + sb.validators[addr] = &sandboxValidator{ + validator: val, + updated: true, } sb.totalValidators++ - return val + return val.Clone() } +// This function takes ownership of the validator pointer. +// It is important that the caller should not modify the validator data and +// keep it immutable. func (sb *sandbox) UpdateValidator(val *validator.Validator) { sb.lk.Lock() defer sb.lk.Unlock() @@ -173,8 +176,8 @@ func (sb *sandbox) UpdateValidator(val *validator.Validator) { sb.shouldPanicForUnknownAddress() } - s.Validator = *val - s.Updated = true + s.validator = val + s.updated = true } func (sb *sandbox) Params() param.Params { @@ -194,21 +197,21 @@ func (sb *sandbox) currentHeight() uint32 { return h + 1 } -func (sb *sandbox) IterateAccounts(consumer func(crypto.Address, *AccountStatus)) { +func (sb *sandbox) IterateAccounts(consumer func(crypto.Address, *account.Account, bool)) { sb.lk.RLock() defer sb.lk.RUnlock() - for addr, as := range sb.accounts { - consumer(addr, as) + for addr, sa := range sb.accounts { + consumer(addr, sa.account, sa.updated) } } -func (sb *sandbox) IterateValidators(consumer func(*ValidatorStatus)) { +func (sb *sandbox) IterateValidators(consumer func(*validator.Validator, bool)) { sb.lk.RLock() defer sb.lk.RUnlock() - for _, vs := range sb.validators { - consumer(vs) + for _, sv := range sb.validators { + consumer(sv.validator, sv.updated) } } diff --git a/sandbox/sandbox_test.go b/sandbox/sandbox_test.go index 6adbe1556..3f72ac1c1 100644 --- a/sandbox/sandbox_test.go +++ b/sandbox/sandbox_test.go @@ -60,13 +60,13 @@ func TestAccountChange(t *testing.T) { invAddr := crypto.GenerateTestAddress() assert.Nil(t, tSandbox.Account(invAddr)) - tSandbox.IterateAccounts(func(_ crypto.Address, _ *AccountStatus) { + tSandbox.IterateAccounts(func(_ crypto.Address, _ *account.Account, _ bool) { panic("should be empty") }) }) t.Run("Retrieve an account from store and update it", func(t *testing.T) { - acc, signer := account.GenerateTestAccount(888) + acc, signer := account.GenerateTestAccount(util.RandInt32(0)) addr := signer.Address() bal := acc.Balance() seq := acc.Sequence() @@ -78,11 +78,11 @@ func TestAccountChange(t *testing.T) { sbAcc1.IncSequence() sbAcc1.AddToBalance(1) - assert.False(t, tSandbox.accounts[addr].Updated) + assert.False(t, tSandbox.accounts[addr].updated) assert.Equal(t, tSandbox.Account(addr).Balance(), bal) assert.Equal(t, tSandbox.Account(addr).Sequence(), seq) tSandbox.UpdateAccount(addr, sbAcc1) - assert.True(t, tSandbox.accounts[addr].Updated) + assert.True(t, tSandbox.accounts[addr].updated) assert.Equal(t, tSandbox.Account(addr).Balance(), bal+1) assert.Equal(t, tSandbox.Account(addr).Sequence(), seq+1) @@ -91,20 +91,20 @@ func TestAccountChange(t *testing.T) { sbAcc2.IncSequence() sbAcc2.AddToBalance(1) - assert.True(t, tSandbox.accounts[addr].Updated, "it is updated before") + assert.True(t, tSandbox.accounts[addr].updated, "it is updated before") assert.Equal(t, tSandbox.Account(addr).Balance(), bal+1) assert.Equal(t, tSandbox.Account(addr).Sequence(), seq+1) tSandbox.UpdateAccount(addr, sbAcc2) - assert.True(t, tSandbox.accounts[addr].Updated) + assert.True(t, tSandbox.accounts[addr].updated) assert.Equal(t, tSandbox.Account(addr).Balance(), bal+2) assert.Equal(t, tSandbox.Account(addr).Sequence(), seq+2) }) t.Run("Should be iterated", func(t *testing.T) { - tSandbox.IterateAccounts(func(a crypto.Address, as *AccountStatus) { + tSandbox.IterateAccounts(func(a crypto.Address, acc *account.Account, updated bool) { assert.Equal(t, addr, a) - assert.True(t, as.Updated) - assert.Equal(t, as.Account.Balance(), bal+2) + assert.True(t, updated) + assert.Equal(t, acc.Balance(), bal+2) }) }) }) @@ -121,10 +121,10 @@ func TestAccountChange(t *testing.T) { assert.Equal(t, acc, sbAcc) t.Run("Should be iterated", func(t *testing.T) { - tSandbox.IterateAccounts(func(a crypto.Address, as *AccountStatus) { + tSandbox.IterateAccounts(func(a crypto.Address, acc *account.Account, updated bool) { if a == addr { - assert.True(t, as.Updated) - assert.Equal(t, as.Account.Balance(), int64(1)) + assert.True(t, updated) + assert.Equal(t, acc.Balance(), int64(1)) } }) }) @@ -138,13 +138,13 @@ func TestValidatorChange(t *testing.T) { invAddr := crypto.GenerateTestAddress() assert.Nil(t, tSandbox.Validator(invAddr)) - tSandbox.IterateValidators(func(_ *ValidatorStatus) { + tSandbox.IterateValidators(func(_ *validator.Validator, _ bool) { panic("should be empty") }) }) t.Run("Retrieve an validator from store and update it", func(t *testing.T) { - val, _ := validator.GenerateTestValidator(888) + val, _ := validator.GenerateTestValidator(util.RandInt32(0)) addr := val.Address() stk := val.Stake() seq := val.Sequence() @@ -156,11 +156,11 @@ func TestValidatorChange(t *testing.T) { sbVal1.IncSequence() sbVal1.AddToStake(1) - assert.False(t, tSandbox.validators[addr].Updated) + assert.False(t, tSandbox.validators[addr].updated) assert.Equal(t, tSandbox.Validator(addr).Stake(), stk) assert.Equal(t, tSandbox.Validator(addr).Sequence(), seq) tSandbox.UpdateValidator(sbVal1) - assert.True(t, tSandbox.validators[sbVal1.Address()].Updated) + assert.True(t, tSandbox.validators[sbVal1.Address()].updated) assert.Equal(t, tSandbox.Validator(addr).Stake(), stk+1) assert.Equal(t, tSandbox.Validator(addr).Sequence(), seq+1) @@ -169,19 +169,19 @@ func TestValidatorChange(t *testing.T) { sbVal2.IncSequence() sbVal2.AddToStake(1) - assert.True(t, tSandbox.validators[addr].Updated, "it is updated before") + assert.True(t, tSandbox.validators[addr].updated, "it is updated before") assert.Equal(t, tSandbox.Validator(addr).Stake(), stk+1) assert.Equal(t, tSandbox.Validator(addr).Sequence(), seq+1) tSandbox.UpdateValidator(sbVal2) - assert.True(t, tSandbox.validators[sbVal1.Address()].Updated) + assert.True(t, tSandbox.validators[sbVal1.Address()].updated) assert.Equal(t, tSandbox.Validator(addr).Stake(), stk+2) assert.Equal(t, tSandbox.Validator(addr).Sequence(), seq+2) }) t.Run("Should be iterated", func(t *testing.T) { - tSandbox.IterateValidators(func(vs *ValidatorStatus) { - assert.True(t, vs.Updated) - assert.Equal(t, vs.Validator.Stake(), stk+2) + tSandbox.IterateValidators(func(val *validator.Validator, updated bool) { + assert.True(t, updated) + assert.Equal(t, val.Stake(), stk+2) }) }) }) @@ -198,10 +198,10 @@ func TestValidatorChange(t *testing.T) { assert.Equal(t, val, sbVal) t.Run("Should be iterated", func(t *testing.T) { - tSandbox.IterateValidators(func(vs *ValidatorStatus) { - if vs.Validator.PublicKey() == pub { - assert.True(t, vs.Updated) - assert.Equal(t, vs.Validator.Stake(), int64(1)) + tSandbox.IterateValidators(func(val *validator.Validator, updated bool) { + if val.PublicKey() == pub { + assert.True(t, updated) + assert.Equal(t, val.Stake(), int64(1)) } }) }) @@ -278,7 +278,7 @@ func TestUpdateFromOutsideTheSandbox(t *testing.T) { t.Errorf("The code did not panic") } }() - acc, signer := account.GenerateTestAccount(999) + acc, signer := account.GenerateTestAccount(util.RandInt32(0)) tSandbox.UpdateAccount(signer.Address(), acc) }) @@ -288,42 +288,67 @@ func TestUpdateFromOutsideTheSandbox(t *testing.T) { t.Errorf("The code did not panic") } }() - val, _ := validator.GenerateTestValidator(999) + val, _ := validator.GenerateTestValidator(util.RandInt32(0)) tSandbox.UpdateValidator(val) }) } -func TestDeepCopy(t *testing.T) { +func TestAccountDeepCopy(t *testing.T) { setup(t) - addr := crypto.GenerateTestAddress() - pub, _ := bls.GenerateTestKeyPair() - acc1 := tSandbox.MakeNewAccount(addr) - val1 := tSandbox.MakeNewValidator(pub) + t.Run("non existing account", func(t *testing.T) { + addr := crypto.GenerateTestAddress() + acc := tSandbox.MakeNewAccount(addr) + acc.IncSequence() - acc2 := tSandbox.Account(addr) - val2 := tSandbox.Validator(pub.Address()) + assert.NotEqual(t, tSandbox.Account(addr), acc) + }) - assert.Equal(t, acc1, acc2) - assert.Equal(t, val1.Hash(), val2.Hash()) + t.Run("existing account", func(t *testing.T) { + addr := crypto.TreasuryAddress + acc := tSandbox.Account(addr) + acc.IncSequence() - acc1.IncSequence() - val1.IncSequence() + assert.NotEqual(t, tSandbox.Account(addr), acc) + }) - acc2.AddToBalance(1) - val2.AddToStake(1) + t.Run("sandbox account", func(t *testing.T) { + addr := crypto.TreasuryAddress + acc := tSandbox.Account(addr) + acc.IncSequence() - assert.NotEqual(t, acc1.Hash(), acc2.Hash()) - assert.NotEqual(t, val1.Hash(), val2.Hash()) + assert.NotEqual(t, tSandbox.Account(addr), acc) + assert.NotEqual(t, acc.Sequence(), 1) + }) +} - acc3 := tSandbox.accounts[addr] - val3 := tSandbox.validators[pub.Address()] +func TestValidatorDeepCopy(t *testing.T) { + setup(t) - assert.NotEqual(t, acc1.Hash(), acc3.Account.Hash()) - assert.NotEqual(t, val1.Hash(), val3.Validator.Hash()) + t.Run("non existing validator", func(t *testing.T) { + pub, _ := bls.GenerateTestKeyPair() + acc := tSandbox.MakeNewValidator(pub) + acc.IncSequence() + + assert.NotEqual(t, tSandbox.Validator(pub.Address()), acc) + }) + + val0, _ := tStore.ValidatorByNumber(0) + addr := val0.Address() + t.Run("existing validator", func(t *testing.T) { + acc := tSandbox.Validator(addr) + acc.IncSequence() - assert.NotEqual(t, acc2.Hash(), acc3.Account.Hash()) - assert.NotEqual(t, val2.Hash(), val3.Validator.Hash()) + assert.NotEqual(t, tSandbox.Validator(addr), acc) + }) + + t.Run("sandbox validator", func(t *testing.T) { + acc := tSandbox.Validator(addr) + acc.IncSequence() + + assert.NotEqual(t, tSandbox.Validator(addr), acc) + assert.NotEqual(t, acc.Sequence(), 1) + }) } func TestRecentBlockByStamp(t *testing.T) { diff --git a/state/state.go b/state/state.go index 46cb63aa7..01818e9ca 100644 --- a/state/state.go +++ b/state/state.go @@ -495,26 +495,26 @@ func (st *state) Fingerprint() string { func (st *state) commitSandbox(sb sandbox.Sandbox, round int16) { joined := make([]*validator.Validator, 0) currentHeight := sb.CurrentHeight() - sb.IterateValidators(func(vs *sandbox.ValidatorStatus) { - if vs.Validator.LastJoinedHeight() == currentHeight { - st.logger.Info("new validator joined", "address", vs.Validator.Address(), "power", vs.Validator.Power()) + sb.IterateValidators(func(val *validator.Validator, updated bool) { + if val.LastJoinedHeight() == currentHeight { + st.logger.Info("new validator joined", "address", val.Address(), "power", val.Power()) - joined = append(joined, &vs.Validator) + joined = append(joined, val) } }) st.committee.Update(round, joined) - sb.IterateAccounts(func(addr crypto.Address, as *sandbox.AccountStatus) { - if as.Updated { - st.store.UpdateAccount(addr, &as.Account) - st.accountMerkle.SetHash(int(as.Account.Number()), as.Account.Hash()) + sb.IterateAccounts(func(addr crypto.Address, acc *account.Account, updated bool) { + if updated { + st.store.UpdateAccount(addr, acc) + st.accountMerkle.SetHash(int(acc.Number()), acc.Hash()) } }) - sb.IterateValidators(func(vs *sandbox.ValidatorStatus) { - if vs.Updated { - st.store.UpdateValidator(&vs.Validator) - st.validatorMerkle.SetHash(int(vs.Validator.Number()), vs.Validator.Hash()) + sb.IterateValidators(func(val *validator.Validator, updated bool) { + if updated { + st.store.UpdateValidator(val) + st.validatorMerkle.SetHash(int(val.Number()), val.Hash()) } }) } diff --git a/store/account.go b/store/account.go index 4e9ea85f4..664970d35 100644 --- a/store/account.go +++ b/store/account.go @@ -1,8 +1,6 @@ package store import ( - "fmt" - "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/types/account" "github.com/pactus-project/pactus/util/logger" @@ -11,83 +9,81 @@ import ( ) type accountStore struct { - db *leveldb.DB - numberMap map[int32]*account.Account - total int32 + db *leveldb.DB + numberMap map[int32]*account.Account + addressMap map[crypto.Address]*account.Account + total int32 } func accountKey(addr crypto.Address) []byte { return append(accountPrefix, addr.Bytes()...) } func newAccountStore(db *leveldb.DB) *accountStore { - as := &accountStore{ - db: db, - numberMap: make(map[int32]*account.Account), - } total := int32(0) - as.iterateAccounts(func(_ crypto.Address, acc *account.Account) bool { - as.numberMap[acc.Number()] = acc + numberMap := make(map[int32]*account.Account) + addressMap := make(map[crypto.Address]*account.Account) + r := util.BytesPrefix(accountPrefix) + iter := db.NewIterator(r, nil) + for iter.Next() { + key := iter.Key() + value := iter.Value() + + acc, err := account.FromBytes(value) + if err != nil { + logger.Panic("unable to decode account", "err", err) + } + + var addr crypto.Address + copy(addr[:], key[1:]) + + numberMap[acc.Number()] = acc + addressMap[addr] = acc total++ - return false - }) - as.total = total + } + iter.Release() - return as + return &accountStore{ + db: db, + total: total, + numberMap: numberMap, + addressMap: addressMap, + } } func (as *accountStore) hasAccount(addr crypto.Address) bool { - has, err := as.db.Has(accountKey(addr), nil) - if err != nil { - return false - } - return has + _, ok := as.addressMap[addr] + return ok } func (as *accountStore) account(addr crypto.Address) (*account.Account, error) { - bs, err := tryGet(as.db, accountKey(addr)) - if err != nil { - return nil, err - } - - acc, err := account.FromBytes(bs) - if err != nil { - return nil, err + acc, ok := as.addressMap[addr] + if ok { + return acc.Clone(), nil } - return acc, nil + return nil, ErrNotFound } func (as *accountStore) accountByNumber(number int32) (*account.Account, error) { - val, ok := as.numberMap[number] + acc, ok := as.numberMap[number] if ok { - return val, nil + return acc.Clone(), nil } - return nil, fmt.Errorf("account not found") + return nil, ErrNotFound } func (as *accountStore) iterateAccounts(consumer func(crypto.Address, *account.Account) (stop bool)) { - r := util.BytesPrefix(accountPrefix) - iter := as.db.NewIterator(r, nil) - for iter.Next() { - key := iter.Key() - value := iter.Value() - - var addr crypto.Address - copy(addr[:], key[1:]) - - acc, err := account.FromBytes(value) - if err != nil { - logger.Panic("unable to decode account", "err", err) - } - - stopped := consumer(addr, acc) + for addr, acc := range as.addressMap { + stopped := consumer(addr, acc.Clone()) if stopped { return } } - iter.Release() } +// This function takes ownership of the account pointer. +// It is important that the caller should not modify the account data and +// keep it immutable. func (as *accountStore) updateAccount(batch *leveldb.Batch, addr crypto.Address, acc *account.Account) { data, err := acc.Bytes() if err != nil { @@ -97,6 +93,7 @@ func (as *accountStore) updateAccount(batch *leveldb.Batch, addr crypto.Address, as.total++ } as.numberMap[acc.Number()] = acc + as.addressMap[addr] = acc batch.Put(accountKey(addr), data) } diff --git a/store/account_test.go b/store/account_test.go index 797b79ff3..10963cab0 100644 --- a/store/account_test.go +++ b/store/account_test.go @@ -3,6 +3,8 @@ package store import ( "testing" + "github.com/pactus-project/pactus/crypto" + "github.com/pactus-project/pactus/crypto/hash" "github.com/pactus-project/pactus/types/account" "github.com/pactus-project/pactus/util" "github.com/stretchr/testify/assert" @@ -12,9 +14,10 @@ import ( func TestAccountCounter(t *testing.T) { setup(t) - acc, signer := account.GenerateTestAccount(util.RandInt32(1000)) + num := util.RandInt32(1000) + acc, signer := account.GenerateTestAccount(num) - t.Run("Update count after adding new account", func(t *testing.T) { + t.Run("Add new account, should increase the total accounts number", func(t *testing.T) { assert.Zero(t, tStore.TotalAccounts()) tStore.UpdateAccount(signer.Address(), acc) @@ -22,70 +25,177 @@ func TestAccountCounter(t *testing.T) { assert.Equal(t, tStore.TotalAccounts(), int32(1)) }) - t.Run("Update account, should not increase counter", func(t *testing.T) { + t.Run("Update account, should not increase the total accounts number", func(t *testing.T) { acc.AddToBalance(1) tStore.UpdateAccount(signer.Address(), acc) + assert.NoError(t, tStore.WriteBatch()) assert.Equal(t, tStore.TotalAccounts(), int32(1)) }) t.Run("Get account", func(t *testing.T) { - assert.True(t, tStore.HasAccount(signer.Address())) - acc2, err := tStore.Account(signer.Address()) + acc1, err := tStore.Account(signer.Address()) + assert.NoError(t, err) + + acc2, err := tStore.AccountByNumber(num) assert.NoError(t, err) - assert.Equal(t, acc2.Hash(), acc.Hash()) + + assert.Equal(t, acc1.Hash(), acc2.Hash()) + assert.Equal(t, tStore.TotalAccounts(), int32(1)) + assert.True(t, tStore.HasAccount(signer.Address())) }) } func TestAccountBatchSaving(t *testing.T) { setup(t) - t.Run("Add 100 accounts", func(t *testing.T) { - for i := 0; i < 100; i++ { - acc, signer := account.GenerateTestAccount(int32(i)) + total := util.RandInt32(100) + t.Run("Add some accounts", func(t *testing.T) { + for i := int32(0); i < total; i++ { + acc, signer := account.GenerateTestAccount(i) tStore.UpdateAccount(signer.Address(), acc) } assert.NoError(t, tStore.WriteBatch()) - assert.Equal(t, tStore.TotalAccounts(), int32(100)) + assert.Equal(t, tStore.TotalAccounts(), total) }) + t.Run("Close and load db", func(t *testing.T) { tStore.Close() store, _ := NewStore(tStore.config, 21) - assert.Equal(t, store.TotalAccounts(), int32(100)) + assert.Equal(t, store.TotalAccounts(), total) }) } func TestAccountByNumber(t *testing.T) { setup(t) + total := util.RandInt32(100) + 1 // +1 when random number is zero t.Run("Add some accounts", func(t *testing.T) { - for i := 0; i < 10; i++ { - val, signer := account.GenerateTestAccount(int32(i)) - tStore.UpdateAccount(signer.Address(), val) + for i := int32(0); i < total; i++ { + acc, signer := account.GenerateTestAccount(i) + tStore.UpdateAccount(signer.Address(), acc) } assert.NoError(t, tStore.WriteBatch()) + assert.Equal(t, tStore.TotalAccounts(), total) + }) - v, err := tStore.AccountByNumber(5) + t.Run("Get a random account", func(t *testing.T) { + num := util.RandInt32(total) + acc, err := tStore.AccountByNumber(num) assert.NoError(t, err) - require.NotNil(t, v) - assert.Equal(t, v.Number(), int32(5)) + require.NotNil(t, acc) + assert.Equal(t, acc.Number(), num) + }) - v, err = tStore.AccountByNumber(11) + t.Run("negative number", func(t *testing.T) { + acc, err := tStore.AccountByNumber(-1) assert.Error(t, err) - assert.Nil(t, v) + assert.Nil(t, acc) + }) + + t.Run("Non existing account", func(t *testing.T) { + acc, err := tStore.AccountByNumber(total + 1) + assert.Error(t, err) + assert.Nil(t, acc) }) t.Run("Reopen the store", func(t *testing.T) { tStore.Close() store, _ := NewStore(tStore.config, 21) - v, err := store.AccountByNumber(5) + num := util.RandInt32(total) + acc, err := store.AccountByNumber(num) assert.NoError(t, err) - require.NotNil(t, v) - assert.Equal(t, v.Number(), int32(5)) + require.NotNil(t, acc) + assert.Equal(t, acc.Number(), num) - v, err = tStore.AccountByNumber(11) + acc, err = tStore.AccountByNumber(total + 1) assert.Error(t, err) - assert.Nil(t, v) + assert.Nil(t, acc) }) } + +func TestAccountByAddress(t *testing.T) { + setup(t) + + total := util.RandInt32(100) + var lastAddr crypto.Address + t.Run("Add some accounts", func(t *testing.T) { + for i := int32(0); i < total; i++ { + acc, signer := account.GenerateTestAccount(i) + tStore.UpdateAccount(signer.Address(), acc) + + lastAddr = signer.Address() + } + assert.NoError(t, tStore.WriteBatch()) + assert.Equal(t, tStore.TotalAccounts(), total) + }) + + t.Run("Get random account", func(t *testing.T) { + acc, err := tStore.Account(lastAddr) + assert.NoError(t, err) + require.NotNil(t, acc) + assert.Equal(t, acc.Number(), total-1) + }) + + t.Run("Unknown address", func(t *testing.T) { + acc, err := tStore.Account(crypto.GenerateTestAddress()) + assert.Error(t, err) + assert.Nil(t, acc) + }) + + t.Run("Reopen the store", func(t *testing.T) { + tStore.Close() + store, _ := NewStore(tStore.config, 21) + + acc, err := store.Account(lastAddr) + assert.NoError(t, err) + require.NotNil(t, acc) + assert.Equal(t, acc.Number(), total-1) + }) +} + +func TestIterateAccounts(t *testing.T) { + setup(t) + + total := util.RandInt32(100) + accs1 := []hash.Hash{} + for i := int32(0); i < total; i++ { + acc, signer := account.GenerateTestAccount(i) + tStore.UpdateAccount(signer.Address(), acc) + accs1 = append(accs1, acc.Hash()) + } + assert.NoError(t, tStore.WriteBatch()) + + accs2 := []hash.Hash{} + tStore.IterateAccounts(func(_ crypto.Address, acc *account.Account) bool { + accs2 = append(accs2, acc.Hash()) + return false + }) + assert.ElementsMatch(t, accs1, accs2) + + stopped := false + tStore.IterateAccounts(func(addr crypto.Address, acc *account.Account) bool { + if acc.Hash().EqualsTo(accs1[0]) { + stopped = true + } + return stopped + }) + assert.True(t, stopped) +} + +func TestAccountDeepCopy(t *testing.T) { + setup(t) + + num := util.RandInt32(1000) + acc1, signer := account.GenerateTestAccount(num) + tStore.UpdateAccount(signer.Address(), acc1) + + acc2, _ := tStore.AccountByNumber(num) + acc2.IncSequence() + assert.NotEqual(t, tStore.accountStore.numberMap[num].Hash(), acc2.Hash()) + + acc3, _ := tStore.Account(signer.Address()) + acc3.IncSequence() + assert.NotEqual(t, tStore.accountStore.numberMap[num].Hash(), acc3.Hash()) +} diff --git a/store/store.go b/store/store.go index d57e95462..d4342c595 100644 --- a/store/store.go +++ b/store/store.go @@ -2,7 +2,7 @@ package store import ( "bytes" - "fmt" + "errors" "sync" "github.com/pactus-project/pactus/crypto" @@ -19,6 +19,11 @@ import ( "github.com/syndtr/goleveldb/leveldb/opt" ) +var ( + ErrNotFound = errors.New("not found") + ErrBadOffset = errors.New("offset is out of range") +) + const lastStoreVersion = int32(1) // TODO: add cache for me @@ -202,7 +207,7 @@ func (s *store) Transaction(id tx.ID) (*StoredTx, error) { start := pos.offset end := pos.offset + pos.length if end > uint32(len(data)) { - return nil, fmt.Errorf("offset is out of range") // TODO: Shall we panic here? + return nil, ErrBadOffset } blockTime := util.SliceToUint32(data[hash.HashSize+1 : hash.HashSize+5]) @@ -327,6 +332,8 @@ func (s *store) WriteBatch() error { defer s.lk.Unlock() if err := s.db.Write(s.batch, nil); err != nil { + // TODO: Should we panic here? + // The store is unreliable if the stored data does not match the cached data. return err } s.batch.Reset() diff --git a/store/store_test.go b/store/store_test.go index 0dd2b7a98..defd6ff07 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -3,11 +3,8 @@ package store import ( "testing" - "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/crypto/hash" - "github.com/pactus-project/pactus/types/account" "github.com/pactus-project/pactus/types/block" - "github.com/pactus-project/pactus/types/validator" "github.com/pactus-project/pactus/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -57,32 +54,14 @@ func TestBlockHeight(t *testing.T) { assert.Equal(t, tStore.BlockHeight(sb.BlockHash), uint32(1)) } -func TestReturnNilForNonExistingItems(t *testing.T) { +func TestUnknownTransactionID(t *testing.T) { setup(t) - lastHeight, _ := tStore.LastCertificate() - - assert.Equal(t, tStore.BlockHash(lastHeight+1), hash.UndefHash) - assert.Equal(t, tStore.BlockHash(0), hash.UndefHash) - - block, err := tStore.Block(lastHeight + 1) - assert.Error(t, err) - assert.Nil(t, block) - tx, err := tStore.Transaction(hash.GenerateTestHash()) assert.Error(t, err) assert.Nil(t, tx) - - acc, err := tStore.Account(crypto.GenerateTestAddress()) - assert.Error(t, err) - assert.Nil(t, acc) - - val, err := tStore.Validator(crypto.GenerateTestAddress()) - assert.Error(t, err) - assert.Nil(t, val) - - assert.NoError(t, tStore.Close()) } + func TestWriteAndClosePeacefully(t *testing.T) { setup(t) @@ -107,128 +86,6 @@ func TestRetrieveBlockAndTransactions(t *testing.T) { } } -func TestRetrieveAccount(t *testing.T) { - setup(t) - - acc, signer := account.GenerateTestAccount(util.RandInt32(10000)) - - t.Run("Add account, should able to retrieve", func(t *testing.T) { - assert.False(t, tStore.HasAccount(signer.Address())) - tStore.UpdateAccount(signer.Address(), acc) - assert.NoError(t, tStore.WriteBatch()) - assert.True(t, tStore.HasAccount(signer.Address())) - acc2, err := tStore.Account(signer.Address()) - assert.NoError(t, err) - assert.Equal(t, acc, acc2) - }) - - t.Run("Update account, should update database", func(t *testing.T) { - acc.AddToBalance(1) - tStore.UpdateAccount(signer.Address(), acc) - assert.NoError(t, tStore.WriteBatch()) - acc2, err := tStore.Account(signer.Address()) - assert.NoError(t, err) - assert.Equal(t, acc, acc2) - }) - assert.Equal(t, tStore.TotalAccounts(), int32(1)) - - // Should not crash - assert.NoError(t, tStore.Close()) - _, err := tStore.Account(signer.Address()) - assert.Error(t, err) -} - -func TestRetrieveValidator(t *testing.T) { - setup(t) - - val, _ := validator.GenerateTestValidator(util.RandInt32(1000)) - - t.Run("Add validator, should able to retrieve", func(t *testing.T) { - assert.False(t, tStore.HasValidator(val.Address())) - tStore.UpdateValidator(val) - assert.NoError(t, tStore.WriteBatch()) - assert.True(t, tStore.HasValidator(val.Address())) - val2, err := tStore.Validator(val.Address()) - assert.NoError(t, err) - assert.Equal(t, val.Hash(), val2.Hash()) - }) - - t.Run("Update validator, should update database", func(t *testing.T) { - val.AddToStake(1) - tStore.UpdateValidator(val) - assert.NoError(t, tStore.WriteBatch()) - val2, err := tStore.Validator(val.Address()) - assert.NoError(t, err) - assert.Equal(t, val.Hash(), val2.Hash()) - }) - - assert.Equal(t, tStore.TotalValidators(), int32(1)) - val2, _ := tStore.ValidatorByNumber(val.Number()) - assert.Equal(t, val.Hash(), val2.Hash()) - - assert.NoError(t, tStore.Close()) - _, err := tStore.Validator(val.Address()) - assert.Error(t, err) -} - -func TestIterateAccounts(t *testing.T) { - setup(t) - - accs1 := []hash.Hash{} - for i := 0; i < 10; i++ { - acc, signer := account.GenerateTestAccount(int32(i)) - tStore.UpdateAccount(signer.Address(), acc) - assert.NoError(t, tStore.WriteBatch()) - accs1 = append(accs1, acc.Hash()) - } - - stopped := false - tStore.IterateAccounts(func(addr crypto.Address, acc *account.Account) bool { - if acc.Hash().EqualsTo(accs1[0]) { - stopped = true - } - return stopped - }) - assert.True(t, stopped) - - accs2 := []hash.Hash{} - tStore.IterateAccounts(func(addr crypto.Address, acc *account.Account) bool { - accs2 = append(accs2, acc.Hash()) - return false - }) - - assert.ElementsMatch(t, accs1, accs2) -} - -func TestIterateValidators(t *testing.T) { - setup(t) - - vals1 := []hash.Hash{} - for i := 0; i < 10; i++ { - val, _ := validator.GenerateTestValidator(int32(i)) - tStore.UpdateValidator(val) - assert.NoError(t, tStore.WriteBatch()) - vals1 = append(vals1, val.Hash()) - } - - stopped := false - tStore.IterateValidators(func(val *validator.Validator) bool { - if val.Hash().EqualsTo(vals1[0]) { - stopped = true - } - return stopped - }) - assert.True(t, stopped) - - vals2 := []hash.Hash{} - tStore.IterateValidators(func(val *validator.Validator) bool { - vals2 = append(vals2, val.Hash()) - return false - }) - - assert.ElementsMatch(t, vals1, vals2) -} - func TestRecentBlockByStamp(t *testing.T) { setup(t) diff --git a/store/validator.go b/store/validator.go index 95d76aef9..277799ae8 100644 --- a/store/validator.go +++ b/store/validator.go @@ -1,8 +1,6 @@ package store import ( - "fmt" - "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/types/validator" "github.com/pactus-project/pactus/util/logger" @@ -11,83 +9,78 @@ import ( ) type validatorStore struct { - db *leveldb.DB - valMap map[int32]*validator.Validator - total int32 + db *leveldb.DB + numberMap map[int32]*validator.Validator + addressMap map[crypto.Address]*validator.Validator + total int32 } func validatorKey(addr crypto.Address) []byte { return append(validatorPrefix, addr.Bytes()...) } func newValidatorStore(db *leveldb.DB) *validatorStore { - vs := &validatorStore{ - db: db, - } - total := int32(0) - valMap := make(map[int32]*validator.Validator) - vs.iterateValidators(func(val *validator.Validator) bool { - valMap[val.Number()] = val - total++ - return false - }) + numberMap := make(map[int32]*validator.Validator) + addressMap := make(map[crypto.Address]*validator.Validator) + r := util.BytesPrefix(validatorPrefix) + iter := db.NewIterator(r, nil) + for iter.Next() { + // key := iter.Key() + value := iter.Value() + + val, err := validator.FromBytes(value) + if err != nil { + logger.Panic("unable to decode validator", "err", err) + } - vs.total = total - vs.valMap = valMap + numberMap[val.Number()] = val + addressMap[val.Address()] = val + total++ + } + iter.Release() - return vs + return &validatorStore{ + db: db, + total: total, + numberMap: numberMap, + addressMap: addressMap, + } } func (vs *validatorStore) hasValidator(addr crypto.Address) bool { - has, err := vs.db.Has(validatorKey(addr), nil) - if err != nil { - return false - } - return has + _, ok := vs.addressMap[addr] + return ok } func (vs *validatorStore) validator(addr crypto.Address) (*validator.Validator, error) { - data, err := tryGet(vs.db, validatorKey(addr)) - if err != nil { - return nil, err - } - - val, err := validator.FromBytes(data) - if err != nil { - return nil, err + val, ok := vs.addressMap[addr] + if ok { + return val.Clone(), nil } - return val, nil + return nil, ErrNotFound } func (vs *validatorStore) validatorByNumber(num int32) (*validator.Validator, error) { - val, ok := vs.valMap[num] + val, ok := vs.numberMap[num] if ok { - return val, nil + return val.Clone(), nil } - return nil, fmt.Errorf("not found") + return nil, ErrNotFound } func (vs *validatorStore) iterateValidators(consumer func(*validator.Validator) (stop bool)) { - r := util.BytesPrefix(validatorPrefix) - iter := vs.db.NewIterator(r, nil) - for iter.Next() { - // key := iter.Key() - value := iter.Value() - - val, err := validator.FromBytes(value) - if err != nil { - logger.Panic("unable to decode validator", "err", err) - } - - stopped := consumer(val) + for _, val := range vs.addressMap { + stopped := consumer(val.Clone()) if stopped { return } } - iter.Release() } +// This function takes ownership of the validator pointer. +// It is important that the caller should not modify the validator data and +// keep it immutable. func (vs *validatorStore) updateValidator(batch *leveldb.Batch, val *validator.Validator) { data, err := val.Bytes() if err != nil { @@ -96,7 +89,8 @@ func (vs *validatorStore) updateValidator(batch *leveldb.Batch, val *validator.V if !vs.hasValidator(val.Address()) { vs.total++ } - vs.valMap[val.Number()] = val + vs.numberMap[val.Number()] = val + vs.addressMap[val.Address()] = val batch.Put(validatorKey(val.Address()), data) } diff --git a/store/validator_test.go b/store/validator_test.go index 213ac2098..f4ddb6507 100644 --- a/store/validator_test.go +++ b/store/validator_test.go @@ -3,6 +3,8 @@ package store import ( "testing" + "github.com/pactus-project/pactus/crypto" + "github.com/pactus-project/pactus/crypto/hash" "github.com/pactus-project/pactus/types/validator" "github.com/pactus-project/pactus/util" "github.com/stretchr/testify/assert" @@ -11,99 +13,190 @@ import ( func TestValidatorCounter(t *testing.T) { setup(t) - val, _ := validator.GenerateTestValidator(util.RandInt32(1000)) - t.Run("Update count after adding new validator", func(t *testing.T) { + num := util.RandInt32(1000) + val, _ := validator.GenerateTestValidator(num) + + t.Run("Add new validator, should increase the total validators number", func(t *testing.T) { assert.Zero(t, tStore.TotalValidators()) + tStore.UpdateValidator(val) assert.NoError(t, tStore.WriteBatch()) assert.Equal(t, tStore.TotalValidators(), int32(1)) }) - t.Run("Update validator, should not increase counter", func(t *testing.T) { + t.Run("Update validator, should not increase the total validators number", func(t *testing.T) { val.AddToStake(1) - tStore.UpdateValidator(val) + assert.NoError(t, tStore.WriteBatch()) assert.Equal(t, tStore.TotalValidators(), int32(1)) }) t.Run("Get validator", func(t *testing.T) { - assert.True(t, tStore.HasValidator(val.Address())) - val2, err := tStore.Validator(val.Address()) + val1, err := tStore.Validator(val.Address()) + assert.NoError(t, err) + + val2, err := tStore.ValidatorByNumber(num) assert.NoError(t, err) - assert.Equal(t, val2.Hash(), val.Hash()) + + assert.Equal(t, val1.Hash(), val2.Hash()) + assert.Equal(t, tStore.TotalValidators(), int32(1)) + assert.True(t, tStore.HasValidator(val.Address())) }) } func TestValidatorBatchSaving(t *testing.T) { setup(t) - t.Run("Add 100 validators", func(t *testing.T) { - for i := 0; i < 100; i++ { - val, _ := validator.GenerateTestValidator(int32(i)) + total := util.RandInt32(100) + t.Run("Add some validators", func(t *testing.T) { + for i := int32(0); i < total; i++ { + val, _ := validator.GenerateTestValidator(i) tStore.UpdateValidator(val) - assert.NoError(t, tStore.WriteBatch()) } - - assert.Equal(t, tStore.TotalValidators(), int32(100)) + assert.NoError(t, tStore.WriteBatch()) + assert.Equal(t, tStore.TotalValidators(), total) }) + t.Run("Close and load db", func(t *testing.T) { tStore.Close() store, _ := NewStore(tStore.config, 21) - assert.Equal(t, store.TotalValidators(), int32(100)) + assert.Equal(t, store.TotalValidators(), total) }) } func TestValidatorByNumber(t *testing.T) { setup(t) + total := util.RandInt32(100) + 1 // +1 when random number is zero t.Run("Add some validators", func(t *testing.T) { - for i := 0; i < 10; i++ { - val, _ := validator.GenerateTestValidator(int32(i)) + for i := int32(0); i < total; i++ { + val, _ := validator.GenerateTestValidator(i) tStore.UpdateValidator(val) } assert.NoError(t, tStore.WriteBatch()) + assert.Equal(t, tStore.TotalValidators(), total) + }) - v, err := tStore.ValidatorByNumber(5) + t.Run("Get a random Validator", func(t *testing.T) { + num := util.RandInt32(total) + val, err := tStore.ValidatorByNumber(num) assert.NoError(t, err) - require.NotNil(t, v) - assert.Equal(t, v.Number(), int32(5)) + require.NotNil(t, val) + assert.Equal(t, val.Number(), num) + }) - v, err = tStore.ValidatorByNumber(11) + t.Run("Negative number", func(t *testing.T) { + val, err := tStore.ValidatorByNumber(-1) assert.Error(t, err) - assert.Nil(t, v) + assert.Nil(t, val) + }) + + t.Run("Non existing validator", func(t *testing.T) { + val, err := tStore.ValidatorByNumber(total + 1) + assert.Error(t, err) + assert.Nil(t, val) }) t.Run("Reopen the store", func(t *testing.T) { tStore.Close() store, _ := NewStore(tStore.config, 21) - v, err := store.ValidatorByNumber(5) + num := util.RandInt32(total) + val, err := store.ValidatorByNumber(num) assert.NoError(t, err) - require.NotNil(t, v) - assert.Equal(t, v.Number(), int32(5)) + require.NotNil(t, val) + assert.Equal(t, val.Number(), num) - v, err = tStore.ValidatorByNumber(11) + val, err = tStore.ValidatorByNumber(total + 1) assert.Error(t, err) - assert.Nil(t, v) + assert.Nil(t, val) }) } -func TestUpdateValidator(t *testing.T) { +func TestValidatorByAddress(t *testing.T) { setup(t) - val1, _ := validator.GenerateTestValidator(0) - tStore.UpdateValidator(val1) + total := util.RandInt32(100) + 1 + t.Run("Add some validators", func(t *testing.T) { + for i := int32(0); i < total; i++ { + val, _ := validator.GenerateTestValidator(i) + tStore.UpdateValidator(val) + } + assert.NoError(t, tStore.WriteBatch()) + assert.Equal(t, tStore.TotalValidators(), total) + }) + + t.Run("Get random validator", func(t *testing.T) { + num := util.RandInt32(total) + val0, _ := tStore.ValidatorByNumber(num) + val, err := tStore.Validator(val0.Address()) + assert.NoError(t, err) + require.NotNil(t, val) + assert.Equal(t, val.Number(), num) + }) + + t.Run("Unknown address", func(t *testing.T) { + val, err := tStore.Validator(crypto.GenerateTestAddress()) + assert.Error(t, err) + assert.Nil(t, val) + }) + + t.Run("Reopen the store", func(t *testing.T) { + tStore.Close() + store, _ := NewStore(tStore.config, 21) + + num := util.RandInt32(total) + val0, _ := store.ValidatorByNumber(num) + val, err := store.Validator(val0.Address()) + assert.NoError(t, err) + require.NotNil(t, val) + assert.Equal(t, val.Number(), num) + }) +} + +func TestIterateValidators(t *testing.T) { + setup(t) + + total := util.RandInt32(100) + vals1 := []hash.Hash{} + for i := int32(0); i < total; i++ { + val, _ := validator.GenerateTestValidator(i) + tStore.UpdateValidator(val) + vals1 = append(vals1, val.Hash()) + } assert.NoError(t, tStore.WriteBatch()) - val2, _ := tStore.ValidatorByNumber(val1.Number()) - assert.Equal(t, val1.Hash(), val2.Hash()) + vals2 := []hash.Hash{} + tStore.IterateValidators(func(val *validator.Validator) bool { + vals2 = append(vals2, val.Hash()) + return false + }) + assert.ElementsMatch(t, vals1, vals2) + + stopped := false + tStore.IterateValidators(func(val *validator.Validator) bool { + if val.Hash().EqualsTo(vals1[0]) { + stopped = true + } + return stopped + }) + assert.True(t, stopped) +} + +func TestValidatorDeepCopy(t *testing.T) { + setup(t) + + num := util.RandInt32(1000) + val1, _ := validator.GenerateTestValidator(num) + tStore.UpdateValidator(val1) + + val2, _ := tStore.ValidatorByNumber(num) + val2.IncSequence() + assert.NotEqual(t, tStore.validatorStore.numberMap[num].Hash(), val2.Hash()) val3, _ := tStore.Validator(val1.Address()) - val3.AddToStake(10000) - tStore.UpdateValidator(val3) - assert.NoError(t, tStore.WriteBatch()) - val4, _ := tStore.ValidatorByNumber(val1.Number()) - assert.Equal(t, val4.Hash(), val3.Hash()) + val3.IncSequence() + assert.NotEqual(t, tStore.validatorStore.numberMap[num].Hash(), val3.Hash()) } diff --git a/types/account/account.go b/types/account/account.go index 31dc57371..acdf36de5 100644 --- a/types/account/account.go +++ b/types/account/account.go @@ -1,3 +1,4 @@ +// Package account provides functionality for managing account information. package account import ( @@ -10,18 +11,19 @@ import ( "github.com/pactus-project/pactus/util/encoding" ) -// Account represents a structure for an account information. +// The Account struct represents a account object. type Account struct { data accountData } +// accountData contains the data associated with a account. type accountData struct { Number int32 Sequence int32 Balance int64 } -// NewAccount constructs a new account object. +// NewAccount constructs a new account from the given number. func NewAccount(number int32) *Account { return &Account{ data: accountData{ @@ -46,22 +48,37 @@ func FromBytes(data []byte) (*Account, error) { return acc, nil } -func (acc Account) Number() int32 { return acc.data.Number } -func (acc Account) Sequence() int32 { return acc.data.Sequence } -func (acc Account) Balance() int64 { return acc.data.Balance } +// Number returns the number of the account. +func (acc Account) Number() int32 { + return acc.data.Number +} + +// Sequence returns the sequence number of the account. +func (acc Account) Sequence() int32 { + return acc.data.Sequence +} + +// Balance returns the balance of the account. +func (acc Account) Balance() int64 { + return acc.data.Balance +} +// SubtractFromBalance subtracts the given amount from the account's balance. func (acc *Account) SubtractFromBalance(amt int64) { acc.data.Balance -= amt } +// AddToBalance adds the given amount to the account's balance. func (acc *Account) AddToBalance(amt int64) { acc.data.Balance += amt } +// IncSequence increases the sequence anytime this account signs a transaction. func (acc *Account) IncSequence() { acc.data.Sequence++ } +// Hash calculates and returns the hash of the account. func (acc *Account) Hash() hash.Hash { bs, err := acc.Bytes() if err != nil { @@ -69,10 +86,13 @@ func (acc *Account) Hash() hash.Hash { } return hash.CalcHash(bs) } + +// SerializeSize returns the size in bytes required to serialize the account. func (acc *Account) SerializeSize() int { return 16 // 4+4+8 } +// Bytes returns the serialized byte representation of the account. func (acc *Account) Bytes() ([]byte, error) { w := bytes.NewBuffer(make([]byte, 0, acc.SerializeSize())) err := encoding.WriteElements(w, @@ -86,7 +106,14 @@ func (acc *Account) Bytes() ([]byte, error) { return w.Bytes(), nil } -// GenerateTestAccount generates an account for testing purpose. +// Clone creates a deep copy of the account. +func (acc *Account) Clone() *Account { + cloned := new(Account) + *cloned = *acc + return cloned +} + +// GenerateTestAccount generates an account for testing purposes. func GenerateTestAccount(number int32) (*Account, crypto.Signer) { signer := bls.GenerateTestSigner() acc := NewAccount(number) diff --git a/types/account/account_test.go b/types/account/account_test.go index 8128027c1..337b87116 100644 --- a/types/account/account_test.go +++ b/types/account/account_test.go @@ -59,3 +59,11 @@ func TestSubtractFromBalance(t *testing.T) { acc.SubtractFromBalance(1) assert.Equal(t, acc.Balance(), bal-1) } + +func TestClone(t *testing.T) { + acc, _ := GenerateTestAccount(100) + cloned := acc.Clone() + cloned.IncSequence() + + assert.NotEqual(t, acc.Sequence(), cloned.Sequence()) +} diff --git a/types/validator/validator.go b/types/validator/validator.go index 368c2e1c7..f5e94903f 100644 --- a/types/validator/validator.go +++ b/types/validator/validator.go @@ -1,3 +1,4 @@ +// Package validator provides functionality for managing validator information. package validator import ( @@ -10,10 +11,12 @@ import ( "github.com/pactus-project/pactus/util/encoding" ) +// The Validator struct represents a validator object. type Validator struct { data validatorData } +// validatorData contains the data associated with a validator. type validatorData struct { PublicKey *bls.PublicKey Number int32 @@ -24,7 +27,7 @@ type validatorData struct { LastJoinedHeight uint32 } -// NewValidator constructs a new validator object. +// NewValidator constructs a new validator from the given public key and number. func NewValidator(publicKey *bls.PublicKey, number int32) *Validator { val := &Validator{ data: validatorData{ @@ -35,7 +38,7 @@ func NewValidator(publicKey *bls.PublicKey, number int32) *Validator { return val } -// FromBytes constructs a new validator from byte array. +// FromBytes constructs a new validator from a byte array. func FromBytes(data []byte) (*Validator, error) { acc := new(Validator) r := bytes.NewReader(data) @@ -61,43 +64,64 @@ func FromBytes(data []byte) (*Validator, error) { return acc, nil } -func (val *Validator) PublicKey() *bls.PublicKey { return val.data.PublicKey } -func (val *Validator) Address() crypto.Address { return val.data.PublicKey.Address() } -func (val *Validator) Number() int32 { return val.data.Number } -func (val *Validator) Sequence() int32 { return val.data.Sequence } -func (val *Validator) Stake() int64 { return val.data.Stake } +// PublicKey returns the public key of the validator. +func (val *Validator) PublicKey() *bls.PublicKey { + return val.data.PublicKey +} + +// Address returns the address of the validator. +func (val *Validator) Address() crypto.Address { + return val.data.PublicKey.Address() +} + +// Number returns the number of the validator. +func (val *Validator) Number() int32 { + return val.data.Number +} + +// Sequence returns the sequence number of the validator. +func (val *Validator) Sequence() int32 { + return val.data.Sequence +} + +// Stake returns the stake of the validator. +func (val *Validator) Stake() int64 { + return val.data.Stake +} -// LastBondingHeight returns the last height in which validator bonded stake +// LastBondingHeight returns the last height in which the validator bonded stake. func (val *Validator) LastBondingHeight() uint32 { return val.data.LastBondingHeight } -// UnbondingHeight returns the last height in which validator unbonded stake +// UnbondingHeight returns the last height in which the validator unbonded stake. func (val *Validator) UnbondingHeight() uint32 { return val.data.UnbondingHeight } -// LastJoinedHeight returns the last height in which validator joined into the committee +// LastJoinedHeight returns the last height in which the validator joined the committee. func (val *Validator) LastJoinedHeight() uint32 { return val.data.LastJoinedHeight } +// Power returns the power of the validator. func (val Validator) Power() int64 { if val.data.UnbondingHeight > 0 { - // Power for unbonded validators set to zero. + // Power for unbonded validators is set to zero. return 0 } else if val.data.Stake == 0 { - // Only bootstrap validators at the genesis block have no stake + // Only bootstrap validators at the genesis block have no stake. return 1 } return val.data.Stake } +// SubtractFromStake subtracts the given amount from the validator's stake. func (val *Validator) SubtractFromStake(amt int64) { val.data.Stake -= amt } -// AddToStake increases the stake by bonding transaction. +// AddToStake adds the given amount to the validator's stake. func (val *Validator) AddToStake(amt int64) { val.data.Stake += amt } @@ -107,12 +131,12 @@ func (val *Validator) IncSequence() { val.data.Sequence++ } -// UpdateLastJoinedHeight updates the last height that this validator joined the committee. +// UpdateLastJoinedHeight updates the last height at which the validator joined the committee. func (val *Validator) UpdateLastJoinedHeight(height uint32) { val.data.LastJoinedHeight = height } -// UpdateLastBondingHeight updates the last height that this validator bonded some stakes. +// UpdateLastBondingHeight updates the last height at which the validator bonded some stakes. func (val *Validator) UpdateLastBondingHeight(height uint32) { val.data.LastBondingHeight = height } @@ -122,7 +146,7 @@ func (val *Validator) UpdateUnbondingHeight(height uint32) { val.data.UnbondingHeight = height } -// Hash return the hash of this validator. +// Hash calculates and returns the hash of the validator. func (val *Validator) Hash() hash.Hash { bs, err := val.Bytes() if err != nil { @@ -131,10 +155,12 @@ func (val *Validator) Hash() hash.Hash { return hash.CalcHash(bs) } +// SerializeSize returns the size in bytes required to serialize the validator. func (val *Validator) SerializeSize() int { return 124 // 96+4+4+8+4+4+4 } +// Bytes returns returns the serialized byte representation of the validator. func (val *Validator) Bytes() ([]byte, error) { w := bytes.NewBuffer(make([]byte, 0, val.SerializeSize())) @@ -156,7 +182,14 @@ func (val *Validator) Bytes() ([]byte, error) { return w.Bytes(), nil } -// GenerateTestValidator generates a validator for testing purpose. +// Clone creates a deep copy of the validator. +func (val *Validator) Clone() *Validator { + cloned := new(Validator) + *cloned = *val + return cloned +} + +// GenerateTestValidator generates a validator for testing purposes. func GenerateTestValidator(number int32) (*Validator, crypto.Signer) { pub, pv := bls.GenerateTestKeyPair() val := NewValidator(pub, number) diff --git a/types/validator/validator_test.go b/types/validator/validator_test.go index b540be2b3..885048746 100644 --- a/types/validator/validator_test.go +++ b/types/validator/validator_test.go @@ -85,3 +85,11 @@ func TestSubtractFromStake(t *testing.T) { val.SubtractFromStake(1) assert.Equal(t, val.Stake(), stake-1) } + +func TestClone(t *testing.T) { + val, _ := GenerateTestValidator(100) + cloned := val.Clone() + cloned.IncSequence() + + assert.NotEqual(t, val.Sequence(), cloned.Sequence()) +}