diff --git a/.rat-excludes b/.rat-excludes index bccb043c2bb55..eaefef1b0aa2e 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -25,6 +25,7 @@ log4j-defaults.properties bootstrap-tooltip.js jquery-1.11.1.min.js sorttable.js +.*avsc .*txt .*json .*data diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bf3c3a6ceb5ef..3848734d6f639 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -66,10 +66,15 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** * Whether the cleaning thread will block on cleanup tasks. - * This is set to true only for tests. + * + * Due to SPARK-3015, this is set to true by default. This is intended to be only a temporary + * workaround for the issue, which is ultimately caused by the way the BlockManager actors + * issue inter-dependent blocking Akka messages to each other at high frequencies. This happens, + * for instance, when the driver performs a GC and cleans up all broadcast blocks that are no + * longer in scope. */ private val blockOnCleanupTasks = sc.conf.getBoolean( - "spark.cleaner.referenceTracking.blocking", false) + "spark.cleaner.referenceTracking.blocking", true) @volatile private var stopped = false @@ -174,9 +179,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - - // Used for testing. These methods explicitly blocks until cleanup is completed - // to ensure that more reliable testing. } private object ContextCleaner { diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index f40baa8e43592..5c262bcbddf76 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.interrupted) { + if (context.isInterrupted) { throw new TaskKilledException } else { delegate.hasNext diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 22d8d1cb1ddcf..fc36e37c53f5e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -210,12 +210,22 @@ object SparkEnv extends Logging { "MapOutputTracker", new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + // Let the user specify short names for shuffle managers + val shortShuffleMgrNames = Map( + "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") + val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) + + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager, mapOutputTracker) + serializer, conf, securityManager, mapOutputTracker, shuffleManager) val connectionManager = blockManager.connectionManager @@ -250,16 +260,6 @@ object SparkEnv extends Logging { "." } - // Let the user specify short names for shuffle managers - val shortShuffleMgrNames = Map( - "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") - val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") - val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) - val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - - val shuffleMemoryManager = new ShuffleMemoryManager(conf) - // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 51f40c339d13c..2b99b8a5af250 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.TaskCompletionListener + /** * :: DeveloperApi :: * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + * @param taskMetrics performance metrics of the task */ @DeveloperApi class TaskContext( @@ -39,13 +47,45 @@ class TaskContext( def splitId = partitionId // List of callback functions to execute when the task completes. - @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] // Whether the corresponding task has been killed. - @volatile var interrupted: Boolean = false + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + /** Checks whether the task has completed. */ + def isCompleted: Boolean = completed - // Whether the task has completed, before the onCompleteCallbacks are executed. - @volatile var completed: Boolean = false + /** Checks whether the task has been killed. */ + def isInterrupted: Boolean = interrupted + + // TODO: Also track whether the task has completed successfully or with exception. + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } /** * Add a callback function to be executed on task completion. An example use @@ -53,13 +93,22 @@ class TaskContext( * Will be called in any situation - success, failure, or cancellation. * @param f Callback function. */ + @deprecated("use addTaskCompletionListener", "1.1.0") def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += f + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } } - def executeOnCompleteCallbacks() { + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { completed = true // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { _() } + onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index f3b05e1243045..49dc95f349eac 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SerializableWritable, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ @@ -42,7 +43,7 @@ private[python] object Converter extends Logging { defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { - val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]] + val c = Utils.classForName(cc).newInstance().asInstanceOf[Converter[Any, Any]] logInfo(s"Loaded converter: $cc") c } match { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0b5322c6fb965..10210a2927dcc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -68,7 +68,7 @@ private[spark] class PythonRDD( // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) - context.addOnCompleteCallback { () => + context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() // Cleanup the worker socket. This will also cause the Python worker to exit. @@ -137,7 +137,7 @@ private[spark] class PythonRDD( } } catch { - case e: Exception if context.interrupted => + case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) throw new TaskKilledException @@ -176,7 +176,7 @@ private[spark] class PythonRDD( /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ def shutdownOnTaskCompletion() { - assert(context.completed) + assert(context.isCompleted) this.interrupt() } @@ -209,7 +209,7 @@ private[spark] class PythonRDD( PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) dataOut.flush() } catch { - case e: Exception if context.completed || context.interrupted => + case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) case e: Exception => @@ -235,10 +235,10 @@ private[spark] class PythonRDD( override def run() { // Kill the worker if it is interrupted, checking until task completion. // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.interrupted && !context.completed) { + while (!context.isInterrupted && !context.isCompleted) { Thread.sleep(2000) } - if (!context.completed) { + if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") env.destroyPythonWorker(pythonExec, envVars.toMap, worker) @@ -315,6 +315,14 @@ private[spark] object PythonRDD extends Logging { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } + def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + sc.broadcast(obj) + } + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { // The right way to implement this would be to use TypeTags to get the full // type of T. Since I don't want to introduce breaking changes throughout the @@ -372,8 +380,8 @@ private[spark] object PythonRDD extends Logging { batchSize: Int) = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, @@ -440,9 +448,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] - val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] if (path.isDefined) { sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) } else { @@ -509,9 +517,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] - val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] if (path.isDefined) { sc.sc.hadoopFile(path.get, fc, kc, vc) } else { @@ -558,7 +566,7 @@ private[spark] object PythonRDD extends Logging { for { k <- Option(keyClass) v <- Option(valueClass) - } yield (Class.forName(k), Class.forName(v)) + } yield (Utils.classForName(k), Utils.classForName(v)) } private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String, @@ -621,10 +629,10 @@ private[spark] object PythonRDD extends Logging { val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) - val codec = Option(compressionCodecClass).map(Class.forName(_).asInstanceOf[Class[C]]) + val codec = Option(compressionCodecClass).map(Utils.classForName(_).asInstanceOf[Class[C]]) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) } @@ -653,7 +661,7 @@ private[spark] object PythonRDD extends Logging { val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] converted.saveAsNewAPIHadoopFile(path, kc, vc, fc, mergedConf) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 72d0589689e71..d3674427b1271 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -46,6 +46,11 @@ private[spark] class ApplicationInfo( init() + private def readObject(in: java.io.ObjectInputStream): Unit = { + in.defaultReadObject() + init() + } + private def init() { state = ApplicationState.WAITING executors = new mutable.HashMap[Int, ExecutorInfo] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala index c87b66f047dc8..38db02cd2421b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source class ApplicationSource(val application: ApplicationInfo) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.%s.%s".format("application", application.desc.name, + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.%s.%s".format("application", application.desc.name, System.currentTimeMillis()) metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 36c1b87b7f684..9c3f79f1244b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source private[spark] class MasterSource(val master: Master) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "master" + override val metricRegistry = new MetricRegistry() + override val sourceName = "master" // Gauge for worker numbers in cluster metricRegistry.register(MetricRegistry.name("workers"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 80fde7e4b2624..81400af22c0bf 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -72,7 +72,6 @@ private[spark] class Worker( val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) val testing: Boolean = sys.props.contains("spark.testing") - val masterLock: Object = new Object() var master: ActorSelection = null var masterAddress: Address = null var activeMasterUrl: String = "" @@ -145,18 +144,16 @@ private[spark] class Worker( } def changeMaster(url: String, uiUrl: String) { - masterLock.synchronized { - activeMasterUrl = url - activeMasterWebUiUrl = uiUrl - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = activeMasterUrl match { - case Master.sparkUrlRegex(_host, _port) => - Address("akka.tcp", Master.systemName, _host, _port.toInt) - case x => - throw new SparkException("Invalid spark URL: " + x) - } - connected = true + activeMasterUrl = url + activeMasterWebUiUrl = uiUrl + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + masterAddress = activeMasterUrl match { + case Master.sparkUrlRegex(_host, _port) => + Address("akka.tcp", Master.systemName, _host, _port.toInt) + case x => + throw new SparkException("Invalid spark URL: " + x) } + connected = true } def tryRegisterAllMasters() { @@ -199,9 +196,7 @@ private[spark] class Worker( } case SendHeartbeat => - masterLock.synchronized { - if (connected) { master ! Heartbeat(workerId) } - } + if (connected) { master ! Heartbeat(workerId) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor @@ -244,9 +239,7 @@ private[spark] class Worker( manager.start() coresUsed += cores_ memoryUsed += memory_ - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) - } + master ! ExecutorStateChanged(appId, execId, manager.state, None, None) } catch { case e: Exception => { logError("Failed to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) @@ -254,17 +247,13 @@ private[spark] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None) - } + master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None) } } } case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) - } + master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -330,9 +319,7 @@ private[spark] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - masterLock.synchronized { - master ! DriverStateChanged(driverId, state, exception) - } + master ! DriverStateChanged(driverId, state, exception) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala index b7ddd8c816cbc..df1e01b23b932 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source private[spark] class WorkerSource(val worker: Worker) extends Source { - val sourceName = "worker" - val metricRegistry = new MetricRegistry() + override val sourceName = "worker" + override val metricRegistry = new MetricRegistry() metricRegistry.register(MetricRegistry.name("executors"), new Gauge[Int] { override def getValue: Int = worker.executors.size diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index eac1f2326a29d..fb3f7bd54bbfa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -99,6 +99,9 @@ private[spark] class Executor( private val urlClassLoader = createClassLoader() private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + // Set the classloader for serializer + env.serializer.setDefaultClassLoader(urlClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 0ed52cfe9df61..d6721586566c2 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -35,9 +35,10 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) }) } - val metricRegistry = new MetricRegistry() + override val metricRegistry = new MetricRegistry() + // TODO: It would be nice to pass the application name here - val sourceName = "executor.%s".format(executorId) + override val sourceName = "executor.%s".format(executorId) // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala index f865f9648a91e..635bff2cd7ec8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala @@ -21,12 +21,9 @@ import com.codahale.metrics.MetricRegistry import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} private[spark] class JvmSource extends Source { - val sourceName = "jvm" - val metricRegistry = new MetricRegistry() + override val sourceName = "jvm" + override val metricRegistry = new MetricRegistry() - val gcMetricSet = new GarbageCollectorMetricSet - val memGaugeSet = new MemoryUsageGaugeSet - - metricRegistry.registerAll(gcMetricSet) - metricRegistry.registerAll(memGaugeSet) + metricRegistry.registerAll(new GarbageCollectorMetricSet) + metricRegistry.registerAll(new MemoryUsageGaugeSet) } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 95f96b8463a01..e77d762bdf221 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -22,6 +22,7 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ +import java.util.{Timer, TimerTask} import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} @@ -61,17 +62,17 @@ private[spark] class ConnectionManager( var ackMessage: Option[Message] = None def markDone(ackMessage: Option[Message]) { - this.synchronized { - this.ackMessage = ackMessage - completionHandler(this) - } + this.ackMessage = ackMessage + completionHandler(this) } } private val selector = SelectorProvider.provider.openSelector() + private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) // default to 30 second timeout waiting for authentication private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -652,19 +653,27 @@ private[spark] class ConnectionManager( } } if (bufferMessage.hasAckId()) { - val sentMessageStatus = messageStatuses.synchronized { + messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status + status.markDone(Some(message)) } case None => { - throw new Exception("Could not find reference for received ack message " + - message.id) + /** + * We can fall down on this code because of following 2 cases + * + * (1) Invalid ack sent due to buggy code. + * + * (2) Late-arriving ack for a SendMessageStatus + * To avoid unwilling late-arriving ack + * caused by long pause like GC, you can set + * larger value than default to spark.core.connection.ack.wait.timeout + */ + logWarning(s"Could not find reference for received ack Message ${message.id}") } } } - sentMessageStatus.markDone(Some(message)) } else { var ackMessage : Option[Message] = None try { @@ -836,9 +845,23 @@ private[spark] class ConnectionManager( def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) : Future[Message] = { val promise = Promise[Message]() + + val timeoutTask = new TimerTask { + override def run(): Unit = { + messageStatuses.synchronized { + messageStatuses.remove(message.id).foreach ( s => { + promise.failure( + new IOException(s"sendMessageReliably failed because ack " + + "was not received within ${ackTimeout} sec")) + }) + } + } + } + val status = new MessageStatus(message, connectionManagerId, s => { + timeoutTask.cancel() s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd + case None => // Indicates a failure where we either never sent or never got ACK'd promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) case Some(ackMessage) => if (ackMessage.hasError) { @@ -852,6 +875,8 @@ private[spark] class ConnectionManager( messageStatuses.synchronized { messageStatuses += ((message.id, status)) } + + ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } @@ -861,6 +886,7 @@ private[spark] class ConnectionManager( } def stop() { + ackTimeoutMonitor.cancel() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 34c51b833025e..20938781ac694 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging { val deserializeStream = serializer.deserializeStream(fileInputStream) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => deserializeStream.close()) + context.addTaskCompletionListener(context => deserializeStream.close()) deserializeStream.asIterator.asInstanceOf[Iterator[T]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index f233544d128f5..e0494ee39657c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -95,7 +95,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { - // Compute the minimum and the maxium + // Scala's built-in range has issues. See #SI-8782 + def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { + val span = max - min + Range.Int(0, steps, 1).map(s => min + (s * span) / steps) :+ max + } + // Compute the minimum and the maximum val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(Double.NegativeInfinity, Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => @@ -107,9 +112,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { throw new UnsupportedOperationException( "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") } - val increment = (max-min)/bucketCount.toDouble - val range = if (increment != 0) { - Range.Double.inclusive(min, max, increment) + val range = if (min != max) { + // Range.Double.inclusive(min, max, increment) + // The above code doesn't always work. See Scala bug #SI-8782. + // https://issues.scala-lang.org/browse/SI-8782 + customRange(min, max, bucketCount) } else { List(min, min) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8d92ea01d9a3f..c8623314c98eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -197,7 +197,7 @@ class HadoopRDD[K, V]( reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } + context.addTaskCompletionListener{ context => closeIfNeeded() } val key: K = reader.createKey() val value: V = reader.createValue() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 8947e66f4577c..0e38f224ac81d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag]( } override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { - context.addOnCompleteCallback{ () => closeIfNeeded() } + context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7dfec9a18ec67..58f707b9b4634 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -129,7 +129,7 @@ class NewHadoopRDD[K, V]( context.taskMetrics.inputMetrics = Some(inputMetrics) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) + context.addTaskCompletionListener(context => close()) var havePair = false var finished = false diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 19e10bd04681b..daea2617e62ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1299,6 +1299,19 @@ abstract class RDD[T: ClassTag]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString: String = { + // Get a debug description of an rdd without its children + def debugSelf (rdd: RDD[_]): Seq[String] = { + import Utils.bytesToString + + val persistence = storageLevel.description + val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => + " CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format( + info.numCachedPartitions, bytesToString(info.memSize), + bytesToString(info.tachyonSize), bytesToString(info.diskSize))) + + s"$rdd [$persistence]" +: storageInfo + } + // Apply a different rule to the last child def debugChildren(rdd: RDD[_], prefix: String): Seq[String] = { val len = rdd.dependencies.length @@ -1324,7 +1337,11 @@ abstract class RDD[T: ClassTag]( val partitionStr = "(" + rdd.partitions.size + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) - Seq(partitionStr + " " + rdd) ++ debugChildren(rdd, nextPrefix) + + debugSelf(rdd).zipWithIndex.map{ + case (desc: String, 0) => s"$partitionStr $desc" + case (desc: String, _) => s"$nextPrefix $desc" + } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { val partitionStr = "(" + rdd.partitions.size + ")" @@ -1334,7 +1351,11 @@ abstract class RDD[T: ClassTag]( thisPrefix + (if (isLastChild) " " else "| ") + (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset))) - Seq(thisPrefix + "+-" + partitionStr + " " + rdd) ++ debugChildren(rdd, nextPrefix) + + debugSelf(rdd).zipWithIndex.map{ + case (desc: String, 0) => s"$thisPrefix+-$partitionStr $desc" + case (desc: String, _) => s"$nextPrefix$desc" + } ++ debugChildren(rdd, nextPrefix) } def debugString(rdd: RDD[_], prefix: String = "", @@ -1342,9 +1363,8 @@ abstract class RDD[T: ClassTag]( isLastChild: Boolean = false): Seq[String] = { if (isShuffle) { shuffleDebugString(rdd, prefix, isLastChild) - } - else { - Seq(prefix + rdd) ++ debugChildren(rdd, prefix) + } else { + debugSelf(rdd).map(prefix + _) ++ debugChildren(rdd, prefix) } } firstDebugString(this).mkString("\n") diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 36bbaaa3f1c85..b86cfbfa48fbe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -634,7 +634,7 @@ class DAGScheduler( val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { - taskContext.executeOnCompleteCallbacks() + taskContext.markTaskCompleted() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 5878e733908f5..94944399b134a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -24,8 +24,8 @@ import org.apache.spark.metrics.source.Source private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: SparkContext) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.DAGScheduler".format(sc.appName) + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.DAGScheduler".format(sc.appName) metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.failedStages.size diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 406147f167bf3..7378ce923f0ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -127,6 +127,8 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) override def onApplicationEnd(event: SparkListenerApplicationEnd) = logEvent(event, flushLogger = true) + // No-op because logging every update would be overkill + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } /** * Stop logging events. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index d09fd7aa57642..2ccbd8edeb028 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U]( try { func(context, rdd.iterator(partition, context)) } finally { - context.executeOnCompleteCallbacks() + context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 11255c07469d4..381eff2147e95 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask( } throw e } finally { - context.executeOnCompleteCallbacks() + context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index cbe0bc0bcb0a5..6aa0cca06878d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex def kill(interruptThread: Boolean) { _killed = true if (context != null) { - context.interrupted = true + context.markInterrupted() } if (interruptThread && taskThread != null) { taskThread.interrupt() diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 34bc3124097bb..554a33ce7f1a6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -63,8 +63,11 @@ extends DeserializationStream { def close() { objIn.close() } } -private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { + +private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) + extends SerializerInstance { + + override def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) @@ -72,23 +75,23 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize ByteBuffer.wrap(bos.toByteArray) } - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis) - in.readObject().asInstanceOf[T] + in.readObject() } - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis, loader) - in.readObject().asInstanceOf[T] + in.readObject() } - def serializeStream(s: OutputStream): SerializationStream = { + override def serializeStream(s: OutputStream): SerializationStream = { new JavaSerializationStream(s, counterReset) } - def deserializeStream(s: InputStream): DeserializationStream = { + override def deserializeStream(s: InputStream): DeserializationStream = { new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader) } @@ -109,7 +112,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) - def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset) + override def newInstance(): SerializerInstance = { + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) + new JavaSerializerInstance(counterReset, classLoader) + } override def writeExternal(out: ObjectOutput) { out.writeInt(counterReset) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 85944eabcfefc..87ef9bb0b43c6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf) val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() kryo.setRegistrationRequired(registrationRequired) - val classLoader = Thread.currentThread.getContextClassLoader + + val oldClassLoader = Thread.currentThread.getContextClassLoader + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. @@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf) try { val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] + + // Use the default classloader when calling the user registrator. + Thread.currentThread.setContextClassLoader(classLoader) reg.registerClasses(kryo) } catch { - case e: Exception => + case e: Exception => throw new SparkException(s"Failed to invoke $regCls", e) + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) } } @@ -99,7 +106,7 @@ class KryoSerializer(conf: SparkConf) kryo } - def newInstance(): SerializerInstance = { + override def newInstance(): SerializerInstance = { new KryoSerializerInstance(this) } } @@ -108,20 +115,20 @@ private[spark] class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { val output = new KryoOutput(outStream) - def writeObject[T: ClassTag](t: T): SerializationStream = { + override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - def flush() { output.flush() } - def close() { output.close() } + override def flush() { output.flush() } + override def close() { output.close() } } private[spark] class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - val input = new KryoInput(inStream) + private val input = new KryoInput(inStream) - def readObject[T: ClassTag](): T = { + override def readObject[T: ClassTag](): T = { try { kryo.readClassAndObject(input).asInstanceOf[T] } catch { @@ -131,31 +138,31 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } } - def close() { + override def close() { // Kryo's Input automatically closes the input stream it is using. input.close() } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.newKryo() + private val kryo = ks.newKryo() // Make these lazy vals to avoid creating a buffer unless we use them - lazy val output = ks.newKryoOutput() - lazy val input = new KryoInput() + private lazy val output = ks.newKryoOutput() + private lazy val input = new KryoInput() - def serialize[T: ClassTag](t: T): ByteBuffer = { + override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() kryo.writeClassAndObject(output, t) ByteBuffer.wrap(output.toBytes) } - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { input.setBuffer(bytes.array) kryo.readClassAndObject(input).asInstanceOf[T] } - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) input.setBuffer(bytes.array) @@ -164,11 +171,11 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ obj } - def serializeStream(s: OutputStream): SerializationStream = { + override def serializeStream(s: OutputStream): SerializationStream = { new KryoSerializationStream(kryo, s) } - def deserializeStream(s: InputStream): DeserializationStream = { + override def deserializeStream(s: InputStream): DeserializationStream = { new KryoDeserializationStream(kryo, s) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f2f5cea469c61..a9144cdd97b8c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -43,11 +43,30 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator} * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi -trait Serializer { +abstract class Serializer { + + /** + * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should + * make sure it is using this when set. + */ + @volatile protected var defaultClassLoader: Option[ClassLoader] = None + + /** + * Sets a class loader for the serializer to use in deserialization. + * + * @return this Serializer object + */ + def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { + defaultClassLoader = Some(classLoader) + this + } + + /** Creates a new [[SerializerInstance]]. */ def newInstance(): SerializerInstance } +@DeveloperApi object Serializer { def getSerializer(serializer: Serializer): Serializer = { if (serializer == null) SparkEnv.get.serializer else serializer @@ -64,7 +83,7 @@ object Serializer { * An instance of a serializer, for use by one thread at a time. */ @DeveloperApi -trait SerializerInstance { +abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer def deserialize[T: ClassTag](bytes: ByteBuffer): T @@ -74,21 +93,6 @@ trait SerializerInstance { def serializeStream(s: OutputStream): SerializationStream def deserializeStream(s: InputStream): DeserializationStream - - def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = { - // Default implementation uses serializeStream - val stream = new ByteArrayOutputStream() - serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.wrap(stream.toByteArray) - buffer.flip() - buffer - } - - def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - // Default implementation uses deserializeStream - buffer.rewind() - deserializeStream(new ByteBufferInputStream(buffer)).asIterator - } } /** @@ -96,7 +100,7 @@ trait SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -trait SerializationStream { +abstract class SerializationStream { def writeObject[T: ClassTag](t: T): SerializationStream def flush(): Unit def close(): Unit @@ -115,7 +119,7 @@ trait SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -trait DeserializationStream { +abstract class DeserializationStream { def readObject[T: ClassTag](): T def close(): Unit diff --git a/core/src/main/scala/org/apache/spark/serializer/package-info.java b/core/src/main/scala/org/apache/spark/serializer/package-info.java index 4c0b73ab36a00..207c6e02e4293 100644 --- a/core/src/main/scala/org/apache/spark/serializer/package-info.java +++ b/core/src/main/scala/org/apache/spark/serializer/package-info.java @@ -18,4 +18,4 @@ /** * Pluggable serializers for RDD and shuffle data. */ -package org.apache.spark.serializer; \ No newline at end of file +package org.apache.spark.serializer; diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e8bbd298c631a..e4c3d58905e7f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -33,6 +33,7 @@ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -57,11 +58,12 @@ private[spark] class BlockManager( maxMemory: Long, val conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) + mapOutputTracker: MapOutputTracker, + shuffleManager: ShuffleManager) extends Logging { private val port = conf.getInt("spark.blockManager.port", 0) - val shuffleBlockManager = new ShuffleBlockManager(this) + val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) val connectionManager = @@ -142,9 +144,10 @@ private[spark] class BlockManager( serializer: Serializer, conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) = { + mapOutputTracker: MapOutputTracker, + shuffleManager: ShuffleManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager, mapOutputTracker) + conf, securityManager, mapOutputTracker, shuffleManager) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 3f14c40ec61cb..49fea6d9e2a76 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -24,8 +24,8 @@ import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: SparkContext) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.BlockManager".format(sc.appName) + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.BlockManager".format(sc.appName) metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 3565719b54545..b8f5d3a5b02aa 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConversions._ import org.apache.spark.Logging import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} @@ -62,7 +63,8 @@ private[spark] trait ShuffleWriterGroup { */ // TODO: Factor this into a separate class for each ShuffleManager implementation private[spark] -class ShuffleBlockManager(blockManager: BlockManager) extends Logging { +class ShuffleBlockManager(blockManager: BlockManager, + shuffleManager: ShuffleManager) extends Logging { def conf = blockManager.conf // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. @@ -71,8 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { conf.getBoolean("spark.shuffle.consolidateFiles", false) // Are we using sort-based shuffle? - val sortBasedShuffle = - conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName + val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager] private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 75c2e09a6bbb8..aa83ea90ee9ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.util.concurrent.ArrayBlockingQueue import akka.actor._ +import org.apache.spark.shuffle.hash.HashShuffleManager import util.Random import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} @@ -101,7 +102,7 @@ private[spark] object ThreadingTest { conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf), new MapOutputTrackerMaster(conf)) + new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 29e9cf947856f..6b4689291097f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -93,7 +93,7 @@ private[spark] object JettyUtils extends Logging { def createServletHandler( path: String, servlet: HttpServlet, - basePath: String = ""): ServletContextHandler = { + basePath: String): ServletContextHandler = { val prefixedPath = attachPrefix(basePath, path) val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index a3e9566832d06..74cd637d88155 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -200,6 +200,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.shuffleReadBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta + val inputBytesDelta = + (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)) + stageData.inputBytes += inputBytesDelta + execSummary.inputBytes += inputBytesDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 6f8eb1ee12634..1e18ec688c40d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -72,8 +72,9 @@ private[spark] object JsonProtocol { case applicationEnd: SparkListenerApplicationEnd => applicationEndToJson(applicationEnd) - // Not used, but keeps compiler happy + // These aren't used, but keeps compiler happy case SparkListenerShutdown => JNothing + case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala similarity index 66% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala rename to core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala index 2deaf4ae8dcab..c1b8bf052c0ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala @@ -15,14 +15,19 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.model +package org.apache.spark.util + +import java.util.EventListener + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.DeveloperApi /** - * Filter specifying a split and type of comparison to be applied on features - * @param split split specifying the feature index, type and threshold - * @param comparison integer specifying <,=,> + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution completes. */ -private[tree] case class Filter(split: Split, comparison: Int) { - // Comparison -1,0,1 signifies <.=,> - override def toString = " split = " + split + "comparison = " + comparison +@DeveloperApi +trait TaskCompletionListener extends EventListener { + def onTaskCompletion(context: TaskContext) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8cac5da644fa9..019f68b160894 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -146,6 +146,9 @@ private[spark] object Utils extends Logging { Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess } + /** Preferred alternative to Class.forName(className) */ + def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) + /** * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. */ diff --git a/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java new file mode 100644 index 0000000000000..3d50ab4fabe42 --- /dev/null +++ b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.Option; +import scala.reflect.ClassTag; + + +/** + * A simple Serializer implementation to make sure the API is Java-friendly. + */ +class TestJavaSerializerImpl extends Serializer { + + @Override + public SerializerInstance newInstance() { + return null; + } + + static class SerializerInstanceImpl extends SerializerInstance { + @Override + public ByteBuffer serialize(T t, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag evidence$1) { + return null; + } + + @Override + public SerializationStream serializeStream(OutputStream s) { + return null; + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + return null; + } + } + + static class SerializationStreamImpl extends SerializationStream { + + @Override + public SerializationStream writeObject(T t, ClassTag evidence$1) { + return null; + } + + @Override + public void flush() { + + } + + @Override + public void close() { + + } + } + + static class DeserializationStreamImpl extends DeserializationStream { + + @Override + public T readObject(ClassTag evidence$1) { + return null; + } + + @Override + public void close() { + + } + } +} diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java new file mode 100644 index 0000000000000..af34cdb03e4d1 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util; + +import org.apache.spark.TaskContext; + + +/** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ +public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { + + @Override + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.runningLocally(); + context.taskMetrics(); + context.addTaskCompletionListener(this); + } +} diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 846537df003df..e2f4d4c57cdb5 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -19,14 +19,19 @@ package org.apache.spark.network import java.io.IOException import java.nio._ +import java.util.concurrent.TimeoutException import org.apache.spark.{SecurityManager, SparkConf} import org.scalatest.FunSuite +import org.mockito.Mockito._ +import org.mockito.Matchers._ + +import scala.concurrent.TimeoutException import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps -import scala.util.Try +import scala.util.{Failure, Success, Try} /** * Test the ConnectionManager with various security settings. @@ -255,5 +260,42 @@ class ConnectionManagerSuite extends FunSuite { } + test("sendMessageReliably timeout") { + val clientConf = new SparkConf + clientConf.set("spark.authenticate", "false") + val ackTimeout = 30 + clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}") + + val clientSecurityManager = new SecurityManager(clientConf) + val manager = new ConnectionManager(0, clientConf, clientSecurityManager) + + val serverConf = new SparkConf + serverConf.set("spark.authenticate", "false") + val serverSecurityManager = new SecurityManager(serverConf) + val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // sleep 60 sec > ack timeout for simulating server slow down or hang up + Thread.sleep(ackTimeout * 3 * 1000) + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + + // Future should throw IOException in 30 sec. + // Otherwise TimeoutExcepton is thrown from Await.result. + // We expect TimeoutException is not thrown. + intercept[IOException] { + Await.result(future, (ackTimeout * 2) second) + } + + manager.stop() + managerServer.stop() + } + } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index a822bd18bfdbd..f89bdb6e07dea 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -245,6 +245,29 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithoutBucketsForLargerDatasets") { + // Verify the case of slighly larger datasets + val rdd = sc.parallelize(6 to 99) + val (histogramBuckets, histogramResults) = rdd.histogram(8) + val expectedHistogramResults = + Array(12, 12, 11, 12, 12, 11, 12, 12) + val expectedHistogramBuckets = + Array(6.0, 17.625, 29.25, 40.875, 52.5, 64.125, 75.75, 87.375, 99.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsWithIrrationalBucketEdges") { + // Verify the case of buckets with irrational edges. See #SPARK-2862. + val rdd = sc.parallelize(6 to 99) + val (histogramBuckets, histogramResults) = rdd.histogram(9) + val expectedHistogramResults = + Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets(0) === 6.0) + assert(histogramBuckets(9) === 99.0) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 270f7e661045a..db2ad829a48f9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => TaskContextSuite.completed = true) + context.addTaskCompletionListener(context => TaskContextSuite.completed = true) sys.error("failed") } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala new file mode 100644 index 0000000000000..11e8c9c4cb37f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.spark.util.Utils + +import com.esotericsoftware.kryo.Kryo +import org.scalatest.FunSuite + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils} +import org.apache.spark.SparkContext._ +import org.apache.spark.serializer.KryoDistributedTest._ + +class KryoSerializerDistributedSuite extends FunSuite { + + test("kryo objects are serialised consistently in different processes") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + conf.set("spark.task.maxFailures", "1") + + val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) + conf.setJars(List(jar.getPath)) + + val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val original = Thread.currentThread.getContextClassLoader + val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) + SparkEnv.get.serializer.setDefaultClassLoader(loader) + + val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache() + + // Randomly mix the keys so that the join below will require a shuffle with each partition + // sending data to multiple other partitions. + val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)} + + // Join the two RDDs, and force evaluation + assert(shuffledRDD.join(cachedRDD).collect().size == 1) + + LocalSparkContext.stop(sc) + } +} + +object KryoDistributedTest { + class MyCustomClass + + class AppJarRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + val classLoader = Thread.currentThread.getContextClassLoader + k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + } + } + + object AppJarRegistrator { + val customClassName = "KryoSerializerDistributedSuiteCustomClass" + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala new file mode 100644 index 0000000000000..967c9e9899c9d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.LocalSparkContext +import org.apache.spark.SparkException + + +class KryoSerializerResizableOutputSuite extends FunSuite { + + // trial and error showed this will not serialize with 1mb buffer + val x = (1 to 400000).toArray + + test("kryo without resizable output buffer should fail on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "1") + val sc = new SparkContext("local", "test", conf) + intercept[SparkException](sc.parallelize(x).collect()) + LocalSparkContext.stop(sc) + } + + test("kryo with resizable output buffer should succeed on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "2") + val sc = new SparkContext("local", "test", conf) + assert(sc.parallelize(x).collect() === x) + LocalSparkContext.stop(sc) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 3bf9efebb39d2..e1e35b688d581 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,9 +23,10 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ + class KryoSerializerSuite extends FunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) @@ -207,7 +208,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x assert(10 + control.sum === result) } - + test("kryo with nonexistent custom registrator should fail") { import org.apache.spark.{SparkConf, SparkException} @@ -217,39 +218,33 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist")) } -} -class KryoSerializerResizableOutputSuite extends FunSuite { - import org.apache.spark.SparkConf - import org.apache.spark.SparkContext - import org.apache.spark.LocalSparkContext - import org.apache.spark.SparkException + test("default class loader can be set by a different thread") { + val ser = new KryoSerializer(new SparkConf) - // trial and error showed this will not serialize with 1mb buffer - val x = (1 to 400000).toArray + // First serialize the object + val serInstance = ser.newInstance() + val bytes = serInstance.serialize(new ClassLoaderTestingObject) - test("kryo without resizable output buffer should fail on large array") { - val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") - val sc = new SparkContext("local", "test", conf) - intercept[SparkException](sc.parallelize(x).collect) - LocalSparkContext.stop(sc) - } + // Deserialize the object to make sure normal deserialization works + serInstance.deserialize[ClassLoaderTestingObject](bytes) - test("kryo with resizable output buffer should succeed on large array") { - val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") - val sc = new SparkContext("local", "test", conf) - assert(sc.parallelize(x).collect === x) - LocalSparkContext.stop(sc) + // Set a special, broken ClassLoader and make sure we get an exception on deserialization + ser.setDefaultClassLoader(new ClassLoader() { + override def loadClass(name: String) = throw new UnsupportedOperationException + }) + intercept[UnsupportedOperationException] { + ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) + } } } + +class ClassLoaderTestingObject + + object KryoTest { + case class CaseClass(i: Int, s: String) {} class ClassWithNoArgConstructor { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 94bb2c445d2e9..20bac66105a69 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit import akka.actor._ import akka.pattern.ask import akka.util.Timeout +import org.apache.spark.shuffle.hash.HashShuffleManager import org.mockito.invocation.InvocationOnMock import org.mockito.Matchers.any @@ -61,6 +62,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -71,8 +73,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { - new BlockManager( - name, actorSystem, master, serializer, maxMem, conf, securityMgr, mapOutputTracker) + new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr, + mapOutputTracker, shuffleManager) } before { @@ -791,7 +793,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) // The put should fail since a1 is not serializable. class UnserializableClass @@ -1007,7 +1009,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) val worker = spy(new BlockManagerWorker(store)) val connManagerId = mock(classOf[ConnectionManagerId]) @@ -1054,7 +1056,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) val worker = spy(new BlockManagerWorker(store)) val connManagerId = mock(classOf[ConnectionManagerId]) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index b8299e2ea187f..777579bc570db 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import org.apache.spark.shuffle.hash.HashShuffleManager + import scala.collection.mutable import scala.language.reflectiveCalls @@ -42,7 +44,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before // so we coerce consolidation if not already enabled. testConf.set("spark.shuffle.consolidateFiles", "true") - val shuffleBlockManager = new ShuffleBlockManager(null) { + private val shuffleManager = new HashShuffleManager(testConf.clone) + + val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) { override def conf = testConf.clone var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) @@ -148,7 +152,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), confCopy) val store = new BlockManager("", actorSystem, master , serializer, confCopy, - securityManager, null) + securityManager, null, shuffleManager) try { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index f5ba31c309277..147ec0bc52e39 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.{LocalSparkContext, SparkConf, Success} -import org.apache.spark.executor.{ShuffleWriteMetrics, ShuffleReadMetrics, TaskMetrics} +import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -150,6 +150,9 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.executorRunTime = base + 4 taskMetrics.diskBytesSpilled = base + 5 taskMetrics.memoryBytesSpilled = base + 6 + val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + taskMetrics.inputMetrics = Some(inputMetrics) + inputMetrics.bytesRead = base + 7 taskMetrics } @@ -182,6 +185,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(stage1Data.diskBytesSpilled == 205) assert(stage0Data.memoryBytesSpilled == 112) assert(stage1Data.memoryBytesSpilled == 206) + assert(stage0Data.inputBytes == 114) + assert(stage1Data.inputBytes == 207) assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get .totalBlocksFetched == 2) assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get @@ -208,6 +213,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(stage1Data.diskBytesSpilled == 610) assert(stage0Data.memoryBytesSpilled == 412) assert(stage1Data.memoryBytesSpilled == 612) + assert(stage0Data.inputBytes == 414) + assert(stage1Data.inputBytes == 614) assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get .totalBlocksFetched == 302) assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 1867cf4ec46ca..28f26d2368254 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -117,12 +117,13 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" -make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & make_binary_release "hadoop2" \ - "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & make_binary_release "hadoop2-without-hive" \ - "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & +wait # Copy data echo "Copying release tarballs" diff --git a/dev/mima b/dev/mima index 4c3e65039b160..09e4482af5f3d 100755 --- a/dev/mima +++ b/dev/mima @@ -26,7 +26,9 @@ cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update -export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) -printf "%p:" ` +export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"` +echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" + ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3076eb847b420..31506e28e05af 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -19,67 +19,148 @@ # Wrapper script that runs the Spark tests then reports QA results # to github via its API. +# Environment variables are populated by the code here: +#+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" +PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" -function post_message { - message=$1 - data="{\"body\": \"$message\"}" - echo "Attempting to post to Github:" - echo "$data" +COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" +# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( +SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" - curl -D- -u x-oauth-basic:$GITHUB_OAUTH_KEY -X POST --data "$data" -H \ - "Content-Type: application/json" \ - $COMMENTS_URL | head -n 8 +# NOTE: Jenkins will kill the whole build after 120 minutes. +# Tests are a large part of that, but not all of it. +TESTS_TIMEOUT="120m" + +function post_message () { + local message=$1 + local data="{\"body\": \"$message\"}" + local HTTP_CODE_HEADER="HTTP Response Code: " + + echo "Attempting to post to Github..." + + local curl_output=$( + curl `#--dump-header -` \ + --silent \ + --user x-oauth-basic:$GITHUB_OAUTH_KEY \ + --request POST \ + --data "$data" \ + --write-out "${HTTP_CODE_HEADER}%{http_code}\n" \ + --header "Content-Type: application/json" \ + "$COMMENTS_URL" #> /dev/null #| "$FWDIR/dev/jq" .id #| head -n 8 + ) + local curl_status=${PIPESTATUS[0]} + + if [ "$curl_status" -ne 0 ]; then + echo "Failed to post message to GitHub." >&2 + echo " > curl_status: ${curl_status}" >&2 + echo " > curl_output: ${curl_output}" >&2 + echo " > data: ${data}" >&2 + # exit $curl_status + fi + + local api_response=$( + echo "${curl_output}" \ + | grep -v -e "^${HTTP_CODE_HEADER}" + ) + + local http_code=$( + echo "${curl_output}" \ + | grep -e "^${HTTP_CODE_HEADER}" \ + | sed -r -e "s/^${HTTP_CODE_HEADER}//g" + ) + + if [ -n "$http_code" ] && [ "$http_code" -ne "201" ]; then + echo " > http_code: ${http_code}." >&2 + echo " > api_response: ${api_response}" >&2 + echo " > data: ${data}" >&2 + fi + + if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then + echo " > Post successful." + fi } -start_message="QA tests have started for PR $ghprbPullId." -if [ "$sha1" == "$ghprbActualCommit" ]; then - start_message="$start_message This patch DID NOT merge cleanly! " -else - start_message="$start_message This patch merges cleanly. " -fi -start_message="$start_message
View progress: " -start_message="$start_message${BUILD_URL}consoleFull" - -post_message "$start_message" - -./dev/run-tests -test_result="$?" - -result_message="QA results for PR $ghprbPullId:
" - -if [ "$test_result" -eq "0" ]; then - result_message="$result_message- This patch PASSES unit tests.
" -else - result_message="$result_message- This patch FAILED unit tests.
" -fi - -if [ "$sha1" != "$ghprbActualCommit" ]; then - result_message="$result_message- This patch merges cleanly
" - non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ") - new_public_classes=$(git diff master $non_test_files \ - | grep -e "trait " -e "class " \ - | grep -e "{" -e "(" \ - | grep -v -e \@\@ -e private \ - | grep \+ \ - | sed "s/\+ *//" \ - | tr "\n" "~" \ - | sed "s/~/
/g") - if [ "$new_public_classes" == "" ]; then - result_message="$result_message- This patch adds no public classes
" +# check PR merge-ability and check for new public classes +{ + if [ "$sha1" == "$ghprbActualCommit" ]; then + merge_note=" * This patch **does not** merge cleanly!" else - result_message="$result_message- This patch adds the following public classes (experimental):
" - result_message="$result_message$new_public_classes" + merge_note=" * This patch merges cleanly." + + non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ") + new_public_classes=$( + git diff master ${non_test_files} `# diff this patch against master and...` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` + ) + + if [ "$new_public_classes" == "" ]; then + public_classes_note=" * This patch adds no public classes." + else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + public_classes_note="${public_classes_note}\n${new_public_classes}" + fi fi -fi -result_message="${result_message}
For more information see test ouptut:" -result_message="${result_message}
${BUILD_URL}consoleFull" +} -post_message "$result_message" +# post start message +{ + start_message="\ + [QA tests have started](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." + + start_message="${start_message}\n${merge_note}" + # start_message="${start_message}\n${public_classes_note}" + + post_message "$start_message" +} + +# run tests +{ + timeout "${TESTS_TIMEOUT}" ./dev/run-tests + test_result="$?" + + if [ "$test_result" -eq "124" ]; then + fail_message="**Tests timed out** after a configured wait of \`${TESTS_TIMEOUT}\`." + post_message "$fail_message" + exit $test_result + else + if [ "$test_result" -eq "0" ]; then + test_result_note=" * This patch **passes** unit tests." + else + test_result_note=" * This patch **fails** unit tests." + fi + fi +} + +# post end message +{ + result_message="\ + [QA tests have finished](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." + + result_message="${result_message}\n${test_result_note}" + result_message="${result_message}\n${merge_note}" + result_message="${result_message}\n${public_classes_note}" + + post_message "$result_message" +} exit $test_result diff --git a/docs/configuration.md b/docs/configuration.md index c408c468dcd94..981170d8b49b7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -884,6 +884,15 @@ Apart from these, the following properties are also available, and may be useful out and giving up. + + spark.core.connection.ack.wait.timeout + 60 + + Number of seconds for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + spark.ui.filters None diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 21453cb9cd8c9..4b3cb715c58c7 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -9,4 +9,65 @@ displayTitle: MLlib - Feature Extraction ## Word2Vec -## TFIDF +Word2Vec computes distributed vector representation of words. The main advantage of the distributed +representations is that similar words are close in the vector space, which makes generalization to +novel patterns easier and model estimation more robust. Distributed vector representation is +showed to be useful in many natural language processing applications such as named entity +recognition, disambiguation, parsing, tagging and machine translation. + +### Model + +In our implementation of Word2Vec, we used skip-gram model. The training objective of skip-gram is +to learn word vector representations that are good at predicting its context in the same sentence. +Mathematically, given a sequence of training words `$w_1, w_2, \dots, w_T$`, the objective of the +skip-gram model is to maximize the average log-likelihood +`\[ +\frac{1}{T} \sum_{t = 1}^{T}\sum_{j=-k}^{j=k} \log p(w_{t+j} | w_t) +\]` +where $k$ is the size of the training window. + +In the skip-gram model, every word $w$ is associated with two vectors $u_w$ and $v_w$ which are +vector representations of $w$ as word and context respectively. The probability of correctly +predicting word $w_i$ given word $w_j$ is determined by the softmax model, which is +`\[ +p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top}v_{w_j})} +\]` +where $V$ is the vocabulary size. + +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, +we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to +$O(\log(V))$ + +### Example + +The example below demonstrates how to load a text file, parse it as an RDD of `Seq[String]`, +construct a `Word2Vec` instance and then fit a `Word2VecModel` with the input data. Finally, +we display the top 40 synonyms of the specified word. To run the example, first download +the [text8](http://mattmahoney.net/dc/text8.zip) data and extract it to your preferred directory. +Here we assume the extracted file is `text8` and in same directory as you run the spark shell. + +
+
+{% highlight scala %} +import org.apache.spark._ +import org.apache.spark.rdd._ +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.feature.Word2Vec + +val input = sc.textFile("text8").map(line => line.split(" ").toSeq) + +val word2vec = new Word2Vec() + +val model = word2vec.fit(input) + +val synonyms = model.findSynonyms("china", 40) + +for((synonym, cosineSimilarity) <- synonyms) { + println(s"$synonym $cosineSimilarity") +} +{% endhighlight %} +
+
+ +## TFIDF \ No newline at end of file diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cd6543945c385..34accade36ea9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -605,6 +605,11 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may also use the beeline script comes with Hive. +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + + SET spark.sql.thriftserver.scheduler.pool=accounting; + ### Migration Guide for Shark Users #### Reducer number diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md index 801c905c88df8..16ad3222105a2 100644 --- a/docs/streaming-kinesis.md +++ b/docs/streaming-kinesis.md @@ -3,56 +3,57 @@ layout: global title: Spark Streaming Kinesis Receiver --- -### Kinesis -Build notes: -
  • Spark supports a Kinesis Streaming Receiver which is not included in the default build due to licensing restrictions.
  • -
  • _**Note that by embedding this library you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • -
  • The Spark Kinesis Streaming Receiver source code, examples, tests, and artifacts live in $SPARK_HOME/extras/kinesis-asl.
  • -
  • To build with Kinesis, you must run the maven or sbt builds with -Pkinesis-asl`.
  • -
  • Applications will need to link to the 'spark-streaming-kinesis-asl` artifact.
  • +## Kinesis +###Design +
  • The KinesisReceiver uses the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License.
  • +
  • The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concept of Workers, Checkpoints, and Shard Leases.
  • +
  • The KCL uses DynamoDB to maintain all state. A DynamoDB table is created in the us-east-1 region (regardless of Kinesis stream region) during KCL initialization for each Kinesis application name.
  • +
  • A single KinesisReceiver can process many shards of a stream by spinning up multiple KinesisRecordProcessor threads.
  • +
  • You never need more KinesisReceivers than the number of shards in your stream as each will spin up at least one KinesisRecordProcessor thread.
  • +
  • Horizontal scaling is achieved by autoscaling additional KinesisReceiver (separate processes) or spinning up new KinesisRecordProcessor threads within each KinesisReceiver - up to the number of current shards for a given stream, of course. Don't forget to autoscale back down!
  • -Kinesis examples notes: -
  • To build the Kinesis examples, you must run the maven or sbt builds with -Pkinesis-asl`.
  • -
  • These examples automatically determine the number of local threads and KinesisReceivers to spin up based on the number of shards for the stream.
  • -
  • KinesisWordCountProducerASL will generate random data to put onto the Kinesis stream for testing.
  • -
  • Checkpointing is disabled (no checkpoint dir is set). The examples as written will not recover from a driver failure.
  • +### Build +
  • Spark supports a Streaming KinesisReceiver, but it is not included in the default build due to Amazon Software Licensing (ASL) restrictions.
  • +
  • To build with the Kinesis Streaming Receiver and supporting ASL-licensed code, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • +
  • All KinesisReceiver-related code, examples, tests, and artifacts live in **$SPARK_HOME/extras/kinesis-asl/**.
  • +
  • Kinesis-based Spark Applications will need to link to the **spark-streaming-kinesis-asl** artifact that is built when **-Pkinesis-asl** is specified.
  • +
  • _**Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • -Deployment and runtime notes: -
  • A single KinesisReceiver can process many shards of a stream.
  • -
  • Each shard of a stream is processed by one or more KinesisReceiver's managed by the Kinesis Client Library (KCL) Worker.
  • -
  • You never need more KinesisReceivers than the number of shards in your stream.
  • -
  • You can horizontally scale the receiving by creating more KinesisReceiver/DStreams (up to the number of shards for a given stream)
  • -
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the Kinesis Client Library.
  • -
  • This code uses the DefaultAWSCredentialsProviderChain and searches for credentials in the following order of precedence:
    - 1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    - 2) Java System Properties - aws.accessKeyId and aws.secretKey
    - 3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    - 4) Instance profile credentials - delivered through the Amazon EC2 metadata service
    -
  • -
  • You need to setup a Kinesis stream with 1 or more shards per the following:
    - http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • -
  • Valid Kinesis endpoint urls can be found here: Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • -
  • When you first start up the KinesisReceiver, the Kinesis Client Library (KCL) needs ~30s to establish connectivity with the AWS Kinesis service, -retrieve any checkpoint data, and negotiate with other KCL's reading from the same stream.
  • -
  • Be careful when changing the app name. Kinesis maintains a mapping table in DynamoDB based on this app name (http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization). -Changing the app name could lead to Kinesis errors as only 1 logical application can process a stream. In order to start fresh, -it's always best to delete the DynamoDB table that matches your app name. This DynamoDB table lives in us-east-1 regardless of the Kinesis endpoint URL.
  • +###Example +
  • To build the Kinesis example, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • +
  • You need to setup a Kinesis stream at one of the valid Kinesis endpoints with 1 or more shards per the following: http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • +
  • Valid Kinesis endpoints can be found here: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • +
  • When running **locally**, the example automatically determines the number of threads and KinesisReceivers to spin up based on the number of shards configured for the stream. Therefore, **local[n]** is not needed when starting the example as with other streaming examples.
  • +
  • While this example could use a single KinesisReceiver which spins up multiple KinesisRecordProcessor threads to process multiple shards, I wanted to demonstrate unioning multiple KinesisReceivers as a single DStream. (It's a bit confusing in local mode.)
  • +
  • **KinesisWordCountProducerASL** is provided to generate random records into the Kinesis stream for testing.
  • +
  • The example has been configured to immediately replicate incoming stream data to another node by using (StorageLevel.MEMORY_AND_DISK_2) +
  • Spark checkpointing is disabled because the example does not use any stateful or window-based DStream operations such as updateStateByKey and reduceByWindow. If those operations are introduced, you would need to enable checkpointing or risk losing data in the case of a failure.
  • +
  • Kinesis checkpointing is enabled. This means that the example will recover from a Kinesis failure.
  • +
  • The example uses InitialPositionInStream.LATEST strategy to pull from the latest tip of the stream if no Kinesis checkpoint info exists.
  • +
  • In our example, **KinesisWordCount** is the Kinesis application name for both the Scala and Java versions. The use of this application name is described next.
  • -Failure recovery notes: -
  • The combination of Spark Streaming and Kinesis creates 3 different checkpoints as follows:
    - 1) RDD data checkpoint (Spark Streaming) - frequency is configurable with DStream.checkpoint(Duration)
    - 2) RDD metadata checkpoint (Spark Streaming) - frequency is every DStream batch
    - 3) Kinesis checkpointing (Kinesis) - frequency is controlled by the developer calling ICheckpointer.checkpoint() directly
    +###Deployment and Runtime +
  • A Kinesis application name must be unique for a given account and region.
  • +
  • A DynamoDB table and CloudWatch namespace are created during KCL initialization using this Kinesis application name. http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization
  • +
  • This DynamoDB table lives in the us-east-1 region regardless of the Kinesis endpoint URL.
  • +
  • Changing the app name or stream name could lead to Kinesis errors as only a single logical application can process a single stream.
  • +
  • If you are seeing errors after changing the app name or stream name, it may be necessary to manually delete the DynamoDB table and start from scratch.
  • +
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the KCL.
  • +
  • The KinesisReceiver uses the DefaultAWSCredentialsProviderChain for AWS credentials which searches for credentials in the following order of precedence:
    +1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    +2) Java System Properties - aws.accessKeyId and aws.secretKey
    +3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    +4) Instance profile credentials - delivered through the Amazon EC2 metadata service
  • -
  • Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling
  • -
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last checkpoint sequence number recorded per shard.
  • -
  • If no checkpoint info exists, the worker will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) -or from the tip/latest (InitialPostitionInStream.LATEST). This is configurable.
  • -
  • When pulling from the stream tip (InitialPositionInStream.LATEST), only new stream data will be picked up after the KinesisReceiver starts.
  • -
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running.
  • -
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data -depending on the checkpoint frequency.
  • -
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records depending on the checkpoint frequency.
  • + +###Fault-Tolerance +
  • The combination of Spark Streaming and Kinesis creates 2 different checkpoints that may occur at different intervals.
  • +
  • Checkpointing too frequently against Kinesis will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random backoff retry strategy.
  • +
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last Kinesis checkpoint sequence number recorded per shard (stored in the DynamoDB table).
  • +
  • If no Kinesis checkpoint info exists, the KinesisReceiver will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable.
  • +
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running (and no checkpoint info is being stored.)
  • +
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data.
  • +
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency.
  • Record processing should be idempotent when possible.
  • -
  • Failed or latent KinesisReceivers will be detected and automatically shutdown/load-balanced by the KCL.
  • -
  • If possible, explicitly shutdown the worker if a failure occurs in order to trigger the final checkpoint.
  • +
  • A failed or latent KinesisRecordProcessor within the KinesisReceiver will be detected and automatically restarted by the KCL.
  • +
  • If possible, the KinesisReceiver should be shutdown cleanly in order to trigger a final checkpoint of all KinesisRecordProcessors to avoid duplicate record processing.
  • \ No newline at end of file diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py new file mode 100644 index 0000000000000..e902ae29753c0 --- /dev/null +++ b/examples/src/main/python/avro_inputformat.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Read data file users.avro in local Spark distro: + +$ cd $SPARK_HOME +$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +> examples/src/main/resources/users.avro +{u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]} +{u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} + +To read name and favorite_color fields only, specify the following reader schema: + +$ cat examples/src/main/resources/user.avsc +{"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_color", "type": ["string", "null"]} + ] +} + +$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +> examples/src/main/resources/users.avro examples/src/main/resources/user.avsc +{u'favorite_color': None, u'name': u'Alyssa'} +{u'favorite_color': u'red', u'name': u'Ben'} +""" +if __name__ == "__main__": + if len(sys.argv) != 2 and len(sys.argv) != 3: + print >> sys.stderr, """ + Usage: avro_inputformat [reader_schema_file] + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/avro_inputformat.py [reader_schema_file] + Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. + """ + exit(-1) + + path = sys.argv[1] + sc = SparkContext(appName="AvroKeyInputFormat") + + conf = None + if len(sys.argv) == 3: + schema_rdd = sc.textFile(sys.argv[2], 1).collect() + conf = {"avro.schema.input.key" : reduce(lambda x, y: x+y, schema_rdd)} + + avro_rdd = sc.newAPIHadoopFile(path, + "org.apache.avro.mapreduce.AvroKeyInputFormat", + "org.apache.avro.mapred.AvroKey", + "org.apache.hadoop.io.NullWritable", + keyConverter="org.apache.spark.examples.pythonconverters.AvroWrapperToJavaConverter", + conf=conf) + output = avro_rdd.map(lambda x: x[0]).collect() + for k in output: + print k diff --git a/examples/src/main/resources/user.avsc b/examples/src/main/resources/user.avsc new file mode 100644 index 0000000000000..4995357ab3736 --- /dev/null +++ b/examples/src/main/resources/user.avsc @@ -0,0 +1,8 @@ +{"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_color", "type": ["string", "null"]} + ] +} diff --git a/examples/src/main/resources/users.avro b/examples/src/main/resources/users.avro new file mode 100644 index 0000000000000..27c526ab114b2 Binary files /dev/null and b/examples/src/main/resources/users.avro differ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 56b02b65d8724..a6f78d2441db1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -21,7 +21,7 @@ import org.apache.log4j.{Level, Logger} import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD} +import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater} @@ -66,7 +66,8 @@ object BinaryClassification { .text("number of iterations") .action((x, c) => c.copy(numIterations = x)) opt[Double]("stepSize") - .text(s"initial step size, default: ${defaultParams.stepSize}") + .text("initial step size (ignored by logistic regression), " + + s"default: ${defaultParams.stepSize}") .action((x, c) => c.copy(stepSize = x)) opt[String]("algorithm") .text(s"algorithm (${Algorithm.values.mkString(",")}), " + @@ -125,10 +126,9 @@ object BinaryClassification { val model = params.algorithm match { case LR => - val algorithm = new LogisticRegressionWithSGD() + val algorithm = new LogisticRegressionWithLBFGS() algorithm.optimizer .setNumIterations(params.numIterations) - .setStepSize(params.stepSize) .setUpdater(updater) .setRegParam(params.regParam) algorithm.run(training).clearThreshold() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1fd37edfa7427..0e992fa9967bb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -18,8 +18,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -56,8 +55,8 @@ object StreamingLinearRegression { val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) - val trainingData = MLUtils.loadStreamingLabeledPoints(ssc, args(0)) - val testData = MLUtils.loadStreamingLabeledPoints(ssc, args(1)) + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0))) diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala new file mode 100644 index 0000000000000..1b25983a38453 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.pythonconverters + +import java.util.{Collection => JCollection, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.avro.generic.{GenericFixed, IndexedRecord} +import org.apache.avro.mapred.AvroWrapper +import org.apache.avro.Schema +import org.apache.avro.Schema.Type._ + +import org.apache.spark.api.python.Converter +import org.apache.spark.SparkException + + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries + * to work with all 3 Avro data mappings (Generic, Specific and Reflect). + */ +class AvroWrapperToJavaConverter extends Converter[Any, Any] { + override def convert(obj: Any): Any = { + if (obj == null) { + return null + } + obj.asInstanceOf[AvroWrapper[_]].datum() match { + case null => null + case record: IndexedRecord => unpackRecord(record) + case other => throw new SparkException( + s"Unsupported top-level Avro data type ${other.getClass.getName}") + } + } + + def unpackRecord(obj: Any): JMap[String, Any] = { + val map = new java.util.HashMap[String, Any] + obj match { + case record: IndexedRecord => + record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + map.put(f.name, fromAvro(record.get(i), f.schema)) + } + case other => throw new SparkException( + s"Unsupported RECORD type ${other.getClass.getName}") + } + map + } + + def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { + obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + (key.toString, fromAvro(value, schema.getValueType)) + } + } + + def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { + unpackBytes(obj.asInstanceOf[GenericFixed].bytes()) + } + + def unpackBytes(obj: Any): Array[Byte] = { + val bytes: Array[Byte] = obj match { + case buf: java.nio.ByteBuffer => buf.array() + case arr: Array[Byte] => arr + case other => throw new SparkException( + s"Unknown BYTES type ${other.getClass.getName}") + } + val bytearray = new Array[Byte](bytes.length) + System.arraycopy(bytes, 0, bytearray, 0, bytes.length) + bytearray + } + + def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { + case c: JCollection[_] => + c.map(fromAvro(_, schema.getElementType)) + case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => + arr.toSeq + case arr: Array[_] => + arr.map(fromAvro(_, schema.getElementType)).toSeq + case other => throw new SparkException( + s"Unknown ARRAY type ${other.getClass.getName}") + } + + def unpackUnion(obj: Any, schema: Schema): Any = { + schema.getTypes.toList match { + case List(s) => fromAvro(obj, s) + case List(n, s) if n.getType == NULL => fromAvro(obj, s) + case List(s, n) if n.getType == NULL => fromAvro(obj, s) + case _ => throw new SparkException( + "Unions may only consist of a concrete type and null") + } + } + + def fromAvro(obj: Any, schema: Schema): Any = { + if (obj == null) { + return null + } + schema.getType match { + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj + case BOOLEAN => obj + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException( + s"Unknown Avro schema type ${other.getName}") + } + } +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 7b735133e3d14..948af5947f5e1 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -131,6 +131,14 @@ class SparkSink extends AbstractSink with Logging with Configurable { blockingLatch.await() Status.BACKOFF } + + private[flume] def getPort(): Int = { + serverOpt + .map(_.getPort) + .getOrElse( + throw new RuntimeException("Server was not started!") + ) + } } /** diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index a69baa16981a1..8a85b0f987e42 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -22,6 +22,8 @@ import java.net.InetSocketAddress import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} import java.util.Random +import org.apache.spark.TestUtils + import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -39,9 +41,6 @@ import org.apache.spark.util.Utils class FlumePollingStreamSuite extends TestSuiteBase { - val random = new Random() - /** Return a port in the ephemeral range. */ - def getTestPort = random.nextInt(16382) + 49152 val batchCount = 5 val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch @@ -77,17 +76,6 @@ class FlumePollingStreamSuite extends TestSuiteBase { } private def testFlumePolling(): Unit = { - val testPort = getTestPort - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", testPort)), - StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - // Start the channel and sink. val context = new Context() context.put("capacity", channelCapacity.toString) @@ -98,10 +86,19 @@ class FlumePollingStreamSuite extends TestSuiteBase { val sink = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink, context) sink.setChannel(channel) sink.start() + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())), + StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() ssc.start() writeAndVerify(Seq(channel), ssc, outputBuffer) @@ -111,18 +108,6 @@ class FlumePollingStreamSuite extends TestSuiteBase { } private def testFlumePollingMultipleHost(): Unit = { - val testPort = getTestPort - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val addresses = Seq(testPort, testPort + 1).map(new InetSocketAddress("localhost", _)) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - // Start the channel and sink. val context = new Context() context.put("capacity", channelCapacity.toString) @@ -136,17 +121,29 @@ class FlumePollingStreamSuite extends TestSuiteBase { val sink = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink, context) sink.setChannel(channel) sink.start() val sink2 = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort + 1)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink2, context) sink2.setChannel(channel2) sink2.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _)) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, + eventsPerBatch, 5) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + ssc.start() writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) assertChannelIsEmpty(channel) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 18dc087856785..4343124f102a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} +import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 31d474a20fa85..486bdbfa9cb47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -62,7 +62,7 @@ class LogisticRegressionModel ( override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - val score = 1.0/ (1.0 + math.exp(-margin)) + val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { case Some(t) => if (score < t) 0.0 else 1.0 case None => score @@ -73,6 +73,8 @@ class LogisticRegressionModel ( /** * Train a classification model for Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( private var stepSize: Double, @@ -191,49 +193,19 @@ object LogisticRegressionWithSGD { /** * Train a classification model for Logistic Regression using Limited-memory BFGS. + * Standard feature scaling and L2 regularization are used by default. * NOTE: Labels used in Logistic Regression should be {0, 1} */ -class LogisticRegressionWithLBFGS private ( - private var convergenceTol: Double, - private var maxNumIterations: Int, - private var regParam: Double) +class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { - /** - * Construct a LogisticRegression object with default parameters - */ - def this() = this(1E-4, 100, 0.0) + this.setFeatureScaling(true) - private val gradient = new LogisticGradient() - private val updater = new SimpleUpdater() - // Have to return new LBFGS object every time since users can reset the parameters anytime. - override def optimizer = new LBFGS(gradient, updater) - .setNumCorrections(10) - .setConvergenceTol(convergenceTol) - .setMaxNumIterations(maxNumIterations) - .setRegParam(regParam) + override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override protected val validators = List(DataValidators.binaryLabelValidator) - /** - * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. - * Smaller value will lead to higher accuracy with the cost of more iterations. - */ - def setConvergenceTol(convergenceTol: Double): this.type = { - this.convergenceTol = convergenceTol - this - } - - /** - * Set the maximal number of iterations for L-BFGS. Default 100. - */ - def setNumIterations(numIterations: Int): this.type = { - this.maxNumIterations = numIterations - this - } - override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index ecd49ea2ff533..1dcaa2cd2e630 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap /** * Entry in vocabulary @@ -287,11 +288,12 @@ class Word2Vec extends Serializable with Logging { var syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) var syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val syn0Modify = new Array[Int](vocabSize) + val syn1Modify = new Array[Int](vocabSize) val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -321,7 +323,8 @@ class Word2Vec extends Serializable with Logging { // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val l2 = bcVocab.value(word).point(d) * vectorSize + val inner = bcVocab.value(word).point(d) + val l2 = inner * vectorSize // Propagate hidden -> output var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { @@ -330,10 +333,12 @@ class Word2Vec extends Serializable with Logging { val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + syn1Modify(inner) += 1 } d += 1 } blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + syn0Modify(lastWord) += 1 } } a += 1 @@ -342,21 +347,36 @@ class Word2Vec extends Serializable with Logging { } (syn0, syn1, lwc, wc) } - Iterator(model) + val syn0Local = model._1 + val syn1Local = model._2 + val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) + var index = 0 + while(index < vocabSize) { + if (syn0Modify(index) != 0) { + synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + if (syn1Modify(index) != 0) { + synOut.update(index + vocabSize, + syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + index += 1 + } + Iterator(synOut) } - val (aggSyn0, aggSyn1, _, _) = - partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) => + blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + v1 + }.collect() + var i = 0 + while (i < synAgg.length) { + val index = synAgg(i)._1 + if (index < vocabSize) { + Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize) + } else { + Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize) } - syn0Global = aggSyn0 - syn1Global = aggSyn1 + i += 1 + } } newSentences.unpersist() @@ -414,15 +434,6 @@ class Word2VecModel private[mllib] ( } } - /** - * Transforms an RDD to its vector representation - * @param dataset a an RDD of words - * @return RDD of vector representation - */ - def transform(dataset: RDD[String]): RDD[Vector] = { - dataset.map(word => transform(word)) - } - /** * Find synonyms of a word * @param word a word diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 033fe44f34f3c..d16d0daf08565 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -69,8 +69,17 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) /** * Set the maximal number of iterations for L-BFGS. Default 100. + * @deprecated use [[LBFGS#setNumIterations]] instead */ + @deprecated("use setNumIterations instead", "1.1.0") def setMaxNumIterations(iters: Int): this.type = { + this.setNumIterations(iters) + } + + /** + * Set the maximal number of iterations for L-BFGS. Default 100. + */ + def setNumIterations(iters: Int): this.type = { this.maxNumIterations = iters this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala rename to mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index b0a0593223910..36270369526cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.random +import scala.reflect.ClassTag + import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.Vector @@ -24,14 +26,12 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import scala.reflect.ClassTag - /** * :: Experimental :: * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. */ @Experimental -object RandomRDDGenerators { +object RandomRDDs { /** * :: Experimental :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 54854252d7477..20c1fdd2269ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ @@ -94,6 +95,22 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] protected var validateData: Boolean = true + /** + * Whether to perform feature scaling before model training to reduce the condition numbers + * which can significantly help the optimizer converging faster. The scaling correction will be + * translated back to resulting model weights, so it's transparent to users. + * Note: This technique is used in both libsvm and glmnet packages. Default false. + */ + private var useFeatureScaling = false + + /** + * Set if the algorithm should use feature scaling to improve the convergence during optimization. + */ + private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type = { + this.useFeatureScaling = useFeatureScaling + this + } + /** * Create a model given the weights and intercept */ @@ -137,11 +154,45 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } + /** + * Scaling columns to unit variance as a heuristic to reduce the condition number: + * + * During the optimization process, the convergence (rate) depends on the condition number of + * the training dataset. Scaling the variables often reduces this condition number + * heuristically, thus improving the convergence rate. Without reducing the condition number, + * some training datasets mixing the columns with different scales may not be able to converge. + * + * GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return + * the weights in the original scale. + * See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * Here, if useFeatureScaling is enabled, we will standardize the training features by dividing + * the variance of each column (without subtracting the mean), and train the model in the + * scaled space. Then we transform the coefficients from the scaled space to the original scale + * as GLMNET and LIBSVM do. + * + * Currently, it's only enabled in LogisticRegressionWithLBFGS + */ + val scaler = if (useFeatureScaling) { + (new StandardScaler).fit(input.map(x => x.features)) + } else { + null + } + // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) + if(useFeatureScaling) { + input.map(labeledPoint => + (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features)))) + } else { + input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) + } } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + if (useFeatureScaling) { + input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features))) + } else { + input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + } } val initialWeightsWithIntercept = if (addIntercept) { @@ -153,13 +204,25 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0 - val weights = + var weights = if (addIntercept) { Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) } else { weightsWithIntercept } + /** + * The weights and intercept are trained in the scaled space; we're converting them back to + * the original scale. + * + * Math shows that if we only perform standardization without subtracting means, the intercept + * will not be changed. w_i = w_i' / v_i where w_i' is the coefficient in the scaled space, w_i + * is the coefficient in the original space, and v_i is the variance of the column i. + */ + if (useFeatureScaling) { + weights = scaler.transform(weights) + } + createModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 62a03af4a9964..17c753c56681f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -36,7 +36,7 @@ case class LabeledPoint(label: Double, features: Vector) { /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. */ -private[mllib] object LabeledPointParser { +object LabeledPoint { /** * Parses a string resulted from `LabeledPoint#toString` into * an [[org.apache.spark.mllib.regression.LabeledPoint]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 8851097050318..1d11fde24712c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vector /** * Train or predict a linear regression model on streaming data. Training uses diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 3cf1028fbc725..3cf4e807b4cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -155,7 +155,7 @@ object Statistics { * :: Experimental :: * Conduct Pearson's independence test for every feature against the label across the input RDD. * For each feature, the (feature, label) pairs are converted into a contingency matrix for which - * the chi-squared statistic is computed. + * the chi-squared statistic is computed. All label and feature values must be categorical. * * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. * Real-valued features will be treated as categorical for each distinct value. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 9bd0c2cd05de4..4a6c677f06d28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.stat.correlation import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, HashPartitioner} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector} -import org.apache.spark.rdd.{CoGroupedRDD, RDD} +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors} +import org.apache.spark.rdd.RDD /** * Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix @@ -43,87 +43,51 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { /** * Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the * correlation between column i and j. - * - * Input RDD[Vector] should be cached or checkpointed if possible since it would be split into - * numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector]. */ override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { - val indexed = X.zipWithUniqueId() - - val numCols = X.first.size - if (numCols > 50) { - logWarning("Computing the Spearman correlation matrix can be slow for large RDDs with more" - + " than 50 columns.") - } - val ranks = new Array[RDD[(Long, Double)]](numCols) - - // Note: we use a for loop here instead of a while loop with a single index variable - // to avoid race condition caused by closure serialization - for (k <- 0 until numCols) { - val column = indexed.map { case (vector, index) => (vector(k), index) } - ranks(k) = getRanks(column) + // ((columnIndex, value), rowUid) + val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) => + vec.toArray.view.zipWithIndex.map { case (v, j) => + ((j, v), uid) + } } - - val ranksMat: RDD[Vector] = makeRankMatrix(ranks, X) - PearsonCorrelation.computeCorrelationMatrix(ranksMat) - } - - /** - * Compute the ranks for elements in the input RDD, using the average method for ties. - * - * With the average method, elements with the same value receive the same rank that's computed - * by taking the average of their positions in the sorted list. - * e.g. ranks([2, 1, 0, 2]) = [2.5, 1.0, 0.0, 2.5] - * Note that positions here are 0-indexed, instead of the 1-indexed as in the definition for - * ranks in the standard definition for Spearman's correlation. This does not affect the final - * results and is slightly more performant. - * - * @param indexed RDD[(Double, Long)] containing pairs of the format (originalValue, uniqueId) - * @return RDD[(Long, Double)] containing pairs of the format (uniqueId, rank), where uniqueId is - * copied from the input RDD. - */ - private def getRanks(indexed: RDD[(Double, Long)]): RDD[(Long, Double)] = { - // Get elements' positions in the sorted list for computing average rank for duplicate values - val sorted = indexed.sortByKey().zipWithIndex() - - val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => - // add an extra element to signify the end of the list so that flatMap can flush the last - // batch of duplicates - val end = -1L - val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) - val firstEntry = padded.next() - var lastVal = firstEntry._1._1 - var firstRank = firstEntry._2.toDouble - val idBuffer = ArrayBuffer(firstEntry._1._2) - padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != end) { - idBuffer += id - Iterator.empty - } else { - val entries = if (idBuffer.size == 1) { - Iterator((idBuffer(0), firstRank)) - } else { - val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 - idBuffer.map(id => (id, averageRank)) - } - lastVal = v - firstRank = rank - idBuffer.clear() - idBuffer += id - entries + // global sort by (columnIndex, value) + val sorted = colBased.sortByKey() + // assign global ranks (using average ranks for tied values) + val globalRanks = sorted.zipWithIndex().mapPartitions { iter => + var preCol = -1 + var preVal = Double.NaN + var startRank = -1.0 + var cachedUids = ArrayBuffer.empty[Long] + val flush: () => Iterable[(Long, (Int, Double))] = () => { + val averageRank = startRank + (cachedUids.size - 1) / 2.0 + val output = cachedUids.map { uid => + (uid, (preCol, averageRank)) } + cachedUids.clear() + output } + iter.flatMap { case (((j, v), uid), rank) => + // If we see a new value or cachedUids is too big, we flush ids with their average rank. + if (j != preCol || v != preVal || cachedUids.size >= 10000000) { + val output = flush() + preCol = j + preVal = v + startRank = rank + cachedUids += uid + output + } else { + cachedUids += uid + Iterator.empty + } + } ++ flush() } - ranks - } - - private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = { - val partitioner = new HashPartitioner(input.partitions.size) - val cogrouped = new CoGroupedRDD[Long](ranks, partitioner) - cogrouped.map { - case (_, values: Array[Iterable[_]]) => - val doubles = values.asInstanceOf[Array[Iterable[Double]]] - new DenseVector(doubles.flatten.toArray) + // Replace values in the input matrix by their ranks compared with values in the same column. + // Note that shifting all ranks in a column by a constant value doesn't affect result. + val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) => + // sort by column index and then convert values to a vector + Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray) } + PearsonCorrelation.computeCorrelationMatrix(groupedRanks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 8f6752737402e..0089419c2c5d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -20,11 +20,13 @@ package org.apache.spark.mllib.stat.test import breeze.linalg.{DenseMatrix => BDM} import cern.jet.stat.Probability.chiSquareComplemented -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import scala.collection.mutable + /** * Conduct the chi-squared test for the input RDDs using the specified method. * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted @@ -56,7 +58,7 @@ private[stat] object ChiSqTest extends Logging { object NullHypothesis extends Enumeration { type NullHypothesis = Value val goodnessOfFit = Value("observed follows the same distribution as expected.") - val independence = Value("observations in each column are statistically independent.") + val independence = Value("the occurrence of the outcomes is statistically independent.") } // Method identification based on input methodName string @@ -75,21 +77,42 @@ private[stat] object ChiSqTest extends Logging { */ def chiSquaredFeatures(data: RDD[LabeledPoint], methodName: String = PEARSON.name): Array[ChiSqTestResult] = { + val maxCategories = 10000 val numCols = data.first().features.size val results = new Array[ChiSqTestResult](numCols) var labels: Map[Double, Int] = null - // At most 100 columns at a time - val batchSize = 100 + // at most 1000 columns at a time + val batchSize = 1000 var batch = 0 while (batch * batchSize < numCols) { // The following block of code can be cleaned up and made public as // chiSquared(data: RDD[(V1, V2)]) val startCol = batch * batchSize val endCol = startCol + math.min(batchSize, numCols - startCol) - val pairCounts = data.flatMap { p => - // assume dense vectors - p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) => - (col, feature, p.label) + val pairCounts = data.mapPartitions { iter => + val distinctLabels = mutable.HashSet.empty[Double] + val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] = + Map((startCol until endCol).map(col => (col, mutable.HashSet.empty[Double])): _*) + var i = 1 + iter.flatMap { case LabeledPoint(label, features) => + if (i % 1000 == 0) { + if (distinctLabels.size > maxCategories) { + throw new SparkException(s"Chi-square test expect factors (categorical values) but " + + s"found more than $maxCategories distinct label values.") + } + allDistinctFeatures.foreach { case (col, distinctFeatures) => + if (distinctFeatures.size > maxCategories) { + throw new SparkException(s"Chi-square test expect factors (categorical values) but " + + s"found more than $maxCategories distinct values in column $col.") + } + } + } + i += 1 + distinctLabels += label + features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) => + allDistinctFeatures(col) += feature + (col, feature, label) + } } }.countByValue() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index 2f278621335e1..4784f9e947908 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -44,6 +44,11 @@ trait TestResult[DF] { */ def statistic: Double + /** + * Null hypothesis of the test. + */ + def nullHypothesis: String + /** * String explaining the hypothesis test result. * Specific classes implementing this trait should override this method to output test-specific @@ -53,13 +58,13 @@ trait TestResult[DF] { // String explaining what the p-value indicates. val pValueExplain = if (pValue <= 0.01) { - "Very strong presumption against null hypothesis." + s"Very strong presumption against null hypothesis: $nullHypothesis." } else if (0.01 < pValue && pValue <= 0.05) { - "Strong presumption against null hypothesis." - } else if (0.05 < pValue && pValue <= 0.01) { - "Low presumption against null hypothesis." + s"Strong presumption against null hypothesis: $nullHypothesis." + } else if (0.05 < pValue && pValue <= 0.1) { + s"Low presumption against null hypothesis: $nullHypothesis." } else { - "No presumption against null hypothesis." + s"No presumption against null hypothesis: $nullHypothesis." } s"degrees of freedom = ${degreesOfFreedom.toString} \n" + @@ -70,19 +75,18 @@ trait TestResult[DF] { /** * :: Experimental :: - * Object containing the test results for the chi squared hypothesis test. + * Object containing the test results for the chi-squared hypothesis test. */ @Experimental -class ChiSqTestResult(override val pValue: Double, +class ChiSqTestResult private[stat] (override val pValue: Double, override val degreesOfFreedom: Int, override val statistic: Double, val method: String, - val nullHypothesis: String) extends TestResult[Int] { + override val nullHypothesis: String) extends TestResult[Int] { override def toString: String = { - "Chi squared test summary: \n" + - s"method: $method \n" + - s"null hypothesis: $nullHypothesis \n" + - super.toString + "Chi squared test summary:\n" + + s"method: $method\n" + + super.toString } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index bb50f07be5d7b..6b9a8f72c244e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,22 +17,24 @@ package org.apache.spark.mllib.tree -import org.apache.spark.api.java.JavaRDD - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom + /** * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. @@ -53,39 +55,45 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Cache input RDD for speedup during multiple passes. - val retaggedInput = input.retag(classOf[LabeledPoint]).cache() + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) + timer.start("findSplitsBins") + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) val numBins = bins(0).length + timer.stop("findSplitsBins") logDebug("numBins = " + numBins) + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) + .persist(StorageLevel.MEMORY_AND_DISK) + + val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 - // Initialize an array to hold filters applied to points for each node. - val filters = new Array[List[Filter]](maxNumNodes) - // The filter at the top node is an empty list. - filters(0) = List() + val maxNumNodes = (2 << maxDepth) - 1 // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - // num features - val numFeatures = retaggedInput.take(1)(0).features.size // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, - strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, - strategy.algo) + val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -96,12 +104,13 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) + timer.stop("init") + /* * The main idea here is to perform level-wise training of the decision tree nodes thus * reducing the passes over the data from l to log2(l) where l is the total number of nodes. - * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., - * the sample is only used for the split calculation at the node if the sampled would have - * still survived the filters of the parent nodes. + * Each data sample is handled by a particular node at that level (or it reaches a leaf + * beforehand and is not used in later levels. */ var level = 0 @@ -113,18 +122,39 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup) + timer.start("findBestSplits") + val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, + metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) + timer.stop("findBestSplits") + val levelNodeIndexOffset = (1 << level) - 1 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - // Extract info for nodes at the current level. + val nodeIndex = levelNodeIndexOffset + index + val isLeftChild = level != 0 && nodeIndex % 2 == 1 + val parentNodeIndex = if (isLeftChild) { // -1 for root node + (nodeIndex - 1) / 2 + } else { + (nodeIndex - 2) / 2 + } + // Extract info for this node (index) at the current level. + timer.start("extractNodeInfo") extractNodeInfo(nodeSplitStats, level, index, nodes) + timer.stop("extractNodeInfo") + if (level != 0) { + // Set parent. + if (isLeftChild) { + nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) + } else { + nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) + } + } // Extract info for nodes at the next lower level. - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, - filters) + timer.start("extractInfoForLowerLevels") + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) + timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } - require(math.pow(2, level) == splitsStatsForLevel.length) + require((1 << level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -144,6 +174,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + new DecisionTreeModel(topNode, strategy.algo) } @@ -157,7 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = math.pow(2, level).toInt - 1 + index + val nodeIndex = (1 << level) - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -172,31 +207,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), - parentImpurities: Array[Double], - filters: Array[List[Filter]]): Unit = { - // 0 corresponds to the left child node and 1 corresponds to the right child node. - var i = 0 - while (i <= 1) { - // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth) { - val impurity = if (i == 0) { - nodeSplitStats._2.leftImpurity - } else { - nodeSplitStats._2.rightImpurity - } - logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - // noting the parent impurities - parentImpurities(nodeIndex) = impurity - // noting the parents filters for the child nodes - val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) - filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - for (filter <- filters(nodeIndex)) { - logDebug("Filter = " + filter) - } - } - i += 1 + parentImpurities: Array[Double]): Unit = { + + if (level >= maxDepth) { + return } + + val leftNodeIndex = (2 << level) - 1 + 2 * index + val leftImpurity = nodeSplitStats._2.leftImpurity + logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) + parentImpurities(leftNodeIndex) = leftImpurity + + val rightNodeIndex = leftNodeIndex + 1 + val rightImpurity = nodeSplitStats._2.rightImpurity + logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity) + parentImpurities(rightNodeIndex) = rightImpurity } } @@ -406,72 +431,70 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for constructing the DecisionTree + * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array of splits with best splits for all nodes at a given level. + * @return array (over nodes) of splits with best split for each node at a given level. */ protected[tree] def findBestSplits( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: DecisionTreeMetadata, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], - maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + maxLevelForSingleGroup: Int, + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, numGroups, groupIndex) + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level, + nodes, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer) } } - /** + /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for constructing the DecisionTree + * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features - * @param bins possible bins for all features + * @param bins possible bins for all features, indexed as (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: DecisionTreeMetadata, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], + timer: TimeTracker, numGroups: Int = 1, groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { @@ -487,7 +510,7 @@ object DecisionTree extends Serializable with Logging { * We use a bin-wise best split computation strategy instead of a straightforward best split * computation strategy. Instead of analyzing each sample for contribution to the left/right * child node impurity of every split, we first categorize each feature of a sample into a - * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * bin. Each bin is an interval between a low and high split. Since each split, and thus bin, * is ordered (read ordering for categorical variables in the findSplitsBins method), * we exploit this structure to calculate aggregates for bins and then use these aggregates * to calculate information gain for each split. @@ -503,258 +526,124 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = math.pow(2, level).toInt / numGroups + val numNodes = (1 << level) / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size + val numFeatures = metadata.numFeatures logDebug("numFeatures = " + numFeatures) // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) - val numClasses = strategy.numClassesForClassification + val numClasses = metadata.numClasses logDebug("numClasses = " + numClasses) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) - val isMulticlassClassificationWithCategoricalFeatures - = strategy.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + - isMulticlassClassificationWithCategoricalFeatures) + val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures + logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - /** Find the filters used before reaching the current code. */ - def findParentFilters(nodeIndex: Int): List[Filter] = { - if (level == 0) { - List[Filter]() - } else { - val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift - filters(nodeFilterIndex) - } - } - /** - * Find whether the sample is valid input for the current node, i.e., whether it passes through - * all the filters for the current node. + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a node + * at the current level being trained; that node's index is returned. + * + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { - // leaf - if ((level > 0) && (parentFilters.length == 0)) { - return false - } - - // Apply each filter and check sample validity. Return false when invalid condition found. - for (filter <- parentFilters) { - val features = labeledPoint.features - val featureIndex = filter.split.feature - val threshold = filter.split.threshold - val comparison = filter.comparison - val categories = filter.split.categories - val isFeatureContinuous = filter.split.featureType == Continuous - val feature = features(featureIndex) - if (isFeatureContinuous) { - comparison match { - case -1 => if (feature > threshold) return false - case 1 => if (feature <= threshold) return false + def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { + if (node.isLeaf) { + node.id + } else { + val featureIndex = node.split.get.feature + val splitLeft = node.split.get.featureType match { + case Continuous => { + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] + // We do not need to check lowSplit since bins are separated by splits. + featureValueUpperBound <= node.split.get.threshold } - } else { - val containsFeature = categories.contains(feature) - comparison match { - case -1 => if (!containsFeature) return false - case 1 => if (containsFeature) return false + case Categorical => { + val featureValue = if (metadata.isUnordered(featureIndex)) { + binnedFeatures(featureIndex) + } else { + val binIndex = binnedFeatures(featureIndex) + bins(featureIndex)(binIndex).category + } + node.split.get.categories.contains(featureValue) } - + case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") } - } - - // Return true when the sample is valid for all filters. - true - } - - /** - * Find bin for one (labeledPoint, feature). - */ - def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, - isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } - else if (lowThreshold >= feature) { - right = mid - 1 - } - else { - left = mid + 1 + if (node.leftNode.isEmpty || node.rightNode.isEmpty) { + // Return index from next layer of nodes to train + if (splitLeft) { + node.id * 2 + 1 // left + } else { + node.id * 2 + 2 // right } - } - -1 - } - - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex + } else { + if (splitLeft) { + predictNodeIndex(node.leftNode.get, binnedFeatures) + } else { + predictNodeIndex(node.rightNode.get, binnedFeatures) } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) } - -1 } + } - if (isFeatureContinuous) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new UnknownError("no bin was found for continuous variable.") - } - binIndex + def nodeIndexToLevel(idx: Int): Int = { + if (idx == 0) { + 0 } else { - // Perform sequential search to find bin for categorical features. - val binIndex = { - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - } - if (binIndex == -1) { - throw new UnknownError("no bin was found for categorical variable.") - } - binIndex + math.floor(math.log(idx) / math.log(2)).toInt } } + // Used for treePointToNodeIndex + val levelOffset = (1 << level) - 1 + /** - * Finds bins for all nodes (and all features) at a given level. - * For l nodes, k features the storage is as follows: - * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1 for regressions and binary - * classification and the categorical feature value in multiclass classification. - * Invalid sample is denoted by noting bin for feature 1 as -1. - * - * For unordered features, the "bin index" returned is actually the feature value (category). - * - * @return Array of size 1 + numFeatures * numNodes, where - * arr(0) = label for labeledPoint, and - * arr(1 + numFeatures * nodeIndex + featureIndex) = - * bin index for this labeledPoint - * (or InvalidBinIndex if labeledPoint is not handled by this node) + * Find the node index for the given example. + * Nodes are indexed from 0 at the start of this (level, group). + * If the example does not reach this level, returns a value < 0. */ - def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { - // Calculate bin index and label per feature per node. - val arr = new Array[Double](1 + (numFeatures * numNodes)) - // First element of the array is the label of the instance. - arr(0) = labeledPoint.label - // Iterate over nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - val parentFilters = findParentFilters(nodeIndex) - // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + numFeatures * nodeIndex - if (!sampleValid) { - // Mark one bin as -1 is sufficient. - arr(shift) = InvalidBinIndex - } else { - var featureIndex = 0 - while (featureIndex < numFeatures) { - val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isSpaceSufficientForAllCategoricalSplits) - } - featureIndex += 1 - } - } - nodeIndex += 1 + def treePointToNodeIndex(treePoint: TreePoint): Int = { + if (level == 0) { + 0 + } else { + val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures) + // Get index for this (level, group). + globalNodeIndex - levelOffset - groupShift } - arr } - // Find feature bins for all nodes at a level. - val binMappedRDD = input.map(x => findBinsForLevel(x)) - /** * Increment aggregate in location for (node, feature, bin, label). * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ def updateBinForOrderedFeature( - arr: Array[Double], + treePoint: TreePoint, agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int): Unit = { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. val aggIndex = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - numClasses * arr(arrIndex).toInt + - label.toInt + numClasses * treePoint.binnedFeatures(featureIndex) + + treePoint.label.toInt agg(aggIndex) += 1 } @@ -763,8 +652,8 @@ object DecisionTree extends Serializable with Logging { * where [bins] ranges over all bins. * Updates left or right side of aggregate depending on split. * - * @param arr arr(0) = label. - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param treePoint Data point being aggregated. * @param agg Indexed by (left/right, node, feature, bin, label) * where label is the least significant bit. * The left/right specifier is a 0/1 index indicating left/right child info. @@ -773,21 +662,18 @@ object DecisionTree extends Serializable with Logging { def updateBinForUnorderedFeature( nodeIndex: Int, featureIndex: Int, - arr: Array[Double], - label: Double, + treePoint: TreePoint, agg: Array[Double], rightChildShift: Int): Unit = { - // Find the bin index for this feature. - val arrIndex = 1 + numFeatures * nodeIndex + featureIndex - val featureValue = arr(arrIndex).toInt + val featureValue = treePoint.binnedFeatures(featureIndex) // Update the left or right count for one bin. val aggShift = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - label.toInt + treePoint.label.toInt // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureCategories = metadata.featureArity(featureIndex) + val numCategoricalBins = (1 << featureCategories - 1) - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val aggIndex = aggShift + binIndex * numClasses @@ -803,80 +689,51 @@ object DecisionTree extends Serializable with Logging { /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 + def binaryOrNotCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) + featureIndex += 1 } } + val rightChildShift = numClasses * numBins * numFeatures * numNodes + /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). - * For ordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. - * For unordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). * @param agg Array storing aggregate calculation. * For ordered features, this is of size: * numClasses * numBins * numFeatures * numNodes. * For unordered features, this is of size: * 2 * numClasses * numBins * numFeatures * numNodes. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - val rightChildShift = numClasses * numBins * numFeatures * numNodes - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, - rightChildShift) - } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } - } - featureIndex += 1 - } + def multiclassWithCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) + } else { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) } - nodeIndex += 1 + featureIndex += 1 } } @@ -887,36 +744,25 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, updated by this function. * Size: 3 * numBins * numFeatures * numNodes - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update count, sum, and sum^2 for one bin. - val aggShift = 3 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 - agg(aggIndex) = agg(aggIndex) + 1 - agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label*label - featureIndex += 1 - } - } - nodeIndex += 1 + def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Update count, sum, and sum^2 for one bin. + val binIndex = treePoint.binnedFeatures(featureIndex) + val aggIndex = + 3 * numBins * numFeatures * nodeIndex + + 3 * numBins * featureIndex + + 3 * binIndex + agg(aggIndex) += 1 + agg(aggIndex + 1) += label + agg(aggIndex + 2) += label * label + featureIndex += 1 } } @@ -935,26 +781,30 @@ object DecisionTree extends Serializable with Logging { * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. * Size for regression: * 3 * numBins * numFeatures * numNodes. - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @return agg */ - def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => - if(isMulticlassClassificationWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(arr, agg) + def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + val nodeIndex = treePointToNodeIndex(treePoint) + // If the example does not reach this level, then nodeIndex < 0. + // If the example reaches this level but is handled in a different group, + // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). + if (nodeIndex >= 0 && nodeIndex < numNodes) { + if (metadata.isClassification) { + if (isMulticlassWithCategoricalFeatures) { + multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) } else { - binaryOrNotCategoricalBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) } - case Regression => regressionBinSeqOp(arr, agg) + } else { + regressionBinSeqOp(agg, treePoint, nodeIndex) + } } agg } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, - isMulticlassClassificationWithCategoricalFeatures, strategy.algo) + val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -974,135 +824,134 @@ object DecisionTree extends Serializable with Logging { } // Calculate bin aggregates. + timer.start("aggregation") val binAggregates = { - binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } + timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) /** - * Calculates the information gain for all splits based upon left/right split aggregates. - * @param leftNodeAgg left node aggregates - * @param featureIndex feature index - * @param splitIndex split index - * @param rightNodeAgg right node aggregate + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftNodeAgg left node aggregates for this (feature, split) + * @param rightNodeAgg right node aggregate for this (feature, split) * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Array[Double]]], + leftNodeAgg: Array[Double], + rightNodeAgg: Array[Double], topImpurity: Double): InformationGainStats = { - strategy.algo match { - case Classification => - val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) - val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) - val leftTotalCount = leftCounts.sum - val rightTotalCount = rightCounts.sum - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) - classIndex += 1 - } - strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) + if (metadata.isClassification) { + val leftTotalCount = leftNodeAgg.sum + val rightTotalCount = rightNodeAgg.sum + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) + classIndex += 1 } + metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) } + } - val totalCount = leftTotalCount + rightTotalCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) - } + val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } - // Sum of count for each label - val leftRightCounts: Array[Double] = - leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => - leftCount + rightCount - } + // Sum of count for each label + val leftrightNodeAgg: Array[Double] = + leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => + leftCount + rightCount + } - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) - else (maxIndex, maxValue, currentIndex + 1) - } - if (result._1 < 0) 0 else result._1 + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } } + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") + } + result._1 + } - val predict = indexOfLargestArrayElement(leftRightCounts) - val prob = leftRightCounts(predict) / totalCount + val predict = indexOfLargestArrayElement(leftrightNodeAgg) + val prob = leftrightNodeAgg(predict) / totalCount - val leftImpurity = if (leftTotalCount == 0) { - topImpurity - } else { - strategy.impurity.calculate(leftCounts, leftTotalCount) - } - val rightImpurity = if (rightTotalCount == 0) { + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + metadata.impurity.calculate(leftNodeAgg, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + metadata.impurity.calculate(rightNodeAgg, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + + } else { + // Regression + + val leftCount = leftNodeAgg(0) + val leftSum = leftNodeAgg(1) + val leftSumSquares = leftNodeAgg(2) + + val rightCount = rightNodeAgg(0) + val rightSum = rightNodeAgg(1) + val rightSumSquares = rightNodeAgg(2) + + val impurity = { + if (level > 0) { topImpurity } else { - strategy.impurity.calculate(rightCounts, rightTotalCount) - } - - val leftWeight = leftTotalCount / totalCount - val rightWeight = rightTotalCount / totalCount - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) - case Regression => - val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) - val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) - val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) - - val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) - val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) - val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - strategy.impurity.calculate(count, sum, sumSquares) - } + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + metadata.impurity.calculate(count, sum, sumSquares) } + } - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum / leftCount) - } + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, + Double.MinValue, leftSum / leftCount) + } - val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } } @@ -1125,6 +974,19 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + /** + * The input binData is indexed as (feature, bin, class). + * This computes cumulative sums over splits. + * Each (feature, class) pair is handled separately. + * Note: numSplits = numBins - 1. + * @param leftNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 0, ..., numSplits - 2) is set to be + * the cumulative sum (from left) over binData for bins 0, ..., i. + * @param rightNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 1, ..., numSplits - 1) is set to be + * the cumulative sum (from right) over binData for bins + * numBins - 1, ..., numBins - 1 - i. + */ def findAggForOrderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1229,45 +1091,32 @@ object DecisionTree extends Serializable with Logging { } } - strategy.algo match { - case Classification => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - } - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - - (leftNodeAgg, rightNodeAgg) - case Regression => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 + if (metadata.isClassification) { + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } - (leftNodeAgg, rightNodeAgg) + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } else { + // Regression + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) } } @@ -1280,15 +1129,38 @@ object DecisionTree extends Serializable with Logging { nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins - 1) { - gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, - splitIndex, rightNodeAgg, nodeImpurity) + var featureIndex = 0 + while (featureIndex < numFeatures) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplitsForFeature) { + gains(featureIndex)(splitIndex) = + calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), + rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) + splitIndex += 1 } + featureIndex += 1 } gains } + /** + * Get the number of splits for a feature. + */ + def getNumSplitsForFeature(featureIndex: Int): Int = { + if (metadata.isContinuous(featureIndex)) { + numBins - 1 + } else { + // Categorical feature + val featureCategories = metadata.featureArity(featureIndex) + if (metadata.isUnordered(featureIndex)) { + (1 << featureCategories - 1) - 1 + } else { + featureCategories + } + } + } + /** * Find the best split for a node. * @param binData Bin data slice for this node, given by getBinDataForNode. @@ -1307,7 +1179,7 @@ object DecisionTree extends Serializable with Logging { // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { + val (bestFeatureIndex, bestSplitIndex, gainStats) = { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue @@ -1317,22 +1189,8 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex: Double = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - numBins - 1 - } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 - } else { // Binary classification - featureCategories - } - } - } - while (splitIndex < maxSplitIndex) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + while (splitIndex < numSplitsForFeature) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -1356,38 +1214,39 @@ object DecisionTree extends Serializable with Logging { * Get bin data for one node. */ def getBinDataForNode(node: Int): Array[Double] = { - strategy.algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + if (metadata.isClassification) { + if (isMulticlassWithCategoricalFeatures) { + val shift = numClasses * node * numBins * numFeatures + val rightChildShift = numClasses * numBins * numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + numClasses * numBins * numFeatures) + leftChildData ++ rightChildData } - case Regression => - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) binsForNode + } else { + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + binsForNode + } + } else { + // Regression + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) + binsForNode } } // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift + val nodeImpurityIndex = (1 << level) - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) @@ -1395,6 +1254,8 @@ object DecisionTree extends Serializable with Logging { bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } + timer.stop("chooseSplits") + bestSplits } @@ -1403,20 +1264,15 @@ object DecisionTree extends Serializable with Logging { * * @param numBins Number of bins = 1 + number of possible splits. */ - private def getElementsPerNode( - numFeatures: Int, - numBins: Int, - numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, - algo: Algo): Int = { - algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - 2 * numClasses * numBins * numFeatures - } else { - numClasses * numBins * numFeatures - } - case Regression => 3 * numBins * numFeatures + private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = { + if (metadata.isClassification) { + if (metadata.isMulticlassWithCategoricalFeatures) { + 2 * metadata.numClasses * numBins * metadata.numFeatures + } else { + metadata.numClasses * numBins * metadata.numFeatures + } + } else { + 3 * numBins * metadata.numFeatures } } @@ -1435,16 +1291,15 @@ object DecisionTree extends Serializable with Logging { * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are math.pow(2, maxFeatureValue - 1) - 1 splits. + * There are (1 << maxFeatureValue - 1) - 1 splits. * (b) "ordered features" * For regression and binary classification, * and for multiclass classification with a high-arity feature, * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return A tuple of (splits,bins). + * @param metadata Learning and dataset metadata + * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numBins - 1). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] @@ -1452,19 +1307,18 @@ object DecisionTree extends Serializable with Logging { */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() // Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.size - val maxBins = strategy.maxBins + val maxBins = metadata.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) - + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) /* * Ensure numBins is always greater than the categories. For multiclass classification, @@ -1476,13 +1330,12 @@ object DecisionTree extends Serializable with Logging { * by the number of training examples. * TODO: Allow this case, where we simply will know nothing about some categories. */ - if (strategy.categoricalFeaturesInfo.size > 0) { - val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + if (metadata.featureArity.size > 0) { + val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + "in categorical features") } - // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1496,7 +1349,7 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - strategy.quantileCalculationStrategy match { + metadata.quantileStrategy match { case Sort => val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) @@ -1507,7 +1360,7 @@ object DecisionTree extends Serializable with Logging { var featureIndex = 0 while (featureIndex < numFeatures) { // Check whether the feature is continuous. - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = metadata.isContinuous(featureIndex) if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / numBins @@ -1520,18 +1373,14 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(index) = split } } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val featureCategories = metadata.featureArity(featureIndex) // Use different bin/split calculation strategy for categorical features in multiclass // classification that satisfy the space constraint. - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { + if (metadata.isUnordered(featureIndex)) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 - while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { + while (index < (1 << featureCategories - 1) - 1) { val categories: List[Double] = extractMultiClassCategories(index + 1, featureCategories) splits(featureIndex)(index) @@ -1561,7 +1410,7 @@ object DecisionTree extends Serializable with Logging { * centroidForCategories is a mapping: category (for the given feature) --> centroid */ val centroidForCategories = { - if (isMulticlassClassification) { + if (isMulticlass) { // For categorical variables in multiclass classification, // each bin is a category. The bins are sorted and they // are ordered by calculating the impurity of their corresponding labels. @@ -1569,7 +1418,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) + .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1621,7 +1470,7 @@ object DecisionTree extends Serializable with Logging { // Find all bins. featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = metadata.isContinuous(featureIndex) if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) @@ -1635,7 +1484,7 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - (splits,bins) + (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index f31a503608b22..cfc8192a85abd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,22 +27,30 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** * :: Experimental :: * Stores all the configuration options for tree construction - * @param algo classification or regression - * @param impurity criterion used for information gain calculation + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param impurity Criterion used for information gain calculation. + * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.Entropy]]. + * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * Default value is 2 (binary classification). + * @param maxBins Maximum number of bins used for discretizing continuous features and + * for choosing how to split on features at each node. + * More bins give higher granularity. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * */ @Experimental class Strategy ( @@ -64,20 +72,7 @@ class Strategy ( = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** - * Java-friendly constructor. - * - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, an entry - * (n -> k) implies the feature n is categorical with k categories - * 0, 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ def this( algo: Algo, @@ -90,6 +85,10 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Check validity of parameters. + * Throws exception if invalid. + */ private[tree] def assertValid(): Unit = { algo match { case Classification => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala new file mode 100644 index 0000000000000..d9eda354dc986 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.rdd.RDD + + +/** + * Learning and dataset metadata for DecisionTree. + * + * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. + * For regression: fixed at 0 (no meaning). + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + */ +private[tree] class DecisionTreeMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + +} + +private[tree] object DecisionTreeMetadata { + + def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { + + val numFeatures = input.take(1)(0).features.size + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClassesForClassification + case Regression => 0 + } + + val maxBins = math.min(strategy.maxBins, numExamples).toInt + val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0) + + val unorderedFeatures = new mutable.HashSet[Int]() + if (numClasses > 2) { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + if (k - 1 < log2MaxBinsp1) { + // Note: The above check is equivalent to checking: + // numUnorderedBins = (1 << k - 1) - 1 < maxBins + unorderedFeatures.add(f) + } else { + // TODO: Allow this case, where we simply will know nothing about some categories? + require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + + s"in categorical features (>= $k)") + } + } + } else { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + + s"in categorical features (>= $k)") + } + } + + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, + strategy.impurity, strategy.quantileCalculationStrategy) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala new file mode 100644 index 0000000000000..d215d68c4279e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable.{HashMap => MutableHashMap} + +import org.apache.spark.annotation.Experimental + +/** + * Time tracker implementation which holds labeled timers. + */ +@Experimental +private[tree] class TimeTracker extends Serializable { + + private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + /** + * Starts a new timer, or re-starts a stopped timer. + */ + def start(timerLabel: String): Unit = { + val currentTime = System.nanoTime() + if (starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + + s" timerLabel = $timerLabel before that timer was stopped.") + } + starts(timerLabel) = currentTime + } + + /** + * Stops a timer and returns the elapsed time in seconds. + */ + def stop(timerLabel: String): Double = { + val currentTime = System.nanoTime() + if (!starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + + s" timerLabel = $timerLabel, but that timer was not started.") + } + val elapsed = currentTime - starts(timerLabel) + starts.remove(timerLabel) + if (totals.contains(timerLabel)) { + totals(timerLabel) += elapsed + } else { + totals(timerLabel) = elapsed + } + elapsed / 1e9 + } + + /** + * Print all timing results in seconds. + */ + override def toString: String = { + totals.map { case (label, elapsed) => + s" $label: ${elapsed / 1e9}" + }.mkString("\n") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala new file mode 100644 index 0000000000000..170e43e222083 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.Bin +import org.apache.spark.rdd.RDD + + +/** + * Internal representation of LabeledPoint for DecisionTree. + * This bins feature values based on a subsampled of data as follows: + * (a) Continuous features are binned into ranges. + * (b) Unordered categorical features are binned based on subsets of feature values. + * "Unordered categorical features" are categorical features with low arity used in + * multiclass classification. + * (c) Ordered categorical features are binned based on feature values. + * "Ordered categorical features" are categorical features with high arity, + * or any categorical feature used in regression or binary classification. + * + * @param label Label from LabeledPoint + * @param binnedFeatures Binned feature values. + * Same length as LabeledPoint.features, but values are bin indices. + */ +private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) + extends Serializable { +} + +private[tree] object TreePoint { + + /** + * Convert an input dataset into its TreePoint representation, + * binning feature values in preparation for DecisionTree training. + * @param input Input dataset. + * @param bins Bins for features, of size (numFeatures, numBins). + * @param metadata Learning and dataset metadata + * @return TreePoint dataset representation + */ + def convertToTreeRDD( + input: RDD[LabeledPoint], + bins: Array[Array[Bin]], + metadata: DecisionTreeMetadata): RDD[TreePoint] = { + input.map { x => + TreePoint.labeledPointToTreePoint(x, bins, metadata) + } + } + + /** + * Convert one LabeledPoint into its TreePoint representation. + * @param bins Bins for features, of size (numFeatures, numBins). + */ + private def labeledPointToTreePoint( + labeledPoint: LabeledPoint, + bins: Array[Array[Bin]], + metadata: DecisionTreeMetadata): TreePoint = { + + val numFeatures = labeledPoint.features.size + val numBins = bins(0).size + val arr = new Array[Int](numFeatures) + var featureIndex = 0 + while (featureIndex < numFeatures) { + arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex), + metadata.isUnordered(featureIndex), bins, metadata.featureArity) + featureIndex += 1 + } + + new TreePoint(labeledPoint.label, arr) + } + + /** + * Find bin for one (labeledPoint, feature). + * + * @param isUnorderedFeature (only applies if feature is categorical) + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isUnorderedFeature: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): Int = { + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) && (highThreshold >= feature)) { + return mid + } else if (lowThreshold >= feature) { + right = mid - 1 + } else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + + /** + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). + */ + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { + val featureCategories = categoricalFeaturesInfo(featureIndex) + val featureValue = labeledPoint.features(featureIndex) + var binIndex = 0 + while (binIndex < featureCategories) { + val bin = bins(featureIndex)(binIndex) + val categories = bin.highSplit.categories + if (categories.contains(featureValue)) { + return binIndex + } + binIndex += 1 + } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1) { + throw new RuntimeException("No bin was found for continuous feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = if (isUnorderedFeature) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeature() + } + if (binIndex == -1) { + throw new RuntimeException("No bin was found for categorical feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index c89c1e371a40e..af35d88f713e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ /** - * Used for "binning" the features bins for faster best split calculation. For a continuous - * feature, a bin is determined by a low and a high "split". For a categorical feature, - * the a bin is determined using a single label value (category). + * Used for "binning" the features bins for faster best split calculation. + * + * For a continuous feature, the bin is determined by a low and a high split, + * where an example with featureValue falls into the bin s.t. + * lowSplit.threshold < featureValue <= highSplit.threshold. + * + * For ordered categorical features, there is a 1-1-1 correspondence between + * bins, splits, and feature values. The bin is determined by category/feature value. + * However, the bins are not necessarily ordered by feature value; + * they are ordered using impurity. + * For unordered categorical features, there is a 1-1 correspondence between bins, splits, + * where bins and splits correspond to subsets of feature values (in highSplit.categories). + * * @param lowSplit signifying the lower threshold for the continuous feature to be * accepted in the bin * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for binary classification + * @param category categorical label value accepted in the bin for ordered features */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 3d3406b5d5f22..0594fd0749d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @return Double prediction from the trained model */ def predict(features: Vector): Double = { - topNode.predictIfLeaf(features) + topNode.predict(features) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 944f11c2c2e4f..0eee6262781c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -69,24 +69,24 @@ class Node ( /** * predict value if node is not leaf - * @param feature feature value + * @param features feature value * @return predicted value */ - def predictIfLeaf(feature: Vector) : Double = { + def predict(features: Vector) : Double = { if (isLeaf) { predict } else{ if (split.get.featureType == Continuous) { - if (feature(split.get.feature) <= split.get.threshold) { - leftNode.get.predictIfLeaf(feature) + if (features(split.get.feature) <= split.get.threshold) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } else { - if (split.get.categories.contains(feature(split.get.feature))) { - leftNode.get.predictIfLeaf(feature) + if (split.get.categories.contains(features(split.get.feature))) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index d7ffd386c05ee..50fb48b40de3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * :: DeveloperApi :: * Split applied to a feature * @param feature feature index - * @param threshold threshold for continuous feature + * @param threshold Threshold for continuous feature. + * Split left if feature <= threshold, else right. * @param featureType type of feature -- categorical or continuous - * @param categories accepted values for categorical variables + * @param categories Split left if categorical feature value is in this set, else right. */ @DeveloperApi case class Split( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index f4cce86a65ba7..ca35100aa99c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD import org.apache.spark.util.random.BernoulliSampler -import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext @@ -185,7 +185,7 @@ object MLUtils { * @return labeled points stored as an RDD[LabeledPoint] */ def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = - sc.textFile(path, minPartitions).map(LabeledPointParser.parse) + sc.textFile(path, minPartitions).map(LabeledPoint.parse) /** * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of @@ -194,19 +194,6 @@ object MLUtils { def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) - /** - * Loads streaming labeled points from a stream of text files - * where points are in the same format as used in `RDD[LabeledPoint].saveAsTextFile`. - * See `StreamingContext.textFileStream` for more details on how to - * generate a stream from files - * - * @param ssc Streaming context - * @param dir Directory path in any Hadoop-supported file system URI - * @return Labeled points stored as a DStream[LabeledPoint] - */ - def loadStreamingLabeledPoints(ssc: StreamingContext, dir: String): DStream[LabeledPoint] = - ssc.textFileStream(dir).map(LabeledPointParser.parse) - /** * Load labeled data from a file. The data format used here is * , ... diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 2289c6cdc19de..862178694a50e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -185,6 +185,63 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("numerical stability of scaling features using logistic regression with LBFGS") { + /** + * If we rescale the features, the condition number will be changed so the convergence rate + * and the solution will not equal to the original solution multiple by the scaling factor + * which it should be. + * + * However, since in the LogisticRegressionWithLBFGS, we standardize the training dataset first, + * no matter how we multiple a scaling factor into the dataset, the convergence rate should be + * the same, and the solution should equal to the original solution multiple by the scaling + * factor. + */ + + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialWeights = Vectors.dense(0.0) + + val testRDD1 = sc.parallelize(testData, 2) + + val testRDD2 = sc.parallelize( + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E3))), 2) + + val testRDD3 = sc.parallelize( + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E6))), 2) + + testRDD1.cache() + testRDD2.cache() + testRDD3.cache() + + val lrA = new LogisticRegressionWithLBFGS().setIntercept(true) + val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + + val modelA1 = lrA.run(testRDD1, initialWeights) + val modelA2 = lrA.run(testRDD2, initialWeights) + val modelA3 = lrA.run(testRDD3, initialWeights) + + val modelB1 = lrB.run(testRDD1, initialWeights) + val modelB2 = lrB.run(testRDD2, initialWeights) + val modelB3 = lrB.run(testRDD3, initialWeights) + + // For model trained with feature standardization, the weights should + // be the same in the scaled space. Note that the weights here are already + // in the original space, we transform back to scaled space to compare. + assert(modelA1.weights(0) ~== modelA2.weights(0) * 1.0E3 absTol 0.01) + assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01) + + // Training data with different scales without feature standardization + // will not yield the same result in the scaled space due to poor + // convergence rate. + assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1) + assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) + } + } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { @@ -215,8 +272,9 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont }.cache() // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. - val model = - (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points) + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + lr.optimizer.setNumIterations(2) + val model = lr.run(points) val predictions = model.predict(points.map(_.features)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 5f4c24115ac80..ccba004baa007 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -55,7 +55,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val (_, loss) = LBFGS.runLBFGS( dataRDD, @@ -63,7 +63,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { simpleUpdater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -99,7 +99,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { // Prepare another non-zero weights to compare the loss in the first iteration. val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS( dataRDD, @@ -107,7 +107,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -140,10 +140,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { /** * For the first run, we set the convergenceTol to 0.0, so that the algorithm will - * run up to the maxNumIterations which is 8 here. + * run up to the numIterations which is 8 here. */ val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0) - val maxNumIterations = 8 + val numIterations = 8 var convergenceTol = 0.0 val (_, lossLBFGS1) = LBFGS.runLBFGS( @@ -152,7 +152,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -167,7 +167,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -182,7 +182,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -200,12 +200,12 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { // Prepare another non-zero weights to compare the loss in the first iteration. val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) .setNumCorrections(numCorrections) .setConvergenceTol(convergenceTol) - .setMaxNumIterations(maxNumIterations) + .setNumIterations(numIterations) .setRegParam(regParam) val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) @@ -241,7 +241,7 @@ class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater) .setNumCorrections(1) .setConvergenceTol(1e-12) - .setMaxNumIterations(1) + .setNumIterations(1) .setRegParam(1.0) val random = new Random(0) // If we serialize data directly in the task closure, the size of the serialized task would be diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala similarity index 88% rename from mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 96e0bc63b0fa4..c50b78bcbcc61 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Serializable { +class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, @@ -113,18 +113,18 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val poissonMean = 100.0 for (seed <- 0 until 5) { - val uniform = RandomRDDGenerators.uniformRDD(sc, size, numPartitions, seed) + val uniform = RandomRDDs.uniformRDD(sc, size, numPartitions, seed) testGeneratedRDD(uniform, size, numPartitions, 0.5, 1 / math.sqrt(12)) - val normal = RandomRDDGenerators.normalRDD(sc, size, numPartitions, seed) + val normal = RandomRDDs.normalRDD(sc, size, numPartitions, seed) testGeneratedRDD(normal, size, numPartitions, 0.0, 1.0) - val poisson = RandomRDDGenerators.poissonRDD(sc, poissonMean, size, numPartitions, seed) + val poisson = RandomRDDs.poissonRDD(sc, poissonMean, size, numPartitions, seed) testGeneratedRDD(poisson, size, numPartitions, poissonMean, math.sqrt(poissonMean), 0.1) } // mock distribution to check that partitions have unique seeds - val random = RandomRDDGenerators.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) + val random = RandomRDDs.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) assert(random.collect.size === random.collect.distinct.size) } @@ -135,13 +135,13 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val poissonMean = 100.0 for (seed <- 0 until 5) { - val uniform = RandomRDDGenerators.uniformVectorRDD(sc, rows, cols, parts, seed) + val uniform = RandomRDDs.uniformVectorRDD(sc, rows, cols, parts, seed) testGeneratedVectorRDD(uniform, rows, cols, parts, 0.5, 1 / math.sqrt(12)) - val normal = RandomRDDGenerators.normalVectorRDD(sc, rows, cols, parts, seed) + val normal = RandomRDDs.normalVectorRDD(sc, rows, cols, parts, seed) testGeneratedVectorRDD(normal, rows, cols, parts, 0.0, 1.0) - val poisson = RandomRDDGenerators.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) + val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index d9308aaba6ee1..110c44a7193fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -28,12 +28,12 @@ class LabeledPointSuite extends FunSuite { LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0)))) points.foreach { p => - assert(p === LabeledPointParser.parse(p.toString)) + assert(p === LabeledPoint.parse(p.toString)) } } test("parse labeled points with v0.9 format") { - val point = LabeledPointParser.parse("1.0,1.0 0.0 -2.0") + val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index ed21f84472c9a..45e25eecf508e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -26,7 +26,7 @@ import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext, MLUtils} +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils @@ -55,7 +55,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { val numBatches = 10 val batchDuration = Milliseconds(1000) val ssc = new StreamingContext(sc, batchDuration) - val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val data = ssc.textFileStream(testDir.toString).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.1) @@ -97,7 +97,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { val batchDuration = Milliseconds(2000) val ssc = new StreamingContext(sc, batchDuration) val numBatches = 5 - val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + val data = ssc.textFileStream(testDir.toString()).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0)) .setStepSize(0.1) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 5bd0521298c14..6de3840b3f198 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.mllib.stat +import java.util.Random + import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest @@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext { // labels: 1.0 (2 / 6), 0.0 (4 / 6) // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) - val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), - new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), - new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), - new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), - new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), - new LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + val data = Seq( + LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) for (numParts <- List(2, 4, 6, 8)) { val chi = Statistics.chiSqTest(sc.parallelize(data, numParts)) val feature1 = chi(0) @@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext { } // Test that the right number of results is returned - val numCols = 321 - val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), - new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0))))) + val numCols = 1001 + val sparseData = Array( + new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) assert(chi.size === numCols) + assert(chi(1000) != null) // SPARK-3087 + + // Detect continous features or labels + val random = new Random(11L) + val continuousLabel = + Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + intercept[SparkException] { + Statistics.chiSqTest(sc.parallelize(continuousLabel, 2)) + } + val continuousFeature = + Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + intercept[SparkException] { + Statistics.chiSqTest(sc.parallelize(continuousFeature, 2)) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 70ca7c8a266f2..2f36fd907772c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -21,11 +21,12 @@ import scala.collection.JavaConverters._ import org.scalatest.FunSuite -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -41,7 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { prediction != expected.label } val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy) + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") } def validateRegressor( @@ -54,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { err * err }.sum val mse = squaredError / input.length - assert(mse <= requiredMSE) + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") } test("split and bin calculation") { @@ -62,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -80,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -160,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Check splits. @@ -277,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -371,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -426,9 +433,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -453,9 +462,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -491,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -499,8 +511,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -513,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -521,8 +535,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -536,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -544,8 +560,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -559,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -567,8 +585,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -582,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -590,13 +610,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) - val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) - val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val nodes: Array[Node] = new Array[Node](7) + nodes(0) = modelOneNode.topNode + nodes(0).leftNode = None + nodes(0).rightNode = None + val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -604,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, - filters, splits, bins, 0) + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, + nodes, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) assert(bestSplitsWithGroups(1)._2.gain > 0) @@ -620,18 +646,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) } - } test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -647,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -678,19 +705,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + numClassesForClassification = 3, maxBins = maxBins, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -705,17 +735,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -729,17 +761,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -752,13 +786,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 diff --git a/pom.xml b/pom.xml index 920912353fe9c..ef12c8f1a5c49 100644 --- a/pom.xml +++ b/pom.xml @@ -316,7 +316,7 @@ org.xerial.snappy snappy-java - 1.0.5 + 1.1.1.3 net.jpountz.lz4 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 6e72035f2c15b..300589394b96f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -61,6 +61,17 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.storage.MemoryStore.Entry") ) ++ + Seq( + // Serializer interface change. See SPARK-3045. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.DeserializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.Serializer"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializerInstance") + )++ Seq( // Renamed putValues -> putArray + putIterator ProblemFilters.exclude[MissingMethodProblem]( @@ -117,6 +128,14 @@ object MimaExcludes { ) ++ Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") + ) ++ + Seq( // synthetic methods generated in LabeledPoint + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") + ) ++ + Seq ( // Scala 2.11 compatibility fix + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") ) case v if v.startsWith("1.0") => Seq( diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index f3e64989ed564..675a2fcd2ff4e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -21,18 +21,16 @@ >>> b = sc.broadcast([1, 2, 3, 4, 5]) >>> b.value [1, 2, 3, 4, 5] - ->>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.bid] = b ->>> from cPickle import dumps, loads ->>> loads(dumps(b)).value -[1, 2, 3, 4, 5] - >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +>>> b.unpersist() >>> large_broadcast = sc.broadcast(list(range(10000))) """ +import os + +from pyspark.serializers import CompressedSerializer, PickleSerializer + # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} @@ -52,17 +50,38 @@ class Broadcast(object): Access its value through C{.value}. """ - def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, + pickle_registry=None, path=None): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.value = value self.bid = bid + if path is None: + self.value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry + self.path = path + + def unpersist(self, blocking=False): + self._jbroadcast.unpersist(blocking) + os.unlink(self.path) def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) + + def __getattr__(self, item): + if item == 'value' and self.path is not None: + ser = CompressedSerializer(PickleSerializer()) + value = ser.load_stream(open(self.path)).next() + self.value = value + return value + + raise AttributeError(item) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4001ecab5ea00..a90870ed3a353 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer + PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -566,13 +566,19 @@ def broadcast(self, value): """ Broadcast a read-only variable to the cluster, returning a L{Broadcast} - object for reading it in distributed functions. The variable will be - sent to each cluster only once. + object for reading it in distributed functions. The variable will + be sent to each cluster only once. + + :keep: Keep the `value` in driver or not. """ - pickleSer = PickleSerializer() - pickled = pickleSer.dumps(value) - jbroadcast = self._jsc.broadcast(bytearray(pickled)) - return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + ser = CompressedSerializer(PickleSerializer()) + # pass large object by py4j is very slow and need much memory + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) + ser.dump_stream([value], tempFile) + tempFile.close() + jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) + return Broadcast(jbroadcast.id(), None, jbroadcast, + self._pickled_broadcast_vars, tempFile.name) def accumulator(self, value, accum_param=None): """ @@ -613,7 +619,7 @@ def addFile(self, path): >>> def func(iterator): ... with open(SparkFiles.get("test.txt")) as testFile: ... fileVal = int(testFile.readline()) - ... return [x * 100 for x in iterator] + ... return [x * fileVal for x in iterator] >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() [100, 200, 300, 400] """ diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index eb496688b6eef..3f3b19053d32e 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -25,8 +25,7 @@ from pyspark.serializers import NoOpSerializer -class RandomRDDGenerators: - +class RandomRDDs: """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. @@ -40,17 +39,17 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - C{RandomRDDGenerators.uniformRDD(sc, n, p, seed)\ + C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} - >>> x = RandomRDDGenerators.uniformRDD(sc, 100).collect() + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() >>> len(x) 100 >>> max(x) <= 1.0 and min(x) >= 0.0 True - >>> RandomRDDGenerators.uniformRDD(sc, 100, 4).getNumPartitions() + >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() 4 - >>> parts = RandomRDDGenerators.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() >>> parts == sc.defaultParallelism True """ @@ -66,10 +65,10 @@ def normalRDD(sc, size, numPartitions=None, seed=None): To transform the distribution in the generated RDD from standard normal to some other normal N(mean, sigma), use - C{RandomRDDGenerators.normal(sc, n, p, seed)\ + C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} - >>> x = RandomRDDGenerators.normalRDD(sc, 1000, seed=1L) + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() 1000L @@ -89,7 +88,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): distribution with the input mean. >>> mean = 100.0 - >>> x = RandomRDDGenerators.poissonRDD(sc, mean, 1000, seed=1L) + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() 1000L @@ -110,12 +109,12 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): from the uniform distribution on [0.0 1.0]. >>> import numpy as np - >>> mat = np.matrix(RandomRDDGenerators.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape (10, 10) >>> mat.max() <= 1.0 and mat.min() >= 0.0 True - >>> RandomRDDGenerators.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 """ jrdd = sc._jvm.PythonMLLibAPI() \ @@ -130,7 +129,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): from the standard normal distribution. >>> import numpy as np - >>> mat = np.matrix(RandomRDDGenerators.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - 0.0) < 0.1 @@ -151,7 +150,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> import numpy as np >>> mean = 100.0 - >>> rdd = RandomRDDGenerators.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 3934bdda0a466..240381e5bae12 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -36,7 +36,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long + PickleSerializer, pack_long, CompressedSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -1810,7 +1810,8 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) - pickled_command = CloudPickleSerializer().dumps(command) + ser = CompressedSerializer(CloudPickleSerializer()) + pickled_command = ser.dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index df90cafb245bf..74870c0edcf99 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -67,6 +67,7 @@ import sys import types import collections +import zlib from pyspark import cloudpickle @@ -403,6 +404,22 @@ def loads(self, obj): raise ValueError("invalid sevialization type: %s" % _type) +class CompressedSerializer(FramedSerializer): + """ + compress the serialized data + """ + + def __init__(self, serializer): + FramedSerializer.__init__(self) + self.serializer = serializer + + def dumps(self, obj): + return zlib.compress(self.serializer.dumps(obj), 1) + + def loads(self, obj): + return self.serializer.loads(zlib.decompress(obj)) + + class UTF8Deserializer(Serializer): """ diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 95086a2258222..d4ca0cc8f336e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1093,8 +1093,8 @@ def applySchema(self, rdd, schema): >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.1 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1...)] + ... "float + 1.5 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 22b51110ed671..f1fece998cd54 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -323,6 +323,13 @@ def test_namedtuple_in_rdd(self): theDoes = self.sc.parallelize([jon, jane]) self.assertEquals([jon, jane], theDoes.collect()) + def test_large_broadcast(self): + N = 100000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 270MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEquals(N, m) + class TestIO(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2770f63059853..77a9c4a0e0677 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,7 +30,8 @@ from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ + CompressedSerializer pickleSer = PickleSerializer() @@ -65,12 +66,13 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) + ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = pickleSer._read_with_length(infile) + value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) - command = pickleSer._read_with_length(infile) + command = ser._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 58d44e7923bee..830711a46a35b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -77,28 +77,28 @@ org.apache.maven.plugins maven-jar-plugin - - - test-jar - - - - test-jar-on-compile - compile - - test-jar - - + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5b398695bf560..de2d67ce82ff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -78,7 +78,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .build( new CacheLoader[InType, OutType]() { override def load(in: InType): OutType = globalLock.synchronized { - create(in) + val startTime = System.nanoTime() + val result = create(in) + val endTime = System.nanoTime() + def timeMs = (endTime - startTime).toDouble / 1000000 + logInfo(s"Code generated expression $in in $timeMs ms") + result } }) @@ -413,7 +418,19 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin """.children } - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + // Only inject debugging code if debugging is turned on. + val debugCode = + if (log.isDebugEnabled) { + val localLogger = log + val localLoggerTree = reify { localLogger } + q""" + $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + """ :: Nil + } else { + Nil + } + + EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) } protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index ce6d99c911ab3..e88c5d4fa178a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -60,6 +60,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def eval(input: Row): Any = { child.eval(input) == null } + + override def toString = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 90de11182e605..4f2adb006fbc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -32,6 +32,10 @@ private[spark] object SQLConf { val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" + + // This is only used for the thriftserver + val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index c86811e838bd8..b08f9aacc1fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.execution import java.util.{HashMap => JavaHashMap} -import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.collection.CompactBuffer @DeveloperApi sealed abstract class BuildSide @@ -67,7 +66,7 @@ trait HashJoin { def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. - val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() var currentRow: Row = null // Create a mapping of buildKeys -> rows @@ -77,7 +76,7 @@ trait HashJoin { if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { - val newMatchList = new ArrayBuffer[Row]() + val newMatchList = new CompactBuffer[Row]() hashTable.put(rowKey, newMatchList) newMatchList } else { @@ -89,7 +88,7 @@ trait HashJoin { new Iterator[Row] { private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. @@ -140,7 +139,7 @@ trait HashJoin { /** * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using + * Performs a hash based outer join for two child relations by shuffling the data using * the join keys. This operator requires loading the associated partition in both side into memory. */ @DeveloperApi @@ -179,26 +178,26 @@ case class HashOuterJoin( @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. + // iterator for performance purpose. private[this] def leftOuterIterator( key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { val joinedRow = new JoinedRow() val rightNullRow = new GenericRow(right.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - leftIter.iterator.flatMap { l => + leftIter.iterator.flatMap { l => joinedRow.withLeft(l) var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => matched = true joinedRow.copy } else { Nil }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the + // as we don't know whether we need to append it until finish iterating all of the // records in right side. // If we didn't get any proper row, then append a single row with empty right joinedRow.withRight(rightNullRow).copy @@ -210,20 +209,20 @@ case class HashOuterJoin( key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { val joinedRow = new JoinedRow() val leftNullRow = new GenericRow(left.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - rightIter.iterator.flatMap { r => + rightIter.iterator.flatMap { r => joinedRow.withRight(r) var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => matched = true joinedRow.copy } else { Nil }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the + // as we don't know whether we need to append it until finish iterating all of the // records in left side. // If we didn't get any proper row, then append a single row with empty left. joinedRow.withLeft(leftNullRow).copy @@ -236,7 +235,7 @@ case class HashOuterJoin( val joinedRow = new JoinedRow() val leftNullRow = new GenericRow(left.output.length) val rightNullRow = new GenericRow(right.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) if (!key.anyNull) { @@ -246,8 +245,8 @@ case class HashOuterJoin( leftIter.iterator.flatMap[Row] { l => joinedRow.withLeft(l) var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, // append them directly case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { @@ -260,7 +259,7 @@ case class HashOuterJoin( // 2. For those unmatched records in left, append additional records with empty right. // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all + // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. joinedRow.withRight(rightNullRow).copy @@ -268,8 +267,8 @@ case class HashOuterJoin( } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. case (r, idx) if (!rightMatchedSet.contains(idx)) => { joinedRow(leftNullRow, r).copy } @@ -284,15 +283,15 @@ case class HashOuterJoin( } private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = { - val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]() + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) var existingMatchList = hashTable.get(rowKey) if (existingMatchList == null) { - existingMatchList = new ArrayBuffer[Row]() + existingMatchList = new CompactBuffer[Row]() hashTable.put(rowKey, existingMatchList) } @@ -311,20 +310,20 @@ case class HashOuterJoin( val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) import scala.collection.JavaConversions._ - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) joinType match { case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), rightHashTable.getOrElse(key, EMPTY_LIST)) } case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") @@ -424,7 +423,7 @@ case class BroadcastHashJoin( UnspecifiedDistribution :: UnspecifiedDistribution :: Nil @transient - lazy val broadcastFuture = future { + val broadcastFuture = future { sparkContext.broadcast(buildPlan.executeCollect()) } @@ -550,7 +549,7 @@ case class BroadcastNestedLoopJoin( /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new ArrayBuffer[Row] + val matchedRows = new CompactBuffer[Row] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) @@ -602,20 +601,20 @@ case class BroadcastNestedLoopJoin( val rightNulls = new GenericMutableRow(right.output.size) /** Rows from broadcasted joined with nulls. */ val broadcastRowsWithNulls: Seq[Row] = { - val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() + val buf: CompactBuffer[Row] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) case _ => } } i += 1 } - arrBuf.toSeq + buf.toSeq } // TODO: Breaks lineage. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 68141ce83c796..f6cfab736d98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -17,17 +17,19 @@ package org.apache.spark.sql.parquet -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.util.Try - import java.io.IOException import java.lang.{Long => JLong} import java.text.SimpleDateFormat -import java.util.{Date, List => JList} +import java.util.concurrent.{Callable, TimeUnit} +import java.util.{ArrayList, Collections, Date, List => JList} + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.Try +import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} @@ -41,8 +43,9 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.{Logging, SerializableWritable, TaskContext} @@ -106,6 +109,11 @@ case class ParquetTableScan( ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) } + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + conf.set( + SQLConf.PARQUET_CACHE_METADATA, + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( sc, @@ -361,10 +369,40 @@ private[parquet] class FilteringParquetRowInputFormat } override def getFooters(jobContext: JobContext): JList[Footer] = { + import FilteringParquetRowInputFormat.footerCache + if (footers eq null) { + val conf = ContextUtil.getConfiguration(jobContext) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap - footers = getFooters(ContextUtil.getConfiguration(jobContext), statuses) + if (statuses.isEmpty) { + footers = Collections.emptyList[Footer] + } else if (!cacheMetadata) { + // Read the footers from HDFS + footers = getFooters(conf, statuses) + } else { + // Read only the footers that are not in the footerCache + val foundFooters = footerCache.getAllPresent(statuses) + val toFetch = new ArrayList[FileStatus] + for (s <- statuses) { + if (!foundFooters.containsKey(s)) { + toFetch.add(s) + } + } + val newFooters = new mutable.HashMap[FileStatus, Footer] + if (toFetch.size > 0) { + val fetched = getFooters(conf, toFetch) + for ((status, i) <- toFetch.zipWithIndex) { + newFooters(status) = fetched.get(i) + } + footerCache.putAll(newFooters) + } + footers = new ArrayList[Footer](statuses.size) + for (status <- statuses) { + footers.add(newFooters.getOrElse(status, foundFooters.get(status))) + } + } } footers @@ -377,6 +415,10 @@ private[parquet] class FilteringParquetRowInputFormat configuration: Configuration, footers: JList[Footer]): JList[ParquetInputSplit] = { + import FilteringParquetRowInputFormat.blockLocationCache + + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) val minSplitSize: JLong = Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) @@ -404,16 +446,23 @@ private[parquet] class FilteringParquetRowInputFormat for (footer <- footers) { val fs = footer.getFile.getFileSystem(configuration) val file = footer.getFile - val fileStatus = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) val parquetMetaData = footer.getParquetMetadata val blocks = parquetMetaData.getBlocks - val fileBlockLocations = fs.getFileBlockLocations(fileStatus, 0, fileStatus.getLen) + var blockLocations: Array[BlockLocation] = null + if (!cacheMetadata) { + blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) + } else { + blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { + def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) + }) + } splits.addAll( generateSplits.invoke( null, blocks, - fileBlockLocations, - fileStatus, + blockLocations, + status, parquetMetaData.getFileMetaData, readContext.getRequestedSchema.toString, readContext.getReadSupportMetadata, @@ -425,6 +474,17 @@ private[parquet] class FilteringParquetRowInputFormat } } +private[parquet] object FilteringParquetRowInputFormat { + private val footerCache = CacheBuilder.newBuilder() + .maximumSize(20000) + .build[FileStatus, Footer]() + + private val blockLocationCache = CacheBuilder.newBuilder() + .maximumSize(20000) + .expireAfterWrite(15, TimeUnit.MINUTES) // Expire locations since HDFS files might move + .build[FileStatus, Array[BlockLocation]]() +} + private[parquet] object FileSystemHelper { def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index b0579f76da073..c79a9ac2dad81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -378,8 +378,7 @@ private[parquet] object ParquetTypesConverter extends Logging { val children = fs.listStatus(path).filterNot { status => val name = status.getPath.getName - name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME || - name == FileOutputCommitter.TEMP_DIR_NAME + name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME } // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index c16a7d3661c66..b092f42372171 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -26,8 +26,6 @@ import jline.{ConsoleReader, History} import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} import org.apache.hadoop.hive.common.LogUtils.LogInitializationException import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} @@ -118,17 +116,13 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - /** - * This should be executed before shutdown hook of - * FileSystem to avoid race condition of FileSystem operation - */ - ShutdownHookManager.get.addShutdownHook( + Runtime.getRuntime.addShutdownHook( new Thread() { override def run() { SparkSQLEnv.stop() } } - , FileSystem.SHUTDOWN_HOOK_PRIORITY - 1) + ) // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 9338e8121b0fe..699a1103f3248 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,24 +17,24 @@ package org.apache.spark.sql.hive.thriftserver.server -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.math.{random, round} - import java.sql.Timestamp import java.util.{Map => JMap} +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map} +import scala.math.{random, round} + import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession - import org.apache.spark.Logging +import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.catalyst.plans.logical.SetCommand import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -43,6 +43,9 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + // TODO: Currenlty this will grow infinitely, even as sessions expire + val sessionToActivePool = Map[HiveSession, String]() + override def newExecuteStatementOperation( parentSession: HiveSession, statement: String, @@ -165,8 +168,18 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage try { result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) + result.queryExecution.logical match { + case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) => + sessionToActivePool(parentSession) = value + logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") + case _ => + } + val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) + sessionToActivePool.get(parentSession).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } iter = { val resultRdd = result.queryExecution.toRdd val useIncrementalCollect = diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 93d00f7c37c9b..30ff277e67c88 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -36,6 +36,11 @@ + + com.twitter + parquet-hive-bundle + 1.5.0 + org.apache.spark spark-core_${scala.binary.version} diff --git a/streaming/pom.xml b/streaming/pom.xml index 1072f74aea0d9..ce35520a28609 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -81,11 +81,11 @@ org.apache.maven.plugins @@ -97,8 +97,8 @@ - test-jar-on-compile - compile + test-jar-on-test-compile + test-compile test-jar diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index e0677b795cb94..101cec1c7a7c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -98,9 +98,15 @@ class StreamingContext private[streaming] ( * @param hadoopConf Optional, configuration object if necessary for reading from * HDFS compatible filesystems */ - def this(path: String, hadoopConf: Configuration = new Configuration) = + def this(path: String, hadoopConf: Configuration) = this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) + /** + * Recreate a StreamingContext from a checkpoint file. + * @param path Path to the directory that was specified as the checkpoint directory + */ + def this(path: String) = this(path, new Configuration) + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 774adc3c23c21..75f0e8716dc7e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -23,10 +23,10 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.streaming.ui.StreamingJobProgressListener private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { - val metricRegistry = new MetricRegistry - val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) + override val metricRegistry = new MetricRegistry + override val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) - val streamingListener = ssc.uiTab.listener + private val streamingListener = ssc.uiTab.listener private def registerGauge[T](name: String, f: StreamingJobProgressListener => T, defaultValue: T) {