diff --git a/cmd/cmd.go b/cmd/cmd.go index 32a1791cb..b9f05d627 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -293,7 +293,7 @@ func CreateNode(numValidators int, chain genesis.ChainType, workingDir string, if err := genDoc.SaveToFile(genPath); err != nil { return nil, nil, err } - conf := config.DefaultConfigTestnet() + conf := config.DefaultConfigTestnet(genDoc.Params()) if err := conf.Save(confPath); err != nil { return nil, nil, err } @@ -304,7 +304,7 @@ func CreateNode(numValidators int, chain genesis.ChainType, workingDir string, return nil, nil, err } - conf := config.DefaultConfigLocalnet() + conf := config.DefaultConfigLocalnet(genDoc.Params()) if err := conf.Save(confPath); err != nil { return nil, nil, err } @@ -338,7 +338,7 @@ func StartNode(workingDir string, passwordFetcher func(*wallet.Wallet) (string, } confPath := PactusConfigPath(workingDir) - conf, err := tryLoadConfig(gen.ChainType(), confPath) + conf, err := tryLoadConfig(gen, confPath) if err != nil { return nil, nil, err } @@ -461,15 +461,15 @@ func makeLocalGenesis(w wallet.Wallet) *genesis.Genesis { return gen } -func tryLoadConfig(chainType genesis.ChainType, confPath string) (*config.Config, error) { +func tryLoadConfig(genDoc *genesis.Genesis, confPath string) (*config.Config, error) { var defConf *config.Config - switch chainType { + switch genDoc.ChainType() { case genesis.Mainnet: panic("not yet implemented!") case genesis.Testnet: - defConf = config.DefaultConfigTestnet() + defConf = config.DefaultConfigTestnet(genDoc.Params()) case genesis.Localnet: - defConf = config.DefaultConfigLocalnet() + defConf = config.DefaultConfigLocalnet(genDoc.Params()) } conf, err := config.LoadFromFile(confPath, true, defConf) @@ -496,7 +496,7 @@ func tryLoadConfig(chainType genesis.ChainType, confPath string) (*config.Config } PrintSuccessMsgf("Config updated.") } else { - switch chainType { + switch genDoc.ChainType() { case genesis.Mainnet: err = config.SaveMainnetConfig(confPath) if err != nil { diff --git a/config/config.go b/config/config.go index 284727938..a497bf269 100644 --- a/config/config.go +++ b/config/config.go @@ -11,6 +11,7 @@ import ( "github.com/pactus-project/pactus/store" "github.com/pactus-project/pactus/sync" "github.com/pactus-project/pactus/txpool" + "github.com/pactus-project/pactus/types/param" "github.com/pactus-project/pactus/util" "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/util/logger" @@ -77,14 +78,21 @@ func defaultConfig() *Config { return conf } -func DefaultConfigMainnet() *Config { +func DefaultConfigMainnet(genParams *param.Params) *Config { conf := defaultConfig() // TO BE DEFINED + + // Store private configs + conf.Store.TxCacheSize = genParams.TransactionToLiveInterval + conf.Store.SortitionCacheSize = genParams.SortitionInterval + conf.Store.AccountCacheSize = 1024 + conf.Store.PublicKeyCacheSize = 1024 + return conf } //nolint:lll // long multi-address -func DefaultConfigTestnet() *Config { +func DefaultConfigTestnet(genParams *param.Params) *Config { conf := defaultConfig() conf.Network.ListenAddrStrings = []string{ "/ip4/0.0.0.0/tcp/21777", "/ip4/0.0.0.0/udp/21777/quic-v1", @@ -121,10 +129,16 @@ func DefaultConfigTestnet() *Config { conf.Nanomsg.Enable = false conf.Nanomsg.Listen = "tcp://127.0.0.1:40799" + // Store private configs + conf.Store.TxCacheSize = genParams.TransactionToLiveInterval + conf.Store.SortitionCacheSize = genParams.SortitionInterval + conf.Store.AccountCacheSize = 1024 + conf.Store.PublicKeyCacheSize = 1024 + return conf } -func DefaultConfigLocalnet() *Config { +func DefaultConfigLocalnet(genParams *param.Params) *Config { conf := defaultConfig() conf.Network.ListenAddrStrings = []string{} conf.Network.EnableRelay = false @@ -143,6 +157,12 @@ func DefaultConfigLocalnet() *Config { conf.Nanomsg.Enable = true conf.Nanomsg.Listen = "tcp://127.0.0.1:0" + // Store private configs + conf.Store.TxCacheSize = genParams.TransactionToLiveInterval + conf.Store.SortitionCacheSize = genParams.SortitionInterval + conf.Store.AccountCacheSize = 1024 + conf.Store.PublicKeyCacheSize = 1024 + return conf } diff --git a/config/config_test.go b/config/config_test.go index 3eef8482b..348c4250a 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "github.com/pactus-project/pactus/types/param" "github.com/pactus-project/pactus/util" "github.com/pactus-project/pactus/util/testsuite" "github.com/stretchr/testify/assert" @@ -13,11 +14,15 @@ func TestSaveMainnetConfig(t *testing.T) { path := util.TempFilePath() assert.NoError(t, SaveMainnetConfig(path)) - defConf := DefaultConfigMainnet() + defConf := DefaultConfigMainnet(param.DefaultParams()) conf, err := LoadFromFile(path, true, defConf) assert.NoError(t, err) assert.NoError(t, conf.BasicCheck()) + assert.Equal(t, conf.Store.TxCacheSize, param.DefaultParams().TransactionToLiveInterval) + assert.Equal(t, conf.Store.SortitionCacheSize, param.DefaultParams().SortitionInterval) + assert.Equal(t, conf.Store.AccountCacheSize, 1024) + assert.Equal(t, conf.Store.PublicKeyCacheSize, 1024) } func TestSaveConfig(t *testing.T) { @@ -25,28 +30,51 @@ func TestSaveConfig(t *testing.T) { conf := defaultConfig() assert.NoError(t, conf.Save(path)) - defConf := DefaultConfigTestnet() + defConf := DefaultConfigTestnet(param.DefaultParams()) conf, err := LoadFromFile(path, true, defConf) assert.NoError(t, err) assert.NoError(t, conf.BasicCheck()) assert.Equal(t, conf.Network.NetworkName, "pactus-testnet-v2") assert.Equal(t, conf.Network.DefaultPort, 21777) + assert.Equal(t, conf.Store.TxCacheSize, param.DefaultParams().TransactionToLiveInterval) + assert.Equal(t, conf.Store.SortitionCacheSize, param.DefaultParams().SortitionInterval) + assert.Equal(t, conf.Store.AccountCacheSize, 1024) + assert.Equal(t, conf.Store.PublicKeyCacheSize, 1024) } func TestLocalnetConfig(t *testing.T) { - conf := DefaultConfigLocalnet() + conf := DefaultConfigLocalnet(param.DefaultParams()) assert.NoError(t, conf.BasicCheck()) assert.Empty(t, conf.Network.ListenAddrStrings) assert.Empty(t, conf.Network.RelayAddrStrings) assert.Equal(t, conf.Network.NetworkName, "pactus-localnet") assert.Equal(t, conf.Network.DefaultPort, 21666) + assert.Equal(t, conf.Store.TxCacheSize, param.DefaultParams().TransactionToLiveInterval) + assert.Equal(t, conf.Store.SortitionCacheSize, param.DefaultParams().SortitionInterval) + assert.Equal(t, conf.Store.AccountCacheSize, 1024) + assert.Equal(t, conf.Store.PublicKeyCacheSize, 1024) +} + +func TestTestnetConfig(t *testing.T) { + conf := DefaultConfigTestnet(param.DefaultParams()) + + assert.NoError(t, conf.BasicCheck()) + assert.NotEmpty(t, conf.Network.ListenAddrStrings) + assert.NotEmpty(t, conf.Network.DefaultRelayAddrStrings) + assert.Empty(t, conf.Network.RelayAddrStrings) + assert.Equal(t, conf.Network.NetworkName, "pactus-testnet-v2") + assert.Equal(t, conf.Network.DefaultPort, 21777) + assert.Equal(t, conf.Store.TxCacheSize, param.DefaultParams().TransactionToLiveInterval) + assert.Equal(t, conf.Store.SortitionCacheSize, param.DefaultParams().SortitionInterval) + assert.Equal(t, conf.Store.AccountCacheSize, 1024) + assert.Equal(t, conf.Store.PublicKeyCacheSize, 1024) } func TestLoadFromFile(t *testing.T) { path := util.TempFilePath() - defConf := DefaultConfigTestnet() + defConf := DefaultConfigTestnet(param.DefaultParams()) _, err := LoadFromFile(path, true, defConf) assert.Error(t, err, "not exists") @@ -58,6 +86,10 @@ func TestLoadFromFile(t *testing.T) { conf, err := LoadFromFile(path, false, defConf) assert.NoError(t, err) assert.Equal(t, conf, defConf) + assert.Equal(t, conf.Store.TxCacheSize, param.DefaultParams().TransactionToLiveInterval) + assert.Equal(t, conf.Store.SortitionCacheSize, param.DefaultParams().SortitionInterval) + assert.Equal(t, conf.Store.AccountCacheSize, 1024) + assert.Equal(t, conf.Store.PublicKeyCacheSize, 1024) } func TestExampleConfig(t *testing.T) { @@ -73,7 +105,7 @@ func TestExampleConfig(t *testing.T) { } } - defaultConf := DefaultConfigMainnet() + defaultConf := DefaultConfigMainnet(param.DefaultParams()) defaultToml := string(defaultConf.toTOML()) exampleToml = strings.ReplaceAll(exampleToml, "##", "") @@ -82,6 +114,10 @@ func TestExampleConfig(t *testing.T) { defaultToml = strings.ReplaceAll(defaultToml, "\n\n", "\n") assert.Equal(t, defaultToml, exampleToml) + assert.Equal(t, defaultConf.Store.TxCacheSize, param.DefaultParams().TransactionToLiveInterval) + assert.Equal(t, defaultConf.Store.SortitionCacheSize, param.DefaultParams().SortitionInterval) + assert.Equal(t, defaultConf.Store.AccountCacheSize, 1024) + assert.Equal(t, defaultConf.Store.PublicKeyCacheSize, 1024) } func TestNodeConfigBasicCheck(t *testing.T) { diff --git a/go.mod b/go.mod index a49775a5a..1004628b7 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 require ( github.com/fxamacker/cbor/v2 v2.5.0 + github.com/gofrs/flock v0.8.1 github.com/google/uuid v1.4.0 github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 @@ -53,7 +54,6 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect - github.com/gofrs/flock v0.8.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect diff --git a/node/node.go b/node/node.go index aa0c2b895..0726f9bdc 100644 --- a/node/node.go +++ b/node/node.go @@ -155,6 +155,7 @@ func (n *Node) Stop() { } // these methods are using by GUI. + func (n *Node) ConsManager() consensus.ManagerReader { return n.consMgr } diff --git a/node/node_test.go b/node/node_test.go index dbfccb14c..03d088076 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -30,7 +30,7 @@ func TestRunningNode(t *testing.T) { gen := genesis.MakeGenesis(util.Now(), map[crypto.Address]*account.Account{crypto.TreasuryAddress: acc}, []*validator.Validator{val}, param.DefaultParams()) - conf := config.DefaultConfigMainnet() + conf := config.DefaultConfigMainnet(param.DefaultParams()) conf.GRPC.Enable = false conf.HTTP.Enable = false conf.Store.Path = util.TempDirPath() diff --git a/sandbox/sandbox.go b/sandbox/sandbox.go index 3965f5ac9..4919788e3 100644 --- a/sandbox/sandbox.go +++ b/sandbox/sandbox.go @@ -277,18 +277,14 @@ func (sb *sandbox) PowerDelta() int64 { // VerifyProof verifies proof of a sortition transaction. func (sb *sandbox) VerifyProof(blockHeight uint32, proof sortition.Proof, val *validator.Validator) bool { - committedBlock, err := sb.store.Block(blockHeight) - if err != nil { - return false - } - // TODO: improvement: - // We can get the sortition seed without parsing the block - blk, err := committedBlock.ToBlock() - if err != nil { + sb.lk.RLock() + defer sb.lk.RUnlock() + + seed := sb.store.SortitionSeed(blockHeight) + if seed == nil { return false } - seed := blk.Header().SortitionSeed() - return sortition.VerifyProof(seed, proof, val.PublicKey(), sb.totalPower, val.Power()) + return sortition.VerifyProof(*seed, proof, val.PublicKey(), sb.totalPower, val.Power()) } func (sb *sandbox) CommitTransaction(trx *tx.Tx) { diff --git a/store/account.go b/store/account.go index e58f38554..453328794 100644 --- a/store/account.go +++ b/store/account.go @@ -1,6 +1,7 @@ package store import ( + lru "github.com/hashicorp/golang-lru/v2" "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/types/account" "github.com/pactus-project/pactus/util/logger" @@ -10,32 +11,22 @@ import ( type accountStore struct { db *leveldb.DB - addressMap map[crypto.Address]*account.Account + accInCache *lru.Cache[crypto.Address, *account.Account] total int32 } func accountKey(addr crypto.Address) []byte { return append(accountPrefix, addr.Bytes()...) } -func newAccountStore(db *leveldb.DB) *accountStore { +func newAccountStore(db *leveldb.DB, cacheSize int) *accountStore { total := int32(0) - numberMap := make(map[int32]*account.Account) - addressMap := make(map[crypto.Address]*account.Account) + addrLruCache, err := lru.New[crypto.Address, *account.Account](cacheSize) + if err != nil { + logger.Panic("unable to create new instance of lru cache", "error", err) + } + r := util.BytesPrefix(accountPrefix) iter := db.NewIterator(r, nil) for iter.Next() { - key := iter.Key() - value := iter.Value() - - acc, err := account.FromBytes(value) - if err != nil { - logger.Panic("unable to decode account", "error", err) - } - - var addr crypto.Address - copy(addr[:], key[1:]) - - numberMap[acc.Number()] = acc - addressMap[addr] = acc total++ } iter.Release() @@ -43,31 +34,59 @@ func newAccountStore(db *leveldb.DB) *accountStore { return &accountStore{ db: db, total: total, - addressMap: addressMap, + accInCache: addrLruCache, } } func (as *accountStore) hasAccount(addr crypto.Address) bool { - _, ok := as.addressMap[addr] + ok := as.accInCache.Contains(addr) + if !ok { + ok = tryHas(as.db, accountKey(addr)) + } return ok } func (as *accountStore) account(addr crypto.Address) (*account.Account, error) { - acc, ok := as.addressMap[addr] + acc, ok := as.accInCache.Get(addr) if ok { return acc.Clone(), nil } - return nil, ErrNotFound + rawData, err := tryGet(as.db, accountKey(addr)) + if err != nil { + return nil, err + } + + acc, err = account.FromBytes(rawData) + if err != nil { + return nil, err + } + + as.accInCache.Add(addr, acc.Clone()) + return acc, nil } func (as *accountStore) iterateAccounts(consumer func(crypto.Address, *account.Account) (stop bool)) { - for addr, acc := range as.addressMap { - stopped := consumer(addr, acc.Clone()) + r := util.BytesPrefix(accountPrefix) + iter := as.db.NewIterator(r, nil) + for iter.Next() { + key := iter.Key() + value := iter.Value() + + acc, err := account.FromBytes(value) + if err != nil { + logger.Panic("unable to decode account", "error", err) + } + + var addr crypto.Address + copy(addr[:], key[1:]) + + stopped := consumer(addr, acc) if stopped { return } } + iter.Release() } // This function takes ownership of the account pointer. @@ -81,7 +100,7 @@ func (as *accountStore) updateAccount(batch *leveldb.Batch, addr crypto.Address, if !as.hasAccount(addr) { as.total++ } - as.addressMap[addr] = acc + as.accInCache.Add(addr, acc) batch.Put(accountKey(addr), data) } diff --git a/store/account_test.go b/store/account_test.go index 6470e0026..6d95cf2bb 100644 --- a/store/account_test.go +++ b/store/account_test.go @@ -11,7 +11,7 @@ import ( ) func TestAccountCounter(t *testing.T) { - td := setup(t) + td := setup(t, nil) num := td.RandInt32(1000) acc, addr := td.GenerateTestAccount(num) @@ -43,7 +43,7 @@ func TestAccountCounter(t *testing.T) { } func TestAccountBatchSaving(t *testing.T) { - td := setup(t) + td := setup(t, nil) total := td.RandInt32NonZero(100) t.Run("Add some accounts", func(t *testing.T) { @@ -63,7 +63,7 @@ func TestAccountBatchSaving(t *testing.T) { } func TestAccountByAddress(t *testing.T) { - td := setup(t) + td := setup(t, nil) total := td.RandInt32NonZero(100) var lastAddr crypto.Address @@ -103,7 +103,7 @@ func TestAccountByAddress(t *testing.T) { } func TestIterateAccounts(t *testing.T) { - td := setup(t) + td := setup(t, nil) total := td.RandInt32NonZero(100) hashes1 := []hash.Hash{} @@ -132,7 +132,7 @@ func TestIterateAccounts(t *testing.T) { } func TestAccountDeepCopy(t *testing.T) { - td := setup(t) + td := setup(t, nil) num := td.RandInt32(1000) acc1, addr := td.GenerateTestAccount(num) @@ -140,5 +140,9 @@ func TestAccountDeepCopy(t *testing.T) { acc2, _ := td.store.Account(addr) acc2.AddToBalance(1) - assert.NotEqual(t, td.store.accountStore.addressMap[addr].Hash(), acc2.Hash()) + accCache, _ := td.store.accountStore.accInCache.Get(addr) + assert.NotEqual(t, accCache.Hash(), acc2.Hash()) + + expectedAcc, _ := td.store.accountStore.accInCache.Get(addr) + assert.NotEqual(t, expectedAcc.Hash(), acc2.Hash()) } diff --git a/store/block.go b/store/block.go index c134cbe95..5bda93358 100644 --- a/store/block.go +++ b/store/block.go @@ -5,10 +5,12 @@ import ( "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/crypto/hash" + "github.com/pactus-project/pactus/sortition" "github.com/pactus-project/pactus/types/block" "github.com/pactus-project/pactus/util" "github.com/pactus-project/pactus/util/encoding" "github.com/pactus-project/pactus/util/logger" + "github.com/pactus-project/pactus/util/pairslice" "github.com/syndtr/goleveldb/leveldb" ) @@ -22,12 +24,16 @@ func blockHashKey(h hash.Hash) []byte { } type blockStore struct { - db *leveldb.DB + db *leveldb.DB + sortitionSeedCache *pairslice.PairSlice[uint32, *sortition.VerifiableSeed] + sortitionCacheSize uint32 } -func newBlockStore(db *leveldb.DB) *blockStore { +func newBlockStore(db *leveldb.DB, sortitionCacheSize uint32) *blockStore { return &blockStore{ - db: db, + db: db, + sortitionSeedCache: pairslice.New[uint32, *sortition.VerifiableSeed](int(sortitionCacheSize)), + sortitionCacheSize: sortitionCacheSize, } } @@ -93,6 +99,9 @@ func (bs *blockStore) saveBlock(batch *leveldb.Batch, height uint32, blk *block. batch.Put(blockKey, w.Bytes()) batch.Put(blockHashKey, util.Uint32ToSlice(height)) + sortitionSeed := blk.Header().SortitionSeed() + bs.saveToCache(height, sortitionSeed) + return regs } @@ -113,18 +122,33 @@ func (bs *blockStore) blockHeight(h hash.Hash) uint32 { return util.SliceToUint32(data) } -func (bs *blockStore) hasBlock(height uint32) bool { - has, err := bs.db.Has(blockKey(height), nil) - if err != nil { - return false +func (bs *blockStore) sortitionSeed(blockHeight uint32) *sortition.VerifiableSeed { + startHeight, _, _ := bs.sortitionSeedCache.First() + + if blockHeight < startHeight { + return nil } - return has + + index := blockHeight - startHeight + _, sortitionSeed, ok := bs.sortitionSeedCache.Get(int(index)) + if !ok { + return nil + } + + return sortitionSeed +} + +func (bs *blockStore) hasBlock(height uint32) bool { + return tryHas(bs.db, blockKey(height)) } func (bs *blockStore) hasPublicKey(addr crypto.Address) bool { - has, err := bs.db.Has(publicKeyKey(addr), nil) - if err != nil { - return false + return tryHas(bs.db, publicKeyKey(addr)) +} + +func (bs *blockStore) saveToCache(blockHeight uint32, sortitionSeed sortition.VerifiableSeed) { + bs.sortitionSeedCache.Append(blockHeight, &sortitionSeed) + if bs.sortitionSeedCache.Len() > int(bs.sortitionCacheSize) { + bs.sortitionSeedCache.RemoveFirst() } - return has } diff --git a/store/block_test.go b/store/block_test.go index f8469d77d..d5a0c2e1a 100644 --- a/store/block_test.go +++ b/store/block_test.go @@ -8,12 +8,12 @@ import ( ) func TestBlockStore(t *testing.T) { - td := setup(t) + td := setup(t, nil) lastCert := td.store.LastCertificate() lastHeight := lastCert.Height() - nextBlk, nextCrert := td.GenerateTestBlock(lastHeight + 1) - nextNextBlk, nextNextCrert := td.GenerateTestBlock(lastHeight + 2) + nextBlk, nextCert := td.GenerateTestBlock(lastHeight + 1) + nextNextBlk, nextNextCert := td.GenerateTestBlock(lastHeight + 2) t.Run("Missed block, Should panic ", func(t *testing.T) { defer func() { @@ -21,18 +21,18 @@ func TestBlockStore(t *testing.T) { t.Errorf("The code did not panic") } }() - td.store.SaveBlock(nextNextBlk, nextNextCrert) + td.store.SaveBlock(nextNextBlk, nextNextCert) }) t.Run("Add block, don't batch write", func(t *testing.T) { - td.store.SaveBlock(nextBlk, nextCrert) + td.store.SaveBlock(nextBlk, nextCert) b2, err := td.store.Block(lastHeight + 1) assert.Error(t, err) assert.Nil(t, b2) }) t.Run("Add block, batch write", func(t *testing.T) { - td.store.SaveBlock(nextBlk, nextCrert) + td.store.SaveBlock(nextBlk, nextCert) assert.NoError(t, td.store.WriteBatch()) committedBlock, err := td.store.Block(lastHeight + 1) @@ -44,7 +44,7 @@ func TestBlockStore(t *testing.T) { cert := td.store.LastCertificate() assert.NoError(t, err) - assert.Equal(t, cert.Hash(), nextCrert.Hash()) + assert.Equal(t, cert.Hash(), nextCert.Hash()) }) t.Run("Duplicated block, Should panic ", func(t *testing.T) { @@ -53,6 +53,36 @@ func TestBlockStore(t *testing.T) { t.Errorf("The code did not panic") } }() - td.store.SaveBlock(nextBlk, nextCrert) + td.store.SaveBlock(nextBlk, nextCert) + }) +} + +func TestSortitionSeed(t *testing.T) { + conf := testConfig() + conf.SortitionCacheSize = 7 + + td := setup(t, conf) + lastHeight := td.store.LastCertificate().Height() + + t.Run("Test height zero", func(t *testing.T) { + assert.Nil(t, td.store.SortitionSeed(0)) + }) + + t.Run("Test non existing height", func(t *testing.T) { + assert.Nil(t, td.store.SortitionSeed(lastHeight+1)) + }) + + t.Run("Test not cached height", func(t *testing.T) { + assert.Nil(t, td.store.SortitionSeed(3)) + }) + + t.Run("OK", func(t *testing.T) { + rndInt := td.RandUint32(conf.SortitionCacheSize) + rndInt += lastHeight - conf.SortitionCacheSize + + committedBlk, _ := td.store.Block(rndInt) + blk, _ := committedBlk.ToBlock() + expectedSortition := blk.Header().SortitionSeed() + assert.Equal(t, &expectedSortition, td.store.SortitionSeed(rndInt)) }) } diff --git a/store/config.go b/store/config.go index 20ec421ed..df4c09c64 100644 --- a/store/config.go +++ b/store/config.go @@ -10,11 +10,21 @@ import ( type Config struct { Path string `toml:"path"` + + // Private configs + TxCacheSize uint32 `toml:"-"` + SortitionCacheSize uint32 `toml:"-"` + AccountCacheSize int `toml:"-"` + PublicKeyCacheSize int `toml:"-"` } func DefaultConfig() *Config { return &Config{ - Path: "data", + Path: "data", + TxCacheSize: 0, + SortitionCacheSize: 0, + AccountCacheSize: 0, + PublicKeyCacheSize: 0, } } diff --git a/store/interface.go b/store/interface.go index 830a364a4..ee2f36b2a 100644 --- a/store/interface.go +++ b/store/interface.go @@ -4,6 +4,7 @@ import ( "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/crypto/bls" "github.com/pactus-project/pactus/crypto/hash" + "github.com/pactus-project/pactus/sortition" "github.com/pactus-project/pactus/types/account" "github.com/pactus-project/pactus/types/block" "github.com/pactus-project/pactus/types/certificate" @@ -81,6 +82,7 @@ type Reader interface { Block(height uint32) (*CommittedBlock, error) BlockHeight(h hash.Hash) uint32 BlockHash(height uint32) hash.Hash + SortitionSeed(blockHeight uint32) *sortition.VerifiableSeed Transaction(id tx.ID) (*CommittedTx, error) AnyRecentTransaction(id tx.ID) bool PublicKey(addr crypto.Address) (*bls.PublicKey, error) diff --git a/store/mock.go b/store/mock.go index 3114060be..53e54f7cd 100644 --- a/store/mock.go +++ b/store/mock.go @@ -6,6 +6,7 @@ import ( "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/crypto/bls" "github.com/pactus-project/pactus/crypto/hash" + "github.com/pactus-project/pactus/sortition" "github.com/pactus-project/pactus/types/account" "github.com/pactus-project/pactus/types/block" "github.com/pactus-project/pactus/types/certificate" @@ -65,6 +66,14 @@ func (m *MockStore) BlockHeight(h hash.Hash) uint32 { return 0 } +func (m *MockStore) SortitionSeed(blockHeight uint32) *sortition.VerifiableSeed { + if blk, ok := m.Blocks[blockHeight]; ok { + sortitionSeed := blk.Header().SortitionSeed() + return &sortitionSeed + } + return nil +} + func (m *MockStore) PublicKey(addr crypto.Address) (*bls.PublicKey, error) { for _, block := range m.Blocks { for _, trx := range block.Transactions() { diff --git a/store/store.go b/store/store.go index c0847bdef..27b4d9313 100644 --- a/store/store.go +++ b/store/store.go @@ -5,9 +5,11 @@ import ( "errors" "sync" + lru "github.com/hashicorp/golang-lru/v2" "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/crypto/bls" "github.com/pactus-project/pactus/crypto/hash" + "github.com/pactus-project/pactus/sortition" "github.com/pactus-project/pactus/types/account" "github.com/pactus-project/pactus/types/block" "github.com/pactus-project/pactus/types/certificate" @@ -25,9 +27,9 @@ var ( ErrBadOffset = errors.New("offset is out of range") ) -const lastStoreVersion = int32(1) - -// TODO: add cache for me +const ( + lastStoreVersion = int32(1) +) var ( lastInfoKey = []byte{0x00} @@ -43,16 +45,26 @@ func tryGet(db *leveldb.DB, key []byte) ([]byte, error) { data, err := db.Get(key, nil) if err != nil { // Probably key doesn't exist in database - logger.Trace("database error", "error", err, "key", key) + logger.Trace("database `get` error", "error", err, "key", key) return nil, err } return data, nil } +func tryHas(db *leveldb.DB, key []byte) bool { + ok, err := db.Has(key, nil) + if err != nil { + logger.Error("database `has` error", "error", err, "key", key) + return false + } + return ok +} + type store struct { lk sync.RWMutex config *Config + pubKeyCache *lru.Cache[crypto.Address, *bls.PublicKey] db *leveldb.DB batch *leveldb.Batch blockStore *blockStore @@ -66,6 +78,12 @@ func NewStore(conf *Config) (Store, error) { Strict: opt.DefaultStrict, Compression: opt.NoCompression, } + + pubKeyCache, err := lru.New[crypto.Address, *bls.PublicKey](conf.PublicKeyCacheSize) + if err != nil { + return nil, err + } + db, err := leveldb.OpenFile(conf.StorePath(), options) if err != nil { return nil, err @@ -73,12 +91,44 @@ func NewStore(conf *Config) (Store, error) { s := &store{ config: conf, db: db, + pubKeyCache: pubKeyCache, batch: new(leveldb.Batch), - blockStore: newBlockStore(db), - txStore: newTxStore(db), - accountStore: newAccountStore(db), + blockStore: newBlockStore(db, conf.SortitionCacheSize), + txStore: newTxStore(db, conf.TxCacheSize), + accountStore: newAccountStore(db, conf.AccountCacheSize), validatorStore: newValidatorStore(db), } + + lc := s.LastCertificate() + if lc == nil { + return s, nil + } + + currentHeight := lc.Height() + startHeight := uint32(1) + if currentHeight > conf.TxCacheSize { + startHeight = currentHeight - conf.TxCacheSize + } + + for i := startHeight; i < currentHeight+1; i++ { + committedBlock, err := s.Block(i) + if err != nil { + return nil, err + } + blk, err := committedBlock.ToBlock() + if err != nil { + return nil, err + } + + txs := blk.Transactions() + for _, transaction := range txs { + s.txStore.saveToCache(transaction.ID(), i) + } + + sortitionSeed := blk.Header().SortitionSeed() + s.blockStore.saveToCache(i, sortitionSeed) + } + return s, nil } @@ -94,10 +144,9 @@ func (s *store) SaveBlock(blk *block.Block, cert *certificate.Certificate) { defer s.lk.Unlock() height := cert.Height() - reg := s.blockStore.saveBlock(s.batch, height, blk) - for i, trx := range blk.Transactions() { - s.txStore.saveTx(s.batch, trx.ID(), ®[i]) - } + regs := s.blockStore.saveBlock(s.batch, height, blk) + s.txStore.saveTxs(s.batch, blk.Transactions(), regs) + s.txStore.pruneCache(height) // Save last certificate: [version: 4 bytes]+[certificate: variant] w := bytes.NewBuffer(make([]byte, 0, 4+cert.SerializeSize())) @@ -155,7 +204,18 @@ func (s *store) BlockHash(height uint32) hash.Hash { return hash.UndefHash } +func (s *store) SortitionSeed(blockHeight uint32) *sortition.VerifiableSeed { + s.lk.Lock() + defer s.lk.Unlock() + + return s.blockStore.sortitionSeed(blockHeight) +} + func (s *store) PublicKey(addr crypto.Address) (*bls.PublicKey, error) { + if pubKey, ok := s.pubKeyCache.Get(addr); ok { + return pubKey, nil + } + bs, err := tryGet(s.db, publicKeyKey(addr)) if err != nil { return nil, err @@ -165,6 +225,7 @@ func (s *store) PublicKey(addr crypto.Address) (*bls.PublicKey, error) { return nil, err } + s.pubKeyCache.Add(addr, pubKey) return pubKey, err } @@ -196,14 +257,11 @@ func (s *store) Transaction(id tx.ID) (*CommittedTx, error) { }, nil } -// TODO implement Dequeue for this function, for the better performance. func (s *store) AnyRecentTransaction(id tx.ID) bool { s.lk.Lock() defer s.lk.Unlock() - pos, _ := s.txStore.tx(id) - - return pos != nil + return s.txStore.hasTX(id) } func (s *store) HasAccount(addr crypto.Address) bool { diff --git a/store/store_test.go b/store/store_test.go index 91b9d5a2d..52cfeb886 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -19,15 +19,26 @@ type testData struct { store *store } -func setup(t *testing.T) *testData { +func testConfig() *Config { + return &Config{ + Path: util.TempDirPath(), + TxCacheSize: 1024, + SortitionCacheSize: 1024, + AccountCacheSize: 1024, + PublicKeyCacheSize: 1024, + } +} + +func setup(t *testing.T, config *Config) *testData { t.Helper() ts := testsuite.NewTestSuite(t) - conf := &Config{ - Path: util.TempDirPath(), + if config == nil { + config = testConfig() } - s, err := NewStore(conf) + + s, err := NewStore(config) require.NoError(t, err) td := &testData{ @@ -38,7 +49,6 @@ func setup(t *testing.T) *testData { // Save 10 blocks for height := uint32(0); height < 10; height++ { blk, cert := td.GenerateTestBlock(height + 1) - td.store.SaveBlock(blk, cert) assert.NoError(t, td.store.WriteBatch()) } @@ -46,7 +56,7 @@ func setup(t *testing.T) *testData { } func TestBlockHash(t *testing.T) { - td := setup(t) + td := setup(t, nil) sb, _ := td.store.Block(1) @@ -56,7 +66,7 @@ func TestBlockHash(t *testing.T) { } func TestBlockHeight(t *testing.T) { - td := setup(t) + td := setup(t, nil) sb, _ := td.store.Block(1) @@ -66,7 +76,7 @@ func TestBlockHeight(t *testing.T) { } func TestUnknownTransactionID(t *testing.T) { - td := setup(t) + td := setup(t, nil) trx, err := td.store.Transaction(td.RandHash()) assert.Error(t, err) @@ -74,7 +84,7 @@ func TestUnknownTransactionID(t *testing.T) { } func TestWriteAndClosePeacefully(t *testing.T) { - td := setup(t) + td := setup(t, nil) // After closing db, we should not crash assert.NoError(t, td.store.Close()) @@ -82,7 +92,7 @@ func TestWriteAndClosePeacefully(t *testing.T) { } func TestRetrieveBlockAndTransactions(t *testing.T) { - td := setup(t) + td := setup(t, nil) lastCert := td.store.LastCertificate() lastHeight := lastCert.Height() @@ -99,20 +109,23 @@ func TestRetrieveBlockAndTransactions(t *testing.T) { assert.Equal(t, trx.ID(), committedTx.TxID) assert.Equal(t, lastHeight, committedTx.Height) trx2, _ := committedTx.ToTx() - assert.Equal(t, trx2.ID(), trx.ID()) + assert.Equal(t, trx.ID(), trx2.ID()) } } func TestIndexingPublicKeys(t *testing.T) { - td := setup(t) + td := setup(t, nil) committedBlock, _ := td.store.Block(1) blk, _ := committedBlock.ToBlock() for _, trx := range blk.Transactions() { addr := trx.Payload().Signer() pub, found := td.store.PublicKey(addr) + pubKeyLruCache, ok := td.store.pubKeyCache.Get(addr) assert.NoError(t, found) + assert.True(t, ok) + assert.Equal(t, pub, pubKeyLruCache) if addr.IsAccountAddress() { assert.Equal(t, pub.AccountAddress(), addr) @@ -121,13 +134,18 @@ func TestIndexingPublicKeys(t *testing.T) { } } - pub, found := td.store.PublicKey(td.RandValAddress()) + randValAddress := td.RandValAddress() + pub, found := td.store.PublicKey(randValAddress) + pubKeyLruCache, ok := td.store.pubKeyCache.Get(randValAddress) + assert.Error(t, found) assert.Nil(t, pub) + assert.False(t, ok) + assert.Nil(t, pubKeyLruCache) } func TestStrippedPublicKey(t *testing.T) { - td := setup(t) + td := setup(t, nil) // Find a public key that we have already indexed in the database. committedBlock1, _ := td.store.Block(1) diff --git a/store/tx.go b/store/tx.go index 89b61ee0a..16e28cbe4 100644 --- a/store/tx.go +++ b/store/tx.go @@ -3,8 +3,10 @@ package store import ( "bytes" + "github.com/pactus-project/pactus/types/block" "github.com/pactus-project/pactus/types/tx" "github.com/pactus-project/pactus/util/encoding" + "github.com/pactus-project/pactus/util/linkedmap" "github.com/syndtr/goleveldb/leveldb" ) @@ -17,24 +19,50 @@ type blockRegion struct { func txKey(id tx.ID) []byte { return append(txPrefix, id.Bytes()...) } type txStore struct { - db *leveldb.DB + db *leveldb.DB + txIDCache *linkedmap.LinkedMap[tx.ID, uint32] + txCacheSize uint32 } -func newTxStore(db *leveldb.DB) *txStore { +func newTxStore(db *leveldb.DB, txCacheSize uint32) *txStore { return &txStore{ - db: db, + db: db, + txIDCache: linkedmap.New[tx.ID, uint32](0), + txCacheSize: txCacheSize, } } -func (ts *txStore) saveTx(batch *leveldb.Batch, id tx.ID, reg *blockRegion) { - w := bytes.NewBuffer(make([]byte, 0, 32+4)) - err := encoding.WriteElements(w, ®.height, ®.offset, ®.length) - if err != nil { - panic(err) +func (ts *txStore) saveTxs(batch *leveldb.Batch, txs block.Txs, regs []blockRegion) { + for i, trx := range txs { + w := bytes.NewBuffer(make([]byte, 0, 32+4)) + + reg := regs[i] + err := encoding.WriteElements(w, ®.height, ®.offset, ®.length) + if err != nil { + panic(err) + } + + id := trx.ID() + key := txKey(id) + batch.Put(key, w.Bytes()) + ts.saveToCache(id, reg.height) } +} + +func (ts *txStore) pruneCache(currentHeight uint32) { + for { + head := ts.txIDCache.HeadNode() + txHeight := head.Data.Value - txKey := txKey(id) - batch.Put(txKey, w.Bytes()) + if currentHeight-txHeight <= ts.txCacheSize { + break + } + ts.txIDCache.RemoveHead() + } +} + +func (ts *txStore) hasTX(id tx.ID) bool { + return ts.txIDCache.Has(id) } func (ts *txStore) tx(id tx.ID) (*blockRegion, error) { @@ -44,9 +72,12 @@ func (ts *txStore) tx(id tx.ID) (*blockRegion, error) { } r := bytes.NewReader(data) reg := new(blockRegion) - err = encoding.ReadElements(r, ®.height, ®.offset, ®.length) - if err != nil { + if err := encoding.ReadElements(r, ®.height, ®.offset, ®.length); err != nil { return nil, err } return reg, nil } + +func (ts *txStore) saveToCache(id tx.ID, height uint32) { + ts.txIDCache.PushBack(id, height) +} diff --git a/store/validator_test.go b/store/validator_test.go index aab47d47c..9e823416f 100644 --- a/store/validator_test.go +++ b/store/validator_test.go @@ -11,7 +11,7 @@ import ( ) func TestValidatorCounter(t *testing.T) { - td := setup(t) + td := setup(t, nil) num := td.RandInt32(1000) val, _ := td.GenerateTestValidator(num) @@ -46,7 +46,7 @@ func TestValidatorCounter(t *testing.T) { } func TestValidatorBatchSaving(t *testing.T) { - td := setup(t) + td := setup(t, nil) total := td.RandInt32NonZero(100) t.Run("Add some validators", func(t *testing.T) { @@ -66,7 +66,7 @@ func TestValidatorBatchSaving(t *testing.T) { } func TestValidatorAddresses(t *testing.T) { - td := setup(t) + td := setup(t, nil) total := td.RandInt32NonZero(100) addrs1 := make([]crypto.Address, 0, total) @@ -82,7 +82,7 @@ func TestValidatorAddresses(t *testing.T) { } func TestValidatorByNumber(t *testing.T) { - td := setup(t) + td := setup(t, nil) total := td.RandInt32NonZero(100) t.Run("Add some validators", func(t *testing.T) { @@ -131,7 +131,7 @@ func TestValidatorByNumber(t *testing.T) { } func TestValidatorByAddress(t *testing.T) { - td := setup(t) + td := setup(t, nil) total := td.RandInt32NonZero(100) t.Run("Add some validators", func(t *testing.T) { @@ -172,7 +172,7 @@ func TestValidatorByAddress(t *testing.T) { } func TestIterateValidators(t *testing.T) { - td := setup(t) + td := setup(t, nil) total := td.RandInt32NonZero(100) hashes1 := []hash.Hash{} @@ -201,7 +201,7 @@ func TestIterateValidators(t *testing.T) { } func TestValidatorDeepCopy(t *testing.T) { - td := setup(t) + td := setup(t, nil) num := td.RandInt32NonZero(1000) val1, _ := td.GenerateTestValidator(num) diff --git a/tests/main_test.go b/tests/main_test.go index f33ed244f..71d5596c3 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -67,7 +67,7 @@ func TestMain(m *testing.M) { tValKeys[i][0] = bls.NewValidatorKey(key0) tValKeys[i][1] = bls.NewValidatorKey(key1) tValKeys[i][2] = bls.NewValidatorKey(key2) - tConfigs[i] = config.DefaultConfigMainnet() + tConfigs[i] = config.DefaultConfigMainnet(param.DefaultParams()) tConfigs[i].Store.Path = util.TempDirPath() tConfigs[i].Consensus.ChangeProposerTimeout = 4 * time.Second @@ -173,7 +173,7 @@ func TestMain(m *testing.M) { panic("Sortition didn't work") } - // Let's shutdown the nodes + // Lets shutdown the nodes tCtx.Done() for i := 0; i < tTotalNodes; i++ { tNodes[i].Stop() diff --git a/txpool/pool.go b/txpool/pool.go index 6f63154d4..79d159d43 100644 --- a/txpool/pool.go +++ b/txpool/pool.go @@ -29,11 +29,11 @@ type txPool struct { func NewTxPool(conf *Config, broadcastCh chan message.Message) TxPool { pending := make(map[payload.Type]*linkedmap.LinkedMap[tx.ID, *tx.Tx]) - pending[payload.TypeTransfer] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.sendPoolSize()) - pending[payload.TypeBond] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.bondPoolSize()) - pending[payload.TypeUnbond] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.unbondPoolSize()) - pending[payload.TypeWithdraw] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.withdrawPoolSize()) - pending[payload.TypeSortition] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.sortitionPoolSize()) + pending[payload.TypeTransfer] = linkedmap.New[tx.ID, *tx.Tx](conf.sendPoolSize()) + pending[payload.TypeBond] = linkedmap.New[tx.ID, *tx.Tx](conf.bondPoolSize()) + pending[payload.TypeUnbond] = linkedmap.New[tx.ID, *tx.Tx](conf.unbondPoolSize()) + pending[payload.TypeWithdraw] = linkedmap.New[tx.ID, *tx.Tx](conf.withdrawPoolSize()) + pending[payload.TypeSortition] = linkedmap.New[tx.ID, *tx.Tx](conf.sortitionPoolSize()) pool := &txPool{ config: conf, diff --git a/util/encoding/encoding.go b/util/encoding/encoding.go index 8399435d6..07e2f461f 100644 --- a/util/encoding/encoding.go +++ b/util/encoding/encoding.go @@ -12,7 +12,7 @@ import ( const ( // MaxPayloadSize is the maximum bytes a message can be regardless of other // individual limits imposed by messages themselves. - MaxPayloadSize = (1024 * 1024 * 32) // 32MB + MaxPayloadSize = 1024 * 1024 * 32 // 32MB // binaryFreeListMaxItems is the number of buffers to keep in the free // list to use for binary serialization and deserialization. binaryFreeListMaxItems = 1024 @@ -157,7 +157,7 @@ func (l binaryFreeList) PutUint64(w io.Writer, val uint64) error { // deserializing primitive integer values to and from io.Readers and io.Writers. var binarySerializer binaryFreeList = make(chan []byte, binaryFreeListMaxItems) -// readElement reads the next sequence of bytes from r using little endian +// ReadElement reads the next sequence of bytes from r using little endian // depending on the concrete type of element pointed to. func ReadElement(r io.Reader, element interface{}) error { // Attempt to read the element based on the concrete type via fast @@ -207,7 +207,7 @@ func ReadElement(r io.Reader, element interface{}) error { return err } -// readElements reads multiple items from r. It is equivalent to multiple +// ReadElements reads multiple items from r. It is equivalent to multiple // calls to readElement. func ReadElements(r io.Reader, elements ...interface{}) error { for _, element := range elements { @@ -219,7 +219,7 @@ func ReadElements(r io.Reader, elements ...interface{}) error { return nil } -// writeElement writes the little endian representation of element to w. +// WriteElement writes the little endian representation of element to w. func WriteElement(w io.Writer, element interface{}) error { // Attempt to write the element based on the concrete type via fast // type assertions first. @@ -258,7 +258,7 @@ func WriteElement(w io.Writer, element interface{}) error { return err } -// writeElements writes multiple items to w. It is equivalent to multiple +// WriteElements writes multiple items to w. It is equivalent to multiple // calls to writeElement. func WriteElements(w io.Writer, elements ...interface{}) error { for _, element := range elements { diff --git a/util/linkedmap/linkedmap.go b/util/linkedmap/linkedmap.go index fcfa09e70..bd0cdd42e 100644 --- a/util/linkedmap/linkedmap.go +++ b/util/linkedmap/linkedmap.go @@ -15,8 +15,8 @@ type LinkedMap[K comparable, V any] struct { capacity int } -// NewLinkedMap creates a new LinkedMap with the specified capacity. -func NewLinkedMap[K comparable, V any](capacity int) *LinkedMap[K, V] { +// New creates a new LinkedMap with the specified capacity. +func New[K comparable, V any](capacity int) *LinkedMap[K, V] { return &LinkedMap[K, V]{ list: ll.New[Pair[K, V]](), hashmap: make(map[K]*ll.Element[Pair[K, V]]), @@ -87,6 +87,10 @@ func (lm *LinkedMap[K, V]) TailNode() *ll.Element[Pair[K, V]] { return ln } +func (lm *LinkedMap[K, V]) RemoveTail() { + lm.remove(lm.list.Tail, lm.list.Tail.Data.Key) +} + // HeadNode returns the LinkNode at the beginning (head) of the LinkedMap. func (lm *LinkedMap[K, V]) HeadNode() *ll.Element[Pair[K, V]] { ln := lm.list.Head @@ -96,6 +100,10 @@ func (lm *LinkedMap[K, V]) HeadNode() *ll.Element[Pair[K, V]] { return ln } +func (lm *LinkedMap[K, V]) RemoveHead() { + lm.remove(lm.list.Head, lm.list.Head.Data.Key) +} + // Remove removes the key-value pair with the specified key from the LinkedMap. // It returns true if the key was found and removed, otherwise false. func (lm *LinkedMap[K, V]) Remove(key K) bool { @@ -107,6 +115,12 @@ func (lm *LinkedMap[K, V]) Remove(key K) bool { return found } +// remove removes the key-value pair with the specified key from the LinkedMap and linkedlist.LinkedList. +func (lm *LinkedMap[K, V]) remove(element *ll.Element[Pair[K, V]], key K) { + lm.list.Delete(element) + delete(lm.hashmap, key) +} + // Empty checks if the LinkedMap is empty (contains no key-value pairs). func (lm *LinkedMap[K, V]) Empty() bool { return lm.Size() == 0 @@ -135,10 +149,14 @@ func (lm *LinkedMap[K, V]) Clear() { // prune removes excess elements from the LinkedMap if its size exceeds the capacity. func (lm *LinkedMap[K, V]) prune() { + if lm.capacity == 0 { + return + } + for lm.list.Length() > lm.capacity { - front := lm.list.Head - key := front.Data.Key - lm.list.Delete(front) + head := lm.list.Head + key := head.Data.Key + lm.list.Delete(head) delete(lm.hashmap, key) } } diff --git a/util/linkedmap/linkedmap_test.go b/util/linkedmap/linkedmap_test.go index 48f68956b..887d002c6 100644 --- a/util/linkedmap/linkedmap_test.go +++ b/util/linkedmap/linkedmap_test.go @@ -8,7 +8,7 @@ import ( func TestLinkedMap(t *testing.T) { t.Run("Test FirstNode", func(t *testing.T) { - lm := NewLinkedMap[int, string](4) + lm := New[int, string](4) assert.Nil(t, lm.HeadNode()) lm.PushFront(3, "c") @@ -20,7 +20,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Test LastNode", func(t *testing.T) { - lm := NewLinkedMap[int, string](4) + lm := New[int, string](4) assert.Nil(t, lm.TailNode()) lm.PushBack(1, "a") @@ -32,7 +32,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Test Get", func(t *testing.T) { - lm := NewLinkedMap[int, string](4) + lm := New[int, string](4) lm.PushBack(2, "b") lm.PushBack(1, "a") @@ -46,7 +46,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Test Remove", func(t *testing.T) { - lm := NewLinkedMap[int, string](4) + lm := New[int, string](4) lm.PushBack(0, "-") lm.PushBack(2, "b") @@ -55,8 +55,30 @@ func TestLinkedMap(t *testing.T) { assert.False(t, lm.Remove(2)) }) + t.Run("Test RemoveTail", func(t *testing.T) { + lm := New[int, string](4) + lm.PushBack(0, "-") + lm.PushBack(1, "a") + lm.PushBack(2, "b") + + lm.RemoveTail() + assert.Equal(t, lm.TailNode().Data.Value, "a") + assert.NotEqual(t, lm.TailNode().Data.Value, "b") + }) + + t.Run("Test RemoveHead", func(t *testing.T) { + lm := New[int, string](4) + lm.PushBack(0, "-") + lm.PushBack(1, "a") + lm.PushBack(2, "b") + + lm.RemoveHead() + assert.Equal(t, lm.HeadNode().Data.Value, "a") + assert.NotEqual(t, lm.HeadNode().Data.Value, "-") + }) + t.Run("Should updates v", func(t *testing.T) { - lm := NewLinkedMap[int, string](4) + lm := New[int, string](4) lm.PushBack(1, "a") lm.PushBack(1, "b") @@ -71,7 +93,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Should prunes oldest item", func(t *testing.T) { - lm := NewLinkedMap[int, string](4) + lm := New[int, string](4) lm.PushBack(1, "a") lm.PushBack(2, "b") @@ -89,7 +111,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Should prunes by changing capacity", func(t *testing.T) { - lm := NewLinkedMap[int, string](4) + lm := New[int, string](4) lm.PushBack(1, "a") lm.PushBack(2, "b") @@ -110,7 +132,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Test PushBack and prune", func(t *testing.T) { - lm := NewLinkedMap[int, string](3) + lm := New[int, string](3) lm.PushBack(1, "a") // This item should be pruned lm.PushBack(2, "b") @@ -123,7 +145,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Test PushFront and prune", func(t *testing.T) { - lm := NewLinkedMap[int, string](3) + lm := New[int, string](3) lm.PushFront(1, "a") lm.PushFront(2, "b") @@ -136,7 +158,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Delete first ", func(t *testing.T) { - lm := NewLinkedMap[int, string](3) + lm := New[int, string](3) lm.PushBack(1, "a") lm.PushBack(2, "b") @@ -149,7 +171,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Delete last", func(t *testing.T) { - lm := NewLinkedMap[int, string](3) + lm := New[int, string](3) lm.PushBack(1, "a") lm.PushBack(2, "b") @@ -162,7 +184,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Test Has function", func(t *testing.T) { - lm := NewLinkedMap[int, string](2) + lm := New[int, string](2) lm.PushBack(1, "a") @@ -171,7 +193,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Test Clear", func(t *testing.T) { - lm := NewLinkedMap[int, string](2) + lm := New[int, string](2) lm.PushBack(1, "a") lm.Clear() @@ -180,7 +202,202 @@ func TestLinkedMap(t *testing.T) { } func TestCapacity(t *testing.T) { - capacity := 100 - lm := NewLinkedMap[int, string](capacity) - assert.Equal(t, lm.Capacity(), capacity) + t.Run("Check Capacity", func(t *testing.T) { + capacity := 100 + lm := New[int, string](capacity) + assert.Equal(t, lm.Capacity(), capacity) + }) + + t.Run("Test FirstNode with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + assert.Nil(t, lm.HeadNode()) + + lm.PushFront(3, "c") + lm.PushFront(2, "b") + lm.PushFront(1, "a") + + assert.Equal(t, lm.HeadNode().Data.Key, 1) + assert.Equal(t, lm.HeadNode().Data.Value, "a") + }) + + t.Run("Test LastNode with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + assert.Nil(t, lm.TailNode()) + + lm.PushBack(1, "a") + lm.PushBack(2, "b") + lm.PushBack(3, "c") + + assert.Equal(t, lm.TailNode().Data.Key, 3) + assert.Equal(t, lm.TailNode().Data.Value, "c") + }) + + t.Run("Test Get with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(2, "b") + lm.PushBack(1, "a") + + n := lm.GetNode(2) + assert.Equal(t, n.Data.Key, 2) + assert.Equal(t, n.Data.Value, "b") + + n = lm.GetNode(5) + assert.Nil(t, n) + }) + + t.Run("Test Remove with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(0, "-") + lm.PushBack(2, "b") + lm.PushBack(1, "a") + assert.True(t, lm.Remove(2)) + assert.False(t, lm.Remove(2)) + }) + + t.Run("Test RemoveTail with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + lm.PushBack(0, "-") + lm.PushBack(1, "a") + lm.PushBack(2, "b") + + lm.RemoveTail() + assert.Equal(t, lm.TailNode().Data.Value, "a") + assert.NotEqual(t, lm.TailNode().Data.Value, "b") + }) + + t.Run("Test RemoveHead with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + lm.PushBack(0, "-") + lm.PushBack(1, "a") + lm.PushBack(2, "b") + + lm.RemoveHead() + assert.Equal(t, lm.HeadNode().Data.Value, "a") + assert.NotEqual(t, lm.HeadNode().Data.Value, "-") + }) + + t.Run("Should updates v with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + lm.PushBack(1, "a") + + lm.PushBack(1, "b") + n := lm.GetNode(1) + assert.Equal(t, n.Data.Key, 1) + assert.Equal(t, n.Data.Value, "b") + + lm.PushFront(1, "c") + n = lm.GetNode(1) + assert.Equal(t, n.Data.Key, 1) + assert.Equal(t, n.Data.Value, "c") + }) + + t.Run("Should not prunes oldest item with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(1, "a") + lm.PushBack(2, "b") + lm.PushBack(3, "c") + lm.PushBack(4, "d") + + n := lm.GetNode(1) + assert.Equal(t, n.Data.Key, 1) + assert.Equal(t, n.Data.Value, "a") + + lm.PushBack(5, "e") + + n = lm.GetNode(1) + assert.NotNil(t, n) + }) + + t.Run("Should prunes by changing capacity with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(1, "a") + lm.PushBack(2, "b") + lm.PushBack(3, "c") + lm.PushBack(4, "d") + + lm.SetCapacity(6) + + n := lm.GetNode(2) + assert.Equal(t, n.Data.Key, 2) + assert.Equal(t, n.Data.Value, "b") + + lm.SetCapacity(2) + assert.True(t, lm.Full()) + + n = lm.GetNode(2) + assert.Nil(t, n) + }) + + t.Run("Test PushBack and should not prune with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(1, "a") // This item should be pruned + lm.PushBack(2, "b") + lm.PushBack(3, "c") + lm.PushBack(4, "d") + + n := lm.TailNode() + assert.Equal(t, n.Data.Key, 4) + assert.Equal(t, n.Data.Value, "d") + }) + + t.Run("Test PushFront and prune with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushFront(1, "a") + lm.PushFront(2, "b") + lm.PushFront(3, "c") + lm.PushFront(4, "d") // This item should be pruned + + n := lm.TailNode() + assert.Equal(t, n.Data.Key, 1) + assert.Equal(t, n.Data.Value, "a") + }) + + t.Run("Delete first with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(1, "a") + lm.PushBack(2, "b") + lm.PushBack(3, "c") + + lm.Remove(1) + + assert.Equal(t, lm.HeadNode().Data.Key, 2) + assert.Equal(t, lm.HeadNode().Data.Value, "b") + }) + + t.Run("Delete last with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(1, "a") + lm.PushBack(2, "b") + lm.PushBack(3, "c") + + lm.Remove(3) + + assert.Equal(t, lm.TailNode().Data.Key, 2) + assert.Equal(t, lm.TailNode().Data.Value, "b") + }) + + t.Run("Test Has function with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(1, "a") + + assert.True(t, lm.Has(1)) + assert.False(t, lm.Has(2)) + }) + + t.Run("Test Clear with Zero Capacity", func(t *testing.T) { + lm := New[int, string](0) + + lm.PushBack(1, "a") + lm.Clear() + assert.True(t, lm.Empty()) + }) } diff --git a/util/pairslice/pairslice.go b/util/pairslice/pairslice.go new file mode 100644 index 000000000..65b176080 --- /dev/null +++ b/util/pairslice/pairslice.go @@ -0,0 +1,69 @@ +package pairslice + +import ( + "golang.org/x/exp/slices" +) + +// Pair represents a key-value pair. +type Pair[K comparable, V any] struct { + First K + Second V +} + +// PairSlice represents a slice of key-value pairs. +type PairSlice[K comparable, V any] struct { + pairs []*Pair[K, V] +} + +// New creates a new instance of PairSlice with a specified capacity. +func New[K comparable, V any](capacity int) *PairSlice[K, V] { + return &PairSlice[K, V]{ + pairs: make([]*Pair[K, V], 0, capacity), + } +} + +// Append adds the first and second to the end of the slice. +func (ps *PairSlice[K, V]) Append(first K, second V) { + ps.pairs = append(ps.pairs, &Pair[K, V]{first, second}) +} + +// RemoveFirst removes the first element from PairSlice. +func (ps *PairSlice[K, V]) RemoveFirst() { + ps.remove(0) +} + +// RemoveLast removes the last element from PairSlice. +func (ps *PairSlice[K, V]) RemoveLast() { + ps.remove(ps.Len() - 1) +} + +// Len returns the number of elements in the PairSlice. +func (ps *PairSlice[K, V]) Len() int { + return len(ps.pairs) +} + +// remove removes the element at the specified index from PairSlice. +func (ps *PairSlice[K, V]) remove(index int) { + ps.pairs = slices.Delete(ps.pairs, index, index+1) +} + +// Get returns the properties at the specified index. If the index is out of bounds, it returns false. +func (ps *PairSlice[K, V]) Get(index int) (K, V, bool) { + if index < 0 || index >= len(ps.pairs) { + var first K + var second V + return first, second, false + } + pair := ps.pairs[index] + return pair.First, pair.Second, true +} + +// First returns the first properties in the PairSlice. If the PairSlice is empty, it returns false. +func (ps *PairSlice[K, V]) First() (K, V, bool) { + return ps.Get(0) +} + +// Last returns the last properties in the PairSlice. If the PairSlice is empty, it returns false. +func (ps *PairSlice[K, V]) Last() (K, V, bool) { + return ps.Get(ps.Len() - 1) +} diff --git a/util/pairslice/pairslice_test.go b/util/pairslice/pairslice_test.go new file mode 100644 index 000000000..94820dce1 --- /dev/null +++ b/util/pairslice/pairslice_test.go @@ -0,0 +1,130 @@ +package pairslice + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + ps := New[int, string](10) + + assert.NotNil(t, ps) + assert.Equal(t, 10, cap(ps.pairs)) + assert.Equal(t, 0, len(ps.pairs)) +} + +func TestPairSlice(t *testing.T) { + t.Run("Test Append", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(2, "b") + ps.Append(3, "c") + + assert.Equal(t, 3, ps.Len()) + assert.Equal(t, ps.pairs[2].First, 3) + assert.Equal(t, ps.pairs[2].Second, "c") + }) + + t.Run("Test RemoveFirst", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(2, "b") + ps.Append(3, "c") + + ps.RemoveFirst() + + assert.Equal(t, ps.pairs[0].First, 2) + assert.Equal(t, ps.pairs[0].Second, "b") + }) + + t.Run("Test RemoveLast", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(2, "b") + ps.Append(3, "c") + ps.Append(4, "d") + + ps.RemoveLast() + + assert.Equal(t, ps.pairs[2].First, 3) + assert.Equal(t, ps.pairs[2].Second, "c") + }) + + t.Run("Test Len", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(2, "b") + + assert.Equal(t, 2, ps.Len()) + }) + + t.Run("Test Remove", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(2, "b") + ps.Append(3, "c") + ps.Append(4, "d") + + ps.remove(1) + + assert.Equal(t, ps.pairs[1].First, 3) + assert.Equal(t, ps.pairs[1].Second, "c") + }) + + t.Run("Test Get", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(2, "b") + ps.Append(3, "c") + ps.Append(4, "d") + + first, second, _ := ps.Get(2) + assert.Equal(t, ps.pairs[2].First, first) + assert.Equal(t, ps.pairs[2].Second, second) + }) + + t.Run("Test Get negative index or bigger than len", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(4, "d") + + _, _, result1 := ps.Get(-1) + _, _, result2 := ps.Get(10) + assert.False(t, result1) + assert.False(t, result2) + }) + + t.Run("Test First", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(2, "b") + ps.Append(3, "c") + ps.Append(4, "d") + + first, second, _ := ps.First() + assert.Equal(t, ps.pairs[0].First, first) + assert.Equal(t, ps.pairs[0].Second, second) + }) + + t.Run("Test Last", func(t *testing.T) { + ps := New[int, string](4) + + ps.Append(1, "a") + ps.Append(2, "b") + ps.Append(3, "c") + ps.Append(4, "d") + + first, second, _ := ps.Last() + assert.Equal(t, ps.pairs[3].First, first) + assert.Equal(t, ps.pairs[3].Second, second) + }) +}