Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a race condition in the TCP input #13038

Merged
merged 7 commits into from
Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.next.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ https://github.com/elastic/beats/compare/v7.0.0-alpha2...master[Check the HEAD d

- Add read_buffer configuration option. {pull}11739[11739]
- `convert_timezone` option is removed and locale is always added to the event so timezone is used when parsing the timestamp, this behaviour can be overriden with processors. {pull}12410[12410]
- Fix a race condition in the TCP input when close the client socket. {pull}13038[13038]

*Heartbeat*

Expand Down
130 changes: 130 additions & 0 deletions filebeat/inputsource/tcp/closeref.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package tcp

import (
"sync"

"github.com/pkg/errors"
)

// CloserFunc is the function called by the Closer on `Close()`.
type CloserFunc func()

// ErrClosed is returned when the Closer is closed.
var ErrClosed = errors.New("closer is closed")

// CloseRef implements a subset of the context.Context interface and it's use to synchronize
// the shutdown of multiple go-routines.
type CloseRef interface {
Done() <-chan struct{}
Err() error
}

// Closer implements a shutdown strategy when dealing with multiples go-routines, it creates a tree
// of Closer, when you call `Close()` on a parent the `Close()` method will be called on the current
// closer and any of the childs it may have and will remove the current node from the parent.
//
// NOTE: The `Close()` is reentrant but will propage the close only once.
type Closer struct {
mu sync.Mutex
done chan struct{}
err error
parent *Closer
children map[*Closer]struct{}
callback CloserFunc
}

// Close closes the closes and propagates the close to any child, on close the close callback will
// be called, this can be used for custom cleanup like closing a TCP socket.
func (c *Closer) Close() {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return
}

if c.callback != nil {
c.callback()
}

close(c.done)

// propagate close to children.
if c.children != nil {
for child := range c.children {
child.Close()
}
c.children = nil
}

c.err = ErrClosed
c.mu.Unlock()

if c.parent != nil {
c.removeChild(c)
}
}

// Done returns the synchronization channel, the channel will be closed if `Close()` was called on
// the current node or any parent it may have.
func (c *Closer) Done() <-chan struct{} {
return c.done
}

// Err returns an error if the Closer was already closed.
func (c *Closer) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any code that can set c.err besides close. That is, one could replace the mutex with a simple atomic.Bool and implement the body like: if !c.active.Load() { return ErrClosed }; return nil.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still have to lock the children map though? I can revisit if when we move it to external repo.

}

func (c *Closer) removeChild(child *Closer) {
c.mu.Lock()
delete(c.children, child)
c.mu.Unlock()
}

func (c *Closer) addChild(child *Closer) {
c.mu.Lock()
if c.children == nil {
c.children = make(map[*Closer]struct{})
}
c.children[child] = struct{}{}
c.mu.Unlock()
}

// WithCloser wraps a new closer into a child of an existing closer.
func WithCloser(parent *Closer, fn CloserFunc) *Closer {
child := &Closer{
done: make(chan struct{}),
parent: parent,
callback: fn,
}
parent.addChild(child)
return child
}

