diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index f198aa8564a54..df4b085d2251e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -18,13 +18,13 @@ package org.apache.spark.network.nio import java.io.IOException +import java.lang.ref.WeakReference import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} -import java.util.{Timer, TimerTask} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ @@ -32,6 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 +import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import org.apache.spark._ import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} @@ -77,7 +78,8 @@ private[nio] class ConnectionManager( } private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) + private val ackTimeoutMonitor = + new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) @@ -139,7 +141,10 @@ private[nio] class ConnectionManager( new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - private val messageStatuses = new HashMap[Int, MessageStatus] + // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this + // map when messages are sent and are removed when acknowledgement messages are received or when + // acknowledgement timeouts expire + private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] @@ -899,22 +904,41 @@ private[nio] class ConnectionManager( : Future[Message] = { val promise = Promise[Message]() - val timeoutTask = new TimerTask { - override def run(): Unit = { + // It's important that the TimerTask doesn't capture a reference to `message`, which can cause + // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time + // at which they would originally be scheduled to run. Therefore, extract the message id + // from outside of the TimerTask closure (see SPARK-4393 for more context). + val messageId = message.id + // Keep a weak reference to the promise so that the completed promise may be garbage-collected + val promiseReference = new WeakReference(promise) + val timeoutTask: TimerTask = new TimerTask { + override def run(timeout: Timeout): Unit = { messageStatuses.synchronized { - messageStatuses.remove(message.id).foreach ( s => { + messageStatuses.remove(messageId).foreach { s => val e = new IOException("sendMessageReliably failed because ack " + s"was not received within $ackTimeout sec") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) + val p = promiseReference.get + if (p != null) { + // Attempt to fail the promise with a Timeout exception + if (!p.tryFailure(e)) { + // If we reach here, then someone else has already signalled success or failure + // on this promise, so log a warning: + logError("Ignore error because promise is completed", e) + } + } else { + // The WeakReference was empty, which should never happen because + // sendMessageReliably's caller should have a strong reference to promise.future; + logError("Promise was garbage collected; this should never happen!", e) } - }) + } } } } + val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) + val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTask.cancel() + timeoutTaskHandle.cancel() s match { case scala.util.Failure(e) => // Indicates a failure where we either never sent or never got ACK'd @@ -943,7 +967,6 @@ private[nio] class ConnectionManager( messageStatuses += ((message.id, status)) } - ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } @@ -953,7 +976,7 @@ private[nio] class ConnectionManager( } def stop() { - ackTimeoutMonitor.cancel() + ackTimeoutMonitor.stop() selectorThread.interrupt() selectorThread.join() selector.close()