Skip to content

Commit

Permalink
Scan http request byte-by-byte to prevent malicious OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
sparrc committed Oct 14, 2016
1 parent c73964c commit bd0af9a
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 40 deletions.
138 changes: 101 additions & 37 deletions plugins/inputs/http_listener/http_listener.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package http_listener

import (
"bufio"
"bytes"
"compress/gzip"
"fmt"
"io"
"log"
"net"
"net/http"
Expand All @@ -17,15 +19,25 @@ import (
"github.com/influxdata/telegraf/plugins/parsers"
)

const MAX_REQUEST_BODY_SIZE = 50 * 1024 * 1024
const (
// DEFAULT_REQUEST_BODY_MAX is the default maximum request body size, in megabytes.
// if the request body is over this size, we will return an HTTP 413 error.
// 1024 MB
DEFAULT_REQUEST_BODY_MAX = 1024

// MAX_ALLOCATION_SIZE is the maximum size, in bytes, of a single allocation
// of bytes that will be made handling a single HTTP request.
// 10MB
MAX_ALLOCATION_SIZE = 10 * 1024 * 1024
)

type HttpListener struct {
ServiceAddress string
ReadTimeout internal.Duration
WriteTimeout internal.Duration
MaxBodySizeMb int64

sync.Mutex
wg sync.WaitGroup

listener *stoppableListener.StoppableListener

Expand All @@ -40,6 +52,10 @@ const sampleConfig = `
## timeouts
read_timeout = "10s"
write_timeout = "10s"
## Maximum allowed http request body size in megabytes.
## 0 means to use the default (1024 MB)
max_body_size_mb = 0
`

func (t *HttpListener) SampleConfig() string {
Expand All @@ -63,6 +79,10 @@ func (t *HttpListener) Start(acc telegraf.Accumulator) error {
t.Lock()
defer t.Unlock()

if t.MaxBodySizeMb == 0 {
t.MaxBodySizeMb = DEFAULT_REQUEST_BODY_MAX
}

t.acc = acc

var rawListener, err = net.Listen("tcp", t.ServiceAddress)
Expand All @@ -89,8 +109,6 @@ func (t *HttpListener) Stop() {
t.listener.Stop()
t.listener.Close()

t.wg.Wait()

log.Println("I! Stopped HTTP listener service on ", t.ServiceAddress)
}

Expand All @@ -113,58 +131,90 @@ func (t *HttpListener) httpListen() error {
}

func (t *HttpListener) ServeHTTP(res http.ResponseWriter, req *http.Request) {
t.wg.Add(1)
defer t.wg.Done()

switch req.URL.Path {
case "/write":
var http400msg bytes.Buffer
var msg413 bytes.Buffer
var msg400 bytes.Buffer
defer func() {
if http400msg.Len() > 0 {
if msg413.Len() > 0 {
res.WriteHeader(http.StatusRequestEntityTooLarge)
res.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, msg413.String())))
} else if msg400.Len() > 0 {
res.Header().Set("Content-Type", "application/json")
res.Header().Set("X-Influxdb-Version", "1.0")
res.WriteHeader(http.StatusBadRequest)
res.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, http400msg.String())))
res.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, msg400.String())))
} else {
res.WriteHeader(http.StatusNoContent)
}
}()

body := req.Body
// Check that the content length is not too large for us to handle.
if req.ContentLength > t.MaxBodySizeMb*1024*1024 {
msg413.WriteString("http: request body too large")
return
}

// Handle gzip request bodies
var body io.ReadCloser
if req.Header.Get("Content-Encoding") == "gzip" {
b, err := gzip.NewReader(req.Body)
body, err := gzip.NewReader(http.MaxBytesReader(res, req.Body, t.MaxBodySizeMb*1024*1024))
if err != nil {
http400msg.WriteString(err.Error() + " ")
msg400.WriteString(err.Error() + " ")
return
}
defer b.Close()
body = b
} else {
body = http.MaxBytesReader(res, req.Body, t.MaxBodySizeMb*1024*1024)
}