// NewCloser creates a new Closer.
func NewCloser(fn CloserFunc) *Closer {
return &Closer{
done: make(chan struct{}),
callback: fn,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (

// splitHandler is a TCP client that has splitting capabilities.
type splitHandler struct {
conn net.Conn
callback inputsource.NetworkFunc
done chan struct{}
metadata inputsource.NetworkMetadata
Expand All @@ -42,19 +41,23 @@ type splitHandler struct {
timeout time.Duration
}

// ClientFactory returns a ConnectionHandler func
type ClientFactory func(config Config) ConnectionHandler
// HandlerFactory returns a ConnectionHandler func
type HandlerFactory func(config Config) ConnectionHandler

// ConnectionHandler interface provides mechanisms for handling of incoming TCP connections
type ConnectionHandler interface {
Handle(conn net.Conn) error
Close()
Handle(CloseRef, net.Conn) error
}

// SplitHandlerFactory allows creation of a ConnectionHandler that can do splitting of messages received on a TCP connection.
func SplitHandlerFactory(callback inputsource.NetworkFunc, splitFunc bufio.SplitFunc) ClientFactory {
func SplitHandlerFactory(callback inputsource.NetworkFunc, splitFunc bufio.SplitFunc) HandlerFactory {
return func(config Config) ConnectionHandler {
return newSplitHandler(callback, splitFunc, uint64(config.MaxMessageSize), config.Timeout)
return newSplitHandler(
callback,
splitFunc,
uint64(config.MaxMessageSize),
config.Timeout,
)
}
}

Expand All @@ -76,8 +79,7 @@ func newSplitHandler(
}

// Handle takes a connection as input and processes data received on it.
func (c *splitHandler) Handle(conn net.Conn) error {
c.conn = conn
func (c *splitHandler) Handle(closer CloseRef, conn net.Conn) error {
c.metadata = inputsource.NetworkMetadata{
RemoteAddr: conn.RemoteAddr(),
TLS: extractSSLInformation(conn),
Expand All @@ -97,7 +99,7 @@ func (c *splitHandler) Handle(conn net.Conn) error {
if err != nil {
// we are forcing a Close on the socket, lets ignore any error that could happen.
select {
case <-c.done:
case <-closer.Done():
break
default:
}
Expand All @@ -121,12 +123,6 @@ func (c *splitHandler) Handle(conn net.Conn) error {
return nil
}

// Close is used to perform clean up before the client is released.
func (c *splitHandler) Close() {
close(c.done)
c.conn.Close()
}

func extractSSLInformation(c net.Conn) *inputsource.TLSMetadata {
if tls, ok := c.(*tls.Conn); ok {
state := tls.ConnectionState()
Expand Down
79 changes: 28 additions & 51 deletions filebeat/inputsource/tcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,46 @@ import (

"golang.org/x/net/netutil"

"github.com/elastic/beats/libbeat/common/atomic"
"github.com/elastic/beats/libbeat/common/transport/tlscommon"
"github.com/elastic/beats/libbeat/logp"
"github.com/elastic/beats/libbeat/outputs/transport"
)

// Server represent a TCP server
type Server struct {
sync.RWMutex
config *Config
Listener net.Listener
clients map[ConnectionHandler]struct{}
wg sync.WaitGroup
done chan struct{}
factory ClientFactory
log *logp.Logger
tlsConfig *transport.TLSConfig
config *Config
Listener net.Listener
wg sync.WaitGroup
done chan struct{}
factory HandlerFactory
log *logp.Logger
tlsConfig *transport.TLSConfig
closer *Closer
clientsCount atomic.Int
}

// New creates a new tcp server
func New(
config *Config,
factory ClientFactory,
factory HandlerFactory,
) (*Server, error) {
tlsConfig, err := tlscommon.LoadTLSServerConfig(config.TLS)
if err != nil {
return nil, err
}

if factory == nil {
return nil, fmt.Errorf("ClientFactory can't be empty")
return nil, fmt.Errorf("HandlerFactory can't be empty")
}

return &Server{
config: config,
clients: make(map[ConnectionHandler]struct{}, 0),
done: make(chan struct{}),
factory: factory,
log: logp.NewLogger("tcp").With("address", config.Host),
tlsConfig: tlsConfig,
closer: NewCloser(nil),
}, nil
}

Expand All @@ -77,6 +78,7 @@ func (s *Server) Start() error {
return err
}

s.closer.callback = func() { s.Listener.Close() }
s.log.Info("Started listening for TCP connection")

s.wg.Add(1)
Expand All @@ -97,27 +99,28 @@ func (s *Server) run() {
conn, err := s.Listener.Accept()
if err != nil {
select {
case <-s.done:
case <-s.closer.Done():
return
default:
s.log.Debugw("Can not accept the connection", "error", err)
continue
}
}

client := s.factory(*s.config)
handler := s.factory(*s.config)
closer := WithCloser(s.closer, func() { conn.Close() })

s.wg.Add(1)
go func() {
defer logp.Recover("recovering from a tcp client crash")
defer s.wg.Done()
defer conn.Close()
defer closer.Close()

s.registerClient(client)
defer s.unregisterClient(client)
s.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", s.clientsCount())
s.registerHandler()
defer s.unregisterHandler()
s.log.Debugw("New client", "remote_address", conn.RemoteAddr(), "total", s.clientsCount.Load())

err := client.Handle(conn)
err := handler.Handle(closer, conn)
if err != nil {
s.log.Debugw("client error", "error", err)
}
Expand All @@ -127,7 +130,7 @@ func (s *Server) run() {
"remote_address",
conn.RemoteAddr(),
"total",
s.clientsCount(),
s.clientsCount.Load(),
)
}()
}
Expand All @@ -136,37 +139,17 @@ func (s *Server) run() {
// Stop stops accepting new incoming TCP connection and Close any active clients
func (s *Server) Stop() {
s.log.Info("Stopping TCP server")
close(s.done)
s.Listener.Close()
for _, client := range s.allClients() {
client.Close()
}
s.closer.Close()
s.wg.Wait()
s.log.Info("TCP server stopped")
}

func (s *Server) registerClient(client ConnectionHandler) {
s.Lock()
defer s.Unlock()
s.clients[client] = struct{}{}
}

func (s *Server) unregisterClient(client ConnectionHandler) {
s.Lock()
defer s.Unlock()
delete(s.clients, client)
func (s *Server) registerHandler() {
s.clientsCount.Inc()
}

func (s *Server) allClients() []ConnectionHandler {
s.RLock()
defer s.RUnlock()
currentClients := make([]ConnectionHandler, len(s.clients))
idx := 0
for client := range s.clients {
currentClients[idx] = client
idx++
}
return currentClients
func (s *Server) unregisterHandler() {
s.clientsCount.Dec()
}

func (s *Server) createServer() (net.Listener, error) {
Expand All @@ -192,12 +175,6 @@ func (s *Server) createServer() (net.Listener, error) {
return l, nil
}

func (s *Server) clientsCount() int {
s.RLock()
defer s.RUnlock()
return len(s.clients)
}

// SplitFunc allows to create a `bufio.SplitFunc` based on a delimiter provided.
func SplitFunc(lineDelimiter []byte) bufio.SplitFunc {
ld := []byte(lineDelimiter)
Expand Down
3 changes: 2 additions & 1 deletion filebeat/inputsource/tcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/dustin/go-humanize"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/beats/filebeat/inputsource"
"github.com/elastic/beats/libbeat/common"
Expand Down Expand Up @@ -180,7 +181,7 @@ func TestReceiveEventsAndMetadata(t *testing.T) {
defer server.Stop()

conn, err := net.Dial("tcp", server.Listener.Addr().String())
assert.NoError(t, err)
require.NoError(t, err)
fmt.Fprint(conn, test.messageSent)
conn.Close()

Expand Down