diff --git a/txnsync/incoming.go b/txnsync/incoming.go index 8cf349bf46..83ac416f1a 100644 --- a/txnsync/incoming.go +++ b/txnsync/incoming.go @@ -20,8 +20,6 @@ import ( "errors" "time" - "github.com/algorand/go-deadlock" - "github.com/algorand/go-algorand/data/pooldata" ) @@ -43,70 +41,6 @@ type incomingMessage struct { timeReceived int64 } -// incomingMessageQueue manages the global incoming message queue across all the incoming peers. -type incomingMessageQueue struct { - incomingMessages chan incomingMessage - enqueuedPeers map[*Peer]struct{} - enqueuedPeersMu deadlock.Mutex -} - -// maxPeersCount defines the maximum number of supported peers that can have their messages waiting -// in the incoming message queue at the same time. This number can be lower then the actual number of -// connected peers, as it's used only for pending messages. -const maxPeersCount = 1024 - -// makeIncomingMessageQueue creates an incomingMessageQueue object and initializes all the internal variables. -func makeIncomingMessageQueue() incomingMessageQueue { - return incomingMessageQueue{ - incomingMessages: make(chan incomingMessage, maxPeersCount), - enqueuedPeers: make(map[*Peer]struct{}, maxPeersCount), - } -} - -// getIncomingMessageChannel returns the incoming messages channel, which would contain entries once -// we have one ( or more ) pending incoming messages. -func (imq *incomingMessageQueue) getIncomingMessageChannel() <-chan incomingMessage { - return imq.incomingMessages -} - -// enqueue places the given message on the queue, if and only if it's associated peer doesn't -// appear on the incoming message queue already. In the case there is no peer, the message -// would be placed on the queue as is. -// The method returns false if the incoming message doesn't have it's peer on the queue and -// the method has failed to place the message on the queue. True is returned otherwise. -func (imq *incomingMessageQueue) enqueue(m incomingMessage) bool { - if m.peer != nil { - imq.enqueuedPeersMu.Lock() - defer imq.enqueuedPeersMu.Unlock() - if _, has := imq.enqueuedPeers[m.peer]; has { - return true - } - } - select { - case imq.incomingMessages <- m: - // if we successfully enqueued the message, set the enqueuedPeers so that we won't enqueue the same peer twice. - if m.peer != nil { - // at this time, the enqueuedPeersMu is still under lock ( due to the above defer ), so we can access - // the enqueuedPeers here. - imq.enqueuedPeers[m.peer] = struct{}{} - } - return true - default: - return false - } -} - -// clear removes the peer that is associated with the message ( if any ) from -// the enqueuedPeers map, allowing future messages from this peer to be placed on the -// incoming message queue. -func (imq *incomingMessageQueue) clear(m incomingMessage) { - if m.peer != nil { - imq.enqueuedPeersMu.Lock() - defer imq.enqueuedPeersMu.Unlock() - delete(imq.enqueuedPeers, m.peer) - } -} - // incomingMessageHandler // note - this message is called by the network go-routine dispatch pool, and is not synchronized with the rest of the transaction synchronizer func (s *syncState) asyncIncomingMessageHandler(networkPeer interface{}, peer *Peer, message []byte, sequenceNumber uint64, receivedTimestamp int64) (err error) { @@ -126,12 +60,14 @@ func (s *syncState) asyncIncomingMessageHandler(networkPeer interface{}, peer *P if err != nil { // if we received a message that we cannot parse, disconnect. s.log.Infof("received unparsable transaction sync message from peer. disconnecting from peer.") + s.incomingMessagesQ.erase(peer, networkPeer) return err } if incomingMessage.message.Version != txnBlockMessageVersion { // we receive a message from a version that we don't support, disconnect. s.log.Infof("received unsupported transaction sync message version from peer (%d). disconnecting from peer.", incomingMessage.message.Version) + s.incomingMessagesQ.erase(peer, networkPeer) return errUnsupportedTransactionSyncMessageVersion } @@ -140,6 +76,7 @@ func (s *syncState) asyncIncomingMessageHandler(networkPeer interface{}, peer *P bloomFilter, err := decodeBloomFilter(incomingMessage.message.TxnBloomFilter) if err != nil { s.log.Infof("Invalid bloom filter received from peer : %v", err) + s.incomingMessagesQ.erase(peer, networkPeer) return errInvalidBloomFilter } incomingMessage.bloomFilter = bloomFilter @@ -151,6 +88,7 @@ func (s *syncState) asyncIncomingMessageHandler(networkPeer interface{}, peer *P incomingMessage.transactionGroups, err = decodeTransactionGroups(incomingMessage.message.TransactionGroups, s.genesisID, s.genesisHash) if err != nil { s.log.Infof("failed to decode received transactions groups: %v\n", err) + s.incomingMessagesQ.erase(peer, networkPeer) return errDecodingReceivedTransactionGroupsFailed } @@ -159,10 +97,21 @@ func (s *syncState) asyncIncomingMessageHandler(networkPeer interface{}, peer *P // all the peer objects are created synchronously. enqueued := s.incomingMessagesQ.enqueue(incomingMessage) if !enqueued { - // if we can't enqueue that, return an error, which would disconnect the peer. - // ( we have to disconnect, since otherwise, we would have no way to synchronize the sequence number) - s.log.Infof("unable to enqueue incoming message from a peer without txsync allocated data; incoming messages queue is full. disconnecting from peer.") - return errTransactionSyncIncomingMessageQueueFull + // if we failed to enqueue, it means that the queue is full. Try to remove disconnected + // peers from the queue before re-attempting. + peers := s.node.GetPeers() + if s.incomingMessagesQ.prunePeers(peers) { + // if we were successful in removing at least a single peer, then try to add the entry again. + enqueued = s.incomingMessagesQ.enqueue(incomingMessage) + } + if !enqueued { + // if we can't enqueue that, return an error, which would disconnect the peer. + // ( we have to disconnect, since otherwise, we would have no way to synchronize the sequence number) + s.log.Infof("unable to enqueue incoming message from a peer without txsync allocated data; incoming messages queue is full. disconnecting from peer.") + s.incomingMessagesQ.erase(peer, networkPeer) + return errTransactionSyncIncomingMessageQueueFull + } + } return nil } @@ -171,15 +120,26 @@ func (s *syncState) asyncIncomingMessageHandler(networkPeer interface{}, peer *P if err != nil { // if the incoming message queue for this peer is full, disconnect from this peer. s.log.Infof("unable to enqueue incoming message into peer incoming message backlog. disconnecting from peer.") + s.incomingMessagesQ.erase(peer, networkPeer) return err } // (maybe) place the peer message on the main queue. This would get skipped if the peer is already on the queue. enqueued := s.incomingMessagesQ.enqueue(incomingMessage) if !enqueued { - // if we can't enqueue that, return an error, which would disconnect the peer. - s.log.Infof("unable to enqueue incoming message from a peer with txsync allocated data; incoming messages queue is full. disconnecting from peer.") - return errTransactionSyncIncomingMessageQueueFull + // if we failed to enqueue, it means that the queue is full. Try to remove disconnected + // peers from the queue before re-attempting. + peers := s.node.GetPeers() + if s.incomingMessagesQ.prunePeers(peers) { + // if we were successful in removing at least a single peer, then try to add the entry again. + enqueued = s.incomingMessagesQ.enqueue(incomingMessage) + } + if !enqueued { + // if we can't enqueue that, return an error, which would disconnect the peer. + s.log.Infof("unable to enqueue incoming message from a peer with txsync allocated data; incoming messages queue is full. disconnecting from peer.") + s.incomingMessagesQ.erase(peer, networkPeer) + return errTransactionSyncIncomingMessageQueueFull + } } return nil } @@ -208,9 +168,7 @@ func (s *syncState) evaluateIncomingMessage(message incomingMessage) { return } } - // clear the peer that is associated with this incoming message from the message queue, allowing future - // messages from the peer to be placed on the message queue. - s.incomingMessagesQ.clear(message) + messageProcessed := false transactionPoolSize := 0 totalAccumulatedTransactionsCount := 0 // the number of transactions that were added during the execution of this method diff --git a/txnsync/incomingMsgQ.go b/txnsync/incomingMsgQ.go new file mode 100644 index 0000000000..aecb673bdd --- /dev/null +++ b/txnsync/incomingMsgQ.go @@ -0,0 +1,372 @@ +// Copyright (C) 2019-2021 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package txnsync + +import ( + "sync" + + "github.com/algorand/go-deadlock" +) + +// queuedMsgEntry used as a helper struct to manage the manipulation of incoming +// message queue. +type queuedMsgEntry struct { + msg incomingMessage + next *queuedMsgEntry + prev *queuedMsgEntry +} + +type queuedMsgList struct { + head *queuedMsgEntry +} + +// incomingMessageQueue manages the global incoming message queue across all the incoming peers. +type incomingMessageQueue struct { + outboundPeerCh chan incomingMessage + enqueuedPeersMap map[*Peer]*queuedMsgEntry + messages queuedMsgList + freelist queuedMsgList + enqueuedPeersMu deadlock.Mutex + enqueuedPeersCond *sync.Cond + shutdownRequest chan struct{} + shutdownConfirmed chan struct{} + deletePeersCh chan interface{} + peerlessCount int +} + +// maxPeersCount defines the maximum number of supported peers that can have their messages waiting +// in the incoming message queue at the same time. This number can be lower then the actual number of +// connected peers, as it's used only for pending messages. +const maxPeersCount = 2048 + +// maxPeerlessCount is the number of messages that we've received that doesn't have a Peer object allocated +// for them ( yet ) +const maxPeerlessCount = 512 + +// makeIncomingMessageQueue creates an incomingMessageQueue object and initializes all the internal variables. +func makeIncomingMessageQueue() *incomingMessageQueue { + imq := &incomingMessageQueue{ + outboundPeerCh: make(chan incomingMessage), + enqueuedPeersMap: make(map[*Peer]*queuedMsgEntry, maxPeersCount), + shutdownRequest: make(chan struct{}, 1), + shutdownConfirmed: make(chan struct{}, 1), + deletePeersCh: make(chan interface{}), + } + imq.enqueuedPeersCond = sync.NewCond(&imq.enqueuedPeersMu) + imq.freelist.initialize(maxPeersCount) + go imq.messagePump() + return imq +} + +// dequeueHead removes the first head message from the linked list. +func (ml *queuedMsgList) dequeueHead() (out *queuedMsgEntry) { + if ml.head == nil { + return nil + } + entry := ml.head + out = entry + if entry.next == entry { + ml.head = nil + return + } + entry.next.prev = entry.prev + entry.prev.next = entry.next + ml.head = entry.next + out.next = out + out.prev = out + return +} + +// initialize initializes a list to have msgCount entries. +func (ml *queuedMsgList) initialize(msgCount int) { + msgs := make([]queuedMsgEntry, msgCount) + for i := 0; i < msgCount; i++ { + msgs[i].next = &msgs[(i+1)%msgCount] + msgs[i].prev = &msgs[(i+msgCount-1)%msgCount] + } + ml.head = &msgs[0] +} + +// empty methods tests to see if the linked list is empty +func (ml *queuedMsgList) empty() bool { + return ml.head == nil +} + +// remove removes the given msg from the linked list. The method +// is written with the assumption that the given msg is known to be +// part of the linked list. +func (ml *queuedMsgList) remove(msg *queuedMsgEntry) { + if msg.next == msg { + ml.head = nil + return + } + msg.prev.next = msg.next + msg.next.prev = msg.prev + if ml.head == msg { + ml.head = msg.next + } + msg.prev = msg + msg.next = msg +} + +// filterRemove removes zero or more messages from the linked list, for which the given +// removeFunc returns true. The removed linked list entries are returned as a linked list. +func (ml *queuedMsgList) filterRemove(removeFunc func(*queuedMsgEntry) bool) *queuedMsgEntry { + if ml.empty() { + return nil + } + // do we have a single item ? + if ml.head.next == ml.head { + if removeFunc(ml.head) { + out := ml.head + ml.head = nil + return out + } + return nil + } + current := ml.head + last := ml.head.prev + var letGo queuedMsgList + for { + next := current.next + if removeFunc(current) { + ml.remove(current) + letGo.enqueueTail(current) + } + if current == last { + break + } + current = next + } + return letGo.head +} + +// enqueueTail adds to the current linked list another linked list whose head is msg. +func (ml *queuedMsgList) enqueueTail(msg *queuedMsgEntry) { + if ml.head == nil { + ml.head = msg + return + } else if msg == nil { + return + } + lastEntryOld := ml.head.prev + lastEntryNew := msg.prev + lastEntryOld.next = msg + ml.head.prev = lastEntryNew + msg.prev = lastEntryOld + lastEntryNew.next = ml.head +} + +// shutdown signals to the message pump to shut down and waits until the message pump goroutine +// aborts. +func (imq *incomingMessageQueue) shutdown() { + imq.enqueuedPeersMu.Lock() + close(imq.shutdownRequest) + imq.enqueuedPeersCond.Signal() + imq.enqueuedPeersMu.Unlock() + <-imq.shutdownConfirmed +} + +// messagePump is the incoming message queue message pump. It takes messages from the messages list +// and attempt to write these to the outboundPeerCh. +func (imq *incomingMessageQueue) messagePump() { + defer close(imq.shutdownConfirmed) + imq.enqueuedPeersMu.Lock() + defer imq.enqueuedPeersMu.Unlock() + + for { + // check if we need to shutdown. + select { + case <-imq.shutdownRequest: + return + default: + } + + // do we have any item to enqueue ? + if !imq.messages.empty() { + msgEntry := imq.messages.dequeueHead() + msg := msgEntry.msg + imq.freelist.enqueueTail(msgEntry) + if msg.peer != nil { + delete(imq.enqueuedPeersMap, msg.peer) + } else { + imq.peerlessCount-- + } + imq.enqueuedPeersMu.Unlock() + writeOutboundMessage: + select { + case imq.outboundPeerCh <- msg: + imq.enqueuedPeersMu.Lock() + continue + case <-imq.shutdownRequest: + imq.enqueuedPeersMu.Lock() + return + // see if this msg need to be delivered or not. + case droppedPeer := <-imq.deletePeersCh: + if msg.networkPeer == droppedPeer { + // we want to skip this message. + imq.enqueuedPeersMu.Lock() + continue + } + goto writeOutboundMessage + } + } + imq.enqueuedPeersCond.Wait() + } +} + +// getIncomingMessageChannel returns the incoming messages channel, which would contain entries once +// we have one ( or more ) pending incoming messages. +func (imq *incomingMessageQueue) getIncomingMessageChannel() <-chan incomingMessage { + return imq.outboundPeerCh +} + +// enqueue places the given message on the queue, if and only if it's associated peer doesn't +// appear on the incoming message queue already. In the case there is no peer, the message +// would be placed on the queue as is. +// The method returns false if the incoming message doesn't have it's peer on the queue and +// the method has failed to place the message on the queue. True is returned otherwise. +func (imq *incomingMessageQueue) enqueue(m incomingMessage) bool { + imq.enqueuedPeersMu.Lock() + defer imq.enqueuedPeersMu.Unlock() + if m.peer != nil { + if _, has := imq.enqueuedPeersMap[m.peer]; has { + return true + } + } else { + // do we have enough "room" for peerless messages ? + if imq.peerlessCount >= maxPeerlessCount { + return false + } + } + // do we have enough room in the message queue for the new message ? + if imq.freelist.empty() { + // no - we don't have enough room in the circular buffer. + return false + } + freeMsgEntry := imq.freelist.dequeueHead() + freeMsgEntry.msg = m + imq.messages.enqueueTail(freeMsgEntry) + // if we successfully enqueued the message, set the enqueuedPeersMap so that we won't enqueue the same peer twice. + if m.peer != nil { + imq.enqueuedPeersMap[m.peer] = freeMsgEntry + } else { + imq.peerlessCount++ + } + imq.enqueuedPeersCond.Signal() + return true +} + +// erase removes all the entries associated with the given network peer. +// this method isn't very efficient, and should be used only in cases where +// we disconnect from a peer and want to cleanup all the pending tasks associated +// with that peer. +func (imq *incomingMessageQueue) erase(peer *Peer, networkPeer interface{}) { + imq.enqueuedPeersMu.Lock() + + var peerMsgEntry *queuedMsgEntry + if peer == nil { + // lookup for a Peer object. + for peer, peerMsgEntry = range imq.enqueuedPeersMap { + if peer.networkPeer != networkPeer { + continue + } + break + } + } else { + var has bool + if peerMsgEntry, has = imq.enqueuedPeersMap[peer]; !has { + // the peer object is not in the map. + peer = nil + } + } + + if peer != nil { + delete(imq.enqueuedPeersMap, peer) + imq.messages.remove(peerMsgEntry) + imq.freelist.enqueueTail(peerMsgEntry) + imq.enqueuedPeersMu.Unlock() + select { + case imq.deletePeersCh <- networkPeer: + default: + } + return + } + + imq.removeMessageByNetworkPeer(networkPeer) + imq.enqueuedPeersMu.Unlock() + select { + case imq.deletePeersCh <- networkPeer: + default: + } +} + +// removeMessageByNetworkPeer removes the messages associated with the given network peer from the +// queue. +// note : the method expect that the enqueuedPeersMu lock would be taken. +func (imq *incomingMessageQueue) removeMessageByNetworkPeer(networkPeer interface{}) { + peerlessCount := 0 + removeByNetworkPeer := func(msg *queuedMsgEntry) bool { + if msg.msg.networkPeer == networkPeer { + if msg.msg.peer == nil { + peerlessCount++ + } + return true + } + return false + } + removeList := imq.messages.filterRemove(removeByNetworkPeer) + imq.freelist.enqueueTail(removeList) + imq.peerlessCount -= peerlessCount +} + +// prunePeers removes from the enqueuedMessages queue all the entries that are not provided in the +// given activePeers slice. +func (imq *incomingMessageQueue) prunePeers(activePeers []PeerInfo) (peerRemoved bool) { + activePeersMap := make(map[*Peer]bool) + activeNetworkPeersMap := make(map[interface{}]bool) + for _, activePeer := range activePeers { + if activePeer.TxnSyncPeer != nil { + activePeersMap[activePeer.TxnSyncPeer] = true + } + if activePeer.NetworkPeer != nil { + activeNetworkPeersMap[activePeer.NetworkPeer] = true + } + } + imq.enqueuedPeersMu.Lock() + defer imq.enqueuedPeersMu.Unlock() + peerlessCount := 0 + isPeerMissing := func(msg *queuedMsgEntry) bool { + if msg.msg.peer != nil { + if !activePeersMap[msg.msg.peer] { + return true + } + } + if !activeNetworkPeersMap[msg.msg.networkPeer] { + if msg.msg.peer == nil { + peerlessCount++ + } + return true + } + return false + } + removeList := imq.messages.filterRemove(isPeerMissing) + peerRemoved = removeList != nil + imq.freelist.enqueueTail(removeList) + imq.peerlessCount -= peerlessCount + return +} diff --git a/txnsync/incomingMsgQ_test.go b/txnsync/incomingMsgQ_test.go new file mode 100644 index 0000000000..9ff1adf4b4 --- /dev/null +++ b/txnsync/incomingMsgQ_test.go @@ -0,0 +1,158 @@ +// Copyright (C) 2019-2021 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see . + +package txnsync + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/algorand/go-algorand/test/partitiontest" +) + +// fillMessageQueue fills the message queue with the given message. +func (imq *incomingMessageQueue) fillMessageQueue(msg incomingMessage) { + imq.enqueuedPeersMu.Lock() + for i := 0; i < maxPeersCount; i++ { + msgEntry := imq.freelist.dequeueHead() + msgEntry.msg = msg + imq.messages.enqueueTail(msgEntry) + } + if msg.peer == nil { + imq.peerlessCount += maxPeersCount + } + imq.enqueuedPeersCond.Signal() + imq.enqueuedPeersMu.Unlock() + + // wait for a single message to be consumed by the message pump. + for { + imq.enqueuedPeersMu.Lock() + if !imq.freelist.empty() { + break + } + imq.enqueuedPeersMu.Unlock() + time.Sleep(time.Millisecond) + } + for !imq.freelist.empty() { + msgEntry := imq.freelist.dequeueHead() + msgEntry.msg = msg + imq.messages.enqueueTail(msgEntry) + } + imq.enqueuedPeersCond.Signal() + imq.enqueuedPeersMu.Unlock() +} + +// count counts teh number of messages in the list +func (ml *queuedMsgList) count() int { + first := ml.head + cur := first + count := 0 + for cur != nil { + next := cur.next + if next == first { + next = nil + } + count++ + cur = next + } + return count +} + +// validateLinking test to see the the entries in the list are correctly connected. +func (ml *queuedMsgList) validateLinking(t *testing.T) { + cur := ml.head + if cur == nil { + return + } + seen := make(map[*queuedMsgEntry]bool) + list := make([]*queuedMsgEntry, 0) + for { + if seen[cur] { + break + } + seen[cur] = true + require.NotNil(t, cur.prev) + require.NotNil(t, cur.next) + list = append(list, cur) + cur = cur.next + } + for i := range list { + require.Equal(t, list[i], list[(i+len(list)-1)%len(list)].next) + require.Equal(t, list[i], list[(i+1)%len(list)].prev) + } +} + +// TestMsgQCounts tests the message queue add/remove manipulations +func TestMsgQCounts(t *testing.T) { + partitiontest.PartitionTest(t) + + var list queuedMsgList + list.initialize(7) + list.validateLinking(t) + require.Equal(t, 7, list.count()) + list.dequeueHead() + list.validateLinking(t) + require.Equal(t, 6, list.count()) + var anotherList queuedMsgList + anotherList.initialize(4) + require.Equal(t, 4, anotherList.count()) + list.enqueueTail(anotherList.head) + list.validateLinking(t) + require.Equal(t, 10, list.count()) +} + +// TestMsgQFiltering tests the message queue filtering +func TestMsgQFiltering(t *testing.T) { + partitiontest.PartitionTest(t) + + item1 := &queuedMsgEntry{} + item2 := &queuedMsgEntry{} + item3 := &queuedMsgEntry{} + item1.next = item1 + item1.prev = item1 + item2.next = item2 + item2.prev = item2 + item3.next = item3 + item3.prev = item3 + + var list queuedMsgList + list.enqueueTail(item1) + list.enqueueTail(item2) + list.enqueueTail(item3) + + // test removing head. + removedItem1 := list.filterRemove(func(msg *queuedMsgEntry) bool { + return msg == item1 + }) + require.Equal(t, item1, removedItem1) + require.Equal(t, 2, list.count()) + + // test removing tail + removedItem3 := list.filterRemove(func(msg *queuedMsgEntry) bool { + return msg == item3 + }) + require.Equal(t, item3, removedItem3) + require.Equal(t, 1, list.count()) + + // test removing last item + removedItem2 := list.filterRemove(func(msg *queuedMsgEntry) bool { + return msg == item2 + }) + require.Equal(t, item2, removedItem2) + require.True(t, list.empty()) +} diff --git a/txnsync/incoming_test.go b/txnsync/incoming_test.go index e8a213dd3f..978a1c4cae 100644 --- a/txnsync/incoming_test.go +++ b/txnsync/incoming_test.go @@ -58,9 +58,10 @@ func TestAsyncIncomingMessageHandlerAndErrors(t *testing.T) { cfg := config.GetDefaultLocal() mNodeConnector := &mockNodeConnector{transactionPoolSize: 3} s := syncState{ - log: wrapLogger(&incLogger, &cfg), - node: mNodeConnector, - clock: mNodeConnector.Clock(), + log: wrapLogger(&incLogger, &cfg), + node: mNodeConnector, + clock: mNodeConnector.Clock(), + incomingMessagesQ: makeIncomingMessageQueue(), } // expect UnmarshalMsg error @@ -92,24 +93,34 @@ func TestAsyncIncomingMessageHandlerAndErrors(t *testing.T) { messageBytes = message.MarshalMsg(nil) err = s.asyncIncomingMessageHandler(nil, nil, messageBytes, sequenceNumber, 0) require.Equal(t, errDecodingReceivedTransactionGroupsFailed, err) + s.incomingMessagesQ.shutdown() + + peer := Peer{networkPeer: &s} // error queue full message.TransactionGroups = packedTransactionGroups{} messageBytes = message.MarshalMsg(nil) + s.incomingMessagesQ = makeIncomingMessageQueue() + s.incomingMessagesQ.fillMessageQueue(incomingMessage{peer: &peer, networkPeer: &s.incomingMessagesQ}) + mNodeConnector.peers = append(mNodeConnector.peers, PeerInfo{TxnSyncPeer: &peer, NetworkPeer: &s.incomingMessagesQ}) err = s.asyncIncomingMessageHandler(nil, nil, messageBytes, sequenceNumber, 0) require.Equal(t, errTransactionSyncIncomingMessageQueueFull, err) + s.incomingMessagesQ.shutdown() // Success where peer == nil s.incomingMessagesQ = makeIncomingMessageQueue() err = s.asyncIncomingMessageHandler(nil, nil, messageBytes, sequenceNumber, 0) require.NoError(t, err) - - peer := Peer{} + s.incomingMessagesQ.shutdown() // error when placing the peer message on the main queue (incomingMessages cannot accept messages) - s.incomingMessagesQ = incomingMessageQueue{} + s.incomingMessagesQ = makeIncomingMessageQueue() + s.incomingMessagesQ.fillMessageQueue(incomingMessage{peer: nil, networkPeer: &s}) + mNodeConnector.peers = append(mNodeConnector.peers, PeerInfo{NetworkPeer: &s}) + err = s.asyncIncomingMessageHandler(nil, &peer, messageBytes, sequenceNumber, 0) require.Equal(t, errTransactionSyncIncomingMessageQueueFull, err) + s.incomingMessagesQ.shutdown() s.incomingMessagesQ = makeIncomingMessageQueue() err = nil @@ -119,6 +130,7 @@ func TestAsyncIncomingMessageHandlerAndErrors(t *testing.T) { err = s.asyncIncomingMessageHandler(nil, &peer, messageBytes, sequenceNumber, 0) } require.Equal(t, errHeapReachedCapacity, err) + s.incomingMessagesQ.shutdown() } func TestEvaluateIncomingMessagePart1(t *testing.T) { @@ -151,28 +163,29 @@ func TestEvaluateIncomingMessagePart1(t *testing.T) { mNodeConnector.updatingPeers = false s.incomingMessagesQ = makeIncomingMessageQueue() - // Add a peer here, and make sure it is cleared - s.incomingMessagesQ.enqueuedPeers[peer] = struct{}{} + defer s.incomingMessagesQ.shutdown() + message.peer = peer + require.True(t, s.incomingMessagesQ.enqueue(message)) mNodeConnector.peerInfo.TxnSyncPeer = peer peer.incomingMessages = messageOrderingHeap{} // TxnSyncPeer in peerInfo s.evaluateIncomingMessage(message) require.False(t, mNodeConnector.updatingPeers) - _, found := s.incomingMessagesQ.enqueuedPeers[peer] + <-s.incomingMessagesQ.getIncomingMessageChannel() + _, found := s.incomingMessagesQ.enqueuedPeersMap[peer] require.False(t, found) - // fill the hip with messageOrderingHeapLimit elements so that the incomingMessages enqueue fails + // fill the heap with messageOrderingHeapLimit elements so that the incomingMessages enqueue fails + message.networkPeer = &s + message.peer = nil for x := 0; x < messageOrderingHeapLimit; x++ { err := peer.incomingMessages.enqueue(message) require.NoError(t, err) } - // Add a peer here, and make sure it is not cleared after the error - s.incomingMessagesQ.enqueuedPeers[peer] = struct{}{} + mNodeConnector.peers = []PeerInfo{{TxnSyncPeer: peer, NetworkPeer: &s}} // TxnSyncPeer in peerInfo s.evaluateIncomingMessage(message) require.False(t, mNodeConnector.updatingPeers) - _, found = s.incomingMessagesQ.enqueuedPeers[peer] - require.True(t, found) } func TestEvaluateIncomingMessagePart2(t *testing.T) { diff --git a/txnsync/mainloop.go b/txnsync/mainloop.go index 00413aaabf..138206be0f 100644 --- a/txnsync/mainloop.go +++ b/txnsync/mainloop.go @@ -70,7 +70,7 @@ type syncState struct { scheduler peerScheduler interruptablePeers []*Peer interruptablePeersMap map[*Peer]int // map a peer into the index of interruptablePeers - incomingMessagesQ incomingMessageQueue + incomingMessagesQ *incomingMessageQueue outgoingMessagesCallbackCh chan sentMessageMetadata nextOffsetRollingCh <-chan time.Time requestsOffset uint64 @@ -105,6 +105,7 @@ func (s *syncState) mainloop(serviceCtx context.Context, wg *sync.WaitGroup) { s.clock = s.node.Clock() s.incomingMessagesQ = makeIncomingMessageQueue() + defer s.incomingMessagesQ.shutdown() s.outgoingMessagesCallbackCh = make(chan sentMessageMetadata, 1024) s.interruptablePeersMap = make(map[*Peer]int) s.scheduler.node = s.node diff --git a/txnsync/outgoing.go b/txnsync/outgoing.go index bfb685c6fe..fa18bd3535 100644 --- a/txnsync/outgoing.go +++ b/txnsync/outgoing.go @@ -67,6 +67,7 @@ type messageAsyncEncoder struct { func (encoder *messageAsyncEncoder) asyncMessageSent(enqueued bool, sequenceNumber uint64) error { if !enqueued { encoder.state.log.Infof("unable to send message to peer. disconnecting from peer.") + encoder.state.incomingMessagesQ.erase(encoder.messageData.peer, encoder.messageData.peer.networkPeer) return errTransactionSyncOutgoingMessageSendFailed } // record the sequence number here, so that we can store that later on. @@ -78,6 +79,7 @@ func (encoder *messageAsyncEncoder) asyncMessageSent(enqueued bool, sequenceNumb default: // if we can't place it on the channel, return an error so that the node could disconnect from this peer. encoder.state.log.Infof("unable to enqueue outgoing message confirmation; outgoingMessagesCallbackCh is full. disconnecting from peer.") + encoder.state.incomingMessagesQ.erase(encoder.messageData.peer, encoder.messageData.peer.networkPeer) return errTransactionSyncOutgoingMessageQueueFull } } diff --git a/txnsync/outgoing_test.go b/txnsync/outgoing_test.go index 493242448f..3136161ef3 100644 --- a/txnsync/outgoing_test.go +++ b/txnsync/outgoing_test.go @@ -62,6 +62,8 @@ func TestAsyncMessageSent(t *testing.T) { var s syncState s.clock = timers.MakeMonotonicClock(time.Now()) s.log = mockAsyncLogger{} + s.incomingMessagesQ = makeIncomingMessageQueue() + defer s.incomingMessagesQ.shutdown() asyncEncoder := messageAsyncEncoder{ state: &s, diff --git a/txnsync/service_test.go b/txnsync/service_test.go index b0cddbcc5e..2b262fc185 100644 --- a/txnsync/service_test.go +++ b/txnsync/service_test.go @@ -43,6 +43,7 @@ type mockNodeConnector struct { peerInfo PeerInfo updatingPeers bool transactionPoolSize int + peers []PeerInfo } func makeMockNodeConnector(calledEvents *bool) mockNodeConnector { @@ -69,7 +70,7 @@ func (fn *mockNodeConnector) Random(rng uint64) uint64 { return rv % rng } -func (fn *mockNodeConnector) GetPeers() []PeerInfo { return nil } +func (fn *mockNodeConnector) GetPeers() []PeerInfo { return fn.peers } func (fn *mockNodeConnector) GetPeer(interface{}) (out PeerInfo) { return fn.peerInfo