Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: return postCheck error to abci client #507

Merged
merged 15 commits into from
Nov 25, 2022
520 changes: 330 additions & 190 deletions abci/types/types.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions consensus/replay_stubs.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func (emptyMempool) Update(
_ *types.Block,
_ []*abci.ResponseDeliverTx,
_ mempl.PreCheckFunc,
_ mempl.PostCheckFunc,
) error {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion mempool/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestCacheAfterUpdate(t *testing.T) {
updateTxs = append(updateTxs, tx)
}
err := mempool.Update(newTestBlock(int64(tcIndex), updateTxs),
abciResponses(len(updateTxs), abci.CodeTypeOK), nil)
abciResponses(len(updateTxs), abci.CodeTypeOK), nil, nil)
require.NoError(t, err)

for _, v := range tc.reAddIndices {
Expand Down
50 changes: 40 additions & 10 deletions mempool/clist_mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type CListMempool struct {

chReqCheckTx chan *requestCheckTxAsync

postCheck PostCheckFunc

wal *auto.AutoFile // a log of mempool txs
txs *clist.CList // concurrent linked-list of good txs
proxyAppConn proxy.AppConnMempool
Expand Down Expand Up @@ -131,6 +133,13 @@ func WithPreCheck(f PreCheckFunc) CListMempoolOption {
return func(mem *CListMempool) { mem.preCheck = f }
}

// WithPostCheck sets a filter for the mempool to reject a tx if f(tx) returns
// false. This is ran after CheckTx. Only applies to the first created block.
// After that, Update overwrites the existing value.
func WithPostCheck(f PostCheckFunc) CListMempoolOption {
return func(mem *CListMempool) { mem.postCheck = f }
}

// WithMetrics sets the metrics.
func WithMetrics(metrics *Metrics) CListMempoolOption {
return func(mem *CListMempool) { mem.metrics = metrics }
Expand Down Expand Up @@ -542,23 +551,40 @@ func (mem *CListMempool) resCbFirstTime(
func (mem *CListMempool) resCbRecheck(req *abci.Request, res *abci.Response) {
switch r := res.Value.(type) {
case *abci.Response_CheckTx:
tx := req.GetCheckTx().Tx
txHash := TxKey(tx)
e, ok := mem.txsMap.Load(txHash)

tnasu marked this conversation as resolved.
Show resolved Hide resolved
if r.CheckTx.Code == abci.CodeTypeOK {
// Good, nothing to do.
if !ok {
panic(fmt.Sprintf("Unexpected tx response from proxy during recheck\ntxHash=%s, tx=%X", txHash, tx))
tnasu marked this conversation as resolved.
Show resolved Hide resolved
}
if mem.postCheck == nil {
return
}
postCheckErr := mem.postCheck(tx, r.CheckTx)
if postCheckErr == nil {
return
}
celem := e.(*clist.CElement)
// Tx became invalidated due to newly committed block.
mem.logger.Debug("tx is no longer valid", "tx", txID(tx), "res", r, "err", postCheckErr)
// NOTE: we remove tx from the cache because it might be good later
mem.removeTx(tx, celem, !mem.config.KeepInvalidTxsInCache)
r.CheckTx.MempoolError = postCheckErr.Error()
} else {
tx := req.GetCheckTx().Tx
txHash := TxKey(tx)
if e, ok := mem.txsMap.Load(txHash); ok {
celem := e.(*clist.CElement)
// Tx became invalidated due to newly committed block.
mem.logger.Debug("tx is no longer valid", "tx", txID(tx), "res", r)
// NOTE: we remove tx from the cache because it might be good later
mem.removeTx(tx, celem, true)
} else {
if !ok {
mem.logger.Debug(
"re-CheckTx transaction does not exist",
"expected", types.Tx(tx),
)
return
}
celem := e.(*clist.CElement)
// Tx became invalidated due to newly committed block.
mem.logger.Debug("tx is no longer valid", "tx", txID(tx), "res", r)
// NOTE: we remove tx from the cache because it might be good later
mem.removeTx(tx, celem, true)
}
default:
// ignore other messages
Expand Down Expand Up @@ -678,6 +704,7 @@ func (mem *CListMempool) Update(
block *types.Block,
deliverTxResponses []*abci.ResponseDeliverTx,
preCheck PreCheckFunc,
postCheck PostCheckFunc,
) (err error) {
// Set height
mem.height = block.Height
Expand All @@ -686,6 +713,9 @@ func (mem *CListMempool) Update(
if preCheck != nil {
mem.preCheck = preCheck
}
if postCheck != nil {
mem.postCheck = postCheck
}

for i, tx := range block.Txs {
if deliverTxResponses[i].Code == abci.CodeTypeOK {
Expand Down
2 changes: 1 addition & 1 deletion mempool/clist_mempool_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func commitBlock(ctx context.Context, t *testing.T,
mem *CListMempool, block *types.Block, deliverTxResponses []*abci.ResponseDeliverTx) {
mem.Lock()
defer mem.Unlock()
err := mem.Update(block, deliverTxResponses, nil)
err := mem.Update(block, deliverTxResponses, nil, nil)
require.NoError(t, err)
}

Expand Down
76 changes: 66 additions & 10 deletions mempool/clist_mempool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io/ioutil"
mrand "math/rand"
"os"
"path/filepath"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -173,7 +175,7 @@ func TestMempoolFilters(t *testing.T) {
}
for tcIndex, tt := range tests {
err := mempool.Update(newTestBlock(1, emptyTxArr),
abciResponses(len(emptyTxArr), abci.CodeTypeOK), tt.preFilter)
abciResponses(len(emptyTxArr), abci.CodeTypeOK), tt.preFilter, nil)
require.NoError(t, err)
checkTxs(t, mempool, tt.numTxsToCreate, UnknownPeerID)
require.Equal(t, tt.expectedNumTxs, mempool.Size(), "mempool had the incorrect size, on test case %d", tcIndex)
Expand All @@ -190,7 +192,7 @@ func TestMempoolUpdate(t *testing.T) {
// 1. Adds valid txs to the cache
{
err := mempool.Update(newTestBlock(1, []types.Tx{[]byte{0x01}}),
abciResponses(1, abci.CodeTypeOK), nil)
abciResponses(1, abci.CodeTypeOK), nil, nil)
require.NoError(t, err)
_, err = mempool.CheckTxSync([]byte{0x01}, TxInfo{})
if assert.Error(t, err) {
Expand All @@ -202,7 +204,7 @@ func TestMempoolUpdate(t *testing.T) {
{
_, err := mempool.CheckTxSync([]byte{0x02}, TxInfo{})
require.NoError(t, err)
err = mempool.Update(newTestBlock(1, []types.Tx{[]byte{0x02}}), abciResponses(1, abci.CodeTypeOK), nil)
err = mempool.Update(newTestBlock(1, []types.Tx{[]byte{0x02}}), abciResponses(1, abci.CodeTypeOK), nil, nil)
require.NoError(t, err)
assert.Zero(t, mempool.Size())
}
Expand All @@ -211,7 +213,7 @@ func TestMempoolUpdate(t *testing.T) {
{
_, err := mempool.CheckTxSync([]byte{0x03}, TxInfo{})
require.NoError(t, err)
err = mempool.Update(newTestBlock(1, []types.Tx{[]byte{0x03}}), abciResponses(1, 1), nil)
err = mempool.Update(newTestBlock(1, []types.Tx{[]byte{0x03}}), abciResponses(1, 1), nil, nil)
require.NoError(t, err)
assert.Zero(t, mempool.Size())

Expand Down Expand Up @@ -243,7 +245,7 @@ func TestMempool_KeepInvalidTxsInCache(t *testing.T) {
_ = app.DeliverTx(abci.RequestDeliverTx{Tx: a})
_ = app.DeliverTx(abci.RequestDeliverTx{Tx: b})
err = mempool.Update(newTestBlock(1, []types.Tx{a, b}),
[]*abci.ResponseDeliverTx{{Code: abci.CodeTypeOK}, {Code: 2}}, nil)
[]*abci.ResponseDeliverTx{{Code: abci.CodeTypeOK}, {Code: 2}}, nil, nil)
require.NoError(t, err)

// a must be added to the cache
Expand Down Expand Up @@ -299,7 +301,7 @@ func TestTxsAvailable(t *testing.T) {
// since there are still txs left
committedTxs, txs := txs[:50], txs[50:]
if err := mempool.Update(newTestBlock(1, committedTxs),
abciResponses(len(committedTxs), abci.CodeTypeOK), nil); err != nil {
abciResponses(len(committedTxs), abci.CodeTypeOK), nil, nil); err != nil {
t.Error(err)
}
ensureFire(t, mempool.TxsAvailable(), timeoutMS)
Expand All @@ -312,7 +314,7 @@ func TestTxsAvailable(t *testing.T) {
// now call update with all the txs. it should not fire as there are no txs left
committedTxs = append(txs, moreTxs...) // nolint: gocritic
if err := mempool.Update(newTestBlock(2, committedTxs),
abciResponses(len(committedTxs), abci.CodeTypeOK), nil); err != nil {
abciResponses(len(committedTxs), abci.CodeTypeOK), nil, nil); err != nil {
t.Error(err)
}
ensureNoFire(t, mempool.TxsAvailable(), timeoutMS)
Expand Down Expand Up @@ -372,7 +374,7 @@ func TestSerialReap(t *testing.T) {
txs = append(txs, txBytes)
}
if err := mempool.Update(newTestBlock(0, txs),
abciResponses(len(txs), abci.CodeTypeOK), nil); err != nil {
abciResponses(len(txs), abci.CodeTypeOK), nil, nil); err != nil {
t.Error(err)
}
}
Expand Down Expand Up @@ -544,7 +546,7 @@ func TestMempoolTxsBytes(t *testing.T) {

// 3. zero again after tx is removed by Update
err = mempool.Update(newTestBlock(1, []types.Tx{[]byte{0x01}}),
abciResponses(1, abci.CodeTypeOK), nil)
abciResponses(1, abci.CodeTypeOK), nil, nil)
require.NoError(t, err)
assert.EqualValues(t, 0, mempool.TxsBytes())

Expand Down Expand Up @@ -594,7 +596,7 @@ func TestMempoolTxsBytes(t *testing.T) {
require.NotEmpty(t, res2.Data)

// Pretend like we committed nothing so txBytes gets rechecked and removed.
err = mempool.Update(newTestBlock(1, []types.Tx{}), abciResponses(0, abci.CodeTypeOK), nil)
err = mempool.Update(newTestBlock(1, []types.Tx{}), abciResponses(0, abci.CodeTypeOK), nil, nil)
require.NoError(t, err)
assert.EqualValues(t, 0, mempool.TxsBytes())

Expand Down Expand Up @@ -699,3 +701,57 @@ func abciResponses(n int, code uint32) []*abci.ResponseDeliverTx {
}
return responses
}

func TestTxMempoolPostCheckError(t *testing.T) {
cases := []struct {
name string
err error
}{
{
name: "error",
err: errors.New("test error"),
},
{
name: "no error",
err: nil,
},
}
for _, tc := range cases {
testCase := tc
t.Run(testCase.name, func(t *testing.T) {
app := kvstore.NewApplication()
cc := proxy.NewLocalClientCreator(app)
mempool, cleanup := newMempoolWithApp(cc)
defer cleanup()

mempool.postCheck = func(_ types.Tx, _ *abci.ResponseCheckTx) error {
return testCase.err
}

tx := types.Tx{1}
_, err := mempool.CheckTxSync(tx, TxInfo{})
require.NoError(t, err)

req := abci.RequestCheckTx{
Tx: tx,
Type: abci.CheckTxType_Recheck,
}
res := &abci.Response{}

m := sync.Mutex{}
m.Lock()
mempool.proxyAppConn.CheckTxAsync(req, func(r *abci.Response) {
res = r
m.Unlock()
})

checkTxRes, ok := res.Value.(*abci.Response_CheckTx)
require.True(t, ok)
expectedErrString := ""
if testCase.err != nil {
expectedErrString = testCase.err.Error()
}
require.Equal(t, expectedErrString, checkTxRes.CheckTx.MempoolError)
})
}
}
25 changes: 25 additions & 0 deletions mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type Mempool interface {
block *types.Block,
deliverTxResponses []*abci.ResponseDeliverTx,
newPreFn PreCheckFunc,
newPostFn PostCheckFunc,
) error

// FlushAppConn flushes the mempool connection to ensure async reqResCb calls are
Expand Down Expand Up @@ -87,6 +88,11 @@ type Mempool interface {
// transaction doesn't exceeded the block size.
type PreCheckFunc func(types.Tx) error

// PostCheckFunc is an optional filter executed after CheckTx and rejects
// transaction if false is returned. An example would be to ensure a
// transaction doesn't require more gas than available for the block.
type PostCheckFunc func(types.Tx, *abci.ResponseCheckTx) error

// TxInfo are parameters that get passed when attempting to add a tx to the
// mempool.
type TxInfo struct {
Expand All @@ -111,3 +117,22 @@ func PreCheckMaxBytes(maxBytes int64) PreCheckFunc {
return nil
}
}

// PostCheckMaxGas checks that the wanted gas is smaller or equal to the passed
// maxGas. Returns nil if maxGas is -1.
func PostCheckMaxGas(maxGas int64) PostCheckFunc {
return func(tx types.Tx, res *abci.ResponseCheckTx) error {
if maxGas == -1 {
return nil
}
if res.GasWanted < 0 {
return fmt.Errorf("gas wanted %d is negative",
res.GasWanted)
}
if res.GasWanted > maxGas {
return fmt.Errorf("gas wanted %d is greater than max gas %d",
res.GasWanted, maxGas)
}
return nil
}
}
29 changes: 29 additions & 0 deletions mempool/mempool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package mempool

import (
"testing"

abci "github.com/line/ostracon/abci/types"
"github.com/stretchr/testify/require"
)

func TestPostCheckMaxGas(t *testing.T) {
tests := []struct {
res *abci.ResponseCheckTx
postCheck PostCheckFunc
ok bool
}{
{&abci.ResponseCheckTx{GasWanted: 10}, PostCheckMaxGas(10), true},
{&abci.ResponseCheckTx{GasWanted: 10}, PostCheckMaxGas(-1), true},
{&abci.ResponseCheckTx{GasWanted: -1}, PostCheckMaxGas(10), false},
{&abci.ResponseCheckTx{GasWanted: 11}, PostCheckMaxGas(10), false},
}
for tcIndex, tt := range tests {
err := tt.postCheck(nil, tt.res)
if tt.ok {
require.NoError(t, err, "postCheck should not return error, on test case %d", tcIndex)
} else {
require.Error(t, err, "postCheck should return error, on test case %d", tcIndex)
}
}
}
1 change: 1 addition & 0 deletions mempool/mock/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func (Mempool) Update(
_ *types.Block,
_ []*abci.ResponseDeliverTx,
_ mempl.PreCheckFunc,
_ mempl.PostCheckFunc,
) error {
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions mempool/reactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestReactorConcurrency(t *testing.T) {
for i := range txs {
deliverTxResponses[i] = &abci.ResponseDeliverTx{Code: 0}
}
err := reactors[0].mempool.Update(newTestBlock(1, txs), deliverTxResponses, nil)
err := reactors[0].mempool.Update(newTestBlock(1, txs), deliverTxResponses, nil, nil)
assert.NoError(t, err)
}()

Expand All @@ -120,7 +120,7 @@ func TestReactorConcurrency(t *testing.T) {
reactors[1].mempool.Lock()
defer reactors[1].mempool.Unlock()
err := reactors[1].mempool.Update(newTestBlock(1, []types.Tx{}),
make([]*abci.ResponseDeliverTx, 0), nil)
make([]*abci.ResponseDeliverTx, 0), nil, nil)
assert.NoError(t, err)
}()

Expand Down
1 change: 1 addition & 0 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ func createMempoolAndMempoolReactor(config *cfg.Config, proxyApp proxy.AppConns,
state.LastBlockHeight,
mempl.WithMetrics(memplMetrics),
mempl.WithPreCheck(sm.TxPreCheck(state)),
mempl.WithPostCheck(sm.TxPostCheck(state)),
)
mempoolLogger := logger.With("module", "mempool")
mempoolReactor := mempl.NewReactor(config.Mempool, config.P2P.RecvAsync, config.P2P.MempoolRecvBufSize, mempool)
Expand Down
Loading