From 960580771994d7ad81e02480781d5d465a83b2f5 Mon Sep 17 00:00:00 2001 From: Corentin Chary Date: Tue, 2 Jan 2024 11:51:39 +0100 Subject: [PATCH] uds: implement a connect timeout option This is similar to https://github.com/DataDog/java-dogstatsd-client/pull/228 --- statsd/options.go | 13 +++++++++ statsd/statsd.go | 12 ++++---- statsd/telemetry.go | 6 ++-- statsd/uds.go | 71 +++++++++++++++++---------------------------- statsd/uds_test.go | 50 ++++++++++++++++++++++++++----- 5 files changed, 91 insertions(+), 61 deletions(-) diff --git a/statsd/options.go b/statsd/options.go index dc00b80d..29e09800 100644 --- a/statsd/options.go +++ b/statsd/options.go @@ -17,6 +17,7 @@ var ( defaultWorkerCount = 32 defaultSenderQueueSize = 0 defaultWriteTimeout = 100 * time.Millisecond + defaultConnectTimeout = 1000 * time.Millisecond defaultTelemetry = true defaultReceivingMode = mutexMode defaultChannelModeBufferSize = 4096 @@ -40,6 +41,7 @@ type Options struct { workersCount int senderQueueSize int writeTimeout time.Duration + connectTimeout time.Duration telemetry bool receiveMode receivingMode channelModeBufferSize int @@ -65,6 +67,7 @@ func resolveOptions(options []Option) (*Options, error) { workersCount: defaultWorkerCount, senderQueueSize: defaultSenderQueueSize, writeTimeout: defaultWriteTimeout, + connectTimeout: defaultConnectTimeout, telemetry: defaultTelemetry, receiveMode: defaultReceivingMode, channelModeBufferSize: defaultChannelModeBufferSize, @@ -206,6 +209,16 @@ func WithWriteTimeout(writeTimeout time.Duration) Option { } } +// WithConnectTimeout sets the timeout for network connection with the Agent, after this interval the connection +// attempt is aborted. This is only used for UDS connection. This will also reset the connection if nothing can be +// written to it for this duration. +func WithConnectTimeout(connectTimeout time.Duration) Option { + return func(o *Options) error { + o.connectTimeout = connectTimeout + return nil + } +} + // WithChannelMode make the client use channels to receive metrics // // This determines how the client receive metrics from the app (for example when calling the `Gauge()` method). diff --git a/statsd/statsd.go b/statsd/statsd.go index bbb14c88..33792a53 100644 --- a/statsd/statsd.go +++ b/statsd/statsd.go @@ -368,7 +368,7 @@ func parseAgentURL(agentURL string) string { return "" } -func createWriter(addr string, writeTimeout time.Duration) (Transport, string, error) { +func createWriter(addr string, writeTimeout time.Duration, connectTimeout time.Duration) (Transport, string, error) { addr = resolveAddr(addr) if addr == "" { return nil, "", errors.New("No address passed and autodetection from environment failed") @@ -379,13 +379,13 @@ func createWriter(addr string, writeTimeout time.Duration) (Transport, string, e w, err := newWindowsPipeWriter(addr, writeTimeout) return w, writerWindowsPipe, err case strings.HasPrefix(addr, UnixAddressPrefix): - w, err := newUDSWriter(addr[len(UnixAddressPrefix):], writeTimeout, "") + w, err := newUDSWriter(addr[len(UnixAddressPrefix):], writeTimeout, connectTimeout, "") return w, writerNameUDS, err case strings.HasPrefix(addr, UnixAddressDatagramPrefix): - w, err := newUDSWriter(addr[len(UnixAddressDatagramPrefix):], writeTimeout, "unixgram") + w, err := newUDSWriter(addr[len(UnixAddressDatagramPrefix):], writeTimeout, connectTimeout, "unixgram") return w, writerNameUDS, err case strings.HasPrefix(addr, UnixAddressStreamPrefix): - w, err := newUDSWriter(addr[len(UnixAddressStreamPrefix):], writeTimeout, "unix") + w, err := newUDSWriter(addr[len(UnixAddressStreamPrefix):], writeTimeout, connectTimeout, "unix") return w, writerNameUDS, err default: w, err := newUDPWriter(addr, writeTimeout) @@ -401,7 +401,7 @@ func New(addr string, options ...Option) (*Client, error) { return nil, err } - w, writerType, err := createWriter(addr, o.writeTimeout) + w, writerType, err := createWriter(addr, o.writeTimeout, o.connectTimeout) if err != nil { return nil, err } @@ -542,7 +542,7 @@ func newWithWriter(w Transport, o *Options, writerName string) (*Client, error) c.telemetryClient = newTelemetryClient(&c, c.agg != nil) } else { var err error - c.telemetryClient, err = newTelemetryClientWithCustomAddr(&c, o.telemetryAddr, c.agg != nil, bufferPool, o.writeTimeout) + c.telemetryClient, err = newTelemetryClientWithCustomAddr(&c, o.telemetryAddr, c.agg != nil, bufferPool, o.writeTimeout, o.connectTimeout) if err != nil { return nil, err } diff --git a/statsd/telemetry.go b/statsd/telemetry.go index 53c12116..61025c37 100644 --- a/statsd/telemetry.go +++ b/statsd/telemetry.go @@ -138,8 +138,10 @@ func newTelemetryClient(c *Client, aggregationEnabled bool) *telemetryClient { return t } -func newTelemetryClientWithCustomAddr(c *Client, telemetryAddr string, aggregationEnabled bool, pool *bufferPool, writeTimeout time.Duration) (*telemetryClient, error) { - telemetryWriter, _, err := createWriter(telemetryAddr, writeTimeout) +func newTelemetryClientWithCustomAddr(c *Client, telemetryAddr string, aggregationEnabled bool, pool *bufferPool, + writeTimeout time.Duration, connectTimeout time.Duration, +) (*telemetryClient, error) { + telemetryWriter, _, err := createWriter(telemetryAddr, writeTimeout, connectTimeout) if err != nil { return nil, fmt.Errorf("Could not resolve telemetry address: %v", err) } diff --git a/statsd/uds.go b/statsd/uds.go index 4abf3d30..09518992 100644 --- a/statsd/uds.go +++ b/statsd/uds.go @@ -21,13 +21,15 @@ type udsWriter struct { conn net.Conn // write timeout writeTimeout time.Duration - sync.RWMutex // used to lock conn / writer can replace it + // connect timeout + connectTimeout time.Duration + sync.RWMutex // used to lock conn / writer can replace it } // newUDSWriter returns a pointer to a new udsWriter given a socket file path as addr. -func newUDSWriter(addr string, writeTimeout time.Duration, transport string) (*udsWriter, error) { +func newUDSWriter(addr string, writeTimeout time.Duration, connectTimeout time.Duration, transport string) (*udsWriter, error) { // Defer connection to first Write - writer := &udsWriter{addr: addr, transport: transport, conn: nil, writeTimeout: writeTimeout} + writer := &udsWriter{addr: addr, transport: transport, conn: nil, writeTimeout: writeTimeout, connectTimeout: connectTimeout} return writer, nil } @@ -43,20 +45,11 @@ func (w *udsWriter) GetTransportName() string { } } -// retryOnWriteErr returns true if we should retry writing after a write error -func (w *udsWriter) retryOnWriteErr(err error, stream bool) bool { - // Never retry when using unixgram (to preserve the historical behavior) - if !stream { - return false - } - // Otherwise we retry on timeout because we might have written a partial packet - if networkError, ok := err.(net.Error); ok && networkError.Timeout() { +func (w *udsWriter) shouldCloseConnection(err error, partialWrite bool) bool { + if err != nil && partialWrite { + // We can't recover from a partial write return true } - return false -} - -func (w *udsWriter) shouldCloseConnection(err error) bool { if err, isNetworkErr := err.(net.Error); err != nil && (!isNetworkErr || !err.Timeout()) { // Statsd server disconnected, retry connecting at next packet return true @@ -64,35 +57,11 @@ func (w *udsWriter) shouldCloseConnection(err error) bool { return false } -// writeFull writes the whole data to the UDS connection -func (w *udsWriter) writeFull(data []byte, stopIfNoneWritten bool, stream bool) (int, error) { - written := 0 - for written < len(data) { - n, e := w.conn.Write(data[written:]) - written += n - - // If we haven't written anything, and we're supposed to stop if we can't write anything, return the error - if written == 0 && stopIfNoneWritten { - return written, e - } - - // If there's an error, check if it is retryable - if e != nil && !w.retryOnWriteErr(e, stream) { - return written, e - } - - // When using "unix" we need to be able to finish to write partially written packets once we have started. - if stream { - w.conn.SetWriteDeadline(time.Time{}) - } - } - return written, nil -} - // Write data to the UDS connection with write timeout and minimal error handling: // create the connection if nil, and destroy it if the statsd server has disconnected func (w *udsWriter) Write(data []byte) (int, error) { var n int + partialWrite := false conn, err := w.ensureConnection() if err != nil { return 0, err @@ -107,15 +76,26 @@ func (w *udsWriter) Write(data []byte) (int, error) { if stream { bs := []byte{0, 0, 0, 0} binary.LittleEndian.PutUint32(bs, uint32(len(data))) - _, err = w.writeFull(bs, true, true) + _, err = w.conn.Write(bs) + + partialWrite = true + + // W need to be able to finish to write partially written packets once we have started. + // But we will reset the connection if we can't write anything at all for a long time. + w.conn.SetWriteDeadline(time.Now().Add(w.connectTimeout)) + + // Continue writing only if we've written the length of the packet if err == nil { - n, err = w.writeFull(data, false, true) + n, err = w.conn.Write(data) + if err == nil { + partialWrite = false + } } } else { - n, err = w.writeFull(data, true, false) + n, err = w.conn.Write(data) } - if w.shouldCloseConnection(err) { + if w.shouldCloseConnection(err, partialWrite) { w.unsetConnection() } return n, err @@ -133,7 +113,7 @@ func (w *udsWriter) tryToDial(network string) (net.Conn, error) { if err != nil { return nil, err } - newConn, err := net.Dial(udsAddr.Network(), udsAddr.String()) + newConn, err := net.DialTimeout(udsAddr.Network(), udsAddr.String(), w.connectTimeout) if err != nil { return nil, err } @@ -182,5 +162,6 @@ func (w *udsWriter) ensureConnection() (net.Conn, error) { func (w *udsWriter) unsetConnection() { w.Lock() defer w.Unlock() + _ = w.conn.Close() w.conn = nil } diff --git a/statsd/uds_test.go b/statsd/uds_test.go index d0c139c3..1317a9c7 100644 --- a/statsd/uds_test.go +++ b/statsd/uds_test.go @@ -5,13 +5,14 @@ package statsd import ( "encoding/binary" - "golang.org/x/net/nettest" "math/rand" "net" "os" "testing" "time" + "golang.org/x/net/nettest" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,13 +22,13 @@ func init() { } func TestNewUDSWriter(t *testing.T) { - w, err := newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "") + w, err := newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "") assert.NotNil(t, w) assert.NoError(t, err) - w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "unix") + w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "unix") assert.NotNil(t, w) assert.NoError(t, err) - w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, "unixgram") + w, err = newUDSWriter("/tmp/test.socket", 100*time.Millisecond, 1000*time.Millisecond, "unixgram") assert.NotNil(t, w) assert.NoError(t, err) } @@ -44,7 +45,7 @@ func TestUDSDatagramWrite(t *testing.T) { err = os.Chmod(socketPath, 0722) require.NoError(t, err) - w, err := newUDSWriter(socketPath, 100*time.Millisecond, "") + w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "") require.Nil(t, err) require.NotNil(t, w) @@ -74,7 +75,7 @@ func TestUDSDatagramWriteUnsetConnection(t *testing.T) { err = os.Chmod(socketPath, 0722) require.NoError(t, err) - w, err := newUDSWriter(socketPath, 100*time.Millisecond, "") + w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "") require.Nil(t, err) require.NotNil(t, w) @@ -107,7 +108,7 @@ func TestUDSStreamWrite(t *testing.T) { err = os.Chmod(socketPath, 0722) require.NoError(t, err) - w, err := newUDSWriter(socketPath, 100*time.Millisecond, "") + w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "") require.Nil(t, err) require.NotNil(t, w) @@ -120,6 +121,7 @@ func TestUDSStreamWrite(t *testing.T) { require.NoError(t, err) assert.Equal(t, len(msg), n) + // This works because the kernel accepts sockets before the accept call if conn == nil { conn, err = listener.Accept() require.NoError(t, err) @@ -148,7 +150,7 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) { err = os.Chmod(socketPath, 0722) require.NoError(t, err) - w, err := newUDSWriter(socketPath, 100*time.Millisecond, "") + w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "") require.Nil(t, err) require.NotNil(t, w) @@ -161,6 +163,7 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) { require.NoError(t, err) assert.Equal(t, len(msg), n) + // This works because the kernel accepts sockets before the accept call if conn == nil { conn, err = listener.Accept() require.NoError(t, err) @@ -180,3 +183,34 @@ func TestUDSStreamWriteUnsetConnection(t *testing.T) { conn = nil } } + +func TestUDSStreamPartialWrite(t *testing.T) { + socketPath, err := nettest.LocalPath() + require.NoError(t, err) + defer os.Remove(socketPath) + + address, err := net.ResolveUnixAddr("unix", socketPath) + require.NoError(t, err) + listener, err := net.ListenUnix("unix", address) + defer listener.Close() + require.NoError(t, err) + err = os.Chmod(socketPath, 0722) + require.NoError(t, err) + + w, err := newUDSWriter(socketPath, 100*time.Millisecond, 1000*time.Millisecond, "") + require.Nil(t, err) + require.NotNil(t, w) + + // Force a connection + w.ensureConnection() + // Set a very low buffer size to force a partial write, but still enough to write the header + w.conn.(*net.UnixConn).SetWriteBuffer(8) + + msg := []byte("some data") + n, err := w.Write(msg) + require.Error(t, err) + assert.Lessf(t, n, len(msg), "n: %d, len(msg): %d", n, len(msg)) + + // The connection should be dropped + assert.Nil(t, w.conn) +}