Skip to content

Commit

Permalink
Implement connection timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
iksaif committed Dec 21, 2023
1 parent 4c55e60 commit aaee24f
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 21 deletions.
16 changes: 10 additions & 6 deletions src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ String tag() {

public static final boolean DEFAULT_ENABLE_AGGREGATION = true;
public static final boolean DEFAULT_ENABLE_ORIGIN_DETECTION = true;
public static final int SOCKET_CONNECT_TIMEOUT_MS = 1000;

public static final String CLIENT_TAG = "client:java";
public static final String CLIENT_VERSION_TAG = "client_version:";
Expand Down Expand Up @@ -240,6 +241,9 @@ protected static String format(ThreadLocal<NumberFormat> formatter, Number value
* The client tries to read the container ID by parsing the file /proc/self/cgroup.
* This is not supported on Windows.
* The client prioritizes the value passed via or entityID or DD_ENTITY_ID (if set) over the container ID.
* @param connectionTimeout
* the timeout in milliseconds for connecting to the StatsD server. Applies to unix sockets only.
* It is also used to detect if a connection is still alive and re-establish a new one if needed.
* @throws StatsDClientException
* if the client could not be started
*/
Expand All @@ -249,7 +253,7 @@ private NonBlockingStatsDClient(final String prefix, final int queueSize, final
final int maxPacketSizeBytes, String entityID, final int poolSize, final int processorWorkers,
final int senderWorkers, boolean blocking, final boolean enableTelemetry, final int telemetryFlushInterval,
final int aggregationFlushInterval, final int aggregationShards, final ThreadFactory customThreadFactory,
String containerID, final boolean originDetectionEnabled)
String containerID, final boolean originDetectionEnabled, final int connectionTimeout)
throws StatsDClientException {

if ((prefix != null) && (!prefix.isEmpty())) {
Expand Down Expand Up @@ -296,7 +300,7 @@ private NonBlockingStatsDClient(final String prefix, final int queueSize, final
}

try {
clientChannel = createByteChannel(addressLookup, timeout, bufferSize);
clientChannel = createByteChannel(addressLookup, timeout, connectionTimeout, bufferSize);

ThreadFactory threadFactory = customThreadFactory != null ? customThreadFactory : new StatsDThreadFactory();

Expand All @@ -315,7 +319,7 @@ private NonBlockingStatsDClient(final String prefix, final int queueSize, final
telemetryClientChannel = clientChannel;
telemetryStatsDProcessor = statsDProcessor;
} else {
telemetryClientChannel = createByteChannel(telemetryAddressLookup, timeout, bufferSize);
telemetryClientChannel = createByteChannel(telemetryAddressLookup, timeout, connectionTimeout, bufferSize);

// similar settings, but a single worker and non-blocking.
telemetryStatsDProcessor = createProcessor(queueSize, handler, maxPacketSizeBytes,
Expand Down Expand Up @@ -376,7 +380,7 @@ public NonBlockingStatsDClient(final NonBlockingStatsDClientBuilder builder) thr
builder.blocking, builder.enableTelemetry, builder.telemetryFlushInterval,
(builder.enableAggregation ? builder.aggregationFlushInterval : 0),
builder.aggregationShards, builder.threadFactory, builder.containerID,
builder.originDetectionEnabled);
builder.originDetectionEnabled, builder.connectionTimeout);
}

protected StatsDProcessor createProcessor(final int queueSize, final StatsDClientErrorHandler handler,
Expand Down Expand Up @@ -477,7 +481,7 @@ StringBuilder tagString(final String[] tags, StringBuilder builder) {
return tagString(tags, constantTagsRendered, builder);
}

ClientChannel createByteChannel(Callable<SocketAddress> addressLookup, int timeout, int bufferSize) throws Exception {
ClientChannel createByteChannel(Callable<SocketAddress> addressLookup, int timeout, int connectionTimeout, int bufferSize) throws Exception {
final SocketAddress address = addressLookup.call();
if (address instanceof NamedPipeSocketAddress) {
return new NamedPipeClientChannel((NamedPipeSocketAddress) address);
Expand All @@ -489,7 +493,7 @@ ClientChannel createByteChannel(Callable<SocketAddress> addressLookup, int timeo
// Allow us to support `unix://` for both kind of sockets like in go.
switch (unixAddr.getTransportType()) {
case UDS_STREAM:
return new UnixStreamClientChannel(unixAddr.getAddress(), timeout, bufferSize);
return new UnixStreamClientChannel(unixAddr.getAddress(), timeout, connectionTimeout, bufferSize);
case UDS_DATAGRAM:
case UDS:
return new UnixDatagramClientChannel(unixAddr.getAddress(), timeout, bufferSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class NonBlockingStatsDClientBuilder implements Cloneable {
public int aggregationFlushInterval = StatsDAggregator.DEFAULT_FLUSH_INTERVAL;
public int aggregationShards = StatsDAggregator.DEFAULT_SHARDS;
public boolean originDetectionEnabled = NonBlockingStatsDClient.DEFAULT_ENABLE_ORIGIN_DETECTION;
public int connectionTimeout = NonBlockingStatsDClient.SOCKET_CONNECT_TIMEOUT_MS;

public Callable<SocketAddress> addressLookup;
public Callable<SocketAddress> telemetryAddressLookup;
Expand Down Expand Up @@ -71,6 +72,11 @@ public NonBlockingStatsDClientBuilder timeout(int val) {
return this;
}

public NonBlockingStatsDClientBuilder connectionTimeout(int val) {
connectionTimeout = val;
return this;
}

public NonBlockingStatsDClientBuilder bufferPoolSize(int val) {
bufferPoolSize = val;
return this;
Expand Down
67 changes: 54 additions & 13 deletions src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
public class UnixStreamClientChannel implements ClientChannel {
private final UnixSocketAddress address;
private final int timeout;
private final int connectionTimeout;
private final int bufferSize;


private SocketChannel delegate;
private final ByteBuffer delimiterBuffer = ByteBuffer.allocateDirect(Integer.SIZE / Byte.SIZE).order(ByteOrder.LITTLE_ENDIAN);

Expand All @@ -26,10 +28,11 @@ public class UnixStreamClientChannel implements ClientChannel {
*
* @param address Location of named pipe
*/
UnixStreamClientChannel(SocketAddress address, int timeout, int bufferSize) throws IOException {
UnixStreamClientChannel(SocketAddress address, int timeout, int connectionTimeout, int bufferSize) throws IOException {
this.delegate = null;
this.address = (UnixSocketAddress) address;
this.timeout = timeout;
this.connectionTimeout = connectionTimeout;
this.bufferSize = bufferSize;
}

Expand All @@ -39,10 +42,8 @@ public boolean isOpen() {
}

@Override
public int write(ByteBuffer src) throws IOException {
if (delegate == null || !delegate.isConnected()) {
connect();
}
synchronized public int write(ByteBuffer src) throws IOException {
connectIfNeeded();

int size = src.remaining();
if (size == 0) {
Expand All @@ -53,12 +54,12 @@ public int write(ByteBuffer src) throws IOException {
delimiterBuffer.flip();

try {
if (writeAll(delimiterBuffer, true) > 0) {
writeAll(src, false);
long deadline = System.currentTimeMillis() + timeout;
if (writeAll(delimiterBuffer, true, deadline) > 0) {
writeAll(src, false, deadline);
}
} catch (IOException e) {
delegate.close();
delegate = null;
disconnect();
throw e;
}

Expand All @@ -69,10 +70,11 @@ public int write(ByteBuffer src) throws IOException {
* Writes all bytes from the given buffer to the channel.
* @param bb buffer to write
* @param canReturnOnTimeout if true, we return if the channel is blocking and we haven't written anything yet
* @param deadline deadline for the write
* @return number of bytes written
* @throws IOException if the channel is closed or an error occurs
*/
public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout) throws IOException {
public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout, long deadline) throws IOException {
int remaining = bb.remaining();
int written = 0;
while (remaining > 0) {
Expand All @@ -85,26 +87,65 @@ public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout) throws IOExceptio

remaining -= read;
written += read;

if (deadline > 0 && System.currentTimeMillis() > deadline) {
throw new IOException("Write timed out");
}
}
return written;
}

private void connectIfNeeded() throws IOException {
if (delegate == null) {
connect();
}
}

private void disconnect() throws IOException {
if (delegate != null) {
delegate.close();
delegate = null;
}
}

private void connect() throws IOException {
if (this.delegate != null) {
try {
this.delegate.close();
this.delegate = null;
disconnect();
} catch (IOException e) {
// ignore to be sure we don't stay with a broken delegate forever.
}
}
this.delegate = UnixSocketChannel.open(address);

UnixSocketChannel delegate = UnixSocketChannel.create();

long deadline = System.currentTimeMillis() + connectionTimeout;
if (connectionTimeout > 0) {
// Set connect timeout, this should work at least on linux
// https://elixir.bootlin.com/linux/v5.7.4/source/net/unix/af_unix.c#L1696
// We'd have better timeout support if we used Java 16's native Unix domain socket support (JEP 380)
delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, connectionTimeout);
}
delegate.connect(address);
while (!delegate.finishConnect()) {
// wait for connection to be established
try {
Thread.sleep(10);
} catch (InterruptedException e) {
throw new IOException("Interrupted while waiting for connection", e);
}
if (connectionTimeout > 0 && System.currentTimeMillis() > deadline) {
throw new IOException("Connection timed out");
}
}

if (timeout > 0) {
delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, timeout);
}
if (bufferSize > 0) {
delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize);
}
this.delegate = delegate;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,7 @@ public NonBlockingStatsDClient build() {
this.originDetectionEnabled(false);
return new NonBlockingStatsDClient(resolve()) {
@Override
ClientChannel createByteChannel(Callable<SocketAddress> addressLookup, int timeout, int bufferSize) throws Exception {
ClientChannel createByteChannel(Callable<SocketAddress> addressLookup, int timeout, int connectionTimeout, int bufferSize) throws Exception {
return new DatagramClientChannel(addressLookup.call()) {
@Override
public int write(ByteBuffer data) throws IOException {
Expand Down Expand Up @@ -1845,7 +1845,7 @@ public NonBlockingStatsDClient build() {
this.bufferPoolSize(1);
return new NonBlockingStatsDClient(resolve()) {
@Override
ClientChannel createByteChannel(Callable<SocketAddress> addressLookup, int timeout, int bufferSize) throws Exception {
ClientChannel createByteChannel(Callable<SocketAddress> addressLookup, int timeout, int connectionTimeout, int bufferSize) throws Exception {
return new DatagramClientChannel(addressLookup.call()) {
@Override
public int write(ByteBuffer data) throws IOException {
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/com/timgroup/statsd/UnixSocketTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public SocketAddress call() throws Exception {
.port(0)
.queueSize(1)
.timeout(1) // non-zero timeout to ensure exception triggered if socket buffer full.
.connectionTimeout(1)
.socketBufferSize(1024 * 1024)
.enableAggregation(false)
.errorHandler(this)
Expand All @@ -96,6 +97,7 @@ public SocketAddress call() throws Exception {
.port(0)
.queueSize(1)
.timeout(1) // non-zero timeout to ensure exception triggered if socket buffer full.
.connectionTimeout(1)
.socketBufferSize(1024 * 1024)
.enableAggregation(false)
.errorHandler(this)
Expand Down Expand Up @@ -218,4 +220,24 @@ public void resist_dsd_timeout() throws Exception {
assertThat(server.messagesReceived(), hasItem("my.prefix.mycount:30|g"));
server.clear();
}

@Test(timeout = 10000L)
public void testConnectionTimeout() throws InterruptedException {
if (transport != "unixstream") {
// Connection timeout is not supported for unixgram
return;
}

// Delay the `accept()` on the server
server.freeze();
client.gauge("mycount", 10);
Thread.sleep(10);
server.unfreeze();

server.waitForMessage();
assertThat(server.messagesReceived(), contains("my.prefix.mycount:10|g"));
server.clear();
assertThat(lastException.getMessage(), nullValue());

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public UnixStreamSocketDummyStatsDServer(String socketPath) throws IOException {
server.socket().bind(new UnixSocketAddress(socketPath));
this.listen();
}

@Override
protected boolean isOpen() {
return server.isOpen();
Expand Down Expand Up @@ -53,6 +54,7 @@ protected void receive(ByteBuffer packet) throws IOException {
private boolean readPacket(SocketChannel channel, ByteBuffer packet) throws IOException {
try {
ByteBuffer delimiterBuffer = ByteBuffer.allocate(Integer.SIZE / Byte.SIZE).order(ByteOrder.LITTLE_ENDIAN);

int read = channel.read(delimiterBuffer);
delimiterBuffer.flip();
if (read <= 0) {
Expand Down

0 comments on commit aaee24f

Please sign in to comment.