Skip to content

Commit

Permalink
server: fix panic when recording with wrong transport header (bluenvi…
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Aug 23, 2024
1 parent 208d828 commit 8cd9e24
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 37 deletions.
83 changes: 51 additions & 32 deletions pkg/headers/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ const (
TransportProtocolTCP
)

// String implements fmt.Stringer.
func (p TransportProtocol) String() string {
if p == TransportProtocolUDP {
return "RTP/AVP"
}
return "RTP/AVP/TCP"
}

// TransportDelivery is a delivery method.
type TransportDelivery int

Expand All @@ -56,6 +64,14 @@ const (
TransportDeliveryMulticast
)

// String implements fmt.Stringer.
func (d TransportDelivery) String() string {
if d == TransportDeliveryUnicast {
return "unicast"
}
return "multicast"
}

// TransportMode is a transport mode.
type TransportMode int

Expand All @@ -67,6 +83,33 @@ const (
TransportModeRecord
)

func (m *TransportMode) unmarshal(v string) error {
str := strings.ToLower(v)

switch str {
case "play":
*m = TransportModePlay
return nil

// receive is an old alias for record, used by ffmpeg with the
// -listen flag, and by Darwin Streaming Server
case "record", "receive":
*m = TransportModeRecord
return nil

default:
return fmt.Errorf("invalid transport mode: '%s'", str)
}
}

// String implements fmt.Stringer.
func (m TransportMode) String() string {
if m == TransportModePlay {
return "play"
}
return "record"
}

// Transport is a Transport header.
type Transport struct {
// protocol of the stream
Expand Down Expand Up @@ -218,24 +261,12 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
}

case "mode":
str := strings.ToLower(v)
str = strings.TrimPrefix(str, "\"")
str = strings.TrimSuffix(str, "\"")

switch str {
case "play":
v := TransportModePlay
h.Mode = &v

// receive is an old alias for record, used by ffmpeg with the
// -listen flag, and by Darwin Streaming Server
case "record", "receive":
v := TransportModeRecord
h.Mode = &v

default:
return fmt.Errorf("invalid transport mode: '%s'", str)
var m TransportMode
err = m.unmarshal(v)
if err != nil {
return err
}
h.Mode = &m

default:
// ignore non-standard keys
Expand All @@ -253,18 +284,10 @@ func (h *Transport) Unmarshal(v base.HeaderValue) error {
func (h Transport) Marshal() base.HeaderValue {
var rets []string

if h.Protocol == TransportProtocolUDP {
rets = append(rets, "RTP/AVP")
} else {
rets = append(rets, "RTP/AVP/TCP")
}
rets = append(rets, h.Protocol.String())

if h.Delivery != nil {
if *h.Delivery == TransportDeliveryUnicast {
rets = append(rets, "unicast")
} else {
rets = append(rets, "multicast")
}
rets = append(rets, h.Delivery.String())
}

if h.Source != nil {
Expand Down Expand Up @@ -309,11 +332,7 @@ func (h Transport) Marshal() base.HeaderValue {
}

if h.Mode != nil {
if *h.Mode == TransportModePlay {
rets = append(rets, "mode=play")
} else {
rets = append(rets, "mode=record")
}
rets = append(rets, "mode="+h.Mode.String())
}

return base.HeaderValue{strings.Join(rets, ";")}
Expand Down
8 changes: 6 additions & 2 deletions pkg/liberrors/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,16 @@ func (e ErrServerMediaNotFound) Error() string {

// ErrServerTransportHeaderInvalidMode is an error that can be returned by a server.
type ErrServerTransportHeaderInvalidMode struct {
Mode headers.TransportMode
Mode *headers.TransportMode
}

// Error implements the error interface.
func (e ErrServerTransportHeaderInvalidMode) Error() string {
return fmt.Sprintf("transport header contains a invalid mode (%v)", e.Mode)
m := "null"
if e.Mode != nil {
m = e.Mode.String()
}
return fmt.Sprintf("transport header contains a invalid mode (%v)", m)

Check warning on line 96 in pkg/liberrors/server.go

View check run for this annotation

Codecov / codecov/patch

pkg/liberrors/server.go#L92-L96

Added lines #L92 - L96 were not covered by tests
}

// ErrServerTransportHeaderNoClientPorts is an error that can be returned by a server.
Expand Down
106 changes: 105 additions & 1 deletion server_record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TestServerRecordErrorAnnounce(t *testing.T) {
"unsupported Content-Type header '[aa]'",
},
{
"invalid medias",
"invalid sdp",
base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Expand All @@ -122,6 +122,29 @@ func TestServerRecordErrorAnnounce(t *testing.T) {
},
"invalid SDP: invalid line: (\x01\x02\x03\x04)",
},
{
"invalid session",
base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Body: []byte("v=0\r\n" +
"o=- 0 0 IN IP4 127.0.0.1\r\n" +
"s=-\r\n" +
"c=IN IP4 0.0.0.0\r\n" +
"t=0 0\r\n" +
"m=video 0 RTP/AVP 96\r\n" +
"a=control\r\n" +
"a=rtpmap:97 H264/90000\r\n" +
"a=fmtp:aa packetization-mode=1; profile-level-id=4D002A; " +
"sprop-parameter-sets=Z00AKp2oHgCJ+WbgICAgQA==,aO48gA==\r\n",
),
},
"invalid SDP: media 1 is invalid: clock rate not found",
},
{
"invalid URL 1",
invalidURLAnnounceReq(t, "rtsp:// aaaaa"),
Expand Down Expand Up @@ -168,6 +191,87 @@ func TestServerRecordErrorAnnounce(t *testing.T) {
}
}

func TestServerRecordErrorSetup(t *testing.T) {
for _, ca := range []struct {
name string
err string
}{
{
"invalid transport",
"transport header contains a invalid mode (null)",
},
} {
t.Run(ca.name, func(t *testing.T) {
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
require.EqualError(t, ctx.Error, ca.err)
},
onAnnounce: func(_ *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil, nil
},
onRecord: func(_ *ServerHandlerOnRecordCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: "localhost:8554",
UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001",
}

err := s.Start()
require.NoError(t, err)
defer s.Close()

nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)

medias := []*description.Media{testH264Media}

doAnnounce(t, conn, "rtsp://localhost:8554/teststream", medias)

var inTH *headers.Transport

switch ca.name {
case "invalid transport":
inTH = &headers.Transport{
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: nil,
Protocol: headers.TransportProtocolUDP,
ClientPorts: &[2]int{35466, 35467},
}
}

res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/" + medias[0].Control),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
require.NotEqual(t, base.StatusOK, res.StatusCode)
})
}
}

func TestServerRecordPath(t *testing.T) {
for _, ca := range []struct {
name string
Expand Down
4 changes: 2 additions & 2 deletions server_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: *inTH.Mode}
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}

Check warning on line 731 in server_session.go

View check run for this annotation

Codecov / codecov/patch

server_session.go#L731

Added line #L731 was not covered by tests
}

default: // record
Expand All @@ -741,7 +741,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
if inTH.Mode == nil || *inTH.Mode != headers.TransportModeRecord {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: *inTH.Mode}
}, liberrors.ErrServerTransportHeaderInvalidMode{Mode: inTH.Mode}
}
}

Expand Down

0 comments on commit 8cd9e24

Please sign in to comment.