From 639faeed6d8a30f3687d1850e4b0947dfb7ccb8f Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 21 Sep 2021 19:18:10 +0200 Subject: [PATCH 01/15] channeldb: use kvdb.Backend instead of channeldb.DB for the Graph --- channeldb/db.go | 2 +- channeldb/graph.go | 22 ++--- channeldb/graph_test.go | 173 ++++++++++++++++++---------------------- 3 files changed, 89 insertions(+), 108 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index 57ebfdb243..2a33450cd2 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -269,7 +269,7 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, dryRun: opts.dryRun, } chanDB.graph = newChannelGraph( - chanDB, opts.RejectCacheSize, opts.ChannelCacheSize, + backend, opts.RejectCacheSize, opts.ChannelCacheSize, opts.BatchCommitInterval, ) diff --git a/channeldb/graph.go b/channeldb/graph.go index 678b7ac06c..cb9268307d 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -174,7 +174,7 @@ const ( // independently. Edge removal results in the deletion of all edge information // for that edge. type ChannelGraph struct { - db *DB + db kvdb.Backend cacheMu sync.RWMutex rejectCache *rejectCache @@ -186,7 +186,7 @@ type ChannelGraph struct { // newChannelGraph allocates a new ChannelGraph backed by a DB instance. The // returned instance has its own unique reject cache and channel cache. -func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int, +func newChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, batchCommitInterval time.Duration) *ChannelGraph { g := &ChannelGraph{ @@ -195,17 +195,17 @@ func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int, chanCache: newChannelCache(chanCacheSize), } g.chanScheduler = batch.NewTimeScheduler( - db.Backend, &g.cacheMu, batchCommitInterval, + db, &g.cacheMu, batchCommitInterval, ) g.nodeScheduler = batch.NewTimeScheduler( - db.Backend, nil, batchCommitInterval, + db, nil, batchCommitInterval, ) return g } // Database returns a pointer to the underlying database. -func (c *ChannelGraph) Database() *DB { +func (c *ChannelGraph) Database() kvdb.Backend { return c.db } @@ -2232,7 +2232,7 @@ type LightningNode struct { // compatible manner. ExtraOpaqueData []byte - db *DB + db kvdb.Backend // TODO(roasbeef): discovery will need storage to keep it's last IP // address and re-announce if interface changes? @@ -2460,7 +2460,7 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. -func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB, +func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { traversal := func(tx kvdb.RTx) error { @@ -2627,7 +2627,7 @@ type ChannelEdgeInfo struct { // compatible manner. ExtraOpaqueData []byte - db *DB + db kvdb.Backend } // AddNodeKeys is a setter-like method that can be used to replace the set of @@ -2988,7 +2988,7 @@ type ChannelEdgePolicy struct { // compatible manner. ExtraOpaqueData []byte - db *DB + db kvdb.Backend } // Signature is a channel announcement signature, which is needed for proper @@ -3406,7 +3406,7 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64, c.cacheMu.Lock() defer c.cacheMu.Unlock() - err := kvdb.Batch(c.db.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(c.db, func(tx kvdb.RwTx) error { edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -4102,7 +4102,7 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, nodes kvdb.RBucket, chanID []byte, - db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { + db kvdb.Backend) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { edgeInfo := edgeIndex.Get(chanID) if edgeInfo == nil { diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index ccd5f379ea..6d04429732 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -45,7 +45,7 @@ var ( testPub = route.Vertex{2, 202, 4} ) -func createLightningNode(db *DB, priv *btcec.PrivateKey) (*LightningNode, error) { +func createLightningNode(db kvdb.Backend, priv *btcec.PrivateKey) (*LightningNode, error) { updateTime := prand.Int63() pub := priv.PubKey().SerializeCompressed() @@ -64,7 +64,7 @@ func createLightningNode(db *DB, priv *btcec.PrivateKey) (*LightningNode, error) return n, nil } -func createTestVertex(db *DB) (*LightningNode, error) { +func createTestVertex(db kvdb.Backend) (*LightningNode, error) { priv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { return nil, err @@ -96,7 +96,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { Addresses: testAddrs, ExtraOpaqueData: []byte("extra new data"), PubKeyBytes: testPub, - db: db, + db: graph.db, } // First, insert the node into the graph DB. This should succeed @@ -180,7 +180,7 @@ func TestPartialNode(t *testing.T) { HaveNodeAnnouncement: false, LastUpdate: time.Unix(0, 0), PubKeyBytes: testPub, - db: db, + db: graph.db, } if err := compareNodes(node, dbNode); err != nil { @@ -214,7 +214,7 @@ func TestAliasLookup(t *testing.T) { // We'd like to test the alias index within the database, so first // create a new test node. - testNode, err := createTestVertex(db) + testNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -241,7 +241,7 @@ func TestAliasLookup(t *testing.T) { } // Ensure that looking up a non-existent alias results in an error. - node, err := createTestVertex(db) + node, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -268,7 +268,7 @@ func TestSourceNode(t *testing.T) { // We'd like to test the setting/getting of the source node, so we // first create a fake node to use within the test. - testNode, err := createTestVertex(db) + testNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -309,11 +309,11 @@ func TestEdgeInsertionDeletion(t *testing.T) { // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -441,7 +441,7 @@ func TestDisconnectBlockAtHeight(t *testing.T) { } graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) } @@ -451,11 +451,11 @@ func TestDisconnectBlockAtHeight(t *testing.T) { // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -641,7 +641,7 @@ func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, } } -func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, +func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) { var ( @@ -731,14 +731,14 @@ func TestEdgeInfoUpdates(t *testing.T) { // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -747,7 +747,7 @@ func TestEdgeInfoUpdates(t *testing.T) { } // Create an edge and add it to the db. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) // Make sure inserting the policy at this point, before the edge info // is added, will fail. @@ -825,13 +825,13 @@ func TestEdgeInfoUpdates(t *testing.T) { assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) } -func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy { +func randEdgePolicy(chanID uint64, db kvdb.Backend) *ChannelEdgePolicy { update := prand.Int63() - return newEdgePolicy(chanID, op, db, update) + return newEdgePolicy(chanID, db, update) } -func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, +func newEdgePolicy(chanID uint64, db kvdb.Backend, updateTime int64) *ChannelEdgePolicy { return &ChannelEdgePolicy{ @@ -866,7 +866,7 @@ func TestGraphTraversal(t *testing.T) { nodes := make([]*LightningNode, numNodes) nodeIndex := map[string]struct{}{} for i := 0; i < numNodes; i++ { - node, err := createTestVertex(db) + node, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create node: %v", err) } @@ -941,7 +941,7 @@ func TestGraphTraversal(t *testing.T) { // Create and add an edge with random data that points from // node1 -> node2. - edge := randEdgePolicy(chanID, op, db) + edge := randEdgePolicy(chanID, graph.db) edge.ChannelFlags = 0 edge.Node = secondNode edge.SigBytes = testSig.Serialize() @@ -951,7 +951,7 @@ func TestGraphTraversal(t *testing.T) { // Create another random edge that points from node2 -> node1 // this time. - edge = randEdgePolicy(chanID, op, db) + edge = randEdgePolicy(chanID, graph.db) edge.ChannelFlags = 1 edge.Node = firstNode edge.SigBytes = testSig.Serialize() @@ -1119,7 +1119,7 @@ func TestGraphPruning(t *testing.T) { } graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) } @@ -1133,7 +1133,7 @@ func TestGraphPruning(t *testing.T) { const numNodes = 5 graphNodes := make([]*LightningNode, numNodes) for i := 0; i < numNodes; i++ { - node, err := createTestVertex(db) + node, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create node: %v", err) } @@ -1192,7 +1192,7 @@ func TestGraphPruning(t *testing.T) { // Create and add an edge with random data that points from // node_i -> node_i+1 - edge := randEdgePolicy(chanID, op, db) + edge := randEdgePolicy(chanID, graph.db) edge.ChannelFlags = 0 edge.Node = graphNodes[i] edge.SigBytes = testSig.Serialize() @@ -1202,7 +1202,7 @@ func TestGraphPruning(t *testing.T) { // Create another random edge that points from node_i+1 -> // node_i this time. - edge = randEdgePolicy(chanID, op, db) + edge = randEdgePolicy(chanID, graph.db) edge.ChannelFlags = 1 edge.Node = graphNodes[i] edge.SigBytes = testSig.Serialize() @@ -1341,11 +1341,11 @@ func TestHighestChanID(t *testing.T) { // Next, we'll insert two channels into the database, with each channel // connecting the same two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1419,14 +1419,14 @@ func TestChanUpdatesInHorizon(t *testing.T) { } // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1441,12 +1441,6 @@ func TestChanUpdatesInHorizon(t *testing.T) { endTime := startTime edges := make([]ChannelEdge, 0, numChans) for i := 0; i < numChans; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - channel, chanID := createEdge( uint32(i*10), 0, 0, 0, node1, node2, ) @@ -1460,7 +1454,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { endTime = endTime.Add(time.Second * 10) edge1 := newEdgePolicy( - chanID.ToUint64(), op, db, edge1UpdateTime.Unix(), + chanID.ToUint64(), graph.db, edge1UpdateTime.Unix(), ) edge1.ChannelFlags = 0 edge1.Node = node2 @@ -1470,7 +1464,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { } edge2 := newEdgePolicy( - chanID.ToUint64(), op, db, edge2UpdateTime.Unix(), + chanID.ToUint64(), graph.db, edge2UpdateTime.Unix(), ) edge2.ChannelFlags = 1 edge2.Node = node1 @@ -1602,7 +1596,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { const numNodes = 10 nodeAnns := make([]LightningNode, 0, numNodes) for i := 0; i < numNodes; i++ { - nodeAnn, err := createTestVertex(db) + nodeAnn, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test vertex: %v", err) } @@ -1716,14 +1710,14 @@ func TestFilterKnownChanIDs(t *testing.T) { } // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1823,14 +1817,14 @@ func TestFilterChannelRange(t *testing.T) { // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1957,14 +1951,14 @@ func TestFetchChanInfos(t *testing.T) { // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -1980,12 +1974,6 @@ func TestFetchChanInfos(t *testing.T) { edges := make([]ChannelEdge, 0, numChans) edgeQuery := make([]uint64, 0, numChans) for i := 0; i < numChans; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - channel, chanID := createEdge( uint32(i*10), 0, 0, 0, node1, node2, ) @@ -1998,7 +1986,7 @@ func TestFetchChanInfos(t *testing.T) { endTime = updateTime.Add(time.Second * 10) edge1 := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), + chanID.ToUint64(), graph.db, updateTime.Unix(), ) edge1.ChannelFlags = 0 edge1.Node = node2 @@ -2008,7 +1996,7 @@ func TestFetchChanInfos(t *testing.T) { } edge2 := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), + chanID.ToUint64(), graph.db, updateTime.Unix(), ) edge2.ChannelFlags = 1 edge2.Node = node1 @@ -2084,14 +2072,14 @@ func TestIncompleteChannelPolicies(t *testing.T) { graph := db.ChannelGraph() // Create two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2099,13 +2087,6 @@ func TestIncompleteChannelPolicies(t *testing.T) { t.Fatalf("unable to add node: %v", err) } - // Create channel between nodes. - txHash := sha256.Sum256([]byte{0}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - channel, chanID := createEdge( uint32(0), 0, 0, 0, node1, node2, ) @@ -2156,7 +2137,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { updateTime := time.Unix(1234, 0) edgePolicy := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), + chanID.ToUint64(), graph.db, updateTime.Unix(), ) edgePolicy.ChannelFlags = 0 edgePolicy.Node = node2 @@ -2171,7 +2152,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { // Create second policy and assert that both policies are reported // as present. edgePolicy = newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), + chanID.ToUint64(), graph.db, updateTime.Unix(), ) edgePolicy.ChannelFlags = 1 edgePolicy.Node = node1 @@ -2197,7 +2178,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { } graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) } @@ -2207,14 +2188,14 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2229,7 +2210,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Fatalf("unable to add edge: %v", err) } - edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge1 := randEdgePolicy(chanID.ToUint64(), graph.db) edge1.ChannelFlags = 0 edge1.Node = node1 edge1.SigBytes = testSig.Serialize() @@ -2237,7 +2218,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Fatalf("unable to update edge: %v", err) } - edge2 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge2 := randEdgePolicy(chanID.ToUint64(), graph.db) edge2.ChannelFlags = 1 edge2.Node = node2 edge2.SigBytes = testSig.Serialize() @@ -2253,7 +2234,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { timestampSet[t] = struct{}{} } - err := kvdb.View(db, func(tx kvdb.RTx) error { + err := kvdb.View(graph.db, func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -2354,7 +2335,7 @@ func TestPruneGraphNodes(t *testing.T) { // We'll start off by inserting our source node, to ensure that it's // the only node left after we prune the graph. graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) } @@ -2365,21 +2346,21 @@ func TestPruneGraphNodes(t *testing.T) { // With the source node inserted, we'll now add three nodes to the // channel graph, at the end of the scenario, only two of these nodes // should still be in the graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } - node3, err := createTestVertex(db) + node3, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2396,7 +2377,7 @@ func TestPruneGraphNodes(t *testing.T) { // We'll now insert an advertised edge, but it'll only be the edge that // points from the first to the second node. - edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge1 := randEdgePolicy(chanID.ToUint64(), graph.db) edge1.ChannelFlags = 0 edge1.Node = node1 edge1.SigBytes = testSig.Serialize() @@ -2439,14 +2420,14 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { // To start, we'll create two nodes, and only add one of them to the // channel graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2493,7 +2474,7 @@ func TestNodePruningUpdateIndexDeletion(t *testing.T) { // We'll first populate our graph with a single node that will be // removed shortly. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2602,7 +2583,7 @@ func TestNodeIsPublic(t *testing.T) { // participant's graph. nodes := []*LightningNode{aliceNode, bobNode, carolNode} edges := []*ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} - dbs := []*DB{aliceDB, bobDB, carolDB} + dbs := []kvdb.Backend{aliceGraph.db, bobGraph.db, carolGraph.db} graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} for i, graph := range graphs { for _, node := range nodes { @@ -2711,7 +2692,7 @@ func TestDisabledChannelIDs(t *testing.T) { graph := db.ChannelGraph() // Create first node and add it to the graph. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2720,7 +2701,7 @@ func TestDisabledChannelIDs(t *testing.T) { } // Create second node and add it to the graph. - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -2729,7 +2710,7 @@ func TestDisabledChannelIDs(t *testing.T) { } // Adding a new channel edge to the graph. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -2810,19 +2791,19 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } @@ -2862,7 +2843,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { // Attempting to deserialize these bytes should return an error. r := bytes.NewReader(stripped) - err = kvdb.View(db, func(tx kvdb.RTx) error { + err = kvdb.View(graph.db, func(tx kvdb.RTx) error { nodes := tx.ReadBucket(nodeBucket) if nodes == nil { return ErrGraphNotFound @@ -2882,7 +2863,7 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { } // Put the stripped bytes in the DB. - err = kvdb.Update(db, func(tx kvdb.RwTx) error { + err = kvdb.Update(graph.db, func(tx kvdb.RwTx) error { edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return ErrEdgeNotFound @@ -2987,11 +2968,11 @@ func TestGraphZombieIndex(t *testing.T) { } graph := db.ChannelGraph() - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test vertex: %v", err) } - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test vertex: %v", err) } @@ -3002,7 +2983,7 @@ func TestGraphZombieIndex(t *testing.T) { node1, node2 = node2, node1 } - edge, _, _ := createChannelEdge(db, node1, node2) + edge, _, _ := createChannelEdge(graph.db, node1, node2) if err := graph.AddChannelEdge(edge); err != nil { t.Fatalf("unable to create channel edge: %v", err) } @@ -3238,16 +3219,16 @@ func TestBatchedAddChannelEdge(t *testing.T) { defer cleanUp() graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) + sourceNode, err := createTestVertex(graph.db) require.Nil(t, err) err = graph.SetSourceNode(sourceNode) require.Nil(t, err) // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) require.Nil(t, err) - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) require.Nil(t, err) // In addition to the fake vertexes we create some fake channel @@ -3324,17 +3305,17 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. - node1, err := createTestVertex(db) + node1, err := createTestVertex(graph.db) require.Nil(t, err) err = graph.AddLightningNode(node1) require.Nil(t, err) - node2, err := createTestVertex(db) + node2, err := createTestVertex(graph.db) require.Nil(t, err) err = graph.AddLightningNode(node2) require.Nil(t, err) // Create an edge and add it to the db. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) // Make sure inserting the policy at this point, before the edge info // is added, will fail. From 292b8e1ce62b10bb0fdae70445159cacc278a2c8 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 21 Sep 2021 19:18:12 +0200 Subject: [PATCH 02/15] channeldb: fix dangerous type casting hack --- channeldb/channel.go | 41 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index 31873ae3fa..a4a47e5874 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -875,12 +875,43 @@ func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey, func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // nolint:interfacer outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, error) { - readBucket, err := fetchChanBucket(tx, nodeKey, outPoint, chainHash) - if err != nil { + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket := tx.ReadWriteBucket(openChannelBucket) + if openChanBucket == nil { + return nil, ErrNoChanDBExists + } + + // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like + // CreateIfNotExists, will return error + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := nodeKey.SerializeCompressed() + nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub) + if nodeChanBucket == nil { + return nil, ErrNoActiveChannels + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:]) + if chainBucket == nil { + return nil, ErrNoActiveChannels + } + + // With the bucket for the node and chain fetched, we can now go down + // another level, for this channel itself. + var chanPointBuf bytes.Buffer + if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { return nil, err } + chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes()) + if chanBucket == nil { + return nil, ErrChannelNotFound + } - return readBucket.(kvdb.RwBucket), nil + return chanBucket, nil } // fullSync syncs the contents of an OpenChannel while re-using an existing @@ -965,7 +996,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { defer c.Unlock() if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { - chanBucket, err := fetchChanBucket( + chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) if err != nil { @@ -980,7 +1011,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { channel.IsPending = false channel.ShortChannelID = openLoc - return putOpenChannel(chanBucket.(kvdb.RwBucket), channel) + return putOpenChannel(chanBucket, channel) }, func() {}); err != nil { return err } From 60cccf840970c00fdc81a3417117a59423c6716f Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 21 Sep 2021 19:18:13 +0200 Subject: [PATCH 03/15] multi: carve out LinkNodeDB from channeldb.DB for cleaner separation --- channeldb/channel.go | 5 +- channeldb/db.go | 111 +++++++++++++++++++++++++++------------- channeldb/db_test.go | 8 +-- channeldb/nodes.go | 38 ++++++++------ channeldb/nodes_test.go | 16 +++--- server.go | 2 +- 6 files changed, 115 insertions(+), 65 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index a4a47e5874..f6b5fab4d0 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -1501,7 +1501,10 @@ func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr) error { // Next, we need to establish a (possibly) new LinkNode relationship // for this channel. The LinkNode metadata contains reachability, // up-time, and service bits related information. - linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...) + linkNode := NewLinkNode( + &LinkNodeDB{backend: c.Db.Backend}, + wire.MainNet, c.IdentityPub, addrs..., + ) // TODO(roasbeef): do away with link node all together? diff --git a/channeldb/db.go b/channeldb/db.go index 2a33450cd2..a9aa0d124f 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -217,6 +217,9 @@ var ( type DB struct { kvdb.Backend + // linkNodeDB separates all DB operations on LinkNodes. + linkNodeDB *LinkNodeDB + dbPath string graph *ChannelGraph clock clock.Clock @@ -265,9 +268,13 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, chanDB := &DB{ Backend: backend, - clock: opts.clock, - dryRun: opts.dryRun, + linkNodeDB: &LinkNodeDB{ + backend: backend, + }, + clock: opts.clock, + dryRun: opts.dryRun, } + chanDB.graph = newChannelGraph( backend, opts.RejectCacheSize, opts.ChannelCacheSize, opts.BatchCommitInterval, @@ -915,7 +922,11 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( // the pending funds in a channel that has been forcibly closed have been // swept. func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { + var ( + openChannels []*OpenChannel + pruneLinkNode *btcec.PublicKey + ) + err := kvdb.Update(d, func(tx kvdb.RwTx) error { var b bytes.Buffer if err := writeOutpoint(&b, chanPoint); err != nil { return err @@ -961,19 +972,33 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { // other open channels with this peer. If we don't we'll // garbage collect it to ensure we don't establish persistent // connections to peers without open channels. - return d.pruneLinkNode(tx, chanSummary.RemotePub) - }, func() {}) + pruneLinkNode = chanSummary.RemotePub + openChannels, err = d.fetchOpenChannels(tx, pruneLinkNode) + if err != nil { + return fmt.Errorf("unable to fetch open channels for "+ + "peer %x: %v", + pruneLinkNode.SerializeCompressed(), err) + } + + return nil + }, func() { + openChannels = nil + pruneLinkNode = nil + }) + if err != nil { + return err + } + + // Decide whether we want to remove the link node, based upon the number + // of still open channels. + return d.pruneLinkNode(openChannels, pruneLinkNode) } // pruneLinkNode determines whether we should garbage collect a link node from // the database due to no longer having any open channels with it. If there are // any left, then this acts as a no-op. -func (d *DB) pruneLinkNode(tx kvdb.RwTx, remotePub *btcec.PublicKey) error { - openChannels, err := d.fetchOpenChannels(tx, remotePub) - if err != nil { - return fmt.Errorf("unable to fetch open channels for peer %x: "+ - "%v", remotePub.SerializeCompressed(), err) - } +func (d *DB) pruneLinkNode(openChannels []*OpenChannel, + remotePub *btcec.PublicKey) error { if len(openChannels) > 0 { return nil @@ -982,27 +1007,42 @@ func (d *DB) pruneLinkNode(tx kvdb.RwTx, remotePub *btcec.PublicKey) error { log.Infof("Pruning link node %x with zero open channels from database", remotePub.SerializeCompressed()) - return d.deleteLinkNode(tx, remotePub) + return d.linkNodeDB.DeleteLinkNode(remotePub) } // PruneLinkNodes attempts to prune all link nodes found within the databse with // whom we no longer have any open channels with. func (d *DB) PruneLinkNodes() error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { - linkNodes, err := d.fetchAllLinkNodes(tx) + allLinkNodes, err := d.linkNodeDB.FetchAllLinkNodes() + if err != nil { + return err + } + + for _, linkNode := range allLinkNodes { + var ( + openChannels []*OpenChannel + linkNode = linkNode + ) + err := kvdb.View(d, func(tx kvdb.RTx) error { + var err error + openChannels, err = d.fetchOpenChannels( + tx, linkNode.IdentityPub, + ) + return err + }, func() { + openChannels = nil + }) if err != nil { return err } - for _, linkNode := range linkNodes { - err := d.pruneLinkNode(tx, linkNode.IdentityPub) - if err != nil { - return err - } + err = d.pruneLinkNode(openChannels, linkNode.IdentityPub) + if err != nil { + return err } + } - return nil - }, func() {}) + return nil } // ChannelShell is a shell of a channel that is meant to be used for channel @@ -1060,19 +1100,13 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { // AddrsForNode consults the graph and channel database for all addresses known // to the passed node public key. func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { - var ( - linkNode *LinkNode - graphNode LightningNode - ) - - dbErr := kvdb.View(d, func(tx kvdb.RTx) error { - var err error - - linkNode, err = fetchLinkNode(tx, nodePub) - if err != nil { - return err - } + linkNode, err := d.linkNodeDB.FetchLinkNode(nodePub) + if err != nil { + return nil, err + } + var graphNode LightningNode + err = kvdb.View(d, func(tx kvdb.RTx) error { // We'll also query the graph for this peer to see if they have // any addresses that we don't currently have stored within the // link node database. @@ -1092,8 +1126,8 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { }, func() { linkNode = nil }) - if dbErr != nil { - return nil, dbErr + if err != nil { + return nil, err } // Now that we have both sources of addrs for this node, we'll use a @@ -1236,11 +1270,16 @@ func (d *DB) syncVersions(versions []version) error { }, func() {}) } -// ChannelGraph returns a new instance of the directed channel graph. +// ChannelGraph returns the current instance of the directed channel graph. func (d *DB) ChannelGraph() *ChannelGraph { return d.graph } +// LinkNodeDB returns the current instance of the link node database. +func (d *DB) LinkNodeDB() *LinkNodeDB { + return d.linkNodeDB +} + func getLatestDBVersion(versions []version) uint32 { return versions[len(versions)-1].number } diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 04c0dcd414..ef471c84c5 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -210,8 +210,8 @@ func TestAddrsForNode(t *testing.T) { if err != nil { t.Fatalf("unable to recv node pub: %v", err) } - linkNode := cdb.NewLinkNode( - wire.MainNet, nodePub, anotherAddr, + linkNode := NewLinkNode( + cdb.linkNodeDB, wire.MainNet, nodePub, anotherAddr, ) if err := linkNode.Sync(); err != nil { t.Fatalf("unable to sync link node: %v", err) @@ -423,7 +423,9 @@ func TestRestoreChannelShells(t *testing.T) { // We should also be able to find the link node that was inserted by // its public key. - linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub) + linkNode, err := cdb.linkNodeDB.FetchLinkNode( + channelShell.Chan.IdentityPub, + ) if err != nil { t.Fatalf("unable to fetch link node: %v", err) } diff --git a/channeldb/nodes.go b/channeldb/nodes.go index 88d98d6ae0..ffc7414c50 100644 --- a/channeldb/nodes.go +++ b/channeldb/nodes.go @@ -56,12 +56,14 @@ type LinkNode struct { // authenticated connection for the stored identity public key. Addresses []net.Addr - db *DB + // db is the database instance this node was fetched from. This is used + // to sync back the node's state if it is updated. + db *LinkNodeDB } // NewLinkNode creates a new LinkNode from the provided parameters, which is -// backed by an instance of channeldb. -func (d *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, +// backed by an instance of a link node DB. +func NewLinkNode(db *LinkNodeDB, bitNet wire.BitcoinNet, pub *btcec.PublicKey, addrs ...net.Addr) *LinkNode { return &LinkNode{ @@ -69,7 +71,7 @@ func (d *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, IdentityPub: pub, LastSeen: time.Now(), Addresses: addrs, - db: d, + db: db, } } @@ -98,10 +100,9 @@ func (l *LinkNode) AddAddress(addr net.Addr) error { // Sync performs a full database sync which writes the current up-to-date data // within the struct to the database. func (l *LinkNode) Sync() error { - // Finally update the database by storing the link node and updating // any relevant indexes. - return kvdb.Update(l.db, func(tx kvdb.RwTx) error { + return kvdb.Update(l.db.backend, func(tx kvdb.RwTx) error { nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket) if nodeMetaBucket == nil { return ErrLinkNodesNotFound @@ -127,15 +128,20 @@ func putLinkNode(nodeMetaBucket kvdb.RwBucket, l *LinkNode) error { return nodeMetaBucket.Put(nodePub, b.Bytes()) } +// LinkNodeDB is a database that keeps track of all link nodes. +type LinkNodeDB struct { + backend kvdb.Backend +} + // DeleteLinkNode removes the link node with the given identity from the // database. -func (d *DB) DeleteLinkNode(identity *btcec.PublicKey) error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { - return d.deleteLinkNode(tx, identity) +func (l *LinkNodeDB) DeleteLinkNode(identity *btcec.PublicKey) error { + return kvdb.Update(l.backend, func(tx kvdb.RwTx) error { + return deleteLinkNode(tx, identity) }, func() {}) } -func (d *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error { +func deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error { nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket) if nodeMetaBucket == nil { return ErrLinkNodesNotFound @@ -148,9 +154,9 @@ func (d *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error { // FetchLinkNode attempts to lookup the data for a LinkNode based on a target // identity public key. If a particular LinkNode for the passed identity public // key cannot be found, then ErrNodeNotFound if returned. -func (d *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { +func (l *LinkNodeDB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { var linkNode *LinkNode - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(l.backend, func(tx kvdb.RTx) error { node, err := fetchLinkNode(tx, identity) if err != nil { return err @@ -191,10 +197,10 @@ func fetchLinkNode(tx kvdb.RTx, targetPub *btcec.PublicKey) (*LinkNode, error) { // FetchAllLinkNodes starts a new database transaction to fetch all nodes with // whom we have active channels with. -func (d *DB) FetchAllLinkNodes() ([]*LinkNode, error) { +func (l *LinkNodeDB) FetchAllLinkNodes() ([]*LinkNode, error) { var linkNodes []*LinkNode - err := kvdb.View(d, func(tx kvdb.RTx) error { - nodes, err := d.fetchAllLinkNodes(tx) + err := kvdb.View(l.backend, func(tx kvdb.RTx) error { + nodes, err := fetchAllLinkNodes(tx) if err != nil { return err } @@ -213,7 +219,7 @@ func (d *DB) FetchAllLinkNodes() ([]*LinkNode, error) { // fetchAllLinkNodes uses an existing database transaction to fetch all nodes // with whom we have active channels with. -func (d *DB) fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) { +func fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) { nodeMetaBucket := tx.ReadBucket(nodeInfoBucket) if nodeMetaBucket == nil { return nil, ErrLinkNodesNotFound diff --git a/channeldb/nodes_test.go b/channeldb/nodes_test.go index 0d649d4315..7e9231fc5f 100644 --- a/channeldb/nodes_test.go +++ b/channeldb/nodes_test.go @@ -34,8 +34,8 @@ func TestLinkNodeEncodeDecode(t *testing.T) { // Create two fresh link node instances with the above dummy data, then // fully sync both instances to disk. - node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1) - node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2) + node1 := NewLinkNode(cdb.linkNodeDB, wire.MainNet, pub1, addr1) + node2 := NewLinkNode(cdb.linkNodeDB, wire.TestNet3, pub2, addr2) if err := node1.Sync(); err != nil { t.Fatalf("unable to sync node: %v", err) } @@ -46,7 +46,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) { // Fetch all current link nodes from the database, they should exactly // match the two created above. originalNodes := []*LinkNode{node2, node1} - linkNodes, err := cdb.FetchAllLinkNodes() + linkNodes, err := cdb.linkNodeDB.FetchAllLinkNodes() if err != nil { t.Fatalf("unable to fetch nodes: %v", err) } @@ -82,7 +82,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) { } // Fetch the same node from the database according to its public key. - node1DB, err := cdb.FetchLinkNode(pub1) + node1DB, err := cdb.linkNodeDB.FetchLinkNode(pub1) if err != nil { t.Fatalf("unable to find node: %v", err) } @@ -121,20 +121,20 @@ func TestDeleteLinkNode(t *testing.T) { IP: net.ParseIP("127.0.0.1"), Port: 1337, } - linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr) + linkNode := NewLinkNode(cdb.linkNodeDB, wire.TestNet3, pubKey, addr) if err := linkNode.Sync(); err != nil { t.Fatalf("unable to write link node to db: %v", err) } - if _, err := cdb.FetchLinkNode(pubKey); err != nil { + if _, err := cdb.linkNodeDB.FetchLinkNode(pubKey); err != nil { t.Fatalf("unable to find link node: %v", err) } - if err := cdb.DeleteLinkNode(pubKey); err != nil { + if err := cdb.linkNodeDB.DeleteLinkNode(pubKey); err != nil { t.Fatalf("unable to delete link node from db: %v", err) } - if _, err := cdb.FetchLinkNode(pubKey); err == nil { + if _, err := cdb.linkNodeDB.FetchLinkNode(pubKey); err == nil { t.Fatal("should not have found link node in db, but did") } } diff --git a/server.go b/server.go index 87cec505e0..62a0aa62d4 100644 --- a/server.go +++ b/server.go @@ -2527,7 +2527,7 @@ func (s *server) establishPersistentConnections() error { // Iterate through the list of LinkNodes to find addresses we should // attempt to connect to based on our set of previous connections. Set // the reconnection port to the default peer port. - linkNodes, err := s.chanStateDB.FetchAllLinkNodes() + linkNodes, err := s.chanStateDB.LinkNodeDB().FetchAllLinkNodes() if err != nil && err != channeldb.ErrLinkNodesNotFound { return err } From c1f686f8609a0bf5d4acf3528b51a31ceab8dbd8 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:14 +0200 Subject: [PATCH 04/15] channeldb+funding: move opening channel state to DB The funding manager doesn't need to know the details of the underlying storage of the opening channel state, so we move the actual store and retrieval into the channel database. --- channeldb/db.go | 55 +++++++++++++++++++++++ funding/manager.go | 98 ++++++++++++----------------------------- funding/manager_test.go | 6 +-- 3 files changed, 87 insertions(+), 72 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index a9aa0d124f..8b373d8d7b 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -209,6 +209,11 @@ var ( // Big endian is the preferred byte order, due to cursor scans over // integer keys iterating in order. byteOrder = binary.BigEndian + + // channelOpeningStateBucket is the database bucket used to store the + // channelOpeningState for each channel that is currently in the process + // of being opened. + channelOpeningStateBucket = []byte("channelOpeningState") ) // DB is the primary datastore for the lnd daemon. The database stores @@ -1197,6 +1202,56 @@ func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { return dbChan.CloseChannel(summary, ChanStatusLocalCloseInitiator) } +// SaveChannelOpeningState saves the serialized channel state for the provided +// chanPoint to the channelOpeningStateBucket. +func (d *DB) SaveChannelOpeningState(outPoint, serializedState []byte) error { + return kvdb.Update(d, func(tx kvdb.RwTx) error { + bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) + if err != nil { + return err + } + + return bucket.Put(outPoint, serializedState) + }, func() {}) +} + +// GetChannelOpeningState fetches the serialized channel state for the provided +// outPoint from the database, or returns ErrChannelNotFound if the channel +// is not found. +func (d *DB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { + var serializedState []byte + err := kvdb.View(d, func(tx kvdb.RTx) error { + bucket := tx.ReadBucket(channelOpeningStateBucket) + if bucket == nil { + // If the bucket does not exist, it means we never added + // a channel to the db, so return ErrChannelNotFound. + return ErrChannelNotFound + } + + serializedState = bucket.Get(outPoint) + if serializedState == nil { + return ErrChannelNotFound + } + + return nil + }, func() { + serializedState = nil + }) + return serializedState, err +} + +// DeleteChannelOpeningState removes any state for outPoint from the database. +func (d *DB) DeleteChannelOpeningState(outPoint []byte) error { + return kvdb.Update(d, func(tx kvdb.RwTx) error { + bucket := tx.ReadWriteBucket(channelOpeningStateBucket) + if bucket == nil { + return ErrChannelNotFound + } + + return bucket.Delete(outPoint) + }, func() {}) +} + // syncVersions function is used for safe db version synchronization. It // applies migration functions to the current database and recovers the // previous state of db if at least one error/panic appeared during migration. diff --git a/funding/manager.go b/funding/manager.go index f600390440..3daa393f1d 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -23,7 +23,6 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnrpc" @@ -550,19 +549,6 @@ const ( addedToRouterGraph ) -var ( - // channelOpeningStateBucket is the database bucket used to store the - // channelOpeningState for each channel that is currently in the process - // of being opened. - channelOpeningStateBucket = []byte("channelOpeningState") - - // ErrChannelNotFound is an error returned when a channel is not known - // to us. In this case of the fundingManager, this error is returned - // when the channel in question is not considered being in an opening - // state. - ErrChannelNotFound = fmt.Errorf("channel not found") -) - // NewFundingManager creates and initializes a new instance of the // fundingManager. func NewFundingManager(cfg Config) (*Manager, error) { @@ -887,7 +873,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, channelState, shortChanID, err := f.getChannelOpeningState( &channel.FundingOutpoint, ) - if err == ErrChannelNotFound { + if err == channeldb.ErrChannelNotFound { // Channel not in fundingManager's opening database, // meaning it was successfully announced to the // network. @@ -3539,26 +3525,20 @@ func copyPubKey(pub *btcec.PublicKey) *btcec.PublicKey { // chanPoint to the channelOpeningStateBucket. func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint, state channelOpeningState, shortChanID *lnwire.ShortChannelID) error { - return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error { - - bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) - if err != nil { - return err - } - - var outpointBytes bytes.Buffer - if err = WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } - // Save state and the uint64 representation of the shortChanID - // for later use. - scratch := make([]byte, 10) - byteOrder.PutUint16(scratch[:2], uint16(state)) - byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64()) + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return err + } - return bucket.Put(outpointBytes.Bytes(), scratch) - }, func() {}) + // Save state and the uint64 representation of the shortChanID + // for later use. + scratch := make([]byte, 10) + byteOrder.PutUint16(scratch[:2], uint16(state)) + byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64()) + return f.cfg.Wallet.Cfg.Database.SaveChannelOpeningState( + outpointBytes.Bytes(), scratch, + ) } // getChannelOpeningState fetches the channelOpeningState for the provided @@ -3567,51 +3547,31 @@ func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint, func (f *Manager) getChannelOpeningState(chanPoint *wire.OutPoint) ( channelOpeningState, *lnwire.ShortChannelID, error) { - var state channelOpeningState - var shortChanID lnwire.ShortChannelID - err := kvdb.View(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RTx) error { - - bucket := tx.ReadBucket(channelOpeningStateBucket) - if bucket == nil { - // If the bucket does not exist, it means we never added - // a channel to the db, so return ErrChannelNotFound. - return ErrChannelNotFound - } - - var outpointBytes bytes.Buffer - if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } - - value := bucket.Get(outpointBytes.Bytes()) - if value == nil { - return ErrChannelNotFound - } + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return 0, nil, err + } - state = channelOpeningState(byteOrder.Uint16(value[:2])) - shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) - return nil - }, func() {}) + value, err := f.cfg.Wallet.Cfg.Database.GetChannelOpeningState( + outpointBytes.Bytes(), + ) if err != nil { return 0, nil, err } + state := channelOpeningState(byteOrder.Uint16(value[:2])) + shortChanID := lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:])) return state, &shortChanID, nil } // deleteChannelOpeningState removes any state for chanPoint from the database. func (f *Manager) deleteChannelOpeningState(chanPoint *wire.OutPoint) error { - return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error { - bucket := tx.ReadWriteBucket(channelOpeningStateBucket) - if bucket == nil { - return fmt.Errorf("bucket not found") - } - - var outpointBytes bytes.Buffer - if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { - return err - } + var outpointBytes bytes.Buffer + if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { + return err + } - return bucket.Delete(outpointBytes.Bytes()) - }, func() {}) + return f.cfg.Wallet.Cfg.Database.DeleteChannelOpeningState( + outpointBytes.Bytes(), + ) } diff --git a/funding/manager_test.go b/funding/manager_test.go index 97ee699f27..acd7ca5147 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -922,12 +922,12 @@ func assertDatabaseState(t *testing.T, node *testNode, } state, _, err = node.fundingMgr.getChannelOpeningState( fundingOutPoint) - if err != nil && err != ErrChannelNotFound { + if err != nil && err != channeldb.ErrChannelNotFound { t.Fatalf("unable to get channel state: %v", err) } // If we found the channel, check if it had the expected state. - if err != ErrChannelNotFound && state == expectedState { + if err != channeldb.ErrChannelNotFound && state == expectedState { // Got expected state, return with success. return } @@ -1165,7 +1165,7 @@ func assertErrChannelNotFound(t *testing.T, node *testNode, } state, _, err = node.fundingMgr.getChannelOpeningState( fundingOutPoint) - if err == ErrChannelNotFound { + if err == channeldb.ErrChannelNotFound { // Got expected state, return with success. return } else if err != nil { From ddea833d3109bced93f9ac366d25835aa42b7ab7 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:16 +0200 Subject: [PATCH 05/15] multi: extract address source into interface As a preparation to have the method for querying the addresses of a node separate from the channel state, we extract that method out into its own interface. --- chanbackup/backup.go | 22 ++++++++++++++-------- chanbackup/backup_test.go | 10 ++++++---- rpcserver.go | 6 +++--- server.go | 7 ++++++- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/chanbackup/backup.go b/chanbackup/backup.go index 076e2fcf6d..dce1210b26 100644 --- a/chanbackup/backup.go +++ b/chanbackup/backup.go @@ -21,7 +21,11 @@ type LiveChannelSource interface { // passed chanPoint. Optionally an existing db tx can be supplied. FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( *channeldb.OpenChannel, error) +} +// AddressSource is an interface that allows us to query for the set of +// addresses a node can be connected to. +type AddressSource interface { // AddrsForNode returns all known addresses for the target node public // key. AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) @@ -31,15 +35,15 @@ type LiveChannelSource interface { // passed open channel. The backup includes all information required to restore // the channel, as well as addressing information so we can find the peer and // reconnect to them to initiate the protocol. -func assembleChanBackup(chanSource LiveChannelSource, +func assembleChanBackup(addrSource AddressSource, openChan *channeldb.OpenChannel) (*Single, error) { log.Debugf("Crafting backup for ChannelPoint(%v)", openChan.FundingOutpoint) // First, we'll query the channel source to obtain all the addresses - // that are are associated with the peer for this channel. - nodeAddrs, err := chanSource.AddrsForNode(openChan.IdentityPub) + // that are associated with the peer for this channel. + nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub) if err != nil { return nil, err } @@ -52,8 +56,8 @@ func assembleChanBackup(chanSource LiveChannelSource, // FetchBackupForChan attempts to create a plaintext static channel backup for // the target channel identified by its channel point. If we're unable to find // the target channel, then an error will be returned. -func FetchBackupForChan(chanPoint wire.OutPoint, - chanSource LiveChannelSource) (*Single, error) { +func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource, + addrSource AddressSource) (*Single, error) { // First, we'll query the channel source to see if the channel is known // and open within the database. @@ -66,7 +70,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint, // Once we have the target channel, we can assemble the backup using // the source to obtain any extra information that we may need. - staticChanBackup, err := assembleChanBackup(chanSource, targetChan) + staticChanBackup, err := assembleChanBackup(addrSource, targetChan) if err != nil { return nil, fmt.Errorf("unable to create chan backup: %v", err) } @@ -76,7 +80,9 @@ func FetchBackupForChan(chanPoint wire.OutPoint, // FetchStaticChanBackups will return a plaintext static channel back up for // all known active/open channels within the passed channel source. -func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) { +func FetchStaticChanBackups(chanSource LiveChannelSource, + addrSource AddressSource) ([]Single, error) { + // First, we'll query the backup source for information concerning all // currently open and available channels. openChans, err := chanSource.FetchAllChannels() @@ -89,7 +95,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) { // channel. staticChanBackups := make([]Single, 0, len(openChans)) for _, openChan := range openChans { - chanBackup, err := assembleChanBackup(chanSource, openChan) + chanBackup, err := assembleChanBackup(addrSource, openChan) if err != nil { return nil, err } diff --git a/chanbackup/backup_test.go b/chanbackup/backup_test.go index e718dce3ea..ff321c1884 100644 --- a/chanbackup/backup_test.go +++ b/chanbackup/backup_test.go @@ -124,7 +124,9 @@ func TestFetchBackupForChan(t *testing.T) { }, } for i, testCase := range testCases { - _, err := FetchBackupForChan(testCase.chanPoint, chanSource) + _, err := FetchBackupForChan( + testCase.chanPoint, chanSource, chanSource, + ) switch { // If this is a valid test case, and we failed, then we'll // return an error. @@ -167,7 +169,7 @@ func TestFetchStaticChanBackups(t *testing.T) { // With the channel source populated, we'll now attempt to create a set // of backups for all the channels. This should succeed, as all items // are populated within the channel source. - backups, err := FetchStaticChanBackups(chanSource) + backups, err := FetchStaticChanBackups(chanSource, chanSource) if err != nil { t.Fatalf("unable to create chan back ups: %v", err) } @@ -184,7 +186,7 @@ func TestFetchStaticChanBackups(t *testing.T) { copy(n[:], randomChan2.IdentityPub.SerializeCompressed()) delete(chanSource.addrs, n) - _, err = FetchStaticChanBackups(chanSource) + _, err = FetchStaticChanBackups(chanSource, chanSource) if err == nil { t.Fatalf("query with incomplete information should fail") } @@ -193,7 +195,7 @@ func TestFetchStaticChanBackups(t *testing.T) { // source at all, then we'll fail as well. chanSource = newMockChannelSource() chanSource.failQuery = true - _, err = FetchStaticChanBackups(chanSource) + _, err = FetchStaticChanBackups(chanSource, chanSource) if err == nil { t.Fatalf("query should fail") } diff --git a/rpcserver.go b/rpcserver.go index 0e27704216..79f6553750 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6469,7 +6469,7 @@ func (r *rpcServer) ExportChannelBackup(ctx context.Context, // the database. If this channel has been closed, or the outpoint is // unknown, then we'll return an error unpackedBackup, err := chanbackup.FetchBackupForChan( - chanPoint, r.server.chanStateDB, + chanPoint, r.server.chanStateDB, r.server.addrSource, ) if err != nil { return nil, err @@ -6639,7 +6639,7 @@ func (r *rpcServer) ExportAllChannelBackups(ctx context.Context, // First, we'll attempt to read back ups for ALL currently opened // channels from disk. allUnpackedBackups, err := chanbackup.FetchStaticChanBackups( - r.server.chanStateDB, + r.server.chanStateDB, r.server.addrSource, ) if err != nil { return nil, fmt.Errorf("unable to fetch all static chan "+ @@ -6766,7 +6766,7 @@ func (r *rpcServer) SubscribeChannelBackups(req *lnrpc.ChannelBackupSubscription // we'll obtains the current set of single channel // backups from disk. chanBackups, err := chanbackup.FetchStaticChanBackups( - r.server.chanStateDB, + r.server.chanStateDB, r.server.addrSource, ) if err != nil { return fmt.Errorf("unable to fetch all "+ diff --git a/server.go b/server.go index 62a0aa62d4..6b8830f03c 100644 --- a/server.go +++ b/server.go @@ -224,6 +224,8 @@ type server struct { chanStateDB *channeldb.DB + addrSource chanbackup.AddressSource + htlcSwitch *htlcswitch.Switch interceptableSwitch *htlcswitch.InterceptableSwitch @@ -433,6 +435,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, cfg: cfg, graphDB: dbs.graphDB.ChannelGraph(), chanStateDB: dbs.chanStateDB, + addrSource: dbs.chanStateDB, cc: cc, sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer), writePool: writePool, @@ -1246,7 +1249,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, addrs: s.chanStateDB, } backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) - startingChans, err := chanbackup.FetchStaticChanBackups(s.chanStateDB) + startingChans, err := chanbackup.FetchStaticChanBackups( + s.chanStateDB, s.addrSource, + ) if err != nil { return nil, err } From 11cf4216e441a323d98ffa279ebb23653f2e7ae3 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 21 Sep 2021 19:18:17 +0200 Subject: [PATCH 06/15] multi: move all channelstate operations to ChannelStateDB --- chainreg/chainregistry.go | 2 +- channeldb/channel.go | 54 +++---- channeldb/channel_test.go | 54 ++++--- channeldb/db.go | 186 ++++++++++++++++--------- channeldb/db_test.go | 46 ++++-- channeldb/nodes_test.go | 8 +- channeldb/waitingproof.go | 4 +- channelnotifier/channelnotifier.go | 4 +- chanrestore.go | 2 +- contractcourt/breacharbiter.go | 6 +- contractcourt/breacharbiter_test.go | 12 +- contractcourt/chain_arbitrator.go | 16 ++- contractcourt/chain_arbitrator_test.go | 6 +- contractcourt/utils_test.go | 4 +- discovery/message_store.go | 11 +- funding/manager_test.go | 6 +- htlcswitch/circuit_map.go | 21 ++- htlcswitch/circuit_test.go | 15 +- htlcswitch/link_test.go | 4 +- htlcswitch/mock.go | 6 +- htlcswitch/payment_result.go | 14 +- htlcswitch/switch.go | 19 ++- htlcswitch/test_utils.go | 26 ++-- lnd.go | 2 +- lnrpc/invoicesrpc/addinvoice.go | 2 +- lnrpc/invoicesrpc/config_active.go | 2 +- lnwallet/config.go | 2 +- lnwallet/test/test_interface.go | 8 +- lnwallet/test_utils.go | 4 +- lnwallet/transactions_test.go | 4 +- peer/brontide.go | 2 +- peer/test_utils.go | 8 +- rpcserver.go | 16 +-- server.go | 37 +++-- subrpcserver_config.go | 2 +- 35 files changed, 377 insertions(+), 238 deletions(-) diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index ad0b2f54b5..762ca7e2f4 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -75,7 +75,7 @@ type Config struct { // ChanStateDB is a pointer to the database that stores the channel // state. - ChanStateDB *channeldb.DB + ChanStateDB *channeldb.ChannelStateDB // BlockCacheSize is the size (in bytes) of blocks kept in memory. BlockCacheSize uint64 diff --git a/channeldb/channel.go b/channeldb/channel.go index f6b5fab4d0..858cead919 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -729,7 +729,7 @@ type OpenChannel struct { RevocationKeyLocator keychain.KeyLocator // TODO(roasbeef): eww - Db *DB + Db *ChannelStateDB // TODO(roasbeef): just need to store local and remote HTLC's? @@ -800,7 +800,7 @@ func (c *OpenChannel) RefreshShortChanID() error { c.Lock() defer c.Unlock() - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -995,7 +995,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { c.Lock() defer c.Unlock() - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1047,7 +1047,7 @@ func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { var commitPoint *btcec.PublicKey - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1271,7 +1271,7 @@ func (c *OpenChannel) BroadcastedCooperative() (*wire.MsgTx, error) { func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { var closeTx *wire.MsgTx - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1305,7 +1305,7 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { func (c *OpenChannel) putChanStatus(status ChannelStatus, fs ...func(kvdb.RwBucket) error) error { - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1349,7 +1349,7 @@ func (c *OpenChannel) putChanStatus(status ChannelStatus, } func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { - if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -1473,7 +1473,7 @@ func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { c.FundingBroadcastHeight = pendingHeight - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return syncNewChannel(tx, c, []net.Addr{addr}) }, func() {}) } @@ -1502,7 +1502,7 @@ func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr) error { // for this channel. The LinkNode metadata contains reachability, // up-time, and service bits related information. linkNode := NewLinkNode( - &LinkNodeDB{backend: c.Db.Backend}, + &LinkNodeDB{backend: c.Db.backend}, wire.MainNet, c.IdentityPub, addrs..., ) @@ -1532,7 +1532,7 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment, return ErrNoRestoredChannelMutation } - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2124,7 +2124,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { return ErrNoRestoredChannelMutation } - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { // First, we'll grab the writable bucket where this channel's // data resides. chanBucket, err := fetchChanBucketRw( @@ -2194,7 +2194,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { // these pointers, causing the tip and the tail to point to the same entry. func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { var cd *CommitDiff - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2233,7 +2233,7 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { // updates that still need to be signed for. func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { var updates []LogUpdate - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2267,7 +2267,7 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { // updates that the remote still needs to sign for. func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { var updates []LogUpdate - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2311,7 +2311,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { c.RemoteNextRevocation = revKey - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2352,7 +2352,7 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg, var newRemoteCommit *ChannelCommitment - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2527,7 +2527,7 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { defer c.RUnlock() var fwdPkgs []*FwdPkg - if err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { var err error fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) return err @@ -2547,7 +2547,7 @@ func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.AckAddHtlcs(tx, addRefs...) }, func() {}) } @@ -2560,7 +2560,7 @@ func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.AckSettleFails(tx, settleFailRefs...) }, func() {}) } @@ -2571,7 +2571,7 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { return c.Packager.SetFwdFilter(tx, height, fwdFilter) }, func() {}) } @@ -2585,7 +2585,7 @@ func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error { c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { for _, height := range heights { err := c.Packager.RemovePkg(tx, height) if err != nil { @@ -2613,7 +2613,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { } var commit ChannelCommitment - if err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2660,7 +2660,7 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) { defer c.RUnlock() var height uint64 - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { // Get the bucket dedicated to storing the metadata for open // channels. chanBucket, err := fetchChanBucket( @@ -2697,7 +2697,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e defer c.RUnlock() var commit ChannelCommitment - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -2855,7 +2855,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary, c.Lock() defer c.Unlock() - return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { openChanBucket := tx.ReadWriteBucket(openChannelBucket) if openChanBucket == nil { return ErrNoChanDBExists @@ -3067,7 +3067,7 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot { // latest fully committed state is returned. The first commitment returned is // the local commitment, and the second returned is the remote commitment. func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) @@ -3089,7 +3089,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen // acting on a possible contract breach to ensure, that the caller has the most // up to date information required to deliver justice. func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { - err := kvdb.View(c.Db, func(tx kvdb.RTx) error { + err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, ) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index ad1b3c07cc..044308f887 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -183,7 +183,7 @@ var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { // createTestChannel writes a test channel to the database. It takes a set of // functional options which can be used to overwrite the default of creating // a pending channel that was broadcast at height 100. -func createTestChannel(t *testing.T, cdb *DB, +func createTestChannel(t *testing.T, cdb *ChannelStateDB, opts ...testChannelOption) *OpenChannel { // Create a default set of parameters. @@ -221,7 +221,7 @@ func createTestChannel(t *testing.T, cdb *DB, return params.channel } -func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { +func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { // Simulate 1000 channel updates. producer, err := shachain.NewRevocationProducerFromBytes(key[:]) if err != nil { @@ -359,12 +359,14 @@ func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { func TestOpenChannelPutGetDelete(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create the test channel state, with additional htlcs on the local // and remote commitment. localHtlcs := []HTLC{ @@ -508,12 +510,14 @@ func TestOptionalShutdown(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a channel with upfront scripts set as // specified in the test. state := createTestChannel( @@ -565,12 +569,14 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func TestChannelStateTransition(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First create a minimal channel, then perform a full sync in order to // persist the data. channel := createTestChannel(t, cdb) @@ -842,7 +848,7 @@ func TestChannelStateTransition(t *testing.T) { } // At this point, we should have 2 forwarding packages added. - fwdPkgs := loadFwdPkgs(t, cdb, channel.Packager) + fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager) require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages") // Now attempt to delete the channel from the database. @@ -877,19 +883,21 @@ func TestChannelStateTransition(t *testing.T) { } // All forwarding packages of this channel has been deleted too. - fwdPkgs = loadFwdPkgs(t, cdb, channel.Packager) + fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager) require.Empty(t, fwdPkgs, "no forwarding packages should exist") } func TestFetchPendingChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a pending channel that was broadcast at height 99. const broadcastHeight = 99 createTestChannel(t, cdb, pendingHeightOption(broadcastHeight)) @@ -963,12 +971,14 @@ func TestFetchPendingChannels(t *testing.T) { func TestFetchClosedChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel in the database. state := createTestChannel(t, cdb, openChannelOption()) @@ -1054,18 +1064,20 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // We'll start by creating two channels within our test database. One of // them will have their funding transaction confirmed on-chain, while // the other one will remain unconfirmed. - db, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + channels := make([]*OpenChannel, numChannels) for i := 0; i < numChannels; i++ { // Create a pending channel in the database at the broadcast // height. channels[i] = createTestChannel( - t, db, pendingHeightOption(broadcastHeight), + t, cdb, pendingHeightOption(broadcastHeight), ) } @@ -1116,7 +1128,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // Now, we'll fetch all the channels waiting to be closed from the // database. We should expect to see both channels above, even if any of // them haven't had their funding transaction confirm on-chain. - waitingCloseChannels, err := db.FetchWaitingCloseChannels() + waitingCloseChannels, err := cdb.FetchWaitingCloseChannels() if err != nil { t.Fatalf("unable to fetch all waiting close channels: %v", err) } @@ -1169,12 +1181,14 @@ func TestFetchWaitingCloseChannels(t *testing.T) { func TestRefreshShortChanID(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First create a test channel. state := createTestChannel(t, cdb) @@ -1317,13 +1331,15 @@ func TestCloseInitiator(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channel := createTestChannel( t, cdb, openChannelOption(), @@ -1362,13 +1378,15 @@ func TestCloseInitiator(t *testing.T) { // TestCloseChannelStatus tests setting of a channel status on the historical // channel on channel close. func TestCloseChannelStatus(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channel := createTestChannel( t, cdb, openChannelOption(), @@ -1427,7 +1445,7 @@ func TestBalanceAtHeight(t *testing.T) { putRevokedState := func(c *OpenChannel, height uint64, local, remote lnwire.MilliSatoshi) error { - err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error { chanBucket, err := fetchChanBucketRw( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, @@ -1508,13 +1526,15 @@ func TestBalanceAtHeight(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create options to set the heights and balances of // our local and remote commitments. localCommitOpt := channelCommitmentOption( diff --git a/channeldb/db.go b/channeldb/db.go index 8b373d8d7b..3ceb93d936 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -222,8 +222,8 @@ var ( type DB struct { kvdb.Backend - // linkNodeDB separates all DB operations on LinkNodes. - linkNodeDB *LinkNodeDB + // channelStateDB separates all DB operations on channel state. + channelStateDB *ChannelStateDB dbPath string graph *ChannelGraph @@ -273,13 +273,19 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, chanDB := &DB{ Backend: backend, - linkNodeDB: &LinkNodeDB{ + channelStateDB: &ChannelStateDB{ + linkNodeDB: &LinkNodeDB{ + backend: backend, + }, backend: backend, }, clock: opts.clock, dryRun: opts.dryRun, } + // Set the parent pointer (only used in tests). + chanDB.channelStateDB.parent = chanDB + chanDB.graph = newChannelGraph( backend, opts.RejectCacheSize, opts.ChannelCacheSize, opts.BatchCommitInterval, @@ -339,10 +345,10 @@ func (d *DB) Wipe() error { return initChannelDB(d.Backend) } -// createChannelDB creates and initializes a fresh version of channeldb. In -// the case that the target path has not yet been created or doesn't yet exist, -// then the path is created. Additionally, all required top-level buckets used -// within the database are created. +// initChannelDB creates and initializes a fresh version of channeldb. In the +// case that the target path has not yet been created or doesn't yet exist, then +// the path is created. Additionally, all required top-level buckets used within +// the database are created. func initChannelDB(db kvdb.Backend) error { err := kvdb.Update(db, func(tx kvdb.RwTx) error { meta := &Meta{} @@ -409,15 +415,45 @@ func fileExists(path string) bool { return true } +// ChannelStateDB is a database that keeps track of all channel state. +type ChannelStateDB struct { + // linkNodeDB separates all DB operations on LinkNodes. + linkNodeDB *LinkNodeDB + + // parent holds a pointer to the "main" channeldb.DB object. This is + // only used for testing and should never be used in production code. + // For testing use the ChannelStateDB.GetParentDB() function to retrieve + // this pointer. + parent *DB + + // backend points to the actual backend holding the channel state + // database. This may be a real backend or a cache middleware. + backend kvdb.Backend +} + +// GetParentDB returns the "main" channeldb.DB object that is the owner of this +// ChannelStateDB instance. Use this function only in tests where passing around +// pointers makes testing less readable. Never to be used in production code! +func (c *ChannelStateDB) GetParentDB() *DB { + return c.parent +} + +// LinkNodeDB returns the current instance of the link node database. +func (c *ChannelStateDB) LinkNodeDB() *LinkNodeDB { + return c.linkNodeDB +} + // FetchOpenChannels starts a new database transaction and returns all stored // currently active/open channels associated with the target nodeID. In the case // that no active channels are known to have been created with this node, then a // zero-length slice is returned. -func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchOpenChannels(nodeID *btcec.PublicKey) ( + []*OpenChannel, error) { + var channels []*OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { var err error - channels, err = d.fetchOpenChannels(tx, nodeID) + channels, err = c.fetchOpenChannels(tx, nodeID) return err }, func() { channels = nil @@ -430,7 +466,7 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) // stored currently active/open channels associated with the target nodeID. In // the case that no active channels are known to have been created with this // node, then a zero-length slice is returned. -func (d *DB) fetchOpenChannels(tx kvdb.RTx, +func (c *ChannelStateDB) fetchOpenChannels(tx kvdb.RTx, nodeID *btcec.PublicKey) ([]*OpenChannel, error) { // Get the bucket dedicated to storing the metadata for open channels. @@ -466,7 +502,7 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx, // Finally, we both of the necessary buckets retrieved, fetch // all the active channels related to this node. - nodeChannels, err := d.fetchNodeChannels(chainBucket) + nodeChannels, err := c.fetchNodeChannels(chainBucket) if err != nil { return fmt.Errorf("unable to read channel for "+ "chain_hash=%x, node_key=%x: %v", @@ -483,7 +519,8 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx, // fetchNodeChannels retrieves all active channels from the target chainBucket // which is under a node's dedicated channel bucket. This function is typically // used to fetch all the active channels related to a particular node. -func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) { +func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) ( + []*OpenChannel, error) { var channels []*OpenChannel @@ -509,7 +546,7 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) return fmt.Errorf("unable to read channel data for "+ "chan_point=%v: %v", outPoint, err) } - oChannel.Db = d + oChannel.Db = c channels = append(channels, oChannel) @@ -526,8 +563,8 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) // point. If the channel cannot be found, then an error will be returned. // Optionally an existing db tx can be supplied. Optionally an existing db tx // can be supplied. -func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, - error) { +func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( + *OpenChannel, error) { var ( targetChan *OpenChannel @@ -603,7 +640,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, } targetChan = channel - targetChan.Db = d + targetChan.Db = c return nil }) @@ -612,7 +649,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, var err error if tx == nil { - err = kvdb.View(d, chanScan, func() {}) + err = kvdb.View(c.backend, chanScan, func() {}) } else { err = chanScan(tx) } @@ -632,16 +669,16 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, // FetchAllChannels attempts to retrieve all open channels currently stored // within the database, including pending open, fully open and channels waiting // for a closing transaction to confirm. -func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { - return fetchChannels(d) +func (c *ChannelStateDB) FetchAllChannels() ([]*OpenChannel, error) { + return fetchChannels(c) } // FetchAllOpenChannels will return all channels that have the funding // transaction confirmed, and is not waiting for a closing transaction to be // confirmed. -func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchAllOpenChannels() ([]*OpenChannel, error) { return fetchChannels( - d, + c, pendingChannelFilter(false), waitingCloseFilter(false), ) @@ -650,8 +687,8 @@ func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { // FetchPendingChannels will return channels that have completed the process of // generating and broadcasting funding transactions, but whose funding // transactions have yet to be confirmed on the blockchain. -func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { - return fetchChannels(d, +func (c *ChannelStateDB) FetchPendingChannels() ([]*OpenChannel, error) { + return fetchChannels(c, pendingChannelFilter(true), waitingCloseFilter(false), ) @@ -661,9 +698,9 @@ func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { // but are now waiting for a closing transaction to be confirmed. // // NOTE: This includes channels that are also pending to be opened. -func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { +func (c *ChannelStateDB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { return fetchChannels( - d, waitingCloseFilter(true), + c, waitingCloseFilter(true), ) } @@ -704,10 +741,12 @@ func waitingCloseFilter(waitingClose bool) fetchChannelsFilter { // which have a true value returned for *all* of the filters will be returned. // If no filters are provided, every channel in the open channels bucket will // be returned. -func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error) { +func fetchChannels(c *ChannelStateDB, filters ...fetchChannelsFilter) ( + []*OpenChannel, error) { + var channels []*OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { // Get the bucket dedicated to storing the metadata for open // channels. openChanBucket := tx.ReadBucket(openChannelBucket) @@ -749,7 +788,7 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error "bucket for chain=%x", chainHash[:]) } - nodeChans, err := d.fetchNodeChannels(chainBucket) + nodeChans, err := c.fetchNodeChannels(chainBucket) if err != nil { return fmt.Errorf("unable to read "+ "channel for chain_hash=%x, "+ @@ -798,10 +837,12 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error // it becomes fully closed after a single confirmation. When a channel was // forcibly closed, it will become fully closed after _all_ the pending funds // (if any) have been swept. -func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, error) { +func (c *ChannelStateDB) FetchClosedChannels(pendingOnly bool) ( + []*ChannelCloseSummary, error) { + var chanSummaries []*ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrNoClosedChannels @@ -839,9 +880,11 @@ var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary // FetchClosedChannel queries for a channel close summary using the channel // point of the channel in question. -func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) { +func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) ( + *ChannelCloseSummary, error) { + var chanSummary *ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrClosedChannelNotFound @@ -873,11 +916,11 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er // FetchClosedChannelForID queries for a channel close summary using the // channel ID of the channel in question. -func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( +func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) ( *ChannelCloseSummary, error) { var chanSummary *ChannelCloseSummary - if err := kvdb.View(d, func(tx kvdb.RTx) error { + if err := kvdb.View(c.backend, func(tx kvdb.RTx) error { closeBucket := tx.ReadBucket(closedChannelBucket) if closeBucket == nil { return ErrClosedChannelNotFound @@ -926,12 +969,12 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( // cooperatively closed and it's reached a single confirmation, or after all // the pending funds in a channel that has been forcibly closed have been // swept. -func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { +func (c *ChannelStateDB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { var ( openChannels []*OpenChannel pruneLinkNode *btcec.PublicKey ) - err := kvdb.Update(d, func(tx kvdb.RwTx) error { + err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error { var b bytes.Buffer if err := writeOutpoint(&b, chanPoint); err != nil { return err @@ -978,7 +1021,9 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { // garbage collect it to ensure we don't establish persistent // connections to peers without open channels. pruneLinkNode = chanSummary.RemotePub - openChannels, err = d.fetchOpenChannels(tx, pruneLinkNode) + openChannels, err = c.fetchOpenChannels( + tx, pruneLinkNode, + ) if err != nil { return fmt.Errorf("unable to fetch open channels for "+ "peer %x: %v", @@ -996,13 +1041,13 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { // Decide whether we want to remove the link node, based upon the number // of still open channels. - return d.pruneLinkNode(openChannels, pruneLinkNode) + return c.pruneLinkNode(openChannels, pruneLinkNode) } // pruneLinkNode determines whether we should garbage collect a link node from // the database due to no longer having any open channels with it. If there are // any left, then this acts as a no-op. -func (d *DB) pruneLinkNode(openChannels []*OpenChannel, +func (c *ChannelStateDB) pruneLinkNode(openChannels []*OpenChannel, remotePub *btcec.PublicKey) error { if len(openChannels) > 0 { @@ -1012,13 +1057,13 @@ func (d *DB) pruneLinkNode(openChannels []*OpenChannel, log.Infof("Pruning link node %x with zero open channels from database", remotePub.SerializeCompressed()) - return d.linkNodeDB.DeleteLinkNode(remotePub) + return c.linkNodeDB.DeleteLinkNode(remotePub) } // PruneLinkNodes attempts to prune all link nodes found within the databse with // whom we no longer have any open channels with. -func (d *DB) PruneLinkNodes() error { - allLinkNodes, err := d.linkNodeDB.FetchAllLinkNodes() +func (c *ChannelStateDB) PruneLinkNodes() error { + allLinkNodes, err := c.linkNodeDB.FetchAllLinkNodes() if err != nil { return err } @@ -1028,9 +1073,9 @@ func (d *DB) PruneLinkNodes() error { openChannels []*OpenChannel linkNode = linkNode ) - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { var err error - openChannels, err = d.fetchOpenChannels( + openChannels, err = c.fetchOpenChannels( tx, linkNode.IdentityPub, ) return err @@ -1041,7 +1086,7 @@ func (d *DB) PruneLinkNodes() error { return err } - err = d.pruneLinkNode(openChannels, linkNode.IdentityPub) + err = c.pruneLinkNode(openChannels, linkNode.IdentityPub) if err != nil { return err } @@ -1069,8 +1114,8 @@ type ChannelShell struct { // addresses, and finally create an edge within the graph for the channel as // well. This method is idempotent, so repeated calls with the same set of // channel shells won't modify the database after the initial call. -func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { - err := kvdb.Update(d, func(tx kvdb.RwTx) error { +func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) error { + err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error { for _, channelShell := range channelShells { channel := channelShell.Chan @@ -1084,7 +1129,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { // and link node for this channel. If the channel // already exists, then in order to ensure this method // is idempotent, we'll continue to the next step. - channel.Db = d + channel.Db = c err := syncNewChannel( tx, channel, channelShell.NodeAddrs, ) @@ -1104,8 +1149,10 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { // AddrsForNode consults the graph and channel database for all addresses known // to the passed node public key. -func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { - linkNode, err := d.linkNodeDB.FetchLinkNode(nodePub) +func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, + error) { + + linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub) if err != nil { return nil, err } @@ -1157,16 +1204,18 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { // database. If the channel was already removed (has a closed channel entry), // then we'll return a nil error. Otherwise, we'll insert a new close summary // into the database. -func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { +func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint, + bestHeight uint32) error { + // With the chanPoint constructed, we'll attempt to find the target // channel in the database. If we can't find the channel, then we'll // return the error back to the caller. - dbChan, err := d.FetchChannel(nil, *chanPoint) + dbChan, err := c.FetchChannel(nil, *chanPoint) switch { // If the channel wasn't found, then it's possible that it was already // abandoned from the database. case err == ErrChannelNotFound: - _, closedErr := d.FetchClosedChannel(chanPoint) + _, closedErr := c.FetchClosedChannel(chanPoint) if closedErr != nil { return closedErr } @@ -1204,8 +1253,10 @@ func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { // SaveChannelOpeningState saves the serialized channel state for the provided // chanPoint to the channelOpeningStateBucket. -func (d *DB) SaveChannelOpeningState(outPoint, serializedState []byte) error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { +func (c *ChannelStateDB) SaveChannelOpeningState(outPoint, + serializedState []byte) error { + + return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) if err != nil { return err @@ -1218,9 +1269,9 @@ func (d *DB) SaveChannelOpeningState(outPoint, serializedState []byte) error { // GetChannelOpeningState fetches the serialized channel state for the provided // outPoint from the database, or returns ErrChannelNotFound if the channel // is not found. -func (d *DB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { +func (c *ChannelStateDB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { var serializedState []byte - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { bucket := tx.ReadBucket(channelOpeningStateBucket) if bucket == nil { // If the bucket does not exist, it means we never added @@ -1241,8 +1292,8 @@ func (d *DB) GetChannelOpeningState(outPoint []byte) ([]byte, error) { } // DeleteChannelOpeningState removes any state for outPoint from the database. -func (d *DB) DeleteChannelOpeningState(outPoint []byte) error { - return kvdb.Update(d, func(tx kvdb.RwTx) error { +func (c *ChannelStateDB) DeleteChannelOpeningState(outPoint []byte) error { + return kvdb.Update(c.backend, func(tx kvdb.RwTx) error { bucket := tx.ReadWriteBucket(channelOpeningStateBucket) if bucket == nil { return ErrChannelNotFound @@ -1330,9 +1381,10 @@ func (d *DB) ChannelGraph() *ChannelGraph { return d.graph } -// LinkNodeDB returns the current instance of the link node database. -func (d *DB) LinkNodeDB() *LinkNodeDB { - return d.linkNodeDB +// ChannelStateDB returns the sub database that is concerned with the channel +// state. +func (d *DB) ChannelStateDB() *ChannelStateDB { + return d.channelStateDB } func getLatestDBVersion(versions []version) uint32 { @@ -1384,9 +1436,11 @@ func fetchHistoricalChanBucket(tx kvdb.RTx, // FetchHistoricalChannel fetches open channel data from the historical channel // bucket. -func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, error) { +func (c *ChannelStateDB) FetchHistoricalChannel(outPoint *wire.OutPoint) ( + *OpenChannel, error) { + var channel *OpenChannel - err := kvdb.View(d, func(tx kvdb.RTx) error { + err := kvdb.View(c.backend, func(tx kvdb.RTx) error { chanBucket, err := fetchHistoricalChanBucket(tx, outPoint) if err != nil { return err @@ -1394,7 +1448,7 @@ func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, erro channel, err = fetchOpenChannel(chanBucket, outPoint) - channel.Db = d + channel.Db = c return err }, func() { channel = nil diff --git a/channeldb/db_test.go b/channeldb/db_test.go index ef471c84c5..5731c03a8a 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -87,15 +87,18 @@ func TestWipe(t *testing.T) { } defer cleanup() - cdb, err := CreateWithBackend(backend) + fullDB, err := CreateWithBackend(backend) if err != nil { t.Fatalf("unable to create channeldb: %v", err) } - defer cdb.Close() + defer fullDB.Close() - if err := cdb.Wipe(); err != nil { + if err := fullDB.Wipe(); err != nil { t.Fatalf("unable to wipe channeldb: %v", err) } + + cdb := fullDB.ChannelStateDB() + // Check correct errors are returned openChannels, err := cdb.FetchAllOpenChannels() require.NoError(t, err, "fetching open channels") @@ -113,12 +116,14 @@ func TestFetchClosedChannelForID(t *testing.T) { const numChans = 101 - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create the test channel state, that we will mutate the index of the // funding point. state := createTestChannelState(t, cdb) @@ -184,18 +189,18 @@ func TestFetchClosedChannelForID(t *testing.T) { func TestAddrsForNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() - graph := cdb.ChannelGraph() + graph := fullDB.ChannelGraph() // We'll make a test vertex to insert into the database, as the source // node, but this node will only have half the number of addresses it // usually does. - testNode, err := createTestVertex(cdb) + testNode, err := createTestVertex(fullDB) if err != nil { t.Fatalf("unable to create test node: %v", err) } @@ -211,7 +216,8 @@ func TestAddrsForNode(t *testing.T) { t.Fatalf("unable to recv node pub: %v", err) } linkNode := NewLinkNode( - cdb.linkNodeDB, wire.MainNet, nodePub, anotherAddr, + fullDB.channelStateDB.linkNodeDB, wire.MainNet, nodePub, + anotherAddr, ) if err := linkNode.Sync(); err != nil { t.Fatalf("unable to sync link node: %v", err) @@ -219,7 +225,7 @@ func TestAddrsForNode(t *testing.T) { // Now that we've created a link node, as well as a vertex for the // node, we'll query for all its addresses. - nodeAddrs, err := cdb.AddrsForNode(nodePub) + nodeAddrs, err := fullDB.AddrsForNode(nodePub) if err != nil { t.Fatalf("unable to obtain node addrs: %v", err) } @@ -245,12 +251,14 @@ func TestAddrsForNode(t *testing.T) { func TestFetchChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create an open channel. channelState := createTestChannel(t, cdb, openChannelOption()) @@ -349,12 +357,14 @@ func genRandomChannelShell() (*ChannelShell, error) { func TestRestoreChannelShells(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First, we'll make our channel shell, it will only have the minimal // amount of information required for us to initiate the data loss // protection feature. @@ -423,7 +433,7 @@ func TestRestoreChannelShells(t *testing.T) { // We should also be able to find the link node that was inserted by // its public key. - linkNode, err := cdb.linkNodeDB.FetchLinkNode( + linkNode, err := fullDB.channelStateDB.linkNodeDB.FetchLinkNode( channelShell.Chan.IdentityPub, ) if err != nil { @@ -445,12 +455,14 @@ func TestRestoreChannelShells(t *testing.T) { func TestAbandonChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // If we attempt to abandon the state of a channel that doesn't exist // in the open or closed channel bucket, then we should receive an // error. @@ -618,13 +630,15 @@ func TestFetchChannels(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test "+ "database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a pending channel that is not awaiting close. createTestChannel( t, cdb, channelIDOption(pendingChan), @@ -687,12 +701,14 @@ func TestFetchChannels(t *testing.T) { // TestFetchHistoricalChannel tests lookup of historical channels. func TestFetchHistoricalChannel(t *testing.T) { - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // Create a an open channel in the database. channel := createTestChannel(t, cdb, openChannelOption()) diff --git a/channeldb/nodes_test.go b/channeldb/nodes_test.go index 7e9231fc5f..8f60a79868 100644 --- a/channeldb/nodes_test.go +++ b/channeldb/nodes_test.go @@ -13,12 +13,14 @@ import ( func TestLinkNodeEncodeDecode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + // First we'll create some initial data to use for populating our test // LinkNode instances. _, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) @@ -110,12 +112,14 @@ func TestLinkNodeEncodeDecode(t *testing.T) { func TestDeleteLinkNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := MakeTestDB() + fullDB, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() + cdb := fullDB.ChannelStateDB() + _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index e8a09b7581..7bb53e1798 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -36,12 +36,12 @@ type WaitingProofStore struct { // cache is used in order to reduce the number of redundant get // calls, when object isn't stored in it. cache map[WaitingProofKey]struct{} - db *DB + db kvdb.Backend mu sync.RWMutex } // NewWaitingProofStore creates new instance of proofs storage. -func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { +func NewWaitingProofStore(db kvdb.Backend) (*WaitingProofStore, error) { s := &WaitingProofStore{ db: db, } diff --git a/channelnotifier/channelnotifier.go b/channelnotifier/channelnotifier.go index 74c2b15eba..2cf6015c4d 100644 --- a/channelnotifier/channelnotifier.go +++ b/channelnotifier/channelnotifier.go @@ -17,7 +17,7 @@ type ChannelNotifier struct { ntfnServer *subscribe.Server - chanDB *channeldb.DB + chanDB *channeldb.ChannelStateDB } // PendingOpenChannelEvent represents a new event where a new channel has @@ -76,7 +76,7 @@ type FullyResolvedChannelEvent struct { // New creates a new channel notifier. The ChannelNotifier gets channel // events from peers and from the chain arbitrator, and dispatches them to // its clients. -func New(chanDB *channeldb.DB) *ChannelNotifier { +func New(chanDB *channeldb.ChannelStateDB) *ChannelNotifier { return &ChannelNotifier{ ntfnServer: subscribe.NewServer(), chanDB: chanDB, diff --git a/chanrestore.go b/chanrestore.go index 7527499cd3..cd68b50772 100644 --- a/chanrestore.go +++ b/chanrestore.go @@ -34,7 +34,7 @@ const ( // need the secret key chain in order obtain the prior shachain root so we can // verify the DLP protocol as initiated by the remote node. type chanDBRestorer struct { - db *channeldb.DB + db *channeldb.ChannelStateDB secretKeys keychain.SecretKeyRing diff --git a/contractcourt/breacharbiter.go b/contractcourt/breacharbiter.go index 112aa5bce5..3253e00094 100644 --- a/contractcourt/breacharbiter.go +++ b/contractcourt/breacharbiter.go @@ -136,7 +136,7 @@ type BreachConfig struct { // DB provides access to the user's channels, allowing the breach // arbiter to determine the current state of a user's channels, and how // it should respond to channel closure. - DB *channeldb.DB + DB *channeldb.ChannelStateDB // Estimator is used by the breach arbiter to determine an appropriate // fee level when generating, signing, and broadcasting sweep @@ -1432,11 +1432,11 @@ func (b *BreachArbiter) sweepSpendableOutputsTxn(txWeight int64, // store is to ensure that we can recover from a restart in the middle of a // breached contract retribution. type RetributionStore struct { - db *channeldb.DB + db kvdb.Backend } // NewRetributionStore creates a new instance of a RetributionStore. -func NewRetributionStore(db *channeldb.DB) *RetributionStore { +func NewRetributionStore(db kvdb.Backend) *RetributionStore { return &RetributionStore{ db: db, } diff --git a/contractcourt/breacharbiter_test.go b/contractcourt/breacharbiter_test.go index 0d423584ca..61e8193915 100644 --- a/contractcourt/breacharbiter_test.go +++ b/contractcourt/breacharbiter_test.go @@ -987,7 +987,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter, contractBreaches := make(chan *ContractBreachEvent) brar, cleanUpArb, err := createTestArbiter( - t, contractBreaches, alice.State().Db, + t, contractBreaches, alice.State().Db.GetParentDB(), ) if err != nil { t.Fatalf("unable to initialize test breach arbiter: %v", err) @@ -1164,7 +1164,7 @@ func TestBreachHandoffFail(t *testing.T) { assertNotPendingClosed(t, alice) brar, cleanUpArb, err := createTestArbiter( - t, contractBreaches, alice.State().Db, + t, contractBreaches, alice.State().Db.GetParentDB(), ) if err != nil { t.Fatalf("unable to initialize test breach arbiter: %v", err) @@ -2075,7 +2075,7 @@ func assertNoArbiterBreach(t *testing.T, brar *BreachArbiter, // assertBrarCleanup blocks until the given channel point has been removed the // retribution store and the channel is fully closed in the database. func assertBrarCleanup(t *testing.T, brar *BreachArbiter, - chanPoint *wire.OutPoint, db *channeldb.DB) { + chanPoint *wire.OutPoint, db *channeldb.ChannelStateDB) { t.Helper() @@ -2174,7 +2174,7 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent, notifier := mock.MakeMockSpendNotifier() ba := NewBreachArbiter(&BreachConfig{ CloseLink: func(_ *wire.OutPoint, _ ChannelCloseType) {}, - DB: db, + DB: db.ChannelStateDB(), Estimator: chainfee.NewStaticEstimator(12500, 0), GenSweepScript: func() ([]byte, error) { return nil, nil }, ContractBreaches: contractBreaches, @@ -2375,7 +2375,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: channels.TestFundingTx, } @@ -2393,7 +2393,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobCommit, RemoteCommitment: bobCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 426382dd33..aeeff69f9a 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -258,7 +258,9 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, // same instance that is used by the link. chanPoint := a.channel.FundingOutpoint - channel, err := a.c.chanSource.FetchChannel(nil, chanPoint) + channel, err := a.c.chanSource.ChannelStateDB().FetchChannel( + nil, chanPoint, + ) if err != nil { return nil, err } @@ -301,7 +303,9 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) // Now that we know the link can't mutate the channel // state, we'll read the channel from disk the target // channel according to its channel point. - channel, err := a.c.chanSource.FetchChannel(nil, chanPoint) + channel, err := a.c.chanSource.ChannelStateDB().FetchChannel( + nil, chanPoint, + ) if err != nil { return nil, err } @@ -422,7 +426,7 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { // First, we'll we'll mark the channel as fully closed from the PoV of // the channel source. - err := c.chanSource.MarkChanFullyClosed(&chanPoint) + err := c.chanSource.ChannelStateDB().MarkChanFullyClosed(&chanPoint) if err != nil { log.Errorf("ChainArbitrator: unable to mark ChannelPoint(%v) "+ "fully closed: %v", chanPoint, err) @@ -480,7 +484,7 @@ func (c *ChainArbitrator) Start() error { // First, we'll fetch all the channels that are still open, in order to // collect them within our set of active contracts. - openChannels, err := c.chanSource.FetchAllChannels() + openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels() if err != nil { return err } @@ -538,7 +542,9 @@ func (c *ChainArbitrator) Start() error { // In addition to the channels that we know to be open, we'll also // launch arbitrators to finishing resolving any channels that are in // the pending close state. - closingChannels, err := c.chanSource.FetchClosedChannels(true) + closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels( + true, + ) if err != nil { return err } diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index e197c0b091..cb1648065a 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -49,7 +49,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { // We manually set the db here to make sure all channels are // synced to the same db. - channel.Db = db + channel.Db = db.ChannelStateDB() addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), @@ -165,7 +165,7 @@ func TestResolveContract(t *testing.T) { } defer cleanup() channel := newChannel.State() - channel.Db = db + channel.Db = db.ChannelStateDB() addr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 18556, @@ -205,7 +205,7 @@ func TestResolveContract(t *testing.T) { // While the resolver are active, we'll now remove the channel from the // database (mark is as closed). - err = db.AbandonChannel(&channel.FundingOutpoint, 4) + err = db.ChannelStateDB().AbandonChannel(&channel.FundingOutpoint, 4) if err != nil { t.Fatalf("unable to remove channel: %v", err) } diff --git a/contractcourt/utils_test.go b/contractcourt/utils_test.go index 11f23d8cc9..0023402c1f 100644 --- a/contractcourt/utils_test.go +++ b/contractcourt/utils_test.go @@ -58,7 +58,7 @@ func copyChannelState(state *channeldb.OpenChannel) ( *channeldb.OpenChannel, func(), error) { // Make a copy of the DB. - dbFile := filepath.Join(state.Db.Path(), "channel.db") + dbFile := filepath.Join(state.Db.GetParentDB().Path(), "channel.db") tempDbPath, err := ioutil.TempDir("", "past-state") if err != nil { return nil, nil, err @@ -81,7 +81,7 @@ func copyChannelState(state *channeldb.OpenChannel) ( return nil, nil, err } - chans, err := newDb.FetchAllChannels() + chans, err := newDb.ChannelStateDB().FetchAllChannels() if err != nil { cleanup() return nil, nil, err diff --git a/discovery/message_store.go b/discovery/message_store.go index 4d5f9b2054..40f2df78ac 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" ) @@ -59,7 +58,7 @@ type GossipMessageStore interface { // version of a message (like in the case of multiple ChannelUpdate's) for a // channel with a peer. type MessageStore struct { - db *channeldb.DB + db kvdb.Backend } // A compile-time assertion to ensure messageStore implements the @@ -67,8 +66,8 @@ type MessageStore struct { var _ GossipMessageStore = (*MessageStore)(nil) // NewMessageStore creates a new message store backed by a channeldb instance. -func NewMessageStore(db *channeldb.DB) (*MessageStore, error) { - err := kvdb.Batch(db.Backend, func(tx kvdb.RwTx) error { +func NewMessageStore(db kvdb.Backend) (*MessageStore, error) { + err := kvdb.Batch(db, func(tx kvdb.RwTx) error { _, err := tx.CreateTopLevelBucket(messageStoreBucket) return err }) @@ -124,7 +123,7 @@ func (s *MessageStore) AddMessage(msg lnwire.Message, peerPubKey [33]byte) error return err } - return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.db, func(tx kvdb.RwTx) error { messageStore := tx.ReadWriteBucket(messageStoreBucket) if messageStore == nil { return ErrCorruptedMessageStore @@ -145,7 +144,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, return err } - return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.db, func(tx kvdb.RwTx) error { messageStore := tx.ReadWriteBucket(messageStoreBucket) if messageStore == nil { return ErrCorruptedMessageStore diff --git a/funding/manager_test.go b/funding/manager_test.go index acd7ca5147..64c66f92ee 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -261,7 +261,7 @@ func (n *testNode) AddNewChannel(channel *channeldb.OpenChannel, } } -func createTestWallet(cdb *channeldb.DB, netParams *chaincfg.Params, +func createTestWallet(cdb *channeldb.ChannelStateDB, netParams *chaincfg.Params, notifier chainntnfs.ChainNotifier, wc lnwallet.WalletController, signer input.Signer, keyRing keychain.SecretKeyRing, bio lnwallet.BlockChainIO, @@ -329,11 +329,13 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, } dbDir := filepath.Join(tempTestDir, "cdb") - cdb, err := channeldb.Open(dbDir) + fullDB, err := channeldb.Open(dbDir) if err != nil { return nil, err } + cdb := fullDB.ChannelStateDB() + keyRing := &mock.SecretKeyRing{ RootKey: alicePrivKey, } diff --git a/htlcswitch/circuit_map.go b/htlcswitch/circuit_map.go index 951c922f01..d5bb5f376d 100644 --- a/htlcswitch/circuit_map.go +++ b/htlcswitch/circuit_map.go @@ -199,9 +199,16 @@ type circuitMap struct { // parameterize an instance of circuitMap. type CircuitMapConfig struct { // DB provides the persistent storage engine for the circuit map. - // TODO(conner): create abstraction to allow for the substitution of - // other persistence engines. - DB *channeldb.DB + DB kvdb.Backend + + // FetchAllOpenChannels is a function that fetches all currently open + // channels from the channel database. + FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error) + + // FetchClosedChannels is a function that fetches all closed channels + // from the channel database. + FetchClosedChannels func( + pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) // ExtractErrorEncrypter derives the shared secret used to encrypt // errors from the obfuscator's ephemeral public key. @@ -296,7 +303,7 @@ func (cm *circuitMap) cleanClosedChannels() error { // Find closed channels and cache their ShortChannelIDs into a map. // This map will be used for looking up relative circuits and keystones. - closedChannels, err := cm.cfg.DB.FetchClosedChannels(false) + closedChannels, err := cm.cfg.FetchClosedChannels(false) if err != nil { return err } @@ -629,7 +636,7 @@ func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) { // channels. Therefore, it must be called before any links are created to avoid // interfering with normal operation. func (cm *circuitMap) trimAllOpenCircuits() error { - activeChannels, err := cm.cfg.DB.FetchAllOpenChannels() + activeChannels, err := cm.cfg.FetchAllOpenChannels() if err != nil { return err } @@ -860,7 +867,7 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) ( // Write the entire batch of circuits to the persistent circuit bucket // using bolt's Batch write. This method must be called from multiple, // distinct goroutines to have any impact on performance. - err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error { circuitBkt := tx.ReadWriteBucket(circuitAddKey) if circuitBkt == nil { return ErrCorruptedCircuitMap @@ -1091,7 +1098,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error { } cm.mtx.Unlock() - err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error { for _, circuit := range removedCircuits { // If this htlc made it to an outgoing link, load the // keystone bucket from which we will remove the diff --git a/htlcswitch/circuit_test.go b/htlcswitch/circuit_test.go index d3ee7b4fee..fed07958ba 100644 --- a/htlcswitch/circuit_test.go +++ b/htlcswitch/circuit_test.go @@ -103,8 +103,11 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig, onionProcessor := newOnionProcessor(t) + db := makeCircuitDB(t, "") circuitMapCfg := &htlcswitch.CircuitMapConfig{ - DB: makeCircuitDB(t, ""), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter, } @@ -634,13 +637,17 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB { func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) ( *htlcswitch.CircuitMapConfig, htlcswitch.CircuitMap) { - // Record the current temp path and close current db. - dbPath := cfg.DB.Path() + // Record the current temp path and close current db. We know we have + // a full channeldb.DB here since we created it just above. + dbPath := cfg.DB.(*channeldb.DB).Path() cfg.DB.Close() // Reinitialize circuit map with same db path. + db := makeCircuitDB(t, dbPath) cfg2 := &htlcswitch.CircuitMapConfig{ - DB: makeCircuitDB(t, dbPath), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, } cm2, err := htlcswitch.NewCircuitMap(cfg2) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 1f99a1d9d1..865f3afd14 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1938,7 +1938,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( pCache := newMockPreimageCache() - aliceDb := aliceLc.channel.State().Db + aliceDb := aliceLc.channel.State().Db.GetParentDB() aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) if err != nil { return nil, nil, nil, nil, nil, nil, err @@ -4438,7 +4438,7 @@ func (h *persistentLinkHarness) restartLink( pCache = newMockPreimageCache() ) - aliceDb := aliceChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() aliceSwitch := h.coreLink.cfg.Switch if restartSwitch { var err error diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index ce9b0f8381..578a92367c 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -169,8 +169,10 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) } cfg := Config{ - DB: db, - SwitchPackager: channeldb.NewSwitchPackager(), + DB: db, + FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, + FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, + SwitchPackager: channeldb.NewSwitchPackager(), FwdingLog: &mockForwardingLog{ events: make(map[time.Time]channeldb.ForwardingEvent), }, diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index 2bd35f60ae..8d6cb5b3af 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -83,7 +83,7 @@ func deserializeNetworkResult(r io.Reader) (*networkResult, error) { // is back. The Switch will checkpoint any received result to the store, and // the store will keep results and notify the callers about them. type networkResultStore struct { - db *channeldb.DB + backend kvdb.Backend // results is a map from paymentIDs to channels where subscribers to // payment results will be notified. @@ -96,9 +96,9 @@ type networkResultStore struct { paymentIDMtx *multimutex.Mutex } -func newNetworkResultStore(db *channeldb.DB) *networkResultStore { +func newNetworkResultStore(db kvdb.Backend) *networkResultStore { return &networkResultStore{ - db: db, + backend: db, results: make(map[uint64][]chan *networkResult), paymentIDMtx: multimutex.NewMutex(), } @@ -126,7 +126,7 @@ func (store *networkResultStore) storeResult(paymentID uint64, var paymentIDBytes [8]byte binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID) - err := kvdb.Batch(store.db.Backend, func(tx kvdb.RwTx) error { + err := kvdb.Batch(store.backend, func(tx kvdb.RwTx) error { networkResults, err := tx.CreateTopLevelBucket( networkResultStoreBucketKey, ) @@ -171,7 +171,7 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) ( resultChan = make(chan *networkResult, 1) ) - err := kvdb.View(store.db, func(tx kvdb.RTx) error { + err := kvdb.View(store.backend, func(tx kvdb.RTx) error { var err error result, err = fetchResult(tx, paymentID) switch { @@ -219,7 +219,7 @@ func (store *networkResultStore) getResult(pid uint64) ( *networkResult, error) { var result *networkResult - err := kvdb.View(store.db, func(tx kvdb.RTx) error { + err := kvdb.View(store.backend, func(tx kvdb.RTx) error { var err error result, err = fetchResult(tx, pid) return err @@ -260,7 +260,7 @@ func fetchResult(tx kvdb.RTx, pid uint64) (*networkResult, error) { // concurrently while this process is ongoing, as its result might end up being // deleted. func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error { - return kvdb.Update(store.db.Backend, func(tx kvdb.RwTx) error { + return kvdb.Update(store.backend, func(tx kvdb.RwTx) error { networkResults, err := tx.CreateTopLevelBucket( networkResultStoreBucketKey, ) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index d367d5e6bd..17b4238573 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -121,9 +121,18 @@ type Config struct { // subsystem. LocalChannelClose func(pubKey []byte, request *ChanClose) - // DB is the channeldb instance that will be used to back the switch's + // DB is the database backend that will be used to back the switch's // persistent circuit map. - DB *channeldb.DB + DB kvdb.Backend + + // FetchAllOpenChannels is a function that fetches all currently open + // channels from the channel database. + FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error) + + // FetchClosedChannels is a function that fetches all closed channels + // from the channel database. + FetchClosedChannels func( + pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error) // SwitchPackager provides access to the forwarding packages of all // active channels. This gives the switch the ability to read arbitrary @@ -281,6 +290,8 @@ type Switch struct { func New(cfg Config, currentHeight uint32) (*Switch, error) { circuitMap, err := NewCircuitMap(&CircuitMapConfig{ DB: cfg.DB, + FetchAllOpenChannels: cfg.FetchAllOpenChannels, + FetchClosedChannels: cfg.FetchClosedChannels, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, }) if err != nil { @@ -1374,7 +1385,7 @@ func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) { // we're the originator of the payment, so the link stops attempting to // re-broadcast. func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error { - return kvdb.Batch(s.cfg.DB.Backend, func(tx kvdb.RwTx) error { + return kvdb.Batch(s.cfg.DB, func(tx kvdb.RwTx) error { return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...) }) } @@ -1778,7 +1789,7 @@ func (s *Switch) Start() error { // forwarding packages and reforwards any Settle or Fail HTLCs found. This is // used to resurrect the switch's mailboxes after a restart. func (s *Switch) reforwardResponses() error { - openChannels, err := s.cfg.DB.FetchAllOpenChannels() + openChannels, err := s.cfg.FetchAllOpenChannels() if err != nil { return err } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index d33daff8f4..eaf2aa99c8 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -306,7 +306,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, ShortChannelID: chanID, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(chanID), FundingTxn: channels.TestFundingTx, } @@ -325,7 +325,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, LocalCommitment: bobCommit, RemoteCommitment: bobCommit, ShortChannelID: chanID, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(chanID), } @@ -384,7 +384,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } restoreAlice := func() (*lnwallet.LightningChannel, error) { - aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub) + aliceStoredChannels, err := dbAlice.ChannelStateDB(). + FetchOpenChannels(aliceKeyPub) switch err { case nil: case kvdb.ErrDatabaseNotOpen: @@ -394,7 +395,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, "db: %v", err) } - aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub) + aliceStoredChannels, err = dbAlice.ChannelStateDB(). + FetchOpenChannels(aliceKeyPub) if err != nil { return nil, errors.Errorf("unable to fetch alice "+ "channel: %v", err) @@ -428,7 +430,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } restoreBob := func() (*lnwallet.LightningChannel, error) { - bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub) + bobStoredChannels, err := dbBob.ChannelStateDB(). + FetchOpenChannels(bobKeyPub) switch err { case nil: case kvdb.ErrDatabaseNotOpen: @@ -438,7 +441,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, "db: %v", err) } - bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub) + bobStoredChannels, err = dbBob.ChannelStateDB(). + FetchOpenChannels(bobKeyPub) if err != nil { return nil, errors.Errorf("unable to fetch bob "+ "channel: %v", err) @@ -950,9 +954,9 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, secondBobChannel, carolChannel *lnwallet.LightningChannel, startingHeight uint32, opts ...serverOption) *threeHopNetwork { - aliceDb := aliceChannel.State().Db - bobDb := firstBobChannel.State().Db - carolDb := carolChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() + bobDb := firstBobChannel.State().Db.GetParentDB() + carolDb := carolChannel.State().Db.GetParentDB() hopNetwork := newHopNetwork() @@ -1201,8 +1205,8 @@ func newTwoHopNetwork(t testing.TB, aliceChannel, bobChannel *lnwallet.LightningChannel, startingHeight uint32) *twoHopNetwork { - aliceDb := aliceChannel.State().Db - bobDb := bobChannel.State().Db + aliceDb := aliceChannel.State().Db.GetParentDB() + bobDb := bobChannel.State().Db.GetParentDB() hopNetwork := newHopNetwork() diff --git a/lnd.go b/lnd.go index 8bf4a9fc46..7dcc7bf55e 100644 --- a/lnd.go +++ b/lnd.go @@ -697,7 +697,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, interceptor signal.Interceptor) error BtcdMode: cfg.BtcdMode, LtcdMode: cfg.LtcdMode, HeightHintDB: dbs.heightHintDB, - ChanStateDB: dbs.chanStateDB, + ChanStateDB: dbs.chanStateDB.ChannelStateDB(), PrivateWalletPw: privateWalletPw, PublicWalletPw: publicWalletPw, Birthday: walletInitParams.Birthday, diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 4e88ae0c16..193f3a63a9 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -56,7 +56,7 @@ type AddInvoiceConfig struct { // ChanDB is a global boltdb instance which is needed to access the // channel graph. - ChanDB *channeldb.DB + ChanDB *channeldb.ChannelStateDB // Graph holds a reference to the ChannelGraph database. Graph *channeldb.ChannelGraph diff --git a/lnrpc/invoicesrpc/config_active.go b/lnrpc/invoicesrpc/config_active.go index 3246f4b7f8..abe8c5565d 100644 --- a/lnrpc/invoicesrpc/config_active.go +++ b/lnrpc/invoicesrpc/config_active.go @@ -50,7 +50,7 @@ type Config struct { // ChanStateDB is a possibly replicated db instance which contains the // channel state - ChanStateDB *channeldb.DB + ChanStateDB *channeldb.ChannelStateDB // GenInvoiceFeatures returns a feature containing feature bits that // should be advertised on freshly generated invoices. diff --git a/lnwallet/config.go b/lnwallet/config.go index a73120c020..cf7f3f4b87 100644 --- a/lnwallet/config.go +++ b/lnwallet/config.go @@ -18,7 +18,7 @@ type Config struct { // Database is a wrapper around a namespace within boltdb reserved for // ln-based wallet metadata. See the 'channeldb' package for further // information. - Database *channeldb.DB + Database *channeldb.ChannelStateDB // Notifier is used by in order to obtain notifications about funding // transaction reaching a specified confirmation depth, and to catch diff --git a/lnwallet/test/test_interface.go b/lnwallet/test/test_interface.go index dd6bf1a958..0b2aecff69 100644 --- a/lnwallet/test/test_interface.go +++ b/lnwallet/test/test_interface.go @@ -327,13 +327,13 @@ func createTestWallet(tempTestDir string, miningNode *rpctest.Harness, signer input.Signer, bio lnwallet.BlockChainIO) (*lnwallet.LightningWallet, error) { dbDir := filepath.Join(tempTestDir, "cdb") - cdb, err := channeldb.Open(dbDir) + fullDB, err := channeldb.Open(dbDir) if err != nil { return nil, err } cfg := lnwallet.Config{ - Database: cdb, + Database: fullDB.ChannelStateDB(), Notifier: notifier, SecretKeyRing: keyRing, WalletController: wc, @@ -2944,11 +2944,11 @@ func clearWalletStates(a, b *lnwallet.LightningWallet) error { a.ResetReservations() b.ResetReservations() - if err := a.Cfg.Database.Wipe(); err != nil { + if err := a.Cfg.Database.GetParentDB().Wipe(); err != nil { return err } - return b.Cfg.Database.Wipe() + return b.Cfg.Database.GetParentDB().Wipe() } func waitForMempoolTx(r *rpctest.Harness, txid *chainhash.Hash) error { diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index bd048b2c0f..40af5201cf 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -322,7 +322,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) ( RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceLocalCommit, RemoteCommitment: aliceRemoteCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: testTx, } @@ -340,7 +340,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) ( RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobLocalCommit, RemoteCommitment: bobRemoteCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } diff --git a/lnwallet/transactions_test.go b/lnwallet/transactions_test.go index 696328cdb2..be5bf705a7 100644 --- a/lnwallet/transactions_test.go +++ b/lnwallet/transactions_test.go @@ -937,7 +937,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp RevocationStore: shachain.NewRevocationStore(), LocalCommitment: remoteCommit, RemoteCommitment: remoteCommit, - Db: dbRemote, + Db: dbRemote.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: tc.fundingTx.MsgTx(), } @@ -955,7 +955,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp RevocationStore: shachain.NewRevocationStore(), LocalCommitment: localCommit, RemoteCommitment: localCommit, - Db: dbLocal, + Db: dbLocal.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: tc.fundingTx.MsgTx(), } diff --git a/peer/brontide.go b/peer/brontide.go index 60c41af6f1..9c6df67342 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -185,7 +185,7 @@ type Config struct { InterceptSwitch *htlcswitch.InterceptableSwitch // ChannelDB is used to fetch opened channels, and closed channels. - ChannelDB *channeldb.DB + ChannelDB *channeldb.ChannelStateDB // ChannelGraph is a pointer to the channel graph which is used to // query information about the set of known active channels. diff --git a/peer/test_utils.go b/peer/test_utils.go index 3ce1cbe031..ac5f5f5ab5 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -229,7 +229,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, RevocationStore: shachain.NewRevocationStore(), LocalCommitment: aliceCommit, RemoteCommitment: aliceCommit, - Db: dbAlice, + Db: dbAlice.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), FundingTxn: channels.TestFundingTx, } @@ -246,7 +246,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, RevocationStore: shachain.NewRevocationStore(), LocalCommitment: bobCommit, RemoteCommitment: bobCommit, - Db: dbBob, + Db: dbBob.ChannelStateDB(), Packager: channeldb.NewChannelPackager(shortChanID), } @@ -321,7 +321,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, ChanStatusSampleInterval: 30 * time.Second, ChanEnableTimeout: chanActiveTimeout, ChanDisableTimeout: 2 * time.Minute, - DB: dbAlice, + DB: dbAlice.ChannelStateDB(), Graph: dbAlice.ChannelGraph(), MessageSigner: nodeSignerAlice, OurPubKey: aliceKeyPub, @@ -359,7 +359,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, ChanActiveTimeout: chanActiveTimeout, InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil), - ChannelDB: dbAlice, + ChannelDB: dbAlice.ChannelStateDB(), FeeEstimator: estimator, Wallet: wallet, ChainNotifier: notifier, diff --git a/rpcserver.go b/rpcserver.go index 79f6553750..67934bac11 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -3979,7 +3979,7 @@ func (r *rpcServer) createRPCClosedChannel( CloseInitiator: closeInitiator, } - reports, err := r.server.chanStateDB.FetchChannelReports( + reports, err := r.server.miscDB.FetchChannelReports( *r.cfg.ActiveNetParams.GenesisHash, &dbChannel.ChanPoint, ) switch err { @@ -5142,7 +5142,7 @@ func (r *rpcServer) ListInvoices(ctx context.Context, PendingOnly: req.PendingOnly, Reversed: req.Reversed, } - invoiceSlice, err := r.server.chanStateDB.QueryInvoices(q) + invoiceSlice, err := r.server.miscDB.QueryInvoices(q) if err != nil { return nil, fmt.Errorf("unable to query invoices: %v", err) } @@ -5944,7 +5944,7 @@ func (r *rpcServer) ListPayments(ctx context.Context, query.MaxPayments = math.MaxUint64 } - paymentsQuerySlice, err := r.server.chanStateDB.QueryPayments(query) + paymentsQuerySlice, err := r.server.miscDB.QueryPayments(query) if err != nil { return nil, err } @@ -5985,9 +5985,7 @@ func (r *rpcServer) DeletePayment(ctx context.Context, rpcsLog.Infof("[DeletePayment] payment_identifier=%v, "+ "failed_htlcs_only=%v", hash, req.FailedHtlcsOnly) - err = r.server.chanStateDB.DeletePayment( - hash, req.FailedHtlcsOnly, - ) + err = r.server.miscDB.DeletePayment(hash, req.FailedHtlcsOnly) if err != nil { return nil, err } @@ -6004,7 +6002,7 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context, "failed_htlcs_only=%v", req.FailedPaymentsOnly, req.FailedHtlcsOnly) - err := r.server.chanStateDB.DeletePayments( + err := r.server.miscDB.DeletePayments( req.FailedPaymentsOnly, req.FailedHtlcsOnly, ) if err != nil { @@ -6166,7 +6164,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, return nil, err } - fwdEventLog := r.server.chanStateDB.ForwardingLog() + fwdEventLog := r.server.miscDB.ForwardingLog() // computeFeeSum is a helper function that computes the total fees for // a particular time slice described by a forwarding event query. @@ -6407,7 +6405,7 @@ func (r *rpcServer) ForwardingHistory(ctx context.Context, IndexOffset: req.IndexOffset, NumMaxEvents: numEvents, } - timeSlice, err := r.server.chanStateDB.ForwardingLog().Query(eventQuery) + timeSlice, err := r.server.miscDB.ForwardingLog().Query(eventQuery) if err != nil { return nil, fmt.Errorf("unable to query forwarding log: %v", err) } diff --git a/server.go b/server.go index 6b8830f03c..f8f1f53e85 100644 --- a/server.go +++ b/server.go @@ -222,10 +222,14 @@ type server struct { graphDB *channeldb.ChannelGraph - chanStateDB *channeldb.DB + chanStateDB *channeldb.ChannelStateDB addrSource chanbackup.AddressSource + // miscDB is the DB that contains all "other" databases within the main + // channel DB that haven't been separated out yet. + miscDB *channeldb.DB + htlcSwitch *htlcswitch.Switch interceptableSwitch *htlcswitch.InterceptableSwitch @@ -434,15 +438,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s := &server{ cfg: cfg, graphDB: dbs.graphDB.ChannelGraph(), - chanStateDB: dbs.chanStateDB, + chanStateDB: dbs.chanStateDB.ChannelStateDB(), addrSource: dbs.chanStateDB, + miscDB: dbs.chanStateDB, cc: cc, sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer), writePool: writePool, readPool: readPool, chansToRestore: chansToRestore, - channelNotifier: channelnotifier.New(dbs.chanStateDB), + channelNotifier: channelnotifier.New( + dbs.chanStateDB.ChannelStateDB(), + ), identityECDH: nodeKeyECDH, nodeSigner: netann.NewNodeSigner(nodeKeySigner), @@ -494,7 +501,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.htlcNotifier = htlcswitch.NewHtlcNotifier(time.Now) s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{ - DB: dbs.chanStateDB, + DB: dbs.chanStateDB, + FetchAllOpenChannels: s.chanStateDB.FetchAllOpenChannels, + FetchClosedChannels: s.chanStateDB.FetchClosedChannels, LocalChannelClose: func(pubKey []byte, request *htlcswitch.ChanClose) { @@ -536,7 +545,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MessageSigner: s.nodeSigner, IsChannelActive: s.htlcSwitch.HasActiveLink, ApplyChannelUpdate: s.applyChannelUpdate, - DB: dbs.chanStateDB, + DB: s.chanStateDB, Graph: dbs.graphDB.ChannelGraph(), } @@ -804,11 +813,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } chanSeries := discovery.NewChanSeries(s.graphDB) - gossipMessageStore, err := discovery.NewMessageStore(s.chanStateDB) + gossipMessageStore, err := discovery.NewMessageStore(dbs.chanStateDB) if err != nil { return nil, err } - waitingProofStore, err := channeldb.NewWaitingProofStore(s.chanStateDB) + waitingProofStore, err := channeldb.NewWaitingProofStore(dbs.chanStateDB) if err != nil { return nil, err } @@ -890,8 +899,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ ChainIO: cc.ChainIO, ConfDepth: 1, - FetchClosedChannels: dbs.chanStateDB.FetchClosedChannels, - FetchClosedChannel: dbs.chanStateDB.FetchClosedChannel, + FetchClosedChannels: s.chanStateDB.FetchClosedChannels, + FetchClosedChannel: s.chanStateDB.FetchClosedChannel, Notifier: cc.ChainNotifier, PublishTransaction: cc.Wallet.PublishTransaction, Store: utxnStore, @@ -1017,7 +1026,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.breachArbiter = contractcourt.NewBreachArbiter(&contractcourt.BreachConfig{ CloseLink: closeLink, - DB: dbs.chanStateDB, + DB: s.chanStateDB, Estimator: s.cc.FeeEstimator, GenSweepScript: newSweepPkScriptGen(cc.Wallet), Notifier: cc.ChainNotifier, @@ -1074,7 +1083,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, FindChannel: func(chanID lnwire.ChannelID) ( *channeldb.OpenChannel, error) { - dbChannels, err := dbs.chanStateDB.FetchAllChannels() + dbChannels, err := s.chanStateDB.FetchAllChannels() if err != nil { return nil, err } @@ -1246,7 +1255,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // static backup of the latest channel state. chanNotifier := &channelNotifier{ chanNotifier: s.channelNotifier, - addrs: s.chanStateDB, + addrs: dbs.chanStateDB, } backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) startingChans, err := chanbackup.FetchStaticChanBackups( @@ -1276,8 +1285,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, }, GetOpenChannels: s.chanStateDB.FetchAllOpenChannels, Clock: clock.NewDefaultClock(), - ReadFlapCount: s.chanStateDB.ReadFlapCount, - WriteFlapCount: s.chanStateDB.WriteFlapCounts, + ReadFlapCount: s.miscDB.ReadFlapCount, + WriteFlapCount: s.miscDB.WriteFlapCounts, FlapCountTicker: ticker.New(chanfitness.FlapCountFlushRate), }) diff --git a/subrpcserver_config.go b/subrpcserver_config.go index bf5911ec21..04853db767 100644 --- a/subrpcserver_config.go +++ b/subrpcserver_config.go @@ -93,7 +93,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config, routerBackend *routerrpc.RouterBackend, nodeSigner *netann.NodeSigner, graphDB *channeldb.ChannelGraph, - chanStateDB *channeldb.DB, + chanStateDB *channeldb.ChannelStateDB, sweeper *sweep.UtxoSweeper, tower *watchtower.Standalone, towerClient wtclient.Client, From d6fa912188951d408a6402bb7fade529d4a18745 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:18 +0200 Subject: [PATCH 07/15] multi: further decouple graph To further separate the channel graph from the channel state, we refactor the AddrsForNode method to use the graphs's public methods instead of directly accessing any buckets. This makes sure that we can have the channel state cached with just its buckets while not using a kvdb level cache for the graph. At the same time we refactor the graph's test to also be less dependent upon the channel state DB. --- autopilot/graph.go | 2 +- channeldb/db.go | 80 +++++------------- channeldb/graph.go | 115 ++++++++++++++++++++----- channeldb/graph_test.go | 175 +++++++++++++++++++-------------------- routing/graph.go | 2 +- routing/pathfind_test.go | 60 +++++++++----- routing/router.go | 6 +- routing/router_test.go | 24 ++---- rpcserver.go | 2 +- server.go | 2 +- 10 files changed, 254 insertions(+), 214 deletions(-) diff --git a/autopilot/graph.go b/autopilot/graph.go index 2624aa79dd..e630f8d35d 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -148,7 +148,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, return nil, err } - dbNode, err := d.db.FetchLightningNode(nil, vertex) + dbNode, err := d.db.FetchLightningNode(vertex) switch { case err == channeldb.ErrGraphNodeNotFound: fallthrough diff --git a/channeldb/db.go b/channeldb/db.go index 3ceb93d936..633639892c 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -23,6 +23,7 @@ import ( "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" ) const ( @@ -286,10 +287,14 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, // Set the parent pointer (only used in tests). chanDB.channelStateDB.parent = chanDB - chanDB.graph = newChannelGraph( + var err error + chanDB.graph, err = NewChannelGraph( backend, opts.RejectCacheSize, opts.ChannelCacheSize, opts.BatchCommitInterval, ) + if err != nil { + return nil, err + } // Synchronize the version of database and apply migrations if needed. if err := chanDB.syncVersions(dbVersions); err != nil { @@ -305,7 +310,7 @@ func (d *DB) Path() string { return d.dbPath } -var topLevelBuckets = [][]byte{ +var dbTopLevelBuckets = [][]byte{ openChannelBucket, closedChannelBucket, forwardingLogBucket, @@ -316,10 +321,6 @@ var topLevelBuckets = [][]byte{ paymentsIndexBucket, peersBucket, nodeInfoBucket, - nodeBucket, - edgeBucket, - edgeIndexBucket, - graphMetaBucket, metaBucket, closeSummaryBucket, outpointBucket, @@ -330,7 +331,7 @@ var topLevelBuckets = [][]byte{ // operation is fully atomic. func (d *DB) Wipe() error { err := kvdb.Update(d, func(tx kvdb.RwTx) error { - for _, tlb := range topLevelBuckets { + for _, tlb := range dbTopLevelBuckets { err := tx.DeleteTopLevelBucket(tlb) if err != nil && err != kvdb.ErrBucketNotFound { return err @@ -358,42 +359,12 @@ func initChannelDB(db kvdb.Backend) error { return nil } - for _, tlb := range topLevelBuckets { + for _, tlb := range dbTopLevelBuckets { if _, err := tx.CreateTopLevelBucket(tlb); err != nil { return err } } - nodes := tx.ReadWriteBucket(nodeBucket) - _, err = nodes.CreateBucket(aliasIndexBucket) - if err != nil { - return err - } - _, err = nodes.CreateBucket(nodeUpdateIndexBucket) - if err != nil { - return err - } - - edges := tx.ReadWriteBucket(edgeBucket) - if _, err := edges.CreateBucket(edgeIndexBucket); err != nil { - return err - } - if _, err := edges.CreateBucket(edgeUpdateIndexBucket); err != nil { - return err - } - if _, err := edges.CreateBucket(channelPointBucket); err != nil { - return err - } - if _, err := edges.CreateBucket(zombieBucket); err != nil { - return err - } - - graphMeta := tx.ReadWriteBucket(graphMetaBucket) - _, err = graphMeta.CreateBucket(pruneLogBucket) - if err != nil { - return err - } - meta.DbVersionNumber = getLatestDBVersion(dbVersions) return putMeta(meta, tx) }, func() {}) @@ -1157,30 +1128,21 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, return nil, err } - var graphNode LightningNode - err = kvdb.View(d, func(tx kvdb.RTx) error { - // We'll also query the graph for this peer to see if they have - // any addresses that we don't currently have stored within the - // link node database. - nodes := tx.ReadBucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - compressedPubKey := nodePub.SerializeCompressed() - graphNode, err = fetchLightningNode(nodes, compressedPubKey) - if err != nil && err != ErrGraphNodeNotFound { - // If the node isn't found, then that's OK, as we still - // have the link node data. - return err - } - - return nil - }, func() { - linkNode = nil - }) + // We'll also query the graph for this peer to see if they have any + // addresses that we don't currently have stored within the link node + // database. + pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed()) if err != nil { return nil, err } + graphNode, err := d.graph.FetchLightningNode(pubKey) + if err != nil && err != ErrGraphNodeNotFound { + return nil, err + } else if err == ErrGraphNodeNotFound { + // If the node isn't found, then that's OK, as we still have the + // link node data. But any other error needs to be returned. + graphNode = &LightningNode{} + } // Now that we have both sources of addrs for this node, we'll use a // map to de-duplicate any addresses between the two sources, and diff --git a/channeldb/graph.go b/channeldb/graph.go index cb9268307d..92b04bd93d 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -184,10 +184,14 @@ type ChannelGraph struct { nodeScheduler batch.Scheduler } -// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The +// NewChannelGraph allocates a new ChannelGraph backed by a DB instance. The // returned instance has its own unique reject cache and channel cache. -func newChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, - batchCommitInterval time.Duration) *ChannelGraph { +func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, + batchCommitInterval time.Duration) (*ChannelGraph, error) { + + if err := initChannelGraph(db); err != nil { + return nil, err + } g := &ChannelGraph{ db: db, @@ -201,7 +205,85 @@ func newChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, db, nil, batchCommitInterval, ) - return g + return g, nil +} + +var graphTopLevelBuckets = [][]byte{ + nodeBucket, + edgeBucket, + edgeIndexBucket, + graphMetaBucket, +} + +// Wipe completely deletes all saved state within all used buckets within the +// database. The deletion is done in a single transaction, therefore this +// operation is fully atomic. +func (c *ChannelGraph) Wipe() error { + err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { + for _, tlb := range graphTopLevelBuckets { + err := tx.DeleteTopLevelBucket(tlb) + if err != nil && err != kvdb.ErrBucketNotFound { + return err + } + } + return nil + }, func() {}) + if err != nil { + return err + } + + return initChannelGraph(c.db) +} + +// createChannelDB creates and initializes a fresh version of channeldb. In +// the case that the target path has not yet been created or doesn't yet exist, +// then the path is created. Additionally, all required top-level buckets used +// within the database are created. +func initChannelGraph(db kvdb.Backend) error { + err := kvdb.Update(db, func(tx kvdb.RwTx) error { + for _, tlb := range graphTopLevelBuckets { + if _, err := tx.CreateTopLevelBucket(tlb); err != nil { + return err + } + } + + nodes := tx.ReadWriteBucket(nodeBucket) + _, err := nodes.CreateBucketIfNotExists(aliasIndexBucket) + if err != nil { + return err + } + _, err = nodes.CreateBucketIfNotExists(nodeUpdateIndexBucket) + if err != nil { + return err + } + + edges := tx.ReadWriteBucket(edgeBucket) + _, err = edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + _, err = edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + _, err = edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + _, err = edges.CreateBucketIfNotExists(zombieBucket) + if err != nil { + return err + } + + graphMeta := tx.ReadWriteBucket(graphMetaBucket) + _, err = graphMeta.CreateBucketIfNotExists(pruneLogBucket) + return err + }, func() {}) + if err != nil { + return fmt.Errorf("unable to create new channel graph: %v", err) + } + + return nil } // Database returns a pointer to the underlying database. @@ -218,7 +300,9 @@ func (c *ChannelGraph) Database() kvdb.Backend { // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // for that particular channel edge routing policy will be passed into the // callback. -func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { +func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + // TODO(roasbeef): ptr map to reduce # of allocs? no duplicates return kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -2356,17 +2440,11 @@ func (l *LightningNode) isPublic(tx kvdb.RTx, sourcePubKey []byte) (bool, error) // FetchLightningNode attempts to look up a target node by its identity public // key. If the node isn't found in the database, then ErrGraphNodeNotFound is // returned. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) ( +func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) ( *LightningNode, error) { var node *LightningNode - - fetchNode := func(tx kvdb.RTx) error { + err := kvdb.View(c.db, func(tx kvdb.RTx) error { // First grab the nodes bucket which stores the mapping from // pubKey to node information. nodes := tx.ReadBucket(nodeBucket) @@ -2393,14 +2471,9 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) ( node = &n return nil - } - - var err error - if tx == nil { - err = kvdb.View(c.db, fetchNode, func() {}) - } else { - err = fetchNode(tx) - } + }, func() { + node = nil + }) if err != nil { return nil, err } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 6d04429732..d2953a523b 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -6,10 +6,12 @@ import ( "errors" "fmt" "image/color" + "io/ioutil" "math" "math/big" prand "math/rand" "net" + "os" "reflect" "runtime" "sync" @@ -45,6 +47,48 @@ var ( testPub = route.Vertex{2, 202, 4} ) +// MakeTestGraph creates a new instance of the ChannelGraph for testing +// purposes. A callback which cleans up the created temporary directories is +// also returned and intended to be executed after the test completes. +func MakeTestGraph(modifiers ...OptionModifier) (*ChannelGraph, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channelgraph") + if err != nil { + return nil, nil, err + } + + opts := DefaultOptions() + for _, modifier := range modifiers { + modifier(&opts) + } + + // Next, create channelgraph for the first time. + backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cgr") + if err != nil { + backendCleanup() + return nil, nil, err + } + + graph, err := NewChannelGraph( + backend, opts.RejectCacheSize, opts.ChannelCacheSize, + opts.BatchCommitInterval, + ) + if err != nil { + backendCleanup() + _ = os.RemoveAll(tempDirName) + return nil, nil, err + } + + cleanUp := func() { + _ = backend.Close() + backendCleanup() + _ = os.RemoveAll(tempDirName) + } + + return graph, cleanUp, nil +} + func createLightningNode(db kvdb.Backend, priv *btcec.PrivateKey) (*LightningNode, error) { updateTime := prand.Int63() @@ -76,14 +120,12 @@ func createTestVertex(db kvdb.Backend) (*LightningNode, error) { func TestNodeInsertionAndDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test basic insertion/deletion for vertexes from the // graph, so we'll create a test vertex to start with. node := &LightningNode{ @@ -107,7 +149,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // Next, fetch the node from the database to ensure everything was // serialized properly. - dbNode, err := graph.FetchLightningNode(nil, testPub) + dbNode, err := graph.FetchLightningNode(testPub) if err != nil { t.Fatalf("unable to locate node: %v", err) } @@ -131,7 +173,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. - _, err = graph.FetchLightningNode(nil, testPub) + _, err = graph.FetchLightningNode(testPub) if err != ErrGraphNodeNotFound { t.Fatalf("fetch after delete should fail!") } @@ -142,14 +184,12 @@ func TestNodeInsertionAndDeletion(t *testing.T) { func TestPartialNode(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We want to be able to insert nodes into the graph that only has the // PubKey set. node := &LightningNode{ @@ -163,7 +203,7 @@ func TestPartialNode(t *testing.T) { // Next, fetch the node from the database to ensure everything was // serialized properly. - dbNode, err := graph.FetchLightningNode(nil, testPub) + dbNode, err := graph.FetchLightningNode(testPub) if err != nil { t.Fatalf("unable to locate node: %v", err) } @@ -195,7 +235,7 @@ func TestPartialNode(t *testing.T) { // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. - _, err = graph.FetchLightningNode(nil, testPub) + _, err = graph.FetchLightningNode(testPub) if err != ErrGraphNodeNotFound { t.Fatalf("fetch after delete should fail!") } @@ -204,14 +244,12 @@ func TestPartialNode(t *testing.T) { func TestAliasLookup(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test the alias index within the database, so first // create a new test node. testNode, err := createTestVertex(graph.db) @@ -258,13 +296,11 @@ func TestAliasLookup(t *testing.T) { func TestSourceNode(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() - defer cleanUp() + graph, cleanUp, err := MakeTestGraph() if err != nil { t.Fatalf("unable to make test database: %v", err) } - - graph := db.ChannelGraph() + defer cleanUp() // We'd like to test the setting/getting of the source node, so we // first create a fake node to use within the test. @@ -299,14 +335,12 @@ func TestSourceNode(t *testing.T) { func TestEdgeInsertionDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test the insertion/deletion of edges, so we create two // vertexes to connect. node1, err := createTestVertex(graph.db) @@ -434,13 +468,12 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, func TestDisconnectBlockAtHeight(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) @@ -721,14 +754,12 @@ func createChannelEdge(db kvdb.Backend, node1, node2 *LightningNode) (*ChannelEd func TestEdgeInfoUpdates(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. node1, err := createTestVertex(graph.db) @@ -851,14 +882,12 @@ func newEdgePolicy(chanID uint64, db kvdb.Backend, func TestGraphTraversal(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the // graph. @@ -1112,13 +1141,12 @@ func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoi func TestGraphPruning(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) @@ -1320,14 +1348,12 @@ func TestGraphPruning(t *testing.T) { func TestHighestChanID(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // If we don't yet have any channels in the database, then we should // get a channel ID of zero if we ask for the highest channel ID. bestID, err := graph.HighestChanID() @@ -1397,14 +1423,12 @@ func TestHighestChanID(t *testing.T) { func TestChanUpdatesInHorizon(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // If we issue an arbitrary query before any channel updates are // inserted in the database, we should get zero results. chanUpdates, err := graph.ChanUpdatesInHorizon( @@ -1567,14 +1591,12 @@ func TestChanUpdatesInHorizon(t *testing.T) { func TestNodeUpdatesInHorizon(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - startTime := time.Unix(1234, 0) endTime := startTime @@ -1690,14 +1712,12 @@ func TestNodeUpdatesInHorizon(t *testing.T) { func TestFilterKnownChanIDs(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // If we try to filter out a set of channel ID's before we even know of // any channels, then we should get the entire set back. preChanIDs := []uint64{1, 2, 3, 4} @@ -1807,14 +1827,12 @@ func TestFilterKnownChanIDs(t *testing.T) { func TestFilterChannelRange(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. node1, err := createTestVertex(graph.db) @@ -1941,14 +1959,12 @@ func TestFilterChannelRange(t *testing.T) { func TestFetchChanInfos(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'll first populate our graph with two nodes. All channels created // below will be made between these two nodes. node1, err := createTestVertex(graph.db) @@ -2063,14 +2079,12 @@ func TestFetchChanInfos(t *testing.T) { func TestIncompleteChannelPolicies(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // Create two nodes. node1, err := createTestVertex(graph.db) if err != nil { @@ -2171,13 +2185,12 @@ func TestIncompleteChannelPolicies(t *testing.T) { func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) @@ -2326,7 +2339,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { func TestPruneGraphNodes(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2334,7 +2347,6 @@ func TestPruneGraphNodes(t *testing.T) { // We'll start off by inserting our source node, to ensure that it's // the only node left after we prune the graph. - graph := db.ChannelGraph() sourceNode, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create source node: %v", err) @@ -2398,7 +2410,7 @@ func TestPruneGraphNodes(t *testing.T) { // Finally, we'll ensure that node3, the only fully unconnected node as // properly deleted from the graph and not another node in its place. - _, err = graph.FetchLightningNode(nil, node3.PubKeyBytes) + _, err = graph.FetchLightningNode(node3.PubKeyBytes) if err == nil { t.Fatalf("node 3 should have been deleted!") } @@ -2410,14 +2422,12 @@ func TestPruneGraphNodes(t *testing.T) { func TestAddChannelEdgeShellNodes(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // To start, we'll create two nodes, and only add one of them to the // channel graph. node1, err := createTestVertex(graph.db) @@ -2441,7 +2451,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { // Ensure that node1 was inserted as a full node, while node2 only has // a shell node present. - node1, err = graph.FetchLightningNode(nil, node1.PubKeyBytes) + node1, err = graph.FetchLightningNode(node1.PubKeyBytes) if err != nil { t.Fatalf("unable to fetch node1: %v", err) } @@ -2449,7 +2459,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { t.Fatalf("have shell announcement for node1, shouldn't") } - node2, err = graph.FetchLightningNode(nil, node2.PubKeyBytes) + node2, err = graph.FetchLightningNode(node2.PubKeyBytes) if err != nil { t.Fatalf("unable to fetch node2: %v", err) } @@ -2464,14 +2474,12 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { func TestNodePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'll first populate our graph with a single node that will be // removed shortly. node1, err := createTestVertex(graph.db) @@ -2534,44 +2542,41 @@ func TestNodeIsPublic(t *testing.T) { // We'll need to create a separate database and channel graph for each // participant to replicate real-world scenarios (private edges being in // some graphs but not others, etc.). - aliceDB, cleanUp, err := MakeTestDB() + aliceGraph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - aliceNode, err := createTestVertex(aliceDB) + aliceNode, err := createTestVertex(aliceGraph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - aliceGraph := aliceDB.ChannelGraph() if err := aliceGraph.SetSourceNode(aliceNode); err != nil { t.Fatalf("unable to set source node: %v", err) } - bobDB, cleanUp, err := MakeTestDB() + bobGraph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - bobNode, err := createTestVertex(bobDB) + bobNode, err := createTestVertex(bobGraph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - bobGraph := bobDB.ChannelGraph() if err := bobGraph.SetSourceNode(bobNode); err != nil { t.Fatalf("unable to set source node: %v", err) } - carolDB, cleanUp, err := MakeTestDB() + carolGraph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - carolNode, err := createTestVertex(carolDB) + carolNode, err := createTestVertex(carolGraph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) } - carolGraph := carolDB.ChannelGraph() if err := carolGraph.SetSourceNode(carolNode); err != nil { t.Fatalf("unable to set source node: %v", err) } @@ -2683,14 +2688,12 @@ func TestNodeIsPublic(t *testing.T) { func TestDisabledChannelIDs(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() - graph := db.ChannelGraph() - // Create first node and add it to the graph. node1, err := createTestVertex(graph.db) if err != nil { @@ -2781,14 +2784,12 @@ func TestDisabledChannelIDs(t *testing.T) { func TestEdgePolicyMissingMaxHtcl(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } - graph := db.ChannelGraph() - // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. node1, err := createTestVertex(graph.db) @@ -2961,12 +2962,11 @@ func TestGraphZombieIndex(t *testing.T) { t.Parallel() // We'll start by creating our test graph along with a test edge. - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() defer cleanUp() if err != nil { t.Fatalf("unable to create test database: %v", err) } - graph := db.ChannelGraph() node1, err := createTestVertex(graph.db) if err != nil { @@ -3136,7 +3136,7 @@ func compareEdgePolicies(a, b *ChannelEdgePolicy) error { return nil } -// TestLightningNodeSigVerifcation checks that we can use the LightningNode's +// TestLightningNodeSigVerification checks that we can use the LightningNode's // pubkey to verify signatures. func TestLightningNodeSigVerification(t *testing.T) { t.Parallel() @@ -3164,13 +3164,13 @@ func TestLightningNodeSigVerification(t *testing.T) { } // Create a LightningNode from the same private key. - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() if err != nil { t.Fatalf("unable to make test database: %v", err) } defer cleanUp() - node, err := createLightningNode(db, priv) + node, err := createLightningNode(graph.db, priv) if err != nil { t.Fatalf("unable to create node: %v", err) } @@ -3214,11 +3214,10 @@ func TestComputeFee(t *testing.T) { func TestBatchedAddChannelEdge(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() require.Nil(t, err) defer cleanUp() - graph := db.ChannelGraph() sourceNode, err := createTestVertex(graph.db) require.Nil(t, err) err = graph.SetSourceNode(sourceNode) @@ -3297,12 +3296,10 @@ func TestBatchedAddChannelEdge(t *testing.T) { func TestBatchedUpdateEdgePolicy(t *testing.T) { t.Parallel() - db, cleanUp, err := MakeTestDB() + graph, cleanUp, err := MakeTestGraph() require.Nil(t, err) defer cleanUp() - graph := db.ChannelGraph() - // We'd like to test the update of edges inserted into the database, so // we create two vertexes to connect. node1, err := createTestVertex(graph.db) diff --git a/routing/graph.go b/routing/graph.go index 83f06807ee..be58698f46 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -85,7 +85,7 @@ func (g *dbRoutingTx) sourceNode() route.Vertex { func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { - targetNode, err := g.graph.FetchLightningNode(g.tx, nodePub) + targetNode, err := g.graph.FetchLightningNode(nodePub) switch err { // If the node exists and has features, return them directly. diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 67fecb321e..d098429c1d 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -23,6 +23,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" @@ -148,26 +149,36 @@ type testChan struct { // makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing // purposes. A callback which cleans up the created temporary directories is // also returned and intended to be executed after the test completes. -func makeTestGraph() (*channeldb.ChannelGraph, func(), error) { +func makeTestGraph() (*channeldb.ChannelGraph, kvdb.Backend, func(), error) { // First, create a temporary directory to be used for the duration of // this test. tempDirName, err := ioutil.TempDir("", "channeldb") if err != nil { - return nil, nil, err + return nil, nil, nil, err } - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) + // Next, create channelgraph for the first time. + backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cgr") if err != nil { - return nil, nil, err + return nil, nil, nil, err } cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) + backendCleanup() + _ = os.RemoveAll(tempDirName) } - return cdb.ChannelGraph(), cleanUp, nil + opts := channeldb.DefaultOptions() + graph, err := channeldb.NewChannelGraph( + backend, opts.RejectCacheSize, opts.ChannelCacheSize, + opts.BatchCommitInterval, + ) + if err != nil { + cleanUp() + return nil, nil, nil, err + } + + return graph, backend, cleanUp, nil } // parseTestGraph returns a fully populated ChannelGraph given a path to a JSON @@ -197,7 +208,7 @@ func parseTestGraph(path string) (*testGraphInstance, error) { testAddrs = append(testAddrs, testAddr) // Next, create a temporary graph database for usage within the test. - graph, cleanUp, err := makeTestGraph() + graph, graphBackend, cleanUp, err := makeTestGraph() if err != nil { return nil, err } @@ -381,11 +392,12 @@ func parseTestGraph(path string) (*testGraphInstance, error) { } return &testGraphInstance{ - graph: graph, - cleanUp: cleanUp, - aliasMap: aliasMap, - privKeyMap: privKeyMap, - channelIDs: channelIDs, + graph: graph, + graphBackend: graphBackend, + cleanUp: cleanUp, + aliasMap: aliasMap, + privKeyMap: privKeyMap, + channelIDs: channelIDs, }, nil } @@ -447,8 +459,9 @@ type testChannel struct { } type testGraphInstance struct { - graph *channeldb.ChannelGraph - cleanUp func() + graph *channeldb.ChannelGraph + graphBackend kvdb.Backend + cleanUp func() // aliasMap is a map from a node's alias to its public key. This type is // provided in order to allow easily look up from the human memorable alias @@ -482,7 +495,7 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( testAddrs = append(testAddrs, testAddr) // Next, create a temporary graph database for usage within the test. - graph, cleanUp, err := makeTestGraph() + graph, graphBackend, cleanUp, err := makeTestGraph() if err != nil { return nil, err } @@ -671,10 +684,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( } return &testGraphInstance{ - graph: graph, - cleanUp: cleanUp, - aliasMap: aliasMap, - privKeyMap: privKeyMap, + graph: graph, + graphBackend: graphBackend, + cleanUp: cleanUp, + aliasMap: aliasMap, + privKeyMap: privKeyMap, }, nil } @@ -2120,7 +2134,7 @@ func TestPathFindSpecExample(t *testing.T) { // Carol, so we set "B" as the source node so path finding starts from // Bob. bob := ctx.aliases["B"] - bobNode, err := ctx.graph.FetchLightningNode(nil, bob) + bobNode, err := ctx.graph.FetchLightningNode(bob) if err != nil { t.Fatalf("unable to find bob: %v", err) } @@ -2170,7 +2184,7 @@ func TestPathFindSpecExample(t *testing.T) { // Next, we'll set A as the source node so we can assert that we create // the proper route for any queries starting with Alice. alice := ctx.aliases["A"] - aliceNode, err := ctx.graph.FetchLightningNode(nil, alice) + aliceNode, err := ctx.graph.FetchLightningNode(alice) if err != nil { t.Fatalf("unable to find alice: %v", err) } diff --git a/routing/router.go b/routing/router.go index 6ebf86c19d..00fa4d316a 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2505,8 +2505,10 @@ func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) ( // within the graph. // // NOTE: This method is part of the ChannelGraphSource interface. -func (r *ChannelRouter) FetchLightningNode(node route.Vertex) (*channeldb.LightningNode, error) { - return r.cfg.Graph.FetchLightningNode(nil, node) +func (r *ChannelRouter) FetchLightningNode( + node route.Vertex) (*channeldb.LightningNode, error) { + + return r.cfg.Graph.FetchLightningNode(node) } // ForEachNode is used to iterate over every node in router topology. diff --git a/routing/router_test.go b/routing/router_test.go index 2633bd5ab4..510d18bf59 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -125,8 +125,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, } mc, err := NewMissionControl( - graphInstance.graph.Database(), route.Vertex{}, - mcConfig, + graphInstance.graphBackend, route.Vertex{}, mcConfig, ) require.NoError(t, err, "failed to create missioncontrol") @@ -188,7 +187,6 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, cleanUp := func() { ctx.router.Stop() - graphInstance.cleanUp() } return ctx, cleanUp @@ -197,17 +195,10 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, func createTestCtxSingleNode(t *testing.T, startingHeight uint32) (*testCtx, func()) { - var ( - graph *channeldb.ChannelGraph - sourceNode *channeldb.LightningNode - cleanup func() - err error - ) - - graph, cleanup, err = makeTestGraph() + graph, graphBackend, cleanup, err := makeTestGraph() require.NoError(t, err, "failed to make test graph") - sourceNode, err = createTestNode() + sourceNode, err := createTestNode() require.NoError(t, err, "failed to create test node") require.NoError(t, @@ -215,8 +206,9 @@ func createTestCtxSingleNode(t *testing.T, ) graphInstance := &testGraphInstance{ - graph: graph, - cleanUp: cleanup, + graph: graph, + graphBackend: graphBackend, + cleanUp: cleanup, } return createTestCtxFromGraphInstance( @@ -1577,7 +1569,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("unable to find any routes: %v", err) } - copy1, err := ctx.graph.FetchLightningNode(nil, pub1) + copy1, err := ctx.graph.FetchLightningNode(pub1) if err != nil { t.Fatalf("unable to fetch node: %v", err) } @@ -1586,7 +1578,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("fetched node not equal to original") } - copy2, err := ctx.graph.FetchLightningNode(nil, pub2) + copy2, err := ctx.graph.FetchLightningNode(pub2) if err != nil { t.Fatalf("unable to fetch node: %v", err) } diff --git a/rpcserver.go b/rpcserver.go index 67934bac11..3cc48132d1 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -5539,7 +5539,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, // With the public key decoded, attempt to fetch the node corresponding // to this public key. If the node cannot be found, then an error will // be returned. - node, err := graph.FetchLightningNode(nil, pubKey) + node, err := graph.FetchLightningNode(pubKey) switch { case err == channeldb.ErrGraphNodeNotFound: return nil, status.Error(codes.NotFound, err.Error()) diff --git a/server.go b/server.go index f8f1f53e85..0b1afe4008 100644 --- a/server.go +++ b/server.go @@ -3921,7 +3921,7 @@ func (s *server) fetchNodeAdvertisedAddr(pub *btcec.PublicKey) (net.Addr, error) return nil, err } - node, err := s.graphDB.FetchLightningNode(nil, vertex) + node, err := s.graphDB.FetchLightningNode(vertex) if err != nil { return nil, err } From 369c09be6152b76a24915050a8a5fe6bccf2b8f0 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 21 Sep 2021 19:18:20 +0200 Subject: [PATCH 08/15] channeldb+routing: add in-memory graph Adds an in-memory channel graph cache for faster pathfinding. Original PoC by: Joost Jager Co-Authored by: Oliver Gugger --- channeldb/graph.go | 149 +++++++++++--- channeldb/graph_cache.go | 328 ++++++++++++++++++++++++++++++ channeldb/graph_cache_test.go | 110 ++++++++++ routing/graph.go | 44 +--- routing/mock_graph_test.go | 43 ++-- routing/pathfind.go | 15 +- routing/pathfind_test.go | 49 ++++- routing/payment_session_source.go | 7 +- routing/router.go | 12 -- routing/router_test.go | 12 ++ routing/unified_policies.go | 16 +- 11 files changed, 652 insertions(+), 133 deletions(-) create mode 100644 channeldb/graph_cache.go create mode 100644 channeldb/graph_cache_test.go diff --git a/channeldb/graph.go b/channeldb/graph.go index 92b04bd93d..e3ec83113f 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -179,6 +179,7 @@ type ChannelGraph struct { cacheMu sync.RWMutex rejectCache *rejectCache chanCache *channelCache + graphCache *GraphCache chanScheduler batch.Scheduler nodeScheduler batch.Scheduler @@ -197,6 +198,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, db: db, rejectCache: newRejectCache(rejectCacheSize), chanCache: newChannelCache(chanCacheSize), + graphCache: NewGraphCache(), } g.chanScheduler = batch.NewTimeScheduler( db, &g.cacheMu, batchCommitInterval, @@ -205,6 +207,19 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, db, nil, batchCommitInterval, ) + startTime := time.Now() + log.Debugf("Populating in-memory channel graph, this might take a " + + "while...") + err := g.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { + return g.graphCache.AddNode(tx, &graphCacheNode{node}) + }) + if err != nil { + return nil, err + } + + log.Debugf("Finished populating in-memory channel graph (took %v, %s)", + time.Since(startTime), g.graphCache.Stats()) + return g, nil } @@ -286,11 +301,6 @@ func initChannelGraph(db kvdb.Backend) error { return nil } -// Database returns a pointer to the underlying database. -func (c *ChannelGraph) Database() kvdb.Backend { - return c.db -} - // ForEachChannel iterates through all the channel edges stored within the // graph and invokes the passed callback for each edge. The callback takes two // edges as since this is a directed graph, both the in/out edges are visited. @@ -354,23 +364,22 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, // ForEachNodeChannel iterates through all channels of a given node, executing the // passed callback with an edge info structure and the policies of each end // of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the +// connecting node, while the second is the incoming edge *from* the // connecting node. If the callback returns an error, then the iteration is // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { +func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex, + cb func(channel *DirectedChannel) error) error { - db := c.db + return c.graphCache.ForEachChannel(node, cb) +} - return nodeTraversal(tx, nodePub, db, cb) +// FetchNodeFeatures returns the features of a given node. +func (c *ChannelGraph) FetchNodeFeatures( + node route.Vertex) (*lnwire.FeatureVector, error) { + + return c.graphCache.GetFeatures(node), nil } // DisabledChannelIDs returns the channel ids of disabled channels. @@ -549,6 +558,11 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode, r := &batch.Request{ Update: func(tx kvdb.RwTx) error { + wNode := &graphCacheNode{node} + if err := c.graphCache.AddNode(tx, wNode); err != nil { + return err + } + return addLightningNode(tx, node) }, } @@ -627,6 +641,8 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error { return ErrGraphNodeNotFound } + c.graphCache.RemoveNode(nodePub) + return c.deleteLightningNode(nodes, nodePub[:]) }, func() {}) } @@ -753,6 +769,8 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo) error return ErrEdgeAlreadyExist } + c.graphCache.AddChannel(edge, nil, nil) + // Before we insert the channel into the database, we'll ensure that // both nodes already exist in the channel graph. If either node // doesn't, then we'll insert a "shell" node that just includes its @@ -952,6 +970,8 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { return ErrEdgeNotFound } + c.graphCache.UpdateChannel(edge) + return putChanEdgeInfo(edgeIndex, edge, chanKey) }, func() {}) } @@ -1037,7 +1057,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, // will be returned if that outpoint isn't known to be // a channel. If no error is returned, then a channel // was successfully pruned. - err = delChannelEdge( + err = c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, chanID, false, false, ) @@ -1088,6 +1108,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, c.chanCache.remove(channel.ChannelID) } + log.Debugf("Pruned graph, cache now has %s", c.graphCache.Stats()) + return chansClosed, nil } @@ -1188,6 +1210,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket, continue } + c.graphCache.RemoveNode(nodePubKey) + // If we reach this point, then there are no longer any edges // that connect this node, so we can delete it. if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { @@ -1286,7 +1310,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf } for _, k := range keys { - err = delChannelEdge( + err = c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, k, false, false, ) @@ -1394,7 +1418,9 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { // true, then when we mark these edges as zombies, we'll set up the keys such // that we require the node that failed to send the fresh update to be the one // that resurrects the channel from its zombie state. -func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...uint64) error { +func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, + chanIDs ...uint64) error { + // TODO(roasbeef): possibly delete from node bucket if node has no more // channels // TODO(roasbeef): don't delete both edges? @@ -1427,7 +1453,7 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...u var rawChanID [8]byte for _, chanID := range chanIDs { byteOrder.PutUint64(rawChanID[:], chanID) - err := delChannelEdge( + err := c.delChannelEdge( edges, edgeIndex, chanIndex, zombieIndex, nodes, rawChanID[:], true, strictZombiePruning, ) @@ -1556,7 +1582,9 @@ type ChannelEdge struct { // ChanUpdatesInHorizon returns all the known channel edges which have at least // one edge that has an update timestamp within the specified horizon. -func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { +func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, + endTime time.Time) ([]ChannelEdge, error) { + // To ensure we don't return duplicate ChannelEdges, we'll use an // additional map to keep track of the edges already seen to prevent // re-adding it. @@ -1689,7 +1717,9 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha // update timestamp within the passed range. This method can be used by two // nodes to quickly determine if they have the same set of up to date node // announcements. -func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { +func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, + endTime time.Time) ([]LightningNode, error) { + var nodesInHorizon []LightningNode err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -2017,7 +2047,7 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, return nil } -func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, +func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error { edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) @@ -2025,6 +2055,11 @@ func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, return err } + c.graphCache.RemoveChannel( + edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes, + edgeInfo.ChannelID, + ) + // We'll also remove the entry in the edge update index bucket before // we delete the edges themselves so we can access their last update // times. @@ -2159,7 +2194,9 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy, }, Update: func(tx kvdb.RwTx) error { var err error - isUpdate1, err = updateEdgePolicy(tx, edge) + isUpdate1, err = updateEdgePolicy( + tx, edge, c.graphCache, + ) // Silence ErrEdgeNotFound so that the batch can // succeed, but propagate the error via local state. @@ -2222,7 +2259,9 @@ func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy, isUpdate1 bool) { // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged // to node2. -func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) { +func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, + graphCache *GraphCache) (bool, error) { + edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { return false, ErrEdgeNotFound @@ -2270,6 +2309,14 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) { return false, err } + var ( + fromNodePubKey route.Vertex + toNodePubKey route.Vertex + ) + copy(fromNodePubKey[:], fromNode) + copy(toNodePubKey[:], toNode) + graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1) + return isUpdate1, nil } @@ -2481,6 +2528,39 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) ( return node, nil } +// graphCacheNode is a struct that wraps a LightningNode in a way that it can be +// cached in the graph cache. +type graphCacheNode struct { + lnNode *LightningNode +} + +// PubKey returns the node's public identity key. +func (w *graphCacheNode) PubKey() route.Vertex { + return w.lnNode.PubKeyBytes +} + +// Features returns the node's features. +func (w *graphCacheNode) Features() *lnwire.FeatureVector { + return w.lnNode.Features +} + +// ForEachChannel iterates through all channels of this node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +func (w *graphCacheNode) ForEachChannel(tx kvdb.RTx, + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + return w.lnNode.ForEachChannel(tx, cb) +} + +var _ GraphCacheNode = (*graphCacheNode)(nil) + // HasLightningNode determines if the graph has a vertex identified by the // target node identity public key. If the node exists in the database, a // timestamp of when the data for the node was lasted updated is returned along @@ -2621,7 +2701,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // ForEachChannel iterates through all channels of this node, executing the // passed callback with an edge info structure and the policies of each end // of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the +// connecting node, while the second is the incoming edge *from* the // connecting node. If the callback returns an error, then the iteration is // halted with the error propagated back up to the caller. // @@ -2632,7 +2712,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // be nil and a fresh transaction will be created to execute the graph // traversal. func (l *LightningNode) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { nodePub := l.PubKeyBytes[:] db := l.db @@ -3490,6 +3571,8 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64, "bucket: %w", err) } + c.graphCache.RemoveChannel(pubKey1, pubKey2, chanID) + return markEdgeZombie(zombieIndex, chanID, pubKey1, pubKey2) }) if err != nil { @@ -3544,6 +3627,18 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { c.rejectCache.remove(chanID) c.chanCache.remove(chanID) + // We need to add the channel back into our graph cache, otherwise we + // won't use it for path finding. + edgeInfos, err := c.FetchChanInfos([]uint64{chanID}) + if err != nil { + return err + } + for _, edgeInfo := range edgeInfos { + c.graphCache.AddChannel( + edgeInfo.Info, edgeInfo.Policy1, edgeInfo.Policy2, + ) + } + return nil } diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go new file mode 100644 index 0000000000..d1ec6dd2a1 --- /dev/null +++ b/channeldb/graph_cache.go @@ -0,0 +1,328 @@ +package channeldb + +import ( + "fmt" + "sync" + + "github.com/btcsuite/btcutil" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// GraphCacheNode is an interface for all the information the cache needs to know +// about a lightning node. +type GraphCacheNode interface { + // PubKey is the node's public identity key. + PubKey() route.Vertex + + // Features returns the node's p2p features. + Features() *lnwire.FeatureVector + + // ForEachChannel iterates through all channels of a given node, + // executing the passed callback with an edge info structure and the + // policies of each end of the channel. The first edge policy is the + // outgoing edge *to* the connecting node, while the second is the + // incoming edge *from* the connecting node. If the callback returns an + // error, then the iteration is halted with the error propagated back up + // to the caller. + ForEachChannel(kvdb.RTx, + func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error +} + +// DirectedChannel is a type that stores the channel information as seen from +// one side of the channel. +type DirectedChannel struct { + // ChannelID is the unique identifier of this channel. + ChannelID uint64 + + // IsNode1 indicates if this is the node with the smaller public key. + IsNode1 bool + + // OtherNode is the public key of the node on the other end of this + // channel. + OtherNode route.Vertex + + // Capacity is the announced capacity of this channel in satoshis. + Capacity btcutil.Amount + + // OutPolicy is the outgoing policy from this node *to* the other node. + OutPolicy *ChannelEdgePolicy + + // InPolicy is the incoming policy *from* the other node to this node. + InPolicy *ChannelEdgePolicy +} + +// GraphCache is a type that holds a minimal set of information of the public +// channel graph that can be used for pathfinding. +type GraphCache struct { + nodeChannels map[route.Vertex]map[uint64]*DirectedChannel + nodeFeatures map[route.Vertex]*lnwire.FeatureVector + + mtx sync.RWMutex +} + +// NewGraphCache creates a new graphCache. +func NewGraphCache() *GraphCache { + return &GraphCache{ + nodeChannels: make(map[route.Vertex]map[uint64]*DirectedChannel), + nodeFeatures: make(map[route.Vertex]*lnwire.FeatureVector), + } +} + +// Stats returns statistics about the current cache size. +func (c *GraphCache) Stats() string { + c.mtx.RLock() + defer c.mtx.RUnlock() + + numChannels := 0 + for node := range c.nodeChannels { + numChannels += len(c.nodeChannels[node]) + } + return fmt.Sprintf("num_node_features=%d, num_nodes=%d, "+ + "num_channels=%d", len(c.nodeFeatures), len(c.nodeChannels), + numChannels) +} + +// AddNode adds a graph node, including all the (directed) channels of that +// node. +func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { + nodePubKey := node.PubKey() + + // Only hold the lock for a short time. The `ForEachChannel()` below is + // possibly slow as it has to go to the backend, so we can unlock + // between the calls. And the AddChannel() method will acquire its own + // lock anyway. + c.mtx.Lock() + c.nodeFeatures[nodePubKey] = node.Features() + c.mtx.Unlock() + + return node.ForEachChannel( + tx, func(tx kvdb.RTx, info *ChannelEdgeInfo, + outPolicy *ChannelEdgePolicy, + inPolicy *ChannelEdgePolicy) error { + + c.AddChannel(info, outPolicy, inPolicy) + + return nil + }, + ) +} + +// AddChannel adds a non-directed channel, meaning that the order of policy 1 +// and policy 2 does not matter, the directionality is extracted from the info +// and policy flags automatically. The policy will be set as the outgoing policy +// on one node and the incoming policy on the peer's side. +func (c *GraphCache) AddChannel(info *ChannelEdgeInfo, + policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) { + + if info == nil { + return + } + + if policy1 != nil && policy1.IsDisabled() && + policy2 != nil && policy2.IsDisabled() { + + return + } + + // Create the edge entry for both nodes. + c.mtx.Lock() + c.updateOrAddEdge(info.NodeKey1Bytes, &DirectedChannel{ + ChannelID: info.ChannelID, + IsNode1: true, + OtherNode: info.NodeKey2Bytes, + Capacity: info.Capacity, + }) + c.updateOrAddEdge(info.NodeKey2Bytes, &DirectedChannel{ + ChannelID: info.ChannelID, + IsNode1: false, + OtherNode: info.NodeKey1Bytes, + Capacity: info.Capacity, + }) + c.mtx.Unlock() + + // The policy's node is always the to_node. So if policy 1 has to_node + // of node 2 then we have the policy 1 as seen from node 1. + if policy1 != nil { + fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes + if policy1.Node.PubKeyBytes != info.NodeKey2Bytes { + fromNode, toNode = toNode, fromNode + } + isEdge1 := policy1.ChannelFlags&lnwire.ChanUpdateDirection == 0 + c.UpdatePolicy(policy1, fromNode, toNode, isEdge1) + } + if policy2 != nil { + fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes + if policy2.Node.PubKeyBytes != info.NodeKey1Bytes { + fromNode, toNode = toNode, fromNode + } + isEdge1 := policy2.ChannelFlags&lnwire.ChanUpdateDirection == 0 + c.UpdatePolicy(policy2, fromNode, toNode, isEdge1) + } +} + +// updateOrAddEdge makes sure the edge information for a node is either updated +// if it already exists or is added to that node's list of channels. +func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) { + if len(c.nodeChannels[node]) == 0 { + c.nodeChannels[node] = make(map[uint64]*DirectedChannel) + } + + c.nodeChannels[node][edge.ChannelID] = edge +} + +// UpdatePolicy updates a single policy on both the from and to node. The order +// of the from and to node is not strictly important. But we assume that a +// channel edge was added beforehand so that the directed channel struct already +// exists in the cache. +func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode, + toNode route.Vertex, edge1 bool) { + + // If a policy's node is nil, we can't cache it yet as that would lead + // to problems in pathfinding. + if policy.Node == nil { + // TODO(guggero): Fix this problem! + log.Warnf("Cannot cache policy because of missing node (from "+ + "%x to %x)", fromNode[:], toNode[:]) + return + } + + c.mtx.Lock() + defer c.mtx.Unlock() + + updatePolicy := func(nodeKey route.Vertex) { + if len(c.nodeChannels[nodeKey]) == 0 { + return + } + + channel, ok := c.nodeChannels[nodeKey][policy.ChannelID] + if !ok { + return + } + + // Edge 1 is defined as the policy for the direction of node1 to + // node2. + switch { + // This is node 1, and it is edge 1, so this is the outgoing + // policy for node 1. + case channel.IsNode1 && edge1: + channel.OutPolicy = policy + + // This is node 2, and it is edge 2, so this is the outgoing + // policy for node 2. + case !channel.IsNode1 && !edge1: + channel.OutPolicy = policy + + // The other two cases left mean it's the inbound policy for the + // node. + default: + channel.InPolicy = policy + } + } + + updatePolicy(fromNode) + updatePolicy(toNode) +} + +// RemoveNode completely removes a node and all its channels (including the +// peer's side). +func (c *GraphCache) RemoveNode(node route.Vertex) { + c.mtx.Lock() + defer c.mtx.Unlock() + + delete(c.nodeFeatures, node) + + // First remove all channels from the other nodes' lists. + for _, channel := range c.nodeChannels[node] { + c.removeChannelIfFound(channel.OtherNode, channel.ChannelID) + } + + // Then remove our whole node completely. + delete(c.nodeChannels, node) +} + +// RemoveChannel removes a single channel between two nodes. +func (c *GraphCache) RemoveChannel(node1, node2 route.Vertex, chanID uint64) { + c.mtx.Lock() + defer c.mtx.Unlock() + + // Remove that one channel from both sides. + c.removeChannelIfFound(node1, chanID) + c.removeChannelIfFound(node2, chanID) +} + +// removeChannelIfFound removes a single channel from one side. +func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) { + if len(c.nodeChannels[node]) == 0 { + return + } + + delete(c.nodeChannels[node], chanID) +} + +// UpdateChannel updates the channel edge information for a specific edge. We +// expect the edge to already exist and be known. If it does not yet exist, this +// call is a no-op. +func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo) { + c.mtx.Lock() + defer c.mtx.Unlock() + + if len(c.nodeChannels[info.NodeKey1Bytes]) == 0 || + len(c.nodeChannels[info.NodeKey2Bytes]) == 0 { + + return + } + + channel, ok := c.nodeChannels[info.NodeKey1Bytes][info.ChannelID] + if ok { + // We only expect to be called when the channel is already + // known. + channel.Capacity = info.Capacity + channel.OtherNode = info.NodeKey2Bytes + } + + channel, ok = c.nodeChannels[info.NodeKey2Bytes][info.ChannelID] + if ok { + channel.Capacity = info.Capacity + channel.OtherNode = info.NodeKey1Bytes + } +} + +// ForEachChannel invokes the given callback for each channel of the given node. +func (c *GraphCache) ForEachChannel(node route.Vertex, + cb func(channel *DirectedChannel) error) error { + + c.mtx.RLock() + defer c.mtx.RUnlock() + + channels, ok := c.nodeChannels[node] + if !ok { + return nil + } + + for _, channel := range channels { + if err := cb(channel); err != nil { + return err + } + } + + return nil +} + +// GetFeatures returns the features of the node with the given ID. +func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector { + c.mtx.RLock() + defer c.mtx.RUnlock() + + features, ok := c.nodeFeatures[node] + if !ok || features == nil { + // The router expects the features to never be nil, so we return + // an empty feature set instead. + return lnwire.EmptyFeatureVector() + } + + return features +} diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go new file mode 100644 index 0000000000..71967c68cb --- /dev/null +++ b/channeldb/graph_cache_test.go @@ -0,0 +1,110 @@ +package channeldb + +import ( + "encoding/hex" + "testing" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" +) + +var ( + pubKey1Bytes, _ = hex.DecodeString( + "0248f5cba4c6da2e4c9e01e81d1404dfac0cbaf3ee934a4fc117d2ea9a64" + + "22c91d", + ) + pubKey2Bytes, _ = hex.DecodeString( + "038155ba86a8d3b23c806c855097ca5c9fa0f87621f1e7a7d2835ad057f6" + + "f4484f", + ) + + pubKey1, _ = route.NewVertexFromBytes(pubKey1Bytes) + pubKey2, _ = route.NewVertexFromBytes(pubKey2Bytes) +) + +type node struct { + pubKey route.Vertex + features *lnwire.FeatureVector + + edgeInfos []*ChannelEdgeInfo + outPolicies []*ChannelEdgePolicy + inPolicies []*ChannelEdgePolicy +} + +func (n *node) PubKey() route.Vertex { + return n.pubKey +} +func (n *node) Features() *lnwire.FeatureVector { + return n.features +} + +func (n *node) ForEachChannel(tx kvdb.RTx, + cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + for idx := range n.edgeInfos { + err := cb( + tx, n.edgeInfos[idx], n.outPolicies[idx], + n.inPolicies[idx], + ) + if err != nil { + return err + } + } + + return nil +} + +// TestGraphCacheAddNode tests that a channel going from node A to node B can be +// cached correctly, independent of the direction we add the channel as. +func TestGraphCacheAddNode(t *testing.T) { + runTest := func(nodeA, nodeB route.Vertex) { + t.Helper() + + outPolicy1 := &ChannelEdgePolicy{ + ChannelID: 1000, + ChannelFlags: 0, + Node: &LightningNode{ + PubKeyBytes: nodeB, + }, + } + inPolicy1 := &ChannelEdgePolicy{ + ChannelID: 1000, + ChannelFlags: 1, + Node: &LightningNode{ + PubKeyBytes: nodeA, + }, + } + node := &node{ + pubKey: nodeA, + features: lnwire.EmptyFeatureVector(), + edgeInfos: []*ChannelEdgeInfo{{ + ChannelID: 1000, + // Those are direction independent! + NodeKey1Bytes: pubKey1, + NodeKey2Bytes: pubKey2, + Capacity: 500, + }}, + outPolicies: []*ChannelEdgePolicy{outPolicy1}, + inPolicies: []*ChannelEdgePolicy{inPolicy1}, + } + cache := NewGraphCache() + require.NoError(t, cache.AddNode(nil, node)) + + fromChannels := cache.nodeChannels[nodeA] + toChannels := cache.nodeChannels[nodeB] + + require.Len(t, fromChannels, 1) + require.Len(t, toChannels, 1) + + require.Equal(t, outPolicy1, fromChannels[0].OutPolicy) + require.Equal(t, inPolicy1, fromChannels[0].InPolicy) + + require.Equal(t, inPolicy1, toChannels[0].OutPolicy) + require.Equal(t, outPolicy1, toChannels[0].InPolicy) + } + runTest(pubKey1, pubKey2) + runTest(pubKey2, pubKey1) +} diff --git a/routing/graph.go b/routing/graph.go index be58698f46..578f480abf 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -2,7 +2,6 @@ package routing import ( "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -12,8 +11,7 @@ import ( type routingGraph interface { // forEachNodeChannel calls the callback for every channel of the given node. forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error + cb func(channel *channeldb.DirectedChannel) error) error // sourceNode returns the source node of the graph. sourceNode() route.Vertex @@ -26,7 +24,6 @@ type routingGraph interface { // database. type dbRoutingTx struct { graph *channeldb.ChannelGraph - tx kvdb.RTx source route.Vertex } @@ -38,37 +35,19 @@ func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { return nil, err } - tx, err := graph.Database().BeginReadTx() - if err != nil { - return nil, err - } - return &dbRoutingTx{ graph: graph, - tx: tx, source: sourceNode.PubKeyBytes, }, nil } -// close closes the underlying db transaction. -func (g *dbRoutingTx) close() error { - return g.tx.Rollback() -} - // forEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the routingGraph interface. func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error { - - txCb := func(_ kvdb.RTx, info *channeldb.ChannelEdgeInfo, - p1, p2 *channeldb.ChannelEdgePolicy) error { + cb func(channel *channeldb.DirectedChannel) error) error { - return cb(info, p1, p2) - } - - return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb) + return g.graph.ForEachNodeChannel(nodePub, cb) } // sourceNode returns the source node of the graph. @@ -85,20 +64,5 @@ func (g *dbRoutingTx) sourceNode() route.Vertex { func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { - targetNode, err := g.graph.FetchLightningNode(nodePub) - switch err { - - // If the node exists and has features, return them directly. - case nil: - return targetNode.Features, nil - - // If we couldn't find a node announcement, populate a blank feature - // vector. - case channeldb.ErrGraphNodeNotFound: - return lnwire.EmptyFeatureVector(), nil - - // Otherwise bubble the error up. - default: - return nil, err - } + return g.graph.FetchNodeFeatures(nodePub) } diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index 3834d9e51a..badeeebb99 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -159,8 +159,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte, // // NOTE: Part of the routingGraph interface. func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, - cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, - *channeldb.ChannelEdgePolicy) error) error { + cb func(channel *channeldb.DirectedChannel) error) error { // Look up the mock node. node, ok := m.nodes[nodePub] @@ -171,36 +170,38 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, // Iterate over all of its channels. for peer, channel := range node.channels { // Lexicographically sort the pubkeys. - var node1, node2 route.Vertex + var node1 route.Vertex if bytes.Compare(nodePub[:], peer[:]) == -1 { - node1, node2 = peer, nodePub + node1 = peer } else { - node1, node2 = nodePub, peer + node1 = nodePub } peerNode := m.nodes[peer] // Call the per channel callback. err := cb( - &channeldb.ChannelEdgeInfo{ - NodeKey1Bytes: node1, - NodeKey2Bytes: node2, - }, - &channeldb.ChannelEdgePolicy{ + &channeldb.DirectedChannel{ ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: peer, - Features: lnwire.EmptyFeatureVector(), + IsNode1: nodePub == node1, + OtherNode: peer, + Capacity: channel.capacity, + OutPolicy: &channeldb.ChannelEdgePolicy{ + ChannelID: channel.id, + Node: &channeldb.LightningNode{ + PubKeyBytes: peer, + Features: lnwire.EmptyFeatureVector(), + }, + FeeBaseMSat: node.baseFee, }, - FeeBaseMSat: node.baseFee, - }, - &channeldb.ChannelEdgePolicy{ - ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: nodePub, - Features: lnwire.EmptyFeatureVector(), + InPolicy: &channeldb.ChannelEdgePolicy{ + ChannelID: channel.id, + Node: &channeldb.LightningNode{ + PubKeyBytes: nodePub, + Features: lnwire.EmptyFeatureVector(), + }, + FeeBaseMSat: peerNode.baseFee, }, - FeeBaseMSat: peerNode.baseFee, }, ) if err != nil { diff --git a/routing/pathfind.go b/routing/pathfind.go index 3d722c8221..fc3be79425 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -359,14 +359,12 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi - cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge, - _ *channeldb.ChannelEdgePolicy) error { - - if outEdge == nil { + cb := func(channel *channeldb.DirectedChannel) error { + if channel.OutPolicy == nil { return nil } - chanID := outEdge.ChannelID + chanID := channel.ChannelID // Enforce outgoing channel restriction. if outgoingChans != nil { @@ -381,9 +379,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // This can happen when a channel is added to the graph after // we've already queried the bandwidth hints. if !ok { - bandwidth = lnwire.NewMSatFromSatoshis( - edgeInfo.Capacity, - ) + bandwidth = lnwire.NewMSatFromSatoshis(channel.Capacity) } if bandwidth > max { @@ -889,7 +885,8 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Determine the next hop forward using the next map. currentNodeWithDist, ok := distance[currentNode] if !ok { - // If the node doesnt have a next hop it means we didn't find a path. + // If the node doesn't have a next hop it means we + // didn't find a path. return nil, errNoPathFound } diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index d098429c1d..b353c24eae 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -304,6 +304,16 @@ func parseTestGraph(path string) (*testGraphInstance, error) { } } + aliasForNode := func(node route.Vertex) string { + for alias, pubKey := range aliasMap { + if pubKey == node { + return alias + } + } + + return "" + } + // With all the vertexes inserted, we can now insert the edges into the // test graph. for _, edge := range g.Edges { @@ -353,10 +363,17 @@ func parseTestGraph(path string) (*testGraphInstance, error) { return nil, err } + channelFlags := lnwire.ChanUpdateChanFlags(edge.ChannelFlags) + isUpdate1 := channelFlags&lnwire.ChanUpdateDirection == 0 + targetNode := edgeInfo.NodeKey1Bytes + if isUpdate1 { + targetNode = edgeInfo.NodeKey2Bytes + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), - ChannelFlags: lnwire.ChanUpdateChanFlags(edge.ChannelFlags), + ChannelFlags: channelFlags, ChannelID: edge.ChannelID, LastUpdate: testTime, TimeLockDelta: edge.Expiry, @@ -364,6 +381,10 @@ func parseTestGraph(path string) (*testGraphInstance, error) { MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC), FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat), FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), + Node: &channeldb.LightningNode{ + Alias: aliasForNode(targetNode), + PubKeyBytes: targetNode, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -635,6 +656,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( channelFlags |= lnwire.ChanUpdateDisabled } + node2Features := lnwire.EmptyFeatureVector() + if node2.testChannelPolicy != nil { + node2Features = node2.Features + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, @@ -646,6 +672,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( MaxHTLC: node1.MaxHTLC, FeeBaseMSat: node1.FeeBaseMsat, FeeProportionalMillionths: node1.FeeRate, + Node: &channeldb.LightningNode{ + Alias: node2.Alias, + PubKeyBytes: node2Vertex, + Features: node2Features, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -663,6 +694,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( } channelFlags |= lnwire.ChanUpdateDirection + node1Features := lnwire.EmptyFeatureVector() + if node1.testChannelPolicy != nil { + node1Features = node1.Features + } + edgePolicy := &channeldb.ChannelEdgePolicy{ SigBytes: testSig.Serialize(), MessageFlags: msgFlags, @@ -674,6 +710,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( MaxHTLC: node2.MaxHTLC, FeeBaseMSat: node2.FeeBaseMsat, FeeProportionalMillionths: node2.FeeRate, + Node: &channeldb.LightningNode{ + Alias: node1.Alias, + PubKeyBytes: node1Vertex, + Features: node1Features, + }, } if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { return nil, err @@ -2980,12 +3021,6 @@ func dbFindPath(graph *channeldb.ChannelGraph, if err != nil { return nil, err } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() return findPath( &graphParams{ diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 8122ff7117..f080059099 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -47,12 +47,7 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { if err != nil { return nil, nil, err } - return routingTx, func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }, nil + return routingTx, func() {}, nil } // NewPaymentSession creates a new payment session backed by the latest prune diff --git a/routing/router.go b/routing/router.go index 00fa4d316a..9864a991d6 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1756,12 +1756,6 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, if err != nil { return nil, err } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() path, err := findPath( &graphParams{ @@ -2763,12 +2757,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, if err != nil { return nil, err } - defer func() { - err := routingTx.close() - if err != nil { - log.Errorf("Error closing db tx: %v", err) - } - }() // Traverse hops backwards to accumulate fees in the running amounts. source := r.selfNode.PubKeyBytes diff --git a/routing/router_test.go b/routing/router_test.go index 510d18bf59..1633d3810d 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1393,6 +1393,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey2Bytes, + }, } edgePolicy.ChannelFlags = 0 @@ -1409,6 +1412,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey1Bytes, + }, } edgePolicy.ChannelFlags = 1 @@ -1490,6 +1496,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey2Bytes, + }, } edgePolicy.ChannelFlags = 0 @@ -1505,6 +1514,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { MinHTLC: 1, FeeBaseMSat: 10, FeeProportionalMillionths: 10000, + Node: &channeldb.LightningNode{ + PubKeyBytes: edge.NodeKey1Bytes, + }, } edgePolicy.ChannelFlags = 1 diff --git a/routing/unified_policies.go b/routing/unified_policies.go index 0ff509382e..4a6e5e00ba 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -69,24 +69,18 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { - cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _, - inEdge *channeldb.ChannelEdgePolicy) error { - + cb := func(channel *channeldb.DirectedChannel) error { // If there is no edge policy for this candidate node, skip. // Note that we are searching backwards so this node would have // come prior to the pivot node in the route. - if inEdge == nil { + if channel.InPolicy == nil { return nil } - // The node on the other end of this channel is the from node. - fromNode, err := edgeInfo.OtherNodeKeyBytes(u.toNode[:]) - if err != nil { - return err - } - // Add this policy to the unified policies map. - u.addPolicy(fromNode, inEdge, edgeInfo.Capacity) + u.addPolicy( + channel.OtherNode, channel.InPolicy, channel.Capacity, + ) return nil } From 15d3f62d5e1b2f90728ddefc5f53d42735b193af Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:21 +0200 Subject: [PATCH 09/15] multi: use cache for source channels --- channeldb/graph.go | 1 + htlcswitch/switch.go | 11 +++++++++++ routing/payment_lifecycle_test.go | 4 ++-- routing/payment_session_source.go | 6 ++++-- routing/router.go | 26 +++++++++++++------------- routing/router_test.go | 6 +++--- server.go | 6 +++--- 7 files changed, 37 insertions(+), 23 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index e3ec83113f..8806bcff10 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2315,6 +2315,7 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, ) copy(fromNodePubKey[:], fromNode) copy(toNodePubKey[:], toNode) + // TODO(guggero): Fetch lightning nodes before updating the cache! graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1) return isUpdate1, nil diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 17b4238573..5f02d406cf 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2052,6 +2052,17 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) { return link, nil } +// GetLinkByShortID attempts to return the link which possesses the target short +// channel ID. +func (s *Switch) GetLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, + error) { + + s.indexMtx.RLock() + defer s.indexMtx.RUnlock() + + return s.getLinkByShortID(chanID) +} + // getLinkByShortID attempts to return the link which possesses the target // short channel ID. // diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 6a35170235..d233d8bdea 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -472,8 +472,8 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, Payer: payer, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, - QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { - return lnwire.NewMSatFromSatoshis(e.Capacity) + QueryBandwidth: func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi { + return lnwire.NewMSatFromSatoshis(c.Capacity) }, NextPaymentID: func() (uint64, error) { next := atomic.AddUint64(&uniquePaymentID, 1) diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index f080059099..661d5861db 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -24,7 +24,7 @@ type SessionSource struct { // to be traversed. If the link isn't available, then a value of zero // should be returned. Otherwise, the current up to date knowledge of // the available bandwidth of the link should be returned. - QueryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi + QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi // MissionControl is a shared memory of sorts that executions of payment // path finding use in order to remember which vertexes/edges were @@ -65,7 +65,9 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, error) { - return generateBandwidthHints(sourceNode, m.QueryBandwidth) + return generateBandwidthHints( + sourceNode.PubKeyBytes, m.Graph, m.QueryBandwidth, + ) } session, err := newPaymentSession( diff --git a/routing/router.go b/routing/router.go index 9864a991d6..aa034eea02 100644 --- a/routing/router.go +++ b/routing/router.go @@ -339,7 +339,7 @@ type Config struct { // a value of zero should be returned. Otherwise, the current up to // date knowledge of the available bandwidth of the link should be // returned. - QueryBandwidth func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi + QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi // NextPaymentID is a method that guarantees to return a new, unique ID // each time it is called. This is used by the router to generate a @@ -1735,7 +1735,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := generateBandwidthHints( - r.selfNode, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err @@ -2657,19 +2657,19 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error { // these hints allows us to reduce the number of extraneous attempts as we can // skip channels that are inactive, or just don't have enough bandwidth to // carry the payment. -func generateBandwidthHints(sourceNode *channeldb.LightningNode, - queryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) (map[uint64]lnwire.MilliSatoshi, error) { +func generateBandwidthHints(sourceNode route.Vertex, graph *channeldb.ChannelGraph, + queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) ( + map[uint64]lnwire.MilliSatoshi, error) { // First, we'll collect the set of outbound edges from the target // source node. - var localChans []*channeldb.ChannelEdgeInfo - err := sourceNode.ForEachChannel(nil, func(tx kvdb.RTx, - edgeInfo *channeldb.ChannelEdgeInfo, - _, _ *channeldb.ChannelEdgePolicy) error { - - localChans = append(localChans, edgeInfo) - return nil - }) + var localChans []*channeldb.DirectedChannel + err := graph.ForEachNodeChannel( + sourceNode, func(channel *channeldb.DirectedChannel) error { + localChans = append(localChans, channel) + return nil + }, + ) if err != nil { return nil, err } @@ -2722,7 +2722,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := generateBandwidthHints( - r.selfNode, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err diff --git a/routing/router_test.go b/routing/router_test.go index 1633d3810d..d263ce7384 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -132,9 +132,9 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, sessionSource := &SessionSource{ Graph: graphInstance.graph, QueryBandwidth: func( - e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + c *channeldb.DirectedChannel) lnwire.MilliSatoshi { - return lnwire.NewMSatFromSatoshis(e.Capacity) + return lnwire.NewMSatFromSatoshis(c.Capacity) }, PathFindingConfig: pathFindingConfig, MissionControl: mc, @@ -158,7 +158,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, QueryBandwidth: func( - e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + e *channeldb.DirectedChannel) lnwire.MilliSatoshi { return lnwire.NewMSatFromSatoshis(e.Capacity) }, diff --git a/server.go b/server.go index 0b1afe4008..55e49d9f26 100644 --- a/server.go +++ b/server.go @@ -710,9 +710,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } - queryBandwidth := func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { - cid := lnwire.NewChanIDFromOutPoint(&edge.ChannelPoint) - link, err := s.htlcSwitch.GetLink(cid) + queryBandwidth := func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi { + cid := lnwire.NewShortChanIDFromInt(c.ChannelID) + link, err := s.htlcSwitch.GetLinkByShortID(cid) if err != nil { // If the link isn't online, then we'll report // that it has zero bandwidth to the router. From 1d1c42f9bae49ac5d3dc24eba7a1451668e83464 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:22 +0200 Subject: [PATCH 10/15] multi: use minimal policy in cache --- channeldb/graph.go | 1 - channeldb/graph_cache.go | 155 ++++++++- channeldb/graph_cache_test.go | 53 ++- channeldb/graph_test.go | 462 +++++++++++++++++++------ lnrpc/routerrpc/router_backend.go | 2 +- lnrpc/routerrpc/router_backend_test.go | 2 +- routing/heap.go | 2 +- routing/mock_graph_test.go | 17 +- routing/mock_test.go | 10 +- routing/pathfind.go | 28 +- routing/pathfind_test.go | 58 ++-- routing/payment_lifecycle.go | 2 +- routing/payment_session.go | 10 +- routing/payment_session_source.go | 13 +- routing/payment_session_test.go | 13 +- routing/router.go | 4 +- routing/router_test.go | 4 +- routing/unified_policies.go | 14 +- routing/unified_policies_test.go | 6 +- 19 files changed, 629 insertions(+), 227 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 8806bcff10..e3ec83113f 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2315,7 +2315,6 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, ) copy(fromNodePubKey[:], fromNode) copy(toNodePubKey[:], toNode) - // TODO(guggero): Fetch lightning nodes before updating the cache! graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1) return isUpdate1, nil diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index d1ec6dd2a1..f36d022fbf 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -32,6 +32,92 @@ type GraphCacheNode interface { *ChannelEdgePolicy) error) error } +// CachedEdgePolicy is a struct that only caches the information of a +// ChannelEdgePolicy that we actually use for pathfinding and therefore need to +// store in the cache. +type CachedEdgePolicy struct { + // ChannelID is the unique channel ID for the channel. The first 3 + // bytes are the block height, the next 3 the index within the block, + // and the last 2 bytes are the output index for the channel. + ChannelID uint64 + + // MessageFlags is a bitfield which indicates the presence of optional + // fields (like max_htlc) in the policy. + MessageFlags lnwire.ChanUpdateMsgFlags + + // ChannelFlags is a bitfield which signals the capabilities of the + // channel as well as the directed edge this update applies to. + ChannelFlags lnwire.ChanUpdateChanFlags + + // TimeLockDelta is the number of blocks this node will subtract from + // the expiry of an incoming HTLC. This value expresses the time buffer + // the node would like to HTLC exchanges. + TimeLockDelta uint16 + + // MinHTLC is the smallest value HTLC this node will forward, expressed + // in millisatoshi. + MinHTLC lnwire.MilliSatoshi + + // MaxHTLC is the largest value HTLC this node will forward, expressed + // in millisatoshi. + MaxHTLC lnwire.MilliSatoshi + + // FeeBaseMSat is the base HTLC fee that will be charged for forwarding + // ANY HTLC, expressed in mSAT's. + FeeBaseMSat lnwire.MilliSatoshi + + // FeeProportionalMillionths is the rate that the node will charge for + // HTLCs for each millionth of a satoshi forwarded. + FeeProportionalMillionths lnwire.MilliSatoshi + + // ToNodePubKey is a function that returns the to node of a policy. + // Since we only ever store the inbound policy, this is always the node + // that we query the channels for in ForEachChannel(). Therefore, we can + // save a lot of space by not storing this information in the memory and + // instead just set this function when we copy the policy from cache in + // ForEachChannel(). + ToNodePubKey func() route.Vertex + + // ToNodeFeatures are the to node's features. They are never set while + // the edge is in the cache, only on the copy that is returned in + // ForEachChannel(). + ToNodeFeatures *lnwire.FeatureVector +} + +// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over +// the passed active payment channel. This value is currently computed as +// specified in BOLT07, but will likely change in the near future. +func (c *CachedEdgePolicy) ComputeFee( + amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts +} + +// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming +// amount. +func (c *CachedEdgePolicy) ComputeFeeFromIncoming( + incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return incomingAmt - divideCeil( + feeRateParts*(incomingAmt-c.FeeBaseMSat), + feeRateParts+c.FeeProportionalMillionths, + ) +} + +// NewCachedPolicy turns a full policy into a minimal one that can be cached. +func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy { + return &CachedEdgePolicy{ + ChannelID: policy.ChannelID, + MessageFlags: policy.MessageFlags, + ChannelFlags: policy.ChannelFlags, + TimeLockDelta: policy.TimeLockDelta, + MinHTLC: policy.MinHTLC, + MaxHTLC: policy.MaxHTLC, + FeeBaseMSat: policy.FeeBaseMSat, + FeeProportionalMillionths: policy.FeeProportionalMillionths, + } +} + // DirectedChannel is a type that stores the channel information as seen from // one side of the channel. type DirectedChannel struct { @@ -48,11 +134,35 @@ type DirectedChannel struct { // Capacity is the announced capacity of this channel in satoshis. Capacity btcutil.Amount - // OutPolicy is the outgoing policy from this node *to* the other node. - OutPolicy *ChannelEdgePolicy + // OutPolicySet is a boolean that indicates whether the node has an + // outgoing policy set. For pathfinding only the existence of the policy + // is important to know, not the actual content. + OutPolicySet bool // InPolicy is the incoming policy *from* the other node to this node. - InPolicy *ChannelEdgePolicy + // In path finding, we're walking backward from the destination to the + // source, so we're always interested in the edge that arrives to us + // from the other node. + InPolicy *CachedEdgePolicy +} + +// DeepCopy creates a deep copy of the channel, including the incoming policy. +func (c *DirectedChannel) DeepCopy() *DirectedChannel { + channelCopy := *c + + if channelCopy.InPolicy != nil { + inPolicyCopy := *channelCopy.InPolicy + channelCopy.InPolicy = &inPolicyCopy + + // The fields for the ToNode can be overwritten by the path + // finding algorithm, which is why we need a deep copy in the + // first place. So we always start out with nil values, just to + // be sure they don't contain any old data. + channelCopy.InPolicy.ToNodePubKey = nil + channelCopy.InPolicy.ToNodeFeatures = nil + } + + return &channelCopy } // GraphCache is a type that holds a minimal set of information of the public @@ -181,15 +291,6 @@ func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) { func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode, toNode route.Vertex, edge1 bool) { - // If a policy's node is nil, we can't cache it yet as that would lead - // to problems in pathfinding. - if policy.Node == nil { - // TODO(guggero): Fix this problem! - log.Warnf("Cannot cache policy because of missing node (from "+ - "%x to %x)", fromNode[:], toNode[:]) - return - } - c.mtx.Lock() defer c.mtx.Unlock() @@ -209,17 +310,17 @@ func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode, // This is node 1, and it is edge 1, so this is the outgoing // policy for node 1. case channel.IsNode1 && edge1: - channel.OutPolicy = policy + channel.OutPolicySet = true // This is node 2, and it is edge 2, so this is the outgoing // policy for node 2. case !channel.IsNode1 && !edge1: - channel.OutPolicy = policy + channel.OutPolicySet = true // The other two cases left mean it's the inbound policy for the // node. default: - channel.InPolicy = policy + channel.InPolicy = NewCachedPolicy(policy) } } @@ -303,8 +404,30 @@ func (c *GraphCache) ForEachChannel(node route.Vertex, return nil } + features, ok := c.nodeFeatures[node] + if !ok { + log.Warnf("Node %v has no features defined, falling back to "+ + "default feature vector for path finding", node) + + features = lnwire.EmptyFeatureVector() + } + + toNodeCallback := func() route.Vertex { + return node + } + for _, channel := range channels { - if err := cb(channel); err != nil { + // We need to copy the channel and policy to avoid it being + // updated in the cache if the path finding algorithm sets + // fields on it (currently only the ToNodeFeatures of the + // policy). + channelCopy := channel.DeepCopy() + if channelCopy.InPolicy != nil { + channelCopy.InPolicy.ToNodePubKey = toNodeCallback + channelCopy.InPolicy.ToNodeFeatures = features + } + + if err := cb(channelCopy); err != nil { return err } } diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index 71967c68cb..57666e1eb9 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -63,18 +63,25 @@ func TestGraphCacheAddNode(t *testing.T) { runTest := func(nodeA, nodeB route.Vertex) { t.Helper() + channelFlagA, channelFlagB := 0, 1 + if nodeA == pubKey2 { + channelFlagA, channelFlagB = 1, 0 + } + outPolicy1 := &ChannelEdgePolicy{ ChannelID: 1000, - ChannelFlags: 0, + ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA), Node: &LightningNode{ PubKeyBytes: nodeB, + Features: lnwire.EmptyFeatureVector(), }, } inPolicy1 := &ChannelEdgePolicy{ ChannelID: 1000, - ChannelFlags: 1, + ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB), Node: &LightningNode{ PubKeyBytes: nodeA, + Features: lnwire.EmptyFeatureVector(), }, } node := &node{ @@ -93,18 +100,48 @@ func TestGraphCacheAddNode(t *testing.T) { cache := NewGraphCache() require.NoError(t, cache.AddNode(nil, node)) - fromChannels := cache.nodeChannels[nodeA] - toChannels := cache.nodeChannels[nodeB] + var fromChannels, toChannels []*DirectedChannel + _ = cache.ForEachChannel(nodeA, func(c *DirectedChannel) error { + fromChannels = append(fromChannels, c) + return nil + }) + _ = cache.ForEachChannel(nodeB, func(c *DirectedChannel) error { + toChannels = append(toChannels, c) + return nil + }) require.Len(t, fromChannels, 1) require.Len(t, toChannels, 1) - require.Equal(t, outPolicy1, fromChannels[0].OutPolicy) - require.Equal(t, inPolicy1, fromChannels[0].InPolicy) + require.Equal(t, outPolicy1 != nil, fromChannels[0].OutPolicySet) + assertCachedPolicyEqual(t, inPolicy1, fromChannels[0].InPolicy) - require.Equal(t, inPolicy1, toChannels[0].OutPolicy) - require.Equal(t, outPolicy1, toChannels[0].InPolicy) + require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet) + assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy) } + runTest(pubKey1, pubKey2) runTest(pubKey2, pubKey1) } + +func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy, + cached *CachedEdgePolicy) { + + require.Equal(t, original.ChannelID, cached.ChannelID) + require.Equal(t, original.MessageFlags, cached.MessageFlags) + require.Equal(t, original.ChannelFlags, cached.ChannelFlags) + require.Equal(t, original.TimeLockDelta, cached.TimeLockDelta) + require.Equal(t, original.MinHTLC, cached.MinHTLC) + require.Equal(t, original.MaxHTLC, cached.MaxHTLC) + require.Equal(t, original.FeeBaseMSat, cached.FeeBaseMSat) + require.Equal( + t, original.FeeProportionalMillionths, + cached.FeeProportionalMillionths, + ) + require.Equal( + t, + route.Vertex(original.Node.PubKeyBytes), + cached.ToNodePubKey(), + ) + require.Equal(t, original.Node.Features, cached.ToNodeFeatures) +} diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index d2953a523b..e624105a36 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -42,7 +42,10 @@ var ( _, _ = testSig.R.SetString("63724406601629180062774974542967536251589935445068131219452686511677818569431", 10) _, _ = testSig.S.SetString("18801056069249825825291287104931333862866033135609736119018462340006816851118", 10) - testFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) + testFeatures = lnwire.NewFeatureVector( + lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), + lnwire.Features, + ) testPub = route.Vertex{2, 202, 4} ) @@ -146,6 +149,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node, testFeatures) // Next, fetch the node from the database to ensure everything was // serialized properly. @@ -170,6 +174,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { if err := graph.DeleteLightningNode(testPub); err != nil { t.Fatalf("unable to delete node; %v", err) } + assertNodeNotInCache(t, graph, testPub) // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. @@ -200,6 +205,7 @@ func TestPartialNode(t *testing.T) { if err := graph.AddLightningNode(node); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node, nil) // Next, fetch the node from the database to ensure everything was // serialized properly. @@ -232,6 +238,7 @@ func TestPartialNode(t *testing.T) { if err := graph.DeleteLightningNode(testPub); err != nil { t.Fatalf("unable to delete node: %v", err) } + assertNodeNotInCache(t, graph, testPub) // Finally, attempt to fetch the node again. This should fail as the // node should have been deleted from the database. @@ -390,6 +397,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { if err := graph.AddChannelEdge(&edgeInfo); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo) // Ensure that both policies are returned as unknown (nil). _, e1, e2, err := graph.FetchChannelEdgesByID(chanID) @@ -405,6 +413,7 @@ func TestEdgeInsertionDeletion(t *testing.T) { if err := graph.DeleteChannelEdges(false, chanID); err != nil { t.Fatalf("unable to delete edge: %v", err) } + assertNoEdge(t, graph, chanID) // Ensure that any query attempts to lookup the delete channel edge are // properly deleted. @@ -544,6 +553,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { if err := graph.AddChannelEdge(&edgeInfo3); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo2) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo3) // Call DisconnectBlockAtHeight, which should prune every channel // that has a funding height of 'height' or greater. @@ -551,6 +563,9 @@ func TestDisconnectBlockAtHeight(t *testing.T) { if err != nil { t.Fatalf("unable to prune %v", err) } + assertNoEdge(t, graph, edgeInfo.ChannelID) + assertNoEdge(t, graph, edgeInfo2.ChannelID) + assertEdgeWithNoPoliciesInCache(t, graph, &edgeInfo3) // The two edges should have been removed. if len(removed) != 2 { @@ -769,6 +784,7 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.AddLightningNode(node1); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node1, testFeatures) node2, err := createTestVertex(graph.db) if err != nil { t.Fatalf("unable to create test node: %v", err) @@ -776,6 +792,7 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.AddLightningNode(node2); err != nil { t.Fatalf("unable to add node: %v", err) } + assertNodeInCache(t, graph, node2, testFeatures) // Create an edge and add it to the db. edgeInfo, edge1, edge2 := createChannelEdge(graph.db, node1, node2) @@ -785,11 +802,13 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound { t.Fatalf("expected ErrEdgeNotFound, got: %v", err) } + require.Len(t, graph.graphCache.nodeChannels, 0) // Add the edge info. if err := graph.AddChannelEdge(edgeInfo); err != nil { t.Fatalf("unable to create channel edge: %v", err) } + assertEdgeWithNoPoliciesInCache(t, graph, edgeInfo) chanID := edgeInfo.ChannelID outpoint := edgeInfo.ChannelPoint @@ -799,9 +818,11 @@ func TestEdgeInfoUpdates(t *testing.T) { if err := graph.UpdateEdgePolicy(edge1); err != nil { t.Fatalf("unable to update edge: %v", err) } + assertEdgeWithPolicyInCache(t, graph, edgeInfo, edge1, true) if err := graph.UpdateEdgePolicy(edge2); err != nil { t.Fatalf("unable to update edge: %v", err) } + assertEdgeWithPolicyInCache(t, graph, edgeInfo, edge2, false) // Check for existence of the edge within the database, it should be // found. @@ -856,6 +877,191 @@ func TestEdgeInfoUpdates(t *testing.T) { assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) } +func assertNodeInCache(t *testing.T, g *ChannelGraph, n *LightningNode, + expectedFeatures *lnwire.FeatureVector) { + + // Let's check the internal view first. + require.Equal( + t, expectedFeatures, g.graphCache.nodeFeatures[n.PubKeyBytes], + ) + + // The external view should reflect this as well. Except when we expect + // the features to be nil internally, we return an empty feature vector + // on the public interface instead. + if expectedFeatures == nil { + expectedFeatures = lnwire.EmptyFeatureVector() + } + features := g.graphCache.GetFeatures(n.PubKeyBytes) + require.Equal(t, expectedFeatures, features) +} + +func assertNodeNotInCache(t *testing.T, g *ChannelGraph, n route.Vertex) { + _, ok := g.graphCache.nodeFeatures[n] + require.False(t, ok) + + _, ok = g.graphCache.nodeChannels[n] + require.False(t, ok) + + // We should get the default features for this node. + features := g.graphCache.GetFeatures(n) + require.Equal(t, lnwire.EmptyFeatureVector(), features) +} + +func assertEdgeWithNoPoliciesInCache(t *testing.T, g *ChannelGraph, + e *ChannelEdgeInfo) { + + // Let's check the internal view first. + require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey1Bytes]) + require.NotEmpty(t, g.graphCache.nodeChannels[e.NodeKey2Bytes]) + + expectedNode1Channel := &DirectedChannel{ + ChannelID: e.ChannelID, + IsNode1: true, + OtherNode: e.NodeKey2Bytes, + Capacity: e.Capacity, + OutPolicySet: false, + InPolicy: nil, + } + require.Contains( + t, g.graphCache.nodeChannels[e.NodeKey1Bytes], e.ChannelID, + ) + require.Equal( + t, expectedNode1Channel, + g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID], + ) + + expectedNode2Channel := &DirectedChannel{ + ChannelID: e.ChannelID, + IsNode1: false, + OtherNode: e.NodeKey1Bytes, + Capacity: e.Capacity, + OutPolicySet: false, + InPolicy: nil, + } + require.Contains( + t, g.graphCache.nodeChannels[e.NodeKey2Bytes], e.ChannelID, + ) + require.Equal( + t, expectedNode2Channel, + g.graphCache.nodeChannels[e.NodeKey2Bytes][e.ChannelID], + ) + + // The external view should reflect this as well. + var foundChannel *DirectedChannel + err := g.graphCache.ForEachChannel( + e.NodeKey1Bytes, func(c *DirectedChannel) error { + if c.ChannelID == e.ChannelID { + foundChannel = c + } + + return nil + }, + ) + require.NoError(t, err) + require.NotNil(t, foundChannel) + require.Equal(t, expectedNode1Channel, foundChannel) + + err = g.graphCache.ForEachChannel( + e.NodeKey2Bytes, func(c *DirectedChannel) error { + if c.ChannelID == e.ChannelID { + foundChannel = c + } + + return nil + }, + ) + require.NoError(t, err) + require.NotNil(t, foundChannel) + require.Equal(t, expectedNode2Channel, foundChannel) +} + +func assertNoEdge(t *testing.T, g *ChannelGraph, chanID uint64) { + // Make sure no channel in the cache has the given channel ID. If there + // are no channels at all, that is fine as well. + for _, channels := range g.graphCache.nodeChannels { + for _, channel := range channels { + require.NotEqual(t, channel.ChannelID, chanID) + } + } +} + +func assertEdgeWithPolicyInCache(t *testing.T, g *ChannelGraph, + e *ChannelEdgeInfo, p *ChannelEdgePolicy, policy1 bool) { + + // Check the internal state first. + c1, ok := g.graphCache.nodeChannels[e.NodeKey1Bytes][e.ChannelID] + require.True(t, ok) + + if policy1 { + require.True(t, c1.OutPolicySet) + } else { + require.NotNil(t, c1.InPolicy) + require.Equal( + t, p.FeeProportionalMillionths, + c1.InPolicy.FeeProportionalMillionths, + ) + } + + c2, ok := g.graphCache.nodeChannels[e.NodeKey2Bytes][e.ChannelID] + require.True(t, ok) + + if policy1 { + require.NotNil(t, c2.InPolicy) + require.Equal( + t, p.FeeProportionalMillionths, + c2.InPolicy.FeeProportionalMillionths, + ) + } else { + require.True(t, c2.OutPolicySet) + } + + // Now for both nodes make sure that the external view is also correct. + var ( + c1Ext *DirectedChannel + c2Ext *DirectedChannel + ) + require.NoError(t, g.graphCache.ForEachChannel( + e.NodeKey1Bytes, func(c *DirectedChannel) error { + c1Ext = c + + return nil + }, + )) + require.NoError(t, g.graphCache.ForEachChannel( + e.NodeKey2Bytes, func(c *DirectedChannel) error { + c2Ext = c + + return nil + }, + )) + + // Only compare the fields that are actually copied, then compare the + // values of the functions separately. + require.Equal(t, c1, c1Ext.DeepCopy()) + require.Equal(t, c2, c2Ext.DeepCopy()) + if policy1 { + require.Equal( + t, p.FeeProportionalMillionths, + c2Ext.InPolicy.FeeProportionalMillionths, + ) + require.Equal( + t, route.Vertex(e.NodeKey2Bytes), + c2Ext.InPolicy.ToNodePubKey(), + ) + require.Equal(t, testFeatures, c2Ext.InPolicy.ToNodeFeatures) + } else { + require.Equal( + t, p.FeeProportionalMillionths, + c1Ext.InPolicy.FeeProportionalMillionths, + ) + require.Equal( + t, route.Vertex(e.NodeKey1Bytes), + c1Ext.InPolicy.ToNodePubKey(), + ) + require.Equal(t, testFeatures, c1Ext.InPolicy.ToNodeFeatures) + } +} + func randEdgePolicy(chanID uint64, db kvdb.Backend) *ChannelEdgePolicy { update := prand.Int63() @@ -890,106 +1096,10 @@ func TestGraphTraversal(t *testing.T) { // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the - // graph. + // graph. And we'll create 5 channels between each node pair. const numNodes = 20 - nodes := make([]*LightningNode, numNodes) - nodeIndex := map[string]struct{}{} - for i := 0; i < numNodes; i++ { - node, err := createTestVertex(graph.db) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - nodes[i] = node - nodeIndex[node.Alias] = struct{}{} - } - - // Add each of the nodes into the graph, they should be inserted - // without error. - for _, node := range nodes { - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - } - - // Iterate over each node as returned by the graph, if all nodes are - // reached, then the map created above should be empty. - err = graph.ForEachNode(func(_ kvdb.RTx, node *LightningNode) error { - delete(nodeIndex, node.Alias) - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(nodeIndex) != 0 { - t.Fatalf("all nodes not reached within ForEach") - } - - // Determine which node is "smaller", we'll need this in order to - // properly create the edges for the graph. - var firstNode, secondNode *LightningNode - if bytes.Compare(nodes[0].PubKeyBytes[:], nodes[1].PubKeyBytes[:]) == -1 { - firstNode = nodes[0] - secondNode = nodes[1] - } else { - firstNode = nodes[0] - secondNode = nodes[1] - } - - // Create 5 channels between the first two nodes we generated above. const numChannels = 5 - chanIndex := map[uint64]struct{}{} - for i := 0; i < numChannels; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64(i + 1) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: op, - Capacity: 1000, - } - copy(edgeInfo.NodeKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], nodes[1].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], nodes[1].PubKeyBytes[:]) - err := graph.AddChannelEdge(&edgeInfo) - if err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create and add an edge with random data that points from - // node1 -> node2. - edge := randEdgePolicy(chanID, graph.db) - edge.ChannelFlags = 0 - edge.Node = secondNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Create another random edge that points from node2 -> node1 - // this time. - edge = randEdgePolicy(chanID, graph.db) - edge.ChannelFlags = 1 - edge.Node = firstNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - chanIndex[chanID] = struct{}{} - } + chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels) // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have @@ -1000,16 +1110,13 @@ func TestGraphTraversal(t *testing.T) { delete(chanIndex, ei.ChannelID) return nil }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(chanIndex) != 0 { - t.Fatalf("all edges not reached within ForEach") - } + require.NoError(t, err) + require.Len(t, chanIndex, 0) // Finally, we want to test the ability to iterate over all the // outgoing channels for a particular node. numNodeChans := 0 + firstNode, secondNode := nodeList[0], nodeList[1] err = firstNode.ForEachChannel(nil, func(_ kvdb.RTx, _ *ChannelEdgeInfo, outEdge, inEdge *ChannelEdgePolicy) error { @@ -1034,13 +1141,148 @@ func TestGraphTraversal(t *testing.T) { numNodeChans++ return nil }) - if err != nil { - t.Fatalf("for each failure: %v", err) + require.NoError(t, err) + require.Equal(t, numChannels, numNodeChans) +} + +func TestGraphCacheTraversal(t *testing.T) { + t.Parallel() + + graph, cleanUp, err := MakeTestGraph() + defer cleanUp() + require.NoError(t, err) + + // We'd like to test some of the graph traversal capabilities within + // the DB, so we'll create a series of fake nodes to insert into the + // graph. And we'll create 5 channels between each node pair. + const numNodes = 20 + const numChannels = 5 + chanIndex, nodeList := fillTestGraph(t, graph, numNodes, numChannels) + + // Iterate through all the known channels within the graph DB, once + // again if the map is empty that indicates that all edges have + // properly been reached. + numNodeChans := 0 + for _, node := range nodeList { + err = graph.graphCache.ForEachChannel( + node.PubKeyBytes, func(d *DirectedChannel) error { + delete(chanIndex, d.ChannelID) + + if !d.OutPolicySet || d.InPolicy == nil { + return fmt.Errorf("channel policy not " + + "present") + } + + // The incoming edge should also indicate that + // it's pointing to the origin node. + inPolicyNodeKey := d.InPolicy.ToNodePubKey() + if !bytes.Equal( + inPolicyNodeKey[:], node.PubKeyBytes[:], + ) { + return fmt.Errorf("wrong outgoing edge") + } + + numNodeChans++ + + return nil + }, + ) + require.NoError(t, err) } - if numNodeChans != numChannels { - t.Fatalf("all edges for node not reached within ForEach: "+ - "expected %v, got %v", numChannels, numNodeChans) + require.Len(t, chanIndex, 0) + + // We count the channels for both nodes, so there should be double the + // amount now. Except for the very last node, that doesn't have any + // channels to make the loop easier in fillTestGraph(). + require.Equal(t, numChannels*2*(numNodes-1), numNodeChans) +} + +func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes, + numChannels int) (map[uint64]struct{}, []*LightningNode) { + + nodes := make([]*LightningNode, numNodes) + nodeIndex := map[string]struct{}{} + for i := 0; i < numNodes; i++ { + node, err := createTestVertex(graph.db) + require.NoError(t, err) + + nodes[i] = node + nodeIndex[node.Alias] = struct{}{} + } + + // Add each of the nodes into the graph, they should be inserted + // without error. + for _, node := range nodes { + require.NoError(t, graph.AddLightningNode(node)) + } + + // Iterate over each node as returned by the graph, if all nodes are + // reached, then the map created above should be empty. + err := graph.ForEachNode(func(_ kvdb.RTx, node *LightningNode) error { + delete(nodeIndex, node.Alias) + return nil + }) + require.NoError(t, err) + require.Len(t, nodeIndex, 0) + + // Create a number of channels between each of the node pairs generated + // above. This will result in numChannels*(numNodes-1) channels. + chanIndex := map[uint64]struct{}{} + for n := 0; n < numNodes-1; n++ { + node1 := nodes[n] + node2 := nodes[n+1] + if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 { + node1, node2 = node2, node1 + } + + for i := 0; i < numChannels; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + chanID := uint64((n << 4) + i + 1) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: op, + Capacity: 1000, + } + copy(edgeInfo.NodeKey1Bytes[:], node1.PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], node2.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], node1.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], node2.PubKeyBytes[:]) + err := graph.AddChannelEdge(&edgeInfo) + require.NoError(t, err) + + // Create and add an edge with random data that points + // from node1 -> node2. + edge := randEdgePolicy(chanID, graph.db) + edge.ChannelFlags = 0 + edge.Node = node2 + edge.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(edge)) + + // Create another random edge that points from + // node2 -> node1 this time. + edge = randEdgePolicy(chanID, graph.db) + edge.ChannelFlags = 1 + edge.Node = node1 + edge.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(edge)) + + chanIndex[chanID] = struct{}{} + } } + + return chanIndex, nodes } func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash, diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 1814c6358c..d39ff7a7d5 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -55,7 +55,7 @@ type RouterBackend struct { FindRoute func(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, destCustomRecords record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) MissionControl MissionControl diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 26a44cbb96..1b05d5f81e 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -126,7 +126,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, findRoute := func(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, _ record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) { if int64(amt) != amtSat*1000 { diff --git a/routing/heap.go b/routing/heap.go index f6869663cd..36563bb661 100644 --- a/routing/heap.go +++ b/routing/heap.go @@ -39,7 +39,7 @@ type nodeWithDist struct { weight int64 // nextHop is the edge this route comes from. - nextHop *channeldb.ChannelEdgePolicy + nextHop *channeldb.CachedEdgePolicy // routingInfoSize is the total size requirement for the payloads field // in the onion packet from this hop towards the final destination. diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index badeeebb99..d29c096fd7 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -186,20 +186,13 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, IsNode1: nodePub == node1, OtherNode: peer, Capacity: channel.capacity, - OutPolicy: &channeldb.ChannelEdgePolicy{ + OutPolicySet: true, + InPolicy: &channeldb.CachedEdgePolicy{ ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: peer, - Features: lnwire.EmptyFeatureVector(), - }, - FeeBaseMSat: node.baseFee, - }, - InPolicy: &channeldb.ChannelEdgePolicy{ - ChannelID: channel.id, - Node: &channeldb.LightningNode{ - PubKeyBytes: nodePub, - Features: lnwire.EmptyFeatureVector(), + ToNodePubKey: func() route.Vertex { + return nodePub }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), FeeBaseMSat: peerNode.baseFee, }, }, diff --git a/routing/mock_test.go b/routing/mock_test.go index 383f891856..a59ae2aa4d 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -173,13 +173,13 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi, } func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate, - _ *btcec.PublicKey, _ *channeldb.ChannelEdgePolicy) bool { + _ *btcec.PublicKey, _ *channeldb.CachedEdgePolicy) bool { return false } func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *btcec.PublicKey, - _ uint64) *channeldb.ChannelEdgePolicy { + _ uint64) *channeldb.CachedEdgePolicy { return nil } @@ -637,17 +637,17 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, } func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, - pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { + pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { args := m.Called(msg, pubKey, policy) return args.Bool(0) } func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy { + channelID uint64) *channeldb.CachedEdgePolicy { args := m.Called(pubKey, channelID) - return args.Get(0).(*channeldb.ChannelEdgePolicy) + return args.Get(0).(*channeldb.CachedEdgePolicy) } type mockControlTower struct { diff --git a/routing/pathfind.go b/routing/pathfind.go index fc3be79425..27a67ea7a3 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -42,7 +42,7 @@ const ( type pathFinder = func(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( - []*channeldb.ChannelEdgePolicy, error) + []*channeldb.CachedEdgePolicy, error) var ( // DefaultAttemptCost is the default fixed virtual cost in path finding @@ -76,7 +76,7 @@ var ( // of the edge. type edgePolicyWithSource struct { sourceNode route.Vertex - edge *channeldb.ChannelEdgePolicy + edge *channeldb.CachedEdgePolicy } // finalHopParams encapsulates various parameters for route construction that @@ -102,7 +102,7 @@ type finalHopParams struct { // any feature vectors on all hops have been validated for transitive // dependencies. func newRoute(sourceVertex route.Vertex, - pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32, + pathEdges []*channeldb.CachedEdgePolicy, currentHeight uint32, finalHop finalHopParams) (*route.Route, error) { var ( @@ -147,10 +147,10 @@ func newRoute(sourceVertex route.Vertex, supports := func(feature lnwire.FeatureBit) bool { // If this edge comes from router hints, the features // could be nil. - if edge.Node.Features == nil { + if edge.ToNodeFeatures == nil { return false } - return edge.Node.Features.HasFeature(feature) + return edge.ToNodeFeatures.HasFeature(feature) } // We start by assuming the node doesn't support TLV. We'll now @@ -225,7 +225,7 @@ func newRoute(sourceVertex route.Vertex, // each new hop such that, the final slice of hops will be in // the forwards order. currentHop := &route.Hop{ - PubKeyBytes: edge.Node.PubKeyBytes, + PubKeyBytes: edge.ToNodePubKey(), ChannelID: edge.ChannelID, AmtToForward: amtToForward, OutgoingTimeLock: outgoingTimeLock, @@ -280,7 +280,7 @@ type graphParams struct { // additionalEdges is an optional set of edges that should be // considered during path finding, that is not already found in the // channel graph. - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy // bandwidthHints is an optional map from channels to bandwidths that // can be populated if the caller has a better estimate of the current @@ -360,7 +360,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, var max, total lnwire.MilliSatoshi cb := func(channel *channeldb.DirectedChannel) error { - if channel.OutPolicy == nil { + if !channel.OutPolicySet { return nil } @@ -412,7 +412,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -519,7 +519,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Build reverse lookup to find incoming edges. Needed because // search is taken place from target to source. for _, outgoingEdgePolicy := range outgoingEdgePolicies { - toVertex := outgoingEdgePolicy.Node.PubKeyBytes + toVertex := outgoingEdgePolicy.ToNodePubKey() incomingEdgePolicy := &edgePolicyWithSource{ sourceNode: vertex, edge: outgoingEdgePolicy, @@ -583,7 +583,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // satisfy our specific requirements. processEdge := func(fromVertex route.Vertex, fromFeatures *lnwire.FeatureVector, - edge *channeldb.ChannelEdgePolicy, toNodeDist *nodeWithDist) { + edge *channeldb.CachedEdgePolicy, toNodeDist *nodeWithDist) { edgesExpanded++ @@ -879,7 +879,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Use the distance map to unravel the forward path from source to // target. - var pathEdges []*channeldb.ChannelEdgePolicy + var pathEdges []*channeldb.CachedEdgePolicy currentNode := source for { // Determine the next hop forward using the next map. @@ -894,7 +894,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, pathEdges = append(pathEdges, currentNodeWithDist.nextHop) // Advance current node. - currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes + currentNode = currentNodeWithDist.nextHop.ToNodePubKey() // Check stop condition at the end of this loop. This prevents // breaking out too soon for self-payments that have target set @@ -915,7 +915,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // route construction does not care where the features are actually // taken from. In the future we may wish to do route construction within // findPath, and avoid using ChannelEdgePolicy altogether. - pathEdges[len(pathEdges)-1].Node.Features = features + pathEdges[len(pathEdges)-1].ToNodeFeatures = features log.Debugf("Found route: probability=%v, hops=%v, fee=%v", distance[source].probability, len(pathEdges), diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index b353c24eae..7c7c7586b5 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -1099,20 +1099,23 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) { // Create the channel edge going from songoku to doge and include it in // our map of additional edges. - songokuToDoge := &channeldb.ChannelEdgePolicy{ - Node: doge, + songokuToDoge := &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return doge.PubKeyBytes + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), ChannelID: 1337, FeeBaseMSat: 1, FeeProportionalMillionths: 1000, TimeLockDelta: 9, } - additionalEdges := map[route.Vertex][]*channeldb.ChannelEdgePolicy{ + additionalEdges := map[route.Vertex][]*channeldb.CachedEdgePolicy{ graph.aliasMap["songoku"]: {songokuToDoge}, } find := func(r *RestrictParams) ( - []*channeldb.ChannelEdgePolicy, error) { + []*channeldb.CachedEdgePolicy, error) { return dbFindPath( graph.graph, additionalEdges, nil, @@ -1179,14 +1182,13 @@ func TestNewRoute(t *testing.T) { createHop := func(baseFee lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, bandwidth lnwire.MilliSatoshi, - timeLockDelta uint16) *channeldb.ChannelEdgePolicy { + timeLockDelta uint16) *channeldb.CachedEdgePolicy { - return &channeldb.ChannelEdgePolicy{ - Node: &channeldb.LightningNode{ - Features: lnwire.NewFeatureVector( - nil, nil, - ), + return &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return route.Vertex{} }, + ToNodeFeatures: lnwire.NewFeatureVector(nil, nil), FeeProportionalMillionths: feeRate, FeeBaseMSat: baseFee, TimeLockDelta: timeLockDelta, @@ -1199,7 +1201,7 @@ func TestNewRoute(t *testing.T) { // hops is the list of hops (the route) that gets passed into // the call to newRoute. - hops []*channeldb.ChannelEdgePolicy + hops []*channeldb.CachedEdgePolicy // paymentAmount is the amount that is send into the route // indicated by hops. @@ -1248,7 +1250,7 @@ func TestNewRoute(t *testing.T) { // For a single hop payment, no fees are expected to be paid. name: "single hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(100, 1000, 1000000, 10), }, expectedFees: []lnwire.MilliSatoshi{0}, @@ -1261,7 +1263,7 @@ func TestNewRoute(t *testing.T) { // a fee to receive the payment. name: "two hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1276,7 +1278,7 @@ func TestNewRoute(t *testing.T) { name: "two hop tlv onion feature", destFeatures: tlvFeatures, paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1293,7 +1295,7 @@ func TestNewRoute(t *testing.T) { destFeatures: tlvPayAddrFeatures, paymentAddr: &testPaymentAddr, paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 1000, 1000000, 10), createHop(30, 1000, 1000000, 5), }, @@ -1313,7 +1315,7 @@ func TestNewRoute(t *testing.T) { // gets rounded down to 1. name: "three hop", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10, 1000000, 10), createHop(0, 10, 1000000, 5), createHop(0, 10, 1000000, 3), @@ -1328,7 +1330,7 @@ func TestNewRoute(t *testing.T) { // because of the increase amount to forward. name: "three hop with fee carry over", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10000, 1000000, 10), createHop(0, 10000, 1000000, 5), createHop(0, 10000, 1000000, 3), @@ -1343,7 +1345,7 @@ func TestNewRoute(t *testing.T) { // effect. name: "three hop with minimal fees for carry over", paymentAmount: 100000, - hops: []*channeldb.ChannelEdgePolicy{ + hops: []*channeldb.CachedEdgePolicy{ createHop(0, 10000, 1000000, 10), // First hop charges 0.1% so the second hop fee @@ -1367,7 +1369,7 @@ func TestNewRoute(t *testing.T) { // custom feature vector. if testCase.destFeatures != nil { finalHop := testCase.hops[len(testCase.hops)-1] - finalHop.Node.Features = testCase.destFeatures + finalHop.ToNodeFeatures = testCase.destFeatures } assertRoute := func(t *testing.T, route *route.Route) { @@ -1594,7 +1596,7 @@ func TestDestTLVGraphFallback(t *testing.T) { } find := func(r *RestrictParams, - target route.Vertex) ([]*channeldb.ChannelEdgePolicy, error) { + target route.Vertex) ([]*channeldb.CachedEdgePolicy, error) { return dbFindPath( ctx.graph, nil, nil, @@ -2325,16 +2327,16 @@ func TestPathFindSpecExample(t *testing.T) { } func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, - path []*channeldb.ChannelEdgePolicy, nodeAliases ...string) { + path []*channeldb.CachedEdgePolicy, nodeAliases ...string) { if len(path) != len(nodeAliases) { t.Fatal("number of hops and number of aliases do not match") } for i, hop := range path { - if hop.Node.PubKeyBytes != aliasMap[nodeAliases[i]] { + if hop.ToNodePubKey() != aliasMap[nodeAliases[i]] { t.Fatalf("expected %v to be pos #%v in hop, instead "+ - "%v was", nodeAliases[i], i, hop.Node.Alias) + "%v was", nodeAliases[i], i, hop.ToNodePubKey()) } } } @@ -2985,7 +2987,7 @@ func (c *pathFindingTestContext) cleanup() { } func (c *pathFindingTestContext) findPath(target route.Vertex, - amt lnwire.MilliSatoshi) ([]*channeldb.ChannelEdgePolicy, + amt lnwire.MilliSatoshi) ([]*channeldb.CachedEdgePolicy, error) { return dbFindPath( @@ -2994,7 +2996,9 @@ func (c *pathFindingTestContext) findPath(target route.Vertex, ) } -func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, expected []uint64) { +func (c *pathFindingTestContext) assertPath(path []*channeldb.CachedEdgePolicy, + expected []uint64) { + if len(path) != len(expected) { c.t.Fatalf("expected path of length %v, but got %v", len(expected), len(path)) @@ -3011,11 +3015,11 @@ func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, // dbFindPath calls findPath after getting a db transaction from the database // graph. func dbFindPath(graph *channeldb.ChannelGraph, - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy, + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy, bandwidthHints map[uint64]lnwire.MilliSatoshi, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { routingTx, err := newDbRoutingTx(graph) if err != nil { diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index aa856e7b2f..945a534666 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -898,7 +898,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route, var ( isAdditionalEdge bool - policy *channeldb.ChannelEdgePolicy + policy *channeldb.CachedEdgePolicy ) // Before we apply the channel update, we need to decide whether the diff --git a/routing/payment_session.go b/routing/payment_session.go index 22e88090ba..d3024d3ff5 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -144,13 +144,13 @@ type PaymentSession interface { // a boolean to indicate whether the update has been applied without // error. UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey, - policy *channeldb.ChannelEdgePolicy) bool + policy *channeldb.CachedEdgePolicy) bool // GetAdditionalEdgePolicy uses the public key and channel ID to query // the ephemeral channel edge policy for additional edges. Returns a nil // if nothing found. GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy + channelID uint64) *channeldb.CachedEdgePolicy } // paymentSession is used during an HTLC routings session to prune the local @@ -162,7 +162,7 @@ type PaymentSession interface { // loop if payment attempts take long enough. An additional set of edges can // also be provided to assist in reaching the payment's destination. type paymentSession struct { - additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy + additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error) @@ -403,7 +403,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // updates to the supplied policy. It returns a boolean to indicate whether // there's an error when applying the updates. func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, - pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { + pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool { // Validate the message signature. if err := VerifyChannelUpdateSignature(msg, pubKey); err != nil { @@ -428,7 +428,7 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, // ephemeral channel edge policy for additional edges. Returns a nil if nothing // found. func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, - channelID uint64) *channeldb.ChannelEdgePolicy { + channelID uint64) *channeldb.CachedEdgePolicy { target := route.NewVertex(pubKey) diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 661d5861db..fdfccd5f1f 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -93,9 +93,9 @@ func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession { // RouteHintsToEdges converts a list of invoice route hints to an edge map that // can be passed into pathfinding. func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( - map[route.Vertex][]*channeldb.ChannelEdgePolicy, error) { + map[route.Vertex][]*channeldb.CachedEdgePolicy, error) { - edges := make(map[route.Vertex][]*channeldb.ChannelEdgePolicy) + edges := make(map[route.Vertex][]*channeldb.CachedEdgePolicy) // Traverse through all of the available hop hints and include them in // our edges map, indexed by the public key of the channel's starting @@ -125,9 +125,12 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( // Finally, create the channel edge from the hop hint // and add it to list of edges corresponding to the node // at the start of the channel. - edge := &channeldb.ChannelEdgePolicy{ - Node: endNode, - ChannelID: hopHint.ChannelID, + edge := &channeldb.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return endNode.PubKeyBytes + }, + ToNodeFeatures: lnwire.EmptyFeatureVector(), + ChannelID: hopHint.ChannelID, FeeBaseMSat: lnwire.MilliSatoshi( hopHint.FeeBaseMSat, ), diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index edc4515b50..bcfc3b0e9a 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -217,7 +217,7 @@ func TestRequestRoute(t *testing.T) { session.pathFinder = func( g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, - finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { + finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { // We expect find path to receive a cltv limit excluding the // final cltv delta (including the block padding). @@ -225,13 +225,14 @@ func TestRequestRoute(t *testing.T) { t.Fatal("wrong cltv limit") } - path := []*channeldb.ChannelEdgePolicy{ + path := []*channeldb.CachedEdgePolicy{ { - Node: &channeldb.LightningNode{ - Features: lnwire.NewFeatureVector( - nil, nil, - ), + ToNodePubKey: func() route.Vertex { + return route.Vertex{} }, + ToNodeFeatures: lnwire.NewFeatureVector( + nil, nil, + ), }, } diff --git a/routing/router.go b/routing/router.go index aa034eea02..1de1130563 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1727,7 +1727,7 @@ type routingMsg struct { func (r *ChannelRouter) FindRoute(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *RestrictParams, destCustomRecords record.CustomSet, - routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, + routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, finalExpiry uint16) (*route.Route, error) { log.Debugf("Searching for path to %v, sending %v", target, amt) @@ -2822,7 +2822,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // total amount, we make a forward pass. Because the amount may have // been increased in the backward pass, fees need to be recalculated and // amount ranges re-checked. - var pathEdges []*channeldb.ChannelEdgePolicy + var pathEdges []*channeldb.CachedEdgePolicy receiverAmt := runningAmt for i, edge := range edges { policy := edge.getPolicy(receiverAmt, bandwidthHints) diff --git a/routing/router_test.go b/routing/router_test.go index d263ce7384..ed6bfdc6a0 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -2478,8 +2478,8 @@ func TestFindPathFeeWeighting(t *testing.T) { if len(path) != 1 { t.Fatalf("expected path length of 1, instead was: %v", len(path)) } - if path[0].Node.Alias != "luoji" { - t.Fatalf("wrong node: %v", path[0].Node.Alias) + if path[0].ToNodePubKey() != ctx.aliases["luoji"] { + t.Fatalf("wrong node: %v", path[0].ToNodePubKey()) } } diff --git a/routing/unified_policies.go b/routing/unified_policies.go index 4a6e5e00ba..fe7cc1ec48 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -40,7 +40,7 @@ func newUnifiedPolicies(sourceNode, toNode route.Vertex, // addPolicy adds a single channel policy. Capacity may be zero if unknown // (light clients). func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, - edge *channeldb.ChannelEdgePolicy, capacity btcutil.Amount) { + edge *channeldb.CachedEdgePolicy, capacity btcutil.Amount) { localChan := fromNode == u.sourceNode @@ -92,7 +92,7 @@ func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { // unifiedPolicyEdge is the individual channel data that is kept inside an // unifiedPolicy object. type unifiedPolicyEdge struct { - policy *channeldb.ChannelEdgePolicy + policy *channeldb.CachedEdgePolicy capacity btcutil.Amount } @@ -133,7 +133,7 @@ type unifiedPolicy struct { // specific amount to send. It differentiates between local and network // channels. func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, - bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { if u.localChan { return u.getPolicyLocal(amt, bandwidthHints) @@ -145,10 +145,10 @@ func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, // getPolicyLocal returns the optimal policy to use for this local connection // given a specific amount to send. func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, - bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { var ( - bestPolicy *channeldb.ChannelEdgePolicy + bestPolicy *channeldb.CachedEdgePolicy maxBandwidth lnwire.MilliSatoshi ) @@ -200,10 +200,10 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, // a specific amount to send. The goal is to return a policy that maximizes the // probability of a successful forward in a non-strict forwarding context. func (u *unifiedPolicy) getPolicyNetwork( - amt lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + amt lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { var ( - bestPolicy *channeldb.ChannelEdgePolicy + bestPolicy *channeldb.CachedEdgePolicy maxFee lnwire.MilliSatoshi maxTimelock uint16 ) diff --git a/routing/unified_policies_test.go b/routing/unified_policies_test.go index e89a3cb122..ac915f99a7 100644 --- a/routing/unified_policies_test.go +++ b/routing/unified_policies_test.go @@ -20,7 +20,7 @@ func TestUnifiedPolicies(t *testing.T) { u := newUnifiedPolicies(source, toNode, nil) // Add two channels between the pair of nodes. - p1 := channeldb.ChannelEdgePolicy{ + p1 := channeldb.CachedEdgePolicy{ FeeProportionalMillionths: 100000, FeeBaseMSat: 30, TimeLockDelta: 60, @@ -28,7 +28,7 @@ func TestUnifiedPolicies(t *testing.T) { MaxHTLC: 500, MinHTLC: 100, } - p2 := channeldb.ChannelEdgePolicy{ + p2 := channeldb.CachedEdgePolicy{ FeeProportionalMillionths: 190000, FeeBaseMSat: 10, TimeLockDelta: 40, @@ -39,7 +39,7 @@ func TestUnifiedPolicies(t *testing.T) { u.addPolicy(fromNode, &p1, 7) u.addPolicy(fromNode, &p2, 7) - checkPolicy := func(policy *channeldb.ChannelEdgePolicy, + checkPolicy := func(policy *channeldb.CachedEdgePolicy, feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, timeLockDelta uint16) { From bf27d05aa8b03b4581a9a917c4fed7a747699eac Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:24 +0200 Subject: [PATCH 11/15] routing+server: use cached graph interface --- routing/graph.go | 23 +++++++++------ routing/integrated_routing_context_test.go | 6 +--- routing/mock_graph_test.go | 10 +++---- routing/pathfind_test.go | 4 +-- routing/payment_session.go | 19 ++++-------- routing/payment_session_source.go | 21 ++----------- routing/payment_session_test.go | 8 ++--- routing/router.go | 34 ++++++++++------------ routing/router_test.go | 5 +++- server.go | 6 +++- 10 files changed, 56 insertions(+), 80 deletions(-) diff --git a/routing/graph.go b/routing/graph.go index 578f480abf..7e0ba65b2d 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -9,7 +9,8 @@ import ( // routingGraph is an abstract interface that provides information about nodes // and edges to pathfinding. type routingGraph interface { - // forEachNodeChannel calls the callback for every channel of the given node. + // forEachNodeChannel calls the callback for every channel of the given + // node. forEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error @@ -20,22 +21,26 @@ type routingGraph interface { fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) } -// dbRoutingTx is a routingGraph implementation that retrieves from the +// CachedGraph is a routingGraph implementation that retrieves from the // database. -type dbRoutingTx struct { +type CachedGraph struct { graph *channeldb.ChannelGraph source route.Vertex } -// newDbRoutingTx instantiates a new db-connected routing graph. It implictly +// A compile time assertion to make sure CachedGraph implements the routingGraph +// interface. +var _ routingGraph = (*CachedGraph)(nil) + +// NewCachedGraph instantiates a new db-connected routing graph. It implictly // instantiates a new read transaction. -func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { +func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) { sourceNode, err := graph.SourceNode() if err != nil { return nil, err } - return &dbRoutingTx{ + return &CachedGraph{ graph: graph, source: sourceNode.PubKeyBytes, }, nil @@ -44,7 +49,7 @@ func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { // forEachNodeChannel calls the callback for every channel of the given node. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, +func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex, cb func(channel *channeldb.DirectedChannel) error) error { return g.graph.ForEachNodeChannel(nodePub, cb) @@ -53,7 +58,7 @@ func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, // sourceNode returns the source node of the graph. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) sourceNode() route.Vertex { +func (g *CachedGraph) sourceNode() route.Vertex { return g.source } @@ -61,7 +66,7 @@ func (g *dbRoutingTx) sourceNode() route.Vertex { // unknown, assume no additional features are supported. // // NOTE: Part of the routingGraph interface. -func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( +func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) ( *lnwire.FeatureVector, error) { return g.graph.FetchNodeFeatures(nodePub) diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index 114b2272e6..d13b1c4324 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -162,11 +162,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, } session, err := newPaymentSession( - &payment, getBandwidthHints, - func() (routingGraph, func(), error) { - return c.graph, func() {}, nil - }, - mc, c.pathFindingCfg, + &payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg, ) if err != nil { c.t.Fatal(err) diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go index d29c096fd7..6d01566665 100644 --- a/routing/mock_graph_test.go +++ b/routing/mock_graph_test.go @@ -182,10 +182,10 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, // Call the per channel callback. err := cb( &channeldb.DirectedChannel{ - ChannelID: channel.id, - IsNode1: nodePub == node1, - OtherNode: peer, - Capacity: channel.capacity, + ChannelID: channel.id, + IsNode1: nodePub == node1, + OtherNode: peer, + Capacity: channel.capacity, OutPolicySet: true, InPolicy: &channeldb.CachedEdgePolicy{ ChannelID: channel.id, @@ -193,7 +193,7 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, return nodePub }, ToNodeFeatures: lnwire.EmptyFeatureVector(), - FeeBaseMSat: peerNode.baseFee, + FeeBaseMSat: peerNode.baseFee, }, }, ) diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 7c7c7586b5..d438de8248 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -3021,7 +3021,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { - routingTx, err := newDbRoutingTx(graph) + routingGraph, err := NewCachedGraph(graph) if err != nil { return nil, err } @@ -3030,7 +3030,7 @@ func dbFindPath(graph *channeldb.ChannelGraph, &graphParams{ additionalEdges: additionalEdges, bandwidthHints: bandwidthHints, - graph: routingTx, + graph: routingGraph, }, r, cfg, source, target, amt, finalHtlcExpiry, ) diff --git a/routing/payment_session.go b/routing/payment_session.go index d3024d3ff5..bbf9b6f96a 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -172,7 +172,7 @@ type paymentSession struct { pathFinder pathFinder - getRoutingGraph func() (routingGraph, func(), error) + routingGraph routingGraph // pathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probabiity. @@ -193,7 +193,7 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error), - getRoutingGraph func() (routingGraph, func(), error), + routingGraph routingGraph, missionControl MissionController, pathFindingConfig PathFindingConfig) ( *paymentSession, error) { @@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment, getBandwidthHints: getBandwidthHints, payment: p, pathFinder: findPath, - getRoutingGraph: getRoutingGraph, + routingGraph: routingGraph, pathFindingConfig: pathFindingConfig, missionControl: missionControl, minShardAmt: DefaultShardMinAmt, @@ -287,29 +287,20 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, p.log.Debugf("pathfinding for amt=%v", maxAmt) - // Get a routing graph. - routingGraph, cleanup, err := p.getRoutingGraph() - if err != nil { - return nil, err - } - - sourceVertex := routingGraph.sourceNode() + sourceVertex := p.routingGraph.sourceNode() // Find a route for the current amount. path, err := p.pathFinder( &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, - graph: routingGraph, + graph: p.routingGraph, }, restrictions, &p.pathFindingConfig, sourceVertex, p.payment.Target, maxAmt, finalHtlcExpiry, ) - // Close routing graph. - cleanup() - switch { case err == errNoPathFound: // Don't split if this is a legacy payment without mpp diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index fdfccd5f1f..d688f98148 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -17,7 +17,7 @@ var _ PaymentSessionSource = (*SessionSource)(nil) type SessionSource struct { // Graph is the channel graph that will be used to gather metrics from // and also to carry out path finding queries. - Graph *channeldb.ChannelGraph + Graph routingGraph // QueryBandwidth is a method that allows querying the lower link layer // to determine the up to date available bandwidth at a prospective link @@ -40,16 +40,6 @@ type SessionSource struct { PathFindingConfig PathFindingConfig } -// getRoutingGraph returns a routing graph and a clean-up function for -// pathfinding. -func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { - routingTx, err := newDbRoutingTx(m.Graph) - if err != nil { - return nil, nil, err - } - return routingTx, func() {}, nil -} - // NewPaymentSession creates a new payment session backed by the latest prune // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the @@ -57,21 +47,16 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) { func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( PaymentSession, error) { - sourceNode, err := m.Graph.SourceNode() - if err != nil { - return nil, err - } - getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, error) { return generateBandwidthHints( - sourceNode.PubKeyBytes, m.Graph, m.QueryBandwidth, + m.Graph.sourceNode(), m.Graph, m.QueryBandwidth, ) } session, err := newPaymentSession( - p, getBandwidthHints, m.getRoutingGraph, + p, getBandwidthHints, m.Graph, m.MissionControl, m.PathFindingConfig, ) if err != nil { diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index bcfc3b0e9a..dae331f840 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -121,9 +121,7 @@ func TestUpdateAdditionalEdge(t *testing.T) { return nil, nil }, - func() (routingGraph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + &sessionGraph{}, &MissionControl{}, PathFindingConfig{}, ) @@ -203,9 +201,7 @@ func TestRequestRoute(t *testing.T) { return nil, nil }, - func() (routingGraph, func(), error) { - return &sessionGraph{}, func() {}, nil - }, + &sessionGraph{}, &MissionControl{}, PathFindingConfig{}, ) diff --git a/routing/router.go b/routing/router.go index 1de1130563..dd8a375a2c 100644 --- a/routing/router.go +++ b/routing/router.go @@ -406,6 +406,10 @@ type ChannelRouter struct { // when doing any path finding. selfNode *channeldb.LightningNode + // cachedGraph is an instance of routingGraph that caches the source node as + // well as the channel graph itself in memory. + cachedGraph routingGraph + // newBlocks is a channel in which new blocks connected to the end of // the main chain are sent over, and blocks updated after a call to // UpdateFilter. @@ -460,14 +464,17 @@ var _ ChannelGraphSource = (*ChannelRouter)(nil) // channel graph is a subset of the UTXO set) set, then the router will proceed // to fully sync to the latest state of the UTXO set. func New(cfg Config) (*ChannelRouter, error) { - selfNode, err := cfg.Graph.SourceNode() if err != nil { return nil, err } r := &ChannelRouter{ - cfg: &cfg, + cfg: &cfg, + cachedGraph: &CachedGraph{ + graph: cfg.Graph, + source: selfNode.PubKeyBytes, + }, networkUpdates: make(chan *routingMsg), topologyClients: make(map[uint64]*topologyClient), ntfnClientUpdates: make(chan *topologyClientUpdate), @@ -1735,7 +1742,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := generateBandwidthHints( - r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err @@ -1752,16 +1759,11 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // execute our path finding algorithm. finalHtlcExpiry := currentHeight + int32(finalExpiry) - routingTx, err := newDbRoutingTx(r.cfg.Graph) - if err != nil { - return nil, err - } - path, err := findPath( &graphParams{ additionalEdges: routeHints, bandwidthHints: bandwidthHints, - graph: routingTx, + graph: r.cachedGraph, }, restrictions, &r.cfg.PathFindingConfig, @@ -2657,14 +2659,14 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error { // these hints allows us to reduce the number of extraneous attempts as we can // skip channels that are inactive, or just don't have enough bandwidth to // carry the payment. -func generateBandwidthHints(sourceNode route.Vertex, graph *channeldb.ChannelGraph, +func generateBandwidthHints(sourceNode route.Vertex, graph routingGraph, queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) ( map[uint64]lnwire.MilliSatoshi, error) { // First, we'll collect the set of outbound edges from the target // source node. var localChans []*channeldb.DirectedChannel - err := graph.ForEachNodeChannel( + err := graph.forEachNodeChannel( sourceNode, func(channel *channeldb.DirectedChannel) error { localChans = append(localChans, channel) return nil @@ -2722,7 +2724,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := generateBandwidthHints( - r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err @@ -2752,12 +2754,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, runningAmt = *amt } - // Open a transaction to execute the graph queries in. - routingTx, err := newDbRoutingTx(r.cfg.Graph) - if err != nil { - return nil, err - } - // Traverse hops backwards to accumulate fees in the running amounts. source := r.selfNode.PubKeyBytes for i := len(hops) - 1; i >= 0; i-- { @@ -2776,7 +2772,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // known in the graph. u := newUnifiedPolicies(source, toNode, outgoingChans) - err := u.addGraphPolicies(routingTx) + err := u.addGraphPolicies(r.cachedGraph) if err != nil { return nil, err } diff --git a/routing/router_test.go b/routing/router_test.go index ed6bfdc6a0..4b5dd505f1 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -129,8 +129,11 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, ) require.NoError(t, err, "failed to create missioncontrol") + cachedGraph, err := NewCachedGraph(graphInstance.graph) + require.NoError(t, err) + sessionSource := &SessionSource{ - Graph: graphInstance.graph, + Graph: cachedGraph, QueryBandwidth: func( c *channeldb.DirectedChannel) lnwire.MilliSatoshi { diff --git a/server.go b/server.go index 55e49d9f26..5531a6b33a 100644 --- a/server.go +++ b/server.go @@ -776,8 +776,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr, MinProbability: routingConfig.MinRouteProbability, } + cachedGraph, err := routing.NewCachedGraph(chanGraph) + if err != nil { + return nil, err + } paymentSessionSource := &routing.SessionSource{ - Graph: chanGraph, + Graph: cachedGraph, MissionControl: s.missionControl, QueryBandwidth: queryBandwidth, PathFindingConfig: pathFindingConfig, From a95a3728b56f849dcad88780e2462debba35c16a Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:25 +0200 Subject: [PATCH 12/15] lnd+channeldb: pre-allocate cache size To avoid the channel map needing to be re-grown while we fill the cache initially, we might as well pre-allocate it with a somewhat sane value to decrease the number of grow events. --- channeldb/db.go | 2 +- channeldb/graph.go | 5 +++-- channeldb/graph_cache.go | 15 ++++++++++++--- channeldb/graph_cache_test.go | 2 +- channeldb/graph_test.go | 2 +- channeldb/options.go | 24 +++++++++++++++++++++--- lnd.go | 22 ++++++++++++++++++---- routing/pathfind_test.go | 2 +- 8 files changed, 58 insertions(+), 16 deletions(-) diff --git a/channeldb/db.go b/channeldb/db.go index 633639892c..17275c2922 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -290,7 +290,7 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB, var err error chanDB.graph, err = NewChannelGraph( backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, + opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, ) if err != nil { return nil, err diff --git a/channeldb/graph.go b/channeldb/graph.go index e3ec83113f..69c86d1f92 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -188,7 +188,8 @@ type ChannelGraph struct { // NewChannelGraph allocates a new ChannelGraph backed by a DB instance. The // returned instance has its own unique reject cache and channel cache. func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, - batchCommitInterval time.Duration) (*ChannelGraph, error) { + batchCommitInterval time.Duration, + preAllocCacheNumNodes int) (*ChannelGraph, error) { if err := initChannelGraph(db); err != nil { return nil, err @@ -198,7 +199,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, db: db, rejectCache: newRejectCache(rejectCacheSize), chanCache: newChannelCache(chanCacheSize), - graphCache: NewGraphCache(), + graphCache: NewGraphCache(preAllocCacheNumNodes), } g.chanScheduler = batch.NewTimeScheduler( db, &g.cacheMu, batchCommitInterval, diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index f36d022fbf..bec44c3e56 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -175,10 +175,19 @@ type GraphCache struct { } // NewGraphCache creates a new graphCache. -func NewGraphCache() *GraphCache { +func NewGraphCache(preAllocNumNodes int) *GraphCache { return &GraphCache{ - nodeChannels: make(map[route.Vertex]map[uint64]*DirectedChannel), - nodeFeatures: make(map[route.Vertex]*lnwire.FeatureVector), + nodeChannels: make( + map[route.Vertex]map[uint64]*DirectedChannel, + // A channel connects two nodes, so we can look it up + // from both sides, meaning we get double the number of + // entries. + preAllocNumNodes*2, + ), + nodeFeatures: make( + map[route.Vertex]*lnwire.FeatureVector, + preAllocNumNodes, + ), } } diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index 57666e1eb9..09cfbf2373 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -97,7 +97,7 @@ func TestGraphCacheAddNode(t *testing.T) { outPolicies: []*ChannelEdgePolicy{outPolicy1}, inPolicies: []*ChannelEdgePolicy{inPolicy1}, } - cache := NewGraphCache() + cache := NewGraphCache(10) require.NoError(t, cache.AddNode(nil, node)) var fromChannels, toChannels []*DirectedChannel diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index e624105a36..b43dbb972e 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -75,7 +75,7 @@ func MakeTestGraph(modifiers ...OptionModifier) (*ChannelGraph, func(), error) { graph, err := NewChannelGraph( backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, + opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, ) if err != nil { backendCleanup() diff --git a/channeldb/options.go b/channeldb/options.go index ceb29bf7b3..ad22fa8ed2 100644 --- a/channeldb/options.go +++ b/channeldb/options.go @@ -17,6 +17,12 @@ const ( // in order to reply to gossip queries. This produces a cache size of // around 40MB. DefaultChannelCacheSize = 20000 + + // DefaultPreAllocCacheNumNodes is the default number of channels we + // assume for mainnet for pre-allocating the graph cache. As of + // September 2021, there currently are 14k nodes in a strictly pruned + // graph, so we choose a number that is slightly higher. + DefaultPreAllocCacheNumNodes = 15000 ) // Options holds parameters for tuning and customizing a channeldb.DB. @@ -35,6 +41,10 @@ type Options struct { // wait before attempting to commit a pending set of updates. BatchCommitInterval time.Duration + // PreAllocCacheNumNodes is the number of nodes we expect to be in the + // graph cache, so we can pre-allocate the map accordingly. + PreAllocCacheNumNodes int + // clock is the time source used by the database. clock clock.Clock @@ -52,9 +62,10 @@ func DefaultOptions() Options { AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge, DBTimeout: kvdb.DefaultDBTimeout, }, - RejectCacheSize: DefaultRejectCacheSize, - ChannelCacheSize: DefaultChannelCacheSize, - clock: clock.NewDefaultClock(), + RejectCacheSize: DefaultRejectCacheSize, + ChannelCacheSize: DefaultChannelCacheSize, + PreAllocCacheNumNodes: DefaultPreAllocCacheNumNodes, + clock: clock.NewDefaultClock(), } } @@ -75,6 +86,13 @@ func OptionSetChannelCacheSize(n int) OptionModifier { } } +// OptionSetPreAllocCacheNumNodes sets the PreAllocCacheNumNodes to n. +func OptionSetPreAllocCacheNumNodes(n int) OptionModifier { + return func(o *Options) { + o.PreAllocCacheNumNodes = n + } +} + // OptionSetSyncFreelist allows the database to sync its freelist. func OptionSetSyncFreelist(b bool) OptionModifier { return func(o *Options) { diff --git a/lnd.go b/lnd.go index 7dcc7bf55e..23f9b835ea 100644 --- a/lnd.go +++ b/lnd.go @@ -22,6 +22,7 @@ import ( "sync" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcutil" "github.com/btcsuite/btcwallet/wallet" @@ -1679,14 +1680,27 @@ func initializeDatabases(ctx context.Context, "instances") } - // Otherwise, we'll open two instances, one for the state we only need - // locally, and the other for things we want to ensure are replicated. - dbs.graphDB, err = channeldb.CreateWithBackend( - databaseBackends.GraphDB, + dbOptions := []channeldb.OptionModifier{ channeldb.OptionSetRejectCacheSize(cfg.Caches.RejectCacheSize), channeldb.OptionSetChannelCacheSize(cfg.Caches.ChannelCacheSize), channeldb.OptionSetBatchCommitInterval(cfg.DB.BatchCommitInterval), channeldb.OptionDryRunMigration(cfg.DryRunMigration), + } + + // We want to pre-allocate the channel graph cache according to what we + // expect for mainnet to speed up memory allocation. + if cfg.ActiveNetParams.Name == chaincfg.MainNetParams.Name { + dbOptions = append( + dbOptions, channeldb.OptionSetPreAllocCacheNumNodes( + channeldb.DefaultPreAllocCacheNumNodes, + ), + ) + } + + // Otherwise, we'll open two instances, one for the state we only need + // locally, and the other for things we want to ensure are replicated. + dbs.graphDB, err = channeldb.CreateWithBackend( + databaseBackends.GraphDB, dbOptions..., ) switch { // Give the DB a chance to dry run the migration. Since we know that diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index d438de8248..426faa0998 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -171,7 +171,7 @@ func makeTestGraph() (*channeldb.ChannelGraph, kvdb.Backend, func(), error) { opts := channeldb.DefaultOptions() graph, err := channeldb.NewChannelGraph( backend, opts.RejectCacheSize, opts.ChannelCacheSize, - opts.BatchCommitInterval, + opts.BatchCommitInterval, opts.PreAllocCacheNumNodes, ) if err != nil { cleanUp() From a5202a89e65679dcea074f8db91772034119ac66 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 21 Sep 2021 19:18:26 +0200 Subject: [PATCH 13/15] docs: add release notes --- docs/release-notes/release-notes-0.14.0.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/release-notes/release-notes-0.14.0.md b/docs/release-notes/release-notes-0.14.0.md index ee748ec1a8..9ceb4a2c17 100644 --- a/docs/release-notes/release-notes-0.14.0.md +++ b/docs/release-notes/release-notes-0.14.0.md @@ -59,6 +59,18 @@ in `lnd`, saving developer time and limiting the potential for bugs. Instructions for enabling Postgres can be found in [docs/postgres.md](../postgres.md). +### In-memory path finding + +Finding a path through the channel graph for sending a payment doesn't involve +any database queries anymore. The [channel graph is now kept fully +in-memory](https://github.com/lightningnetwork/lnd/pull/5642) for up a massive +performance boost when calling `QueryRoutes` or any of the `SendPayment` +variants. Keeping the full graph in memory naturally comes with increased RAM +usage. Users running `lnd` on low-memory systems are advised to run with the +`routing.strictgraphpruning=true` configuration option that more aggressively +removes zombie channels from the graph, reducing the number of channels that +need to be kept in memory. + ## Protocol Extensions ### Explicit Channel Negotiation From 6240851f93b8e58f671733dbe8d6205bec4c2c17 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Tue, 28 Sep 2021 13:23:02 +0200 Subject: [PATCH 14/15] channeldb: optimize memory usage of initial cache fill With this commit we use an optimized version of the node iteration that causes fewer memory allocations by only loading the part of the graph node that we actually need to know for the cache. --- channeldb/graph.go | 125 ++++++++++++++++++++++++++++++++++++---- channeldb/graph_test.go | 103 ++++++++++++++++++++++++++++++++- 2 files changed, 215 insertions(+), 13 deletions(-) diff --git a/channeldb/graph.go b/channeldb/graph.go index 69c86d1f92..68f4fb537a 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -211,8 +211,8 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, startTime := time.Now() log.Debugf("Populating in-memory channel graph, this might take a " + "while...") - err := g.ForEachNode(func(tx kvdb.RTx, node *LightningNode) error { - return g.graphCache.AddNode(tx, &graphCacheNode{node}) + err := g.ForEachNodeCacheable(func(tx kvdb.RTx, node GraphCacheNode) error { + return g.graphCache.AddNode(tx, node) }) if err != nil { return nil, err @@ -468,6 +468,47 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro return kvdb.View(c.db, traversal, func() {}) } +// ForEachNodeCacheable iterates through all the stored vertices/nodes in the +// graph, executing the passed callback with each node encountered. If the +// callback returns an error, then the transaction is aborted and the iteration +// stops early. +func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx, + GraphCacheNode) error) error { + + traversal := func(tx kvdb.RTx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.ReadBucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + cacheableNode := newGraphCacheNode(route.Vertex{}, nil) + return nodes.ForEach(func(pubKey, nodeBytes []byte) error { + // If this is the source key, then we skip this + // iteration as the value for this key is a pubKey + // rather than raw node information. + if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { + return nil + } + + nodeReader := bytes.NewReader(nodeBytes) + err := deserializeLightningNodeCacheable( + nodeReader, cacheableNode, + ) + if err != nil { + return err + } + + // Execute the callback, the transaction will abort if + // this returns an error. + return cb(tx, cacheableNode) + }) + } + + return kvdb.View(c.db, traversal, func() {}) +} + // SourceNode returns the source node of the graph. The source node is treated // as the center node within a star-graph. This method may be used to kick off // a path finding algorithm in order to explore the reachability of another @@ -559,8 +600,10 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode, r := &batch.Request{ Update: func(tx kvdb.RwTx) error { - wNode := &graphCacheNode{node} - if err := c.graphCache.AddNode(tx, wNode); err != nil { + cNode := newGraphCacheNode( + node.PubKeyBytes, node.Features, + ) + if err := c.graphCache.AddNode(tx, cNode); err != nil { return err } @@ -2532,17 +2575,30 @@ func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) ( // graphCacheNode is a struct that wraps a LightningNode in a way that it can be // cached in the graph cache. type graphCacheNode struct { - lnNode *LightningNode + pubKeyBytes route.Vertex + features *lnwire.FeatureVector + + nodeScratch [8]byte +} + +// newGraphCacheNode returns a new cache optimized node. +func newGraphCacheNode(pubKey route.Vertex, + features *lnwire.FeatureVector) *graphCacheNode { + + return &graphCacheNode{ + pubKeyBytes: pubKey, + features: features, + } } // PubKey returns the node's public identity key. -func (w *graphCacheNode) PubKey() route.Vertex { - return w.lnNode.PubKeyBytes +func (n *graphCacheNode) PubKey() route.Vertex { + return n.pubKeyBytes } // Features returns the node's features. -func (w *graphCacheNode) Features() *lnwire.FeatureVector { - return w.lnNode.Features +func (n *graphCacheNode) Features() *lnwire.FeatureVector { + return n.features } // ForEachChannel iterates through all channels of this node, executing the @@ -2553,11 +2609,11 @@ func (w *graphCacheNode) Features() *lnwire.FeatureVector { // halted with the error propagated back up to the caller. // // Unknown policies are passed into the callback as nil values. -func (w *graphCacheNode) ForEachChannel(tx kvdb.RTx, +func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx, cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - return w.lnNode.ForEachChannel(tx, cb) + return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb) } var _ GraphCacheNode = (*graphCacheNode)(nil) @@ -3865,6 +3921,53 @@ func fetchLightningNode(nodeBucket kvdb.RBucket, return deserializeLightningNode(nodeReader) } +func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error { + // Always populate a feature vector, even if we don't have a node + // announcement and short circuit below. + node.features = lnwire.EmptyFeatureVector() + + // Skip ahead: + // - LastUpdate (8 bytes) + if _, err := r.Read(node.nodeScratch[:]); err != nil { + return err + } + + if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil { + return err + } + + // Read the node announcement flag. + if _, err := r.Read(node.nodeScratch[:2]); err != nil { + return err + } + hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2]) + + // The rest of the data is optional, and will only be there if we got a + // node announcement for this node. + if hasNodeAnn == 0 { + return nil + } + + // We did get a node announcement for this node, so we'll have the rest + // of the data available. + var rgb uint8 + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &rgb); err != nil { + return err + } + + if _, err := wire.ReadVarString(r, 0); err != nil { + return err + } + + return node.features.Decode(r) +} + func deserializeLightningNode(r io.Reader) (LightningNode, error) { var ( node LightningNode diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index b43dbb972e..27b979842a 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -21,6 +21,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -1145,6 +1146,58 @@ func TestGraphTraversal(t *testing.T) { require.Equal(t, numChannels, numNodeChans) } +// TestGraphTraversalCacheable tests that the memory optimized node traversal is +// working correctly. +func TestGraphTraversalCacheable(t *testing.T) { + t.Parallel() + + graph, cleanUp, err := MakeTestGraph() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + // We'd like to test some of the graph traversal capabilities within + // the DB, so we'll create a series of fake nodes to insert into the + // graph. And we'll create 5 channels between the first two nodes. + const numNodes = 20 + const numChannels = 5 + chanIndex, _ := fillTestGraph(t, graph, numNodes, numChannels) + + // Create a map of all nodes with the iteration we know works (because + // it is tested in another test). + nodeMap := make(map[route.Vertex]struct{}) + err = graph.ForEachNode(func(tx kvdb.RTx, n *LightningNode) error { + nodeMap[n.PubKeyBytes] = struct{}{} + + return nil + }) + require.NoError(t, err) + require.Len(t, nodeMap, numNodes) + + // Iterate through all the known channels within the graph DB by + // iterating over each node, once again if the map is empty that + // indicates that all edges have properly been reached. + err = graph.ForEachNodeCacheable( + func(tx kvdb.RTx, node GraphCacheNode) error { + delete(nodeMap, node.PubKey()) + + return node.ForEachChannel( + tx, func(tx kvdb.RTx, info *ChannelEdgeInfo, + policy *ChannelEdgePolicy, + policy2 *ChannelEdgePolicy) error { + + delete(chanIndex, info.ChannelID) + return nil + }, + ) + }, + ) + require.NoError(t, err) + require.Len(t, nodeMap, 0) + require.Len(t, chanIndex, 0) +} + func TestGraphCacheTraversal(t *testing.T) { t.Parallel() @@ -1164,6 +1217,8 @@ func TestGraphCacheTraversal(t *testing.T) { // properly been reached. numNodeChans := 0 for _, node := range nodeList { + node := node + err = graph.graphCache.ForEachChannel( node.PubKeyBytes, func(d *DirectedChannel) error { delete(chanIndex, d.ChannelID) @@ -1197,7 +1252,7 @@ func TestGraphCacheTraversal(t *testing.T) { require.Equal(t, numChannels*2*(numNodes-1), numNodeChans) } -func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes, +func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, numChannels int) (map[uint64]struct{}, []*LightningNode) { nodes := make([]*LightningNode, numNodes) @@ -1237,7 +1292,7 @@ func fillTestGraph(t *testing.T, graph *ChannelGraph, numNodes, for i := 0; i < numChannels; i++ { txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64((n << 4) + i + 1) + chanID := uint64((n << 8) + i + 1) op := wire.OutPoint{ Hash: txHash, Index: 0, @@ -3592,3 +3647,47 @@ func TestBatchedUpdateEdgePolicy(t *testing.T) { require.Nil(t, err) } } + +// BenchmarkForEachChannel is a benchmark test that measures the number of +// allocations and the total memory consumed by the full graph traversal. +func BenchmarkForEachChannel(b *testing.B) { + graph, cleanUp, err := MakeTestGraph() + require.Nil(b, err) + defer cleanUp() + + const numNodes = 100 + const numChannels = 4 + _, _ = fillTestGraph(b, graph, numNodes, numChannels) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var ( + totalCapacity btcutil.Amount + maxHTLCs lnwire.MilliSatoshi + ) + err := graph.ForEachNodeCacheable( + func(tx kvdb.RTx, n GraphCacheNode) error { + return n.ForEachChannel( + tx, func(tx kvdb.RTx, + info *ChannelEdgeInfo, + policy *ChannelEdgePolicy, + policy2 *ChannelEdgePolicy) error { + + // We need to do something with + // the data here, otherwise the + // compiler is going to optimize + // this away, and we get bogus + // results. + totalCapacity += info.Capacity + maxHTLCs += policy.MaxHTLC + maxHTLCs += policy2.MaxHTLC + + return nil + }, + ) + }, + ) + require.NoError(b, err) + } +} From 493262e253921cb9a6bb9a144ff10a087286d86e Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 29 Sep 2021 17:48:01 +0200 Subject: [PATCH 15/15] itest: fix flake in update_channel_status itest This commit fixes a flake in the channel status update itest that occurred if Carol got a channel edge update for a channel before it heard of the channel in the first place. To avoid that, we wait for Carol to sync her graph before sending out channel edge or policy updates. As always when we touch itest code, we bring the formatting and use of the require library up to date. --- lntest/itest/lnd_channel_graph_test.go | 112 +++++++++++-------------- 1 file changed, 50 insertions(+), 62 deletions(-) diff --git a/lntest/itest/lnd_channel_graph_test.go b/lntest/itest/lnd_channel_graph_test.go index c7040e02e2..747f018fd4 100644 --- a/lntest/itest/lnd_channel_graph_test.go +++ b/lntest/itest/lnd_channel_graph_test.go @@ -25,24 +25,20 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { ctxb := context.Background() // Create two fresh nodes and open a channel between them. - alice := net.NewNode( - t.t, "Alice", []string{ - "--minbackoff=10s", - "--chan-enable-timeout=1.5s", - "--chan-disable-timeout=3s", - "--chan-status-sample-interval=.5s", - }, - ) + alice := net.NewNode(t.t, "Alice", []string{ + "--minbackoff=10s", + "--chan-enable-timeout=1.5s", + "--chan-disable-timeout=3s", + "--chan-status-sample-interval=.5s", + }) defer shutdownAndAssert(net, t, alice) - bob := net.NewNode( - t.t, "Bob", []string{ - "--minbackoff=10s", - "--chan-enable-timeout=1.5s", - "--chan-disable-timeout=3s", - "--chan-status-sample-interval=.5s", - }, - ) + bob := net.NewNode(t.t, "Bob", []string{ + "--minbackoff=10s", + "--chan-enable-timeout=1.5s", + "--chan-disable-timeout=3s", + "--chan-status-sample-interval=.5s", + }) defer shutdownAndAssert(net, t, bob) // Connect Alice to Bob. @@ -55,36 +51,32 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { // being the sole funder of the channel. chanAmt := btcutil.Amount(100000) chanPoint := openChannelAndAssert( - t, net, alice, bob, - lntest.OpenChannelParams{ + t, net, alice, bob, lntest.OpenChannelParams{ Amt: chanAmt, }, ) // Wait for Alice and Bob to receive the channel edge from the // funding manager. - ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + defer cancel() err := alice.WaitForNetworkChannelOpen(ctxt, chanPoint) - if err != nil { - t.Fatalf("alice didn't see the alice->bob channel before "+ - "timeout: %v", err) - } + require.NoError(t.t, err, "alice didn't see the alice->bob channel") - ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) err = bob.WaitForNetworkChannelOpen(ctxt, chanPoint) - if err != nil { - t.Fatalf("bob didn't see the bob->alice channel before "+ - "timeout: %v", err) - } + require.NoError(t.t, err, "bob didn't see the alice->bob channel") - // Launch a node for Carol which will connect to Alice and Bob in - // order to receive graph updates. This will ensure that the - // channel updates are propagated throughout the network. + // Launch a node for Carol which will connect to Alice and Bob in order + // to receive graph updates. This will ensure that the channel updates + // are propagated throughout the network. carol := net.NewNode(t.t, "Carol", nil) defer shutdownAndAssert(net, t, carol) + // Connect both Alice and Bob to the new node Carol, so she can sync her + // graph. net.ConnectNodes(t.t, alice, carol) net.ConnectNodes(t.t, bob, carol) + waitForGraphSync(t, carol) // assertChannelUpdate checks that the required policy update has // happened on the given node. @@ -109,12 +101,11 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { ChanPoint: chanPoint, Action: action, } - ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + defer cancel() + _, err = node.RouterClient.UpdateChanStatus(ctxt, req) - if err != nil { - t.Fatalf("unable to call UpdateChanStatus for %s's node: %v", - node.Name(), err) - } + require.NoErrorf(t.t, err, "UpdateChanStatus") } // assertEdgeDisabled ensures that a given node has the correct @@ -122,26 +113,30 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { assertEdgeDisabled := func(node *lntest.HarnessNode, chanPoint *lnrpc.ChannelPoint, disabled bool) { - var predErr error - err = wait.Predicate(func() bool { + outPoint, err := lntest.MakeOutpoint(chanPoint) + require.NoError(t.t, err) + + err = wait.NoError(func() error { req := &lnrpc.ChannelGraphRequest{ IncludeUnannounced: true, } - ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + defer cancel() + chanGraph, err := node.DescribeGraph(ctxt, req) if err != nil { - predErr = fmt.Errorf("unable to query node %v's graph: %v", node, err) - return false + return fmt.Errorf("unable to query node %v's "+ + "graph: %v", node, err) } numEdges := len(chanGraph.Edges) if numEdges != 1 { - predErr = fmt.Errorf("expected to find 1 edge in the graph, found %d", numEdges) - return false + return fmt.Errorf("expected to find 1 edge in "+ + "the graph, found %d", numEdges) } edge := chanGraph.Edges[0] - if edge.ChanPoint != chanPoint.GetFundingTxidStr() { - predErr = fmt.Errorf("expected chan_point %v, got %v", - chanPoint.GetFundingTxidStr(), edge.ChanPoint) + if edge.ChanPoint != outPoint.String() { + return fmt.Errorf("expected chan_point %v, "+ + "got %v", outPoint, edge.ChanPoint) } var policy *lnrpc.RoutingPolicy if node.PubKeyStr == edge.Node1Pub { @@ -150,15 +145,14 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { policy = edge.Node2Policy } if disabled != policy.Disabled { - predErr = fmt.Errorf("expected policy.Disabled to be %v, "+ - "but policy was %v", disabled, policy) - return false + return fmt.Errorf("expected policy.Disabled "+ + "to be %v, but policy was %v", disabled, + policy) } - return true + + return nil }, defaultTimeout) - if err != nil { - t.Fatalf("%v", predErr) - } + require.NoError(t.t, err) } // When updating the state of the channel between Alice and Bob, we @@ -193,9 +187,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { // disconnections from automatically disabling the channel again // (we don't want to clutter the network with channels that are // falsely advertised as enabled when they don't work). - if err := net.DisconnectNodes(alice, bob); err != nil { - t.Fatalf("unable to disconnect Alice from Bob: %v", err) - } + require.NoError(t.t, net.DisconnectNodes(alice, bob)) expectedPolicy.Disabled = true assertChannelUpdate(alice, expectedPolicy) assertChannelUpdate(bob, expectedPolicy) @@ -217,9 +209,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { expectedPolicy.Disabled = true assertChannelUpdate(alice, expectedPolicy) - if err := net.DisconnectNodes(alice, bob); err != nil { - t.Fatalf("unable to disconnect Alice from Bob: %v", err) - } + require.NoError(t.t, net.DisconnectNodes(alice, bob)) // Bob sends a "Disabled = true" update upon detecting the // disconnect. @@ -237,9 +227,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) { // note the asymmetry between manual enable and manual disable! assertEdgeDisabled(alice, chanPoint, true) - if err := net.DisconnectNodes(alice, bob); err != nil { - t.Fatalf("unable to disconnect Alice from Bob: %v", err) - } + require.NoError(t.t, net.DisconnectNodes(alice, bob)) // Bob sends a "Disabled = true" update upon detecting the // disconnect.