Skip to content

Commit

Permalink
Merge pull request ipfs#405 from libp2p/feat/simplify-bootstrapping
Browse files Browse the repository at this point in the history
fix and simplify some bootstrapping logic
  • Loading branch information
Stebalien authored Nov 6, 2019
2 parents 4fd6498 + a33b0b9 commit 8ecf938
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 156 deletions.
18 changes: 9 additions & 9 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ type IpfsDHT struct {

bucketSize int

bootstrapCfg opts.BootstrapConfig

triggerAutoBootstrap bool
triggerBootstrap chan struct{}
latestSelfWalk time.Time // the last time we looked-up our own peerID in the network
autoRefresh bool
rtRefreshQueryTimeout time.Duration
rtRefreshPeriod time.Duration
triggerRtRefresh chan struct{}
}

// Assert that IPFS assumptions about interfaces aren't broken. These aren't a
Expand All @@ -92,7 +91,9 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er
return nil, err
}
dht := makeDHT(ctx, h, cfg.Datastore, cfg.Protocols, cfg.BucketSize)
dht.bootstrapCfg = cfg.BootstrapConfig
dht.autoRefresh = cfg.RoutingTable.AutoRefresh
dht.rtRefreshPeriod = cfg.RoutingTable.RefreshPeriod
dht.rtRefreshQueryTimeout = cfg.RoutingTable.RefreshQueryTimeout

// register for network notifs.
dht.host.Network().Notify((*netNotifiee)(dht))
Expand All @@ -105,14 +106,13 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er

dht.proc.AddChild(dht.providers.Process())
dht.Validator = cfg.Validator
dht.triggerAutoBootstrap = cfg.TriggerAutoBootstrap

if !cfg.Client {
for _, p := range cfg.Protocols {
h.SetStreamHandler(p, dht.handleNewStream)
}
}
dht.startBootstrapping()
dht.startRefreshing()
return dht, nil
}

Expand Down Expand Up @@ -163,7 +163,7 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching, protocols []p
routingTable: rt,
protocols: protocols,
bucketSize: bucketSize,
triggerBootstrap: make(chan struct{}),
triggerRtRefresh: make(chan struct{}),
}

dht.ctx = dht.newContextWithLocalTags(ctx)
Expand Down
151 changes: 60 additions & 91 deletions dht_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,20 @@ package dht

import (
"context"
"fmt"
"strings"
"sync"
"time"

process "github.com/jbenet/goprocess"
processctx "github.com/jbenet/goprocess/context"
"github.com/libp2p/go-libp2p-core/routing"
"github.com/multiformats/go-multiaddr"
_ "github.com/multiformats/go-multiaddr-dns"
"github.com/pkg/errors"
)

var DefaultBootstrapPeers []multiaddr.Multiaddr

var minRTBootstrapThreshold = 4
// Minimum number of peers in the routing table. If we drop below this and we
// see a new peer, we trigger a bootstrap round.
var minRTRefreshThreshold = 4

