Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoundRobinLoadBalancer fix connection leak during graceful closure #2450

Merged
merged 5 commits into from
Dec 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions servicetalk-loadbalancer/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies {
testImplementation project(":servicetalk-concurrent-test-internal")
testImplementation project(":servicetalk-test-resources")
testImplementation "org.junit.jupiter:junit-jupiter-api"
testImplementation "org.junit.jupiter:junit-jupiter-params"
testImplementation "org.apache.logging.log4j:log4j-core"
testImplementation "org.hamcrest:hamcrest:$hamcrestVersion"
testImplementation "org.mockito:mockito-core:$mockitoCoreVersion"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,19 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map.Entry;
import java.util.Spliterator;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.UnaryOperator;
import java.util.stream.Stream;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -91,7 +97,6 @@ final class RoundRobinLoadBalancer<ResolvedAddress, C extends LoadBalancedConnec
implements LoadBalancer<C> {

private static final Logger LOGGER = LoggerFactory.getLogger(RoundRobinLoadBalancer.class);
private static final List<?> CLOSED_LIST = new ArrayList<>(0);
private static final Object[] EMPTY_ARRAY = new Object[0];

@SuppressWarnings("rawtypes")
Expand Down Expand Up @@ -176,7 +181,7 @@ public void onNext(final Collection<? extends ServiceDiscovererEvent<ResolvedAdd
@SuppressWarnings("unchecked")
final List<Host<ResolvedAddress, C>> usedAddresses =
usedHostsUpdater.updateAndGet(RoundRobinLoadBalancer.this, oldHosts -> {
if (oldHosts == CLOSED_LIST) {
if (isClosedList(oldHosts)) {
return oldHosts;
}
final ResolvedAddress addr = requireNonNull(event.address());
Expand Down Expand Up @@ -235,7 +240,7 @@ private List<Host<ResolvedAddress, C>> markHostAsExpired(

private Host<ResolvedAddress, C> createHost(ResolvedAddress addr) {
Host<ResolvedAddress, C> host = new Host<>(targetResource, addr, healthCheckConfig);
host.onClosing().afterFinally(() ->
host.onClose().afterFinally(() ->
usedHostsUpdater.updateAndGet(RoundRobinLoadBalancer.this, previousHosts -> {
@SuppressWarnings("unchecked")
List<Host<ResolvedAddress, C>> previousHostsTyped =
Expand Down Expand Up @@ -311,12 +316,24 @@ public void onComplete() {
}
});
asyncCloseable = toAsyncCloseable(graceful -> {
@SuppressWarnings("unchecked")
List<Host<ResolvedAddress, C>> currentList = usedHostsUpdater.getAndSet(this, CLOSED_LIST);
discoveryCancellable.cancel();
eventStreamProcessor.onComplete();
CompositeCloseable cc = newCompositeCloseable().appendAll(currentList).appendAll(connectionFactory);
return graceful ? cc.closeAsyncGracefully() : cc.closeAsync();
final CompositeCloseable compositeCloseable;
for (;;) {
List<Host<ResolvedAddress, C>> currentList = usedHosts;
if (isClosedList(currentList) ||
usedHostsUpdater.compareAndSet(this, currentList, new ClosedList<>(currentList))) {
compositeCloseable = newCompositeCloseable().appendAll(currentList).appendAll(connectionFactory);
break;
}
}
return (graceful ? compositeCloseable.closeAsyncGracefully() : compositeCloseable.closeAsync())
.beforeOnError(t -> {
if (!graceful) {
usedHosts = new ClosedList<>(emptyList());
}
})
.beforeOnComplete(() -> usedHosts = new ClosedList<>(emptyList()));
});
}

Expand Down Expand Up @@ -345,7 +362,7 @@ public String toString() {
private Single<C> selectConnection0(final Predicate<C> selector, @Nullable final ContextMap context) {
final List<Host<ResolvedAddress, C>> usedHosts = this.usedHosts;
if (usedHosts.isEmpty()) {
return usedHosts == CLOSED_LIST ? failedLBClosed(targetResource) :
return isClosedList(usedHosts) ? failedLBClosed(targetResource) :
// This is the case when SD has emitted some items but none of the hosts are available.
failed(StacklessNoAvailableHostException.newInstance(
"No hosts are available to connect for " + targetResource + ".",
Expand Down Expand Up @@ -433,7 +450,7 @@ private Single<C> selectConnection0(final Predicate<C> selector, @Nullable final
if (host.addConnection(newCnx)) {
return succeeded(newCnx);
}
return newCnx.closeAsync().concat(this.usedHosts == CLOSED_LIST ? failedLBClosed(targetResource) :
return newCnx.closeAsync().concat(isClosedList(this.usedHosts) ? failedLBClosed(targetResource) :
failed(StacklessConnectionRejectedException.newInstance(
"Failed to add newly created connection " + newCnx + " for " + targetResource
+ " for " + host, RoundRobinLoadBalancer.class, "selectConnection0(...)")));
Expand Down Expand Up @@ -528,7 +545,7 @@ boolean markActiveIfNotClosed() {
}

void markClosed() {
final ConnState oldState = connStateUpdater.getAndSet(this, CLOSED_CONN_STATE);
final ConnState oldState = closeConnState();
final Object[] toRemove = oldState.connections;
cancelIfHealthCheck(oldState.state);
LOGGER.debug("Load balancer for {}: closing {} connection(s) gracefully to the closed address: {}.",
Expand All @@ -540,6 +557,19 @@ void markClosed() {
}
}

private ConnState closeConnState() {
for (;;) {
// We need to keep the oldState.connections around even if we are closed because the user may do
// closeGracefully with a timeout, which fails, and then force close. If we discard connections when
// closeGracefully is started we may leak connections.
final ConnState oldState = connState;
if (oldState.state == State.CLOSED || connStateUpdater.compareAndSet(this, oldState,
new ConnState(oldState.connections, State.CLOSED))) {
Scottmitch marked this conversation as resolved.
Show resolved Hide resolved
return oldState;
}
}
}

void markExpired() {
for (;;) {
ConnState oldState = connStateUpdater.get(this);
Expand Down Expand Up @@ -625,11 +655,11 @@ boolean isActiveAndHealthy() {
boolean addConnection(C connection) {
int addAttempt = 0;
for (;;) {
++addAttempt;
final ConnState previous = connStateUpdater.get(this);
if (previous == CLOSED_CONN_STATE) {
if (previous.state == State.CLOSED) {
return false;
}
++addAttempt;

final Object[] existing = previous.connections;
// Brute force iteration to avoid duplicates. If connections grow larger and faster lookup is required
Expand All @@ -654,14 +684,14 @@ previous, new ConnState(newList, newState))) {
LOGGER.trace("Load balancer for {}: added a new connection {} to {} after {} attempt(s).",
targetResource, connection, this, addAttempt);
// Instrument the new connection so we prune it on close
connection.onClosing().beforeFinally(() -> {
connection.onClose().beforeFinally(() -> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while this is preferred to remove connections from the selection process we still need to manage the lifecycle of the connections. we can use connection.onClosing in the future to disable selection eligibility but the connection by adding extra state or another collection but this PR prioritizes correctness and we can followup to optimize.

int removeAttempt = 0;
for (;;) {
++removeAttempt;
final ConnState currentConnState = this.connState;
if (currentConnState == CLOSED_CONN_STATE) {
if (currentConnState.state == State.CLOSED) {
break;
}
++removeAttempt;
int i = 0;
final Object[] connections = currentConnState.connections;
for (; i < connections.length; ++i) {
Expand Down Expand Up @@ -735,7 +765,7 @@ public Completable onClosing() {
@SuppressWarnings("unchecked")
private Completable doClose(final Function<? super C, Completable> closeFunction) {
return Completable.defer(() -> {
final ConnState oldState = connStateUpdater.getAndSet(this, CLOSED_CONN_STATE);
final ConnState oldState = closeConnState();
cancelIfHealthCheck(oldState.state);
final Object[] connections = oldState.connections;
return (connections.length == 0 ? completed() :
Expand Down Expand Up @@ -905,4 +935,166 @@ public static StacklessConnectionRejectedException newInstance(String message, C
return ThrowableUtils.unknownStackTrace(new StacklessConnectionRejectedException(message), clazz, method);
}
}

private static boolean isClosedList(List<?> list) {
return list.getClass().equals(ClosedList.class);
}

private static final class ClosedList<T> implements List<T> {
private final List<T> delegate;

private ClosedList(final List<T> delegate) {
this.delegate = requireNonNull(delegate);
}

@Override
public int size() {
return delegate.size();
}

@Override
public boolean isEmpty() {
return delegate.isEmpty();
}

@Override
public boolean contains(final Object o) {
return delegate.contains(o);
}

@Override
public Iterator<T> iterator() {
return delegate.iterator();
}

@Override
public void forEach(final Consumer<? super T> action) {
delegate.forEach(action);
}

@Override
public Object[] toArray() {
return delegate.toArray();
}

@Override
public <T1> T1[] toArray(final T1[] a) {
idelpivnitskiy marked this conversation as resolved.
Show resolved Hide resolved
return delegate.toArray(a);
}

@Override
public boolean add(final T t) {
return delegate.add(t);
}

@Override
public boolean remove(final Object o) {
return delegate.remove(o);
}

@Override
public boolean containsAll(final Collection<?> c) {
return delegate.containsAll(c);
}

@Override
public boolean addAll(final Collection<? extends T> c) {
return delegate.addAll(c);
}

@Override
public boolean addAll(final int index, final Collection<? extends T> c) {
return delegate.addAll(c);
}

@Override
public boolean removeAll(final Collection<?> c) {
return delegate.removeAll(c);
}

@Override
public boolean removeIf(final Predicate<? super T> filter) {
return delegate.removeIf(filter);
}

@Override
public boolean retainAll(final Collection<?> c) {
return delegate.retainAll(c);
}

@Override
public void replaceAll(final UnaryOperator<T> operator) {
delegate.replaceAll(operator);
}

@Override
public void sort(final Comparator<? super T> c) {
delegate.sort(c);
}

@Override
public void clear() {
delegate.clear();
}

@Override
public T get(final int index) {
return delegate.get(index);
}

@Override
public T set(final int index, final T element) {
return delegate.set(index, element);
}

@Override
public void add(final int index, final T element) {
delegate.add(index, element);
}

@Override
public T remove(final int index) {
return delegate.remove(index);
}

@Override
public int indexOf(final Object o) {
return delegate.indexOf(o);
}

@Override
public int lastIndexOf(final Object o) {
return delegate.lastIndexOf(o);
}

@Override
public ListIterator<T> listIterator() {
return delegate.listIterator();
}

@Override
public ListIterator<T> listIterator(final int index) {
return delegate.listIterator(index);
}

@Override
public List<T> subList(final int fromIndex, final int toIndex) {
return new ClosedList<>(delegate.subList(fromIndex, toIndex));
}

@Override
public Spliterator<T> spliterator() {
return delegate.spliterator();
}

@Override
public Stream<T> stream() {
return delegate.stream();
}

@Override
public Stream<T> parallelStream() {
return delegate.parallelStream();
}
}
}
Loading