diff --git a/p2p/security/noise/benchmark_test.go b/p2p/security/noise/benchmark_test.go index 52454f5959..836275b954 100644 --- a/p2p/security/noise/benchmark_test.go +++ b/p2p/security/noise/benchmark_test.go @@ -39,7 +39,7 @@ func makeTransport(b *testing.B) *Transport { if err != nil { b.Fatal(err) } - tpt, err := New(priv) + tpt, err := New(priv, nil) if err != nil { b.Fatalf("error constructing transport: %v", err) } diff --git a/p2p/security/noise/pb/payload.pb.go b/p2p/security/noise/pb/payload.pb.go index 84db783eff..fdebe4879d 100644 --- a/p2p/security/noise/pb/payload.pb.go +++ b/p2p/security/noise/pb/payload.pb.go @@ -24,6 +24,7 @@ const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type NoiseExtensions struct { WebtransportCerthashes [][]byte `protobuf:"bytes,1,rep,name=webtransport_certhashes,json=webtransportCerthashes" json:"webtransport_certhashes,omitempty"` + StreamMuxers []string `protobuf:"bytes,2,rep,name=stream_muxers,json=streamMuxers" json:"stream_muxers,omitempty"` } func (m *NoiseExtensions) Reset() { *m = NoiseExtensions{} } @@ -66,6 +67,13 @@ func (m *NoiseExtensions) GetWebtransportCerthashes() [][]byte { return nil } +func (m *NoiseExtensions) GetStreamMuxers() []string { + if m != nil { + return m.StreamMuxers + } + return nil +} + type NoiseHandshakePayload struct { IdentityKey []byte `protobuf:"bytes,1,opt,name=identity_key,json=identityKey" json:"identity_key"` IdentitySig []byte `protobuf:"bytes,2,opt,name=identity_sig,json=identitySig" json:"identity_sig"` @@ -134,21 +142,23 @@ func init() { func init() { proto.RegisterFile("payload.proto", fileDescriptor_678c914f1bee6d56) } var fileDescriptor_678c914f1bee6d56 = []byte{ - // 221 bytes of a gzipped FileDescriptorProto + // 251 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x48, 0xac, 0xcc, - 0xc9, 0x4f, 0x4c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xf2, 0xe2, + 0xc9, 0x4f, 0x4c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0xca, 0xe7, 0xe2, 0xf7, 0xcb, 0xcf, 0x2c, 0x4e, 0x75, 0xad, 0x28, 0x49, 0xcd, 0x2b, 0xce, 0xcc, 0xcf, 0x2b, 0x16, 0x32, 0xe7, 0x12, 0x2f, 0x4f, 0x4d, 0x2a, 0x29, 0x4a, 0xcc, 0x2b, 0x2e, 0xc8, 0x2f, 0x2a, 0x89, 0x4f, 0x4e, 0x2d, 0x2a, 0xc9, 0x48, 0x2c, 0xce, 0x48, 0x2d, 0x96, 0x60, 0x54, 0x60, 0xd6, - 0xe0, 0x09, 0x12, 0x43, 0x96, 0x76, 0x86, 0xcb, 0x2a, 0xcd, 0x63, 0xe4, 0x12, 0x05, 0x1b, 0xe6, - 0x91, 0x98, 0x97, 0x52, 0x9c, 0x91, 0x98, 0x9d, 0x1a, 0x00, 0xb1, 0x4f, 0x48, 0x9d, 0x8b, 0x27, - 0x33, 0x25, 0x35, 0xaf, 0x24, 0xb3, 0xa4, 0x32, 0x3e, 0x3b, 0xb5, 0x52, 0x82, 0x51, 0x81, 0x51, - 0x83, 0xc7, 0x89, 0xe5, 0xc4, 0x3d, 0x79, 0x86, 0x20, 0x6e, 0x98, 0x8c, 0x77, 0x6a, 0x25, 0x8a, - 0xc2, 0xe2, 0xcc, 0x74, 0x09, 0x26, 0x6c, 0x0a, 0x83, 0x33, 0xd3, 0x85, 0x8c, 0xb9, 0xb8, 0x52, - 0xe1, 0x4e, 0x96, 0x60, 0x51, 0x60, 0xd4, 0xe0, 0x36, 0x12, 0xd6, 0x2b, 0x48, 0xd2, 0x43, 0xf3, - 0x4d, 0x10, 0x92, 0x32, 0x27, 0x89, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, - 0x48, 0x8e, 0x71, 0xc2, 0x63, 0x39, 0x86, 0x0b, 0x8f, 0xe5, 0x18, 0x6e, 0x3c, 0x96, 0x63, 0x00, - 0x04, 0x00, 0x00, 0xff, 0xff, 0xb2, 0xb0, 0x39, 0x45, 0x1a, 0x01, 0x00, 0x00, + 0xe0, 0x09, 0x12, 0x43, 0x96, 0x76, 0x86, 0xcb, 0x0a, 0x29, 0x73, 0xf1, 0x16, 0x97, 0x14, 0xa5, + 0x26, 0xe6, 0xc6, 0xe7, 0x96, 0x56, 0xa4, 0x16, 0x15, 0x4b, 0x30, 0x29, 0x30, 0x6b, 0x70, 0x06, + 0xf1, 0x40, 0x04, 0x7d, 0xc1, 0x62, 0x4a, 0xf3, 0x18, 0xb9, 0x44, 0xc1, 0x36, 0x7a, 0x24, 0xe6, + 0xa5, 0x14, 0x67, 0x24, 0x66, 0xa7, 0x06, 0x40, 0x1c, 0x25, 0xa4, 0xce, 0xc5, 0x93, 0x99, 0x92, + 0x9a, 0x57, 0x92, 0x59, 0x52, 0x19, 0x9f, 0x9d, 0x5a, 0x29, 0xc1, 0xa8, 0xc0, 0xa8, 0xc1, 0xe3, + 0xc4, 0x72, 0xe2, 0x9e, 0x3c, 0x43, 0x10, 0x37, 0x4c, 0xc6, 0x3b, 0xb5, 0x12, 0x45, 0x61, 0x71, + 0x66, 0xba, 0x04, 0x13, 0x36, 0x85, 0xc1, 0x99, 0xe9, 0x42, 0xc6, 0x5c, 0x5c, 0xa9, 0x70, 0x7f, + 0x49, 0xb0, 0x28, 0x30, 0x6a, 0x70, 0x1b, 0x09, 0xeb, 0x15, 0x24, 0xe9, 0xa1, 0x79, 0x39, 0x08, + 0x49, 0x99, 0x93, 0xc4, 0x89, 0x47, 0x72, 0x8c, 0x17, 0x1e, 0xc9, 0x31, 0x3e, 0x78, 0x24, 0xc7, + 0x38, 0xe1, 0xb1, 0x1c, 0xc3, 0x85, 0xc7, 0x72, 0x0c, 0x37, 0x1e, 0xcb, 0x31, 0x00, 0x02, 0x00, + 0x00, 0xff, 0xff, 0x02, 0xdb, 0x23, 0xb3, 0x3f, 0x01, 0x00, 0x00, } func (m *NoiseExtensions) Marshal() (dAtA []byte, err error) { @@ -171,6 +181,15 @@ func (m *NoiseExtensions) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if len(m.StreamMuxers) > 0 { + for iNdEx := len(m.StreamMuxers) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.StreamMuxers[iNdEx]) + copy(dAtA[i:], m.StreamMuxers[iNdEx]) + i = encodeVarintPayload(dAtA, i, uint64(len(m.StreamMuxers[iNdEx]))) + i-- + dAtA[i] = 0x12 + } + } if len(m.WebtransportCerthashes) > 0 { for iNdEx := len(m.WebtransportCerthashes) - 1; iNdEx >= 0; iNdEx-- { i -= len(m.WebtransportCerthashes[iNdEx]) @@ -255,6 +274,12 @@ func (m *NoiseExtensions) Size() (n int) { n += 1 + l + sovPayload(uint64(l)) } } + if len(m.StreamMuxers) > 0 { + for _, s := range m.StreamMuxers { + l = len(s) + n += 1 + l + sovPayload(uint64(l)) + } + } return n } @@ -346,6 +371,38 @@ func (m *NoiseExtensions) Unmarshal(dAtA []byte) error { m.WebtransportCerthashes = append(m.WebtransportCerthashes, make([]byte, postIndex-iNdEx)) copy(m.WebtransportCerthashes[len(m.WebtransportCerthashes)-1], dAtA[iNdEx:postIndex]) iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field StreamMuxers", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowPayload + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthPayload + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthPayload + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.StreamMuxers = append(m.StreamMuxers, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipPayload(dAtA[iNdEx:]) diff --git a/p2p/security/noise/pb/payload.proto b/p2p/security/noise/pb/payload.proto index 7c1b0bdcae..ff303b0daf 100644 --- a/p2p/security/noise/pb/payload.proto +++ b/p2p/security/noise/pb/payload.proto @@ -3,6 +3,7 @@ package pb; message NoiseExtensions { repeated bytes webtransport_certhashes = 1; + repeated string stream_muxers = 2; } message NoiseHandshakePayload { diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index f1286b9ffb..ce8d97cdc8 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -40,6 +40,9 @@ type secureSession struct { prologue []byte initiatorEarlyDataHandler, responderEarlyDataHandler EarlyDataHandler + + // ConnectionState holds state information releated to the secureSession entity. + connectionState network.ConnectionState } // newSecureSession creates a Noise session over the given insecureConn Conn, using @@ -110,7 +113,7 @@ func (s *secureSession) RemotePublicKey() crypto.PubKey { } func (s *secureSession) ConnState() network.ConnectionState { - return network.ConnectionState{} + return s.connectionState } func (s *secureSession) SetDeadline(t time.Time) error { @@ -128,3 +131,10 @@ func (s *secureSession) SetWriteDeadline(t time.Time) error { func (s *secureSession) Close() error { return s.insecureConn.Close() } + +func SessionWithConnState(s *secureSession, muxer string) *secureSession { + if s != nil { + s.connectionState.NextProto = muxer + } + return s +} diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index c6923698cc..7ae4d7deb1 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -7,13 +7,18 @@ import ( "github.com/libp2p/go-libp2p/core/canonicallog" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" + "github.com/libp2p/go-libp2p/p2p/security/noise/pb" manet "github.com/multiformats/go-multiaddr/net" ) // ID is the protocol ID for noise -const ID = "/noise" +const ( + ID = "/noise" + maxProtoNum = 100 +) var _ sec.SecureTransport = &Transport{} @@ -22,38 +27,51 @@ var _ sec.SecureTransport = &Transport{} type Transport struct { localID peer.ID privateKey crypto.PrivKey + muxers []string } // New creates a new Noise transport using the given private key as its // libp2p identity key. -func New(privkey crypto.PrivKey) (*Transport, error) { +func New(privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { localID, err := peer.IDFromPrivateKey(privkey) if err != nil { return nil, err } + smuxers := make([]string, 0, len(muxers)) + for _, muxer := range muxers { + smuxers = append(smuxers, string(muxer)) + } + return &Transport{ localID: localID, privateKey: privkey, + muxers: smuxers, }, nil } // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(t, ctx, insecure, p, nil, nil, nil, false) + responderEDH := newTransportEDH(t) + c, err := newSecureSession(t, ctx, insecure, p, nil, nil, responderEDH, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { canonicallog.LogPeerStatus(100, p, addr, "handshake_failure", "noise", "err", err.Error()) } } - return c, err + return SessionWithConnState(c, responderEDH.MatchMuxers(false)), err } // SecureOutbound runs the Noise handshake as the initiator. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, nil, nil, nil, true) + initiatorEDH := newTransportEDH(t) + c, err := newSecureSession(t, ctx, insecure, p, nil, initiatorEDH, nil, true) + if err != nil { + return c, err + } + return SessionWithConnState(c, initiatorEDH.MatchMuxers(true)), err } func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) { @@ -65,3 +83,46 @@ func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTranspo } return st, nil } + +func matchMuxers(initiatorMuxers, responderMuxers []string) string { + for _, muxer := range responderMuxers { + for _, initMuxer := range initiatorMuxers { + if initMuxer == muxer { + return muxer + } + } + } + return "" +} + +type transportEarlyDataHandler struct { + transport *Transport + receivedMuxers []string +} + +var _ EarlyDataHandler = &transportEarlyDataHandler{} + +func newTransportEDH(t *Transport) *transportEarlyDataHandler { + return &transportEarlyDataHandler{transport: t} +} + +func (i *transportEarlyDataHandler) Send(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions { + return &pb.NoiseExtensions{ + StreamMuxers: i.transport.muxers, + } +} + +func (i *transportEarlyDataHandler) Received(_ context.Context, _ net.Conn, extension *pb.NoiseExtensions) error { + // Discard messages with size or the number of protocols exceeding extension limit for security. + if extension != nil && len(extension.StreamMuxers) <= maxProtoNum { + i.receivedMuxers = extension.GetStreamMuxers() + } + return nil +} + +func (i *transportEarlyDataHandler) MatchMuxers(isInitiator bool) string { + if isInitiator { + return matchMuxers(i.transport.muxers, i.receivedMuxers) + } + return matchMuxers(i.receivedMuxers, i.transport.muxers) +} diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 2fa90d06ef..c3b180470e 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -37,6 +37,12 @@ func newTestTransport(t *testing.T, typ, bits int) *Transport { } } +func newTestTransportWithMuxers(t *testing.T, typ, bits int, muxers []string) *Transport { + transport := newTestTransport(t, typ, bits) + transport.muxers = muxers + return transport +} + // Create a new pair of connected TCP sockets. func newConnPair(t *testing.T) (net.Conn, net.Conn) { lstnr, err := net.Listen("tcp", "localhost:0") @@ -586,3 +592,56 @@ func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) { require.NoError(t, err) } } + +type noiseEarlyDataTestCase struct { + initProtos []string + respProtos []string + expectedResult string +} + +func TestHandshakeWithTransportEarlyData(t *testing.T) { + tests := []noiseEarlyDataTestCase{ + {initProtos: nil, respProtos: nil, expectedResult: ""}, + {[]string{"muxer1"}, []string{"muxer1"}, "muxer1"}, + {[]string{"muxer1"}, []string{}, ""}, + {[]string{}, []string{"muxer2"}, ""}, + {[]string{"muxer2"}, []string{"muxer1"}, ""}, + {[]string{"muxer1/1.0.0", "muxer2/1.0.1"}, []string{"muxer2/1.0.1", "muxer1/1.0.0"}, "muxer2/1.0.1"}, + {[]string{"muxer1/1.0.0", "muxer2/1.0.1", "muxer3/1.0.0"}, []string{"muxer2/1.0.1", "muxer1/1.0.1", "muxer3/1.0.0"}, "muxer2/1.0.1"}, + {[]string{"muxer1/1.0.0", "muxer2/1.0.0"}, []string{"muxer3/1.0.0"}, ""}, + } + + noiseHandshake := func(t *testing.T, initProtos, respProtos []string, expectedProto string) { + initTransport := newTestTransportWithMuxers(t, crypto.Ed25519, 2048, initProtos) + respTransport := newTestTransportWithMuxers(t, crypto.Ed25519, 2048, respProtos) + + initConn, respConn := connect(t, initTransport, respTransport) + defer initConn.Close() + defer respConn.Close() + + require.Equal(t, expectedProto, initConn.connectionState.NextProto) + require.Equal(t, expectedProto, respConn.connectionState.NextProto) + + initData := []byte("Test data for noise transport") + _, err := initConn.Write(initData) + if err != nil { + t.Fatal(err) + } + + respData := make([]byte, len(initData)) + _, err = respConn.Read(respData) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(initData, respData) { + t.Errorf("Data transmitted mismatch over noise session. %v != %v", initData, respData) + } + } + + for _, test := range tests { + t.Run("Transport EarlyData Test", func(t *testing.T) { + noiseHandshake(t, test.initProtos, test.respProtos, test.expectedResult) + }) + } +} diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 1961e9cec9..e41a414a1c 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -79,7 +79,7 @@ func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) { t.Fatal(err) } var secMuxer csms.SSMuxer - noiseTpt, err := noise.New(priv) + noiseTpt, err := noise.New(priv, nil) require.NoError(t, err) secMuxer.AddTransport(noise.ID, noiseTpt) return id, &secMuxer diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 5cd3a88170..c67bd960b6 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -102,7 +102,7 @@ func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceMa return nil, err } } - n, err := noise.New(key) + n, err := noise.New(key, nil) if err != nil { return nil, err }