Skip to content

Commit

Permalink
Thread-safety fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jan 2, 2016
1 parent 5ffe30f commit e6482fa
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)

// HashMaps for storing mapStatuses and cached serialized statuses in the driver.
// Statuses are dropped only by explicit de-registering.
protected val mapStatuses = new HashMap[Int, Array[MapStatus]]()
private val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]()
protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala

def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.io._
import java.lang.reflect.Constructor
import java.net.URI
import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.ConcurrentMap
import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger}
import java.util.UUID.randomUUID

Expand Down Expand Up @@ -295,7 +296,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private[spark] val addedJars = HashMap[String, Long]()

// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new MapMaker().weakValues().makeMap[Int, RDD[_]]().asScala
private[spark] val persistentRdds = {
val map : ConcurrentMap[Int, RDD[_]] = new MapMaker().weakValues().makeMap[Int, RDD[_]]()
map.asScala
}
private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener

def statusTracker: SparkStatusTracker = _statusTracker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle

import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -63,7 +63,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
val completedMapTasks = new ConcurrentLinkedQueue[Int]()
}

private val shuffleStates = new scala.collection.mutable.HashMap[ShuffleId, ShuffleState]
private val shuffleStates = new ConcurrentHashMap[ShuffleId, ShuffleState]

/**
* Get a ShuffleWriterGroup for the given map task, which will register it as complete
Expand All @@ -72,8 +72,12 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer,
writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = {
new ShuffleWriterGroup {
private val shuffleState =
shuffleStates.getOrElseUpdate(shuffleId, new ShuffleState(numReducers))
private val shuffleState: ShuffleState = {
// Note: we do _not_ want to just wrap this java ConcurrentHashMap into a Scala map and use
// .getOrElseUpdate() because that's actually NOT atomic.
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers))
shuffleStates.get(shuffleId)
}
val openStartTime = System.nanoTime
val serializerInstance = serializer.newInstance()
val writers: Array[DiskBlockObjectWriter] = {
Expand Down Expand Up @@ -110,7 +114,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)

/** Remove all the blocks / files related to a particular shuffle. */
private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
shuffleStates.get(shuffleId) match {
Option(shuffleStates.get(shuffleId)) match {
case Some(state) =>
for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) {
val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
Expand Down

0 comments on commit e6482fa

Please sign in to comment.