Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4393] Fix memory leak in ConnectionManager ACK timeout TimerTasks; use HashedWheelTimer #3259

Closed
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@
package org.apache.spark.network.nio

import java.io.IOException
import java.lang.ref.WeakReference
import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
import java.util.{Timer, TimerTask}

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps

import com.google.common.base.Charsets.UTF_8
import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}

import org.apache.spark._
import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
Expand Down Expand Up @@ -77,7 +78,8 @@ private[nio] class ConnectionManager(
}

private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
private val ackTimeoutMonitor =
new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))

private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)

Expand Down Expand Up @@ -139,7 +141,10 @@ private[nio] class ConnectionManager(
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
with SynchronizedMap[ConnectionManagerId, SendingConnection]
private val messageStatuses = new HashMap[Int, MessageStatus]
// Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this
// map when messages are sent and are removed when acknowledgement messages are received or when
// acknowledgement timeouts expire
private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]

Expand Down Expand Up @@ -899,22 +904,41 @@ private[nio] class ConnectionManager(
: Future[Message] = {
val promise = Promise[Message]()

val timeoutTask = new TimerTask {
override def run(): Unit = {
// It's important that the TimerTask doesn't capture a reference to `message`, which can cause
// memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time
// at which they would originally be scheduled to run. Therefore, extract the message id
// from outside of the TimerTask closure (see SPARK-4393 for more context).
val messageId = message.id
// Keep a weak reference to the promise so that the completed promise may be garbage-collected
val promiseReference = new WeakReference(promise)
val timeoutTask: TimerTask = new TimerTask {
override def run(timeout: Timeout): Unit = {
messageStatuses.synchronized {
messageStatuses.remove(message.id).foreach ( s => {
messageStatuses.remove(messageId).foreach ( s => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick here - u can remove one layer of parenthesis/brackets

messageStatuses.remove(messageId).foreach { s =>

}

val e = new IOException("sendMessageReliably failed because ack " +
s"was not received within $ackTimeout sec")
if (!promise.tryFailure(e)) {
logWarning("Ignore error because promise is completed", e)
Option(promiseReference.get) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not

val p = promiseReference.get
if (p == null) {
  ...
} else {
  ...
}

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually on the fence about this, but your comment tips me towards the == null camp since it removes a level of nesting / indentation.

case Some(p) =>
// Attempt to fail the promise with a Timeout exception
if (!p.tryFailure(e)) {
// If we reach here, then someone else has already signalled success or failure
// on this promise, so log a warning:
logError("Ignore error because promise is completed", e)
}
case None =>
// The WeakReference was empty, which should never happen because
// sendMessageReliably's caller should have a strong reference to promise.future;
logError("Promise was garbage collected; this should never happen!", e)
}
})
}
}
}

val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS)

val status = new MessageStatus(message, connectionManagerId, s => {
timeoutTask.cancel()
timeoutTaskHandle.cancel()
s match {
case scala.util.Failure(e) =>
// Indicates a failure where we either never sent or never got ACK'd
Expand Down Expand Up @@ -943,7 +967,6 @@ private[nio] class ConnectionManager(
messageStatuses += ((message.id, status))
}

ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
sendMessage(connectionManagerId, message)
promise.future
}
Expand All @@ -953,7 +976,7 @@ private[nio] class ConnectionManager(
}

def stop() {
ackTimeoutMonitor.cancel()
ackTimeoutMonitor.stop()
selectorThread.interrupt()
selectorThread.join()
selector.close()
Expand Down