From 914916f2322d71a4a53dfac827f4873b009a2d50 Mon Sep 17 00:00:00 2001 From: ragnarok87 Date: Tue, 13 Jun 2023 16:16:20 +0800 Subject: [PATCH 1/2] fix: updating linkedmap to use generic --- sandbox/sandbox.go | 24 ++--- store/store.go | 26 ++--- txpool/pool.go | 57 +++++------ util/encoding/encoding.go | 4 +- util/linkedmap/doublylink.go | 142 +++++++++++++++++++++++++++ util/linkedmap/doublylink_test.go | 114 ++++++++++++++++++++++ util/linkedmap/linkedmap.go | 144 +++++++++++----------------- util/linkedmap/linkedmap_test.go | 143 +++++++++++---------------- util/orderedmap/ordered_map.go | 78 --------------- util/orderedmap/ordered_map_test.go | 60 ------------ 10 files changed, 419 insertions(+), 373 deletions(-) create mode 100644 util/linkedmap/doublylink.go create mode 100644 util/linkedmap/doublylink_test.go delete mode 100644 util/orderedmap/ordered_map.go delete mode 100644 util/orderedmap/ordered_map_test.go diff --git a/sandbox/sandbox.go b/sandbox/sandbox.go index d5cbbca8b..b0c798ec3 100644 --- a/sandbox/sandbox.go +++ b/sandbox/sandbox.go @@ -74,8 +74,8 @@ func (sb *sandbox) shouldPanicForUnknownAddress() { } func (sb *sandbox) Account(addr crypto.Address) *account.Account { - sb.lk.RLock() - defer sb.lk.RUnlock() + sb.lk.Lock() + defer sb.lk.Unlock() s, ok := sb.accounts[addr] if ok { @@ -95,8 +95,8 @@ func (sb *sandbox) Account(addr crypto.Address) *account.Account { return acc } func (sb *sandbox) MakeNewAccount(addr crypto.Address) *account.Account { - sb.lk.RLock() - defer sb.lk.RUnlock() + sb.lk.Lock() + defer sb.lk.Unlock() if sb.store.HasAccount(addr) { sb.shouldPanicForDuplicatedAddress() @@ -112,8 +112,8 @@ func (sb *sandbox) MakeNewAccount(addr crypto.Address) *account.Account { } func (sb *sandbox) UpdateAccount(addr crypto.Address, acc *account.Account) { - sb.lk.RLock() - defer sb.lk.RUnlock() + sb.lk.Lock() + defer sb.lk.Unlock() s, ok := sb.accounts[addr] if !ok { @@ -124,8 +124,8 @@ func (sb *sandbox) UpdateAccount(addr crypto.Address, acc *account.Account) { } func (sb *sandbox) Validator(addr crypto.Address) *validator.Validator { - sb.lk.RLock() - defer sb.lk.RUnlock() + sb.lk.Lock() + defer sb.lk.Unlock() s, ok := sb.validators[addr] if ok { @@ -145,8 +145,8 @@ func (sb *sandbox) Validator(addr crypto.Address) *validator.Validator { } func (sb *sandbox) MakeNewValidator(pub *bls.PublicKey) *validator.Validator { - sb.lk.RLock() - defer sb.lk.RUnlock() + sb.lk.Lock() + defer sb.lk.Unlock() addr := pub.Address() if sb.store.HasValidator(addr) { @@ -163,8 +163,8 @@ func (sb *sandbox) MakeNewValidator(pub *bls.PublicKey) *validator.Validator { } func (sb *sandbox) UpdateValidator(val *validator.Validator) { - sb.lk.RLock() - defer sb.lk.RUnlock() + sb.lk.Lock() + defer sb.lk.Unlock() addr := val.Address() s, ok := sb.validators[addr] diff --git a/store/store.go b/store/store.go index f5015108f..07116c962 100644 --- a/store/store.go +++ b/store/store.go @@ -57,7 +57,7 @@ type store struct { txStore *txStore accountStore *accountStore validatorStore *validatorStore - stampLookup *linkedmap.LinkedMap + stampLookup *linkedmap.LinkedMap[hash.Stamp, hashPair] } func NewStore(conf *Config, stampLookupCapacity int) (Store, error) { @@ -78,7 +78,7 @@ func NewStore(conf *Config, stampLookupCapacity int) (Store, error) { txStore: newTxStore(db), accountStore: newAccountStore(db), validatorStore: newValidatorStore(db), - stampLookup: linkedmap.NewLinkedMap(stampLookupCapacity), + stampLookup: linkedmap.NewLinkedMap[hash.Stamp, hashPair](stampLookupCapacity), } lastHeight, _ := s.LastCertificate() @@ -99,7 +99,7 @@ func (s *store) Close() error { } func (s *store) appendStamp(hash hash.Hash, height uint32) { - pair := &hashPair{ + pair := hashPair{ Height: height, Hash: hash, } @@ -174,31 +174,31 @@ func (s *store) BlockHash(height uint32) hash.Hash { } func (s *store) FindBlockHashByStamp(stamp hash.Stamp) (hash.Hash, bool) { - s.lk.Lock() - defer s.lk.Unlock() + s.lk.RLock() + defer s.lk.RUnlock() if stamp.EqualsTo(hash.UndefHash.Stamp()) { return hash.UndefHash, true } - v, ok := s.stampLookup.Get(stamp) - if ok { - return v.(*hashPair).Hash, true + n := s.stampLookup.GetNode(stamp) + if n != nil { + return n.Data.Value.Hash, true } return hash.UndefHash, false } func (s *store) FindBlockHeightByStamp(stamp hash.Stamp) (uint32, bool) { - s.lk.Lock() - defer s.lk.Unlock() + s.lk.RLock() + defer s.lk.RUnlock() if stamp.EqualsTo(hash.UndefHash.Stamp()) { return 0, true } - v, ok := s.stampLookup.Get(stamp) - if ok { - return v.(*hashPair).Height, true + n := s.stampLookup.GetNode(stamp) + if n != nil { + return n.Data.Value.Height, true } return 0, false } diff --git a/txpool/pool.go b/txpool/pool.go index a5c48f23a..236853364 100644 --- a/txpool/pool.go +++ b/txpool/pool.go @@ -1,7 +1,6 @@ package txpool import ( - "container/list" "fmt" "sync" @@ -21,24 +20,24 @@ type txPool struct { config *Config checker *execution.Execution sandbox sandbox.Sandbox - pools map[payload.Type]*linkedmap.LinkedMap + pools map[payload.Type]*linkedmap.LinkedMap[tx.ID, *tx.Tx] broadcastCh chan message.Message logger *logger.Logger } func NewTxPool(conf *Config, broadcastCh chan message.Message) TxPool { - pendings := make(map[payload.Type]*linkedmap.LinkedMap) + pending := make(map[payload.Type]*linkedmap.LinkedMap[tx.ID, *tx.Tx]) - pendings[payload.PayloadTypeTransfer] = linkedmap.NewLinkedMap(conf.sendPoolSize()) - pendings[payload.PayloadTypeBond] = linkedmap.NewLinkedMap(conf.bondPoolSize()) - pendings[payload.PayloadTypeUnbond] = linkedmap.NewLinkedMap(conf.unbondPoolSize()) - pendings[payload.PayloadTypeWithdraw] = linkedmap.NewLinkedMap(conf.withdrawPoolSize()) - pendings[payload.PayloadTypeSortition] = linkedmap.NewLinkedMap(conf.sortitionPoolSize()) + pending[payload.PayloadTypeTransfer] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.sendPoolSize()) + pending[payload.PayloadTypeBond] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.bondPoolSize()) + pending[payload.PayloadTypeUnbond] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.unbondPoolSize()) + pending[payload.PayloadTypeWithdraw] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.withdrawPoolSize()) + pending[payload.PayloadTypeSortition] = linkedmap.NewLinkedMap[tx.ID, *tx.Tx](conf.sortitionPoolSize()) pool := &txPool{ config: conf, checker: execution.NewChecker(), - pools: pendings, + pools: pending, broadcastCh: broadcastCh, } @@ -53,11 +52,11 @@ func (p *txPool) SetNewSandboxAndRecheck(sb sandbox.Sandbox) { p.sandbox = sb p.logger.Debug("set new sandbox") - var next *list.Element + var next *linkedmap.LinkNode[linkedmap.Pair[tx.ID, *tx.Tx]] for _, pool := range p.pools { - for e := pool.FirstElement(); e != nil; e = next { - next = e.Next() - trx := e.Value.(*linkedmap.Pair).Second.(*tx.Tx) + for e := pool.FirstNode(); e != nil; e = next { + next = e.Next + trx := e.Data.Value if err := p.checkTx(trx); err != nil { p.logger.Debug("invalid transaction after rechecking", "id", trx.ID()) @@ -136,10 +135,9 @@ func (p *txPool) PendingTx(id tx.ID) *tx.Tx { defer p.lk.Unlock() for _, pool := range p.pools { - val, found := pool.Get(id) - if found { - trx := val.(*tx.Tx) - return trx + n := pool.GetNode(id) + if n != nil { + return n.Data.Value } } @@ -154,37 +152,32 @@ func (p *txPool) PrepareBlockTransactions() block.Txs { // Appending one sortition transaction poolSortition := p.pools[payload.PayloadTypeSortition] - for e := poolSortition.FirstElement(); e != nil; e = e.Next() { - trx := e.Value.(*linkedmap.Pair).Second.(*tx.Tx) - trxs = append(trxs, trx) + for n := poolSortition.FirstNode(); n != nil; n = n.Next { + trxs = append(trxs, n.Data.Value) } // Appending bond transactions poolBond := p.pools[payload.PayloadTypeBond] - for e := poolBond.FirstElement(); e != nil; e = e.Next() { - trx := e.Value.(*linkedmap.Pair).Second.(*tx.Tx) - trxs = append(trxs, trx) + for n := poolBond.FirstNode(); n != nil; n = n.Next { + trxs = append(trxs, n.Data.Value) } // Appending unbond transactions poolUnbond := p.pools[payload.PayloadTypeUnbond] - for e := poolUnbond.FirstElement(); e != nil; e = e.Next() { - trx := e.Value.(*linkedmap.Pair).Second.(*tx.Tx) - trxs = append(trxs, trx) + for n := poolUnbond.FirstNode(); n != nil; n = n.Next { + trxs = append(trxs, n.Data.Value) } // Appending withdraw transactions poolWithdraw := p.pools[payload.PayloadTypeWithdraw] - for e := poolWithdraw.FirstElement(); e != nil; e = e.Next() { - trx := e.Value.(*linkedmap.Pair).Second.(*tx.Tx) - trxs = append(trxs, trx) + for n := poolWithdraw.FirstNode(); n != nil; n = n.Next { + trxs = append(trxs, n.Data.Value) } // Appending send transactions poolSend := p.pools[payload.PayloadTypeTransfer] - for e := poolSend.FirstElement(); e != nil; e = e.Next() { - trx := e.Value.(*linkedmap.Pair).Second.(*tx.Tx) - trxs = append(trxs, trx) + for n := poolSend.FirstNode(); n != nil; n = n.Next { + trxs = append(trxs, n.Data.Value) } return trxs diff --git a/util/encoding/encoding.go b/util/encoding/encoding.go index 47a4f63b7..7828404bc 100644 --- a/util/encoding/encoding.go +++ b/util/encoding/encoding.go @@ -20,8 +20,8 @@ const ( ) var ( - ErrOverflow = errors.New("Overflow") - ErrNonCanonical = errors.New("NonCanonical") + ErrOverflow = errors.New("overflow") + ErrNonCanonical = errors.New("non canonical") ) // binaryFreeList defines a concurrent safe free list of byte slices (up to the diff --git a/util/linkedmap/doublylink.go b/util/linkedmap/doublylink.go new file mode 100644 index 000000000..6fd191a9c --- /dev/null +++ b/util/linkedmap/doublylink.go @@ -0,0 +1,142 @@ +package linkedmap + +type LinkNode[T any] struct { + Data T + Next *LinkNode[T] + Prev *LinkNode[T] +} + +func NewLinkNode[T any](data T) *LinkNode[T] { + return &LinkNode[T]{ + Data: data, + Next: nil, + Prev: nil, + } +} + +// DoublyLinkedList represents a doubly linked list +type DoublyLinkedList[T any] struct { + Head *LinkNode[T] + Tail *LinkNode[T] + length int +} + +func NewDoublyLinkedList[T any]() *DoublyLinkedList[T] { + return &DoublyLinkedList[T]{ + Head: nil, + Tail: nil, + length: 0, + } +} + +// InsertAtHead inserts a new node at the head of the list +func (l *DoublyLinkedList[T]) InsertAtHead(data T) *LinkNode[T] { + newNode := NewLinkNode(data) + + if l.Head == nil { + // Empty list case + l.Head = newNode + l.Tail = newNode + } else { + newNode.Next = l.Head + l.Head.Prev = newNode + l.Head = newNode + } + + l.length++ + + return newNode +} + +// InsertAtTail appends a new node at the tail of the list +func (l *DoublyLinkedList[T]) InsertAtTail(data T) *LinkNode[T] { + newNode := NewLinkNode(data) + + if l.Head == nil { + // Empty list case + l.Head = newNode + l.Tail = newNode + } else { + newNode.Prev = l.Tail + l.Tail.Next = newNode + l.Tail = newNode + } + + l.length++ + + return newNode +} + +// DeleteAtHead deletes the node at the head of the list +func (l *DoublyLinkedList[T]) DeleteAtHead() { + if l.Head == nil { + // Empty list case + return + } + + l.Head = l.Head.Next + if l.Head != nil { + l.Head.Prev = nil + } else { + l.Tail = nil + } + + l.length-- +} + +// DeleteAtTail deletes the node at the tail of the list +func (l *DoublyLinkedList[T]) DeleteAtTail() { + if l.Tail == nil { + // Empty list case + return + } + + l.Tail = l.Tail.Prev + if l.Tail != nil { + l.Tail.Next = nil + } else { + l.Head = nil + } + + l.length-- +} + +// Delete removes a specific node from the list +func (l *DoublyLinkedList[T]) Delete(ln *LinkNode[T]) { + if ln.Prev != nil { + ln.Prev.Next = ln.Next + } else { + l.Head = ln.Next + } + + if ln.Next != nil { + ln.Next.Prev = ln.Prev + } else { + l.Tail = ln.Prev + } + + l.length-- +} + +// Length returns the number of nodes in the list +func (l *DoublyLinkedList[T]) Length() int { + return l.length +} + +// Values returns a slice of values in the list +func (l *DoublyLinkedList[T]) Values() []T { + values := []T{} + cur := l.Head + for cur != nil { + values = append(values, cur.Data) + cur = cur.Next + } + return values +} + +// Clear removes all nodes from the list, making it empty +func (l *DoublyLinkedList[T]) Clear() { + l.Head = nil + l.Tail = nil + l.length = 0 +} diff --git a/util/linkedmap/doublylink_test.go b/util/linkedmap/doublylink_test.go new file mode 100644 index 000000000..d9f4e11ef --- /dev/null +++ b/util/linkedmap/doublylink_test.go @@ -0,0 +1,114 @@ +package linkedmap + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDoublyLink_InsertAtHead(t *testing.T) { + link := NewDoublyLinkedList[int]() + link.InsertAtHead(1) + link.InsertAtHead(2) + link.InsertAtHead(3) + link.InsertAtHead(4) + + assert.Equal(t, link.Values(), []int{4, 3, 2, 1}) + assert.Equal(t, link.Length(), 4) + assert.Equal(t, link.Head.Data, 4) + assert.Equal(t, link.Tail.Data, 1) +} + +func TestSinglyLink_InsertAtTail(t *testing.T) { + link := NewDoublyLinkedList[int]() + link.InsertAtTail(1) + link.InsertAtTail(2) + link.InsertAtTail(3) + link.InsertAtTail(4) + + assert.Equal(t, link.Values(), []int{1, 2, 3, 4}) + assert.Equal(t, link.Length(), 4) + assert.Equal(t, link.Head.Data, 1) + assert.Equal(t, link.Tail.Data, 4) +} + +func TestDeleteAtHead(t *testing.T) { + link := NewDoublyLinkedList[int]() + link.InsertAtTail(1) + link.InsertAtTail(2) + link.InsertAtTail(3) + + link.DeleteAtHead() + assert.Equal(t, link.Values(), []int{2, 3}) + assert.Equal(t, link.Length(), 2) + + link.DeleteAtHead() + assert.Equal(t, link.Values(), []int{3}) + assert.Equal(t, link.Length(), 1) + + link.DeleteAtHead() + assert.Equal(t, link.Values(), []int{}) + assert.Equal(t, link.Length(), 0) + + link.DeleteAtHead() + assert.Equal(t, link.Values(), []int{}) + assert.Equal(t, link.Length(), 0) +} + +func TestDeleteAtTail(t *testing.T) { + link := NewDoublyLinkedList[int]() + link.InsertAtTail(1) + link.InsertAtTail(2) + link.InsertAtTail(3) + + link.DeleteAtTail() + assert.Equal(t, link.Values(), []int{1, 2}) + assert.Equal(t, link.Length(), 2) + + link.DeleteAtTail() + assert.Equal(t, link.Values(), []int{1}) + assert.Equal(t, link.Length(), 1) + + link.DeleteAtTail() + assert.Equal(t, link.Values(), []int{}) + assert.Equal(t, link.Length(), 0) + + link.DeleteAtTail() + assert.Equal(t, link.Values(), []int{}) + assert.Equal(t, link.Length(), 0) +} + +func TestDelete(t *testing.T) { + link := NewDoublyLinkedList[int]() + n1 := link.InsertAtTail(1) + n2 := link.InsertAtTail(2) + n3 := link.InsertAtTail(3) + n4 := link.InsertAtTail(4) + + link.Delete(n1) + assert.Equal(t, link.Values(), []int{2, 3, 4}) + assert.Equal(t, link.Length(), 3) + + link.Delete(n4) + assert.Equal(t, link.Values(), []int{2, 3}) + assert.Equal(t, link.Length(), 2) + + link.Delete(n2) + assert.Equal(t, link.Values(), []int{3}) + assert.Equal(t, link.Length(), 1) + + link.Delete(n3) + assert.Equal(t, link.Values(), []int{}) + assert.Equal(t, link.Length(), 0) +} + +func TestClear(t *testing.T) { + link := NewDoublyLinkedList[int]() + link.InsertAtTail(1) + link.InsertAtTail(2) + link.InsertAtTail(3) + + link.Clear() + assert.Equal(t, link.Values(), []int{}) + assert.Equal(t, link.Length(), 0) +} diff --git a/util/linkedmap/linkedmap.go b/util/linkedmap/linkedmap.go index 92f1225f9..fd1c39a9a 100644 --- a/util/linkedmap/linkedmap.go +++ b/util/linkedmap/linkedmap.go @@ -1,158 +1,126 @@ package linkedmap -import ( - "container/list" -) - // TODO: should be thread safe or not? -type Pair struct { - First, Second interface{} +type Pair[K comparable, V any] struct { + Key K + Value V } -type LinkedMap struct { - list *list.List - hashmap map[interface{}]*list.Element +type LinkedMap[K comparable, V any] struct { + list *DoublyLinkedList[Pair[K, V]] + hashmap map[K]*LinkNode[Pair[K, V]] capacity int } -func NewLinkedMap(capacity int) *LinkedMap { - return &LinkedMap{ - list: list.New(), - hashmap: make(map[interface{}]*list.Element), +func NewLinkedMap[K comparable, V any](capacity int) *LinkedMap[K, V] { + return &LinkedMap[K, V]{ + list: NewDoublyLinkedList[Pair[K, V]](), + hashmap: make(map[K]*LinkNode[Pair[K, V]]), capacity: capacity, } } -func (lm *LinkedMap) SetCapacity(capacity int) { +func (lm *LinkedMap[K, V]) SetCapacity(capacity int) { lm.capacity = capacity lm.prune() } -func (lm *LinkedMap) Has(key interface{}) bool { +func (lm *LinkedMap[K, V]) Has(key K) bool { _, found := lm.hashmap[key] return found } -func (lm *LinkedMap) PushBack(first interface{}, second interface{}) { - el, found := lm.hashmap[first] +func (lm *LinkedMap[K, V]) PushBack(key K, value V) { + ln, found := lm.hashmap[key] if found { // update the second - el.Value.(*Pair).Second = second + ln.Data.Value = value return } - el = lm.list.PushBack(&Pair{first, second}) - lm.hashmap[first] = el + p := Pair[K, V]{Key: key, Value: value} + ln = lm.list.InsertAtTail(p) + lm.hashmap[key] = ln lm.prune() } -func (lm *LinkedMap) PushFront(first interface{}, second interface{}) { - el, found := lm.hashmap[first] +func (lm *LinkedMap[K, V]) PushFront(key K, value V) { + ln, found := lm.hashmap[key] if found { // update the second - el.Value.(*Pair).Second = second + ln.Data.Value = value return } - el = lm.list.PushFront(&Pair{first, second}) - lm.hashmap[first] = el + p := Pair[K, V]{Key: key, Value: value} + ln = lm.list.InsertAtHead(p) + lm.hashmap[key] = ln lm.prune() } -func (lm *LinkedMap) Get(first interface{}) (interface{}, bool) { - el, found := lm.hashmap[first] +func (lm *LinkedMap[K, V]) GetNode(key K) *LinkNode[Pair[K, V]] { + ln, found := lm.hashmap[key] if found { - return el.Value.(*Pair).Second, true + return ln } - return nil, false + return nil } -func (lm *LinkedMap) Last() (interface{}, interface{}) { - el := lm.list.Back() - if el == nil { - return nil, nil +func (lm *LinkedMap[K, V]) LastNode() *LinkNode[Pair[K, V]] { + ln := lm.list.Tail + if ln == nil { + return nil } - p := el.Value.(*Pair) - return p.First, p.Second + return ln } -func (lm *LinkedMap) First() (interface{}, interface{}) { - el := lm.list.Front() - if el == nil { - return nil, nil +func (lm *LinkedMap[K, V]) FirstNode() *LinkNode[Pair[K, V]] { + ln := lm.list.Head + if ln == nil { + return nil } - p := el.Value.(*Pair) - return p.First, p.Second -} - -func (lm *LinkedMap) LastElement() *list.Element { - return lm.list.Back() + return ln } -func (lm *LinkedMap) FirstElement() *list.Element { - return lm.list.Front() -} - -func (lm *LinkedMap) Remove(first interface{}) bool { - el, found := lm.hashmap[first] +func (lm *LinkedMap[K, V]) Remove(key K) bool { + nl, found := lm.hashmap[key] if found { - lm.list.Remove(el) - delete(lm.hashmap, el.Value.(*Pair).First) + lm.list.Delete(nl) + delete(lm.hashmap, nl.Data.Key) } return found } -func (lm *LinkedMap) Empty() bool { +func (lm *LinkedMap[K, V]) Empty() bool { return lm.Size() == 0 } -func (lm *LinkedMap) Capacity() int { +func (lm *LinkedMap[K, V]) Capacity() int { return lm.capacity } -func (lm *LinkedMap) Size() int { - return lm.list.Len() +func (lm *LinkedMap[K, V]) Size() int { + return lm.list.Length() } -func (lm *LinkedMap) Full() bool { - return lm.list.Len() == lm.capacity +func (lm *LinkedMap[K, V]) Full() bool { + return lm.list.Length() == lm.capacity } -func (lm *LinkedMap) Clear() { - lm.list = list.New() - lm.hashmap = make(map[interface{}]*list.Element) +func (lm *LinkedMap[K, V]) Clear() { + lm.list.Clear() + lm.hashmap = make(map[K]*LinkNode[Pair[K, V]]) } -func (lm *LinkedMap) prune() { - for lm.list.Len() > lm.capacity { - front := lm.list.Front() - key := front.Value.(*Pair).First - lm.list.Remove(front) +func (lm *LinkedMap[K, V]) prune() { + for lm.list.Length() > lm.capacity { + front := lm.list.Head + key := front.Data.Key + lm.list.Delete(front) delete(lm.hashmap, key) } } - -func (lm *LinkedMap) SortList(cmp func(left interface{}, right interface{}) bool) { - index := lm.list.Front() - if index == nil { - return - } - - for index != nil { - current := index.Next() - for current != nil { - if cmp(current.Value.(*Pair).Second, index.Value.(*Pair).Second) { - lm.list.MoveBefore(current, index) - index = current - current = index - } - current = current.Next() - } - - index = index.Next() - } -} diff --git a/util/linkedmap/linkedmap_test.go b/util/linkedmap/linkedmap_test.go index 87ae5f646..68fda44b8 100644 --- a/util/linkedmap/linkedmap_test.go +++ b/util/linkedmap/linkedmap_test.go @@ -8,57 +8,46 @@ import ( ) func TestLinkedMap(t *testing.T) { - t.Run("Test FirstElement", func(t *testing.T) { - lm := NewLinkedMap(4) - k, v := lm.First() - assert.Nil(t, lm.FirstElement()) - assert.Nil(t, k) - assert.Nil(t, v) + t.Run("Test FirstNode", func(t *testing.T) { + lm := NewLinkedMap[int, string](4) + assert.Nil(t, lm.FirstNode()) lm.PushFront(3, "c") lm.PushFront(2, "b") lm.PushFront(1, "a") - k, v = lm.First() - assert.Equal(t, lm.FirstElement().Value, &Pair{1, "a"}) - assert.Equal(t, k, 1) - assert.Equal(t, v, "a") + assert.Equal(t, lm.FirstNode().Data.Key, 1) + assert.Equal(t, lm.FirstNode().Data.Value, "a") }) - t.Run("Test LastElement", func(t *testing.T) { - lm := NewLinkedMap(4) - k, v := lm.Last() - assert.Nil(t, lm.LastElement()) - assert.Nil(t, k) - assert.Nil(t, v) + t.Run("Test LastNode", func(t *testing.T) { + lm := NewLinkedMap[int, string](4) + assert.Nil(t, lm.LastNode()) lm.PushBack(1, "a") lm.PushBack(2, "b") lm.PushBack(3, "c") - k, v = lm.Last() - assert.Equal(t, lm.LastElement().Value, &Pair{3, "c"}) - assert.Equal(t, k, 3) - assert.Equal(t, v, "c") + assert.Equal(t, lm.LastNode().Data.Key, 3) + assert.Equal(t, lm.LastNode().Data.Value, "c") }) t.Run("Test Get", func(t *testing.T) { - lm := NewLinkedMap(4) + lm := NewLinkedMap[int, string](4) lm.PushBack(2, "b") lm.PushBack(1, "a") - v, ok := lm.Get(2) - assert.Equal(t, ok, true) - assert.Equal(t, v, "b") + n := lm.GetNode(2) + assert.Equal(t, n.Data.Key, 2) + assert.Equal(t, n.Data.Value, "b") - v, ok = lm.Get(5) - assert.Equal(t, ok, false) - assert.Equal(t, v, nil) + n = lm.GetNode(5) + assert.Nil(t, n) }) t.Run("Test Remove", func(t *testing.T) { - lm := NewLinkedMap(4) + lm := NewLinkedMap[int, string](4) lm.PushBack(0, "-") lm.PushBack(2, "b") @@ -68,41 +57,40 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Should updates v", func(t *testing.T) { - lm := NewLinkedMap(4) + lm := NewLinkedMap[int, string](4) lm.PushBack(1, "a") lm.PushBack(1, "b") - v, ok := lm.Get(1) - assert.Equal(t, ok, true) - assert.Equal(t, v, "b") + n := lm.GetNode(1) + assert.Equal(t, n.Data.Key, 1) + assert.Equal(t, n.Data.Value, "b") lm.PushFront(1, "c") - v, ok = lm.Get(1) - assert.Equal(t, ok, true) - assert.Equal(t, v, "c") + n = lm.GetNode(1) + assert.Equal(t, n.Data.Key, 1) + assert.Equal(t, n.Data.Value, "c") }) t.Run("Should prunes oldest item", func(t *testing.T) { - lm := NewLinkedMap(4) + lm := NewLinkedMap[int, string](4) lm.PushBack(1, "a") lm.PushBack(2, "b") lm.PushBack(3, "c") lm.PushBack(4, "d") - v, ok := lm.Get(1) - assert.Equal(t, ok, true) - assert.Equal(t, v, "a") + n := lm.GetNode(1) + assert.Equal(t, n.Data.Key, 1) + assert.Equal(t, n.Data.Value, "a") lm.PushBack(5, "e") - v, ok = lm.Get(1) - assert.Equal(t, ok, false) - assert.Equal(t, v, nil) + n = lm.GetNode(1) + assert.Nil(t, n) }) t.Run("Should prunes by changing capacity", func(t *testing.T) { - lm := NewLinkedMap(4) + lm := NewLinkedMap[int, string](4) lm.PushBack(1, "a") lm.PushBack(2, "b") @@ -111,48 +99,45 @@ func TestLinkedMap(t *testing.T) { lm.SetCapacity(6) - v, ok := lm.Get(2) - assert.Equal(t, ok, true) - assert.Equal(t, v, "b") + n := lm.GetNode(2) + assert.Equal(t, n.Data.Key, 2) + assert.Equal(t, n.Data.Value, "b") lm.SetCapacity(2) assert.True(t, lm.Full()) - v, ok = lm.Get(2) - assert.Equal(t, ok, false) - assert.Equal(t, v, nil) + n = lm.GetNode(2) + assert.Nil(t, n) }) t.Run("Test PushBack and prune", func(t *testing.T) { - lm := NewLinkedMap(3) + lm := NewLinkedMap[int, string](3) lm.PushBack(1, "a") // This item should be pruned lm.PushBack(2, "b") lm.PushBack(3, "c") lm.PushBack(4, "d") - k, v := lm.First() - assert.Equal(t, lm.FirstElement().Value, &Pair{2, "b"}) - assert.Equal(t, k, 2) - assert.Equal(t, v, "b") + n := lm.FirstNode() + assert.Equal(t, n.Data.Key, 2) + assert.Equal(t, n.Data.Value, "b") }) t.Run("Test PushFront and prune", func(t *testing.T) { - lm := NewLinkedMap(3) + lm := NewLinkedMap[int, string](3) lm.PushFront(1, "a") lm.PushFront(2, "b") lm.PushFront(3, "c") lm.PushFront(4, "d") // This item should be pruned - k, v := lm.Last() - assert.Equal(t, lm.LastElement().Value, &Pair{1, "a"}) - assert.Equal(t, k, 1) - assert.Equal(t, v, "a") + n := lm.LastNode() + assert.Equal(t, n.Data.Key, 1) + assert.Equal(t, n.Data.Value, "a") }) - t.Run("Deletd first ", func(t *testing.T) { - lm := NewLinkedMap(3) + t.Run("Delete first ", func(t *testing.T) { + lm := NewLinkedMap[int, string](3) lm.PushBack(1, "a") lm.PushBack(2, "b") @@ -160,22 +145,25 @@ func TestLinkedMap(t *testing.T) { lm.Remove(1) - assert.Equal(t, lm.FirstElement().Value, &Pair{2, "b"}) + assert.Equal(t, lm.FirstNode().Data.Key, 2) + assert.Equal(t, lm.FirstNode().Data.Value, "b") }) t.Run("Delete last", func(t *testing.T) { - lm := NewLinkedMap(3) + lm := NewLinkedMap[int, string](3) lm.PushBack(1, "a") lm.PushBack(2, "b") lm.PushBack(3, "c") lm.Remove(3) - assert.Equal(t, lm.LastElement().Value, &Pair{2, "b"}) + + assert.Equal(t, lm.LastNode().Data.Key, 2) + assert.Equal(t, lm.LastNode().Data.Value, "b") }) t.Run("Test Has function", func(t *testing.T) { - lm := NewLinkedMap(2) + lm := NewLinkedMap[int, string](2) lm.PushBack(1, "a") @@ -184,7 +172,7 @@ func TestLinkedMap(t *testing.T) { }) t.Run("Test Clear", func(t *testing.T) { - lm := NewLinkedMap(2) + lm := NewLinkedMap[int, string](2) lm.PushBack(1, "a") lm.Clear() @@ -192,29 +180,8 @@ func TestLinkedMap(t *testing.T) { }) } -func TestSortingLinkedMap(t *testing.T) { - lm := NewLinkedMap(6) - - cmp := func(left interface{}, right interface{}) bool { - return left.(string) < right.(string) - } - lm.SortList(cmp) - assert.Nil(t, lm.FirstElement()) - - lm.PushBack(3, "c") - lm.PushBack(5, "e") - lm.PushBack(1, "a") - lm.PushBack(2, "b") - lm.PushBack(4, "d") - - lm.SortList(cmp) - assert.Equal(t, lm.FirstElement().Value, &Pair{1, "a"}) - assert.Equal(t, lm.LastElement().Value, &Pair{5, "e"}) - assert.Equal(t, lm.Size(), 5) -} - func TestCapacity(t *testing.T) { capacity := int(util.RandInt32(1000)) - lm := NewLinkedMap(capacity) + lm := NewLinkedMap[int, string](capacity) assert.Equal(t, lm.Capacity(), capacity) } diff --git a/util/orderedmap/ordered_map.go b/util/orderedmap/ordered_map.go deleted file mode 100644 index 88c1ce823..000000000 --- a/util/orderedmap/ordered_map.go +++ /dev/null @@ -1,78 +0,0 @@ -package orderedmap - -import ( - "github.com/google/btree" -) - -type OrderedMap struct { - bt *btree.BTree - lesser func(l, r interface{}) bool -} - -type item struct { - less func(l, r interface{}) bool - key interface{} - value interface{} -} - -func (me item) Less(right btree.Item) bool { - return me.less(me.key, right.(*item).key) -} - -func NewMap(lesser func(l, r interface{}) bool) *OrderedMap { - return &OrderedMap{ - bt: btree.New(32), - lesser: lesser, - } -} - -func (me *OrderedMap) Set(key interface{}, value interface{}) { - me.bt.ReplaceOrInsert(&item{me.lesser, key, value}) -} - -func (me *OrderedMap) Get(key interface{}) interface{} { - ret, _ := me.GetOk(key) - return ret -} - -func (me *OrderedMap) GetOk(key interface{}) (interface{}, bool) { - i := me.bt.Get(&item{me.lesser, key, nil}) - if i == nil { - return nil, false - } - return i.(*item).value, true -} - -// Callback receives a value and returns true if another value should be -// received or false to stop iteration. -type callback func(key, value interface{}) (more bool) - -func (me *OrderedMap) Iter(f callback) { - me.bt.Ascend(func(i btree.Item) bool { - return f(i.(*item).key, i.(*item).value) - }) -} - -func (me *OrderedMap) Unset(key interface{}) { - me.bt.Delete(&item{me.lesser, key, nil}) -} - -func (me *OrderedMap) Len() int { - return me.bt.Len() -} - -func (me *OrderedMap) MinKey() (interface{}, bool) { - min := me.bt.Min() - if min == nil { - return nil, false - } - return min.(*item).key, true -} - -func (me *OrderedMap) MaxKey() (interface{}, bool) { - max := me.bt.Max() - if max == nil { - return nil, false - } - return max.(*item).key, true -} diff --git a/util/orderedmap/ordered_map_test.go b/util/orderedmap/ordered_map_test.go deleted file mode 100644 index 83de39cb2..000000000 --- a/util/orderedmap/ordered_map_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package orderedmap - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func slice(om *OrderedMap) (ret []interface{}) { - om.Iter(func(k, v interface{}) bool { - ret = append(ret, om.Get(k)) - return true - }) - return -} - -func TestSimple(t *testing.T) { - om := NewMap(func(l, r interface{}) bool { - return l.(int) < r.(int) - }) - om.Set(3, 1) - om.Set(2, 2) - om.Set(1, 3) - assert.EqualValues(t, []interface{}{3, 2, 1}, slice(om)) - om.Set(3, 2) - om.Unset(2) - assert.EqualValues(t, []interface{}{3, 2}, slice(om)) - om.Set(-1, 4) - assert.EqualValues(t, []interface{}{4, 3, 2}, slice(om)) -} - -func TestIterEmpty(t *testing.T) { - om := NewMap(nil) - om.Iter(func(key, value interface{}) (more bool) { - assert.Fail(t, "Iterating empty map.") - return false - }) -} - -func TestGetMinMax(t *testing.T) { - om := NewMap(func(l, r interface{}) bool { - return l.(int) < r.(int) - }) - - _, ok := om.MinKey() - assert.False(t, ok) - - om.Set(3, 'a') - om.Set(5, 'b') - om.Set(1, 'c') - om.Set(4, 'd') - - min, ok := om.MinKey() - assert.True(t, ok) - assert.Equal(t, min, 1) - - max, ok := om.MaxKey() - assert.True(t, ok) - assert.Equal(t, max, 5) -} From dfc378d10dcc98331bf868926c45d51b7d12353b Mon Sep 17 00:00:00 2001 From: ragnarok87 Date: Tue, 13 Jun 2023 19:21:47 +0800 Subject: [PATCH 2/2] chore: adding comments for LinkedMap functions --- txpool/pool.go | 12 ++++++------ util/linkedmap/linkedmap.go | 32 +++++++++++++++++++++++--------- util/linkedmap/linkedmap_test.go | 24 ++++++++++++------------ 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/txpool/pool.go b/txpool/pool.go index 236853364..a13aa2c61 100644 --- a/txpool/pool.go +++ b/txpool/pool.go @@ -54,7 +54,7 @@ func (p *txPool) SetNewSandboxAndRecheck(sb sandbox.Sandbox) { var next *linkedmap.LinkNode[linkedmap.Pair[tx.ID, *tx.Tx]] for _, pool := range p.pools { - for e := pool.FirstNode(); e != nil; e = next { + for e := pool.HeadNode(); e != nil; e = next { next = e.Next trx := e.Data.Value @@ -152,31 +152,31 @@ func (p *txPool) PrepareBlockTransactions() block.Txs { // Appending one sortition transaction poolSortition := p.pools[payload.PayloadTypeSortition] - for n := poolSortition.FirstNode(); n != nil; n = n.Next { + for n := poolSortition.HeadNode(); n != nil; n = n.Next { trxs = append(trxs, n.Data.Value) } // Appending bond transactions poolBond := p.pools[payload.PayloadTypeBond] - for n := poolBond.FirstNode(); n != nil; n = n.Next { + for n := poolBond.HeadNode(); n != nil; n = n.Next { trxs = append(trxs, n.Data.Value) } // Appending unbond transactions poolUnbond := p.pools[payload.PayloadTypeUnbond] - for n := poolUnbond.FirstNode(); n != nil; n = n.Next { + for n := poolUnbond.HeadNode(); n != nil; n = n.Next { trxs = append(trxs, n.Data.Value) } // Appending withdraw transactions poolWithdraw := p.pools[payload.PayloadTypeWithdraw] - for n := poolWithdraw.FirstNode(); n != nil; n = n.Next { + for n := poolWithdraw.HeadNode(); n != nil; n = n.Next { trxs = append(trxs, n.Data.Value) } // Appending send transactions poolSend := p.pools[payload.PayloadTypeTransfer] - for n := poolSend.FirstNode(); n != nil; n = n.Next { + for n := poolSend.HeadNode(); n != nil; n = n.Next { trxs = append(trxs, n.Data.Value) } diff --git a/util/linkedmap/linkedmap.go b/util/linkedmap/linkedmap.go index fd1c39a9a..f6a53b3ac 100644 --- a/util/linkedmap/linkedmap.go +++ b/util/linkedmap/linkedmap.go @@ -1,7 +1,5 @@ package linkedmap -// TODO: should be thread safe or not? - type Pair[K comparable, V any] struct { Key K Value V @@ -13,6 +11,7 @@ type LinkedMap[K comparable, V any] struct { capacity int } +// NewLinkedMap creates a new LinkedMap with the specified capacity. func NewLinkedMap[K comparable, V any](capacity int) *LinkedMap[K, V] { return &LinkedMap[K, V]{ list: NewDoublyLinkedList[Pair[K, V]](), @@ -21,21 +20,24 @@ func NewLinkedMap[K comparable, V any](capacity int) *LinkedMap[K, V] { } } +// SetCapacity sets the capacity of the LinkedMap and prunes the excess elements if needed. func (lm *LinkedMap[K, V]) SetCapacity(capacity int) { lm.capacity = capacity lm.prune() } +// Has checks if the specified key exists in the LinkedMap. func (lm *LinkedMap[K, V]) Has(key K) bool { _, found := lm.hashmap[key] return found } +// PushBack adds a new key-value pair to the end of the LinkedMap. func (lm *LinkedMap[K, V]) PushBack(key K, value V) { ln, found := lm.hashmap[key] if found { - // update the second + // Update the value if the key already exists ln.Data.Value = value return } @@ -47,10 +49,11 @@ func (lm *LinkedMap[K, V]) PushBack(key K, value V) { lm.prune() } +// PushFront adds a new key-value pair to the beginning of the LinkedMap. func (lm *LinkedMap[K, V]) PushFront(key K, value V) { ln, found := lm.hashmap[key] if found { - // update the second + // Update the value if the key already exists ln.Data.Value = value return } @@ -62,6 +65,7 @@ func (lm *LinkedMap[K, V]) PushFront(key K, value V) { lm.prune() } +// GetNode returns the LinkNode corresponding to the specified key. func (lm *LinkedMap[K, V]) GetNode(key K) *LinkNode[Pair[K, V]] { ln, found := lm.hashmap[key] if found { @@ -70,7 +74,8 @@ func (lm *LinkedMap[K, V]) GetNode(key K) *LinkNode[Pair[K, V]] { return nil } -func (lm *LinkedMap[K, V]) LastNode() *LinkNode[Pair[K, V]] { +// TailNode returns the LinkNode at the end (tail) of the LinkedMap. +func (lm *LinkedMap[K, V]) TailNode() *LinkNode[Pair[K, V]] { ln := lm.list.Tail if ln == nil { return nil @@ -78,7 +83,8 @@ func (lm *LinkedMap[K, V]) LastNode() *LinkNode[Pair[K, V]] { return ln } -func (lm *LinkedMap[K, V]) FirstNode() *LinkNode[Pair[K, V]] { +// HeadNode returns the LinkNode at the beginning (head) of the LinkedMap. +func (lm *LinkedMap[K, V]) HeadNode() *LinkNode[Pair[K, V]] { ln := lm.list.Head if ln == nil { return nil @@ -86,36 +92,44 @@ func (lm *LinkedMap[K, V]) FirstNode() *LinkNode[Pair[K, V]] { return ln } +// Remove removes the key-value pair with the specified key from the LinkedMap. +// It returns true if the key was found and removed, otherwise false. func (lm *LinkedMap[K, V]) Remove(key K) bool { - nl, found := lm.hashmap[key] + ln, found := lm.hashmap[key] if found { - lm.list.Delete(nl) - delete(lm.hashmap, nl.Data.Key) + lm.list.Delete(ln) + delete(lm.hashmap, ln.Data.Key) } return found } +// Empty checks if the LinkedMap is empty (contains no key-value pairs). func (lm *LinkedMap[K, V]) Empty() bool { return lm.Size() == 0 } +// Capacity returns the capacity of the LinkedMap. func (lm *LinkedMap[K, V]) Capacity() int { return lm.capacity } +// Size returns the number of key-value pairs in the LinkedMap. func (lm *LinkedMap[K, V]) Size() int { return lm.list.Length() } +// Full checks if the LinkedMap is full (reached its capacity). func (lm *LinkedMap[K, V]) Full() bool { return lm.list.Length() == lm.capacity } +// Clear removes all key-value pairs from the LinkedMap, making it empty. func (lm *LinkedMap[K, V]) Clear() { lm.list.Clear() lm.hashmap = make(map[K]*LinkNode[Pair[K, V]]) } +// prune removes excess elements from the LinkedMap if its size exceeds the capacity. func (lm *LinkedMap[K, V]) prune() { for lm.list.Length() > lm.capacity { front := lm.list.Head diff --git a/util/linkedmap/linkedmap_test.go b/util/linkedmap/linkedmap_test.go index 68fda44b8..5f6b7c9a7 100644 --- a/util/linkedmap/linkedmap_test.go +++ b/util/linkedmap/linkedmap_test.go @@ -10,26 +10,26 @@ import ( func TestLinkedMap(t *testing.T) { t.Run("Test FirstNode", func(t *testing.T) { lm := NewLinkedMap[int, string](4) - assert.Nil(t, lm.FirstNode()) + assert.Nil(t, lm.HeadNode()) lm.PushFront(3, "c") lm.PushFront(2, "b") lm.PushFront(1, "a") - assert.Equal(t, lm.FirstNode().Data.Key, 1) - assert.Equal(t, lm.FirstNode().Data.Value, "a") + assert.Equal(t, lm.HeadNode().Data.Key, 1) + assert.Equal(t, lm.HeadNode().Data.Value, "a") }) t.Run("Test LastNode", func(t *testing.T) { lm := NewLinkedMap[int, string](4) - assert.Nil(t, lm.LastNode()) + assert.Nil(t, lm.TailNode()) lm.PushBack(1, "a") lm.PushBack(2, "b") lm.PushBack(3, "c") - assert.Equal(t, lm.LastNode().Data.Key, 3) - assert.Equal(t, lm.LastNode().Data.Value, "c") + assert.Equal(t, lm.TailNode().Data.Key, 3) + assert.Equal(t, lm.TailNode().Data.Value, "c") }) t.Run("Test Get", func(t *testing.T) { @@ -118,7 +118,7 @@ func TestLinkedMap(t *testing.T) { lm.PushBack(3, "c") lm.PushBack(4, "d") - n := lm.FirstNode() + n := lm.HeadNode() assert.Equal(t, n.Data.Key, 2) assert.Equal(t, n.Data.Value, "b") }) @@ -131,7 +131,7 @@ func TestLinkedMap(t *testing.T) { lm.PushFront(3, "c") lm.PushFront(4, "d") // This item should be pruned - n := lm.LastNode() + n := lm.TailNode() assert.Equal(t, n.Data.Key, 1) assert.Equal(t, n.Data.Value, "a") }) @@ -145,8 +145,8 @@ func TestLinkedMap(t *testing.T) { lm.Remove(1) - assert.Equal(t, lm.FirstNode().Data.Key, 2) - assert.Equal(t, lm.FirstNode().Data.Value, "b") + assert.Equal(t, lm.HeadNode().Data.Key, 2) + assert.Equal(t, lm.HeadNode().Data.Value, "b") }) t.Run("Delete last", func(t *testing.T) { @@ -158,8 +158,8 @@ func TestLinkedMap(t *testing.T) { lm.Remove(3) - assert.Equal(t, lm.LastNode().Data.Key, 2) - assert.Equal(t, lm.LastNode().Data.Value, "b") + assert.Equal(t, lm.TailNode().Data.Key, 2) + assert.Equal(t, lm.TailNode().Data.Value, "b") }) t.Run("Test Has function", func(t *testing.T) {