Skip to content

Commit

Permalink
webrtc: close connection when remote closes (#2914)
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt authored Aug 15, 2024
1 parent 8a11b7c commit fda0eca
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 45 deletions.
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ require (
github.com/multiformats/go-varint v0.0.7
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58
github.com/pion/datachannel v1.5.8
github.com/pion/ice/v2 v2.3.32
github.com/pion/ice/v2 v2.3.34
github.com/pion/logging v0.2.2
github.com/pion/sctp v1.8.20
github.com/pion/stun v0.6.1
github.com/pion/webrtc/v3 v3.2.50
github.com/pion/webrtc/v3 v3.3.0
github.com/prometheus/client_golang v1.19.1
github.com/prometheus/client_model v0.6.1
github.com/quic-go/quic-go v0.45.2
Expand Down Expand Up @@ -111,7 +111,7 @@ require (
github.com/pion/rtp v1.8.8 // indirect
github.com/pion/sdp/v3 v3.0.9 // indirect
github.com/pion/srtp/v2 v2.0.20 // indirect
github.com/pion/transport/v2 v2.2.9 // indirect
github.com/pion/transport/v2 v2.2.10 // indirect
github.com/pion/turn/v2 v2.1.6 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
13 changes: 6 additions & 7 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ github.com/pion/datachannel v1.5.8/go.mod h1:PgmdpoaNBLX9HNzNClmdki4DYW5JtI7Yibu
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk=
github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/ice/v2 v2.3.32 h1:VwE/uEeqiMm0zUWpdt1DJtnqEkj3UjEbhX92/CurtWI=
github.com/pion/ice/v2 v2.3.32/go.mod h1:8fac0+qftclGy1tYd/nfwfHC729BLaxtVqMdMVCAVPU=
github.com/pion/ice/v2 v2.3.34 h1:Ic1ppYCj4tUOcPAp76U6F3fVrlSw8A9JtRXLqw6BbUM=
github.com/pion/ice/v2 v2.3.34/go.mod h1:mBF7lnigdqgtB+YHkaY/Y6s6tsyRyo4u4rPGRuOjUBQ=
github.com/pion/interceptor v0.1.29 h1:39fsnlP1U8gw2JzOFWdfCU82vHvhW9o0rZnZF56wF+M=
github.com/pion/interceptor v0.1.29/go.mod h1:ri+LGNjRUc5xUNtDEPzfdkmSqISixVTBF/z/Zms/6T4=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
Expand All @@ -307,17 +307,16 @@ github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.3/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v2 v2.2.8/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
github.com/pion/transport/v2 v2.2.9 h1:WEDygVovkJlV2CCunM9KS2kds+kcl7zdIefQA5y/nkE=
github.com/pion/transport/v2 v2.2.9/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q=
github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.0.6 h1:k1mQU06bmmX143qSWgXFqSH1KUJceQvIUuVH/K5ELWw=
github.com/pion/transport/v3 v3.0.6/go.mod h1:HvJr2N/JwNJAfipsRleqwFoR3t/pWyHeZUs89v3+t5s=
github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pion/turn/v2 v2.1.6 h1:Xr2niVsiPTB0FPtt+yAWKFUkU1eotQbGgpTIld4x1Gc=
github.com/pion/turn/v2 v2.1.6/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pion/webrtc/v3 v3.2.50 h1:C/rwL2mBfCxHv6tlLzDAO3krJpQXfVx8A8WHnGJ2j34=
github.com/pion/webrtc/v3 v3.2.50/go.mod h1:dytYYoSBy7ZUWhJMbndx9UckgYvzNAfL7xgVnrIKxqo=
github.com/pion/webrtc/v3 v3.3.0 h1:Rf4u6n6U5t5sUxhYPQk/samzU/oDv7jk6BA5hyO2F9I=
github.com/pion/webrtc/v3 v3.3.0/go.mod h1:hVmrDJvwhEertRWObeb1xzulzHGeVUoPlWvxdGzcfU0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
Expand Down
28 changes: 28 additions & 0 deletions p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,3 +770,31 @@ func TestConnDroppedWhenBlocked(t *testing.T) {
})
}
}

// TestConnClosedWhenRemoteCloses tests that a connection is closed locally when it's closed by remote
func TestConnClosedWhenRemoteCloses(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
server := tc.HostGenerator(t, TransportTestCaseOpts{})
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
defer server.Close()
defer client.Close()

client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := client.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()})
require.NoError(t, err)

require.Eventually(t, func() bool {
return server.Network().Connectedness(client.ID()) != network.NotConnected
}, 5*time.Second, 50*time.Millisecond)
for _, c := range client.Network().ConnsToPeer(server.ID()) {
c.Close()
}
require.Eventually(t, func() bool {
return server.Network().Connectedness(client.ID()) == network.NotConnected
}, 5*time.Second, 50*time.Millisecond)
})
}
}
60 changes: 41 additions & 19 deletions p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

ma "github.com/multiformats/go-multiaddr"
"github.com/pion/datachannel"
"github.com/pion/sctp"
"github.com/pion/webrtc/v3"
)