allocSize := 512
if req.ContentLength < MAX_REQUEST_BODY_SIZE {
allocSize = int(req.ContentLength)
}
buf := bytes.NewBuffer(make([]byte, 0, allocSize))
_, err := buf.ReadFrom(http.MaxBytesReader(res, body, MAX_REQUEST_BODY_SIZE))
if err != nil {
log.Printf("E! HttpListener unable to read request body. error: %s\n", err.Error())
http400msg.WriteString("HttpHandler unable to read from request body: " + err.Error())
return
// Set the maximum size of the buffer that we will allocate at a time.
// The following loop goes through the request body byte-by-byte.
// if the body is larger than the max allocation size, then when we find
// a newline within 5% of the end of the body we will attempt to
// parse and write the buffer.
allocSize := MAX_ALLOCATION_SIZE
if req.ContentLength < MAX_ALLOCATION_SIZE {
// content length doesn't include trailing whitespace:
allocSize = int(req.ContentLength) + 1
}

metrics, err := t.parser.Parse(buf.Bytes())
if err != nil {
log.Printf("E! HttpListener unable to parse metrics. error: %s \n", err.Error())
if len(metrics) == 0 {
http400msg.WriteString(err.Error())
} else {
http400msg.WriteString("partial write: " + err.Error())
scanLim := int(float64(allocSize) * 0.95)
buffer := bytes.NewBuffer(make([]byte, 0, allocSize))
reader := bufio.NewReader(body)
for {
b, err := reader.ReadByte()
if err != nil {
if err != io.EOF {
log.Printf("E! %s", err)
// if it's not an EOF error, then it's almost certainly a
// tooLarge error coming from http.MaxBytesReader. It's unlikely
// that this code path will get hit because the client should
// be setting the ContentLength header, unless it's malicious.
msg413.WriteString(err.Error())
return
}
break
}
// returned error is always nil:
// https://golang.org/pkg/bytes/#Buffer.WriteByte
buffer.WriteByte(b)
// if we have a newline and we're nearing the end of the buffer,
// do a write and continue with a fresh buffer.
if buffer.Len() > scanLim && b == '\n' {
t.parse(buffer.Bytes(), &msg400)
buffer.Reset()
} else if buffer.Len() == buffer.Cap() {
// we've reached the end of our buffer without finding a newline
// in the body, so we insert a newline here and attempt to parse.
if buffer.Len() == 0 {
continue
}
buffer.WriteByte('\n')
t.parse(buffer.Bytes(), &msg400)
buffer.Reset()
}
}

for _, m := range metrics {
t.acc.AddFields(m.Name(), m.Fields(), m.Tags(), m.Time())
if buffer.Len() != 0 {
t.parse(buffer.Bytes(), &msg400)
}
case "/query":
// Deliver a dummy response to the query endpoint, as some InfluxDB
Expand All @@ -177,11 +227,25 @@ func (t *HttpListener) ServeHTTP(res http.ResponseWriter, req *http.Request) {
// respond to ping requests
res.WriteHeader(http.StatusNoContent)
default:
// Don't know how to respond to calls to other endpoints
http.NotFound(res, req)
}
}

func (t *HttpListener) parse(b []byte, errmsg *bytes.Buffer) {
metrics, err := t.parser.Parse(b)
if err != nil {
if len(metrics) == 0 {
errmsg.WriteString(err.Error())
} else {
errmsg.WriteString("partial write: " + err.Error())
}
}

for _, m := range metrics {
t.acc.AddFields(m.Name(), m.Fields(), m.Tags(), m.Time())
}
}

func init() {
inputs.Add("http_listener", func() telegraf.Input {
return &HttpListener{}
Expand Down
21 changes: 18 additions & 3 deletions plugins/inputs/http_listener/http_listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"
"time"

"github.com/influxdata/telegraf/internal"
"github.com/influxdata/telegraf/plugins/parsers"
"github.com/influxdata/telegraf/testutil"

Expand Down Expand Up @@ -133,7 +134,11 @@ func TestWriteHTTPHighTraffic(t *testing.T) {

// // writes 5000 metrics to the listener with 10 different writers
func TestWriteHTTPHighBatchSize(t *testing.T) {
listener := &HttpListener{ServiceAddress: ":8287"}
listener := &HttpListener{
ServiceAddress: ":8287",
ReadTimeout: internal.Duration{Duration: time.Second * 30},
WriteTimeout: internal.Duration{Duration: time.Second * 30},
}
parser, _ := parsers.NewInfluxParser()
listener.SetParser(parser)

Expand All @@ -143,19 +148,29 @@ func TestWriteHTTPHighBatchSize(t *testing.T) {

time.Sleep(time.Millisecond * 25)

type result struct {
err error
resp *http.Response
}
results := make(chan *result, 10)
// post many messages to listener
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := http.Post("http://localhost:8287/write?db=mydb", "", bytes.NewBuffer(makeMetricsBatch(5000)))
require.NoError(t, err)
require.EqualValues(t, 204, resp.StatusCode)
results <- &result{err: err, resp: resp}
}()
}

wg.Wait()
close(results)
for result := range results {
require.NoError(t, result.err)
require.EqualValues(t, 204, result.resp.StatusCode)
}

time.Sleep(time.Millisecond * 50)
listener.Gather(acc)

Expand Down

0 comments on commit bd0af9a

Please sign in to comment.