Skip to content

Commit

Permalink
webrtc muxer: fix multiple race conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Jan 8, 2023
1 parent 2de0941 commit f3f5545
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 28 deletions.
6 changes: 5 additions & 1 deletion internal/core/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package core
import (
"context"
"encoding/json"
"log"
"net"
"net/http"
"reflect"
Expand Down Expand Up @@ -201,7 +202,10 @@ func newAPI(
group.POST("/v1/webrtcconns/kick/:id", a.onWebRTCConnsKick)
}

a.s = &http.Server{Handler: router}
a.s = &http.Server{
Handler: router,
ErrorLog: log.New(&nilWriter{}, "", 0),
}

go a.s.Serve(ln)

Expand Down
25 changes: 25 additions & 0 deletions internal/core/http_requestpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package core

import (
"sync"

"github.com/gin-gonic/gin"
)

type httpRequestPool struct {
wg sync.WaitGroup
}

func newHTTPRequestPool() *httpRequestPool {
return &httpRequestPool{}
}

func (rp *httpRequestPool) mw(ctx *gin.Context) {
rp.wg.Add(1)
ctx.Next()
rp.wg.Done()
}

func (rp *httpRequestPool) close() {
rp.wg.Wait()
}
6 changes: 5 additions & 1 deletion internal/core/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package core
import (
"context"
"io"
"log"
"net"
"net/http"
"strconv"
Expand Down Expand Up @@ -53,7 +54,10 @@ func newMetrics(
router.SetTrustedProxies(nil)
router.GET("/metrics", m.onMetrics)

m.server = &http.Server{Handler: router}
m.server = &http.Server{
Handler: router,
ErrorLog: log.New(&nilWriter{}, "", 0),
}

m.log(logger.Info, "listener opened on "+address)

Expand Down
4 changes: 3 additions & 1 deletion internal/core/pprof.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package core

import (
"context"
"log"
"net"
"net/http"

Expand Down Expand Up @@ -37,7 +38,8 @@ func newPPROF(
}

pp.server = &http.Server{
Handler: http.DefaultServeMux,
Handler: http.DefaultServeMux,
ErrorLog: log.New(&nilWriter{}, "", 0),
}

pp.log(logger.Info, "listener opened on "+address)
Expand Down
19 changes: 14 additions & 5 deletions internal/core/webrtc_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ type webRTCConn struct {
created time.Time
curPC *webrtc.PeerConnection
mutex sync.RWMutex

closed chan struct{}
}

func newWebRTCConn(
Expand Down Expand Up @@ -138,6 +140,7 @@ func newWebRTCConn(
iceUDPMux: iceUDPMux,
iceTCPMux: iceTCPMux,
iceHostNAT1To1IPs: iceHostNAT1To1IPs,
closed: make(chan struct{}),
}

c.log(logger.Info, "opened")
Expand All @@ -152,6 +155,10 @@ func (c *webRTCConn) close() {
c.ctxCancel()
}

func (c *webRTCConn) wait() {
<-c.closed
}

func (c *webRTCConn) remoteAddr() net.Addr {
return c.wsconn.RemoteAddr()
}
Expand Down Expand Up @@ -250,6 +257,7 @@ func (c *webRTCConn) log(level logger.Level, format string, args ...interface{})
}

func (c *webRTCConn) run() {
defer close(c.closed)
defer c.wg.Done()

innerCtx, innerCtxCancel := context.WithCancel(c.ctx)
Expand Down Expand Up @@ -277,11 +285,6 @@ func (c *webRTCConn) run() {
}

func (c *webRTCConn) runInner(ctx context.Context) error {
go func() {
<-ctx.Done()
c.wsconn.Close()
}()

res := c.pathManager.readerAdd(pathReaderAddReq{
author: c,
pathName: c.pathName,
Expand Down Expand Up @@ -348,6 +351,12 @@ func (c *webRTCConn) runInner(ctx context.Context) error {
pcClosed := make(chan struct{})

pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
select {
case <-pcClosed:
return
default:
}

c.log(logger.Debug, "peer connection state: "+state.String())

switch state {
Expand Down
50 changes: 38 additions & 12 deletions internal/core/webrtc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type webRTCServerAPIConnsKickReq struct {
type webRTCConnNewReq struct {
pathName string
wsconn *websocket.Conn
res chan *webRTCConn
}

type webRTCServerParent interface {
Expand All @@ -84,7 +85,6 @@ type webRTCServer struct {

ctx context.Context
ctxCancel func()
wg sync.WaitGroup
ln net.Listener
udpMuxLn net.PacketConn
tcpMuxLn net.Listener
Expand All @@ -99,6 +99,9 @@ type webRTCServer struct {
chConnClose chan *webRTCConn
chAPIConnsList chan webRTCServerAPIConnsListReq
chAPIConnsKick chan webRTCServerAPIConnsKickReq

// out
done chan struct{}
}

func newWebRTCServer(
Expand Down Expand Up @@ -182,6 +185,7 @@ func newWebRTCServer(
chConnClose: make(chan *webRTCConn),
chAPIConnsList: make(chan webRTCServerAPIConnsListReq),
chAPIConnsKick: make(chan webRTCServerAPIConnsKickReq),
done: make(chan struct{}),
}

str := "listener opened on " + address + " (HTTP)"
Expand All @@ -197,7 +201,6 @@ func newWebRTCServer(
s.metrics.webRTCServerSet(s)
}

s.wg.Add(1)
go s.run()

return s, nil
Expand All @@ -211,14 +214,17 @@ func (s *webRTCServer) log(level logger.Level, format string, args ...interface{
func (s *webRTCServer) close() {
s.log(logger.Info, "listener is closing")
s.ctxCancel()
s.wg.Wait()
<-s.done
}

func (s *webRTCServer) run() {
defer s.wg.Done()
defer close(s.done)

rp := newHTTPRequestPool()
defer rp.close()

router := gin.New()
router.NoRoute(httpLoggerMiddleware(s), s.onRequest)
router.NoRoute(rp.mw, httpLoggerMiddleware(s), s.onRequest)

tmp := make([]string, len(s.trustedProxies))
for i, entry := range s.trustedProxies {
Expand All @@ -238,6 +244,8 @@ func (s *webRTCServer) run() {
go hs.Serve(s.ln)
}

var wg sync.WaitGroup

outer:
for {
select {
Expand All @@ -248,14 +256,15 @@ outer:
req.pathName,
req.wsconn,
s.stunServers,
&s.wg,
&wg,
s.pathManager,
s,
s.iceHostNAT1To1IPs,
s.iceUDPMux,
s.iceTCPMux,
)
s.conns[c] = struct{}{}
req.res <- c

case conn := <-s.chConnClose:
delete(s.conns, conn)
Expand Down Expand Up @@ -306,6 +315,8 @@ outer:
hs.Shutdown(context.Background())
s.ln.Close() // in case Shutdown() is called before Serve()

wg.Wait()

if s.udpMuxLn != nil {
s.udpMuxLn.Close()
}
Expand Down Expand Up @@ -389,14 +400,29 @@ func (s *webRTCServer) onRequest(ctx *gin.Context) {
if err != nil {
return
}
defer wsconn.Close()

select {
case s.connNew <- webRTCConnNewReq{
pathName: dir,
wsconn: wsconn,
}:
case <-s.ctx.Done():
c := s.newConn(dir, wsconn)
if c == nil {
return
}

c.wait()
}
}

func (s *webRTCServer) newConn(dir string, wsconn *websocket.Conn) *webRTCConn {
req := webRTCConnNewReq{
pathName: dir,
wsconn: wsconn,
res: make(chan *webRTCConn),
}

select {
case s.connNew <- req:
return <-req.res
case <-s.ctx.Done():
return nil
}
}

Expand Down
32 changes: 24 additions & 8 deletions internal/core/webrtc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ import (
)

type webRTCTestClient struct {
wc *websocket.Conn
pc *webrtc.PeerConnection
track chan *webrtc.TrackRemote
wc *websocket.Conn
pc *webrtc.PeerConnection
track chan *webrtc.TrackRemote
closed chan struct{}
}

func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
wc, _, err := websocket.DefaultDialer.Dial(addr, nil) //nolint:bodyclose
wc, res, err := websocket.DefaultDialer.Dial(addr, nil)
if err != nil {
return nil, err
}
defer res.Body.Close()

_, msg, err := wc.ReadMessage()
if err != nil {
Expand Down Expand Up @@ -55,13 +57,25 @@ func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
})

connected := make(chan struct{})
closed := make(chan struct{})

pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
if state == webrtc.PeerConnectionStateConnected {
switch state {
case webrtc.PeerConnectionStateConnected:
close(connected)

case webrtc.PeerConnectionStateClosed:
select {
case <-closed:
return
default:
}
close(closed)
}
})

track := make(chan *webrtc.TrackRemote, 1)

pc.OnTrack(func(trak *webrtc.TrackRemote, recv *webrtc.RTPReceiver) {
track <- trak
})
Expand Down Expand Up @@ -143,15 +157,17 @@ func newWebRTCTestClient(addr string) (*webRTCTestClient, error) {
<-connected

return &webRTCTestClient{
wc: wc,
pc: pc,
track: track,
wc: wc,
pc: pc,
track: track,
closed: closed,
}, nil
}

func (c *webRTCTestClient) close() {
c.pc.Close()
c.wc.Close()
<-c.closed
}

func TestWebRTCServer(t *testing.T) {
Expand Down

0 comments on commit f3f5545

Please sign in to comment.