Expand All @@ -31,6 +32,8 @@ func (errConnectionTimeout) Error() string { return "connection timeout" }
func (errConnectionTimeout) Timeout() bool { return true }
func (errConnectionTimeout) Temporary() bool { return false }

var errConnClosed = errors.New("connection closed")

type dataChannel struct {
stream datachannel.ReadWriteCloser
channel *webrtc.DataChannel
Expand Down Expand Up @@ -74,6 +77,7 @@ func newConnection(
remoteKey ic.PubKey,
remoteMultiaddr ma.Multiaddr,
incomingDataChannels chan dataChannel,
peerConnectionClosedCh chan struct{},
) (*connection, error) {
ctx, cancel := context.WithCancel(context.Background())
c := &connection{
Expand Down Expand Up @@ -102,6 +106,18 @@ func newConnection(
}

pc.OnConnectionStateChange(c.onConnectionStateChange)
pc.SCTP().OnClose(func(err error) {
if err != nil {
c.closeWithError(fmt.Errorf("%w: %w", errConnClosed, err))
}
c.closeWithError(errConnClosed)
})
select {
case <-peerConnectionClosedCh:
c.Close()
return nil, errConnClosed
default:
}
return c, nil
}

Expand All @@ -112,27 +128,29 @@ func (c *connection) ConnState() network.ConnectionState {

// Close closes the underlying peerconnection.
func (c *connection) Close() error {
c.closeOnce.Do(func() { c.closeWithError(errors.New("connection closed")) })
c.closeWithError(errConnClosed)
return nil
}

// closeWithError is used to Close the connection when the underlying DTLS connection fails
func (c *connection) closeWithError(err error) {
c.closeErr = err
// cancel must be called after closeErr is set. This ensures interested goroutines waiting on
// ctx.Done can read closeErr without holding the conn lock.
c.cancel()
// closing peerconnection will close the datachannels associated with the streams
c.pc.Close()

c.m.Lock()
streams := c.streams
c.streams = nil
c.m.Unlock()
for _, s := range streams {
s.closeForShutdown(err)
}
c.scope.Done()
c.closeOnce.Do(func() {
c.closeErr = err
// cancel must be called after closeErr is set. This ensures interested goroutines waiting on
// ctx.Done can read closeErr without holding the conn lock.
c.cancel()
// closing peerconnection will close the datachannels associated with the streams
c.pc.Close()

c.m.Lock()
streams := c.streams
c.streams = nil
c.m.Unlock()
for _, s := range streams {
s.closeForShutdown(err)
}
c.scope.Done()
})
}

func (c *connection) IsClosed() bool {
Expand All @@ -155,6 +173,12 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
}
rwc, err := c.detachChannel(ctx, dc)
if err != nil {
// There's a race between webrtc.SCTP.OnClose callback and the underlying
// association closing. It's nicer to close the connection here.
if errors.Is(err, sctp.ErrStreamClosed) {
c.closeWithError(errConnClosed)
return nil, c.closeErr
}
dc.Close()
return nil, fmt.Errorf("detach channel failed for stream(%d): %w", streamID, err)
}
Expand Down Expand Up @@ -209,9 +233,7 @@ func (c *connection) removeStream(id uint16) {

func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed {
c.closeOnce.Do(func() {
c.closeWithError(errConnectionTimeout{})
})
c.closeWithError(errConnectionTimeout{})
}
}

Expand Down
1 change: 1 addition & 0 deletions p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ func (l *listener) setupConnection(
remotePubKey,
remoteMultiaddr,
w.IncomingDataChannels,
w.PeerConnectionClosedCh,
)
if err != nil {
return nil, err
Expand Down
25 changes: 19 additions & 6 deletions p2p/transport/webrtc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
}
if tConn != nil {
_ = tConn.Close()
tConn = nil
}
}
}()
Expand Down Expand Up @@ -399,6 +400,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
remotePubKey,
remoteMultiaddrWithoutCerthash,
w.IncomingDataChannels,
w.PeerConnectionClosedCh,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -572,9 +574,10 @@ func detachHandshakeDataChannel(ctx context.Context, dc *webrtc.DataChannel) (da
// a small window of time where datachannels created by the peer may not surface to us and cause a
// memory leak.
type webRTCConnection struct {
PeerConnection *webrtc.PeerConnection
HandshakeDataChannel *webrtc.DataChannel
IncomingDataChannels chan dataChannel
PeerConnection *webrtc.PeerConnection
HandshakeDataChannel *webrtc.DataChannel
IncomingDataChannels chan dataChannel
PeerConnectionClosedCh chan struct{}
}

func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configuration) (webRTCConnection, error) {
Expand Down Expand Up @@ -613,10 +616,20 @@ func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configurat
}
})
})

