Skip to content

Commit

Permalink
Add goaway and reconnection mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
drpmma committed Sep 8, 2023
1 parent a2ca239 commit 65412b0
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public class NettyClientConfig {
private boolean disableCallbackExecutor = false;
private boolean disableNettyWorkerGroup = false;

private long maxReconnectIntervalTimeSeconds = 60;

public boolean isClientCloseSocketIfTimeout() {
return clientCloseSocketIfTimeout;
}
Expand Down Expand Up @@ -181,6 +183,14 @@ public void setDisableNettyWorkerGroup(boolean disableNettyWorkerGroup) {
this.disableNettyWorkerGroup = disableNettyWorkerGroup;
}

public long getMaxReconnectIntervalTimeSeconds() {
return maxReconnectIntervalTimeSeconds;
}

public void setMaxReconnectIntervalTimeSeconds(long maxReconnectIntervalTimeSeconds) {
this.maxReconnectIntervalTimeSeconds = maxReconnectIntervalTimeSeconds;
}

public String getSocksProxyConfig() {
return socksProxyConfig;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.TypeReference;
import com.google.common.base.Stopwatch;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
Expand Down Expand Up @@ -48,6 +49,7 @@
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.cert.CertificateException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -57,6 +59,7 @@
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
Expand All @@ -67,6 +70,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.common.Pair;
import org.apache.rocketmq.common.ThreadFactoryImpl;
Expand All @@ -82,6 +86,7 @@
import org.apache.rocketmq.remoting.exception.RemotingTimeoutException;
import org.apache.rocketmq.remoting.exception.RemotingTooMuchRequestException;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;
import org.apache.rocketmq.remoting.protocol.ResponseCode;
import org.apache.rocketmq.remoting.proxy.SocksProxyConfig;

public class NettyRemotingClient extends NettyRemotingAbstract implements RemotingClient {
Expand All @@ -97,6 +102,7 @@ public class NettyRemotingClient extends NettyRemotingAbstract implements Remoti
private final Map<String /* cidr */, SocksProxyConfig /* proxy */> proxyMap = new HashMap<>();
private final ConcurrentHashMap<String /* cidr */, Bootstrap> bootstrapMap = new ConcurrentHashMap<>();
private final ConcurrentMap<String /* addr */, ChannelWrapper> channelTables = new ConcurrentHashMap<>();
private final ConcurrentMap<Channel, ChannelWrapper> channelWrapperTables = new ConcurrentHashMap<>();

private final HashedWheelTimer timer = new HashedWheelTimer(r -> new Thread(r, "ClientHouseKeepingService"));

Expand Down Expand Up @@ -354,9 +360,10 @@ public void shutdown() {
this.timer.stop();

for (String addr : this.channelTables.keySet()) {
this.closeChannel(addr, this.channelTables.get(addr).getChannel());
this.channelTables.get(addr).close();
}

this.channelWrapperTables.clear();
this.channelTables.clear();

this.eventLoopGroupWorker.shutdownGracefully();
Expand Down Expand Up @@ -414,7 +421,10 @@ public void closeChannel(final String addr, final Channel channel) {
}

if (removeItemFromTable) {
this.channelTables.remove(addrRemote);
ChannelWrapper channelWrapper = this.channelWrapperTables.remove(channel);
if (channelWrapper != null && channelWrapper.tryClose(channel)) {
this.channelTables.remove(addrRemote);
}
LOGGER.info("closeChannel: the channel[{}] was removed from channel table", addrRemote);
}

Expand Down Expand Up @@ -461,7 +471,10 @@ public void closeChannel(final Channel channel) {
}

if (removeItemFromTable) {
this.channelTables.remove(addrRemote);
ChannelWrapper channelWrapper = this.channelWrapperTables.remove(channel);
if (channelWrapper != null && channelWrapper.tryClose(channel)) {
this.channelTables.remove(addrRemote);
}
LOGGER.info("closeChannel: the channel[{}] was removed from channel table", addrRemote);
RemotingHelper.closeChannel(channel);
}
Expand Down Expand Up @@ -509,7 +522,7 @@ public void updateNameServerAddressList(List<String> addrs) {
if (addr.contains(namesrvAddr)) {
ChannelWrapper channelWrapper = this.channelTables.get(addr);
if (channelWrapper != null) {
closeChannel(channelWrapper.getChannel());
channelWrapper.close();
}
}
}
Expand Down Expand Up @@ -687,8 +700,9 @@ private Channel createChannel(final String addr) throws InterruptedException {
ChannelFuture channelFuture = fetchBootstrap(addr)
.connect(hostAndPort[0], Integer.parseInt(hostAndPort[1]));
LOGGER.info("createChannel: begin to connect remote host[{}] asynchronously", addr);
cw = new ChannelWrapper(channelFuture);
cw = new ChannelWrapper(addr, channelFuture);
this.channelTables.put(addr, cw);
this.channelWrapperTables.put(channelFuture.channel(), cw);
}
} catch (Exception e) {
LOGGER.error("createChannel: create channel exception", e);
Expand Down Expand Up @@ -756,6 +770,60 @@ public void invokeOneway(String addr, RemotingCommand request, long timeoutMilli
}
}

@Override
public CompletableFuture<RemotingCommand> invoke(String addr, RemotingCommand request,
long timeoutMillis) {
CompletableFuture<RemotingCommand> future = new CompletableFuture<>();
try {
final Channel channel = this.getAndCreateChannel(addr);
if (channel != null && channel.isActive()) {
return invokeImpl(channel, request, timeoutMillis).whenComplete((v, t) -> {
if (t == null) {
updateChannelLastResponseTime(addr);
}
}).thenApply(ResponseFuture::getResponseCommand);
} else {
this.closeChannel(addr, channel);
future.completeExceptionally(new RemotingConnectException(addr));
}
} catch (Throwable t) {
future.completeExceptionally(t);
}
return future;
}

@Override
public CompletableFuture<ResponseFuture> invokeImpl(final Channel channel, final RemotingCommand request,
final long timeoutMillis) {
Stopwatch stopwatch = Stopwatch.createStarted();
return super.invokeImpl(channel, request, timeoutMillis).thenCompose(responseFuture -> {
RemotingCommand response = responseFuture.getResponseCommand();
if (response.getCode() == ResponseCode.GO_AWAY) {
ChannelWrapper channelWrapper = channelWrapperTables.computeIfPresent(channel, (channel0, channelWrapper0) -> {
try {
if (channelWrapper0.reconnect()) {
LOGGER.info("Receive go away from channel {}, recreate the channel", channel0);
channelWrapperTables.put(channelWrapper0.getChannel(), channelWrapper0);
}
} catch (Throwable t) {
LOGGER.error("Channel {} reconnect error", channelWrapper0, t);
}
return channelWrapper0;
});
if (channelWrapper != null) {
long duration = stopwatch.elapsed(TimeUnit.MILLISECONDS);
stopwatch.stop();
RemotingCommand retryRequest = RemotingCommand.createRequestCommand(request.getCode(), request.readCustomHeader());
Channel retryChannel = channelWrapper.getChannel();
if (channel != retryChannel) {
return super.invokeImpl(retryChannel, retryRequest, timeoutMillis - duration);
}
}
}
return CompletableFuture.completedFuture(responseFuture);
});
}

@Override
public void registerProcessor(int requestCode, NettyRequestProcessor processor, ExecutorService executor) {
ExecutorService executorThis = executor;
Expand Down Expand Up @@ -875,30 +943,41 @@ public void run() {
}
}

static class ChannelWrapper {
private final ChannelFuture channelFuture;
class ChannelWrapper {
private final ReentrantReadWriteLock lock;
private ChannelFuture channelFuture;
// only affected by sync or async request, oneway is not included.
private ChannelFuture channelToClose;
private long lastResponseTime;
private volatile long lastReconnectTimestamp = 0L;
private final String channelAddress;

public ChannelWrapper(ChannelFuture channelFuture) {
public ChannelWrapper(String address, ChannelFuture channelFuture) {
this.lock = new ReentrantReadWriteLock();
this.channelFuture = channelFuture;
this.lastResponseTime = System.currentTimeMillis();
this.channelAddress = address;
}

public boolean isOK() {
return this.channelFuture.channel() != null && this.channelFuture.channel().isActive();
return getChannel() != null && getChannel().isActive();
}

public boolean isWritable() {
return this.channelFuture.channel().isWritable();
return getChannel().isWritable();
}

private Channel getChannel() {
return this.channelFuture.channel();
return getChannelFuture().channel();
}

public ChannelFuture getChannelFuture() {
return channelFuture;
lock.readLock().lock();
try {
return this.channelFuture;
} finally {
lock.readLock().unlock();
}
}

public long getLastResponseTime() {
Expand All @@ -908,6 +987,52 @@ public long getLastResponseTime() {
public void updateLastResponseTime() {
this.lastResponseTime = System.currentTimeMillis();
}

public boolean reconnect() {
if (lock.writeLock().tryLock()) {
try {
if (lastReconnectTimestamp == 0L || System.currentTimeMillis() - lastReconnectTimestamp > Duration.ofSeconds(nettyClientConfig.getMaxReconnectIntervalTimeSeconds()).toMillis()) {
channelToClose = channelFuture;
String[] hostAndPort = getHostAndPort(channelAddress);
channelFuture = fetchBootstrap(channelAddress)
.connect(hostAndPort[0], Integer.parseInt(hostAndPort[1]));
lastReconnectTimestamp = System.currentTimeMillis();
return true;
}
} finally {
lock.writeLock().unlock();
}
}
return false;
}

public boolean tryClose(Channel channel) {
try {
lock.readLock().lock();
if (channelFuture != null) {
if (channelFuture.channel().equals(channel)) {
return true;
}
}
} finally {
lock.readLock().unlock();
}
return false;
}

public void close() {
try {
lock.writeLock().lock();
if (channelFuture != null) {
closeChannel(channelFuture.channel());
}
if (channelToClose != null) {
closeChannel(channelToClose.channel());
}
} finally {
lock.writeLock().unlock();
}
}
}

class InvokeCallbackWrapper implements InvokeCallback {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ public class ResponseCode extends RemotingSysResponseCode {
public static final int RPC_SEND_TO_CHANNEL_FAILED = -1004;
public static final int RPC_TIME_OUT = -1006;

public static final int GO_AWAY = 1500;

/**
* Controller response code
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.rocketmq.remoting.InvokeCallback;
import org.apache.rocketmq.remoting.exception.RemotingConnectException;
import org.apache.rocketmq.remoting.exception.RemotingException;
import org.apache.rocketmq.remoting.exception.RemotingSendRequestException;
Expand All @@ -37,7 +36,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;

@RunWith(MockitoJUnitRunner.class)
public class NettyRemotingClientTest {
Expand All @@ -57,13 +56,11 @@ public void testInvokeResponse() throws Exception {

RemotingCommand response = RemotingCommand.createResponseCommand(null);
response.setCode(ResponseCode.SUCCESS);
doAnswer(invocation -> {
InvokeCallback callback = invocation.getArgument(3);
ResponseFuture responseFuture = new ResponseFuture(null, request.getOpaque(), 3 * 1000, null, null);
responseFuture.setResponseCommand(response);
callback.operationSucceed(responseFuture.getResponseCommand());
return null;
}).when(remotingClient).invokeAsync(anyString(), any(RemotingCommand.class), anyLong(), any(InvokeCallback.class));
ResponseFuture responseFuture = new ResponseFuture(null, request.getOpaque(), 3 * 1000, null, null);
responseFuture.setResponseCommand(response);
CompletableFuture<RemotingCommand> future0 = new CompletableFuture<>();
future0.complete(responseFuture.getResponseCommand());
doReturn(future0).when(remotingClient).invoke(anyString(), any(RemotingCommand.class), anyLong());

CompletableFuture<RemotingCommand> future = remotingClient.invoke("0.0.0.0", request, 1000);
RemotingCommand actual = future.get();
Expand All @@ -76,11 +73,9 @@ public void testRemotingSendRequestException() throws Exception {

RemotingCommand response = RemotingCommand.createResponseCommand(null);
response.setCode(ResponseCode.SUCCESS);
doAnswer(invocation -> {
InvokeCallback callback = invocation.getArgument(3);
callback.operationFail(new RemotingSendRequestException(null));
return null;
}).when(remotingClient).invokeAsync(anyString(), any(RemotingCommand.class), anyLong(), any(InvokeCallback.class));
CompletableFuture<RemotingCommand> future0 = new CompletableFuture<>();
future0.completeExceptionally(new RemotingSendRequestException(null));
doReturn(future0).when(remotingClient).invoke(anyString(), any(RemotingCommand.class), anyLong());

CompletableFuture<RemotingCommand> future = remotingClient.invoke("0.0.0.0", request, 1000);
Throwable thrown = catchThrowable(future::get);
Expand All @@ -93,11 +88,9 @@ public void testRemotingTimeoutException() throws Exception {

RemotingCommand response = RemotingCommand.createResponseCommand(null);
response.setCode(ResponseCode.SUCCESS);
doAnswer(invocation -> {
InvokeCallback callback = invocation.getArgument(3);
callback.operationFail(new RemotingTimeoutException(""));
return null;
}).when(remotingClient).invokeAsync(anyString(), any(RemotingCommand.class), anyLong(), any(InvokeCallback.class));
CompletableFuture<RemotingCommand> future0 = new CompletableFuture<>();
future0.completeExceptionally(new RemotingTimeoutException(""));
doReturn(future0).when(remotingClient).invoke(anyString(), any(RemotingCommand.class), anyLong());

CompletableFuture<RemotingCommand> future = remotingClient.invoke("0.0.0.0", request, 1000);
Throwable thrown = catchThrowable(future::get);
Expand All @@ -110,12 +103,6 @@ public void testRemotingException() throws Exception {

RemotingCommand response = RemotingCommand.createResponseCommand(null);
response.setCode(ResponseCode.SUCCESS);
doAnswer(invocation -> {
InvokeCallback callback = invocation.getArgument(3);
ResponseFuture responseFuture = new ResponseFuture(null, request.getOpaque(), 3 * 1000, null, null);
callback.operationFail(new RemotingException(null));
return null;
}).when(remotingClient).invokeAsync(anyString(), any(RemotingCommand.class), anyLong(), any(InvokeCallback.class));

CompletableFuture<RemotingCommand> future = remotingClient.invoke("0.0.0.0", request, 1000);
Throwable thrown = catchThrowable(future::get);
Expand Down

0 comments on commit 65412b0

Please sign in to comment.