diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index be4dae19df5f9..cf094d96d16e4 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1760,25 +1760,24 @@ class SparkContext(config: SparkConf) extends Logging { def listJars(): Seq[String] = addedJars.keySet.toSeq /** - * Shut down the SparkContext. + * When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark + * may wait for some internal threads to finish. It's better to use this method to stop + * SparkContext instead. */ - def stop(): Unit = { - if (env.rpcEnv.isInRPCThread) { - // `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread. - // We should launch a new thread to call `stop` to avoid dead-lock. - new Thread("stop-spark-context") { - setDaemon(true) + private[spark] def stopInNewThread(): Unit = { + new Thread("stop-spark-context") { + setDaemon(true) - override def run(): Unit = { - _stop() - } - }.start() - } else { - _stop() - } + override def run(): Unit = { + SparkContext.this.stop() + } + }.start() } - private def _stop() { + /** + * Shut down the SparkContext. + */ + def stop() { if (LiveListenerBus.withinListenerThread.value) { throw new SparkException( s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index bbc416381490b..530743c03640b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -146,11 +146,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * @param uri URI with location of the file. */ def openChannel(uri: String): ReadableByteChannel - - /** - * Return if the current thread is a RPC thread. - */ - def isInRPCThread: Boolean } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 67baabd2cbff2..a02cf30a5d831 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -201,7 +201,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { - NettyRpcEnv.rpcThreadFlag.value = true try { while (true) { try { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 0b8cd144a2161..e56943da1303a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -407,14 +407,9 @@ private[netty] class NettyRpcEnv( } } - - override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value } private[netty] object NettyRpcEnv extends Logging { - - private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false) - /** * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. * Use `currentEnv` to wrap the deserialization codes. E.g., diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7fde34d8974c0..9378f15b7bc24 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1661,7 +1661,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } - dagScheduler.sc.stop() + dagScheduler.sc.stopInNewThread() } override def onStop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 368cd30a2e11a..7befdb0c1f64d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -139,7 +139,7 @@ private[spark] class StandaloneSchedulerBackend( scheduler.error(reason) } finally { // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + sc.stopInNewThread() } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 91f5606127f42..c6ad15416755d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1249,7 +1249,7 @@ private[spark] object Utils extends Logging { val currentThreadName = Thread.currentThread().getName if (sc != null) { logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t) - sc.stop() + sc.stopInNewThread() } if (!NonFatal(t)) { logError(s"throw uncaught fatal error in thread $currentThreadName", t) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index aa0705987d837..acdf21df9a161 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -870,19 +870,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { verify(endpoint, never()).onDisconnected(any()) verify(endpoint, never()).onNetworkError(any(), any()) } - - test("isInRPCThread") { - val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint { - override val rpcEnv = env - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case m => context.reply(rpcEnv.isInRPCThread) - } - }) - assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true) - assert(env.isInRPCThread === false) - env.stop(rpcEndpointRef) - } } class UnserializableClass