diff --git a/internal/protocols/webrtc/peer_connection.go b/internal/protocols/webrtc/peer_connection.go index db1f1cb7258..850ca3456cb 100644 --- a/internal/protocols/webrtc/peer_connection.go +++ b/internal/protocols/webrtc/peer_connection.go @@ -10,6 +10,7 @@ import ( "github.com/pion/ice/v2" "github.com/pion/interceptor" + "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" "github.com/bluenviron/mediamtx/internal/conf" @@ -29,6 +30,37 @@ func stringInSlice(a string, list []string) bool { return false } +// TracksAreValid checks whether tracks in the SDP are valid +func TracksAreValid(medias []*sdp.MediaDescription) error { + videoTrack := false + audioTrack := false + + for _, media := range medias { + switch media.MediaName.Media { + case "video": + if videoTrack { + return fmt.Errorf("only a single video and a single audio track are supported") + } + videoTrack = true + + case "audio": + if audioTrack { + return fmt.Errorf("only a single video and a single audio track are supported") + } + audioTrack = true + + default: + return fmt.Errorf("unsupported media '%s'", media.MediaName.Media) + } + } + + if !videoTrack && !audioTrack { + return fmt.Errorf("no valid tracks count") + } + + return nil +} + type trackRecvPair struct { track *webrtc.TrackRemote receiver *webrtc.RTPReceiver @@ -334,10 +366,12 @@ outer: } // GatherIncomingTracks gathers incoming tracks. -func (co *PeerConnection) GatherIncomingTracks( - ctx context.Context, - maxCount int, -) ([]*IncomingTrack, error) { +func (co *PeerConnection) GatherIncomingTracks(ctx context.Context) ([]*IncomingTrack, error) { + var sdp sdp.SessionDescription + sdp.Unmarshal([]byte(co.wr.RemoteDescription().SDP)) //nolint:errcheck + + maxTrackCount := len(sdp.MediaDescriptions) + var tracks []*IncomingTrack t := time.NewTimer(time.Duration(co.TrackGatherTimeout)) @@ -346,7 +380,7 @@ func (co *PeerConnection) GatherIncomingTracks( for { select { case <-t.C: - if maxCount == 0 && len(tracks) != 0 { + if len(tracks) != 0 { return tracks, nil } return nil, fmt.Errorf("deadline exceeded while waiting tracks") @@ -358,7 +392,7 @@ func (co *PeerConnection) GatherIncomingTracks( } tracks = append(tracks, track) - if len(tracks) == maxCount || len(tracks) >= 2 { + if len(tracks) >= maxTrackCount { return tracks, nil } diff --git a/internal/protocols/webrtc/peer_connection_test.go b/internal/protocols/webrtc/peer_connection_test.go index 8c265dbd820..c85ff31c788 100644 --- a/internal/protocols/webrtc/peer_connection_test.go +++ b/internal/protocols/webrtc/peer_connection_test.go @@ -284,7 +284,7 @@ func TestPeerConnectionPublishRead(t *testing.T) { }) require.NoError(t, err) - inc, err := pc2.GatherIncomingTracks(context.Background(), 1) + inc, err := pc2.GatherIncomingTracks(context.Background()) require.NoError(t, err) require.Equal(t, ca.out, inc[0].Format()) diff --git a/internal/protocols/webrtc/track_count.go b/internal/protocols/webrtc/track_count.go deleted file mode 100644 index 99e9abea126..00000000000 --- a/internal/protocols/webrtc/track_count.go +++ /dev/null @@ -1,37 +0,0 @@ -package webrtc - -import ( - "fmt" - - "github.com/pion/sdp/v3" -) - -// TrackCount returns the track count. -func TrackCount(medias []*sdp.MediaDescription) (int, error) { - videoTrack := false - audioTrack := false - trackCount := 0 - - for _, media := range medias { - switch media.MediaName.Media { - case "video": - if videoTrack { - return 0, fmt.Errorf("only a single video and a single audio track are supported") - } - videoTrack = true - - case "audio": - if audioTrack { - return 0, fmt.Errorf("only a single video and a single audio track are supported") - } - audioTrack = true - - default: - return 0, fmt.Errorf("unsupported media '%s'", media.MediaName.Media) - } - - trackCount++ - } - - return trackCount, nil -} diff --git a/internal/protocols/webrtc/whip_client.go b/internal/protocols/webrtc/whip_client.go index 28c8cd0fc0a..5c1154b8e56 100644 --- a/internal/protocols/webrtc/whip_client.go +++ b/internal/protocols/webrtc/whip_client.go @@ -169,8 +169,7 @@ func (c *WHIPClient) Read(ctx context.Context) ([]*IncomingTrack, error) { return nil, err } - // check that there are at most two tracks - _, err = TrackCount(sdp.MediaDescriptions) + err = TracksAreValid(sdp.MediaDescriptions) if err != nil { c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() @@ -210,7 +209,7 @@ outer: } } - tracks, err := c.pc.GatherIncomingTracks(ctx, 0) + tracks, err := c.pc.GatherIncomingTracks(ctx) if err != nil { c.deleteSession(context.Background()) //nolint:errcheck c.pc.Close() diff --git a/internal/servers/webrtc/session.go b/internal/servers/webrtc/session.go index ed6a20664f8..2a574c21c37 100644 --- a/internal/servers/webrtc/session.go +++ b/internal/servers/webrtc/session.go @@ -461,7 +461,7 @@ func (s *session) runPublish() (int, error) { return http.StatusBadRequest, err } - trackCount, err := webrtc.TrackCount(sdp.MediaDescriptions) + err = webrtc.TracksAreValid(sdp.MediaDescriptions) if err != nil { // RFC draft-ietf-wish-whip // if the number of audio and or video @@ -489,7 +489,7 @@ func (s *session) runPublish() (int, error) { s.pc = pc s.mutex.Unlock() - tracks, err := pc.GatherIncomingTracks(s.ctx, trackCount) + tracks, err := pc.GatherIncomingTracks(s.ctx) if err != nil { return 0, err }