Skip to content

Commit

Permalink
Fix a race condition in the TCP input (#13038)
Browse files Browse the repository at this point in the history
* Fix a race condition in the TCP input

Pass the `net.Conn` object when creating the client instead of passing it to the
`Handle()` method that keep a reference to it. By doing this we do not
have to worry about read or write race over the internal field. The
client still need a reference to the connection when the out of bound
call to `Close()` is executed to make sure we are getting out of a
`Read()` call early.

Tested with :

```
while true; do go test -v -race; done
```

Fixes: #12982
  • Loading branch information
ph committed Jul 25, 2019
1 parent b23c3b7 commit 106ad06
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 68 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.next.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ https://github.com/elastic/beats/compare/v7.0.0-alpha2...master[Check the HEAD d

- Add read_buffer configuration option. {pull}11739[11739]
- `convert_timezone` option is removed and locale is always added to the event so timezone is used when parsing the timestamp, this behaviour can be overriden with processors. {pull}12410[12410]
- Fix a race condition in the TCP input when close the client socket. {pull}13038[13038]

*Heartbeat*

Expand Down
130 changes: 130 additions & 0 deletions filebeat/inputsource/tcp/closeref.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package tcp

import (
"sync"

"github.com/pkg/errors"
)

// CloserFunc is the function called by the Closer on `Close()`.
type CloserFunc func()

// ErrClosed is returned when the Closer is closed.
var ErrClosed = errors.New("closer is closed")

// CloseRef implements a subset of the context.Context interface and it's use to synchronize
// the shutdown of multiple go-routines.
type CloseRef interface {
Done() <-chan struct{}
Err() error
}

// Closer implements a shutdown strategy when dealing with multiples go-routines, it creates a tree
// of Closer, when you call `Close()` on a parent the `Close()` method will be called on the current
// closer and any of the childs it may have and will remove the current node from the parent.
//
// NOTE: The `Close()` is reentrant but will propage the close only once.
type Closer struct {
mu sync.Mutex
done chan struct{}
err error
parent *Closer
children map[*Closer]struct{}
callback CloserFunc
}

// Close closes the closes and propagates the close to any child, on close the close callback will
// be called, this can be used for custom cleanup like closing a TCP socket.
func (c *Closer) Close() {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return
}

if c.callback != nil {
c.callback()
}

close(c.done)

// propagate close to children.
if c.children != nil {
for child := range c.children {
child.Close()
}
c.children = nil
}

c.err = ErrClosed
c.mu.Unlock()

if c.parent != nil {
c.removeChild(c)
}
}

// Done returns the synchronization channel, the channel will be closed if `Close()` was called on
// the current node or any parent it may have.
func (c *Closer) Done() <-chan struct{} {
return c.done
}

// Err returns an error if the Closer was already closed.
func (c *Closer) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
}

func (c *Closer) removeChild(child *Closer) {
c.mu.Lock()
delete(c.children, child)
c.mu.Unlock()
}

func (c *Closer) addChild(child *Closer) {
c.mu.Lock()
if c.children == nil {
c.children = make(map[*Closer]struct{})
}
c.children[child] = struct{}{}
c.mu.Unlock()
}

// WithCloser wraps a new closer into a child of an existing closer.
func WithCloser(parent *Closer, fn CloserFunc) *Closer {
child := &Closer{
done: make(chan struct{}),
parent: parent,
callback: fn,
}
parent.addChild(child)
return child
}

