From 786af9161b869b224f9ef314ba4b915a563acfb4 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 17 Nov 2014 11:38:35 -0800 Subject: [PATCH] Fixed memory leak issue of ConnectionManager --- .../spark/network/ConnectionManager.scala | 52 ++++++++++++++----- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 578d806263006..6d58129babc88 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -18,11 +18,11 @@ package org.apache.spark.network import java.io.IOException +import java.lang.ref.WeakReference import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ -import java.util.{Timer, TimerTask} import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} @@ -37,6 +37,8 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ import scala.language.postfixOps +import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} + import org.apache.spark._ import org.apache.spark.util.{SystemClock, Utils} @@ -68,7 +70,8 @@ private[spark] class ConnectionManager( } private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) + private val ackTimeoutMonitor = + new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) // default to 30 second timeout waiting for authentication private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) @@ -105,7 +108,10 @@ private[spark] class ConnectionManager( new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - private val messageStatuses = new HashMap[Int, MessageStatus] + // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this + // map when messages are sent and are removed when acknowledgement messages are received or when + // acknowledgement timeouts expire + private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] @@ -846,20 +852,41 @@ private[spark] class ConnectionManager( : Future[Message] = { val promise = Promise[Message]() - val timeoutTask = new TimerTask { - override def run(): Unit = { + // It's important that the TimerTask doesn't capture a reference to `message`, which can cause + // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time + // at which they would originally be scheduled to run. Therefore, extract the message id + // from outside of the TimerTask closure (see SPARK-4393 for more context). + val messageId = message.id + // Keep a weak reference to the promise so that the completed promise may be garbage-collected + val promiseReference = new WeakReference(promise) + val timeoutTask: TimerTask = new TimerTask { + override def run(timeout: Timeout): Unit = { messageStatuses.synchronized { - messageStatuses.remove(message.id).foreach ( s => { - promise.failure( - new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec")) - }) + messageStatuses.remove(messageId).foreach { s => + val e = new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec") + val p = promiseReference.get + if (p != null) { + // Attempt to fail the promise with a Timeout exception + if (!p.tryFailure(e)) { + // If we reach here, then someone else has already signalled success or failure + // on this promise, so log a warning: + logError("Ignore error because promise is completed", e) + } + } else { + // The WeakReference was empty, which should never happen because + // sendMessageReliably's caller should have a strong reference to promise.future; + logError("Promise was garbage collected; this should never happen!", e) + } + } } } } + val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) + val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTask.cancel() + timeoutTaskHandle.cancel() s.ackMessage match { case None => // Indicates a failure where we either never sent or never got ACK'd promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) @@ -876,7 +903,6 @@ private[spark] class ConnectionManager( messageStatuses += ((message.id, status)) } - ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } @@ -886,7 +912,7 @@ private[spark] class ConnectionManager( } def stop() { - ackTimeoutMonitor.cancel() + ackTimeoutMonitor.stop() selectorThread.interrupt() selectorThread.join() selector.close()