diff --git a/.golangci.yml b/.golangci.yml index f1b4ea3df5..f3de207048 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -104,11 +104,13 @@ linters-settings: - 'errors.Wrap' gomoddirectives: + replace-local: true replace-allow-list: # See go.mod for the explanation why these are needed. - github.com/ulikunitz/xz - github.com/gogo/protobuf - google.golang.org/protobuf + - github.com/lightningnetwork/lnd/sqldb linters: diff --git a/build/version.go b/build/version.go index f0face37e2..b3bcfb611d 100644 --- a/build/version.go +++ b/build/version.go @@ -47,7 +47,7 @@ const ( // AppPreRelease MUST only contain characters from semanticAlphabet per // the semantic versioning spec. - AppPreRelease = "beta.rc2" + AppPreRelease = "beta.rc3" ) func init() { diff --git a/channeldb/invoices.go b/channeldb/invoices.go index df124b632a..9da504a5d8 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) ( // For each key found, we'll look up the actual // invoice, then accumulate it into our return value. - invoice, err := fetchInvoice(invoiceKey, invoices) + invoice, err := fetchInvoice( + invoiceKey, invoices, nil, false, + ) if err != nil { return err } @@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) ( // An invoice was found, retrieve the remainder of the invoice // body. - i, err := fetchInvoice(invoiceNum, invoices, setID) + i, err := fetchInvoice( + invoiceNum, invoices, []*invpkg.SetID{setID}, true, + ) if err != nil { return err } @@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) ( return nil } - invoice, err := fetchInvoice(v, invoices) + invoice, err := fetchInvoice(v, invoices, nil, false) if err != nil { return err } @@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) ( // characteristics for our query and returns the number of items // we have added to our set of invoices. accumulateInvoices := func(_, indexValue []byte) (bool, error) { - invoice, err := fetchInvoice(indexValue, invoices) + invoice, err := fetchInvoice( + indexValue, invoices, nil, false, + ) if err != nil { return false, err } @@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, if setIDHint != nil { invSetID = *setIDHint } - invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID) + invoice, err := fetchInvoice( + invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false, + ) if err != nil { return err } @@ -676,8 +684,17 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, updatedInvoice, err = invpkg.UpdateInvoice( payHash, updater.invoice, now, callback, updater, ) + if err != nil { + return err + } - return err + // If this is an AMP update, then limit the returned AMP state + // to only the requested set ID. + if setIDHint != nil { + filterInvoiceAMPState(updatedInvoice, &invSetID) + } + + return nil }, func() { updatedInvoice = nil }) @@ -685,6 +702,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef, return updatedInvoice, err } +// filterInvoiceAMPState filters the AMP state of the invoice to only include +// state for the specified set IDs. +func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) { + filteredAMPState := make(invpkg.AMPInvoiceState) + + for _, setID := range setIDs { + if setID == nil { + return + } + + ampState, ok := invoice.AMPState[*setID] + if ok { + filteredAMPState[*setID] = ampState + } + } + + invoice.AMPState = filteredAMPState +} + // ampHTLCsMap is a map of AMP HTLCs affected by an invoice update. type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC @@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) ( // For each key found, we'll look up the actual // invoice, then accumulate it into our return value. invoice, err := fetchInvoice( - invoiceKey[:], invoices, setID, + invoiceKey[:], invoices, []*invpkg.SetID{setID}, + true, ) if err != nil { return err @@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte, // specified by the invoice number. If the setID fields are set, then only the // HTLC information pertaining to those set IDs is returned. func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, - setIDs ...*invpkg.SetID) (invpkg.Invoice, error) { + setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) { invoiceBytes := invoices.Get(invoiceNum) if invoiceBytes == nil { @@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, log.Errorf("unable to fetch amp htlcs for inv "+ "%v and setIDs %v: %w", invoiceNum, setIDs, err) } + + if filterAMPState { + filterInvoiceAMPState(&invoice, setIDs...) + } } return invoice, nil @@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error { return nil } - invoice, err := fetchInvoice(v, invoices) + invoice, err := fetchInvoice(v, invoices, nil, false) if err != nil { return err } diff --git a/docs/release-notes/release-notes-0.18.3.md b/docs/release-notes/release-notes-0.18.3.md index 6389e26329..804f332a9a 100644 --- a/docs/release-notes/release-notes-0.18.3.md +++ b/docs/release-notes/release-notes-0.18.3.md @@ -253,6 +253,11 @@ that validate `ChannelAnnouncement` messages. our health checker to correctly shut down LND if network partitioning occurs towards the etcd cluster. +* [Fix](https://github.com/lightningnetwork/lnd/pull/9050) some inconsistencies + to make the native SQL invoice DB compatible with the KV implementation. + Furthermore fix a native SQL invoice issue where AMP subinvoice HTLCs are + sometimes updated incorrectly on settlement. + ## Code Health * [Move graph building and @@ -269,6 +274,7 @@ that validate `ChannelAnnouncement` messages. # Contributors (Alphabetical Order) +* Alex Akselrod * Andras Banki-Horvath * bitromortac * Bufo diff --git a/go.mod b/go.mod index 0e628d18b0..ee656cae08 100644 --- a/go.mod +++ b/go.mod @@ -204,6 +204,9 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2 // allows us to specify that as an option. replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display +// Temporary replace until the next version of sqldb is taged. +replace github.com/lightningnetwork/lnd/sqldb => ./sqldb + // If you change this please also update .github/pull_request_template.md, // docs/INSTALL.md and GO_IMAGE in lnrpc/gen_protos_docker.sh. go 1.21.4 diff --git a/go.sum b/go.sum index d556042dd3..cc4ce1bfd9 100644 --- a/go.sum +++ b/go.sum @@ -458,8 +458,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.10 h1:vK89IVv1oVH9ubQWU+EmoCQFeVRaC8kf github.com/lightningnetwork/lnd/kvdb v1.4.10/go.mod h1:J2diNABOoII9UrMnxXS5w7vZwP7CA1CStrl8MnIrb3A= github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI= github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4= -github.com/lightningnetwork/lnd/sqldb v1.0.3 h1:zLfAwOvM+6+3+hahYO9Q3h8pVV0TghAR7iJ5YMLCd3I= -github.com/lightningnetwork/lnd/sqldb v1.0.3/go.mod h1:4cQOkdymlZ1znnjuRNvMoatQGJkRneTj2CoPSPaQhWo= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= github.com/lightningnetwork/lnd/tlv v1.2.3 h1:If5ibokA/UoCBGuCKaY6Vn2SJU0l9uAbehCnhTZjEP8= diff --git a/invoices/sql_store.go b/invoices/sql_store.go index 4b488715ba..eb465eabb4 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -10,6 +10,7 @@ import ( "strconv" "time" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntypes" @@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat GetInvoice(ctx context.Context, arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error) + GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice, + error) + GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]sqlc.InvoiceFeature, error) @@ -343,7 +347,22 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, params.SetID = ref.SetID()[:] } - rows, err := db.GetInvoice(ctx, params) + var ( + rows []sqlc.Invoice + err error + ) + + // We need to split the query based on how we intend to look up the + // invoice. If only the set ID is given then we want to have an exact + // match on the set ID. If other fields are given, we want to match on + // those fields and the set ID but with a less strict join condition. + if params.Hash == nil && params.PaymentAddr == nil && + params.SetID != nil { + + rows, err = db.GetInvoiceBySetID(ctx, params.SetID) + } else { + rows, err = db.GetInvoice(ctx, params) + } switch { case len(rows) == 0: return nil, ErrInvoiceNotFound @@ -351,8 +370,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context, case len(rows) > 1: // In case the reference is ambiguous, meaning it matches more // than one invoice, we'll return an error. - return nil, fmt.Errorf("ambiguous invoice ref: %s", - ref.String()) + return nil, fmt.Errorf("ambiguous invoice ref: %s: %s", + ref.String(), spew.Sdump(rows)) case err != nil: return nil, fmt.Errorf("unable to fetch invoice: %w", err) @@ -906,8 +925,10 @@ func (i *SQLStore) QueryInvoices(ctx context.Context, } if q.CreationDateEnd != 0 { + // We need to add 1 to the end date as we're + // checking less than the end date in SQL. params.CreatedBefore = sqldb.SQLTime( - time.Unix(q.CreationDateEnd, 0).UTC(), + time.Unix(q.CreationDateEnd+1, 0).UTC(), ) } @@ -1116,6 +1137,9 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte, SetID: setID[:], HtlcID: int64(circuitKey.HtlcID), Preimage: preimage[:], + ChanID: strconv.FormatUint( + circuitKey.ChanID.ToUint64(), 10, + ), }, ) if err != nil { @@ -1280,6 +1304,13 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte, return err } + if settleIndex.Valid { + updatedState := s.invoice.AMPState[setID] + updatedState.SettleIndex = uint64(settleIndex.Int64) + updatedState.SettleDate = s.updateTime.UTC() + s.invoice.AMPState[setID] = updatedState + } + return nil } @@ -1298,13 +1329,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error { // invoice and is therefore atomic. The fields to update are controlled by the // supplied callback. func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef, - _ *SetID, callback InvoiceUpdateCallback) ( + setID *SetID, callback InvoiceUpdateCallback) ( *Invoice, error) { var updatedInvoice *Invoice txOpt := SQLInvoiceQueriesTxOptions{readOnly: false} txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error { + if setID != nil { + // Make sure to use the set ID if this is an AMP update. + var setIDBytes [32]byte + copy(setIDBytes[:], setID[:]) + ref.setID = &setIDBytes + + // If we're updating an AMP invoice, we'll also only + // need to fetch the HTLCs for the given set ID. + ref.refModifier = HtlcSetOnlyModifier + } + invoice, err := i.fetchInvoice(ctx, db, ref) if err != nil { return err diff --git a/itest/lnd_amp_test.go b/itest/lnd_amp_test.go index 23bfd8654d..4b4cfb5a29 100644 --- a/itest/lnd_amp_test.go +++ b/itest/lnd_amp_test.go @@ -260,7 +260,8 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { invoiceNtfn := ht.ReceiveInvoiceUpdate(invSubscription) // The notification should signal that the invoice is now settled, and - // should also include the set ID, and show the proper amount paid. + // should also include the set ID, show the proper amount paid, and have + // the correct settle index and time. require.True(ht, invoiceNtfn.Settled) require.Equal(ht, lnrpc.Invoice_SETTLED, invoiceNtfn.State) require.Equal(ht, paymentAmt, int(invoiceNtfn.AmtPaidSat)) @@ -270,6 +271,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { firstSetID, _ = hex.DecodeString(setIDStr) require.Equal(ht, lnrpc.InvoiceHTLCState_SETTLED, ampState.State) + require.GreaterOrEqual(ht, ampState.SettleTime, + rpcInvoice.CreationDate) + require.Equal(ht, uint64(1), ampState.SettleIndex) } // Pay the invoice again, we should get another notification that Dave @@ -299,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { // return the "projected" sub-invoice for a given setID. require.Equal(ht, 1, len(invoiceNtfn.Htlcs)) - // However the AMP state index should show that there've been two - // repeated payments to this invoice so far. - require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState)) + // The AMP state should also be restricted to a single entry for the + // "projected" sub-invoice. + require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState)) // Now we'll look up the invoice using the new LookupInvoice2 RPC call // by the set ID of each of the invoices. @@ -360,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) { // through. backlogInv := ht.ReceiveInvoiceUpdate(invSub2) require.Equal(ht, 1, len(backlogInv.Htlcs)) - require.Equal(ht, 2, len(backlogInv.AmpInvoiceState)) + require.Equal(ht, 1, len(backlogInv.AmpInvoiceState)) require.True(ht, backlogInv.Settled) require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat)) } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index e3a777caab..737c1d92e9 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -3905,6 +3905,27 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { }, nil } +// resignMusigCommit is used to resign a commitment transaction for taproot +// channels when we need to retransmit a signature after a channel reestablish +// message. Taproot channels use musig2, which means we must use fresh nonces +// each time. After we receive the channel reestablish message, we learn the +// nonce we need to use for the remote party. As a result, we need to generate +// the partial signature again with the new nonce. +func (lc *LightningChannel) resignMusigCommit(commitTx *wire.MsgTx, +) (lnwire.OptPartialSigWithNonceTLV, error) { + + remoteSession := lc.musigSessions.RemoteSession + musig, err := remoteSession.SignCommit(commitTx) + if err != nil { + var none lnwire.OptPartialSigWithNonceTLV + return none, err + } + + partialSig := lnwire.MaybePartialSigWithNonce(musig.ToWireSig()) + + return partialSig, nil +} + // ProcessChanSyncMsg processes a ChannelReestablish message sent by the remote // connection upon re establishment of our connection with them. This method // will return a single message if we are currently out of sync, otherwise a @@ -4182,12 +4203,23 @@ func (lc *LightningChannel) ProcessChanSyncMsg( commitUpdates = append(commitUpdates, logUpdate.UpdateMsg) } + // If this is a taproot channel, then we need to regenerate the + // musig2 signature for the remote party, using their fresh + // nonce. + if lc.channelState.ChanType.IsTaproot() { + partialSig, err := lc.resignMusigCommit( + commitDiff.Commitment.CommitTx, + ) + if err != nil { + return nil, nil, nil, err + } + + commitDiff.CommitSig.PartialSig = partialSig + } + // With the batch of updates accumulated, we'll now re-send the // original CommitSig message required to re-sync their remote // commitment chain with our local version of their chain. - // - // TODO(roasbeef): need to re-sign commitment states w/ - // fresh nonce commitUpdates = append(commitUpdates, commitDiff.CommitSig) // NOTE: If a revocation is not owed, then updates is empty. diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index fcaf0fc08f..0281c59bd4 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -47,8 +47,8 @@ func createHTLC(id int, amount lnwire.MilliSatoshi) (*lnwire.UpdateAddHTLC, [32] } func assertOutputExistsByValue(t *testing.T, commitTx *wire.MsgTx, - value btcutil.Amount) { - + value btcutil.Amount, +) { for _, txOut := range commitTx.TxOut { if txOut.Value == int64(value) { return @@ -63,8 +63,8 @@ func assertOutputExistsByValue(t *testing.T, commitTx *wire.MsgTx, // add, the settle an HTLC between themselves. func testAddSettleWorkflow(t *testing.T, tweakless bool, chanTypeModifier channeldb.ChannelType, - storeFinalHtlcResolutions bool) { - + storeFinalHtlcResolutions bool, +) { // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. @@ -514,7 +514,6 @@ func TestCheckCommitTxSize(t *testing.T) { if 0 > diff || BaseCommitmentTxSizeEstimationError < diff { t.Fatalf("estimation is wrong, diff: %v", diff) } - } // Create a test channel which will be used for the duration of this @@ -1467,8 +1466,8 @@ func TestHTLCSigNumber(t *testing.T) { // createChanWithHTLC is a helper method that sets ut two channels, and // adds HTLCs with the passed values to the channels. createChanWithHTLC := func(htlcValues ...btcutil.Amount) ( - *LightningChannel, *LightningChannel) { - + *LightningChannel, *LightningChannel, + ) { // Create a test channel funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. Alice's dustlimit is 200 sat, while // Bob has 1300 sat. @@ -2367,7 +2366,6 @@ func TestUpdateFeeFail(t *testing.T) { if err == nil { t.Fatalf("expected bob to fail receiving alice's signature") } - } // TestUpdateFeeConcurrentSig tests that the channel can properly handle a fee @@ -2547,7 +2545,6 @@ func TestUpdateFeeSenderCommits(t *testing.T) { // Bob receives revocation from Alice. _, _, _, _, err = bobChannel.ReceiveRevocation(aliceRevocation) require.NoError(t, err, "bob unable to process alice's revocation") - } // TestUpdateFeeReceiverCommits tests that the state machine progresses as @@ -2857,8 +2854,8 @@ func TestAddHTLCNegativeBalance(t *testing.T) { // two channels conclude that they're fully synchronized and don't need to // retransmit any new messages. func assertNoChanSyncNeeded(t *testing.T, aliceChannel *LightningChannel, - bobChannel *LightningChannel) { - + bobChannel *LightningChannel, +) { _, _, line, _ := runtime.Caller(1) aliceChanSyncMsg, err := aliceChannel.channelState.ChanSyncMsg() @@ -3007,19 +3004,11 @@ func restartChannel(channelOld *LightningChannel) (*LightningChannel, error) { return channelNew, nil } -// TestChanSyncOweCommitment tests that if Bob restarts (and then Alice) before -// he receives Alice's CommitSig message, then Alice concludes that she needs -// to re-send the CommitDiff. After the diff has been sent, both nodes should -// resynchronize and be able to complete the dangling commit. -func TestChanSyncOweCommitment(t *testing.T) { - t.Parallel() - +func testChanSyncOweCommitment(t *testing.T, chanType channeldb.ChannelType) { // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. - aliceChannel, bobChannel, err := CreateTestChannels( - t, channeldb.SingleFunderTweaklessBit, - ) + aliceChannel, bobChannel, err := CreateTestChannels(t, chanType) require.NoError(t, err, "unable to create test channels") var fakeOnionBlob [lnwire.OnionPacketSize]byte @@ -3094,6 +3083,15 @@ func TestChanSyncOweCommitment(t *testing.T) { aliceNewCommit, err := aliceChannel.SignNextCommitment() require.NoError(t, err, "unable to sign commitment") + // If this is a taproot channel, then we'll generate fresh verification + // nonce for both sides. + if chanType.IsTaproot() { + _, err = aliceChannel.GenMusigNonces() + require.NoError(t, err) + _, err = bobChannel.GenMusigNonces() + require.NoError(t, err) + } + // Bob doesn't get this message so upon reconnection, they need to // synchronize. Alice should conclude that she owes Bob a commitment, // while Bob should think he's properly synchronized. @@ -3105,7 +3103,7 @@ func TestChanSyncOweCommitment(t *testing.T) { // This is a helper function that asserts Alice concludes that she // needs to retransmit the exact commitment that we failed to send // above. - assertAliceCommitRetransmit := func() { + assertAliceCommitRetransmit := func() *lnwire.CommitSig { aliceMsgsToSend, _, _, err := aliceChannel.ProcessChanSyncMsg( bobSyncMsg, ) @@ -3170,12 +3168,25 @@ func TestChanSyncOweCommitment(t *testing.T) { len(commitSigMsg.HtlcSigs)) } for i, htlcSig := range commitSigMsg.HtlcSigs { - if htlcSig != aliceNewCommit.HtlcSigs[i] { + if !bytes.Equal(htlcSig.RawBytes(), + aliceNewCommit.HtlcSigs[i].RawBytes()) { + t.Fatalf("htlc sig msgs don't match: "+ - "expected %x got %x", - aliceNewCommit.HtlcSigs[i], htlcSig) + "expected %v got %v", + spew.Sdump(aliceNewCommit.HtlcSigs[i]), + spew.Sdump(htlcSig)) } } + + // If this is a taproot channel, then partial sig information + // should be present in the commit sig sent over. This + // signature will be re-regenerated, so we can't compare it + // with the old one. + if chanType.IsTaproot() { + require.True(t, commitSigMsg.PartialSig.IsSome()) + } + + return commitSigMsg } // Alice should detect that she needs to re-send 5 messages: the 3 @@ -3196,14 +3207,19 @@ func TestChanSyncOweCommitment(t *testing.T) { // send the exact same set of messages. aliceChannel, err = restartChannel(aliceChannel) require.NoError(t, err, "unable to restart alice") - assertAliceCommitRetransmit() - // TODO(roasbeef): restart bob as well??? + // To properly simulate a restart, we'll use the *new* signature that + // would send in an actual p2p setting. + aliceReCommitSig := assertAliceCommitRetransmit() // At this point, we should be able to resume the prior state update // without any issues, resulting in Alice settling the 3 htlc's, and // adding one of her own. - err = bobChannel.ReceiveNewCommitment(aliceNewCommit.CommitSigs) + err = bobChannel.ReceiveNewCommitment(&CommitSigs{ + CommitSig: aliceReCommitSig.CommitSig, + HtlcSigs: aliceReCommitSig.HtlcSigs, + PartialSig: aliceReCommitSig.PartialSig, + }) require.NoError(t, err, "bob unable to process alice's commitment") bobRevocation, _, _, err := bobChannel.RevokeCurrentCommitment() require.NoError(t, err, "unable to revoke bob commitment") @@ -3290,16 +3306,46 @@ func TestChanSyncOweCommitment(t *testing.T) { } } -// TestChanSyncOweCommitmentPendingRemote asserts that local updates are applied -// to the remote commit across restarts. -func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { +// TestChanSyncOweCommitment tests that if Bob restarts (and then Alice) before +// he receives Alice's CommitSig message, then Alice concludes that she needs +// to re-send the CommitDiff. After the diff has been sent, both nodes should +// resynchronize and be able to complete the dangling commit. +func TestChanSyncOweCommitment(t *testing.T) { t.Parallel() + testCases := []struct { + name string + chanType channeldb.ChannelType + }{ + { + name: "tweakless", + chanType: channeldb.SingleFunderTweaklessBit, + }, + { + name: "anchors", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit, + }, + { + name: "taproot", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testChanSyncOweCommitment(t, tc.chanType) + }) + } +} + +func testChanSyncOweCommitmentPendingRemote(t *testing.T, + chanType channeldb.ChannelType, +) { // Create a test channel which will be used for the duration of this // unittest. - aliceChannel, bobChannel, err := CreateTestChannels( - t, channeldb.SingleFunderTweaklessBit, - ) + aliceChannel, bobChannel, err := CreateTestChannels(t, chanType) require.NoError(t, err, "unable to create test channels") var fakeOnionBlob [lnwire.OnionPacketSize]byte @@ -3382,6 +3428,12 @@ func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { bobChannel, err = restartChannel(bobChannel) require.NoError(t, err, "unable to restart bob") + // If this is a taproot channel, then since Bob just restarted, we need + // to exchange nonces once again. + if chanType.IsTaproot() { + require.NoError(t, initMusigNonce(aliceChannel, bobChannel)) + } + // Bob signs the commitment he owes. bobNewCommit, err := bobChannel.SignNextCommitment() require.NoError(t, err, "unable to sign commitment") @@ -3407,6 +3459,38 @@ func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { } } +// TestChanSyncOweCommitmentPendingRemote asserts that local updates are applied +// to the remote commit across restarts. +func TestChanSyncOweCommitmentPendingRemote(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + chanType channeldb.ChannelType + }{ + { + name: "tweakless", + chanType: channeldb.SingleFunderTweaklessBit, + }, + { + name: "anchors", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit, + }, + { + name: "taproot", + chanType: channeldb.SingleFunderTweaklessBit | + channeldb.AnchorOutputsBit | + channeldb.SimpleTaprootFeatureBit, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testChanSyncOweCommitmentPendingRemote(t, tc.chanType) + }) + } +} + // testChanSyncOweRevocation is the internal version of // TestChanSyncOweRevocation that is parameterized based on the type of channel // being used in the test. @@ -3556,8 +3640,6 @@ func testChanSyncOweRevocation(t *testing.T, chanType channeldb.ChannelType) { assertAliceOwesRevoke() - // TODO(roasbeef): restart bob too??? - // We'll continue by then allowing bob to process Alice's revocation // message. _, _, _, _, err = bobChannel.ReceiveRevocation(aliceRevocation) @@ -3606,11 +3688,19 @@ func TestChanSyncOweRevocation(t *testing.T) { testChanSyncOweRevocation(t, taprootBits) }) + t.Run("taproot", func(t *testing.T) { + taprootBits := channeldb.SimpleTaprootFeatureBit | + channeldb.AnchorOutputsBit | + channeldb.ZeroHtlcTxFeeBit | + channeldb.SingleFunderTweaklessBit + + testChanSyncOweRevocation(t, taprootBits) + }) } func testChanSyncOweRevocationAndCommit(t *testing.T, - chanType channeldb.ChannelType) { - + chanType channeldb.ChannelType, +) { // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. @@ -3735,6 +3825,14 @@ func testChanSyncOweRevocationAndCommit(t *testing.T, bobNewCommit.HtlcSigs[i]) } } + + // If this is a taproot channel, then partial sig information + // should be present in the commit sig sent over. This + // signature will be re-regenerated, so we can't compare it + // with the old one. + if chanType.IsTaproot() { + require.True(t, bobReCommitSigMsg.PartialSig.IsSome()) + } } // We expect Bob to send exactly two messages: first his revocation @@ -3794,8 +3892,8 @@ func TestChanSyncOweRevocationAndCommit(t *testing.T) { } func testChanSyncOweRevocationAndCommitForceTransition(t *testing.T, - chanType channeldb.ChannelType) { - + chanType channeldb.ChannelType, +) { // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. @@ -4803,8 +4901,8 @@ func TestChanAvailableBandwidth(t *testing.T) { ) assertBandwidthEstimateCorrect := func(aliceInitiate bool, - numNonDustHtlcsOnCommit lntypes.WeightUnit) { - + numNonDustHtlcsOnCommit lntypes.WeightUnit, + ) { // With the HTLC's added, we'll now query the AvailableBalance // method for the current available channel bandwidth from // Alice's PoV. @@ -4976,8 +5074,8 @@ func TestChanAvailableBalanceNearHtlcFee(t *testing.T) { // Helper method to check the current reported balance. checkBalance := func(t *testing.T, expBalanceAlice, - expBalanceBob lnwire.MilliSatoshi) { - + expBalanceBob lnwire.MilliSatoshi, + ) { t.Helper() aliceBalance := aliceChannel.AvailableBalance() if aliceBalance != expBalanceAlice { @@ -6248,8 +6346,8 @@ func TestMaxPendingAmount(t *testing.T) { } func assertChannelBalances(t *testing.T, alice, bob *LightningChannel, - aliceBalance, bobBalance btcutil.Amount) { - + aliceBalance, bobBalance btcutil.Amount, +) { _, _, line, _ := runtime.Caller(1) aliceSelfBalance := alice.channelState.LocalCommitment.LocalBalance.ToSatoshis() @@ -6940,7 +7038,8 @@ func assertInLog(t *testing.T, log *updateLog, numAdds, numFails int) { // assertInLogs asserts that the expected number of Adds and Fails occurs in // the local and remote update log of the given channel. func assertInLogs(t *testing.T, channel *LightningChannel, numAddsLocal, - numFailsLocal, numAddsRemote, numFailsRemote int) { + numFailsLocal, numAddsRemote, numFailsRemote int, +) { assertInLog(t, channel.localUpdateLog, numAddsLocal, numFailsLocal) assertInLog(t, channel.remoteUpdateLog, numAddsRemote, numFailsRemote) } @@ -6949,8 +7048,8 @@ func assertInLogs(t *testing.T, channel *LightningChannel, numAddsLocal, // state, and asserts that the new channel has had its logs restored to the // expected state. func restoreAndAssert(t *testing.T, channel *LightningChannel, numAddsLocal, - numFailsLocal, numAddsRemote, numFailsRemote int) { - + numFailsLocal, numAddsRemote, numFailsRemote int, +) { newChannel, err := NewLightningChannel( channel.Signer, channel.channelState, channel.sigPool, @@ -7211,8 +7310,8 @@ func TestChannelRestoreCommitHeight(t *testing.T) { // log after a restore. restoreAndAssertCommitHeights := func(t *testing.T, channel *LightningChannel, remoteLog bool, htlcIndex uint64, - expLocal, expRemote uint64) *LightningChannel { - + expLocal, expRemote uint64, + ) *LightningChannel { newChannel, err := NewLightningChannel( channel.Signer, channel.channelState, channel.sigPool, ) @@ -7667,10 +7766,9 @@ func TestIdealCommitFeeRate(t *testing.T) { // inputs fed to IdealCommitFeeRate. propertyTest := func(c *LightningChannel) func(ma maxAlloc, netFee, minRelayFee, maxAnchorFee fee) bool { - return func(ma maxAlloc, netFee, minRelayFee, - maxAnchorFee fee) bool { - + maxAnchorFee fee, + ) bool { idealFeeRate := c.IdealCommitFeeRate( chainfee.SatPerKWeight(netFee), chainfee.SatPerKWeight(minRelayFee), @@ -7715,8 +7813,8 @@ func TestIdealCommitFeeRate(t *testing.T) { // a channel is allowed to allocate to fees. It does not take a minimum // fee rate into account. maxFeeRate := func(c *LightningChannel, - maxFeeAlloc float64) chainfee.SatPerKWeight { - + maxFeeAlloc float64, + ) chainfee.SatPerKWeight { balance, weight := c.availableBalance(AdditionalHtlc) feeRate := c.localCommitChain.tip().feePerKw currentFee := feeRate.FeeForWeight(weight) @@ -7885,8 +7983,8 @@ func TestIdealCommitFeeRate(t *testing.T) { assertIdealFeeRate := func(c *LightningChannel, netFee, minRelay, maxAnchorCommit chainfee.SatPerKWeight, - maxFeeAlloc float64, expectedFeeRate chainfee.SatPerKWeight) { - + maxFeeAlloc float64, expectedFeeRate chainfee.SatPerKWeight, + ) { feeRate := c.IdealCommitFeeRate( netFee, minRelay, maxAnchorCommit, maxFeeAlloc, ) @@ -8582,8 +8680,8 @@ func TestEvaluateView(t *testing.T) { // checkExpectedHtlcs checks that a set of htlcs that we have contains all the // htlcs we expect. func checkExpectedHtlcs(t *testing.T, actual []*PaymentDescriptor, - expected map[uint64]bool) { - + expected map[uint64]bool, +) { if len(expected) != len(actual) { t.Fatalf("expected: %v htlcs, got: %v", len(expected), len(actual)) @@ -9178,13 +9276,11 @@ func TestProcessAddRemoveEntry(t *testing.T) { EntryType: test.updateType, } - var ( - // Start both parties off with an initial - // balance. Copy by value here so that we do - // not mutate the startBalance constant. - ourBalance, theirBalance = startBalance, - startBalance - ) + // Start both parties off with an initial + // balance. Copy by value here so that we do + // not mutate the startBalance constant. + ourBalance, theirBalance := startBalance, + startBalance // Choose the processing function we need based on the // update type. Process remove is used for settles, @@ -9710,8 +9806,8 @@ func TestIsChannelClean(t *testing.T) { // assertCleanOrDirty is a helper function that asserts that both channels are // clean if clean is true, and dirty if clean is false. func assertCleanOrDirty(clean bool, alice, bob *LightningChannel, - t *testing.T) { - + t *testing.T, +) { t.Helper() if clean { @@ -9749,8 +9845,8 @@ func testGetDustSum(t *testing.T, chantype channeldb.ChannelType) { // Use a function closure to assert the dust sum for a passed channel's // local and remote commitments match the expected values. checkDust := func(c *LightningChannel, expLocal, - expRemote lnwire.MilliSatoshi) { - + expRemote lnwire.MilliSatoshi, + ) { localDustSum := c.GetDustSum( lntypes.Local, fn.None[chainfee.SatPerKWeight](), ) @@ -9905,8 +10001,8 @@ func testGetDustSum(t *testing.T, chantype channeldb.ChannelType) { // deriveDummyRetributionParams is a helper function that derives a list of // dummy params to assist retribution creation related tests. func deriveDummyRetributionParams(chanState *channeldb.OpenChannel) (uint32, - *CommitmentKeyRing, chainhash.Hash) { - + *CommitmentKeyRing, chainhash.Hash, +) { config := chanState.RemoteChanCfg commitHash := chanState.RemoteCommitment.CommitTx.TxHash() keyRing := DeriveCommitmentKeys( @@ -10283,8 +10379,8 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // assertRetribution is a helper closure that checks a given breach // retribution has the expected values on certain fields. assertRetribution := func(br *BreachRetribution, - localIndex, remoteIndex uint32) { - + localIndex, remoteIndex uint32, + ) { require.Equal(t, txid, br.BreachTxHash) require.Equal(t, chainHash, br.ChainHash) require.Equal(t, breachHeight, br.BreachHeight) @@ -10399,8 +10495,8 @@ func TestExtractPayDescs(t *testing.T) { // assertPayDescMatchHTLC compares a PaymentDescriptor to a channeldb.HTLC and // asserts that the fields are matched. func assertPayDescMatchHTLC(t *testing.T, pd PaymentDescriptor, - htlc channeldb.HTLC) { - + htlc channeldb.HTLC, +) { require := require.New(t) require.EqualValues(htlc.RHash, pd.RHash, "RHash") diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index 0e5d527234..1e5af77031 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -419,6 +419,28 @@ func CreateTestChannels(t *testing.T, chanType channeldb.ChannelType, return channelAlice, channelBob, nil } +// initMusigNonce is used to manually setup musig2 nonces for a new channel, +// outside the normal chan-reest flow. +func initMusigNonce(chanA, chanB *LightningChannel) error { + chanANonces, err := chanA.GenMusigNonces() + if err != nil { + return err + } + chanBNonces, err := chanB.GenMusigNonces() + if err != nil { + return err + } + + if err := chanA.InitRemoteMusigNonces(chanBNonces); err != nil { + return err + } + if err := chanB.InitRemoteMusigNonces(chanANonces); err != nil { + return err + } + + return nil +} + // initRevocationWindows simulates a new channel being opened within the p2p // network by populating the initial revocation windows of the passed // commitment state machines. @@ -427,19 +449,7 @@ func initRevocationWindows(chanA, chanB *LightningChannel) error { // either FundingLocked or ChannelReestablish by calling // InitRemoteMusigNonces for both sides. if chanA.channelState.ChanType.IsTaproot() { - chanANonces, err := chanA.GenMusigNonces() - if err != nil { - return err - } - chanBNonces, err := chanB.GenMusigNonces() - if err != nil { - return err - } - - if err := chanA.InitRemoteMusigNonces(chanBNonces); err != nil { - return err - } - if err := chanB.InitRemoteMusigNonces(chanANonces); err != nil { + if err := initMusigNonce(chanA, chanB); err != nil { return err } } diff --git a/sqldb/sqlc/amp_invoices.sql.go b/sqldb/sqlc/amp_invoices.sql.go index 3fcfe4b27b..e47b1c803d 100644 --- a/sqldb/sqlc/amp_invoices.sql.go +++ b/sqldb/sqlc/amp_invoices.sql.go @@ -268,15 +268,16 @@ func (q *Queries) InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubI const updateAMPSubInvoiceHTLCPreimage = `-- name: UpdateAMPSubInvoiceHTLCPreimage :execresult UPDATE amp_sub_invoice_htlcs AS a -SET preimage = $4 +SET preimage = $5 WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = ( - SELECT id FROM invoice_htlcs AS i WHERE i.htlc_id = $3 + SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4 ) ` type UpdateAMPSubInvoiceHTLCPreimageParams struct { InvoiceID int64 SetID []byte + ChanID string HtlcID int64 Preimage []byte } @@ -285,6 +286,7 @@ func (q *Queries) UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg Updat return q.db.ExecContext(ctx, updateAMPSubInvoiceHTLCPreimage, arg.InvoiceID, arg.SetID, + arg.ChanID, arg.HtlcID, arg.Preimage, ) diff --git a/sqldb/sqlc/invoices.sql.go b/sqldb/sqlc/invoices.sql.go index fde02391c6..9e31380abb 100644 --- a/sqldb/sqlc/invoices.sql.go +++ b/sqldb/sqlc/invoices.sql.go @@ -78,7 +78,7 @@ WHERE ( created_at >= $6 OR $6 IS NULL ) AND ( - created_at <= $7 OR + created_at < $7 OR $7 IS NULL ) AND ( CASE @@ -170,21 +170,22 @@ const getInvoice = `-- name: GetInvoice :many SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at FROM invoices i -LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id +LEFT JOIN amp_sub_invoices a +ON i.id = a.invoice_id +AND ( + a.set_id = $1 OR $1 IS NULL +) WHERE ( - i.id = $1 OR - $1 IS NULL -) AND ( - i.hash = $2 OR + i.id = $2 OR $2 IS NULL ) AND ( - i.preimage = $3 OR + i.hash = $3 OR $3 IS NULL ) AND ( - i.payment_addr = $4 OR + i.preimage = $4 OR $4 IS NULL ) AND ( - a.set_id = $5 OR + i.payment_addr = $5 OR $5 IS NULL ) GROUP BY i.id @@ -192,11 +193,11 @@ LIMIT 2 ` type GetInvoiceParams struct { + SetID []byte AddIndex sql.NullInt64 Hash []byte Preimage []byte PaymentAddr []byte - SetID []byte } // This method may return more than one invoice if filter using multiple fields @@ -204,11 +205,11 @@ type GetInvoiceParams struct { // we bubble up an error in those cases. func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) { rows, err := q.db.QueryContext(ctx, getInvoice, + arg.SetID, arg.AddIndex, arg.Hash, arg.Preimage, arg.PaymentAddr, - arg.SetID, ) if err != nil { return nil, err @@ -250,6 +251,55 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi return items, nil } +const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many +SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at +FROM invoices i +INNER JOIN amp_sub_invoices a +ON i.id = a.invoice_id AND a.set_id = $1 +` + +func (q *Queries) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) { + rows, err := q.db.QueryContext(ctx, getInvoiceBySetID, setID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Invoice + for rows.Next() { + var i Invoice + if err := rows.Scan( + &i.ID, + &i.Hash, + &i.Preimage, + &i.SettleIndex, + &i.SettledAt, + &i.Memo, + &i.AmountMsat, + &i.CltvDelta, + &i.Expiry, + &i.PaymentAddr, + &i.PaymentRequest, + &i.PaymentRequestHash, + &i.State, + &i.AmountPaidMsat, + &i.IsAmp, + &i.IsHodl, + &i.IsKeysend, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many SELECT feature, invoice_id FROM invoice_features diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index d55d8090a7..04b61c7007 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -21,6 +21,7 @@ type Querier interface { // from different invoices. It is the caller's responsibility to ensure that // we bubble up an error in those cases. GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) + GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error) GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error) diff --git a/sqldb/sqlc/queries/amp_invoices.sql b/sqldb/sqlc/queries/amp_invoices.sql index 3b6ee76ac3..1fad75e0da 100644 --- a/sqldb/sqlc/queries/amp_invoices.sql +++ b/sqldb/sqlc/queries/amp_invoices.sql @@ -61,7 +61,7 @@ WHERE ( -- name: UpdateAMPSubInvoiceHTLCPreimage :execresult UPDATE amp_sub_invoice_htlcs AS a -SET preimage = $4 +SET preimage = $5 WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = ( - SELECT id FROM invoice_htlcs AS i WHERE i.htlc_id = $3 + SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4 ); diff --git a/sqldb/sqlc/queries/invoices.sql b/sqldb/sqlc/queries/invoices.sql index 07c5ca418b..2a49553e65 100644 --- a/sqldb/sqlc/queries/invoices.sql +++ b/sqldb/sqlc/queries/invoices.sql @@ -26,7 +26,11 @@ WHERE invoice_id = $1; -- name: GetInvoice :many SELECT i.* FROM invoices i -LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id +LEFT JOIN amp_sub_invoices a +ON i.id = a.invoice_id +AND ( + a.set_id = sqlc.narg('set_id') OR sqlc.narg('set_id') IS NULL +) WHERE ( i.id = sqlc.narg('add_index') OR sqlc.narg('add_index') IS NULL @@ -39,13 +43,16 @@ WHERE ( ) AND ( i.payment_addr = sqlc.narg('payment_addr') OR sqlc.narg('payment_addr') IS NULL -) AND ( - a.set_id = sqlc.narg('set_id') OR - sqlc.narg('set_id') IS NULL ) GROUP BY i.id LIMIT 2; +-- name: GetInvoiceBySetID :many +SELECT i.* +FROM invoices i +INNER JOIN amp_sub_invoices a +ON i.id = a.invoice_id AND a.set_id = $1; + -- name: FilterInvoices :many SELECT invoices.* @@ -69,7 +76,7 @@ WHERE ( created_at >= sqlc.narg('created_after') OR sqlc.narg('created_after') IS NULL ) AND ( - created_at <= sqlc.narg('created_before') OR + created_at < sqlc.narg('created_before') OR sqlc.narg('created_before') IS NULL ) AND ( CASE