Skip to content

Commit

Permalink
Problem: mempool iteration is not thread safe
Browse files Browse the repository at this point in the history
Closes: cosmos#675

Solution:
- hold the lock during iteration
  • Loading branch information
yihuang committed Aug 27, 2024
1 parent 50f1fa0 commit a4171e2
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
### Bug Fixes

* (x/bank) [#20028](https://github.com/cosmos/cosmos-sdk/pull/20028) Align query with multi denoms for send-enabled.
* (baseapp) [#]() Fix data race in mempool iteration.

## [Unreleased-Upstream]

Expand Down
32 changes: 19 additions & 13 deletions baseapp/abci_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,16 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
return &abci.ResponsePrepareProposal{Txs: h.txSelector.SelectedTxs(ctx)}, nil
}

iterator := h.mempool.Select(ctx, req.Txs)
selectedTxsSignersSeqs := make(map[string]uint64)
var selectedTxsNums int
for iterator != nil {
memTx := iterator.Tx()
signerData, err := h.signerExtAdapter.GetSigners(memTx.Tx)
var (
err error
selectedTxsNums int
)
h.mempool.SelectBy(ctx, req.Txs, func(memTx mempool.Tx) bool {
var signerData []mempool.SignerData
signerData, err = h.signerExtAdapter.GetSigners(memTx.Tx)
if err != nil {
return nil, err
return false
}

// If the signers aren't in selectedTxsSignersSeqs then we haven't seen them before
Expand All @@ -315,24 +317,24 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
txSignersSeqs[signer.Signer.String()] = signer.Sequence
}
if !shouldAdd {
iterator = iterator.Next()
continue
return true
}

// NOTE: Since transaction verification was already executed in CheckTx,
// which calls mempool.Insert, in theory everything in the pool should be
// valid. But some mempool implementations may insert invalid txs, so we
// check again.
txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx.Tx)
var txBz []byte
txBz, err = h.txVerifier.PrepareProposalVerifyTx(memTx.Tx)
if err != nil {
err := h.mempool.Remove(memTx.Tx)
err = h.mempool.Remove(memTx.Tx)
if err != nil && !errors.Is(err, mempool.ErrTxNotFound) {
return nil, err
return false
}
} else {
stop := h.txSelector.SelectTxForProposal(ctx, uint64(req.MaxTxBytes), maxBlockGas, memTx.Tx, txBz, memTx.GasWanted)
if stop {
break
return false
}

txsLen := len(h.txSelector.SelectedTxs(ctx))
Expand All @@ -353,7 +355,11 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
selectedTxsNums = txsLen
}

iterator = iterator.Next()
return true
})

if err != nil {
return nil, err
}

return &abci.ResponsePrepareProposal{Txs: h.txSelector.SelectedTxs(ctx)}, nil
Expand Down
3 changes: 3 additions & 0 deletions types/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ type Mempool interface {
// closed by the caller.
Select(context.Context, [][]byte) Iterator

// SelectBy use callback to iterate over the mempool.
SelectBy(context.Context, [][]byte, func(Tx) bool)

// CountTx returns the number of transactions currently in the mempool.
CountTx() int

Expand Down
1 change: 1 addition & 0 deletions types/mempool/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ type NoOpMempool struct{}
func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil }
func (NoOpMempool) InsertWithGasWanted(context.Context, sdk.Tx, uint64) error { return nil }
func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil }
func (NoOpMempool) SelectBy(context.Context, [][]byte, func(Tx) bool) {}
func (NoOpMempool) CountTx() int { return 0 }
func (NoOpMempool) Remove(sdk.Tx) error { return nil }
24 changes: 24 additions & 0 deletions types/mempool/priority_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,30 @@ func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterato
return iterator.iteratePriority()
}

func (mp *PriorityNonceMempool[C]) SelectBy(_ context.Context, _ [][]byte, callback func(Tx) bool) {
mp.mtx.Lock()
defer mp.mtx.Unlock()

if mp.priorityIndex.Len() == 0 {
return
}

mp.reorderPriorityTies()

iterator := &PriorityNonceIterator[C]{
mempool: mp,
senderCursors: make(map[string]*skiplist.Element),
}

iter := iterator.iteratePriority()
for iter != nil {
if !callback(iter.Tx()) {
break
}
iter = iter.Next()
}
}

type reorderKey[C comparable] struct {
deleteKey txMeta[C]
insertKey txMeta[C]
Expand Down
84 changes: 84 additions & 0 deletions types/mempool/priority_nonce_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package mempool_test

import (
"context"
"fmt"
"math"
"math/rand"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -396,6 +398,88 @@ func (s *MempoolTestSuite) TestIterator() {
}
}

func (s *MempoolTestSuite) TestIteratorConcurrency() {
t := s.T()
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

tests := []struct {
txs []txSpec
fail bool
}{
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: 8, n: 2, a: sb},
},
},
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: math.MinInt64, n: 2, a: sb},
},
},
}

for i, tt := range tests {
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
pool := mempool.DefaultPriorityMempool()

// create test txs and insert into mempool
for i, ts := range tt.txs {
tx := testTx{id: i, priority: int64(ts.p), nonce: uint64(ts.n), address: ts.a}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}

// iterate through txs
stdCtx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

id := len(tt.txs)
for {
select {
case <-stdCtx.Done():
return
default:
id++
tx := testTx{id: id, priority: int64(rand.Intn(100)), nonce: uint64(id), address: sa}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}
}
}()

var i int
pool.SelectBy(ctx, nil, func(memTx mempool.Tx) bool {
tx := memTx.Tx.(testTx)
if tx.id < len(tt.txs) {
require.Equal(t, tt.txs[tx.id].p, int(tx.priority))
require.Equal(t, tt.txs[tx.id].n, int(tx.nonce))
require.Equal(t, tt.txs[tx.id].a, tx.address)
i++
}
return i < len(tt.txs)
})
cancel()
wg.Wait()
})
}
}

func (s *MempoolTestSuite) TestPriorityTies() {
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 3)
Expand Down
36 changes: 36 additions & 0 deletions types/mempool/sender_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,42 @@ func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator {
return iter.Next()
}

func (snm *SenderNonceMempool) SelectBy(_ context.Context, _ [][]byte, callback func(Tx) bool) {
snm.mtx.Lock()
defer snm.mtx.Unlock()
var senders []string

senderCursors := make(map[string]*skiplist.Element)
orderedSenders := skiplist.New(skiplist.String)

// #nosec
for s := range snm.senders {
orderedSenders.Set(s, s)
}

s := orderedSenders.Front()
for s != nil {
sender := s.Value.(string)
senders = append(senders, sender)
senderCursors[sender] = snm.senders[sender].Front()
s = s.Next()
}

iterator := &senderNonceMempoolIterator{
senders: senders,
rnd: snm.rnd,
senderCursors: senderCursors,
}

iter := iterator.Next()
for iter != nil {
if !callback(iter.Tx()) {
break
}
iter = iter.Next()
}
}

// CountTx returns the total count of txs in the mempool.
func (snm *SenderNonceMempool) CountTx() int {
snm.mtx.Lock()
Expand Down

0 comments on commit a4171e2

Please sign in to comment.