From e6482fad7da812fd3fe775f064e19893717f7a88 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 2 Jan 2016 14:26:55 -0800 Subject: [PATCH] Thread-safety fixes. --- .../scala/org/apache/spark/MapOutputTracker.scala | 4 ++-- .../main/scala/org/apache/spark/SparkContext.scala | 6 +++++- .../spark/shuffle/FileShuffleBlockResolver.scala | 14 +++++++++----- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 2e94ab844195d..a2675db7a5bd2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -291,8 +291,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) // HashMaps for storing mapStatuses and cached serialized statuses in the driver. // Statuses are dropped only by explicit de-registering. - protected val mapStatuses = new HashMap[Int, Array[MapStatus]]() - private val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]() + protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala def registerShuffle(shuffleId: Int, numMaps: Int) { if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c2ca31fc83f5a..26d018b048a17 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -23,6 +23,7 @@ import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} +import java.util.concurrent.ConcurrentMap import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID @@ -295,7 +296,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new MapMaker().weakValues().makeMap[Int, RDD[_]]().asScala + private[spark] val persistentRdds = { + val map : ConcurrentMap[Int, RDD[_]] = new MapMaker().weakValues().makeMap[Int, RDD[_]]() + map.asScala + } private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener def statusTracker: SparkStatusTracker = _statusTracker diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index ed877aa0a7759..b2f4a730d5b32 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import scala.collection.JavaConverters._ @@ -63,7 +63,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val completedMapTasks = new ConcurrentLinkedQueue[Int]() } - private val shuffleStates = new scala.collection.mutable.HashMap[ShuffleId, ShuffleState] + private val shuffleStates = new ConcurrentHashMap[ShuffleId, ShuffleState] /** * Get a ShuffleWriterGroup for the given map task, which will register it as complete @@ -72,8 +72,12 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - private val shuffleState = - shuffleStates.getOrElseUpdate(shuffleId, new ShuffleState(numReducers)) + private val shuffleState: ShuffleState = { + // Note: we do _not_ want to just wrap this java ConcurrentHashMap into a Scala map and use + // .getOrElseUpdate() because that's actually NOT atomic. + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) + shuffleStates.get(shuffleId) + } val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() val writers: Array[DiskBlockObjectWriter] = { @@ -110,7 +114,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) /** Remove all the blocks / files related to a particular shuffle. */ private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { - shuffleStates.get(shuffleId) match { + Option(shuffleStates.get(shuffleId)) match { case Some(state) => for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)