From 6d501e9f31be91125d8c935ed3903dfa88fbb323 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Thu, 13 Apr 2023 11:56:48 +0200 Subject: [PATCH] server: support TCP read requests without interleaved IDs (https://github.com/aler9/mediamtx/issues/1650) --- pkg/liberrors/server.go | 8 ---- server_play_test.go | 99 +++++++++++++++++++++++++++++++++++++++++ server_session.go | 36 +++++++++------ 3 files changed, 122 insertions(+), 21 deletions(-) diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index 1253f2f9..f1f404ed 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -124,14 +124,6 @@ func (e ErrServerTransportHeaderNoClientPorts) Error() string { return "transport header does not contain client ports" } -// ErrServerTransportHeaderNoInterleavedIDs is an error that can be returned by a server. -type ErrServerTransportHeaderNoInterleavedIDs struct{} - -// Error implements the error interface. -func (e ErrServerTransportHeaderNoInterleavedIDs) Error() string { - return "transport header does not contain interleaved IDs" -} - // ErrServerTransportHeaderInvalidInterleavedIDs is an error that can be returned by a server. type ErrServerTransportHeaderInvalidInterleavedIDs struct{} diff --git a/server_play_test.go b/server_play_test.go index 9ca240e2..9692d087 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -1970,3 +1970,102 @@ func TestServerPlayAdditionalInfos(t *testing.T) { }(), }, ssrcs) } + +func TestServerPlayNoInterleavedIDs(t *testing.T) { + forma := &formats.Generic{ + PayloadTyp: 96, + RTPMap: "private/90000", + } + err := forma.Init() + require.NoError(t, err) + + stream := NewServerStream(media.Medias{ + &media.Media{ + Type: "application", + Formats: []formats.Format{forma}, + }, + &media.Media{ + Type: "application", + Formats: []formats.Format{forma}, + }, + }) + defer stream.Close() + + s := &Server{ + Handler: &testServerHandler{ + onDescribe: func(ctx *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + RTSPAddress: "localhost:8554", + } + + err = s.Start() + require.NoError(t, err) + defer s.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + desc := doDescribe(t, conn) + + inTH := &headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + Protocol: headers.TransportProtocolTCP, + } + + res, th := doSetup(t, conn, absoluteControlAttribute(desc.MediaDescriptions[0]), inTH) + + require.Equal(t, &[2]int{0, 1}, th.InterleavedIDs) + + session := readSession(t, res) + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL(absoluteControlAttribute(desc.MediaDescriptions[1])), + Header: base.Header{ + "CSeq": base.HeaderValue{"3"}, + "Transport": inTH.Marshal(), + "Session": base.HeaderValue{session}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = th.Unmarshal(res.Header["Transport"]) + require.NoError(t, err) + + require.Equal(t, &[2]int{2, 3}, th.InterleavedIDs) + + doPlay(t, conn, "rtsp://localhost:8554/teststream", session) + + for i := 0; i < 2; i++ { + stream.WritePacketRTP(stream.Medias()[i], &testRTPPacket) + + f, err := conn.ReadInterleavedFrame() + require.NoError(t, err) + require.Equal(t, i*2, f.Channel) + require.Equal(t, testRTPPacketMarshaled, f.Payload) + } +} diff --git a/server_session.go b/server_session.go index cf40ec40..102ecbfb 100644 --- a/server_session.go +++ b/server_session.go @@ -112,22 +112,28 @@ func findAndValidateTransport(inTH *headers.Transport, return TransportUDP, nil } - if inTH.InterleavedIDs == nil { - return 0, liberrors.ErrServerTransportHeaderNoInterleavedIDs{} - } - - if (inTH.InterleavedIDs[0]%2) != 0 || - (inTH.InterleavedIDs[0]+1) != inTH.InterleavedIDs[1] { - return 0, liberrors.ErrServerTransportHeaderInvalidInterleavedIDs{} - } + if inTH.InterleavedIDs != nil { + if (inTH.InterleavedIDs[0]%2) != 0 || + (inTH.InterleavedIDs[0]+1) != inTH.InterleavedIDs[1] { + return 0, liberrors.ErrServerTransportHeaderInvalidInterleavedIDs{} + } - if _, ok := tcpMediasByChannel[inTH.InterleavedIDs[0]]; ok { - return 0, liberrors.ErrServerTransportHeaderInterleavedIDsAlreadyUsed{} + if _, ok := tcpMediasByChannel[inTH.InterleavedIDs[0]]; ok { + return 0, liberrors.ErrServerTransportHeaderInterleavedIDsAlreadyUsed{} + } } return TransportTCP, nil } +func findFreeChannel(tcpMediasByChannel map[int]*serverSessionMedia) int { + for i := 0; ; i += 2 { + if _, ok := tcpMediasByChannel[i]; !ok { + return i + } + } +} + // ServerSessionState is a state of a ServerSession. type ServerSessionState int @@ -823,18 +829,22 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort} default: // TCP - sm.tcpChannel = inTH.InterleavedIDs[0] + if inTH.InterleavedIDs != nil { + sm.tcpChannel = inTH.InterleavedIDs[0] + } else { + sm.tcpChannel = findFreeChannel(ss.tcpMediasByChannel) + } if ss.tcpMediasByChannel == nil { ss.tcpMediasByChannel = make(map[int]*serverSessionMedia) } - ss.tcpMediasByChannel[inTH.InterleavedIDs[0]] = sm + ss.tcpMediasByChannel[sm.tcpChannel] = sm th.Protocol = headers.TransportProtocolTCP de := headers.TransportDeliveryUnicast th.Delivery = &de - th.InterleavedIDs = inTH.InterleavedIDs + th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1} } if ss.setuppedMedias == nil {