connectionClosedCh := make(chan struct{}, 1)
pc.SCTP().OnClose(func(err error) {
// We only need one message. Closing a connection is a problem as pion might invoke the callback more than once.
select {
case connectionClosedCh <- struct{}{}:
default:
}
})
return webRTCConnection{
PeerConnection: pc,
HandshakeDataChannel: handshakeDataChannel,
IncomingDataChannels: incomingDataChannels,
PeerConnection: pc,
HandshakeDataChannel: handshakeDataChannel,
IncomingDataChannels: incomingDataChannels,
PeerConnectionClosedCh: connectionClosedCh,
}, nil
}

Expand Down
25 changes: 25 additions & 0 deletions p2p/transport/webrtc/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1009,3 +1009,28 @@ func TestManyConnections(t *testing.T) {
}
}
}

func TestConnectionClosedWhenRemoteCloses(t *testing.T) {
listenT, p := getTransport(t)
listener, err := listenT.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct"))
require.NoError(t, err)

dialer, _ := getTransport(t)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := listener.Accept()
if !assert.NoError(t, err) {
return
}
assert.Eventually(t, func() bool {
return c.IsClosed()
}, 5*time.Second, 50*time.Millisecond)
}()

c, err := dialer.Dial(context.Background(), listener.Multiaddr(), p)
require.NoError(t, err)
c.Close()
wg.Wait()
}
6 changes: 3 additions & 3 deletions test-plans/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ require (
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
github.com/pion/datachannel v1.5.8 // indirect
github.com/pion/dtls/v2 v2.2.12 // indirect
github.com/pion/ice/v2 v2.3.32 // indirect
github.com/pion/ice/v2 v2.3.34 // indirect
github.com/pion/interceptor v0.1.29 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns v0.0.12 // indirect
Expand All @@ -78,9 +78,9 @@ require (
github.com/pion/sdp/v3 v3.0.9 // indirect
github.com/pion/srtp/v2 v2.0.20 // indirect
github.com/pion/stun v0.6.1 // indirect
github.com/pion/transport/v2 v2.2.9 // indirect
github.com/pion/transport/v2 v2.2.10 // indirect
github.com/pion/turn/v2 v2.1.6 // indirect
github.com/pion/webrtc/v3 v3.2.50 // indirect
github.com/pion/webrtc/v3 v3.3.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.19.1 // indirect
Expand Down
13 changes: 6 additions & 7 deletions test-plans/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ github.com/pion/datachannel v1.5.8/go.mod h1:PgmdpoaNBLX9HNzNClmdki4DYW5JtI7Yibu
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk=
github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/ice/v2 v2.3.32 h1:VwE/uEeqiMm0zUWpdt1DJtnqEkj3UjEbhX92/CurtWI=
github.com/pion/ice/v2 v2.3.32/go.mod h1:8fac0+qftclGy1tYd/nfwfHC729BLaxtVqMdMVCAVPU=
github.com/pion/ice/v2 v2.3.34 h1:Ic1ppYCj4tUOcPAp76U6F3fVrlSw8A9JtRXLqw6BbUM=
github.com/pion/ice/v2 v2.3.34/go.mod h1:mBF7lnigdqgtB+YHkaY/Y6s6tsyRyo4u4rPGRuOjUBQ=
github.com/pion/interceptor v0.1.29 h1:39fsnlP1U8gw2JzOFWdfCU82vHvhW9o0rZnZF56wF+M=
github.com/pion/interceptor v0.1.29/go.mod h1:ri+LGNjRUc5xUNtDEPzfdkmSqISixVTBF/z/Zms/6T4=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
Expand All @@ -253,17 +253,16 @@ github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.3/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v2 v2.2.8/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
github.com/pion/transport/v2 v2.2.9 h1:WEDygVovkJlV2CCunM9KS2kds+kcl7zdIefQA5y/nkE=
github.com/pion/transport/v2 v2.2.9/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q=
github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.0.6 h1:k1mQU06bmmX143qSWgXFqSH1KUJceQvIUuVH/K5ELWw=
github.com/pion/transport/v3 v3.0.6/go.mod h1:HvJr2N/JwNJAfipsRleqwFoR3t/pWyHeZUs89v3+t5s=
github.com/pion/turn/v2 v2.1.3/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pion/turn/v2 v2.1.6 h1:Xr2niVsiPTB0FPtt+yAWKFUkU1eotQbGgpTIld4x1Gc=
github.com/pion/turn/v2 v2.1.6/go.mod h1:huEpByKKHix2/b9kmTAM3YoX6MKP+/D//0ClgUYR2fY=
github.com/pion/webrtc/v3 v3.2.50 h1:C/rwL2mBfCxHv6tlLzDAO3krJpQXfVx8A8WHnGJ2j34=
github.com/pion/webrtc/v3 v3.2.50/go.mod h1:dytYYoSBy7ZUWhJMbndx9UckgYvzNAfL7xgVnrIKxqo=
github.com/pion/webrtc/v3 v3.3.0 h1:Rf4u6n6U5t5sUxhYPQk/samzU/oDv7jk6BA5hyO2F9I=
github.com/pion/webrtc/v3 v3.3.0/go.mod h1:hVmrDJvwhEertRWObeb1xzulzHGeVUoPlWvxdGzcfU0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
Expand Down

0 comments on commit fda0eca

Please sign in to comment.