func init() {
for _, s := range []string{
Expand All @@ -43,71 +41,53 @@ func init() {
}
}

// Start the bootstrap worker.
func (dht *IpfsDHT) startBootstrapping() error {
// Start the refresh worker.
func (dht *IpfsDHT) startRefreshing() error {
// scan the RT table periodically & do a random walk on k-buckets that haven't been queried since the given bucket period
dht.proc.Go(func(proc process.Process) {
ctx := processctx.OnClosingContext(proc)

scanInterval := time.NewTicker(dht.bootstrapCfg.BucketPeriod)
defer scanInterval.Stop()
refreshTicker := time.NewTicker(dht.rtRefreshPeriod)
defer refreshTicker.Stop()

// run bootstrap if option is set
if dht.triggerAutoBootstrap {
if err := dht.doBootstrap(ctx, true); err != nil {
logger.Warningf("bootstrap error: %s", err)
}
// refresh if option is set
if dht.autoRefresh {
dht.doRefresh(ctx)
} else {
// disable the "auto-bootstrap" ticker so that no more ticks are sent to this channel
scanInterval.Stop()
// disable the "auto-refresh" ticker so that no more ticks are sent to this channel
refreshTicker.Stop()
}

for {
select {
case now := <-scanInterval.C:
walkSelf := now.After(dht.latestSelfWalk.Add(dht.bootstrapCfg.SelfQueryInterval))
if err := dht.doBootstrap(ctx, walkSelf); err != nil {
logger.Warning("bootstrap error: %s", err)
}
case <-dht.triggerBootstrap:
logger.Infof("triggering a bootstrap: RT has %d peers", dht.routingTable.Size())
if err := dht.doBootstrap(ctx, true); err != nil {
logger.Warning("bootstrap error: %s", err)
}
case <-refreshTicker.C:
case <-dht.triggerRtRefresh:
logger.Infof("triggering a refresh: RT has %d peers", dht.routingTable.Size())
case <-ctx.Done():
return
}
dht.doRefresh(ctx)
}
})

return nil
}

func (dht *IpfsDHT) doBootstrap(ctx context.Context, walkSelf bool) error {
if walkSelf {
if err := dht.selfWalk(ctx); err != nil {
return fmt.Errorf("self walk: error: %s", err)
}
dht.latestSelfWalk = time.Now()
}

if err := dht.bootstrapBuckets(ctx); err != nil {
return fmt.Errorf("bootstrap buckets: error bootstrapping: %s", err)
}

return nil
func (dht *IpfsDHT) doRefresh(ctx context.Context) {
dht.selfWalk(ctx)
dht.refreshBuckets(ctx)
}

// bootstrapBuckets scans the routing table, and does a random walk on k-buckets that haven't been queried since the given bucket period
func (dht *IpfsDHT) bootstrapBuckets(ctx context.Context) error {
// refreshBuckets scans the routing table, and does a random walk on k-buckets that haven't been queried since the given bucket period
func (dht *IpfsDHT) refreshBuckets(ctx context.Context) {
doQuery := func(bucketId int, target string, f func(context.Context) error) error {
logger.Infof("starting bootstrap query for bucket %d to %s (routing table size was %d)",
logger.Infof("starting refreshing bucket %d to %s (routing table size was %d)",
bucketId, target, dht.routingTable.Size())
defer func() {
logger.Infof("finished bootstrap query for bucket %d to %s (routing table size is now %d)",
logger.Infof("finished refreshing bucket %d to %s (routing table size is now %d)",
bucketId, target, dht.routingTable.Size())
}()
queryCtx, cancel := context.WithTimeout(ctx, dht.bootstrapCfg.Timeout)
queryCtx, cancel := context.WithTimeout(ctx, dht.rtRefreshQueryTimeout)
defer cancel()
err := f(queryCtx)
if err == context.DeadlineExceeded && queryCtx.Err() == context.DeadlineExceeded && ctx.Err() == nil {
Expand All @@ -117,69 +97,58 @@ func (dht *IpfsDHT) bootstrapBuckets(ctx context.Context) error {
}

buckets := dht.routingTable.GetAllBuckets()
var wg sync.WaitGroup
errChan := make(chan error)

if len(buckets) > 16 {
// Don't bother bootstrapping more than 16 buckets.
// GenRandPeerID can't generate target peer IDs with more than
// 16 bits specified anyways.
buckets = buckets[:16]
}
for bucketID, bucket := range buckets {
if time.Since(bucket.RefreshedAt()) > dht.bootstrapCfg.BucketPeriod {
wg.Add(1)
go func(bucketID int, errChan chan<- error) {
defer wg.Done()
// gen rand peer in the bucket
randPeerInBucket := dht.routingTable.GenRandPeerID(bucketID)

// walk to the generated peer
walkFnc := func(c context.Context) error {
_, err := dht.FindPeer(ctx, randPeerInBucket)
if err == routing.ErrNotFound {
return nil
}
return err
}

if err := doQuery(bucketID, randPeerInBucket.String(), walkFnc); err != nil {
errChan <- errors.Wrapf(err, "failed to do a random walk on bucket %d", bucketID)
}
}(bucketID, errChan)
if time.Since(bucket.RefreshedAt()) <= dht.rtRefreshPeriod {
continue
}
// gen rand peer in the bucket
randPeerInBucket := dht.routingTable.GenRandPeerID(bucketID)

// walk to the generated peer
walkFnc := func(c context.Context) error {
_, err := dht.FindPeer(c, randPeerInBucket)
if err == routing.ErrNotFound {
return nil
}
return err
}
}

// wait for all walks to finish & close the error channel
go func() {
wg.Wait()
close(errChan)
}()

// accumulate errors from all go-routines. ensures wait group is completed by reading errChan until closure.
var errStrings []string
for err := range errChan {
errStrings = append(errStrings, err.Error())
}
if len(errStrings) == 0 {
return nil
} else {
return fmt.Errorf("errors encountered while running bootstrap on RT:\n%s", strings.Join(errStrings, "\n"))
if err := doQuery(bucketID, randPeerInBucket.String(), walkFnc); err != nil {
logger.Warningf("failed to do a random walk on bucket %d: %s", bucketID, err)
}
}
}

// Traverse the DHT toward the self ID
func (dht *IpfsDHT) selfWalk(ctx context.Context) error {
queryCtx, cancel := context.WithTimeout(ctx, dht.bootstrapCfg.Timeout)
func (dht *IpfsDHT) selfWalk(ctx context.Context) {
queryCtx, cancel := context.WithTimeout(ctx, dht.rtRefreshQueryTimeout)
defer cancel()
_, err := dht.FindPeer(queryCtx, dht.self)
if err == routing.ErrNotFound {
return nil
return
}
return err
logger.Warningf("failed to query self during routing table refresh: %s", err)
}

// Bootstrap tells the DHT to get into a bootstrapped state.
// Bootstrap tells the DHT to get into a bootstrapped state satisfying the
// IpfsRouter interface.
//
// Note: the context is ignored.
// This just calls `RefreshRoutingTable`.
func (dht *IpfsDHT) Bootstrap(_ context.Context) error {
dht.RefreshRoutingTable()
return nil
}

// RefreshRoutingTable tells the DHT to refresh it's routing tables.
func (dht *IpfsDHT) RefreshRoutingTable() {
select {
case dht.triggerBootstrap <- struct{}{}:
case dht.triggerRtRefresh <- struct{}{}:
default:
}
return nil
}
22 changes: 11 additions & 11 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func setupDHT(ctx context.Context, t *testing.T, client bool) *IpfsDHT {
bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)),
opts.Client(client),
opts.NamespacedValidator("v", blankValidator{}),
opts.DisableAutoBootstrap(),
opts.DisableAutoRefresh(),
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -191,7 +191,7 @@ func bootstrap(t *testing.T, ctx context.Context, dhts []*IpfsDHT) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

logger.Debugf("Bootstrapping DHTs...")
logger.Debugf("refreshing DHTs routing tables...")

// tried async. sequential fares much better. compare:
// 100 async https://gist.github.com/jbenet/56d12f0578d5f34810b2
Expand All @@ -201,7 +201,7 @@ func bootstrap(t *testing.T, ctx context.Context, dhts []*IpfsDHT) {
start := rand.Intn(len(dhts)) // randomize to decrease bias.
for i := range dhts {
dht := dhts[(start+i)%len(dhts)]
dht.Bootstrap(ctx)
dht.RefreshRoutingTable()
}
}

Expand Down Expand Up @@ -639,7 +639,7 @@ func printRoutingTables(dhts []*IpfsDHT) {
}
}

func TestBootstrap(t *testing.T) {
func TestRefresh(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
Expand Down Expand Up @@ -689,7 +689,7 @@ func TestBootstrap(t *testing.T) {
}
}

func TestBootstrapBelowMinRTThreshold(t *testing.T) {
func TestRefreshBelowMinRTThreshold(t *testing.T) {
ctx := context.Background()

// enable auto bootstrap on A
Expand Down Expand Up @@ -721,7 +721,7 @@ func TestBootstrapBelowMinRTThreshold(t *testing.T) {
connect(t, ctx, dhtB, dhtC)

// we ONLY init bootstrap on A
dhtA.Bootstrap(ctx)
dhtA.RefreshRoutingTable()
// and wait for one round to complete i.e. A should be connected to both B & C
waitForWellFormedTables(t, []*IpfsDHT{dhtA}, 2, 2, 20*time.Second)

Expand Down Expand Up @@ -749,7 +749,7 @@ func TestBootstrapBelowMinRTThreshold(t *testing.T) {
assert.Equal(t, dhtE.self, dhtA.routingTable.Find(dhtE.self), "A's routing table should have peer E!")
}

func TestPeriodicBootstrap(t *testing.T) {
func TestPeriodicRefresh(t *testing.T) {
if ci.IsRunning() {
t.Skip("skipping on CI. highly timing dependent")
}
Expand Down Expand Up @@ -795,7 +795,7 @@ func TestPeriodicBootstrap(t *testing.T) {

t.Logf("bootstrapping them so they find each other. %d", nDHTs)
for _, dht := range dhts {
go dht.Bootstrap(ctx)
dht.RefreshRoutingTable()
}

// this is async, and we dont know when it's finished with one cycle, so keep checking
Expand Down Expand Up @@ -1428,7 +1428,7 @@ func TestGetSetPluggedProtocol(t *testing.T) {
opts.Protocols("/esh/dht"),
opts.Client(false),
opts.NamespacedValidator("v", blankValidator{}),
opts.DisableAutoBootstrap(),
opts.DisableAutoRefresh(),
}

dhtA, err := New(ctx, bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)), os...)
Expand Down Expand Up @@ -1467,7 +1467,7 @@ func TestGetSetPluggedProtocol(t *testing.T) {
opts.Protocols("/esh/dht"),
opts.Client(false),
opts.NamespacedValidator("v", blankValidator{}),
opts.DisableAutoBootstrap(),
opts.DisableAutoRefresh(),
}...)
if err != nil {
t.Fatal(err)
Expand All @@ -1477,7 +1477,7 @@ func TestGetSetPluggedProtocol(t *testing.T) {
opts.Protocols("/lsr/dht"),
opts.Client(false),
opts.NamespacedValidator("v", blankValidator{}),
opts.DisableAutoBootstrap(),
opts.DisableAutoRefresh(),
}...)
if err != nil {
t.Fatal(err)
Expand Down
Loading

0 comments on commit 8ecf938

Please sign in to comment.