Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18751][Core]Fix deadlock when SparkContext.stop is called in Utils.tryOrStopSparkContext #16178

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1760,25 +1760,30 @@ class SparkContext(config: SparkConf) extends Logging {
def listJars(): Seq[String] = addedJars.keySet.toSeq

/**
* Shut down the SparkContext.
* When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark
* may wait for some internal threads to finish. It's better to use this method to stop
* SparkContext instead.
*/
def stop(): Unit = {
if (env.rpcEnv.isInRPCThread) {
// `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread.
// We should launch a new thread to call `stop` to avoid dead-lock.
new Thread("stop-spark-context") {
setDaemon(true)

override def run(): Unit = {
_stop()
private[spark] def stopInNewThread(): Unit = {
new Thread("stop-spark-context") {
setDaemon(true)

override def run(): Unit = {
try {
SparkContext.this.stop()
} catch {
case e: Throwable =>
logError(e.getMessage, e)
throw e
}
}.start()
} else {
_stop()
}
}
}.start()
}

private def _stop() {
/**
* Shut down the SparkContext.
*/
def stop(): Unit = {
if (LiveListenerBus.withinListenerThread.value) {
throw new SparkException(
s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}")
Expand Down
5 changes: 0 additions & 5 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* @param uri URI with location of the file.
*/
def openChannel(uri: String): ReadableByteChannel

/**
* Return if the current thread is a RPC thread.
*/
def isInRPCThread: Boolean
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
NettyRpcEnv.rpcThreadFlag.value = true
try {
while (true) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,9 @@ private[netty] class NettyRpcEnv(
}

}

override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value
}

private[netty] object NettyRpcEnv extends Logging {

private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false)

/**
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
* Use `currentEnv` to wrap the deserialization codes. E.g.,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
dagScheduler.sc.stop()
dagScheduler.sc.stopInNewThread()
}

override def onStop(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ private[spark] class StandaloneSchedulerBackend(
scheduler.error(reason)
} finally {
// Ensure the application terminates, as we can no longer run jobs.
sc.stop()
sc.stopInNewThread()
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ private[spark] object Utils extends Logging {
val currentThreadName = Thread.currentThread().getName
if (sc != null) {
logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t)
sc.stop()
sc.stopInNewThread()
}
if (!NonFatal(t)) {
logError(s"throw uncaught fatal error in thread $currentThreadName", t)
Expand Down
13 changes: 0 additions & 13 deletions core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -870,19 +870,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
verify(endpoint, never()).onDisconnected(any())
verify(endpoint, never()).onNetworkError(any(), any())
}

test("isInRPCThread") {
val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint {
override val rpcEnv = env

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case m => context.reply(rpcEnv.isInRPCThread)
}
})
assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true)
assert(env.isInRPCThread === false)
env.stop(rpcEndpointRef)
}
}

class UnserializableClass
Expand Down