From 075529e2fef98b295ba6e45a3f36d5ebd99b21a2 Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Wed, 5 Aug 2015 22:05:58 -0700 Subject: [PATCH] Properly handle and enforce max payload --- server/client.go | 19 +++++++++-- server/configs/test.conf | 15 ++++++++- server/errors.go | 5 ++- server/opts.go | 9 ++++++ server/opts_test.go | 64 +++++++++++++++++++++----------------- server/parser_test.go | 9 ++++++ server/server.go | 4 +-- test/bench_test.go | 8 +++-- test/cluster_test.go | 10 ++---- test/configs/override.conf | 9 ++++++ test/gosrv_test.go | 6 ++-- test/maxpayload_test.go | 36 +++++++++++++++++++++ test/opts_test.go | 21 +++++++++++++ test/test.go | 15 +++++++-- 14 files changed, 180 insertions(+), 50 deletions(-) create mode 100644 test/configs/override.conf create mode 100644 test/maxpayload_test.go create mode 100644 test/opts_test.go diff --git a/server/client.go b/server/client.go index 62d8235abe4..97a9d3a7088 100644 --- a/server/client.go +++ b/server/client.go @@ -39,6 +39,7 @@ type client struct { lang string opts clientOpts nc net.Conn + mpay int ncs string bw *bufio.Writer srv *Server @@ -153,9 +154,9 @@ func (c *client) readLoop() { return } if err := c.parse(b[:n]); err != nil { - c.Errorf("Error reading from client: %s", err.Error()) - // Auth was handled inline - if err != ErrAuthorization { + // handled inline + if err != ErrMaxPayload && err != ErrAuthorization { + c.Errorf("Error reading from client: %s", err.Error()) c.sendErr("Parser Error") c.closeConnection() } @@ -297,10 +298,17 @@ func (c *client) authTimeout() { } func (c *client) authViolation() { + c.Errorf(ErrAuthorization.Error()) c.sendErr("Authorization Violation") c.closeConnection() } +func (c *client) maxPayloadViolation(sz int) { + c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, c.mpay) + c.sendErr("Maximum Payload Violation") + c.closeConnection() +} + func (c *client) sendErr(err string) { c.mu.Lock() if c.bw != nil { @@ -430,6 +438,11 @@ func (c *client) processPub(arg []byte) error { if c.pa.size < 0 { return fmt.Errorf("processPub Bad or Missing Size: '%s'", arg) } + if c.mpay > 0 && c.pa.size > c.mpay { + c.maxPayloadViolation(c.pa.size) + return ErrMaxPayload + } + if c.opts.Pedantic && !sublist.IsValidLiteralSubject(c.pa.subject) { c.sendErr("Invalid Subject") } diff --git a/server/configs/test.conf b/server/configs/test.conf index 62e6819cc07..b27215a25cd 100644 --- a/server/configs/test.conf +++ b/server/configs/test.conf @@ -20,8 +20,21 @@ log_file: "/tmp/gnatsd.log" syslog: true remote_syslog: "udp://foo.com:33" -#pid file +# pid file pid_file: "/tmp/gnatsd.pid" +# prof_port prof_port: 6543 +# max_connections +max_connections: 100 + +# maximum control line +max_control_line: 2048 + +# maximum payload +max_payload: 65536 + +# slow consumer threshold +max_pending_size: 10000000 + diff --git a/server/errors.go b/server/errors.go index 2fe22ba26a0..fec29c8ab5a 100644 --- a/server/errors.go +++ b/server/errors.go @@ -6,8 +6,11 @@ import "errors" var ( // ErrConnectionClosed represents error condition on a closed connection. - ErrConnectionClosed = errors.New("Connection closed") + ErrConnectionClosed = errors.New("Connection Closed") // ErrAuthorization represents error condition on failed authorization. ErrAuthorization = errors.New("Authorization Error") + + // ErrMaxPayload represents error condition when the payload is too big. + ErrMaxPayload = errors.New("Maximum Payload Exceeded") ) diff --git a/server/opts.go b/server/opts.go index e5a0a074314..20bfe6ec048 100644 --- a/server/opts.go +++ b/server/opts.go @@ -34,6 +34,7 @@ type Options struct { AuthTimeout float64 `json:"auth_timeout"` MaxControlLine int `json:"max_control_line"` MaxPayload int `json:"max_payload"` + MaxPending int `json:"max_pending_size"` ClusterHost string `json:"addr"` ClusterPort int `json:"port"` ClusterUsername string `json:"-"` @@ -107,6 +108,14 @@ func ProcessConfigFile(configFile string) (*Options, error) { opts.PidFile = v.(string) case "prof_port": opts.ProfPort = int(v.(int64)) + case "max_control_line": + opts.MaxControlLine = int(v.(int64)) + case "max_payload": + opts.MaxPayload = int(v.(int64)) + case "max_pending_size", "max_pending": + opts.MaxPending = int(v.(int64)) + case "max_connections", "max_conn": + opts.MaxConn = int(v.(int64)) } } return opts, nil diff --git a/server/opts_test.go b/server/opts_test.go index 05c70212cca..3ef83ac83c9 100644 --- a/server/opts_test.go +++ b/server/opts_test.go @@ -44,20 +44,24 @@ func TestOptions_RandomPort(t *testing.T) { func TestConfigFile(t *testing.T) { golden := &Options{ - Host: "apcera.me", - Port: 4242, - Username: "derek", - Password: "bella", - AuthTimeout: 1.0, - Debug: false, - Trace: true, - Logtime: false, - HTTPPort: 8222, - LogFile: "/tmp/gnatsd.log", - PidFile: "/tmp/gnatsd.pid", - ProfPort: 6543, - Syslog: true, - RemoteSyslog: "udp://foo.com:33", + Host: "apcera.me", + Port: 4242, + Username: "derek", + Password: "bella", + AuthTimeout: 1.0, + Debug: false, + Trace: true, + Logtime: false, + HTTPPort: 8222, + LogFile: "/tmp/gnatsd.log", + PidFile: "/tmp/gnatsd.pid", + ProfPort: 6543, + Syslog: true, + RemoteSyslog: "udp://foo.com:33", + MaxControlLine: 2048, + MaxPayload: 65536, + MaxConn: 100, + MaxPending: 10000000, } opts, err := ProcessConfigFile("./configs/test.conf") @@ -73,20 +77,24 @@ func TestConfigFile(t *testing.T) { func TestMergeOverrides(t *testing.T) { golden := &Options{ - Host: "apcera.me", - Port: 2222, - Username: "derek", - Password: "spooky", - AuthTimeout: 1.0, - Debug: true, - Trace: true, - Logtime: false, - HTTPPort: DEFAULT_HTTP_PORT, - LogFile: "/tmp/gnatsd.log", - PidFile: "/tmp/gnatsd.pid", - ProfPort: 6789, - Syslog: true, - RemoteSyslog: "udp://foo.com:33", + Host: "apcera.me", + Port: 2222, + Username: "derek", + Password: "spooky", + AuthTimeout: 1.0, + Debug: true, + Trace: true, + Logtime: false, + HTTPPort: DEFAULT_HTTP_PORT, + LogFile: "/tmp/gnatsd.log", + PidFile: "/tmp/gnatsd.pid", + ProfPort: 6789, + Syslog: true, + RemoteSyslog: "udp://foo.com:33", + MaxControlLine: 2048, + MaxPayload: 65536, + MaxConn: 100, + MaxPending: 10000000, } fopts, err := ProcessConfigFile("./configs/test.conf") if err != nil { diff --git a/server/parser_test.go b/server/parser_test.go index e2ce76a0e27..64e0283e7c4 100644 --- a/server/parser_test.go +++ b/server/parser_test.go @@ -223,6 +223,15 @@ func TestParsePubArg(t *testing.T) { testPubArg(c, t) } +func TestParsePubBadSize(t *testing.T) { + c := dummyClient() + // Setup localized max payload + c.mpay = 32768 + if err := c.processPub([]byte("foo 2222222222222222\r")); err == nil { + t.Fatalf("Expected parse error for size too large") + } +} + func TestParseMsg(t *testing.T) { c := dummyClient() diff --git a/server/server.go b/server/server.go index 500fd017ea6..3ba4f46b7af 100644 --- a/server/server.go +++ b/server/server.go @@ -81,7 +81,7 @@ func New(opts *Options) *Server { Port: opts.Port, AuthRequired: false, SslRequired: false, - MaxPayload: MAX_PAYLOAD_SIZE, + MaxPayload: opts.MaxPayload, } s := &Server{ @@ -380,7 +380,7 @@ func (s *Server) StartHTTPMonitoring() { } func (s *Server) createClient(conn net.Conn) *client { - c := &client{srv: s, nc: conn, opts: defaultOpts} + c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: s.info.MaxPayload} // Grab lock c.mu.Lock() diff --git a/test/bench_test.go b/test/bench_test.go index 75715e82e32..6056e4d7e65 100644 --- a/test/bench_test.go +++ b/test/bench_test.go @@ -60,12 +60,16 @@ func benchPub(b *testing.B, subject, payload string) { var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()") -func sizedString(sz int) string { +func sizedBytes(sz int) []byte { b := make([]byte, sz) for i := range b { b[i] = ch[rand.Intn(len(ch))] } - return string(b) + return b +} + +func sizedString(sz int) string { + return string(sizedBytes(sz)) } func Benchmark___PubNo_Payload(b *testing.B) { diff --git a/test/cluster_test.go b/test/cluster_test.go index 74e8c25370f..12161e872e9 100644 --- a/test/cluster_test.go +++ b/test/cluster_test.go @@ -12,14 +12,8 @@ import ( ) func runServers(t *testing.T) (srvA, srvB *server.Server, optsA, optsB *server.Options) { - optsA, _ = server.ProcessConfigFile("./configs/srv_a.conf") - optsB, _ = server.ProcessConfigFile("./configs/srv_b.conf") - - optsA.NoSigs, optsA.NoLog = true, true - optsB.NoSigs, optsB.NoLog = true, true - - srvA = RunServer(optsA) - srvB = RunServer(optsB) + srvA, optsA = RunServerWithConfig("./configs/srv_a.conf") + srvB, optsB = RunServerWithConfig("./configs/srv_b.conf") return } diff --git a/test/configs/override.conf b/test/configs/override.conf new file mode 100644 index 00000000000..6bf4339ec7f --- /dev/null +++ b/test/configs/override.conf @@ -0,0 +1,9 @@ +# Copyright 2015 Apcera Inc. All rights reserved. + +# Config file to test overrides to client + +port: 4224 + +# maximum payload +max_payload: 2222 + diff --git a/test/gosrv_test.go b/test/gosrv_test.go index 104af5b3268..3be9d9181c2 100644 --- a/test/gosrv_test.go +++ b/test/gosrv_test.go @@ -10,7 +10,7 @@ import ( func TestSimpleGoServerShutdown(t *testing.T) { base := runtime.NumGoroutine() - s := runDefaultServer() + s := RunDefaultServer() s.Shutdown() time.Sleep(10 * time.Millisecond) delta := (runtime.NumGoroutine() - base) @@ -21,7 +21,7 @@ func TestSimpleGoServerShutdown(t *testing.T) { func TestGoServerShutdownWithClients(t *testing.T) { base := runtime.NumGoroutine() - s := runDefaultServer() + s := RunDefaultServer() for i := 0; i < 50; i++ { createClientConn(t, "localhost", 4222) } @@ -37,7 +37,7 @@ func TestGoServerShutdownWithClients(t *testing.T) { } func TestGoServerMultiShutdown(t *testing.T) { - s := runDefaultServer() + s := RunDefaultServer() s.Shutdown() s.Shutdown() } diff --git a/test/maxpayload_test.go b/test/maxpayload_test.go new file mode 100644 index 00000000000..f372fb0ba09 --- /dev/null +++ b/test/maxpayload_test.go @@ -0,0 +1,36 @@ +// Copyright 2015 Apcera Inc. All rights reserved. + +package test + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/nats-io/nats" +) + +func TestMaxPayload(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/override.conf") + defer srv.Shutdown() + + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d/", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Could not connect to server: %v", err) + } + defer nc.Close() + + big := sizedBytes(4 * 1024 * 1024) + nc.Publish("foo", big) + err = nc.FlushTimeout(1 * time.Second) + if err == nil { + t.Fatalf("Expected an error from flush") + } + if strings.Contains(err.Error(), "Maximum Payload Violation") != true { + t.Fatalf("Received wrong error message (%v)\n", err) + } + if !nc.IsClosed() { + t.Fatalf("Expected connection to be closed") + } +} diff --git a/test/opts_test.go b/test/opts_test.go new file mode 100644 index 00000000000..b94ac7c30e1 --- /dev/null +++ b/test/opts_test.go @@ -0,0 +1,21 @@ +// Copyright 2015 Apcera Inc. All rights reserved. + +package test + +import ( + "testing" +) + +func TestServerConfig(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/override.conf") + defer srv.Shutdown() + + c := createClientConn(t, opts.Host, opts.Port) + defer c.Close() + + sinfo := checkInfoMsg(t, c) + if sinfo.MaxPayload != opts.MaxPayload { + t.Fatalf("Expected max_payload from server, got %d vs %d", + opts.MaxPayload, sinfo.MaxPayload) + } +} diff --git a/test/test.go b/test/test.go index cd1e50b0c73..1e5b659a6e7 100644 --- a/test/test.go +++ b/test/test.go @@ -38,7 +38,7 @@ var DefaultTestOptions = server.Options{ NoSigs: true, } -func runDefaultServer() *server.Server { +func RunDefaultServer() *server.Server { return RunServer(&DefaultTestOptions) } @@ -47,6 +47,16 @@ func RunServer(opts *server.Options) *server.Server { return RunServerWithAuth(opts, nil) } +func RunServerWithConfig(configFile string) (srv *server.Server, opts *server.Options) { + opts, err := server.ProcessConfigFile(configFile) + if err != nil { + panic(fmt.Sprintf("Error processing configuration file: %v", err)) + } + opts.NoSigs, opts.NoLog = true, true + srv = RunServer(opts) + return +} + // New Go Routine based server with auth func RunServerWithAuth(opts *server.Options, auth server.Auth) *server.Server { if opts == nil { @@ -193,7 +203,7 @@ func checkSocket(t tLogger, addr string, wait time.Duration) { t.Fatalf("Failed to connect to the socket: %q", addr) } -func checkInfoMsg(t tLogger, c net.Conn) { +func checkInfoMsg(t tLogger, c net.Conn) server.Info { buf := expectResult(t, c, infoRe) js := infoRe.FindAllSubmatch(buf, 1)[0][1] var sinfo server.Info @@ -201,6 +211,7 @@ func checkInfoMsg(t tLogger, c net.Conn) { if err != nil { stackFatalf(t, "Could not unmarshal INFO json: %v\n", err) } + return sinfo } func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) {