diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 009ed64775844..57874df3819b2 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -91,6 +91,9 @@ class SparkEnv ( // actorSystem.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // clear all the references in ThreadLocal object + SparkEnv.reset() } private[spark] @@ -119,7 +122,7 @@ class SparkEnv ( } object SparkEnv extends Logging { - private val env = new ThreadLocal[SparkEnv] + @volatile private var env = new ThreadLocal[SparkEnv] @volatile private var lastSetSparkEnv : SparkEnv = _ private[spark] val driverActorSystemName = "sparkDriver" @@ -130,6 +133,12 @@ object SparkEnv extends Logging { env.set(e) } + // clear all the threadlocal references + private[spark] def reset(): Unit = { + env = new ThreadLocal[SparkEnv] + lastSetSparkEnv = null + } + /** * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv * previously set in any thread.