diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index dcd0a1514a1..8aee8051709 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -27,6 +27,7 @@ https://github.com/elastic/beats/compare/v7.0.0-alpha2...master[Check the HEAD d - Add read_buffer configuration option. {pull}11739[11739] - `convert_timezone` option is removed and locale is always added to the event so timezone is used when parsing the timestamp, this behaviour can be overriden with processors. {pull}12410[12410] +- Fix a race condition in the TCP input when close the client socket. {pull}13038[13038] *Heartbeat* diff --git a/filebeat/inputsource/tcp/closeref.go b/filebeat/inputsource/tcp/closeref.go new file mode 100644 index 00000000000..d718df42343 --- /dev/null +++ b/filebeat/inputsource/tcp/closeref.go @@ -0,0 +1,130 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "sync" + + "github.com/pkg/errors" +) + +// CloserFunc is the function called by the Closer on `Close()`. +type CloserFunc func() + +// ErrClosed is returned when the Closer is closed. +var ErrClosed = errors.New("closer is closed") + +// CloseRef implements a subset of the context.Context interface and it's use to synchronize +// the shutdown of multiple go-routines. +type CloseRef interface { + Done() <-chan struct{} + Err() error +} + +// Closer implements a shutdown strategy when dealing with multiples go-routines, it creates a tree +// of Closer, when you call `Close()` on a parent the `Close()` method will be called on the current +// closer and any of the childs it may have and will remove the current node from the parent. +// +// NOTE: The `Close()` is reentrant but will propage the close only once. +type Closer struct { + mu sync.Mutex + done chan struct{} + err error + parent *Closer + children map[*Closer]struct{} + callback CloserFunc +} + +// Close closes the closes and propagates the close to any child, on close the close callback will +// be called, this can be used for custom cleanup like closing a TCP socket. +func (c *Closer) Close() { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return + } + + if c.callback != nil { + c.callback() + } + + close(c.done) + + // propagate close to children. + if c.children != nil { + for child := range c.children { + child.Close() + } + c.children = nil + } + + c.err = ErrClosed + c.mu.Unlock() + + if c.parent != nil { + c.removeChild(c) + } +} + +// Done returns the synchronization channel, the channel will be closed if `Close()` was called on +// the current node or any parent it may have. +func (c *Closer) Done() <-chan struct{} { + return c.done +} + +// Err returns an error if the Closer was already closed. +func (c *Closer) Err() error { + c.mu.Lock() + err := c.err + c.mu.Unlock() + return err +} + +func (c *Closer) removeChild(child *Closer) { + c.mu.Lock() + delete(c.children, child) + c.mu.Unlock() +} + +func (c *Closer) addChild(child *Closer) { + c.mu.Lock() + if c.children == nil { + c.children = make(map[*Closer]struct{}) + } + c.children[child] = struct{}{} + c.mu.Unlock() +} + +// WithCloser wraps a new closer into a child of an existing closer. +func WithCloser(parent *Closer, fn CloserFunc) *Closer { + child := &Closer{ + done: make(chan struct{}), + parent: parent, + callback: fn, + } + parent.addChild(child) + return child +} + +// NewCloser creates a new Closer. +func NewCloser(fn CloserFunc) *Closer { + return &Closer{ + done: make(chan struct{}), + callback: fn, + } +} diff --git a/filebeat/inputsource/tcp/client.go b/filebeat/inputsource/tcp/handler.go similarity index 89% rename from filebeat/inputsource/tcp/client.go rename to filebeat/inputsource/tcp/handler.go index 1e5769a797a..eeb5d5088b6 100644 --- a/filebeat/inputsource/tcp/client.go +++ b/filebeat/inputsource/tcp/handler.go @@ -33,7 +33,6 @@ import ( // splitHandler is a TCP client that has splitting capabilities. type splitHandler struct { - conn net.Conn callback inputsource.NetworkFunc done chan struct{} metadata inputsource.NetworkMetadata @@ -42,19 +41,23 @@ type splitHandler struct { timeout time.Duration } -// ClientFactory returns a ConnectionHandler func -type ClientFactory func(config Config) ConnectionHandler +// HandlerFactory returns a ConnectionHandler func +type HandlerFactory func(config Config) ConnectionHandler // ConnectionHandler interface provides mechanisms for handling of incoming TCP connections type ConnectionHandler interface { - Handle(conn net.Conn) error - Close() + Handle(CloseRef, net.Conn) error } // SplitHandlerFactory allows creation of a ConnectionHandler that can do splitting of messages received on a TCP connection. -func SplitHandlerFactory(callback inputsource.NetworkFunc, splitFunc bufio.SplitFunc) ClientFactory { +func SplitHandlerFactory(callback inputsource.NetworkFunc, splitFunc bufio.SplitFunc) HandlerFactory { return func(config Config) ConnectionHandler { - return newSplitHandler(callback, splitFunc, uint64(config.MaxMessageSize), config.Timeout) + return newSplitHandler( + callback, + splitFunc, + uint64(config.MaxMessageSize), + config.Timeout, + ) } } @@ -76,8 +79,7 @@ func newSplitHandler( } // Handle takes a connection as input and processes data received on it. -func (c *splitHandler) Handle(conn net.Conn) error { - c.conn = conn +func (c *splitHandler) Handle(closer CloseRef, conn net.Conn) error { c.metadata = inputsource.NetworkMetadata{ RemoteAddr: conn.RemoteAddr(), TLS: extractSSLInformation(conn), @@ -97,7 +99,7 @@ func (c *splitHandler) Handle(conn net.Conn) error { if err != nil { // we are forcing a Close on the socket, lets ignore any error that could happen. select { - case <-c.done: + case <-closer.Done(): break default: } @@ -121,12 +123,6 @@ func (c *splitHandler) Handle(conn net.Conn) error { return nil } -// Close is used to perform clean up before the client is released. -func (c *splitHandler) Close() { - close(c.done) - c.conn.Close() -} - func extractSSLInformation(c net.Conn) *inputsource.TLSMetadata { if tls, ok := c.(*tls.Conn); ok { state := tls.ConnectionState() diff --git a/filebeat/inputsource/tcp/server.go b/filebeat/inputsource/tcp/server.go index e1e6225cb38..df7343f9947 100644 --- a/filebeat/inputsource/tcp/server.go +++ b/filebeat/inputsource/tcp/server.go @@ -27,6 +27,7 @@ import ( "golang.org/x/net/netutil" + "github.com/elastic/beats/libbeat/common/atomic" "github.com/elastic/beats/libbeat/common/transport/tlscommon" "github.com/elastic/beats/libbeat/logp" "github.com/elastic/beats/libbeat/outputs/transport" @@ -34,21 +35,21 @@ import ( // Server represent a TCP server type Server struct { - sync.RWMutex - config *Config - Listener net.Listener - clients map[ConnectionHandler]struct{} - wg sync.WaitGroup - done chan struct{} - factory ClientFactory - log *logp.Logger - tlsConfig *transport.TLSConfig + config *Config + Listener net.Listener + wg sync.WaitGroup + done chan struct{} + factory HandlerFactory + log *logp.Logger + tlsConfig *transport.TLSConfig + closer *Closer + clientsCount atomic.Int } // New creates a new tcp server func New( config *Config, - factory ClientFactory, + factory HandlerFactory, ) (*Server, error) { tlsConfig, err := tlscommon.LoadTLSServerConfig(config.TLS) if err != nil { @@ -56,16 +57,16 @@ func New( } if factory == nil { - return nil, fmt.Errorf("ClientFactory can't be empty") + return nil, fmt.Errorf("HandlerFactory can't be empty") } return &Server{ config: config, - clients: make(map[ConnectionHandler]struct{}, 0), done: make(chan struct{}), factory: factory, log: logp.NewLogger("tcp").With("address", config.Host), tlsConfig: tlsConfig, + closer: NewCloser(nil), }, nil } @@ -77,6 +78,7 @@ func (s *Server) Start() error { return err } + s.closer.callback = func() { s.Listener.Close() } s.log.Info("Started listening for TCP connection") s.wg.Add(1) @@ -97,7 +99,7 @@ func (s *Server) run() { conn, err := s.Listener.Accept() if err != nil { select { - case <-s.done: + case <-s.closer.Done(): return default: s.log.Debugw("Can not accept the connection", "error", err) @@ -105,19 +107,20 @@ func (s *Server) run() { } } - client := s.factory(*s.config) + handler := s.factory(*s.config) + closer := WithCloser(s.closer, func() { conn.Close() }) s.wg.Add(1) go func() { defer logp.Recover("recovering from a tcp client crash") defer s.wg.Done() - defer conn.Close() + defer closer.Close() - s.registerClient(client) - defer s.unregisterClient(client) - s.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", s.clientsCount()) + s.registerHandler() + defer s.unregisterHandler() + s.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", s.clientsCount.Load()) - err := client.Handle(conn) + err := handler.Handle(closer, conn) if err != nil { s.log.Debugw("client error", "error", err) } @@ -127,7 +130,7 @@ func (s *Server) run() { "remote_address", conn.RemoteAddr(), "total", - s.clientsCount(), + s.clientsCount.Load(), ) }() } @@ -136,37 +139,17 @@ func (s *Server) run() { // Stop stops accepting new incoming TCP connection and Close any active clients func (s *Server) Stop() { s.log.Info("Stopping TCP server") - close(s.done) - s.Listener.Close() - for _, client := range s.allClients() { - client.Close() - } + s.closer.Close() s.wg.Wait() s.log.Info("TCP server stopped") } -func (s *Server) registerClient(client ConnectionHandler) { - s.Lock() - defer s.Unlock() - s.clients[client] = struct{}{} -} - -func (s *Server) unregisterClient(client ConnectionHandler) { - s.Lock() - defer s.Unlock() - delete(s.clients, client) +func (s *Server) registerHandler() { + s.clientsCount.Inc() } -func (s *Server) allClients() []ConnectionHandler { - s.RLock() - defer s.RUnlock() - currentClients := make([]ConnectionHandler, len(s.clients)) - idx := 0 - for client := range s.clients { - currentClients[idx] = client - idx++ - } - return currentClients +func (s *Server) unregisterHandler() { + s.clientsCount.Dec() } func (s *Server) createServer() (net.Listener, error) { @@ -192,12 +175,6 @@ func (s *Server) createServer() (net.Listener, error) { return l, nil } -func (s *Server) clientsCount() int { - s.RLock() - defer s.RUnlock() - return len(s.clients) -} - // SplitFunc allows to create a `bufio.SplitFunc` based on a delimiter provided. func SplitFunc(lineDelimiter []byte) bufio.SplitFunc { ld := []byte(lineDelimiter) diff --git a/filebeat/inputsource/tcp/server_test.go b/filebeat/inputsource/tcp/server_test.go index df63417a841..82e89fb72a7 100644 --- a/filebeat/inputsource/tcp/server_test.go +++ b/filebeat/inputsource/tcp/server_test.go @@ -28,6 +28,7 @@ import ( "github.com/dustin/go-humanize" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/elastic/beats/filebeat/inputsource" "github.com/elastic/beats/libbeat/common" @@ -180,7 +181,7 @@ func TestReceiveEventsAndMetadata(t *testing.T) { defer server.Stop() conn, err := net.Dial("tcp", server.Listener.Addr().String()) - assert.NoError(t, err) + require.NoError(t, err) fmt.Fprint(conn, test.messageSent) conn.Close()