Skip to content

Commit

Permalink
Fixed memory leak issue of ConnectionManager
Browse files Browse the repository at this point in the history
  • Loading branch information
sarutak committed Nov 17, 2014
1 parent 4b1c77c commit 786af91
Showing 1 changed file with 39 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.spark.network

import java.io.IOException
import java.lang.ref.WeakReference
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
import java.util.{Timer, TimerTask}
import java.util.concurrent.atomic.AtomicInteger

import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
Expand All @@ -37,6 +37,8 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.concurrent.duration._
import scala.language.postfixOps

import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}

import org.apache.spark._
import org.apache.spark.util.{SystemClock, Utils}

Expand Down Expand Up @@ -68,7 +70,8 @@ private[spark] class ConnectionManager(
}

private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
private val ackTimeoutMonitor =
new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))

// default to 30 second timeout waiting for authentication
private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
Expand Down Expand Up @@ -105,7 +108,10 @@ private[spark] class ConnectionManager(
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
with SynchronizedMap[ConnectionManagerId, SendingConnection]
private val messageStatuses = new HashMap[Int, MessageStatus]
// Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this
// map when messages are sent and are removed when acknowledgement messages are received or when
// acknowledgement timeouts expire
private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]

Expand Down Expand Up @@ -846,20 +852,41 @@ private[spark] class ConnectionManager(
: Future[Message] = {
val promise = Promise[Message]()

val timeoutTask = new TimerTask {
override def run(): Unit = {
// It's important that the TimerTask doesn't capture a reference to `message`, which can cause
// memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time
// at which they would originally be scheduled to run. Therefore, extract the message id
// from outside of the TimerTask closure (see SPARK-4393 for more context).
val messageId = message.id
// Keep a weak reference to the promise so that the completed promise may be garbage-collected
val promiseReference = new WeakReference(promise)
val timeoutTask: TimerTask = new TimerTask {
override def run(timeout: Timeout): Unit = {
messageStatuses.synchronized {
messageStatuses.remove(message.id).foreach ( s => {
promise.failure(
new IOException("sendMessageReliably failed because ack " +
s"was not received within $ackTimeout sec"))
})
messageStatuses.remove(messageId).foreach { s =>
val e = new IOException("sendMessageReliably failed because ack " +
s"was not received within $ackTimeout sec")
val p = promiseReference.get
if (p != null) {
// Attempt to fail the promise with a Timeout exception
if (!p.tryFailure(e)) {
// If we reach here, then someone else has already signalled success or failure
// on this promise, so log a warning:
logError("Ignore error because promise is completed", e)
}
} else {
// The WeakReference was empty, which should never happen because
// sendMessageReliably's caller should have a strong reference to promise.future;
logError("Promise was garbage collected; this should never happen!", e)
}
}
}
}
}

val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS)

val status = new MessageStatus(message, connectionManagerId, s => {
timeoutTask.cancel()
timeoutTaskHandle.cancel()
s.ackMessage match {
case None => // Indicates a failure where we either never sent or never got ACK'd
promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
Expand All @@ -876,7 +903,6 @@ private[spark] class ConnectionManager(
messageStatuses += ((message.id, status))
}

ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
sendMessage(connectionManagerId, message)
promise.future
}
Expand All @@ -886,7 +912,7 @@ private[spark] class ConnectionManager(
}

def stop() {
ackTimeoutMonitor.cancel()
ackTimeoutMonitor.stop()
selectorThread.interrupt()
selectorThread.join()
selector.close()
Expand Down

0 comments on commit 786af91

Please sign in to comment.