Skip to content

Commit

Permalink
Prevent request smuggling (#1719)
Browse files Browse the repository at this point in the history
* Prevent request smuggling

Prevent request smuggling when fasthttp is behind a reverse proxy that
might interprets headers differently by being stricter. Should also
prevent request smuggling when fasthttp is used as the reverse proxy.

* Make header value comparison case-insensitive
  • Loading branch information
erikdubbelboer authored Feb 11, 2024
1 parent 3327266 commit bce5766
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
19 changes: 18 additions & 1 deletion header.go
Original file line number Diff line number Diff line change
Expand Up @@ -3029,6 +3029,8 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
h.contentLength = -2

contentLengthSeen := false

var s headerScanner
s.b = buf
s.disableNormalizing = h.disableNormalizing
Expand Down Expand Up @@ -3064,6 +3066,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
continue
}
if caseInsensitiveCompare(s.key, strContentLength) {
if contentLengthSeen {
return 0, fmt.Errorf("duplicate Content-Length header")
}
contentLengthSeen = true

if h.contentLength != -1 {
var nerr error
if h.contentLength, nerr = parseContentLength(s.value); nerr != nil {
Expand All @@ -3088,7 +3095,17 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
}
case 't':
if caseInsensitiveCompare(s.key, strTransferEncoding) {
if !bytes.Equal(s.value, strIdentity) {
isIdentity := caseInsensitiveCompare(s.value, strIdentity)
isChunked := caseInsensitiveCompare(s.value, strChunked)

if !isIdentity && !isChunked {
if h.secureErrorLogMessage {
return 0, fmt.Errorf("unsupported Transfer-Encoding")
}
return 0, fmt.Errorf("unsupported Transfer-Encoding: %q", s.value)
}

if isChunked {
h.contentLength = -1
h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue)
}
Expand Down
13 changes: 9 additions & 4 deletions header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2618,10 +2618,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) {
testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: aa\r\nContent-Type: ab\r\nContent-Length: 123\r\nContent-Type: xx\r\n\r\n",
123, "/a", "aa", "", "xx", nil)

// post with duplicate content-length
testRequestHeaderReadSuccess(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n",
1, "/xx", "aa", "", "s", nil)

// non-post with content-type
testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\r\nHost: bbb.com\r\nContent-Type: aaab\r\n\r\n",
-2, "/aaa", "bbb.com", "", "aaab", nil)
Expand Down Expand Up @@ -2756,6 +2752,9 @@ func TestRequestHeaderReadError(t *testing.T) {

// forbidden trailer
testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nContent-Length: -1\r\nTrailer: Foo, Content-Length\r\n\r\n")

// post with duplicate content-length
testRequestHeaderReadError(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n")
}

func TestRequestHeaderReadSecuredError(t *testing.T) {
Expand Down Expand Up @@ -2805,6 +2804,8 @@ func testResponseHeaderReadSecuredError(t *testing.T, h *ResponseHeader, headers
}

func testRequestHeaderReadError(t *testing.T, h *RequestHeader, headers string) {
t.Helper()

r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
Expand Down Expand Up @@ -2835,6 +2836,8 @@ func testRequestHeaderReadSecuredError(t *testing.T, h *RequestHeader, headers s
func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers string, expectedStatusCode, expectedContentLength int,
expectedContentType string,
) {
t.Helper()

r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
Expand All @@ -2847,6 +2850,8 @@ func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers stri
func testRequestHeaderReadSuccess(t *testing.T, h *RequestHeader, headers string, expectedContentLength int,
expectedRequestURI, expectedHost, expectedReferer, expectedContentType string, expectedTrailer map[string]string,
) {
t.Helper()

r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
Expand Down

0 comments on commit bce5766

Please sign in to comment.