diff --git a/src/net/http/request.go b/src/net/http/request.go index d706d8e1b600a..d1793c75d7bfd 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -1121,3 +1121,24 @@ var validHostByte = [256]bool{ '_': true, // unreserved '~': true, // unreserved } + +func validHeaderName(v string) bool { + if len(v) == 0 { + return false + } + return strings.IndexFunc(v, isNotToken) == -1 +} + +func validHeaderValue(v string) bool { + for i := 0; i < len(v); i++ { + b := v[i] + if b == '\t' { + continue + } + if ' ' <= b && b <= '~' { + continue + } + return false + } + return true +} diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 31ba06a267497..0ce492c6dd930 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -3629,6 +3629,7 @@ func testHandlerSetsBodyNil(t *testing.T, h2 bool) { } // Test that we validate the Host header. +// Issue 11206 (invalid bytes in Host) and 13624 (Host present in HTTP/1.1) func TestServerValidatesHostHeader(t *testing.T) { tests := []struct { proto string @@ -3676,6 +3677,43 @@ func TestServerValidatesHostHeader(t *testing.T) { } } +// Test that we validate the valid bytes in HTTP/1 headers. +// Issue 11207. +func TestServerValidatesHeaders(t *testing.T) { + tests := []struct { + header string + want int + }{ + {"", 200}, + {"Foo: bar\r\n", 200}, + {"X-Foo: bar\r\n", 200}, + {"Foo: a space\r\n", 200}, + + {"A space: foo\r\n", 400}, // space in header + {"foo\xffbar: foo\r\n", 400}, // binary in header + {"foo\x00bar: foo\r\n", 400}, // binary in header + + {"foo: foo\x00foo\r\n", 400}, // binary in value + {"foo: foo\xfffoo\r\n", 400}, // binary in value + } + for _, tt := range tests { + conn := &testConn{closec: make(chan bool)} + io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n") + + ln := &oneConnListener{conn} + go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) + <-conn.closec + res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) + if err != nil { + t.Errorf("For %q, ReadResponse: %v", tt.header, res) + continue + } + if res.StatusCode != tt.want { + t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want) + } + } +} + func BenchmarkClientServer(b *testing.B) { b.ReportAllocs() b.StopTimer() diff --git a/src/net/http/server.go b/src/net/http/server.go index 4f7fbae6005c1..f6428bcf18d6c 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -707,6 +707,16 @@ func (c *conn) readRequest() (w *response, err error) { if len(hosts) == 1 && !validHostHeader(hosts[0]) { return nil, badRequestError("malformed Host header") } + for k, vv := range req.Header { + if !validHeaderName(k) { + return nil, badRequestError("invalid header name") + } + for _, v := range vv { + if !validHeaderValue(v) { + return nil, badRequestError("invalid header value") + } + } + } delete(req.Header, "Host") req.RemoteAddr = c.remoteAddr