From c4a8403d41ef2624f65d70099a514f93ff4eee47 Mon Sep 17 00:00:00 2001 From: Piotr Kukielka Date: Mon, 21 May 2018 18:18:02 +0200 Subject: [PATCH] Changes after review --- src/compiler/scala/tools/nsc/Global.scala | 138 ++++++++---------- .../tools/nsc/backend/jvm/GenBCode.scala | 4 +- .../tools/nsc/profile/ThreadPoolFactory.scala | 3 +- .../nsc/reporters/BufferedReporter.scala | 26 ++++ .../tools/nsc/transform/SpecializeTypes.scala | 5 +- .../tools/nsc/typechecker/Analyzer.scala | 12 +- .../util/ThreadIdentityAwareThreadLocal.scala | 39 ----- src/reflect/scala/reflect/api/Trees.scala | 10 +- .../scala/reflect/internal/Positions.scala | 4 +- .../scala/reflect/internal/Reporting.scala | 4 +- .../scala/reflect/internal/Trees.scala | 5 +- .../reflect/internal/util/Parallel.scala | 86 +++++++++++ .../ThreadIdentityAwareThreadLocal.scala | 29 ---- 13 files changed, 199 insertions(+), 166 deletions(-) create mode 100644 src/compiler/scala/tools/nsc/reporters/BufferedReporter.scala delete mode 100644 src/compiler/scala/tools/nsc/util/ThreadIdentityAwareThreadLocal.scala create mode 100644 src/reflect/scala/reflect/internal/util/Parallel.scala delete mode 100644 src/reflect/scala/reflect/runtime/ThreadIdentityAwareThreadLocal.scala diff --git a/src/compiler/scala/tools/nsc/Global.scala b/src/compiler/scala/tools/nsc/Global.scala index 1093bf435b4c..368d24e4514e 100644 --- a/src/compiler/scala/tools/nsc/Global.scala +++ b/src/compiler/scala/tools/nsc/Global.scala @@ -10,14 +10,14 @@ package nsc import java.io.{File, FileNotFoundException, IOException} import java.net.URL import java.nio.charset.{Charset, CharsetDecoder, IllegalCharsetNameException, UnsupportedCharsetException} -import java.util.concurrent.atomic.AtomicInteger import scala.collection.{immutable, mutable} import io.{AbstractFile, Path, SourceReader} -import reporters.{Reporter, StoreReporter} -import util.{ClassPath, ThreadIdentityAwareThreadLocal, returning} +import reporters.{BufferedReporter, Reporter} +import util.{ClassPath, returning} import scala.reflect.ClassTag import scala.reflect.internal.util.{BatchSourceFile, NoSourceFile, ScalaClassLoader, ScriptSourceFile, SourceFile, StatisticsStatics} +import scala.reflect.internal.util.Parallel._ import scala.reflect.internal.pickling.PickleBuffer import symtab.{Flags, SymbolTable, SymbolTrackers} import symtab.classfile.Pickler @@ -78,10 +78,10 @@ class Global(var currentSettings: Settings, reporter0: Reporter) override def settings = currentSettings - private[this] val currentReporter: ThreadIdentityAwareThreadLocal[Reporter] = - ThreadIdentityAwareThreadLocal(new StoreReporter, reporter0) + private[this] val currentReporter: WorkerOrMainThreadLocal[Reporter] = + WorkerThreadLocal(new BufferedReporter, reporter0) - def reporter: Reporter = { reporter = reporter0 ; currentReporter.get } + def reporter: Reporter = currentReporter.get def reporter_=(newReporter: Reporter): Unit = currentReporter.set(newReporter match { @@ -392,81 +392,75 @@ class Global(var currentSettings: Settings, reporter0: Reporter) def apply(unit: CompilationUnit): Unit - def run() { - assertOnMainThread() - echoPhaseSummary(this) - Await.result(Future.sequence(currentRun.units map processUnit), Duration.Inf) - } + // Method added to allow stacking functionality on top of the`run` (e..g measuring run time). + // Overriding `run` is now not allowed since we want to be in charge of how units are processed. + def wrapRun(code: => Unit): Unit = code - final def applyPhase(unit: CompilationUnit): Unit = Await.result(processUnit(unit), Duration.Inf) + def afterUnit(unit: CompilationUnit): Unit = {} - implicit val ec: ExecutionContext = { - val threadPoolFactory = ThreadPoolFactory(Global.this, this) - val javaExecutor = threadPoolFactory.newUnboundedQueueFixedThreadPool(parallelThreads, "worker") - scala.concurrent.ExecutionContext.fromExecutorService(javaExecutor, (_) => ()) - } + final def run(): Unit = wrapRun { + assertOnMain() + + if (isDebugPrintEnabled) inform("[running phase " + name + " on " + currentRun.size + " compilation units]") - private def processUnit(unit: CompilationUnit)(implicit ec: ExecutionContext): Future[Unit] = { - if (settings.debug && (settings.verbose || currentRun.size < 5)) - inform("[running phase " + name + " on " + unit + "]") - - def runWithCurrentUnit(): Unit = { - val threadName = Thread.currentThread().getName - if (!threadName.contains("worker")) Thread.currentThread().setName(s"$threadName-worker") - val unit0 = currentUnit - - try { - if ((unit ne null) && unit.exists) lastSeenSourceFile = unit.source - currentRun.currentUnit = unit - apply(unit) - } finally { - currentRun.currentUnit = unit0 - currentRun.advanceUnit() - Thread.currentThread().setName(threadName) - - // If we are on main thread it means there are no worker threads at all. - // That in turn means we were already using main reporter all the time, so there is nothing more to do. - // Otherwise we have to forward messages from worker thread reporter to main one. - reporter match { - case rep: StoreReporter => - val mainReporter = currentReporter.main - mainReporter.synchronized { - rep.infos.foreach { info => - info.severity.toString match { - case "INFO" => mainReporter.info(info.pos, info.msg, force = false) - case "WARNING" => mainReporter.warning(info.pos, info.msg) - case "ERROR" => mainReporter.error(info.pos, info.msg) - } - } - } - case _ => + implicit val ec: ExecutionContextExecutorService = createExecutionContext() + val futures = currentRun.units.collect { + case unit if !cancelled(unit) => + Future { + processUnit(unit) + afterUnit(unit) + reporter } - } } - if (cancelled(unit)) Future.successful(()) - else if (isParallel) Future(runWithCurrentUnit()) - else Future.fromTry(scala.util.Try(runWithCurrentUnit())) + futures.foreach { future => + val workerReporter = Await.result(future, Duration.Inf) + workerReporter.asInstanceOf[BufferedReporter].flushTo(reporter) + } + } + + final def applyPhase(unit: CompilationUnit): Unit = { + assertOnWorker() + if (!cancelled(unit)) processUnit(unit) } - private def adjustWorkerThreadName(): Unit = { - val currentThreadName = Thread.currentThread().getName + private def processUnit(unit: CompilationUnit): Unit = { + assertOnWorker() + + reporter = new BufferedReporter + + if (isDebugPrintEnabled) inform("[running phase " + name + " on " + unit + "]") + + val unit0 = currentUnit + + try { + if ((unit ne null) && unit.exists) lastSeenSourceFile = unit.source + currentRun.currentUnit = unit + apply(unit) + } finally { + currentRun.currentUnit = unit0 + currentRun.advanceUnit() + } } - private def parallelThreads = settings.YparallelThreads.value + /* Only output a summary message under debug if we aren't echoing each file. */ + private def isDebugPrintEnabled: Boolean = settings.debug && !(settings.verbose || currentRun.size < 5) - private def isParallel = settings.YparallelPhases.containsPhase(this) + private def createExecutionContext(): ExecutionContextExecutorService = { + val isParallel = settings.YparallelPhases.containsPhase(this) + val parallelThreads = if (isParallel) settings.YparallelThreads.value else 1 + val threadPoolFactory = ThreadPoolFactory(Global.this, this) + val javaExecutor = threadPoolFactory.newUnboundedQueueFixedThreadPool(parallelThreads, "worker") + scala.concurrent.ExecutionContext.fromExecutorService(javaExecutor, _ => ()) + } /** Is current phase cancelled on this unit? */ private def cancelled(unit: CompilationUnit) = { + assertOnMain() // run the typer only if in `createJavadoc` mode val maxJavaPhase = if (createJavadoc) currentRun.typerPhase.id else currentRun.namerPhase.id reporter.cancelled || unit.isJava && this.id > maxJavaPhase } - - private def assertOnMainThread(): Unit = { - assert("main".equals(Thread.currentThread().getName), "") - } } // phaseName = "parser" @@ -996,7 +990,7 @@ class Global(var currentSettings: Settings, reporter0: Reporter) * of what file was being compiled when it broke. Since I really * really want to know, this hack. */ - protected var _lastSeenSourceFile: ThreadIdentityAwareThreadLocal[SourceFile] = ThreadIdentityAwareThreadLocal(NoSourceFile) + private[this] final val _lastSeenSourceFile: WorkerThreadLocal[SourceFile] = WorkerThreadLocal(NoSourceFile) @inline protected def lastSeenSourceFile: SourceFile = _lastSeenSourceFile.get @inline protected def lastSeenSourceFile_=(source: SourceFile): Unit = _lastSeenSourceFile.set(source) @@ -1103,12 +1097,6 @@ class Global(var currentSettings: Settings, reporter0: Reporter) */ override def currentRunId = curRunId - def echoPhaseSummary(ph: Phase) = { - /* Only output a summary message under debug if we aren't echoing each file. */ - if (settings.debug && !(settings.verbose || currentRun.size < 5)) - inform("[running phase " + ph.name + " on " + currentRun.size + " compilation units]") - } - def newSourceFile(code: String, filename: String = "") = new BatchSourceFile(filename, code) @@ -1135,7 +1123,7 @@ class Global(var currentSettings: Settings, reporter0: Reporter) */ var isDefined = false /** The currently compiled unit; set from GlobalPhase */ - private final val _currentUnit: ThreadIdentityAwareThreadLocal[CompilationUnit] = ThreadIdentityAwareThreadLocal(NoCompilationUnit, NoCompilationUnit) + private[this] final val _currentUnit: WorkerOrMainThreadLocal[CompilationUnit] = WorkerThreadLocal(NoCompilationUnit, NoCompilationUnit) def currentUnit: CompilationUnit = _currentUnit.get def currentUnit_=(unit: CompilationUnit): Unit = _currentUnit.set(unit) @@ -1175,8 +1163,8 @@ class Global(var currentSettings: Settings, reporter0: Reporter) /** A map from compiled top-level symbols to their picklers */ val symData = new mutable.AnyRefMap[Symbol, PickleBuffer] - private var phasec: Int = 0 // phases completed - private final val unitc: AtomicInteger = new AtomicInteger(0) // units completed this phase + private var phasec: Int = 0 // phases completed + private final val unitc: Counter = new Counter // units completed this phase def size = unitbuf.size override def toString = "scalac Run for:\n " + compiledFiles.toList.sorted.mkString("\n ") @@ -1297,7 +1285,7 @@ class Global(var currentSettings: Settings, reporter0: Reporter) * (for progress reporting) */ def advancePhase(): Unit = { - unitc.set(0) + unitc.reset() phasec += 1 refreshProgress() } @@ -1312,7 +1300,7 @@ class Global(var currentSettings: Settings, reporter0: Reporter) // for sbt def cancel(): Unit = { reporter.cancelled = true } - private def currentProgress = (phasec * size) + unitc.get() + private def currentProgress = (phasec * size) + unitc.get private def totalProgress = (phaseDescriptors.size - 1) * size // -1: drops terminal phase private def refreshProgress() = if (size > 0) progress(currentProgress, totalProgress) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala index 86660a5db232..b6b848bc2ade 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala @@ -66,11 +66,11 @@ abstract class GenBCode extends SubComponent { def apply(unit: CompilationUnit): Unit = codeGen.genUnit(unit) - override def run(): Unit = { + override def wrapRun(code: => Unit): Unit = { statistics.timed(bcodeTimer) { try { initialize() - super.run() // invokes `apply` for each compilation unit + code // invokes `apply` for each compilation unit generatedClassHandler.complete() } finally { this.close() diff --git a/src/compiler/scala/tools/nsc/profile/ThreadPoolFactory.scala b/src/compiler/scala/tools/nsc/profile/ThreadPoolFactory.scala index 33d8cefde10b..1caacd2224e0 100644 --- a/src/compiler/scala/tools/nsc/profile/ThreadPoolFactory.scala +++ b/src/compiler/scala/tools/nsc/profile/ThreadPoolFactory.scala @@ -4,6 +4,7 @@ import java.util.concurrent.ThreadPoolExecutor.AbortPolicy import java.util.concurrent._ import java.util.concurrent.atomic.AtomicInteger +import scala.reflect.internal.util.Parallel.WorkerThread import scala.tools.nsc.{Global, Phase} sealed trait ThreadPoolFactory { @@ -47,7 +48,7 @@ object ThreadPoolFactory { // the thread pool and executes them (on the thread created here). override def newThread(worker: Runnable): Thread = { val wrapped = wrapWorker(worker, shortId) - val t: Thread = new Thread(group, wrapped, namePrefix + threadNumber.getAndIncrement, 0) + val t: Thread = new WorkerThread(group, wrapped, namePrefix + threadNumber.getAndIncrement, 0) if (t.isDaemon != daemon) t.setDaemon(daemon) if (t.getPriority != priority) t.setPriority(priority) t diff --git a/src/compiler/scala/tools/nsc/reporters/BufferedReporter.scala b/src/compiler/scala/tools/nsc/reporters/BufferedReporter.scala new file mode 100644 index 000000000000..cfde8d9ed5bf --- /dev/null +++ b/src/compiler/scala/tools/nsc/reporters/BufferedReporter.scala @@ -0,0 +1,26 @@ +package scala.tools.nsc.reporters + +import scala.reflect.internal.util.Parallel.{assertOnMain, assertOnWorker} +import scala.reflect.internal.util.Position + +final class BufferedReporter extends Reporter { + private[this] var buffered = List.empty[BufferedMessage] + + protected def info0(pos: Position, msg: String, severity: Severity, force: Boolean): Unit = { + assertOnWorker() + buffered = BufferedMessage(pos, msg, severity, force) :: buffered + severity.count += 1 + } + + def flushTo(reporter: Reporter): Unit = { + assertOnMain() + val sev = Array(reporter.INFO, reporter.WARNING, reporter.ERROR) + buffered.reverse.foreach { + msg => + reporter.info1(msg.pos, msg.msg, sev(msg.severity.id), msg.force) + } + buffered = Nil + } + + private case class BufferedMessage(pos: Position, msg: String, severity: Severity, force: Boolean) +} \ No newline at end of file diff --git a/src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala b/src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala index c5de1789432c..9acde70122dc 100644 --- a/src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala +++ b/src/compiler/scala/tools/nsc/transform/SpecializeTypes.scala @@ -198,8 +198,9 @@ abstract class SpecializeTypes extends InfoTransform with TypingTransformers { override def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new SpecializationPhase(prev) class SpecializationPhase(prev: scala.tools.nsc.Phase) extends super.Phase(prev) { override def checkable = false - override def run(): Unit = { - super.run() + override def wrapRun(code: => Unit): Unit = { + code + exitingSpecialize { FunctionClass.seq.map(_.info) TupleClass.seq.map(_.info) diff --git a/src/compiler/scala/tools/nsc/typechecker/Analyzer.scala b/src/compiler/scala/tools/nsc/typechecker/Analyzer.scala index b25119d6ba30..5e7b349c0223 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Analyzer.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Analyzer.scala @@ -87,17 +87,17 @@ trait Analyzer extends AnyRef // Lacking a better fix, we clear it here (before the phase is created, meaning for each // compiler run). This is good enough for the resident compiler, which was the most affected. undoLog.clear() - override def run(): Unit = { + + override def afterUnit(unit: CompilationUnit): Unit = undoLog.clear() + + override def wrapRun(code: => Unit): Unit = { val start = if (StatisticsStatics.areSomeColdStatsEnabled) statistics.startTimer(statistics.typerNanos) else null - global.echoPhaseSummary(this) - for (unit <- currentRun.units) { - applyPhase(unit) - undoLog.clear() - } + code // defensive measure in case the bookkeeping in deferred macro expansion is buggy clearDelayed() if (StatisticsStatics.areSomeColdStatsEnabled) statistics.stopTimer(statistics.typerNanos, start) } + def apply(unit: CompilationUnit): Unit = { try { val typer = newTyper(rootContext(unit)) diff --git a/src/compiler/scala/tools/nsc/util/ThreadIdentityAwareThreadLocal.scala b/src/compiler/scala/tools/nsc/util/ThreadIdentityAwareThreadLocal.scala deleted file mode 100644 index 62f8db51917b..000000000000 --- a/src/compiler/scala/tools/nsc/util/ThreadIdentityAwareThreadLocal.scala +++ /dev/null @@ -1,39 +0,0 @@ -package scala.tools.nsc.util - -object ThreadIdentityAwareThreadLocal { - def apply[T](valueOnWorker: => T) = new ThreadIdentityAwareThreadLocal[T](valueOnWorker, null.asInstanceOf[T]) - def apply[T](valueOnWorker: => T, valueOnMain: => T) = new ThreadIdentityAwareThreadLocal[T](valueOnWorker, valueOnMain) -} - -// `ThreadIdentityAwareThreadLocal` allows us to have different (sub)type of values on main and worker threads. -// It's useful in cases like reporter, when on workers we want to just store messages and on main we want to print them, -// but also in the cases when we do not expect some value to be read/write on the main thread, -// and we want to discover violations of that rule. -class ThreadIdentityAwareThreadLocal[T](valueOnWorker: => T, valueOnMain: => T) { - var main: T = valueOnMain - - private val worker: ThreadLocal[T] = new ThreadLocal[T] { - override def initialValue(): T = valueOnWorker - } - - // That logic may look a little bit funky because we need to consider cases - // where there is only one thread which is both main and worker at a given time. - private def isOnMainThread: Boolean = { - val currentThreadName = Thread.currentThread().getName - def isMainDefined = currentThreadName.startsWith("main") && valueOnMain != null - def isWorkerDefined = currentThreadName.contains("-worker") - - assert(isMainDefined || isWorkerDefined, "Variable cannot be accessed on the main thread") - - isMainDefined - } - - def get: T = if (isOnMainThread) main else worker.get() - - def set(value: T): Unit = if (isOnMainThread) main = value else worker.set(value) - - def reset(): Unit = { - main = valueOnMain - worker.set(valueOnWorker) - } -} diff --git a/src/reflect/scala/reflect/api/Trees.scala b/src/reflect/scala/reflect/api/Trees.scala index a6e2a711bd9b..98a13f78d2aa 100644 --- a/src/reflect/scala/reflect/api/Trees.scala +++ b/src/reflect/scala/reflect/api/Trees.scala @@ -6,7 +6,7 @@ package scala package reflect package api -import scala.reflect.runtime.ThreadIdentityAwareThreadLocal +import scala.reflect.internal.util.Parallel.WorkerThreadLocal /** * EXPERIMENTAL @@ -2464,9 +2464,9 @@ trait Trees { self: Universe => * @group Traversal */ class Traverser { - protected[scala] def currentOwner: Symbol = _currentOwner.get - protected[scala] def currentOwner_=(sym: Symbol): Unit = _currentOwner.set(sym) - private val _currentOwner: ThreadIdentityAwareThreadLocal[Symbol] = ThreadIdentityAwareThreadLocal[Symbol](rootMirror.RootClass) + @inline final protected[scala] def currentOwner: Symbol = _currentOwner.get + @inline final protected[scala] def currentOwner_=(sym: Symbol): Unit = _currentOwner.set(sym) + private final val _currentOwner: WorkerThreadLocal[Symbol] = WorkerThreadLocal(rootMirror.RootClass) /** Traverse something which Trees contain, but which isn't a Tree itself. */ def traverseName(name: Name): Unit = () @@ -2540,7 +2540,7 @@ trait Trees { self: Universe => /** The current owner symbol. */ protected[scala] def currentOwner: Symbol = _currentOwner.get protected[scala] def currentOwner_=(sym: Symbol): Unit = _currentOwner.set(sym) - private val _currentOwner: ThreadIdentityAwareThreadLocal[Symbol] = ThreadIdentityAwareThreadLocal[Symbol](rootMirror.RootClass) + private final val _currentOwner: WorkerThreadLocal[Symbol] = WorkerThreadLocal(rootMirror.RootClass) /** The enclosing method of the currently transformed tree. */ protected def currentMethod = { diff --git a/src/reflect/scala/reflect/internal/Positions.scala b/src/reflect/scala/reflect/internal/Positions.scala index 8f08dbfcd521..a77a46ca3b1c 100644 --- a/src/reflect/scala/reflect/internal/Positions.scala +++ b/src/reflect/scala/reflect/internal/Positions.scala @@ -5,7 +5,7 @@ package internal import scala.collection.mutable import util._ import scala.collection.mutable.ListBuffer -import scala.reflect.runtime.ThreadIdentityAwareThreadLocal +import scala.reflect.internal.util.Parallel.WorkerThreadLocal /** Handling range positions * atPos, the main method in this trait, will add positions to a tree, @@ -279,7 +279,7 @@ trait Positions extends api.Positions { self: SymbolTable => } trait PosAssigner extends InternalTraverser { - private val _pos = ThreadIdentityAwareThreadLocal[Position](NoPosition) + private val _pos: WorkerThreadLocal[Position] = WorkerThreadLocal(NoPosition) @inline def pos: Position = _pos.get @inline def pos_=(position: Position): Unit = _pos.set(position) } diff --git a/src/reflect/scala/reflect/internal/Reporting.scala b/src/reflect/scala/reflect/internal/Reporting.scala index f7b46cdb656a..2c920def0746 100644 --- a/src/reflect/scala/reflect/internal/Reporting.scala +++ b/src/reflect/scala/reflect/internal/Reporting.scala @@ -78,8 +78,8 @@ import util.Position */ abstract class Reporter { protected def info0(pos: Position, msg: String, severity: Severity, force: Boolean): Unit - - def infoRaw(pos: Position, msg: String, severity: Severity, force: Boolean): Unit = info0(pos, msg, severity, force) + def info1(pos: Position, msg: String, severity: Severity, force: Boolean): Unit = + info0(pos, msg, severity, force) def echo(pos: Position, msg: String): Unit = info0(pos, msg, INFO, force = true) def warning(pos: Position, msg: String): Unit = info0(pos, msg, WARNING, force = false) diff --git a/src/reflect/scala/reflect/internal/Trees.scala b/src/reflect/scala/reflect/internal/Trees.scala index ad3143c7148e..9f0b551fc709 100644 --- a/src/reflect/scala/reflect/internal/Trees.scala +++ b/src/reflect/scala/reflect/internal/Trees.scala @@ -7,17 +7,16 @@ package scala package reflect package internal -import java.util.concurrent.atomic.AtomicInteger - import Flags._ import scala.collection.mutable +import scala.reflect.internal.util.Parallel.Counter import scala.reflect.macros.Attachments import util.{Statistics, StatisticsStatics} trait Trees extends api.Trees { self: SymbolTable => - private[scala] final val nodeCount: AtomicInteger = new AtomicInteger(0) + private[scala] final val nodeCount: Counter = new Counter protected def treeLine(t: Tree): String = if (t.pos.isDefined && t.pos.isRange) t.pos.lineContent.drop(t.pos.column - 1).take(t.pos.end - t.pos.start + 1) diff --git a/src/reflect/scala/reflect/internal/util/Parallel.scala b/src/reflect/scala/reflect/internal/util/Parallel.scala new file mode 100644 index 000000000000..bab9a6d0727d --- /dev/null +++ b/src/reflect/scala/reflect/internal/util/Parallel.scala @@ -0,0 +1,86 @@ +package scala.reflect.internal.util + +import java.util.concurrent.atomic.AtomicInteger + +object Parallel { + + class WorkerThread(group: ThreadGroup, target: Runnable, name: String, + stackSize: Long) extends Thread(group, target, name, stackSize) + + def WorkerThreadLocal[T <: AnyRef](valueOnWorker: => T, valueOnMain: => T) = new WorkerOrMainThreadLocal[T](valueOnWorker, valueOnMain) + + def WorkerThreadLocal[T <: AnyRef](valueOnWorker: => T) = new WorkerThreadLocal[T](valueOnWorker) + + // `WorkerOrMainThreadLocal` allows us to have different (sub)type of values on main and worker threads. + // It's useful in cases like reporter, when on workers we want to just store messages and on main we want to print them, + class WorkerOrMainThreadLocal[T](valueOnWorker: => T, valueOnMain: => T) { + + private var main: T = null.asInstanceOf[T] + + private val worker: ThreadLocal[T] = new ThreadLocal[T] { + override def initialValue(): T = valueOnWorker + } + + final def get: T = { + if (isWorkerThread) worker.get() + else { + if (main == null) main = valueOnMain + main + } + } + + final def set(value: T): Unit = if (isWorkerThread) worker.set(value) else main = value + + final def reset(): Unit = { + worker.remove() + main = valueOnMain + } + } + + + // `WorkerThreadLocal` allows us to detect some value to be read/write on the main thread, + // and we want to discover violations of that rule. + class WorkerThreadLocal[T](valueOnWorker: => T) + extends WorkerOrMainThreadLocal(valueOnWorker, throw new IllegalStateException("not allowed on main thread")) + + class Counter { + private val count = new AtomicInteger + + def get: Int = count.get() + + def reset(): Unit = { + assertOnMain() + count.set(0) + } + + def incrementAndGet(): Int = count.incrementAndGet + + def getAndIncrement(): Int = count.getAndIncrement + + override def toString: String = s"Counter[$count]" + } + + def assertOnMain(): Unit = { + assert(!isWorkerThread) + } + + def assertOnWorker(): Unit = { + assert(isWorkerThread) + } + + def isWorkerThread: Boolean = { + Thread.currentThread.isInstanceOf[WorkerThread] || isWorker.get() + } + + // This needs to be a ThreadLocal to support parallel compilation + val isWorker: ThreadLocal[Boolean] = new ThreadLocal[Boolean] { + override def initialValue(): Boolean = false + } + + @inline def asWorkerThread[T](fn: => T): T = { + assertOnMain() + isWorker.set(true) + try fn finally isWorker.set(false) + } + +} \ No newline at end of file diff --git a/src/reflect/scala/reflect/runtime/ThreadIdentityAwareThreadLocal.scala b/src/reflect/scala/reflect/runtime/ThreadIdentityAwareThreadLocal.scala deleted file mode 100644 index 06b09def99bc..000000000000 --- a/src/reflect/scala/reflect/runtime/ThreadIdentityAwareThreadLocal.scala +++ /dev/null @@ -1,29 +0,0 @@ -package scala.reflect.runtime - -object ThreadIdentityAwareThreadLocal { - def apply[T](valueOnWorker: => T) = new ThreadIdentityAwareThreadLocal[T](valueOnWorker, null.asInstanceOf[T]) - def apply[T](valueOnWorker: => T, valueOnMain: => T) = new ThreadIdentityAwareThreadLocal[T](valueOnWorker, valueOnMain) -} - -class ThreadIdentityAwareThreadLocal[T](valueOnWorker: => T, valueOnMain: => T) { - var main: T = valueOnMain - - private val worker: ThreadLocal[T] = new ThreadLocal[T] { - override def initialValue(): T = valueOnWorker - } - - private def isOnMainThread = { - val isMain = "main".equals(Thread.currentThread().getName) - assert(!(isMain && valueOnMain == null), "Variable cannot be accessed on the main thread") - isMain - } - - def get: T = if (isOnMainThread) main else worker.get() - - def set(value: T): Unit = if (isOnMainThread) main = value else worker.set(value) - - def reset(): Unit = { - main = valueOnMain - worker.set(valueOnWorker) - } -}