diff --git a/CHANGELOG.md b/CHANGELOG.md index 68dd1c9a2e..af063db3bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and what APIs have changed, if applicable. - Strictly enforce Gradle version compatibility in the `pegasus` Gradle plugin. - Minimum required Gradle version is now `1.0` (effectively backward-compatible). - Minimum suggested Gradle version is now `5.2.1` +- Fix TimingKey Memory Leak ## [29.18.2] - 2021-04-28 - Fix bug in generated fluent client APIs when typerefs are used as association key params diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java index 8b7864808a..bc9467650e 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/DynamicClient.java @@ -157,6 +157,7 @@ public void shutdown(final Callback callback) callback.onSuccess(None.none()); }); + TimingKey.unregisterKey(TIMING_KEY); } @Override diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java b/r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java index bd6273fb89..f2e9f142c3 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/FilterChain.java @@ -23,7 +23,7 @@ import com.linkedin.r2.message.rest.RestResponse; import com.linkedin.r2.message.stream.StreamRequest; import com.linkedin.r2.message.stream.StreamResponse; - +import java.util.List; import java.util.Map; /** @@ -189,4 +189,14 @@ void onStreamResponse(StreamResponse res, void onStreamError(Exception ex, RequestContext requestContext, Map wireAttrs); + + /** + * Returns a copy of a list of RestFilters + */ + List getRestFilters(); + + /** + * Returns a copy of a list of StreamFilters + */ + List getStreamFilters(); } diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java b/r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java index dee5ff2ced..725f3193f7 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/FilterChainImpl.java @@ -78,6 +78,16 @@ public FilterChain addLast(StreamFilter filter) return new FilterChainImpl(_restFilters, doAddLast(_streamFilters, decorateStreamFilter(filter))); } + @Override + public List getRestFilters() { + return new ArrayList(_restFilters); + } + + @Override + public List getStreamFilters() { + return new ArrayList(_streamFilters); + } + private RestFilter decorateRestFilter(RestFilter filter) { return new TimedRestFilter(filter); diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/TimedRestFilter.java b/r2-core/src/main/java/com/linkedin/r2/filter/TimedRestFilter.java index 0261b5a0f8..ee7cecfe1d 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/TimedRestFilter.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/TimedRestFilter.java @@ -23,6 +23,8 @@ import com.linkedin.r2.message.rest.RestRequest; import com.linkedin.r2.message.rest.RestResponse; import com.linkedin.r2.message.timing.TimingImportance; +import java.util.Arrays; +import java.util.List; import java.util.Map; @@ -31,7 +33,7 @@ * * @author Xialin Zhu */ -/* package private */ class TimedRestFilter implements RestFilter +public class TimedRestFilter implements RestFilter { protected static final String ON_REQUEST_SUFFIX = "onRequest"; protected static final String ON_RESPONSE_SUFFIX = "onResponse"; @@ -41,6 +43,7 @@ private final TimingKey _onRequestTimingKey; private final TimingKey _onResponseTimingKey; private final TimingKey _onErrorTimingKey; + private boolean _shared; /** * Registers {@link TimingKey}s for {@link com.linkedin.r2.message.timing.TimingNameConstants#TIMED_REST_FILTER}. @@ -61,6 +64,7 @@ public TimedRestFilter(RestFilter restFilter) _restFilter.getClass().getSimpleName(), TimingImportance.LOW); _onErrorTimingKey = TimingKey.registerNewKey(timingKeyPrefix + ON_ERROR_SUFFIX + timingKeyPostfix, _restFilter.getClass().getSimpleName(), TimingImportance.LOW); + _shared = false; } @Override @@ -91,4 +95,16 @@ public void onRestError(Throwable ex, TimingContextUtil.markTiming(requestContext, _onErrorTimingKey); _restFilter.onRestError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter)); } + + public void setShared() { + _shared = true; + } + + public void onShutdown() { + if (!_shared) { + TimingKey.unregisterKey(_onErrorTimingKey); + TimingKey.unregisterKey(_onRequestTimingKey); + TimingKey.unregisterKey(_onResponseTimingKey); + } + } } diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/TimedStreamFilter.java b/r2-core/src/main/java/com/linkedin/r2/filter/TimedStreamFilter.java index c1759decac..ec6a29112b 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/TimedStreamFilter.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/TimedStreamFilter.java @@ -34,12 +34,13 @@ * * @author Xialin Zhu */ -/* package private */ class TimedStreamFilter implements StreamFilter +public class TimedStreamFilter implements StreamFilter { private final StreamFilter _streamFilter; private final TimingKey _onRequestTimingKey; private final TimingKey _onResponseTimingKey; private final TimingKey _onErrorTimingKey; + private boolean _shared; /** * Registers {@link TimingKey}s for {@link com.linkedin.r2.message.timing.TimingNameConstants#TIMED_STREAM_FILTER}. @@ -60,6 +61,7 @@ public TimedStreamFilter(StreamFilter streamFilter) filterClassName, TimingImportance.LOW); _onErrorTimingKey = TimingKey.registerNewKey(timingKeyPrefix + ON_ERROR_SUFFIX + timingKeyPostfix, filterClassName, TimingImportance.LOW); + _shared = false; } @Override @@ -91,4 +93,16 @@ public void onStreamError(Throwable ex, TimingContextUtil.markTiming(requestContext, _onErrorTimingKey); _streamFilter.onStreamError(ex, requestContext, wireAttrs, new TimedNextFilter<>(_onErrorTimingKey, nextFilter)); } + + public void setShared() { + _shared = true; + } + + public void onShutdown() { + if (!_shared) { + TimingKey.unregisterKey(_onErrorTimingKey); + TimingKey.unregisterKey(_onRequestTimingKey); + TimingKey.unregisterKey(_onResponseTimingKey); + } + } } diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java b/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java index 11e4e15840..134200be56 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/transport/FilterChainClient.java @@ -17,10 +17,13 @@ /* $Id$ */ package com.linkedin.r2.filter.transport; - import com.linkedin.common.callback.Callback; import com.linkedin.common.util.None; import com.linkedin.r2.filter.FilterChain; +import com.linkedin.r2.filter.TimedRestFilter; +import com.linkedin.r2.filter.TimedStreamFilter; +import com.linkedin.r2.filter.message.rest.RestFilter; +import com.linkedin.r2.filter.message.stream.StreamFilter; import com.linkedin.r2.message.RequestContext; import com.linkedin.r2.message.Response; import com.linkedin.r2.message.rest.RestRequest; @@ -29,12 +32,15 @@ import com.linkedin.r2.message.stream.StreamResponse; import com.linkedin.r2.message.timing.FrameworkTimingKeys; import com.linkedin.r2.message.timing.TimingContextUtil; +import com.linkedin.r2.message.timing.TimingKey; import com.linkedin.r2.transport.common.bridge.client.TransportClient; import com.linkedin.r2.transport.common.bridge.common.TransportCallback; - import com.linkedin.r2.transport.common.bridge.common.TransportResponse; +import java.util.Collection; +import java.util.List; import java.util.Map; + /** * {@link TransportClient} adapter which composes a {@link TransportClient} * and a {@link FilterChain}. @@ -94,6 +100,12 @@ public void streamRequest(StreamRequest request, public void shutdown(Callback callback) { _client.shutdown(callback); + + _filters.getStreamFilters().stream().filter(TimedStreamFilter.class::isInstance) + .map(TimedStreamFilter.class::cast).forEach(TimedStreamFilter::onShutdown); + + _filters.getRestFilters().stream().filter(TimedRestFilter.class::isInstance) + .map(TimedRestFilter.class::cast).forEach(TimedRestFilter::onShutdown); } /** diff --git a/r2-core/src/main/java/com/linkedin/r2/message/timing/TimingKey.java b/r2-core/src/main/java/com/linkedin/r2/message/timing/TimingKey.java index 8d551af59f..7012121e83 100644 --- a/r2-core/src/main/java/com/linkedin/r2/message/timing/TimingKey.java +++ b/r2-core/src/main/java/com/linkedin/r2/message/timing/TimingKey.java @@ -17,9 +17,12 @@ package com.linkedin.r2.message.timing; import java.util.Map; +import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import com.linkedin.r2.message.RequestContext; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** @@ -32,6 +35,7 @@ public class TimingKey { private static final Map _pool = new ConcurrentHashMap<>(); + private static final ExecutorService _unregisterExecutor = Executors.newFixedThreadPool(1); private final String _name; private final String _type; @@ -130,4 +134,26 @@ public static TimingKey registerNewKey(String uniqueName, String type, TimingImp { return registerNewKey(new TimingKey(uniqueName, type, timingImportance)); } + + /** + * Unregister a TimingKey to reclaim the memory + * + */ + public static void unregisterKey(TimingKey key) + { + _unregisterExecutor.submit(new Callable() { + public Void call() throws Exception { + _pool.remove(key.getName()); + return null; + } + }); + } + + /** + * Return how many registered keys, for testing purpose. + */ + public static int getCount() { + return _pool.size(); + } + } diff --git a/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java b/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java index d458a2f1a1..e9d0ac9fac 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java +++ b/r2-netty/src/main/java/com/linkedin/r2/netty/client/HttpNettyClient.java @@ -246,6 +246,7 @@ public void onSuccess(None result) { callback.onError(new IllegalStateException("Shutdown has already been requested.")); } + TimingKey.unregisterKey(TIMING_KEY); } private void sendStreamRequestAsRestRequest(StreamRequest request, RequestContext requestContext, diff --git a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java index 2be59979f3..6b1de01551 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java +++ b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/HttpClientFactory.java @@ -26,6 +26,8 @@ import com.linkedin.r2.filter.CompressionConfig; import com.linkedin.r2.filter.FilterChain; import com.linkedin.r2.filter.FilterChains; +import com.linkedin.r2.filter.TimedRestFilter; +import com.linkedin.r2.filter.TimedStreamFilter; import com.linkedin.r2.filter.compression.ClientCompressionFilter; import com.linkedin.r2.filter.compression.ClientCompressionHelper; import com.linkedin.r2.filter.compression.ClientStreamCompressionFilter; @@ -685,6 +687,11 @@ private HttpClientFactory(FilterChain filters, { _channelPoolManagerFactory = new ConnectionSharingChannelPoolManagerFactory(_channelPoolManagerFactory); } + + _filters.getStreamFilters().stream().filter(TimedStreamFilter.class::isInstance) + .map(TimedStreamFilter.class::cast).forEach(TimedStreamFilter::setShared); + _filters.getRestFilters().stream().filter(TimedRestFilter.class::isInstance) + .map(TimedRestFilter.class::cast).forEach(TimedRestFilter::setShared); } public static class Builder @@ -958,7 +965,6 @@ public TransportClient getClient(Map properties) properties = new HashMap(properties); sslContext = coerceAndRemoveFromMap(HTTP_SSL_CONTEXT, properties, SSLContext.class); sslParameters = coerceAndRemoveFromMap(HTTP_SSL_PARAMS, properties, SSLParameters.class); - return getClient(properties, sslContext, sslParameters); } diff --git a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java index d08dbc1e8a..4db1b81ee9 100644 --- a/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java +++ b/r2-netty/src/main/java/com/linkedin/r2/transport/http/client/common/AbstractNettyClient.java @@ -284,6 +284,7 @@ public void onSuccess(None result) _shutdownTimeout); _jmxManager.onProviderShutdown(_channelPoolManager); _jmxManager.onProviderShutdown(_sslChannelPoolManager); + TimingKey.unregisterKey(TIMING_KEY); } else { diff --git a/r2-netty/src/test/java/com/linkedin/r2/transport/http/client/TestHttpClientFactory.java b/r2-netty/src/test/java/com/linkedin/r2/transport/http/client/TestHttpClientFactory.java index 136b599c9d..9ca82c882b 100644 --- a/r2-netty/src/test/java/com/linkedin/r2/transport/http/client/TestHttpClientFactory.java +++ b/r2-netty/src/test/java/com/linkedin/r2/transport/http/client/TestHttpClientFactory.java @@ -27,6 +27,7 @@ import com.linkedin.r2.message.rest.RestRequest; import com.linkedin.r2.message.rest.RestRequestBuilder; import com.linkedin.r2.message.rest.RestResponse; +import com.linkedin.r2.message.timing.TimingKey; import com.linkedin.r2.testutils.server.HttpServerBuilder; import com.linkedin.r2.transport.common.Client; import com.linkedin.r2.transport.common.bridge.client.TransportClient; @@ -93,13 +94,17 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion { server.start(); List clients = new ArrayList<>(); + + int savedTimingKeyCount = TimingKey.getCount(); for (int i = 0; i < 100; i++) { HashMap properties = new HashMap<>(); properties.put(HttpClientFactory.HTTP_PROTOCOL_VERSION, protocolVersion); clients.add(new TransportClientAdapter(factory.getClient(properties), restOverStream)); } - + int addedTimingKeyCount = TimingKey.getCount() - savedTimingKeyCount; + // In current implementation, one client can have around 30 TimingKeys by default. + Assert.assertTrue(addedTimingKeyCount >= 30 * clients.size()); for (Client c : clients) { RestRequest r = new RestRequestBuilder(new URI(URI)).build(); @@ -107,6 +112,7 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion } Assert.assertEquals(httpServerStatsProvider.requestCount(), expectedRequests); + savedTimingKeyCount = TimingKey.getCount(); for (Client c : clients) { FutureCallback callback = new FutureCallback<>(); @@ -117,6 +123,8 @@ public void testSuccessfulRequest(boolean restOverStream, String protocolVersion FutureCallback factoryShutdown = new FutureCallback<>(); factory.shutdown(factoryShutdown); factoryShutdown.get(30, TimeUnit.SECONDS); + int removedTimingKeyCount = savedTimingKeyCount - TimingKey.getCount(); + Assert.assertEquals(addedTimingKeyCount, removedTimingKeyCount); } finally {