// NewCloser creates a new Closer.
func NewCloser(fn CloserFunc) *Closer {
return &Closer{
done: make(chan struct{}),
callback: fn,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (

// splitHandler is a TCP client that has splitting capabilities.
type splitHandler struct {
conn net.Conn
callback inputsource.NetworkFunc
done chan struct{}
metadata inputsource.NetworkMetadata
Expand All @@ -42,19 +41,23 @@ type splitHandler struct {
timeout time.Duration
}

// ClientFactory returns a ConnectionHandler func
type ClientFactory func(config Config) ConnectionHandler
// HandlerFactory returns a ConnectionHandler func
type HandlerFactory func(config Config) ConnectionHandler

// ConnectionHandler interface provides mechanisms for handling of incoming TCP connections
type ConnectionHandler interface {
Handle(conn net.Conn) error
Close()
Handle(CloseRef, net.Conn) error
}

// SplitHandlerFactory allows creation of a ConnectionHandler that can do splitting of messages received on a TCP connection.
func SplitHandlerFactory(callback inputsource.NetworkFunc, splitFunc bufio.SplitFunc) ClientFactory {
func SplitHandlerFactory(callback inputsource.NetworkFunc, splitFunc bufio.SplitFunc) HandlerFactory {
return func(config Config) ConnectionHandler {
return newSplitHandler(callback, splitFunc, uint64(config.MaxMessageSize), config.Timeout)
return newSplitHandler(
callback,
splitFunc,
uint64(config.MaxMessageSize),
config.Timeout,
)
}
}

Expand All @@ -76,8 +79,7 @@ func newSplitHandler(
}

// Handle takes a connection as input and processes data received on it.
func (c *splitHandler) Handle(conn net.Conn) error {
c.conn = conn
func (c *splitHandler) Handle(closer CloseRef, conn net.Conn) error {
c.metadata = inputsource.NetworkMetadata{
RemoteAddr: conn.RemoteAddr(),
TLS: extractSSLInformation(conn),
Expand All @@ -97,7 +99,7 @@ func (c *splitHandler) Handle(conn net.Conn) error {
if err != nil {
// we are forcing a Close on the socket, lets ignore any error that could happen.
select {
case <-c.done:
case <-closer.Done():
break
default:
}
Expand All @@ -121,12 +123,6 @@ func (c *splitHandler) Handle(conn net.Conn) error {
return nil
}

// Close is used to perform clean up before the client is released.
func (c *splitHandler) Close() {
close(c.done)
c.conn.Close()
}

func extractSSLInformation(c net.Conn) *inputsource.TLSMetadata {
if tls, ok := c.(*tls.Conn); ok {
state := tls.ConnectionState()
Expand Down
79 changes: 28 additions & 51 deletions filebeat/inputsource/tcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,46 @@ import (

"golang.org/x/net/netutil"

"github.com/elastic/beats/libbeat/common/atomic"
"github.com/elastic/beats/libbeat/common/transport/tlscommon"
"github.com/elastic/beats/libbeat/logp"
"github.com/elastic/beats/libbeat/outputs/transport"
)

// Server represent a TCP server
type Server struct {
sync.RWMutex
config *Config
Listener net.Listener
clients map[ConnectionHandler]struct{}
wg sync.WaitGroup
done chan struct{}
factory ClientFactory
log *logp.Logger
tlsConfig *transport.TLSConfig
config *Config
Listener net.Listener
wg sync.WaitGroup
done chan struct{}
factory HandlerFactory
log *logp.Logger
tlsConfig *transport.TLSConfig
closer *Closer
clientsCount atomic.Int
}

// New creates a new tcp server
func New(
config *Config,
factory ClientFactory,
factory HandlerFactory,
) (*Server, error) {
tlsConfig, err := tlscommon.LoadTLSServerConfig(config.TLS)
if err != nil {
return nil, err
}

if factory == nil {
return nil, fmt.Errorf("ClientFactory can't be empty")
return nil, fmt.Errorf("HandlerFactory can't be empty")
}

return &Server{
config: config,
clients: make(map[ConnectionHandler]struct{}, 0),
done: make(chan struct{}),
factory: factory,
log: logp.NewLogger("tcp").With("address", config.Host),
tlsConfig: tlsConfig,
closer: NewCloser(nil),
}, nil
}

Expand All @@ -77,6 +78,7 @@ func (s *Server) Start() error {
return err
}

s.closer.callback = func() { s.Listener.Close() }
s.log.Info("Started listening for TCP connection")

s.wg.Add(1)
Expand All @@ -97,27 +99,28 @@ func (s *Server) run() {
conn, err := s.Listener.Accept()
if err != nil {
select {
case <-s.done:
case <-s.closer.Done():
return
default:
s.log.Debugw("Can not accept the connection", "error", err)
continue
}
}

client := s.factory(*s.config)
handler := s.factory(*s.config)
closer := WithCloser(s.closer, func() { conn.Close() })

s.wg.Add(1)
go func() {
defer logp.Recover("recovering from a tcp client crash")
defer s.wg.Done()
defer conn.Close()
defer closer.Close()

s.registerClient(client)
defer s.unregisterClient(client)
s.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", s.clientsCount())
s.registerHandler()
defer s.unregisterHandler()
s.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", s.clientsCount.Load())

err := client.Handle(conn)
err := handler.Handle(closer, conn)
if err != nil {
s.log.Debugw("client error", "error", err)
}
Expand All @@ -127,7 +130,7 @@ func (s *Server) run() {
"remote_address",
conn.RemoteAddr(),
"total",
s.clientsCount(),
s.clientsCount.Load(),
)
}()
}
Expand All @@ -136,37 +139,17 @@ func (s *Server) run() {
// Stop stops accepting new incoming TCP connection and Close any active clients
func (s *Server) Stop() {
s.log.Info("Stopping TCP server")
close(s.done)
s.Listener.Close()
for _, client := range s.allClients() {
client.Close()
}
s.closer.Close()
s.wg.Wait()
s.log.Info("TCP server stopped")
}

func (s *Server) registerClient(client ConnectionHandler) {
s.Lock()
defer s.Unlock()
s.clients[client] = struct{}{}
}

func (s *Server) unregisterClient(client ConnectionHandler) {
s.Lock()
defer s.Unlock()
delete(s.clients, client)
func (s *Server) registerHandler() {
s.clientsCount.Inc()
}

func (s *Server) allClients() []ConnectionHandler {
s.RLock()
defer s.RUnlock()
currentClients := make([]ConnectionHandler, len(s.clients))
idx := 0
for client := range s.clients {
currentClients[idx] = client
idx++
}
return currentClients
func (s *Server) unregisterHandler() {
s.clientsCount.Dec()
}

func (s *Server) createServer() (net.Listener, error) {
Expand All @@ -192,12 +175,6 @@ func (s *Server) createServer() (net.Listener, error) {
return l, nil
}

func (s *Server) clientsCount() int {
s.RLock()
defer s.RUnlock()
return len(s.clients)
}

// SplitFunc allows to create a `bufio.SplitFunc` based on a delimiter provided.
func SplitFunc(lineDelimiter []byte) bufio.SplitFunc {
ld := []byte(lineDelimiter)
Expand Down
3 changes: 2 additions & 1 deletion filebeat/inputsource/tcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/dustin/go-humanize"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/beats/filebeat/inputsource"
"github.com/elastic/beats/libbeat/common"
Expand Down Expand Up @@ -180,7 +181,7 @@ func TestReceiveEventsAndMetadata(t *testing.T) {
defer server.Stop()

conn, err := net.Dial("tcp", server.Listener.Addr().String())
assert.NoError(t, err)
require.NoError(t, err)
fmt.Fprint(conn, test.messageSent)
conn.Close()

Expand Down

0 comments on commit 106ad06

Please sign in to comment.