diff --git a/chain/messagesigner/messagesigner.go b/chain/messagesigner/messagesigner.go index 1ad83543b6f..ac94d6a3e1f 100644 --- a/chain/messagesigner/messagesigner.go +++ b/chain/messagesigner/messagesigner.go @@ -3,6 +3,7 @@ package messagesigner import ( "bytes" "context" + "sync" "github.com/filecoin-project/go-address" "github.com/filecoin-project/lotus/chain/messagepool" @@ -16,7 +17,7 @@ import ( "golang.org/x/xerrors" ) -const dsKeyActorNonce = "ActorNonce" +const dsKeyActorNonce = "ActorNextNonce" var log = logging.Logger("messagesigner") @@ -28,6 +29,7 @@ type mpoolAPI interface { // when signing a message type MessageSigner struct { wallet *wallet.Wallet + lk sync.Mutex mpool mpoolAPI ds datastore.Batching } @@ -47,25 +49,42 @@ func newMessageSigner(wallet *wallet.Wallet, mpool mpoolAPI, ds dtypes.MetadataD // SignMessage increments the nonce for the message From address, and signs // the message -func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message) (*types.SignedMessage, error) { +func (ms *MessageSigner) SignMessage(ctx context.Context, msg *types.Message, cb func(*types.SignedMessage) error) (*types.SignedMessage, error) { + ms.lk.Lock() + defer ms.lk.Unlock() + + // Get the next message nonce nonce, err := ms.nextNonce(msg.From) if err != nil { return nil, xerrors.Errorf("failed to create nonce: %w", err) } + // Sign the message with the nonce msg.Nonce = nonce sig, err := ms.wallet.Sign(ctx, msg.From, msg.Cid().Bytes()) if err != nil { return nil, xerrors.Errorf("failed to sign message: %w", err) } - return &types.SignedMessage{ + // Callback with the signed message + smsg := &types.SignedMessage{ Message: *msg, Signature: *sig, - }, nil + } + err = cb(smsg) + if err != nil { + return nil, err + } + + // If the callback executed successfully, write the nonce to the datastore + if err := ms.saveNonce(msg.From, nonce); err != nil { + return nil, xerrors.Errorf("failed to save nonce: %w", err) + } + + return smsg, nil } -// nextNonce increments the nonce. +// nextNonce gets the next nonce for the given address. // If there is no nonce in the datastore, gets the nonce from the message pool. func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) { // Nonces used to be created by the mempool and we need to support nodes @@ -77,21 +96,22 @@ func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) { return 0, xerrors.Errorf("failed to get nonce from mempool: %w", err) } - // Get the nonce for this address from the datastore - addrNonceKey := datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()}) + // Get the next nonce for this address from the datastore + addrNonceKey := ms.dstoreKey(addr) dsNonceBytes, err := ms.ds.Get(addrNonceKey) switch { case xerrors.Is(err, datastore.ErrNotFound): // If a nonce for this address hasn't yet been created in the // datastore, just use the nonce from the mempool + return nonce, nil case err != nil: return 0, xerrors.Errorf("failed to get nonce from datastore: %w", err) default: - // There is a nonce in the datastore, so unmarshall and increment it - maj, val, err := cbg.CborReadHeader(bytes.NewReader(dsNonceBytes)) + // There is a nonce in the datastore, so unmarshall it + maj, dsNonce, err := cbg.CborReadHeader(bytes.NewReader(dsNonceBytes)) if err != nil { return 0, xerrors.Errorf("failed to parse nonce from datastore: %w", err) } @@ -99,26 +119,37 @@ func (ms *MessageSigner) nextNonce(addr address.Address) (uint64, error) { return 0, xerrors.Errorf("bad cbor type parsing nonce from datastore") } - dsNonce := val + 1 - // The message pool nonce should be <= than the datastore nonce if nonce <= dsNonce { nonce = dsNonce } else { log.Warnf("mempool nonce was larger than datastore nonce (%d > %d)", nonce, dsNonce) } + + return nonce, nil } +} + +// saveNonce increments the nonce for this address and writes it to the +// datastore +func (ms *MessageSigner) saveNonce(addr address.Address, nonce uint64) error { + // Increment the nonce + nonce++ - // Write the nonce for this address to the datastore + // Write the nonce to the datastore + addrNonceKey := ms.dstoreKey(addr) buf := bytes.Buffer{} - _, err = buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce)) + _, err := buf.Write(cbg.CborEncodeMajorType(cbg.MajUnsignedInt, nonce)) if err != nil { - return 0, xerrors.Errorf("failed to marshall nonce: %w", err) + return xerrors.Errorf("failed to marshall nonce: %w", err) } err = ms.ds.Put(addrNonceKey, buf.Bytes()) if err != nil { - return 0, xerrors.Errorf("failed to write nonce to datastore: %w", err) + return xerrors.Errorf("failed to write nonce to datastore: %w", err) } + return nil +} - return nonce, nil +func (ms *MessageSigner) dstoreKey(addr address.Address) datastore.Key { + return datastore.KeyWithNamespaces([]string{dsKeyActorNonce, addr.String()}) } diff --git a/chain/messagesigner/messagesigner_test.go b/chain/messagesigner/messagesigner_test.go index 55676b25805..04869ff6dde 100644 --- a/chain/messagesigner/messagesigner_test.go +++ b/chain/messagesigner/messagesigner_test.go @@ -5,6 +5,8 @@ import ( "sync" "testing" + "golang.org/x/xerrors" + "github.com/filecoin-project/lotus/chain/wallet" "github.com/filecoin-project/go-state-types/crypto" @@ -58,6 +60,7 @@ func TestMessageSignerSignMessage(t *testing.T) { msg *types.Message mpoolNonce [1]uint64 expNonce uint64 + cbErr error } tests := []struct { name string @@ -137,6 +140,37 @@ func TestMessageSignerSignMessage(t *testing.T) { }, expNonce: 2, }}, + }, { + name: "recover from callback error", + msgs: []msgSpec{{ + // No nonce yet in datastore + msg: &types.Message{ + To: to1, + From: from1, + }, + expNonce: 0, + }, { + // Increment nonce + msg: &types.Message{ + To: to1, + From: from1, + }, + expNonce: 1, + }, { + // Callback returns error + msg: &types.Message{ + To: to1, + From: from1, + }, + cbErr: xerrors.Errorf("err"), + }, { + // Callback successful, should increment nonce in datastore + msg: &types.Message{ + To: to1, + From: from1, + }, + expNonce: 2, + }}, }} for _, tt := range tests { tt := tt @@ -149,9 +183,18 @@ func TestMessageSignerSignMessage(t *testing.T) { if len(m.mpoolNonce) == 1 { mpool.setNonce(m.msg.From, m.mpoolNonce[0]) } - smsg, err := ms.SignMessage(ctx, m.msg) - require.NoError(t, err) - require.Equal(t, m.expNonce, smsg.Message.Nonce) + merr := m.cbErr + smsg, err := ms.SignMessage(ctx, m.msg, func(message *types.SignedMessage) error { + return merr + }) + + if m.cbErr != nil { + require.Error(t, err) + require.Nil(t, smsg) + } else { + require.NoError(t, err) + require.Equal(t, m.expNonce, smsg.Message.Nonce) + } } }) } diff --git a/node/impl/full/mpool.go b/node/impl/full/mpool.go index e0dd3ecef00..1f093606c38 100644 --- a/node/impl/full/mpool.go +++ b/node/impl/full/mpool.go @@ -160,17 +160,13 @@ func (a *MpoolAPI) MpoolPushMessage(ctx context.Context, msg *types.Message, spe return nil, xerrors.Errorf("mpool push: not enough funds: %s < %s", b, msg.Value) } - smsg, err := a.MessageSigner.SignMessage(ctx, msg) - if err != nil { - return nil, xerrors.Errorf("mpool push: failed to sign message: %w", err) - } - - _, err = a.Mpool.Push(smsg) - if err != nil { - return nil, xerrors.Errorf("mpool push: failed to push message: %w", err) - } - - return smsg, err + // Sign and push the message + return a.MessageSigner.SignMessage(ctx, msg, func(smsg *types.SignedMessage) error { + if _, err := a.Mpool.Push(smsg); err != nil { + return xerrors.Errorf("mpool push: failed to push message: %w", err) + } + return nil + }) } func (a *MpoolAPI) MpoolGetNonce(ctx context.Context, addr address.Address) (uint64, error) {