diff --git a/dial_sync_test.go b/dial_sync_test.go index 441c8b7d..db480bd5 100644 --- a/dial_sync_test.go +++ b/dial_sync_test.go @@ -241,7 +241,7 @@ func TestDialSelf(t *testing.T) { defer cancel() self := peer.ID("ABC") - s := NewSwarm(ctx, self, nil, nil) + s := NewSwarm(self, nil, nil) defer s.Close() // this should fail diff --git a/dial_test.go b/dial_test.go index 6258d0ed..6b319e4e 100644 --- a/dial_test.go +++ b/dial_test.go @@ -7,28 +7,33 @@ import ( "testing" "time" - addrutil "github.com/libp2p/go-addr-util" + . "github.com/libp2p/go-libp2p-swarm" + addrutil "github.com/libp2p/go-addr-util" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/libp2p/go-libp2p-core/transport" - testutil "github.com/libp2p/go-libp2p-core/test" + "github.com/libp2p/go-libp2p-core/transport" swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-libp2p-testing/ci" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" - . "github.com/libp2p/go-libp2p-swarm" + "github.com/stretchr/testify/require" ) func init() { transport.DialTimeout = time.Second } -func closeSwarms(swarms []*Swarm) { +type swarmWithBackoff interface { + network.Network + Backoff() *DialBackoff +} + +func closeSwarms(swarms []network.Network) { for _, s := range swarms { s.Close() } @@ -36,50 +41,37 @@ func closeSwarms(swarms []*Swarm) { func TestBasicDialPeer(t *testing.T) { t.Parallel() - ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) defer closeSwarms(swarms) s1 := swarms[0] s2 := swarms[1] s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) - c, err := s1.DialPeer(ctx, s2.LocalPeer()) - if err != nil { - t.Fatal(err) - } - - s, err := c.NewStream(ctx) - if err != nil { - t.Fatal(err) - } + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + s, err := c.NewStream(context.Background()) + require.NoError(t, err) s.Close() } func TestDialWithNoListeners(t *testing.T) { t.Parallel() - ctx := context.Background() - s1 := makeDialOnlySwarm(ctx, t) - - swarms := makeSwarms(ctx, t, 1) + s1 := makeDialOnlySwarm(t) + swarms := makeSwarms(t, 1) defer closeSwarms(swarms) s2 := swarms[0] s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) - c, err := s1.DialPeer(ctx, s2.LocalPeer()) - if err != nil { - t.Fatal(err) - } - - s, err := c.NewStream(ctx) - if err != nil { - t.Fatal(err) - } + c, err := s1.DialPeer(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + s, err := c.NewStream(context.Background()) + require.NoError(t, err) s.Close() } @@ -104,12 +96,12 @@ func TestSimultDials(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 2, swarmt.OptDisableReuseport) + swarms := makeSwarms(t, 2, swarmt.OptDisableReuseport) // connect everyone { var wg sync.WaitGroup - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) s.Peerstore().AddAddr(dst, addr, peerstore.TempAddrTTL) @@ -175,7 +167,7 @@ func TestDialWait(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 1) + swarms := makeSwarms(t, 1) s1 := swarms[0] defer s1.Close() @@ -201,7 +193,7 @@ func TestDialWait(t *testing.T) { t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) } - if !s1.Backoff().Backoff(s2p, s2addr) { + if !s1.(swarmWithBackoff).Backoff().Backoff(s2p, s2addr) { t.Error("s2 should now be on backoff") } } @@ -215,7 +207,7 @@ func TestDialBackoff(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] s2 := swarms[1] defer s1.Close() @@ -338,10 +330,10 @@ func TestDialBackoff(t *testing.T) { } // check backoff state - if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { t.Error("s2 should not be on backoff") } - if !s1.Backoff().Backoff(s3p, s3addr) { + if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { t.Error("s3 should be on backoff") } @@ -408,10 +400,10 @@ func TestDialBackoff(t *testing.T) { } // check backoff state (the same) - if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { t.Error("s2 should not be on backoff") } - if !s1.Backoff().Backoff(s3p, s3addr) { + if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { t.Error("s3 should be on backoff") } } @@ -422,7 +414,7 @@ func TestDialBackoffClears(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] s2 := swarms[1] defer s1.Close() @@ -453,7 +445,7 @@ func TestDialBackoffClears(t *testing.T) { t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) } - if !s1.Backoff().Backoff(s2.LocalPeer(), s2bad) { + if !s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { t.Error("s2 should now be on backoff") } else { t.Log("correctly added to backoff") @@ -480,7 +472,7 @@ func TestDialBackoffClears(t *testing.T) { t.Log("correctly connected") } - if s1.Backoff().Backoff(s2.LocalPeer(), s2bad) { + if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { t.Error("s2 should no longer be on backoff") } else { t.Log("correctly cleared backoff") @@ -491,7 +483,7 @@ func TestDialPeerFailed(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) defer closeSwarms(swarms) testedSwarm, targetSwarm := swarms[0], swarms[1] @@ -530,7 +522,7 @@ func TestDialPeerFailed(t *testing.T) { func TestDialExistingConnection(t *testing.T) { ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) defer closeSwarms(swarms) s1 := swarms[0] s2 := swarms[1] @@ -574,7 +566,7 @@ func TestDialSimultaneousJoin(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] s2 := swarms[1] defer s1.Close() @@ -676,12 +668,10 @@ func TestDialSelf2(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] defer s1.Close() _, err := s1.DialPeer(ctx, s1.LocalPeer()) - if err != ErrDialToSelf { - t.Fatal("expected error from self dial") - } + require.ErrorIs(t, err, ErrDialToSelf, "expected error from self dial") } diff --git a/go.mod b/go.mod index 28193a21..be770061 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,9 @@ go 1.16 require ( github.com/ipfs/go-log v1.0.5 - github.com/jbenet/goprocess v0.1.4 github.com/libp2p/go-addr-util v0.1.0 github.com/libp2p/go-conn-security-multistream v0.2.1 - github.com/libp2p/go-libp2p-core v0.8.6 + github.com/libp2p/go-libp2p-core v0.9.1-0.20210905173309-045ce33f287b github.com/libp2p/go-libp2p-peerstore v0.2.8 github.com/libp2p/go-libp2p-quic-transport v0.11.2 github.com/libp2p/go-libp2p-testing v0.4.2 diff --git a/go.sum b/go.sum index d9a16369..a76b260f 100644 --- a/go.sum +++ b/go.sum @@ -127,7 +127,6 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -234,7 +233,6 @@ github.com/jbenet/go-temp-err-catcher v0.1.0 h1:zpb3ZH6wIE8Shj2sKS+khgRvf7T7RABo github.com/jbenet/go-temp-err-catcher v0.1.0/go.mod h1:0kJRvmDZXNMIiJirNPEYfhpPwbGVtZVWC34vc5WLsDk= github.com/jbenet/goprocess v0.0.0-20160826012719-b497e2f366b8/go.mod h1:Ly/wlsjFq/qrU3Rar62tu1gASgGw6chQbSh/XgIIXCY= github.com/jbenet/goprocess v0.1.3/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= -github.com/jbenet/goprocess v0.1.4 h1:DRGOFReOMqqDNXwW70QkacFW0YN9QnwLV0Vqk+3oU0o= github.com/jbenet/goprocess v0.1.4/go.mod h1:5yspPrukOVuOLORacaBi858NqyClJPQxYZlqdZVfqY4= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= @@ -288,8 +286,9 @@ github.com/libp2p/go-libp2p-core v0.5.1/go.mod h1:uN7L2D4EvPCvzSH5SrhR72UWbnSGpt github.com/libp2p/go-libp2p-core v0.7.0/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= github.com/libp2p/go-libp2p-core v0.8.0/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= github.com/libp2p/go-libp2p-core v0.8.1/go.mod h1:FfewUH/YpvWbEB+ZY9AQRQ4TAD8sJBt/G1rVvhz5XT8= -github.com/libp2p/go-libp2p-core v0.8.6 h1:3S8g006qG6Tjpj1JdRK2S+TWc2DJQKX/RG9fdLeiLSU= github.com/libp2p/go-libp2p-core v0.8.6/go.mod h1:dgHr0l0hIKfWpGpqAMbpo19pen9wJfdCGv51mTmdpmM= +github.com/libp2p/go-libp2p-core v0.9.1-0.20210905173309-045ce33f287b h1:Rc/KIaoWLFumEDUm0oMkpMKG+ASa5YzpZJQaF1zfZV0= +github.com/libp2p/go-libp2p-core v0.9.1-0.20210905173309-045ce33f287b/go.mod h1:j3WKs+bvJ5a1/WEe8IFouxQQltzZQWDOKL0MWT8C0Eo= github.com/libp2p/go-libp2p-mplex v0.4.1 h1:/pyhkP1nLwjG3OM+VuaNJkQT/Pqq73WzB3aDN3Fx1sc= github.com/libp2p/go-libp2p-mplex v0.4.1/go.mod h1:cmy+3GfqfM1PceHTLL7zQzAAYaryDu6iPSC+CIb094g= github.com/libp2p/go-libp2p-peerstore v0.2.8 h1:nJghUlUkFVvyk7ccsM67oFA6kqUkwyCM1G4WPVMCWYA= @@ -611,7 +610,6 @@ go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= diff --git a/peers_test.go b/peers_test.go index 8e82bf5b..908abe91 100644 --- a/peers_test.go +++ b/peers_test.go @@ -9,17 +9,15 @@ import ( "github.com/libp2p/go-libp2p-core/peerstore" ma "github.com/multiformats/go-multiaddr" - - . "github.com/libp2p/go-libp2p-swarm" ) func TestPeers(t *testing.T) { ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) s1 := swarms[0] s2 := swarms[1] - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) // t.Logf("connections from %s", s.LocalPeer()) @@ -55,7 +53,7 @@ func TestPeers(t *testing.T) { log.Infof("%s swarm routing table: %s", s.LocalPeer(), s.Peers()) } - test := func(s *Swarm) { + test := func(s network.Network) { expect := 1 actual := len(s.Peers()) if actual != expect { diff --git a/simul_test.go b/simul_test.go index 0373e37d..326c4e21 100644 --- a/simul_test.go +++ b/simul_test.go @@ -7,32 +7,29 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" ma "github.com/multiformats/go-multiaddr" - . "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-libp2p-testing/ci" ) func TestSimultOpen(t *testing.T) { - t.Parallel() - - ctx := context.Background() - swarms := makeSwarms(ctx, t, 2, swarmt.OptDisableReuseport) + swarms := makeSwarms(t, 2, swarmt.OptDisableReuseport) // connect everyone { var wg sync.WaitGroup - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { defer wg.Done() // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) - if _, err := s.DialPeer(ctx, dst); err != nil { + if _, err := s.DialPeer(context.Background(), dst); err != nil { t.Error("error swarm dialing to peer", err) } } diff --git a/swarm.go b/swarm.go index 35eb4156..3980575d 100644 --- a/swarm.go +++ b/swarm.go @@ -18,8 +18,6 @@ import ( "github.com/libp2p/go-libp2p-core/transport" logging "github.com/ipfs/go-log" - "github.com/jbenet/goprocess" - goprocessctx "github.com/jbenet/goprocess/context" ma "github.com/multiformats/go-multiaddr" ) @@ -92,9 +90,10 @@ type Swarm struct { limiter *dialLimiter gater connmgr.ConnectionGater - proc goprocess.Process - ctx context.Context - bwc metrics.Reporter + ctx context.Context // is canceled when Close is called + ctxCancel context.CancelFunc + + bwc metrics.Reporter } // NewSwarm constructs a Swarm. @@ -103,11 +102,14 @@ type Swarm struct { // `extra` interface{} parameter facilitates the future migration. Supported // elements are: // - connmgr.ConnectionGater -func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, extra ...interface{}) *Swarm { +func NewSwarm(local peer.ID, peers peerstore.Peerstore, bwc metrics.Reporter, extra ...interface{}) *Swarm { + ctx, cancel := context.WithCancel(context.Background()) s := &Swarm{ - local: local, - peers: peers, - bwc: bwc, + local: local, + peers: peers, + bwc: bwc, + ctx: ctx, + ctxCancel: cancel, } s.conns.m = make(map[peer.ID][]*Conn) @@ -124,23 +126,11 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc s.dsync = newDialSync(s.startDialWorker) s.limiter = newDialLimiter(s.dialAddr) - s.proc = goprocessctx.WithContext(ctx) - s.ctx = goprocessctx.OnClosingContext(s.proc) s.backf.init(s.ctx) - - // Set teardown after setting the context/process so we don't start the - // teardown process early. - s.proc.SetTeardown(s.teardown) - return s } -func (s *Swarm) teardown() error { - // Wait for the context to be canceled. - // This allows other parts of the swarm to detect that we're shutting - // down. - <-s.ctx.Done() - +func (s *Swarm) Close() error { // Prevents new connections and/or listeners from being added to the swarm. s.listeners.Lock() @@ -201,11 +191,6 @@ func (s *Swarm) teardown() error { return nil } -// Process returns the Process of the swarm -func (s *Swarm) Process() goprocess.Process { - return s.proc -} - func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, error) { var ( p = tc.RemotePeer() @@ -293,16 +278,6 @@ func (s *Swarm) Peerstore() peerstore.Peerstore { return s.peers } -// Context returns the context of the swarm -func (s *Swarm) Context() context.Context { - return s.ctx -} - -// Close stops the Swarm. -func (s *Swarm) Close() error { - return s.proc.Close() -} - // TODO: We probably don't need the conn handlers. // SetConnHandler assigns the handler for new connections. diff --git a/swarm_addr_test.go b/swarm_addr_test.go index baeac462..21ecafa2 100644 --- a/swarm_addr_test.go +++ b/swarm_addr_test.go @@ -6,6 +6,7 @@ import ( "github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/test" + "github.com/stretchr/testify/require" ma "github.com/multiformats/go-multiaddr" @@ -13,7 +14,6 @@ import ( ) func TestDialBadAddrs(t *testing.T) { - m := func(s string) ma.Multiaddr { maddr, err := ma.NewMultiaddr(s) if err != nil { @@ -22,13 +22,12 @@ func TestDialBadAddrs(t *testing.T) { return maddr } - ctx := context.Background() - s := makeSwarms(ctx, t, 1)[0] + s := makeSwarms(t, 1)[0] test := func(a ma.Multiaddr) { p := test.RandPeerIDFatal(t) s.Peerstore().AddAddr(p, a, peerstore.PermanentAddrTTL) - if _, err := s.DialPeer(ctx, p); err == nil { + if _, err := s.DialPeer(context.Background(), p); err == nil { t.Errorf("swarm should not dial: %s", p) } } @@ -39,19 +38,13 @@ func TestDialBadAddrs(t *testing.T) { } func TestAddrRace(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := makeSwarms(ctx, t, 1)[0] + s := makeSwarms(t, 1)[0] defer s.Close() a1, err := s.InterfaceListenAddresses() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) a2, err := s.InterfaceListenAddresses() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(a1) > 0 && len(a2) > 0 && &a1[0] == &a2[0] { t.Fatal("got the exact same address set twice; this could lead to data races") @@ -59,15 +52,8 @@ func TestAddrRace(t *testing.T) { } func TestAddressesWithoutListening(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := swarmt.GenSwarm(t, ctx, swarmt.OptDialOnly) - + s := swarmt.GenSwarm(t, swarmt.OptDialOnly) a1, err := s.InterfaceListenAddresses() - if err != nil { - t.Fatal(err) - } - if len(a1) != 0 { - t.Fatalf("expected to be listening on no addresses, was listening on %d", len(a1)) - } + require.NoError(t, err) + require.Empty(t, a1, "expected to be listening on no addresses") } diff --git a/swarm_listen.go b/swarm_listen.go index c064ae85..ca54280c 100644 --- a/swarm_listen.go +++ b/swarm_listen.go @@ -46,7 +46,7 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { // // Distinguish between these two cases to avoid confusing users. select { - case <-s.proc.Closing(): + case <-s.ctx.Done(): return ErrSwarmClosed default: return ErrNoTransport diff --git a/swarm_net_test.go b/swarm_net_test.go index 05984f6b..1f1d0454 100644 --- a/swarm_net_test.go +++ b/swarm_net_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/libp2p/go-libp2p-core/network" . "github.com/libp2p/go-libp2p-swarm/testing" @@ -15,19 +17,16 @@ import ( // TestConnectednessCorrect starts a few networks, connects a few // and tests Connectedness value is correct. func TestConnectednessCorrect(t *testing.T) { - - ctx := context.Background() - nets := make([]network.Network, 4) for i := 0; i < 4; i++ { - nets[i] = GenSwarm(t, ctx) + nets[i] = GenSwarm(t) } // connect 0-1, 0-2, 0-3, 1-2, 2-3 dial := func(a, b network.Network) { DivulgeAddresses(b, a) - if _, err := a.DialPeer(ctx, b.LocalPeer()); err != nil { + if _, err := a.DialPeer(context.Background(), b.LocalPeer()); err != nil { t.Fatalf("Failed to dial: %s", err) } } @@ -54,33 +53,17 @@ func TestConnectednessCorrect(t *testing.T) { expectConnectedness(t, nets[0], nets[2], network.NotConnected) expectConnectedness(t, nets[1], nets[3], network.NotConnected) - if len(nets[0].Peers()) != 2 { - t.Fatal("expected net 0 to have two peers") - } - - if len(nets[2].Peers()) != 2 { - t.Fatal("expected net 2 to have two peers") - } - - if len(nets[1].ConnsToPeer(nets[3].LocalPeer())) != 0 { - t.Fatal("net 1 should have no connections to net 3") - } - - if err := nets[2].ClosePeer(nets[1].LocalPeer()); err != nil { - t.Fatal(err) - } + require.Len(t, nets[0].Peers(), 2, "expected net 0 to have two peers") + require.Len(t, nets[2].Peers(), 2, "expected net 2 to have two peers") + require.NotZerof(t, nets[1].ConnsToPeer(nets[3].LocalPeer()), "net 1 should have no connections to net 3") + require.NoError(t, nets[2].ClosePeer(nets[1].LocalPeer())) time.Sleep(time.Millisecond * 50) - expectConnectedness(t, nets[2], nets[1], network.NotConnected) for _, n := range nets { n.Close() } - - for _, n := range nets { - <-n.Process().Closed() - } } func expectConnectedness(t *testing.T, a, b network.Network, expected network.Connectedness) { @@ -113,7 +96,7 @@ func TestNetworkOpenStream(t *testing.T) { nets := make([]network.Network, 4) for i := 0; i < 4; i++ { - nets[i] = GenSwarm(t, ctx) + nets[i] = GenSwarm(t) } dial := func(a, b network.Network) { diff --git a/swarm_notif_test.go b/swarm_notif_test.go index 33836172..157405cb 100644 --- a/swarm_notif_test.go +++ b/swarm_notif_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -18,8 +20,7 @@ func TestNotifications(t *testing.T) { notifiees := make([]*netNotifiee, swarmSize) - ctx := context.Background() - swarms := makeSwarms(ctx, t, swarmSize) + swarms := makeSwarms(t, swarmSize) defer func() { for i, s := range swarms { select { @@ -27,10 +28,7 @@ func TestNotifications(t *testing.T) { t.Error("should not have been closed") default: } - err := s.Close() - if err != nil { - t.Error(err) - } + require.NoError(t, s.Close()) select { case <-notifiees[i].listenClose: default: @@ -48,7 +46,7 @@ func TestNotifications(t *testing.T) { notifiees[i] = n } - connectSwarms(t, ctx, swarms) + connectSwarms(t, context.Background(), swarms) time.Sleep(50 * time.Millisecond) // should've gotten 5 by now. @@ -96,7 +94,7 @@ func TestNotifications(t *testing.T) { } } - complement := func(c network.Conn) (*Swarm, *netNotifiee, *Conn) { + complement := func(c network.Conn) (network.Network, *netNotifiee, *Conn) { for i, s := range swarms { for _, c2 := range s.Conns() { if c.LocalMultiaddr().Equal(c2.RemoteMultiaddr()) && diff --git a/swarm_test.go b/swarm_test.go index a94281b1..e6cb0ba3 100644 --- a/swarm_test.go +++ b/swarm_test.go @@ -58,29 +58,25 @@ func EchoStreamHandler(stream network.Stream) { }() } -func makeDialOnlySwarm(ctx context.Context, t *testing.T) *Swarm { - swarm := GenSwarm(t, ctx, OptDialOnly) +func makeDialOnlySwarm(t *testing.T) network.Network { + swarm := GenSwarm(t, OptDialOnly) swarm.SetStreamHandler(EchoStreamHandler) - return swarm } -func makeSwarms(ctx context.Context, t *testing.T, num int, opts ...Option) []*Swarm { - swarms := make([]*Swarm, 0, num) - +func makeSwarms(t *testing.T, num int, opts ...Option) []network.Network { + swarms := make([]network.Network, 0, num) for i := 0; i < num; i++ { - swarm := GenSwarm(t, ctx, opts...) + swarm := GenSwarm(t, opts...) swarm.SetStreamHandler(EchoStreamHandler) swarms = append(swarms, swarm) } - return swarms } -func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) { - +func connectSwarms(t *testing.T, ctx context.Context, swarms []network.Network) { var wg sync.WaitGroup - connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) if _, err := s.DialPeer(ctx, dst); err != nil { @@ -104,13 +100,10 @@ func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) { } func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { - // t.Skip("skipping for another test") - - ctx := context.Background() - swarms := makeSwarms(ctx, t, SwarmNum, OptDisableReuseport) + swarms := makeSwarms(t, SwarmNum, OptDisableReuseport) // connect everyone - connectSwarms(t, ctx, swarms) + connectSwarms(t, context.Background(), swarms) // ping/pong for _, s1 := range swarms { @@ -118,7 +111,7 @@ func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { log.Debugf("%s ping pong round", s1.LocalPeer()) log.Debugf("-------------------------------------------------------") - _, cancel := context.WithCancel(ctx) + _, cancel := context.WithCancel(context.Background()) got := map[peer.ID]int{} errChan := make(chan error, MsgNum*len(swarms)) streamChan := make(chan network.Stream, MsgNum) @@ -132,7 +125,7 @@ func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) { defer wg.Done() // first, one stream per peer (nice) - stream, err := s1.NewStream(ctx, p) + stream, err := s1.NewStream(context.Background(), p) if err != nil { errChan <- err return @@ -253,7 +246,7 @@ func TestConnHandler(t *testing.T) { t.Parallel() ctx := context.Background() - swarms := makeSwarms(ctx, t, 5) + swarms := makeSwarms(t, 5) gotconn := make(chan struct{}, 10) swarms[0].SetConnHandler(func(conn network.Conn) { @@ -387,8 +380,8 @@ func TestConnectionGating(t *testing.T) { p2Gater = tc.p2Gater(p2Gater) } - sw1 := GenSwarm(t, ctx, OptConnGater(p1Gater), optTransport) - sw2 := GenSwarm(t, ctx, OptConnGater(p2Gater), optTransport) + sw1 := GenSwarm(t, OptConnGater(p1Gater), optTransport) + sw2 := GenSwarm(t, OptConnGater(p2Gater), optTransport) p1 := sw1.LocalPeer() p2 := sw2.LocalPeer() @@ -408,10 +401,9 @@ func TestConnectionGating(t *testing.T) { } func TestNoDial(t *testing.T) { - ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) - _, err := swarms[0].NewStream(network.WithNoDial(ctx, "swarm test"), swarms[1].LocalPeer()) + _, err := swarms[0].NewStream(network.WithNoDial(context.Background(), "swarm test"), swarms[1].LocalPeer()) if err != network.ErrNoConn { t.Fatal("should have failed with ErrNoConn") } @@ -419,36 +411,29 @@ func TestNoDial(t *testing.T) { func TestCloseWithOpenStreams(t *testing.T) { ctx := context.Background() - swarms := makeSwarms(ctx, t, 2) + swarms := makeSwarms(t, 2) connectSwarms(t, ctx, swarms) s, err := swarms[0].NewStream(ctx, swarms[1].LocalPeer()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer s.Close() // close swarm before stream. - err = swarms[0].Close() - if err != nil { - t.Fatal(err) - } + require.NoError(t, swarms[0].Close()) } func TestTypedNilConn(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := GenSwarm(t, ctx) + s := GenSwarm(t) defer s.Close() // We can't dial ourselves. - c, err := s.DialPeer(ctx, s.LocalPeer()) + c, err := s.DialPeer(context.Background(), s.LocalPeer()) require.Error(t, err) // If we fail to dial, the connection should be nil. - require.True(t, c == nil) + require.Nil(t, c) } func TestPreventDialListenAddr(t *testing.T) { - s := GenSwarm(t, context.Background(), OptDialOnly) + s := GenSwarm(t, OptDialOnly) if err := s.Listen(ma.StringCast("/ip4/0.0.0.0/udp/0/quic")); err != nil { t.Fatal(err) } diff --git a/testing/testing.go b/testing/testing.go index ba517769..d6091354 100644 --- a/testing/testing.go +++ b/testing/testing.go @@ -1,7 +1,6 @@ package testing import ( - "context" "testing" csms "github.com/libp2p/go-conn-security-multistream" @@ -22,7 +21,6 @@ import ( msmux "github.com/libp2p/go-stream-muxer-multistream" "github.com/libp2p/go-tcp-transport" - "github.com/jbenet/goprocess" ma "github.com/multiformats/go-multiaddr" ) @@ -73,7 +71,7 @@ func OptPeerPrivateKey(sk crypto.PrivKey) Option { } // GenUpgrader creates a new connection upgrader for use with this swarm. -func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { +func GenUpgrader(n network.Network) *tptu.Upgrader { id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) secMuxer := new(csms.SSMuxer) @@ -88,8 +86,18 @@ func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { } } +type mSwarm struct { + *swarm.Swarm + ps peerstore.Peerstore +} + +func (s *mSwarm) Close() error { + s.ps.Close() + return s.Swarm.Close() +} + // GenSwarm generates a new test swarm. -func GenSwarm(t *testing.T, ctx context.Context, opts ...Option) *swarm.Swarm { +func GenSwarm(t *testing.T, opts ...Option) network.Network { var cfg config for _, o := range opts { o(t, &cfg) @@ -113,11 +121,10 @@ func GenSwarm(t *testing.T, ctx context.Context, opts ...Option) *swarm.Swarm { ps := pstoremem.NewPeerstore() ps.AddPubKey(p.ID, p.PubKey) ps.AddPrivKey(p.ID, p.PrivKey) - s := swarm.NewSwarm(ctx, p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater) - - // Call AddChildNoWait because we can't call AddChild after the process - // may have been closed (e.g., if the context was canceled). - s.Process().AddChildNoWait(goprocess.WithTeardown(ps.Close)) + s := &mSwarm{ + Swarm: swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater), + ps: ps, + } upgrader := GenUpgrader(s) upgrader.ConnGater = cfg.connectionGater diff --git a/testing/testing_test.go b/testing/testing_test.go index a80cca17..60cd2787 100644 --- a/testing/testing_test.go +++ b/testing/testing_test.go @@ -1,14 +1,13 @@ package testing import ( - "context" "testing" "github.com/stretchr/testify/require" ) func TestGenSwarm(t *testing.T) { - swarm := GenSwarm(t, context.Background()) + swarm := GenSwarm(t) require.NoError(t, swarm.Close()) GenUpgrader(swarm) } diff --git a/transport_test.go b/transport_test.go index 82225840..52726026 100644 --- a/transport_test.go +++ b/transport_test.go @@ -7,9 +7,13 @@ import ( swarm "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" + "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/transport" + ma "github.com/multiformats/go-multiaddr" + + "github.com/stretchr/testify/require" ) type dummyTransport struct { @@ -42,24 +46,23 @@ func (dt *dummyTransport) Close() error { return nil } +type swarmWithTransport interface { + network.Network + AddTransport(transport.Transport) error +} + func TestUselessTransport(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := swarmt.GenSwarm(t, ctx) - err := s.AddTransport(new(dummyTransport)) + s := swarmt.GenSwarm(t) + err := s.(swarmWithTransport).AddTransport(new(dummyTransport)) if err == nil { t.Fatal("adding a transport that supports no protocols should have failed") } } func TestTransportClose(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := swarmt.GenSwarm(t, ctx) + s := swarmt.GenSwarm(t) tpt := &dummyTransport{protocols: []int{1}} - if err := s.AddTransport(tpt); err != nil { - t.Fatal(err) - } + require.NoError(t, s.(swarmWithTransport).AddTransport(tpt)) _ = s.Close() if !tpt.closed { t.Fatal("expected transport to be closed") @@ -68,13 +71,11 @@ func TestTransportClose(t *testing.T) { } func TestTransportAfterClose(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - s := swarmt.GenSwarm(t, ctx) + s := swarmt.GenSwarm(t) s.Close() tpt := &dummyTransport{protocols: []int{1}} - if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed { + if err := s.(swarmWithTransport).AddTransport(tpt); err != swarm.ErrSwarmClosed { t.Fatal("expected swarm closed error, got: ", err) } }