diff --git a/core/transport/transport.go b/core/transport/transport.go index d56a3cff06..23ecabb4dc 100644 --- a/core/transport/transport.go +++ b/core/transport/transport.go @@ -85,6 +85,16 @@ type Resolver interface { Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) } +// SkipResolver can be optionally implemented by transports that don't want to +// resolve or transform the multiaddr. Useful for transports that wrap other +// transports. This lets the inner transport specify how a multiaddr is +// resolved later. +// Also useful in cases where the transport doesn't need a resolved address to +// dial. +type SkipResolver interface { + SkipResolve(ctx context.Context, maddr ma.Multiaddr) bool +} + // Listener is an interface closely resembling the net.Listener interface. The // only real difference is that Accept() returns Conn's of the type in this // package, and also exposes a Multiaddr method as opposed to a regular Addr diff --git a/libp2p_test.go b/libp2p_test.go index a5803add4d..b290227fc1 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -2,10 +2,16 @@ package libp2p import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "errors" "fmt" "io" + "math/big" "net" "net/netip" "regexp" @@ -26,11 +32,12 @@ import ( "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" - tls "github.com/libp2p/go-libp2p/p2p/security/tls" + sectls "github.com/libp2p/go-libp2p/p2p/security/tls" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/libp2p/go-libp2p/p2p/transport/tcp" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/websocket" webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "go.uber.org/goleak" @@ -256,7 +263,7 @@ func TestSecurityConstructor(t *testing.T) { h, err := New( Transport(tcp.NewTCPTransport), Security("/noisy", noise.New), - Security("/tls", tls.New), + Security("/tls", sectls.New), DefaultListenAddrs, DisableRelay(), ) @@ -655,3 +662,92 @@ func TestUseCorrectTransportForDialOut(t *testing.T) { } } } + +func TestCircuitBehindWSS(t *testing.T) { + relayTLSConf := getTLSConf(t, net.IPv4(127, 0, 0, 1), time.Now(), time.Now().Add(time.Hour)) + serverNameChan := make(chan string, 2) // Channel that returns what server names the client hello specified + relayTLSConf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + serverNameChan <- chi.ServerName + return relayTLSConf, nil + } + + relay, err := New( + EnableRelayService(), + ForceReachabilityPublic(), + Transport(websocket.New, websocket.WithTLSConfig(relayTLSConf)), + ListenAddrStrings("/ip4/127.0.0.1/tcp/0/wss"), + ) + require.NoError(t, err) + defer relay.Close() + + relayAddrPort, _ := relay.Addrs()[0].ValueForProtocol(ma.P_TCP) + relayAddrWithSNIString := fmt.Sprintf( + "/dns4/localhost/tcp/%s/wss", relayAddrPort, + ) + relayAddrWithSNI := []ma.Multiaddr{ma.StringCast(relayAddrWithSNIString)} + + h, err := New( + NoListenAddrs, + EnableRelay(), + Transport(websocket.New, websocket.WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})), + ForceReachabilityPrivate()) + require.NoError(t, err) + defer h.Close() + + peerBehindRelay, err := New( + NoListenAddrs, + Transport(websocket.New, websocket.WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true})), + EnableRelay(), + EnableAutoRelayWithStaticRelays([]peer.AddrInfo{{ID: relay.ID(), Addrs: relayAddrWithSNI}}), + ForceReachabilityPrivate()) + require.NoError(t, err) + defer peerBehindRelay.Close() + + require.Equal(t, + "localhost", + <-serverNameChan, // The server connects to the relay + ) + + // Connect to the peer behind the relay + h.Connect(context.Background(), peer.AddrInfo{ + ID: peerBehindRelay.ID(), + Addrs: []ma.Multiaddr{ma.StringCast( + fmt.Sprintf("%s/p2p/%s/p2p-circuit", relayAddrWithSNIString, relay.ID()), + )}, + }) + require.NoError(t, err) + + require.Equal(t, + "localhost", + <-serverNameChan, // The client connects to the relay and sends the SNI + ) +} + +// getTLSConf is a helper to generate a self-signed TLS config +func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { + t.Helper() + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(1234), + Subject: pkix.Name{Organization: []string{"websocket"}}, + NotBefore: start, + NotAfter: end, + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IPAddresses: []net.IP{ip}, + } + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv) + require.NoError(t, err) + cert, err := x509.ParseCertificate(caBytes) + require.NoError(t, err) + return &tls.Config{ + Certificates: []tls.Certificate{{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: priv, + Leaf: cert, + }}, + } +} diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 085c0825a9..baa01e603d 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -401,6 +401,32 @@ func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) []ma.Multiad return s.multiaddrResolver.ResolveDNSAddr(ctx, pi.ID, maddr, maximumDNSADDRRecursion, outputLimit) }, } + + var skipped []ma.Multiaddr + skipResolver := resolver{ + canResolve: func(addr ma.Multiaddr) bool { + tpt := s.TransportForDialing(addr) + if tpt == nil { + return false + } + _, ok := tpt.(transport.SkipResolver) + return ok + + }, + resolve: func(ctx context.Context, addr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error) { + tpt := s.TransportForDialing(addr) + resolver, ok := tpt.(transport.SkipResolver) + if !ok { + return []ma.Multiaddr{addr}, nil + } + if resolver.SkipResolve(ctx, addr) { + skipped = append(skipped, addr) + return nil, nil + } + return []ma.Multiaddr{addr}, nil + }, + } + tptResolver := resolver{ canResolve: func(addr ma.Multiaddr) bool { tpt := s.TransportForDialing(addr) @@ -426,14 +452,17 @@ func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) []ma.Multiad return addrs, nil }, } + dnsResolver := resolver{ canResolve: startsWithDNSComponent, resolve: s.multiaddrResolver.ResolveDNSComponent, } - addrs, errs := chainResolvers(ctx, pi.Addrs, maximumResolvedAddresses, []resolver{dnsAddrResolver, tptResolver, dnsResolver}) + addrs, errs := chainResolvers(ctx, pi.Addrs, maximumResolvedAddresses, []resolver{dnsAddrResolver, skipResolver, tptResolver, dnsResolver}) for _, err := range errs { log.Warnf("Failed to resolve addr %s: %v", err.addr, err.err) } + // Add skipped addresses back to the resolved addresses + addrs = append(addrs, skipped...) return stripP2PComponent(addrs) } diff --git a/p2p/protocol/circuitv2/client/transport.go b/p2p/protocol/circuitv2/client/transport.go index 2c9e49f509..694681af18 100644 --- a/p2p/protocol/circuitv2/client/transport.go +++ b/p2p/protocol/circuitv2/client/transport.go @@ -46,8 +46,16 @@ func AddTransport(h host.Host, upgrader transport.Upgrader) error { // Transport interface var _ transport.Transport = (*Client)(nil) +var _ transport.SkipResolver = (*Client)(nil) var _ io.Closer = (*Client)(nil) +// SkipResolve returns true since we always defer to the inner transport for +// the actual connection. By skipping resolution here, we let the inner +// transport decide how to resolve the multiaddr +func (c *Client) SkipResolve(ctx context.Context, maddr ma.Multiaddr) bool { + return true +} + func (c *Client) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (transport.CapableConn, error) { connScope, err := c.host.Network().ResourceManager().OpenConnection(network.DirOutbound, false, a)