Skip to content

Commit

Permalink
webtransport: add PSK to constructor, and fail if it is used
Browse files Browse the repository at this point in the history
That way, it won't be possible to construct a host with a PSK
when WebTransport is enabled. This is desireable since WebTransport doesn't
support private network (same as QUIC).
  • Loading branch information
marten-seemann committed Dec 2, 2022
1 parent 1c8eaab commit 4bb466a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
7 changes: 6 additions & 1 deletion p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/libp2p/go-libp2p/p2p/security/noise/pb"
Expand Down Expand Up @@ -93,7 +94,11 @@ var _ tpt.Transport = &transport{}
var _ tpt.Resolver = &transport{}
var _ io.Closer = &transport{}

func New(key ic.PrivKey, connManager *quicreuse.ConnManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) {
func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) {
if len(psk) > 0 {
log.Error("WebTransport doesn't support private networks yet.")
return nil, errors.New("WebTransport doesn't support private networks yet")
}
id, err := peer.IDFromPrivateKey(key)
if err != nil {
return nil, err
Expand Down
64 changes: 32 additions & 32 deletions p2p/transport/webtransport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func newConnManager(t *testing.T, opts ...quicreuse.Option) *quicreuse.ConnManag

func TestTransport(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -122,7 +122,7 @@ func TestTransport(t *testing.T) {
addrChan := make(chan ma.Multiaddr)
go func() {
_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, &network.NullResourceManager{})
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr2.(io.Closer).Close()

Expand Down Expand Up @@ -158,7 +158,7 @@ func TestTransport(t *testing.T) {

func TestHashVerification(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -171,7 +171,7 @@ func TestHashVerification(t *testing.T) {
}()

_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, &network.NullResourceManager{})
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr2.(io.Closer).Close()

Expand Down Expand Up @@ -209,7 +209,7 @@ func TestCanDial(t *testing.T) {
}

_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand All @@ -235,7 +235,7 @@ func TestListenAddrValidity(t *testing.T) {
}

_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand All @@ -252,7 +252,7 @@ func TestListenAddrValidity(t *testing.T) {

func TestListenerAddrs(t *testing.T) {
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand All @@ -275,7 +275,7 @@ func TestResourceManagerDialing(t *testing.T) {
p := peer.ID("foobar")

_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand All @@ -290,7 +290,7 @@ func TestResourceManagerDialing(t *testing.T) {

func TestResourceManagerListening(t *testing.T) {
clientID, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand All @@ -299,7 +299,7 @@ func TestResourceManagerListening(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
Expand All @@ -325,7 +325,7 @@ func TestResourceManagerListening(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, rcmgr)
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, rcmgr)
require.NoError(t, err)
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
Expand Down Expand Up @@ -369,7 +369,7 @@ func TestConnectionGaterDialing(t *testing.T) {
connGater := NewMockConnectionGater(ctrl)

serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -380,7 +380,7 @@ func TestConnectionGaterDialing(t *testing.T) {
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
})
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), connGater, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand All @@ -393,7 +393,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) {
connGater := NewMockConnectionGater(ctrl)

serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), connGater, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -406,7 +406,7 @@ func TestConnectionGaterInterceptAccept(t *testing.T) {
})

_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand All @@ -419,15 +419,15 @@ func TestConnectionGaterInterceptSecured(t *testing.T) {
connGater := NewMockConnectionGater(ctrl)

serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), connGater, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), connGater, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()

clientID, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand Down Expand Up @@ -485,7 +485,7 @@ func TestStaticTLSConf(t *testing.T) {
tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour))

serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf))
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf))
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -495,7 +495,7 @@ func TestStaticTLSConf(t *testing.T) {

t.Run("fails when the certificate is invalid", func(t *testing.T) {
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand All @@ -509,7 +509,7 @@ func TestStaticTLSConf(t *testing.T) {

t.Run("fails when dialing with a wrong certhash", func(t *testing.T) {
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand All @@ -524,7 +524,7 @@ func TestStaticTLSConf(t *testing.T) {
store := x509.NewCertPool()
store.AddCert(tlsConf.Certificates[0].Leaf)
tlsConf := &tls.Config{RootCAs: store}
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSClientConfig(tlsConf))
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSClientConfig(tlsConf))
require.NoError(t, err)
defer cl.(io.Closer).Close()

Expand All @@ -537,7 +537,7 @@ func TestStaticTLSConf(t *testing.T) {

func TestAcceptQueueFilledUp(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, &network.NullResourceManager{})
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -547,7 +547,7 @@ func TestAcceptQueueFilledUp(t *testing.T) {
newConn := func() (tpt.CapableConn, error) {
t.Helper()
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{})
cl, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer cl.(io.Closer).Close()
return cl.Dial(context.Background(), ln.Multiaddr(), serverID)
Expand Down Expand Up @@ -577,15 +577,15 @@ func TestSNIIsSent(t *testing.T) {
return tlsConf, nil
},
}
tr, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf))
tr, err := libp2pwebtransport.New(key, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithTLSConfig(tlsConf))
require.NoError(t, err)
defer tr.(io.Closer).Close()

ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)

_, key2 := newIdentity(t)
clientTr, err := libp2pwebtransport.New(key2, newConnManager(t), nil, &network.NullResourceManager{})
clientTr, err := libp2pwebtransport.New(key2, nil, newConnManager(t), nil, &network.NullResourceManager{})
require.NoError(t, err)
defer tr.(io.Closer).Close()

Expand Down Expand Up @@ -643,7 +643,7 @@ func TestFlowControlWindowIncrease(t *testing.T) {
serverID, serverKey := newIdentity(t)
serverWindowIncreases := make(chan int, 100)
serverRcmgr := &reportingRcmgr{report: serverWindowIncreases}
tr, err := libp2pwebtransport.New(serverKey, newConnManager(t), nil, serverRcmgr)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, serverRcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -670,7 +670,7 @@ func TestFlowControlWindowIncrease(t *testing.T) {
_, clientKey := newIdentity(t)
clientWindowIncreases := make(chan int, 100)
clientRcmgr := &reportingRcmgr{report: clientWindowIncreases}
tr2, err := libp2pwebtransport.New(clientKey, newConnManager(t), nil, clientRcmgr)
tr2, err := libp2pwebtransport.New(clientKey, nil, newConnManager(t), nil, clientRcmgr)
require.NoError(t, err)
defer tr2.(io.Closer).Close()

Expand Down Expand Up @@ -754,7 +754,7 @@ func serverSendsBackValidCert(t *testing.T, timeSinceUnixEpoch time.Duration, ke

priv, _, err := test.SeededTestKeyPair(ic.Ed25519, 256, keySeed)
require.NoError(t, err)
tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
Expand Down Expand Up @@ -833,7 +833,7 @@ func TestServerRotatesCertCorrectly(t *testing.T) {
if err != nil {
return false
}
tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
if err != nil {
return false
}
Expand All @@ -847,7 +847,7 @@ func TestServerRotatesCertCorrectly(t *testing.T) {

// These two certificates together are valid for at most certValidity - (4*clockSkewAllowance)
cl.Add(certValidity - (4 * clockSkewAllowance) - time.Second)
tr, err = libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err = libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
if err != nil {
return false
}
Expand Down Expand Up @@ -883,7 +883,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) {

priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256)
require.NoError(t, err)
tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)

l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
Expand All @@ -896,7 +896,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) {
// e.g. certhash/A/certhash/B ... -> ... certhash/B/certhash/C ... -> ... certhash/C/certhash/D
for i := 0; i < 200; i++ {
cl.Add(24 * time.Hour)
tr, err := libp2pwebtransport.New(priv, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
tr, err := libp2pwebtransport.New(priv, nil, newConnManager(t), nil, &network.NullResourceManager{}, libp2pwebtransport.WithClock(cl))
require.NoError(t, err)
l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
Expand Down

0 comments on commit 4bb466a

Please sign in to comment.